⬆︎
×

[PAT-A] 1115 Counting Nodes in a Binary Search Tree

Hyplus目录

Java

import java.util.Scanner;

public class Main {
    static TreeNode root;
    static int n1 = 0, n2 = 0, bottom = 0;

    static void insert(int x) {
        TreeNode newNode = new TreeNode(x);
        if (root == null) {
            root = newNode;
            root.depth = 1;
            bottom = 1;
        } else {
            TreeNode current = root;
            TreeNode parent = null;
            while (current != null) {
                parent = current;
                if (x <= current.val) {
                    current = current.left;
                } else {
                    current = current.right;
                }
            }
            newNode.depth = parent.depth + 1;
            bottom = Math.max(bottom, newNode.depth);
            if (x <= parent.val) {
                parent.left = newNode;
            } else {
                parent.right = newNode;
            }
        }
    }

    static void dfs(TreeNode node) {
        if (node != null) {
            if (node.depth == bottom) {
                n1++;
            } else if (node.depth == bottom - 1) {
                n2++;
            }
            dfs(node.left);
            dfs(node.right);
        }
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
        root = null;
        n1 = 0;
        n2 = 0;
        bottom = 0;

        for (int i = 0; i < n; i++) {
            int x = scanner.nextInt();
            insert(x);
        }

        dfs(root);
        System.out.printf("%d + %d = %d\n", n1, n2, n1 + n2);
        scanner.close();
    }
}

class TreeNode {
    int val;
    TreeNode left;
    TreeNode right;
    int depth;

    TreeNode(int val) {
        this.val = val;
    }
}

C++

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 1010;

int n;        // [1, n]
int l[N], r[N], w[N], idx;    // 左<=,右>
int depth[N], bottom, cnt1, cnt2;

void insert(int &u, int x, int d) {
    if (!u) {
        u = ++idx;
        w[u] = x;
        depth[u] = d;
        bottom = max(bottom, d);
    } else if (x <= w[u]) {
        insert(l[u], x, d + 1);
    } else insert(r[u], x, d + 1);
}

void dfs(int u) {
    if (u) {
        if (depth[u] == bottom) cnt1++;
        else if (depth[u] == bottom - 1) cnt2++;

        dfs(l[u]);
        dfs(r[u]);
    }
}

int main() {
    scanf("%d", &n);

    int root = 0;
    for (int i = 0; i < n; ++i) {
        int x;
        scanf("%d", &x);
        insert(root, x, 0);
    }

    dfs(root);
    printf("%d + %d = %d\n", cnt1, cnt2, cnt1 + cnt2);

    return 0;
}

发表评论