333. Largest BST Subtree

Given a binary tree, find the largest subtree which is a Binary Search Tree (BST), where largest means subtree with largest number of nodes in it.

Note: A subtree must include all of its descendants.

Example:

Input: [10,5,15,1,8,null,7]

   10 
   / \ 
  5  15 
 / \   \ 
1   8   7

Output: 3
Explanation: The Largest BST Subtree in this case is the highlighted one.
             The return value is the subtree's size, which is 3.

Follow up: Can you figure out ways to solve it with O(n) time complexity?

在dfs function的時候,如果當前這個node能找到符合的subtree,就不需要再dfs下去用更小的node找BST,因為往下找不可能找到更大的BST了。 count function會幫忙算valid的node的數量。

int count(TreeNode* root, int lower, int upper) {
    if (!root) return 0;
    if (root->val <= lower || root->val >= upper) return -1;
    int left = count(root->left, lower, root->val);
    if (left == -1) return -1;
    int right = count(root->right, root->val, upper);
    if (right == -1) return -1;
    return left + right + 1;
}
void dfs(TreeNode* root, int& res) {
    if (!root) return;
    int n = count(root, INT_MIN, INT_MAX);
    if (n != -1) {
        res = max(res, n);
        return;
    }
    dfs(root->left, res);
    dfs(root->right, res);
}
int largestBSTSubtree(TreeNode* root) { // time: O(n^2); space: O(n)
    int res = 0;
    dfs(root, res);
    return res;
}
int count(TreeNode* root) {
    if (!root) return 0;
    return count(root->left) + count(root->right) + 1;
}
bool isValid(TreeNode* root, int lower, int upper) {
    if (!root) return true;
    if (root->val <= lower || root->val >= upper) return false;
    return isValid(root->left, lower, root->val) && isValid(root->right, root->val, upper);
}
int largestBSTSubtree(TreeNode* root) { // time: O(n^2); space: O(n)
    if (!root) return 0;
    if (isValid(root, INT_MIN, INT_MAX)) return count(root);
    return max(largestBSTSubtree(root->left), largestBSTSubtree(root->right));
}

為了避免重複計算,如果用pre-order搭配其他變數就可以讓helper function先直直往下走到leaf node,然後再一層層回傳要記錄的值。

// Follow-Up O(n) Method
void helper(TreeNode* root, int& lower, int& upper, int& res) {
    if (!root) return;
    int left_cnt = 0, right_cnt = 0;
    int left_mn = INT_MIN, right_mn = INT_MIN;
    int left_mx = INT_MAX, right_mx = INT_MAX;
    helper(root->left, left_mn, left_mx, left_cnt);
    helper(root->right, right_mn, right_mx, right_cnt);
    if ((!root->left || root->val > left_mx) && (!root->right || root->val < right_mn)) {
        res = left_cnt + right_cnt + 1;
        lower = root->left ? left_mn : root->val;
        upper = root->right ? right_mx : root->val;
    } else {
        res = max(left_cnt, right_cnt);
    }
}
int largestBSTSubtree(TreeNode* root) { // time: O(n); space: O(n)
    int res = 0, lower = INT_MIN, upper = INT_MAX;
    helper(root, lower, upper, res);
    return res;
}

helper function裡如果root是nullptr的話,回傳{INT_MAX, INT_MIN, 0}是為了讓null node回傳到leaf node時可以讓count variable加1,這樣recursive function才有辦法運作。

// Avoid using too many variables
vector<int> helper(TreeNode* root) {
    if (!root) return {INT_MAX, INT_MIN, 0};
    vector<int> left = helper(root->left), right = helper(root->right);
    if (root->val > left[1] && root->val < right[0]) {
        return {min(left[0], root->val), max(right[1], root->val), left[2] + right[2] + 1};
    } else {
        return {INT_MIN, INT_MAX, max(left[2], right[2])};
    }
}
int largestBSTSubtree(TreeNode* root) { // time: O(n); space: O(n)
    vector<int> res = helper(root); // res = {min, max, cnt}
    return res[2];
}

Last updated

Was this helpful?