⬆︎
×

Ultimate Solutions:算法题解与实用模板

A comprehensive collection of solutions for programming challenges from LeetCode, PAT (Programming Ability Test), and other coding platforms.

GitHub仓库:github.com/hyperplasma/Ultimate-Solutions
Gitee仓库:gitee.com/hyperplasma/Ultimate-Solutions

1 基本算法

1.1 排序

1.1.1 插入排序

/**
 * 插入排序
 */
public class InsertSort {
    /**
     * 直接插入排序
     */
    public static void insertSort(int[] arr) {
        for (int i = 1; i < arr.length; i++) {
            for (int j = i; j >= 1 && arr[j] > arr[j - 1]; j--) {
                swap(arr, j, j - 1);
            }
        }
    }

    private static void swap(int[] arr, int i, int j) {
        int temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }

    public static void main(String[] args) {
        int[] arr = {3, 10, 1, 6, 2, 5, 7, 4};
        insertSort(arr);
        for (int i : arr) {
            System.out.print(i + " ");
        }
    }
}

1.1.2 快速排序

/**
 * 快速排序
 */
public class QuickSort {
    /**
     * 快速排序 arr[l ... r]
     */
    public static void quickSort(int[] arr, int l, int r) {
        if (l >= r) return;

        int x = arr[l + (r - l) / 2];   // 枢轴(选择中间元素)
        int i = l - 1, j = r + 1;       // 双指针初始位于两侧外(追加1偏移量)

        while (i < j) {
            do {
                i++;
            } while (arr[i] < x);
            do {
                j--;
            } while (arr[j] > x);
            if (i < j) {
                swap(arr, i, j);
            }
        }

        quickSort(arr, l, j);       // 左子区间右端点必须为j
        quickSort(arr, j + 1, r);
    }

    private static void swap(int[] arr, int i, int j) {
        int temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }

    public static void main(String[] args) {
        int[] arr = {3, 6, 5, 1, 9, 7, 10, 2, 8, 4};
        quickSort(arr, 0, arr.length - 1);
        for (int j : arr) {
            System.out.print(j + " ");
        }
    }
}

1.1.3 归并排序

/**
 * 归并排序
 */
public class MergeSort {
    /**
     * 二路归并排序 arr[l ... r]
     */
    public static void mergeSort(int[] arr, int l, int r) {
        if (l >= r) return;

        int mid = l + r >> 1;
        mergeSort(arr, l, mid);     // 递归排序左半部分
        mergeSort(arr, mid + 1, r);     // 递归排序右半部分

        int[] temp = new int[r - l + 1];    // 辅助数组
        int i = l, j = mid + 1, k = 0;      // 初始化指针

        // 归并左右子区间为有序子区间
        while (i <= mid && j <= r) {
            if (arr[i] <= arr[j]) {
                temp[k++] = arr[i++];
            } else {
                temp[k++] = arr[j++];
            }
        }

        // 并入区间剩余元素
        while (i <= mid) {
            temp[k++] = arr[i++];
        }
        while (j <= r) {
            temp[k++] = arr[j++];
        }

        // 将排序后的结果复制回原始数组
        for (i = l, j = 0; i <= r; i++, j++) {
            arr[i] = temp[j];
        }
    }

    public static void main(String[] args) {
        int[] arr = {5, 9, 2, 4, 1, 6, 3, 7, 8};
        mergeSort(arr, 0, arr.length - 1);
        for (int i : arr) {
            System.out.print(i + " ");
        }
    }
}

1.2 二分查找

/**
 * 二分查找
 */
public class BinarySearch {

    static int target;

    /**
     * 查找左边界,即第一个满足条件的元素下标 (lower_bound)
     */
    public static int binarySearchL(int l, int r) {
        while (l < r) {
            int mid = l + r >> 1;    // 计算中点值
            if (ge(mid, target)) {
                r = mid;            // 如果中点值符合条件,则继续在左边查找
            } else {
                l = mid + 1;        // 否则在右边查找
            }
        }
        return l;    // 返回左边界
    }

    /**
     * 查找右边界,即最后一个满足条件的元素下标 (upper_bound的前驱)
     */
    public static int binarySearchR(int l, int r) {
        while (l < r) {
            int mid = l + r + 1 >> 1;    // 计算中点值,向右偏移
            if (le(mid, target)) {
                l = mid;                // 如果中点值符合条件(≤),则继续在右边查找
            } else {
                r = mid - 1;            // 否则在左边查找
            }
        }
        return r;    // 返回右边界
    }

    /**
     * 浮点数二分
     */
    public static double binarySearchF(double l, double r) {
        final double eps = 1e-8;        // 精度,视题目而定
        while (r - l > eps) {
            double mid = (l + r) / 2;    // 计算中点值
            if (ge(mid, target)) {
                r = mid;        // 目标在左边,更新右边界
            } else {
                l = mid;        // 否则更新左边界
            }
        }
        return l;    // 返回左边界,即为目标值的估计
    }

    private static boolean ge(int mid, int target) {
        // 自定义比较条件 ...
        return mid >= target;
    }

    private static boolean ge(double mid, double target) {
        // 自定义比较条件 ...
        return mid - target >= 0;
    }

    private static boolean le(int mid, int target) {
        // 自定义比较条件 ...
        return mid <= target;
    }

    public static void main(String[] args) {
        System.out.println(binarySearchL(0, 10));
        System.out.println(binarySearchR(0, 10));
        System.out.println(binarySearchF(0, 10));
    }
}

1.3 高精度运算

1.3.1 ArrayList实现

/**
 * 高精度计算(使用列表存储大数)
 */
public class BigIntCalculation {

    public static int BASE = 10;  // 进制

    /**
     * 高精度加法 C = A + B, A >= 0, B >= 0
     */
    public static ArrayList<Integer> add(ArrayList<Integer> A, ArrayList<Integer> B) {
        if (A.size() < B.size()) {
            return add(B, A);
        }

        ArrayList<Integer> C = new ArrayList<>();
        int t = 0;  // 进位
        for (int i = 0; i < A.size(); i++) {
            t += A.get(i);
            if (i < B.size()) {
                t += B.get(i);
            }
            C.add(t % BASE);
            t /= BASE;
        }

        if (t > 0) {
            C.add(t); // 存入最后的进位
        }
        return C;
    }

    // 比较两个高精度整数的大小,返回A - B的符号
    public static int cmp(ArrayList<Integer> A, ArrayList<Integer> B) {
        if (A.size() > B.size()) {
            return 1;   // 优先比较长度
        } else if (A.size() < B.size()) {
            return -1;
        }

        for (int i = A.size() - 1; i >= 0; i--)  // 从高位起逐位比较
            if (A.get(i) > B.get(i)) {
                return 1;
            } else if (A.get(i) < B.get(i)) {
                return -1;
            }

        return 0;
    }

    /**
     * 高精度减法 C = A - B, A >= B, A >= 0, B >= 0
     */
    public static ArrayList<Integer> sub(ArrayList<Integer> A, ArrayList<Integer> B) {
        ArrayList<Integer> C = new ArrayList<>();
        int t = 0;  // 借位
        for (int i = 0; i < A.size(); i++) {
            t = A.get(i) - t;   // 成为本轮的被减数
            if (i < B.size()) {
                t -= B.get(i); // 先直接相减,t<0则说明需借位
            }
            C.add((t + BASE) % BASE);     // 若t<0,则存的是借位后的差;否则正常存差
            if (t < 0) { // 判断是否需借位
                t = 1;
            } else {
                t = 0;
            }
        }

        while (C.size() > 1 && C.get(C.size() - 1) == 0) {
            C.remove(C.size() - 1);   // 去除前导0(结果为0则保留1位)
        }
        return C;
    }

    /**
     * 高精度乘法 C = A * b, A >= 0, b >= 0
     */
    public static ArrayList<Integer> mul(ArrayList<Integer> A, int b) {
        ArrayList<Integer> C = new ArrayList<>();
        int t = 0;  // 进位
        for (int i = 0; i < A.size() || t > 0; i++) {    // 自动处理最后剩余进位(i>=size但t>0)
            if (i < A.size()) {
                t += A.get(i) * b;
            }
            C.add(t % BASE);
            t /= BASE;
        }

        while (C.size() > 1 && C.get(C.size() - 1) == 0) {
            C.remove(C.size() - 1);   // b为0时,需去除前导0
        }
        return C;
    }

    /**
     * 高精度除法 A / b = C ... r, A >= 0, b > 0
     */
    public static ArrayList<Integer> div(ArrayList<Integer> A, int b, int[] r) {
        ArrayList<Integer> C = new ArrayList<>();
        r[0] = 0;  // 余数
        for (int i = A.size() - 1; i >= 0; i--) {    // 从最高位开始除
            r[0] = r[0] * BASE + A.get(i);
            C.add(r[0] / b); // 暂时将高位存于低位
            r[0] %= b;
        }

        Collections.reverse(C);    // 逆转后即为正常存储形式
        while (C.size() > 1 && C.get(C.size() - 1) == 0) {
            C.remove(C.size() - 1);
        }
        return C;
    }

    // 将以列表存储的大数格式化为字符串
    private static String printf(ArrayList<Integer> num) {
        StringBuilder sb = new StringBuilder();
        for (int i = num.size() - 1; i >= 0; i--) {
            sb.append(num.get(i));
        }
        return sb.toString();
    }
}

1.3.2 BigInteger

import java.math.BigInteger;

public class BigIntegerDemo {

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        BigInteger a = new BigInteger(scanner.next());

        BigInteger b = new BigInteger(scanner.next());

        // 读取进制(可选,默认为10)
        int radix = 10;
        try {
            String radixStr = scanner.nextLine();
            if (!radixStr.isEmpty()) {
                radix = Integer.parseInt(radixStr);
            }
        } catch (NumberFormatException e) {
            System.out.println("无效的进制,使用默认值10");
        }

        // 加法运算
        BigInteger sum = a.add(b);
        System.out.println("a + b = " + sum.toString(radix));

        // 减法运算
        BigInteger diff = a.subtract(b);
        System.out.println("a - b = " + diff.toString(radix));

        // 乘法运算
        BigInteger product = a.multiply(b);
        System.out.println("a * b = " + product.toString(radix));

        // 除法运算
        if (b.equals(BigInteger.ZERO)) {
            System.out.println("a / b = Inf");
        } else {
            BigInteger[] result = a.divideAndRemainder(b);
            System.out.println("a / b = " + result[0].toString(radix) + " ... " + result[1].toString(radix));
        }

        // 比较大小
        int cmpResult = a.compareTo(b);
        if (cmpResult > 0) {
            System.out.println("a > b");
        } else if (cmpResult < 0) {
            System.out.println("a < b");
        } else {
            System.out.println("a = b");
        }

        scanner.close();
    }
}

1.4 前缀和、差分

1.4.1 一维

/**
 * 一维前缀和
 */
public class PrefixSumArray {

    static int n = 100010;
    static int[] a = new int[n + 1];
    static int[] s = new int[n + 1];    // 前缀和数组

    /**
     * 初始化前缀和数组
     */
    public static void init() {
        for (int i = 1; i <= n; i++) {
            s[i] = s[i - 1] + a[i];
        }
    }

    /**
     * 求下标区间[l, r]上的片段和
     */
    public static int getPartialSum(int l, int r) {
        return s[r] - s[l - 1];
    }

    public static void main(String[] args) {
        init();
        System.out.println(getPartialSum(1, 3));
    }
}
/**
 * 一维差分
 */
public class DifferenceArray {

    static int n = 100010;
    static int[] a = new int[n + 1];
    static int[] b = new int[n + 1];    // 差分数组

    /**
     * 给区间[l, r]上所有数加上c
     */
    public static void insert(int l, int r, int c) {
        b[l] += c;
        b[r + 1] -= c;
    }

    /**
     * 初始化差分数组
     */
    public static void init() {
        for (int i = 1; i <= n; i++) {
            insert(i, i, a[i]);
        }
    }

    /**
     * 将操作过的差分数组变为原数组(前缀和与差分互为逆运算)
     */
    public static void toOriginalArray() {
        for (int i = 1; i <= n; i++) {
            b[i] += b[i - 1];
        }
    }

    public static void main(String[] args) {
        init();
        toOriginalArray();
    }
}

1.4.2 二维

/**
 * 二维前缀和
 */
public class PrefixSumMatrix {

    static int n = 100010;
    static int m = 100010;
    static int[][] a = new int[n + 1][m + 1];
    static int[][] s = new int[n + 1][m + 1];   // 前缀和矩阵

    /**
     * 初始化前缀和矩阵
     */
    public static void init() {
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                s[i][j] = s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] + a[i][j];
            }
        }
    }

    /**
     * 求以(x1, y1)为左上角、(x2, y2)为右下角的子矩阵(含边界)上的片段和
     */
    public static int getPartialSum(int x1, int y1, int x2, int y2) {
        return s[x2][y2] - s[x2][y1 - 1] - s[x1 - 1][y2] + s[x1 - 1][y1 - 1];
    }

    public static void main(String[] args) {
        init();
        System.out.println(getPartialSum(1, 1, 2, 2));
    }
}
/**
 * 二维差分
 */
public class DifferenceMatrix {

    static int n = 100010;
    static int m = 100010;
    static int[][] a = new int[n + 1][m + 1];
    static int[][] b = new int[n + 1][m + 1];    // 差分矩阵

    /**
     * 给以(x1, y1)为左上角、(x2, y2)为右下角的子矩阵(含边界)加上c
     */
    public static void insert(int x1, int y1, int x2, int y2, int c) {
        b[x1][y1] += c;
        b[x2 + 1][y1] -= c;
        b[x1][y2 + 1] -= c;
        b[x2 + 1][y2 + 1] += c;
    }

    /**
     * 初始化差分矩阵
     */
    public static void init() {
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                insert(i, j, i, j, a[i][j]);
            }
        }
    }

    /**
     * 将操作过的差分矩阵变为原矩阵:求差分矩阵的前缀和
     */
    public static void toOriginalMatrix() {
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                b[i][j] += b[i - 1][j] + b[i][j - 1] - b[i - 1][j - 1];
            }
        }
    }

    public static void main(String[] args) {
        init();
        toOriginalMatrix();
    }
}

1.5 位运算

/**
 * 位运算
 */
public class BitwiseOperation {
    /**
     * 返回x的最后一位1
     */
    public static int lowbit(int x) {
        return x & -x;  // -x = ~x + 1
    }

    /**
     * 输出整数x的二进制表示(31位)
     */
    public static void printInBinary(int x) {
        for (int i = 0; i < 31; i++) {
            System.out.print(x >> i & 1);
        }
    }

    /**
     * 统计x的二进制表示中有几位1
     */
    public static int countOnesInBinary(int x) {
        int cnt = 0;
        while (x != 0) {
            x -= lowbit(x);
            cnt++;
        }
        return cnt;
    }

    public static void main(String[] args) {
        System.out.println(lowbit(5));  // 1
        printInBinary(5);  // 101
        System.out.println(countOnesInBinary(5));  // 2
    }
}

1.6 离散化

/**
 * 离散化
 */
public class Discretization {

    static List<Integer> alls = new ArrayList<>();  // 存储所有待离散化的值

    /**
     * 初始化离散化列表(保序)
     */
    public static void init() {
        Collections.sort(alls);     // 将所有值排序
        alls = new ArrayList<>(new HashSet<>(alls));    // 去重
    }

    /**
     * 根据离散化的值k获取原来的值x
     */
    public static int get(int k) {
        return alls.get(k);
    }

    /**
     * 二分求出x对应的离散化的值
     */
    public static int find(int x) {
        int l = 0, r = alls.size() - 1;
        while (l < r) {     // 找到第一个大于等于x的位置(唯一)
            int mid = (l + r) >> 1;
            if (alls.get(mid) >= x) {
                r = mid;
            } else {
                l = mid + 1;
            }
        }
        return r + 1;   // 这里+1是为了映射到1, 2, ..., alls.size()
    }

    public static void main(String[] args) {
        alls.addAll(Arrays.asList(5, 7, 1, 12, 9, 20, 4, 6, 10, 33, 2));
        init();
        for (int i = 0; i < alls.size(); i++) {
            System.out.println(get(i));
        }
        System.out.println(find(2));
        System.out.println(find(4));
    }
}

1.7 区间合并

/**
 * 区间合并
 */
public class SegmentsMerge {

    static List<int[]> segs = new ArrayList<>();    // 每个元素表示一段区间,int[0]表示左端点,int[1]表示右端点

    /**
     * 合并区间
     */
    public static List<int[]> merge() {
        List<int[]> res = new ArrayList<>();

        segs.sort(Comparator.comparingInt(a -> a[0])); // 按左端点大小排序

        int st = Integer.MIN_VALUE, ed = Integer.MIN_VALUE; // 当前维护区间(初始化为负无穷)
        for (int[] seg : segs) {
            if (ed < seg[0]) {    // 若与当前维护区间无交集
                if (st != Integer.MIN_VALUE) {
                    res.add(new int[]{st, ed});    // 当前区间结束维护并保存
                }
                st = seg[0];    // 转移至此区间
                ed = seg[1];
            } else {
                ed = Math.max(ed, seg[1]);    // 有交集则比较右端点
            }
        }

        if (st != Integer.MIN_VALUE) {
            res.add(new int[]{st, ed});    // 保存最后一个区间
        }
        return res;
    }

    public static void main(String[] args) {
        segs.add(new int[]{1, 3});
        segs.add(new int[]{2, 6});
        segs.add(new int[]{8, 10});
        segs.add(new int[]{15, 18});
        System.out.println(merge());
    }
}

2 数据结构

2.1 链表

/**
 * 单链表结点
 */
public class ListNode {

    int val;
    ListNode next;

    ListNode() {
    }

    ListNode(int val) {
        this.val = val;
    }

    ListNode(int val, ListNode next) {
        this.val = val;
        this.next = next;
    }
}
/**
 * 链表
 */
public class LinkedList {
    /**
     * 翻转链表
     */
    public static ListNode reverse(ListNode head) {
        ListNode cur = head, pre = null;
        while (cur != null) {
            ListNode nxt = cur.next;
            cur.next = pre;
            pre = cur;
            cur = nxt;
        }
        return pre;
    }

    /**
     * 快慢指针,寻找环
     */
    public static ListNode findMid(ListNode head) {
        ListNode fast = head, slow = head;
        while (fast != null && fast.next != null) {
            fast = fast.next.next;
            slow = slow.next;
        }
        return slow;
    }

    public static void main(String[] args) {
        ListNode head = new ListNode(1, new ListNode(2, new ListNode(3, new ListNode(4, new ListNode(5)))));
        reverse(head);
        ListNode mid = findMid(head);
        System.out.println(mid.val);
    }
}

2.2 KMP算法

/**
 * KMP算法
 */
public class KMP {
    /**
     * KMP算法模式匹配
     */
    public static int kmp(String str, String pattern) {
        int n = str.length(), m = pattern.length();
        if (m == 0) {
            return 0;
        }

        int[] next = new int[m];
        // 求pattern串的next数组
        for (int i = 1, j = 0; i < m; i++) {
            while (j > 0 && pattern.charAt(i) != pattern.charAt(j)) {
                j = next[j - 1];
            }
            if (pattern.charAt(i) == pattern.charAt(j)) {
                j++;
            }
            next[i] = j;
        }

        // 字符串匹配
        for (int i = 0, j = 0; i < n; i++) {
            while (j > 0 && str.charAt(i) != pattern.charAt(j)) {
                j = next[j - 1];
            }
            if (str.charAt(i) == pattern.charAt(j)) {
                j++;
            }
            if (j == m) {
                return i - m + 1;
            }
        }
        return -1;
    }

    public static void main(String[] args) {
        String str = "ababababca";
        String pattern = "ababca";
        System.out.println(kmp(str, pattern));
    }
}

2.3 字典树(Trie)

/**
 * 字典树
 */
public class Trie {

    static final int N = 100010;

    static int[][] son = new int[N][26];        // son[p][u]记录结点p的第u个子结点
    static int[] cnt = new int[N];      // cnt[p]存储以结点p结尾的单词数量
    static int idx = 0;     // idx初始为0(0号点既是根结点,又是空结点)

    /**
     * 插入一个字符串
     */
    public static void insert(String str) {
        int p = 0;    // 从根结点0起遍历Trie的指针
        for (int i = 0; i < str.length(); i++) {
            int u = str.charAt(i) - 'a';    // 将字符转换为对应的索引
            if (son[p][u] == 0) {
                son[p][u] = ++idx;    // 不存在则创建该子结点
            }
            p = son[p][u];    // 更新指针走向当前子结点
        }
        cnt[p]++;    // p最终指向字符串末尾字母,p计数器自增
    }

    /**
     * 查询字符串出现的次数
     */
    public static int query(String str) {
        int p = 0;    // 从根结点开始
        for (int i = 0; i < str.length(); i++) {
            int u = str.charAt(i) - 'a'; // 将字符转换为对应的索引
            if (son[p][u] == 0) {
                return 0;       // 不存在则直接返回0
            }
            p = son[p][u];    // 更新指针至下一个结点
        }
        return cnt[p];    // p最终指向字符串末尾字母,返回数量
    }

    public static void main(String[] args) {
        insert("abc");
        insert("ab");
        insert("abc");
        insert("aab");
        insert("a");
        insert("abc");
        System.out.println(query("abc"));
    }
}

2.4 并查集

2.4.1 朴素并查集

/**
 * 朴素并查集
 */
public class DisjointSet {

    static final int N = 100010;

    static int n = 10;
    static int[] p = new int[N];    // p[1 ... n],p[i]存储结点i的祖先(路径压缩后则为根结点)

    /**
     * 初始化
     */
    public static void init() {
        for (int i = 1; i <= n; i++) {
            p[i] = i;
        }
    }

    /**
     * 并查集核心操作:返回结点x所属集合的根结点,并进行路径压缩
     */
    public static int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }

    /**
     * 合并结点a和b所在集合:将a并至b
     */
    public static void union(int a, int b) {
        p[find(a)] = find(b);   // 将a的根接在b的根之后
    }

    public static void main(String[] args) {
        init();
        union(1, 2);
        union(2, 3);
        System.out.println(find(1) == find(3));
    }
}

2.4.2 维护集合大小的并查集

/**
 * 维护集合大小的并查集
 */
public class DisjointSetWithSize {

    static final int N = 100010;

    static int n = 10;
    static int[] p = new int[N];
    static int[] cnt = new int[N];  // cnt[i]存储根结点i的集合中结点数(仅根结点的cnt有意义)

    /**
     * 初始化
     */
    public static void init() {
        for (int i = 1; i <= n; i++) {
            p[i] = i;
            cnt[i] = 1; // 初始化各结点为根,其所属集合大小为1
        }
    }

    /**
     * 并查集核心操作
     */
    public static int find(int x) {
        if (p[x] != x) {
            p[x] = find(p[x]);
        }
        return p[x];
    }

    /**
     * 合并结点a和b所在集合:将a并至b
     */
    public static void union(int a, int b) {
        if (find(a) == find(b)) return; // 若已在同一集合内则跳过

        cnt[find(b)] += cnt[find(a)];   // 需将a所属集合的大小加至b
        p[find(a)] = find(b);
    }

    public static void main(String[] args) {
        init();
        union(1, 2);
        union(2, 3);
        System.out.println(cnt[find(1)]);
    }
}

2.4.3 维护到祖宗结点距离的并查集

/**
 * 维护到祖宗结点距离的并查集
 */
public class DisjointSetWithDistance {

    static final int N = 100010;

    static int n = 10;
    static int[] p = new int[N];
    static int[] d = new int[N];    // d[i]存储结点i到其根结点p[i]的距离

    /**
     * 初始化
     */
    public static void init() {
        for (int i = 1; i <= n; i++) {
            p[i] = i;
            d[i] = 0;   // 初始化全为0
        }
    }

    /**
     * 并查集核心操作
     */
    public static int find(int x) {
        if (p[x] != x) {
            int u = find(p[x]);    // u临时记录根结点
            d[x] += d[p[x]];    // 更新x到根p[x]的路径长度
            p[x] = u;
        }
        return p[x];
    }

    /**
     * 合并结点a和b所在集合:将a并至b
     */
    public static void union(int a, int b) {
        p[find(a)] = find(b);
    }

    /**
     * 根据具体问题,初始化根find(a)的偏移量
     */
    public static void setDistance(int a, int distance) {
        d[find(a)] = distance;
    }

    public static void main(String[] args) {
        init();
        union(1, 2);
        union(2, 3);
        setDistance(1, 2);
        System.out.println(find(1));
        System.out.println(d[1]);
    }
}

2.5 堆(优先队列)

public class Heap {
    Queue<Integer> minHeap = new PriorityQueue<>(); // 小根堆
    Queue<Integer> maxHeap = new PriorityQueue<>((a, b) -> b - a); // 大根堆
}

3 图论与搜索

3.1 DFS与BFS

/**
 * 遍历
 */
public class Traversal {

    public static int[][] dirs = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}};

    /**
     * DFS示例:遍历迷宫
     */
    public static void dfsMaze(int[][] maze, int[] u) {
        System.out.println(u[0] + " " + u[1]);
        if (u[0] == maze.length - 1 && u[1] == maze[0].length - 1) {
            return;
        }
        for (int[] dir : dirs) {
            int x = u[0] + dir[0];
            int y = u[1] + dir[1];
            if (x >= 0 && x < maze.length && y >= 0 && y < maze[0].length && maze[x][y] == 0) {
                maze[x][y] = 1;
                dfsMaze(maze, u);
                maze[x][y] = 0;
            }
        }
    }

    /**
     * BFS示例:遍历迷宫
     */
    public static void bfsMaze(int[][] maze, int[] u) {
        Queue<int[]> q = new LinkedList<>();
        q.offer(u);
        while (!q.isEmpty()) {
            int[] cur = q.poll();
            System.out.println(cur[0] + " " + cur[1]);
            if (cur[0] == maze.length - 1 && cur[1] == maze[0].length - 1) {
                return;
            }
            for (int[] dir : dirs) {
                int x = cur[0] + dir[0];
                int y = cur[1] + dir[1];
                if (x >= 0 && x < maze.length && y >= 0 && y < maze[0].length && maze[x][y] == 0) {
                    maze[x][y] = 1;
                    q.offer(new int[]{x, y});
                }
            }
        }
    }
}

3.2 二叉树

/**
 * 二叉树结点
 */
public class TreeNode {

    int val;
    TreeNode left;
    TreeNode right;

    TreeNode() {
    }

    TreeNode(int val) {
        this.val = val;
    }

    TreeNode(int val, TreeNode left, TreeNode right) {
        this.val = val;
        this.left = left;
        this.right = right;
    }
}

3.2.1 遍历

/**
 * 二叉树的遍历
 */
public class Traversal {
    /**
     * 层序遍历
     */
    public static void levelOrder(TreeNode root) {
        Queue<TreeNode> q = new LinkedList<>();
        q.offer(root);
        while (!q.isEmpty()) {
            TreeNode cur = q.poll();
            System.out.println(cur.val);
            if (cur.left != null) {
                q.offer(cur.left);
            }
            if (cur.right != null) {
                q.offer(cur.right);
            }
        }
    }

    /**
     * 先序遍历
     */
    public static void preOrder(TreeNode root) {
        if (root == null) {
            return;
        }
        System.out.println(root.val);
        preOrder(root.left);
        preOrder(root.right);
    }

    /**
     * 中序遍历
     */
    public static void inOrder(TreeNode root) {
        if (root == null) {
            return;
        }
        inOrder(root.left);
        System.out.println(root.val);
        inOrder(root.right);
    }

    /**
     * 后序遍历
     */
    public static void postOrder(TreeNode root) {
        if (root == null) {
            return;
        }
        postOrder(root.left);
        postOrder(root.right);
        System.out.println(root.val);
    }

    public static void main(String[] args) {
        TreeNode root = new TreeNode(1, new TreeNode(2, new TreeNode(4), new TreeNode(5)), new TreeNode(3));

        levelOrder(root);
        System.out.println();

        preOrder(root);
        System.out.println();

        inOrder(root);
        System.out.println();

        postOrder(root);
    }
}
/**
 * 二叉树的非递归遍历
 */
public class TraversalNonrecurring {
    /**
     * 非递归先序遍历
     */
    public static List<Integer> preOrder(TreeNode root) {
        List<Integer> res = new ArrayList<>();
        Deque<TreeNode> stack = new ArrayDeque<>();
        while (root != null || !stack.isEmpty()) {
            while (root != null) {
                res.add(root.val);
                stack.push(root);
                root = root.left;
            }
            root = stack.pop();
            root = root.right;
        }
        return res;
    }

    /**
     * 非递归中序遍历
     */
    public static List<Integer> inOrder(TreeNode root) {
        List<Integer> res = new ArrayList<>();
        Deque<TreeNode> stack = new ArrayDeque<>();
        while (root != null || !stack.isEmpty()) {
            while (root != null) {
                stack.push(root);
                root = root.left;
            }
            root = stack.pop();
            res.add(root.val);
            root = root.right;
        }
        return res;
    }

    /**
     * 非递归后序遍历
     */
    public static List<Integer> postOrder(TreeNode root) {
        List<Integer> res = new ArrayList<>();
        Deque<TreeNode> stack = new ArrayDeque<>();
        TreeNode pre = null;
        while (root != null || !stack.isEmpty()) {
            while (root != null) {
                stack.push(root);
                root = root.left;
            }
            root = stack.pop();
            if (root.right == null || root.right == pre) {
                res.add(root.val);
                pre = root;
                root = null;
            } else {
                stack.push(root);
                root = root.right;
            }
        }
        return res;
    }

    public static void main(String[] args) {
        TreeNode root = new TreeNode(1, new TreeNode(2, new TreeNode(4), new TreeNode(5)), new TreeNode(3));
        System.out.println(preOrder(root));
        System.out.println(inOrder(root));
        System.out.println(postOrder(root));
    }
}

3.2.2 构建二叉树

/**
 * 构建二叉树
 */
public class BuildTree {

    private static Map<Integer, Integer> in2Idx;     // 中序数组的值 -> 索引

    /**
     * 根据前序和中序序列建立二叉树
     */
    public static TreeNode buildTree(int[] preorder, int[] inorder) {
        in2Idx = new HashMap<>();
        for (int i = 0; i < inorder.length; i++) {
            in2Idx.put(inorder[i], i);
        }
        return build(preorder, inorder, 0, 0, inorder.length - 1);
    }

    /**
     * DFS递归建树
     */
    private static TreeNode build(int[] preorder, int[] inorder, int preIdx, int inStart, int inEnd) {
        if (inStart > inEnd || preIdx > preorder.length - 1) {
            return null;
        }
        TreeNode root = new TreeNode(preorder[preIdx]);
        int inIdx = in2Idx.get(preorder[preIdx]);
        root.left = build(preorder, inorder, preIdx + 1, inStart, inIdx - 1);
        root.right = build(preorder, inorder, preIdx + inIdx - inStart + 1, inIdx + 1, inEnd);
        return root;
    }

    public static void main(String[] args) {
        int[] preorder = {3, 9, 20, 15, 7};
        int[] inorder = {9, 3, 15, 20, 7};
        TreeNode root = buildTree(preorder, inorder);
        System.out.println(root.val);
    }
}

3.2.3 最近公共祖先(LCA)

/**
 * 最近公共祖先(LCA)
 */
public class LowestCommonAncestor {
    /**
     * 递归解法
     */
    public static TreeNode lca(TreeNode root, TreeNode p, TreeNode q) {
        if (root == null || root == p || root == q) {
            return root;
        }
        TreeNode l = lca(root.left, p, q);
        TreeNode r = lca(root.right, p, q);
        if (l != null && r != null) {
            return root;
        }
        if (l != null) {
            return l;
        }
        return r;
    }

    public static void main(String[] args) {
        TreeNode root = new TreeNode(3, new TreeNode(5, new TreeNode(6), new TreeNode(2, new TreeNode(7), new TreeNode(4))), new TreeNode(1, new TreeNode(0), new TreeNode(8)));
        System.out.println(lca(root, root.left.left, root.left.right).val);
    }
}

3.2.4 二叉搜索树(BST)

/**
 * 二叉搜索树(BST)
 */
public class BinarySearchTree {
    /**
     * 验证BST
     */
    public static boolean isValidBST(TreeNode root) {
        return checkBST(root);
    }

    /**
     * 递归整棵树
     */
    private static boolean checkBST(TreeNode root) {
        if (root == null) {
            return true;
        }
        if (root.left != null && root.left.val >= root.val) {
            return false;
        }
        if (root.right != null && root.right.val <= root.val) {
            return false;
        }
        return checkBST(root.left) && checkBST(root.right);
    }

    public static void main(String[] args) {
        TreeNode root1 = new TreeNode(1, new TreeNode(2, new TreeNode(4), new TreeNode(5)), new TreeNode(3));
        System.out.println(isValidBST(root1));
        TreeNode root2 = new TreeNode(4, new TreeNode(2, new TreeNode(1), new TreeNode(3)), new TreeNode(5));
        System.out.println(isValidBST(root2));
    }
}

4 数学

4.1 素数

/**
 * 质数
 */
public class Primes {
    /**
     * 试除法判断质数 O(sqrt(n))
     */
    public static boolean isPrime(int x) {
        if (x < 2) {
            return false;
        }
        for (int i = 2; i <= x / i; i++) {
            if (x % i == 0) {
                return false;
            }
        }
        return true;
    }

    /**
     * 试除法分解质因数 O(sqrt(n))
     *
     * @return pairs 存储分解所得的质因数及其个数
     */
    public static List<int[]> divide(int x) {
        List<int[]> pairs = new ArrayList<>();  // 存储质因数及其个数
        for (int i = 2; i <= x / i; i++) {
            while (x % i == 0) {
                int cnt = 0;
                while (x % i == 0) {
                    x /= i;
                    cnt++;
                }
                pairs.add(new int[]{i, cnt});
            }
        }
        if (x > 1) {
            pairs.add(new int[]{x, 1});    // x中只包含1个大于sqrt(x)的质因子
        }
        return pairs;
    }

    /**
     * 埃氏筛法求素数表 O(n log log n)
     *
     * @param n 范围[2, n]
     * @return primes 存储所有素数
     */
    public static List<Integer> sieveOfEratosthenes(int n) {
        List<Integer> primes = new ArrayList<>();
        boolean[] st = new boolean[n + 1];  // 标记数i是否被筛掉(非素数)
        for (int i = 2; i <= n; i++) {
            if (!st[i]) {
                primes.add(i);
                for (int j = i + i; j <= n; j += i) { // 筛去i的倍数,朴素法遍历全部倍数
                    st[j] = true;
                }
            }
        }
        return primes;
    }

    /**
     * 线性筛法求素数表 O(n)
     *
     * 核心思想:每个合数只会被其最小质因子筛掉
     *
     * @param n 范围[2, n]
     * @return primes 存储所有素数
     */
    public static List<Integer> getPrimes(int n) {
        List<Integer> primes = new ArrayList<>();
        boolean[] st = new boolean[n + 1];  // 标记数i是否被筛掉(非素数)
        for (int i = 2; i <= n; i++) {
            if (!st[i]) {
                primes.add(i);
            }
            for (int j = 0; primes.get(j) <= n / i; j++) {
                st[i * primes.get(j)] = true;
                if (i % primes.get(j) == 0) {
                    break;  // i的最小质因子是primes.get(j),故i * primes.get(j)的最小质因子必为primes.get(j)
                }
            }
        }
        return primes;
    }

    public static void main(String[] args) {
        System.out.println(isPrime(100) + "\n");

        List<int[]> pairs = divide(100);
        for (int[] pair : pairs) {
            System.out.println(pair[0] + " " + pair[1]);
        }
        System.out.println();

        List<Integer> primes = sieveOfEratosthenes(100);
        for (int prime : primes) {
            System.out.println(prime);
        }
        System.out.println();

        primes = getPrimes(100);
        for (int prime : primes) {
            System.out.println(prime);
        }
    }
}

4.2 约数

/**
 * 约数
 */
public class Divisors {

    private static final int MOD = 1000000007;  // 防止结果过大而溢出

    /**
     * 试除法求所有约数
     *
     * 时间复杂度:取决于排序函数,试除的消耗为 O(sqrt(n))
     */
    public static List<Integer> getDivisors(int x) {
        List<Integer> divisors = new ArrayList<>();
        for (int i = 1; i <= x / i; i++) {  // 枚举到sqrt(x)
            if (x % i == 0) {   // 若i为x的约数,则x/i也是x的约数
                divisors.add(i);
                if (i != x / i) {    // 不重复存储约数sqrt(x)
                    divisors.add(x / i);
                }
            }
        }
        divisors.sort(null);    // 对约数进行排序
        return divisors;
    }

    /**
     * 求约数个数
     */
    public static long getDivisorsCount(int x) {
        List<int[]> pairs = Primes.divide(x);   // 存储分解所得的质因数及其个数
        long cnt = 1;
        for (int[] pair : pairs) {
            cnt = cnt * (pair[1] + 1) % MOD;
        }
        return cnt;
    }

    /**
     * 求约数之和
     */
    public static long getDivisorsSum(int x) {
        List<int[]> pairs = Primes.divide(x);   // 存储分解所得的质因数及其个数
        long sum = 1;
        for (int[] pair : pairs) {
            int p = pair[0], a = pair[1];   // 约数p与指数a
            long tmp = 1;
            while (a-- > 0) {
                tmp = (tmp * p + 1) % MOD;  // 秦九韶算法
            }
            sum = sum * tmp % MOD;
        }
        return sum;
    }

    /**
     * 欧几里得算法 O(log n)
     */
    public static int gcd(int a, int b) {
        return b == 0 ? a : gcd(b, a % b);
    }

    /**
     * 求最小公倍数
     */
    public static int lcm(int a, int b) {
        return a / gcd(a, b) * b;
    }

    public static void main(String[] args) {
        System.out.println(getDivisors(12));
        System.out.println(getDivisorsCount(12));
        System.out.println(getDivisorsSum(12));
        System.out.println(gcd(12, 18));
        System.out.println(lcm(12, 18));
    }
}

4.3 快速幂

/**
 * 快速幂 O(log n)
 */
public class QuickMul {
    /**
     * x^n
     */
    public static double quickMul(double x, long n) {
        double res = 1.0;
        while (n > 0) {
            if (n % 2 == 1) {
                res *= x;
            }
            x *= x;
            n /= 2;
        }
        return res;
    }

    /**
     * 取余 x^n % mod
     */
    public static double quickMulMod(double x, long n, long mod) {
        double res = 1.0;
        while (n > 0) {
            if (n % 2 == 1) {
                res = res * x % mod;
            }
            x = x * x % mod;
            n /= 2;
        }
        return res;
    }

    /**
     * 递归表示
     */
    public static double quickMulRecursion(double x, long n) {
        if (n == 0) {
            return 1.0;
        }
        double y = quickMulRecursion(x, n / 2);
        return n % 2 == 0 ? y * y : y * y * x;
    }

    public static void main(String[] args) {
        System.out.println(quickMul(2, 10));
        System.out.println(quickMulMod(2, 10, 1000000007));
        System.out.println(quickMulRecursion(2, 10));
    }
}

发表评论