230. Kth Smallest Element in a BST

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note: You may assume k is always valid, 1 ≤ k ≤ BST's total elements.

Example 1:

Input: root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1
Output: 3

Follow up: What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

// Inorder Traversal Iteration with Stack
int kthSmallest(TreeNode* root, int k) { // time: O(n or k); space: O(n or k)
    TreeNode* cur = root;
    stack<TreeNode*> st;
    int n = k;
    while (cur || !st.empty()) {
        while (cur) {
            st.push(cur);
            cur = cur->left;
        }
        cur = st.top(); st.pop();
        if (--n == 0) return cur->val;
        cur = cur->right;
    }
    return -1;
}
// Inorder Traversal Recursion
void helper(TreeNode* node, int& n, int& res) {
    if (!node) return;
    helper(node->left, n, res);
    if (--n == 0) {
        res = node->val;
        return;
    }
    helper(node->right, n, res);
}
int kthSmallest(TreeNode* root, int k) { // time: O(n or k); space: O(n or k)
    int n = k, res = -1;
    helper(root, n, res);
    return res;
}
// Follow-Up
struct TreeCountNode {
    int val;
    int count;
    TreeCountNode *left, *right;
    TreeCountNode(int v) : val(v), count(1), left(nullptr), right(nullptr) {}
};

TreeCountNode* build(TreeNode* root) {
    if (!root) return nullptr;
    TreeCountNode* node = new TreeCountNode(root->val);
    node->left = build(root->left);
    node->right = build(root->right);
    if (node->left) node->count += node->left->count;
    if (node->right) node->count += node->right->count;
    return node;
}

int helper(TreeCountNode* node, int k) {
    if (node->left) {
        int cnt = node->left->count;
        if (k <= cnt) {
            return helper(node->left, k);
        } else if (k > cnt + 1) {
            return helper(node->right, k - cnt - 1);
        } else return node->val;
    } else {
        if (k == 1) return node->val;
        return helper(node->right, k - 1);
    }
}

int kthSmallest(TreeNode* root, int k) {
    TreeCountNode* node = build(root);
    return helper(node, k);
}

Last updated

Was this helpful?