05-二叉搜索树(BST)

💡 核心结论

BST本质

  • 定义:左子树所有节点 < 根 < 右子树所有节点

  • 性能:平均O(log n),最坏O(n)(退化成链表)

  • 中序遍历:得到有序序列(BST的重要性质)

  • 查找效率:通过比较可以排除一半节点

  • 局限:可能不平衡,需要AVL树或红黑树改进

关键操作

操作

平均

最坏

关键点

查找

O(log n)

O(n)

比较大小决定方向

插入

O(log n)

O(n)

找到位置后O(1)插入

删除

O(log n)

O(n)

三种情况分别处理

最小/最大

O(log n)

O(n)

一直往左/右

删除节点三种情况

  1. 叶子节点:直接删除

  2. 一个子节点:用子节点替换

  3. 两个子节点:用后继节点(右子树最小)替换

BST vs 数组 vs 链表

操作

BST

有序数组

链表

查找

O(log n)

O(log n)

O(n)

插入

O(log n)

O(n)

O(1)*

删除

O(log n)

O(n)

O(1)*

有序遍历

O(n)

O(n)

O(n log n)

*已知位置

应用场景

  • 动态有序数据维护

  • 范围查询(findMin, findMax, floor, ceiling)

  • 数据库索引(B树、B+树是BST的扩展)

  • 文件系统

🔍 查找操作

递归查找

def search(root, val):
    if not root or root.val == val:
        return root
    if val < root.val:
        return search(root.left, val)
    return search(root.right, val)

迭代查找(推荐)

def search_iterative(root, val):
    while root and root.val != val:
        if val < root.val:
            root = root.left
        else:
            root = root.right
    return root

➕ 插入操作

递归插入

def insert(root, val):
    if not root:
        return TreeNode(val)
    if val < root.val:
        root.left = insert(root.left, val)
    else:
        root.right = insert(root.right, val)
    return root

➖ 删除操作(最复杂)

删除步骤

def delete(root, val):
    if not root:
        return None
    
    # 查找节点
    if val < root.val:
        root.left = delete(root.left, val)
    elif val > root.val:
        root.right = delete(root.right, val)
    else:
        # 找到要删除的节点
        
        # 情况1:叶子节点或只有一个子节点
        if not root.left:
            return root.right
        if not root.right:
            return root.left
        
        # 情况2:有两个子节点
        # 找右子树的最小节点(后继)
        successor = find_min(root.right)
        root.val = successor.val
        # 删除后继节点
        root.right = delete(root.right, successor.val)
    
    return root

def find_min(root):
    while root.left:
        root = root.left
    return root

📚 LeetCode练习

💻 完整代码实现

Python 实现

  1"""
  2二叉搜索树(BST)实现
  3"""
  4
  5class TreeNode:
  6    """BST节点"""
  7    
  8    def __init__(self, val=0):
  9        self.val = val
 10        self.left = None
 11        self.right = None
 12
 13
 14class BST:
 15    """二叉搜索树"""
 16    
 17    def __init__(self):
 18        self.root = None
 19    
 20    # ========== 查找 ==========
 21    
 22    def search(self, val):
 23        """查找节点(迭代)"""
 24        curr = self.root
 25        while curr and curr.val != val:
 26            if val < curr.val:
 27                curr = curr.left
 28            else:
 29                curr = curr.right
 30        return curr
 31    
 32    def search_recursive(self, root, val):
 33        """查找节点(递归)"""
 34        if not root or root.val == val:
 35            return root
 36        if val < root.val:
 37            return self.search_recursive(root.left, val)
 38        return self.search_recursive(root.right, val)
 39    
 40    # ========== 插入 ==========
 41    
 42    def insert(self, val):
 43        """插入节点"""
 44        self.root = self._insert(self.root, val)
 45    
 46    def _insert(self, root, val):
 47        """插入辅助函数(递归)"""
 48        if not root:
 49            return TreeNode(val)
 50        
 51        if val < root.val:
 52            root.left = self._insert(root.left, val)
 53        elif val > root.val:
 54            root.right = self._insert(root.right, val)
 55        # val == root.val 则不插入(BST不存重复值)
 56        
 57        return root
 58    
 59    def insert_iterative(self, val):
 60        """插入节点(迭代)"""
 61        if not self.root:
 62            self.root = TreeNode(val)
 63            return
 64        
 65        curr = self.root
 66        while True:
 67            if val < curr.val:
 68                if not curr.left:
 69                    curr.left = TreeNode(val)
 70                    return
 71                curr = curr.left
 72            elif val > curr.val:
 73                if not curr.right:
 74                    curr.right = TreeNode(val)
 75                    return
 76                curr = curr.right
 77            else:
 78                return  # 已存在
 79    
 80    # ========== 删除 ==========
 81    
 82    def delete(self, val):
 83        """删除节点"""
 84        self.root = self._delete(self.root, val)
 85    
 86    def _delete(self, root, val):
 87        """删除辅助函数"""
 88        if not root:
 89            return None
 90        
 91        # 查找节点
 92        if val < root.val:
 93            root.left = self._delete(root.left, val)
 94        elif val > root.val:
 95            root.right = self._delete(root.right, val)
 96        else:
 97            # 找到要删除的节点
 98            
 99            # 情况1:没有子节点或只有一个子节点
100            if not root.left:
101                return root.right
102            if not root.right:
103                return root.left
104            
105            # 情况2:有两个子节点
106            # 用右子树的最小节点(后继)替换
107            successor = self._find_min(root.right)
108            root.val = successor.val
109            root.right = self._delete(root.right, successor.val)
110        
111        return root
112    
113    # ========== 辅助方法 ==========
114    
115    def _find_min(self, root):
116        """找最小节点"""
117        while root.left:
118            root = root.left
119        return root
120    
121    def _find_max(self, root):
122        """找最大节点"""
123        while root.right:
124            root = root.right
125        return root
126    
127    def find_min(self):
128        """找最小值"""
129        if not self.root:
130            return None
131        return self._find_min(self.root).val
132    
133    def find_max(self):
134        """找最大值"""
135        if not self.root:
136            return None
137        return self._find_max(self.root).val
138    
139    # ========== 遍历 ==========
140    
141    def inorder(self):
142        """中序遍历(有序输出)"""
143        result = []
144        self._inorder(self.root, result)
145        return result
146    
147    def _inorder(self, root, result):
148        if root:
149            self._inorder(root.left, result)
150            result.append(root.val)
151            self._inorder(root.right, result)
152    
153    # ========== 验证 ==========
154    
155    def is_valid_bst(self):
156        """验证是否为有效BST"""
157        def validate(node, min_val, max_val):
158            if not node:
159                return True
160            if not (min_val < node.val < max_val):
161                return False
162            return (validate(node.left, min_val, node.val) and
163                    validate(node.right, node.val, max_val))
164        
165        return validate(self.root, float('-inf'), float('inf'))
166    
167    # ========== 打印 ==========
168    
169    def print_tree(self, node=None, level=0, prefix="Root: "):
170        """打印树结构"""
171        if node is None:
172            node = self.root
173        
174        if node:
175            print(" " * (level * 4) + prefix + str(node.val))
176            if node.left or node.right:
177                if node.left:
178                    self.print_tree(node.left, level + 1, "L--- ")
179                else:
180                    print(" " * ((level + 1) * 4) + "L--- None")
181                
182                if node.right:
183                    self.print_tree(node.right, level + 1, "R--- ")
184                else:
185                    print(" " * ((level + 1) * 4) + "R--- None")
186
187
188def demo():
189    """演示BST操作"""
190    print("=== 二叉搜索树演示 ===\n")
191    
192    bst = BST()
193    
194    # 插入节点
195    values = [5, 3, 7, 2, 4, 6, 8]
196    print(f"插入节点: {values}")
197    for val in values:
198        bst.insert(val)
199    
200    print("\n树结构:")
201    bst.print_tree()
202    
203    # 遍历
204    print(f"\n中序遍历(有序): {bst.inorder()}")
205    
206    # 查找
207    search_val = 4
208    found = bst.search(search_val)
209    print(f"\n查找 {search_val}: {'找到' if found else '未找到'}")
210    
211    # 最小最大值
212    print(f"最小值: {bst.find_min()}")
213    print(f"最大值: {bst.find_max()}")
214    
215    # 验证
216    print(f"是否为有效BST: {bst.is_valid_bst()}")
217    
218    # 删除
219    print(f"\n删除节点 3")
220    bst.delete(3)
221    print("树结构:")
222    bst.print_tree()
223    print(f"中序遍历: {bst.inorder()}")
224
225
226if __name__ == '__main__':
227    demo()
228

C++ 实现

  1/**
  2 * 二叉搜索树(BST)实现
  3 */
  4
  5#include <iostream>
  6#include <vector>
  7
  8struct TreeNode {
  9    int val;
 10    TreeNode* left;
 11    TreeNode* right;
 12    
 13    TreeNode(int v) : val(v), left(nullptr), right(nullptr) {}
 14};
 15
 16
 17class BST {
 18private:
 19    TreeNode* root;
 20    
 21    TreeNode* insertHelper(TreeNode* node, int val) {
 22        if (!node) {
 23            return new TreeNode(val);
 24        }
 25        
 26        if (val < node->val) {
 27            node->left = insertHelper(node->left, val);
 28        } else if (val > node->val) {
 29            node->right = insertHelper(node->right, val);
 30        }
 31        
 32        return node;
 33    }
 34    
 35    TreeNode* deleteHelper(TreeNode* node, int val) {
 36        if (!node) {
 37            return nullptr;
 38        }
 39        
 40        if (val < node->val) {
 41            node->left = deleteHelper(node->left, val);
 42        } else if (val > node->val) {
 43            node->right = deleteHelper(node->right, val);
 44        } else {
 45            // 找到要删除的节点
 46            
 47            // 情况1:没有子节点或只有一个子节点
 48            if (!node->left) {
 49                TreeNode* temp = node->right;
 50                delete node;
 51                return temp;
 52            }
 53            if (!node->right) {
 54                TreeNode* temp = node->left;
 55                delete node;
 56                return temp;
 57            }
 58            
 59            // 情况2:有两个子节点
 60            TreeNode* successor = findMin(node->right);
 61            node->val = successor->val;
 62            node->right = deleteHelper(node->right, successor->val);
 63        }
 64        
 65        return node;
 66    }
 67    
 68    TreeNode* findMin(TreeNode* node) const {
 69        while (node->left) {
 70            node = node->left;
 71        }
 72        return node;
 73    }
 74    
 75    TreeNode* findMax(TreeNode* node) const {
 76        while (node->right) {
 77            node = node->right;
 78        }
 79        return node;
 80    }
 81    
 82    void inorderHelper(TreeNode* node, std::vector<int>& result) const {
 83        if (node) {
 84            inorderHelper(node->left, result);
 85            result.push_back(node->val);
 86            inorderHelper(node->right, result);
 87        }
 88    }
 89    
 90    void printTreeHelper(TreeNode* node, int level, std::string prefix) const {
 91        if (node) {
 92            std::cout << std::string(level * 4, ' ') << prefix << node->val << std::endl;
 93            if (node->left || node->right) {
 94                printTreeHelper(node->left, level + 1, "L--- ");
 95                printTreeHelper(node->right, level + 1, "R--- ");
 96            }
 97        }
 98    }
 99    
100    void deleteTree(TreeNode* node) {
101        if (node) {
102            deleteTree(node->left);
103            deleteTree(node->right);
104            delete node;
105        }
106    }
107
108public:
109    BST() : root(nullptr) {}
110    
111    ~BST() {
112        deleteTree(root);
113    }
114    
115    void insert(int val) {
116        root = insertHelper(root, val);
117    }
118    
119    void remove(int val) {
120        root = deleteHelper(root, val);
121    }
122    
123    TreeNode* search(int val) const {
124        TreeNode* curr = root;
125        while (curr && curr->val != val) {
126            if (val < curr->val) {
127                curr = curr->left;
128            } else {
129                curr = curr->right;
130            }
131        }
132        return curr;
133    }
134    
135    int getMin() const {
136        if (!root) throw std::runtime_error("Empty tree");
137        return findMin(root)->val;
138    }
139    
140    int getMax() const {
141        if (!root) throw std::runtime_error("Empty tree");
142        return findMax(root)->val;
143    }
144    
145    std::vector<int> inorder() const {
146        std::vector<int> result;
147        inorderHelper(root, result);
148        return result;
149    }
150    
151    void printTree() const {
152        printTreeHelper(root, 0, "Root: ");
153    }
154};
155
156
157int main() {
158    std::cout << "=== 二叉搜索树演示 ===" << std::endl << std::endl;
159    
160    BST bst;
161    
162    // 插入
163    std::vector<int> values = {5, 3, 7, 2, 4, 6, 8};
164    std::cout << "插入节点: ";
165    for (int val : values) {
166        std::cout << val << " ";
167        bst.insert(val);
168    }
169    std::cout << std::endl << std::endl;
170    
171    std::cout << "树结构:" << std::endl;
172    bst.printTree();
173    
174    // 遍历
175    std::cout << "\n中序遍历(有序): [";
176    std::vector<int> inorder_result = bst.inorder();
177    for (size_t i = 0; i < inorder_result.size(); i++) {
178        std::cout << inorder_result[i];
179        if (i < inorder_result.size() - 1) std::cout << ", ";
180    }
181    std::cout << "]" << std::endl;
182    
183    // 查找
184    int search_val = 4;
185    TreeNode* found = bst.search(search_val);
186    std::cout << "\n查找 " << search_val << ": " 
187              << (found ? "找到" : "未找到") << std::endl;
188    
189    // 最小最大值
190    std::cout << "最小值: " << bst.getMin() << std::endl;
191    std::cout << "最大值: " << bst.getMax() << std::endl;
192    
193    // 删除
194    std::cout << "\n删除节点 3" << std::endl;
195    bst.remove(3);
196    std::cout << "树结构:" << std::endl;
197    bst.printTree();
198    
199    return 0;
200}