⬆︎
×

[PAT-A] 1135 Is It A Red-Black Tree

Hyplus目录

Java

package PAT_A1135_Is_It_a_Red_Black_Tree;

import java.util.LinkedList;
import java.util.Scanner;

public class Main {
    static int blackNum = -1;

    static class TreeNode {
        int color;  // 0红,1黑
        int key;
        TreeNode lChild;
        TreeNode rChild;

        public TreeNode(int color, int key) {
            this.color = color;
            this.key = key;
            this.lChild = null;
            this.rChild = null;
        }
    }

    static TreeNode insert(TreeNode root, int color, int key) {
        if (root == null) {
            root = new TreeNode(color, key);
        } else if (key < root.key) {
            root.lChild = insert(root.lChild, color, key);
        } else if (key > root.key) {
            root.rChild = insert(root.rChild, color, key);
        }
        return root;
    }

    static boolean dfs(TreeNode root, int num) {
        boolean flag = true;
        if (root == null) {
            if (blackNum == -1) {
                blackNum = num;
            } else if (blackNum != num) {
                flag = false;
            }
        } else {
            if (root.color == 1) {
                num++;
            }
            if (!dfs(root.lChild, num)) {
                flag = false;
            }
            if (!dfs(root.rChild, num)) {
                flag = false;
            }
        }
        return flag;
    }

    static boolean bfs(TreeNode root) {
        LinkedList<TreeNode> list = new LinkedList<>();
        boolean flag = true;
        if (root.color == 0) {
            flag = false;
        } else {
            list.add(root);
            while (!list.isEmpty()) {
                TreeNode n = list.remove();
                if (n.color == 0) {
                    if ((n.lChild != null && n.lChild.color == 0) || (n.rChild != null && n.rChild.color == 0)) {
                        flag = false;
                        break;
                    }
                }
                blackNum = -1;
                if (!dfs(n, 0)) {
                    flag = false;
                    break;
                }
                if (n.lChild != null) {
                    list.add(n.lChild);
                }
                if (n.rChild != null) {
                    list.add(n.rChild);
                }
            }
        }
        return flag;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int k = sc.nextInt();

        TreeNode root;
        for (int i = 0; i < k; i++) {
            int num = sc.nextInt();
            root = null;
            for (int j = 0; j < num; j++) {
                int temp = sc.nextInt();
                if (temp < 0) {
                    root = insert(root, 0, Math.abs(temp));
                } else {
                    root = insert(root, 1, Math.abs(temp));
                }
            }
            if (bfs(root)) {
                System.out.println("Yes");
            } else {
                System.out.println("No");
            }
        }

        sc.close();
    }
}

C++

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 5029;        // 未明确给出结点值范围,经测N=5028时测试点3会段错误

int n, Q;
int pre[N], in[N], inpos[N];
bool is_black[N], flag = true;

int build(int inL, int inR, int preL, int preR, int &sum) {    // sum: 黑结点总数
    int root = pre[preL];
    int k = inpos[root];

    if (k < inL || k > inR) {
        flag = false;
        return 0;
    }

    int lchild = 0, rchild = 0, lsum = 0, rsum = 0;
    if (inL < k) lchild = build(inL, k - 1, preL + 1, preL + 1 + (k - 1 - inL), lsum);
    if (k < inR) rchild = build(k + 1, inR, preL + 1 + (k - 1 - inL) + 1, preR, rsum);

    if (lsum != rsum) flag = false;

    sum = lsum;
    if (is_black[root]) sum++;
    else if (!is_black[lchild] || !is_black[rchild]) flag = false;

    return root;
}

int main() {
    scanf("%d", &Q);
    while (Q--) {
        flag = true;
        memset(is_black, true, sizeof is_black);
        memset(inpos, -1, sizeof inpos);

        scanf("%d", &n);
        for (int i = 0; i < n; ++i) {
            int x;
            scanf("%d", &x);
            pre[i] = abs(x);
            in[i] = pre[i];

            if (x >= 0) is_black[pre[i]] = true;
            else is_black[pre[i]] = false;
        }
        sort(in, in + n);

        for (int i = 0; i < n; ++i) inpos[in[i]] = i;

        int sum = 0;
        int root = build(0, n - 1, 0, n - 1, sum);
        if (!is_black[root]) flag = false;

        if (flag) printf("Yes\n");
        else printf("No\n");
    }

    return 0;
}

发表评论