10-并查集(Union-Find)

💡 核心结论

并查集本质

  • 定义:维护元素分组,支持快速合并和查询

  • 核心操作:union(合并)、find(查找)

  • 时间复杂度:接近O(1)(α(n),阿克曼函数的反函数)

  • 空间复杂度:O(n)

  • 应用:连通性问题、最小生成树、动态连通性

两大优化

  1. 路径压缩:查找时将路径上所有节点直接连到根

  2. 按秩合并:小树合并到大树,保持树的平衡

优化效果

优化

时间复杂度

无优化

O(n)

只路径压缩

O(log n)

只按秩合并

O(log n)

两者结合

O(α(n)) ≈ O(1)

应用场景(重要)

  • 连通性问题:判断两点是否连通

  • 最小生成树:Kruskal算法

  • 动态连通性:动态添加边

  • 朋友圈问题:社交网络分组

  • 岛屿数量:DFS的替代方案

🎯 基本实现

简单版本

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
    
    def find(self, x):
        """查找根节点"""
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])  # 路径压缩
        return self.parent[x]
    
    def union(self, x, y):
        """合并两个集合"""
        px, py = self.find(x), self.find(y)
        if px != py:
            self.parent[px] = py
            return True
        return False
    
    def connected(self, x, y):
        """判断是否连通"""
        return self.find(x) == self.find(y)

完整版本(按秩合并)

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.count = n  # 连通分量数
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        px, py = self.find(x), self.find(y)
        if px == py:
            return False
        
        # 按秩合并
        if self.rank[px] < self.rank[py]:
            self.parent[px] = py
        elif self.rank[px] > self.rank[py]:
            self.parent[py] = px
        else:
            self.parent[py] = px
            self.rank[px] += 1
        
        self.count -= 1
        return True
    
    def get_count(self):
        """获取连通分量数"""
        return self.count

📚 经典问题

1. 岛屿数量

def num_islands(grid):
    if not grid:
        return 0
    
    m, n = len(grid), len(grid[0])
    uf = UnionFind(m * n)
    
    for i in range(m):
        for j in range(n):
            if grid[i][j] == '1':
                for di, dj in [(0,1), (1,0)]:
                    ni, nj = i + di, j + dj
                    if 0 <= ni < m and 0 <= nj < n and grid[ni][nj] == '1':
                        uf.union(i * n + j, ni * n + nj)
    
    # 统计连通分量
    return sum(1 for i in range(m) for j in range(n) 
               if grid[i][j] == '1' and uf.find(i * n + j) == i * n + j)

2. 朋友圈数量

def find_circle_num(is_connected):
    n = len(is_connected)
    uf = UnionFind(n)
    
    for i in range(n):
        for j in range(i + 1, n):
            if is_connected[i][j] == 1:
                uf.union(i, j)
    
    return uf.get_count()

3. 冗余连接

def find_redundant_connection(edges):
    """找到使图成环的最后一条边"""
    n = len(edges)
    uf = UnionFind(n + 1)
    
    for u, v in edges:
        if not uf.union(u, v):
            return [u, v]
    
    return []

📚 LeetCode练习

💻 完整代码实现

Python 实现

  1"""
  2并查集实现
  3"""
  4
  5class UnionFind:
  6    """并查集(路径压缩 + 按秩合并)"""
  7    
  8    def __init__(self, n):
  9        self.parent = list(range(n))
 10        self.rank = [0] * n
 11        self.count = n  # 连通分量数
 12    
 13    def find(self, x):
 14        """查找根节点(路径压缩)"""
 15        if self.parent[x] != x:
 16            self.parent[x] = self.find(self.parent[x])
 17        return self.parent[x]
 18    
 19    def union(self, x, y):
 20        """合并两个集合(按秩合并)"""
 21        px, py = self.find(x), self.find(y)
 22        
 23        if px == py:
 24            return False
 25        
 26        if self.rank[px] < self.rank[py]:
 27            self.parent[px] = py
 28        elif self.rank[px] > self.rank[py]:
 29            self.parent[py] = px
 30        else:
 31            self.parent[py] = px
 32            self.rank[px] += 1
 33        
 34        self.count -= 1
 35        return True
 36    
 37    def connected(self, x, y):
 38        """判断是否连通"""
 39        return self.find(x) == self.find(y)
 40    
 41    def get_count(self):
 42        """获取连通分量数"""
 43        return self.count
 44
 45
 46# ========== 应用示例 ==========
 47
 48def num_islands(grid):
 49    """岛屿数量"""
 50    if not grid:
 51        return 0
 52    
 53    m, n = len(grid), len(grid[0])
 54    uf = UnionFind(m * n)
 55    zeros = 0
 56    
 57    for i in range(m):
 58        for j in range(n):
 59            if grid[i][j] == '0':
 60                zeros += 1
 61            else:
 62                for di, dj in [(0, 1), (1, 0)]:
 63                    ni, nj = i + di, j + dj
 64                    if 0 <= ni < m and 0 <= nj < n and grid[ni][nj] == '1':
 65                        uf.union(i * n + j, ni * n + nj)
 66    
 67    return uf.get_count() - zeros
 68
 69
 70def find_circle_num(is_connected):
 71    """朋友圈数量"""
 72    n = len(is_connected)
 73    uf = UnionFind(n)
 74    
 75    for i in range(n):
 76        for j in range(i + 1, n):
 77            if is_connected[i][j] == 1:
 78                uf.union(i, j)
 79    
 80    return uf.get_count()
 81
 82
 83def find_redundant_connection(edges):
 84    """找冗余连接"""
 85    n = len(edges)
 86    uf = UnionFind(n + 1)
 87    
 88    for u, v in edges:
 89        if not uf.union(u, v):
 90            return [u, v]
 91    
 92    return []
 93
 94
 95def demo():
 96    """演示并查集"""
 97    print("=== 并查集演示 ===\n")
 98    
 99    uf = UnionFind(10)
100    
101    print(f"初始连通分量数: {uf.get_count()}")
102    
103    # 合并
104    uf.union(1, 2)
105    uf.union(2, 3)
106    uf.union(4, 5)
107    
108    print(f"合并后连通分量数: {uf.get_count()}")
109    print(f"1和3连通: {uf.connected(1, 3)}")
110    print(f"1和4连通: {uf.connected(1, 4)}")
111    
112    uf.union(3, 4)
113    print(f"\n合并3和4后")
114    print(f"连通分量数: {uf.get_count()}")
115    print(f"1和5连通: {uf.connected(1, 5)}")
116    
117    # 朋友圈问题
118    print("\n=== 朋友圈问题 ===\n")
119    is_connected = [
120        [1, 1, 0],
121        [1, 1, 0],
122        [0, 0, 1]
123    ]
124    print(f"朋友圈矩阵: {is_connected}")
125    print(f"朋友圈数量: {find_circle_num(is_connected)}")
126
127
128if __name__ == '__main__':
129    demo()
130

C++ 实现

  1/**
  2 * 并查集实现
  3 */
  4
  5#include <iostream>
  6#include <vector>
  7#include <numeric>
  8
  9using namespace std;
 10
 11class UnionFind {
 12private:
 13    vector<int> parent;
 14    vector<int> rank;
 15    int count;  // 连通分量数
 16
 17public:
 18    UnionFind(int n) : parent(n), rank(n, 0), count(n) {
 19        iota(parent.begin(), parent.end(), 0);  // 初始化为0,1,2,...
 20    }
 21    
 22    // 查找根节点(路径压缩)
 23    int find(int x) {
 24        if (parent[x] != x) {
 25            parent[x] = find(parent[x]);
 26        }
 27        return parent[x];
 28    }
 29    
 30    // 合并两个集合(按秩合并)
 31    bool unite(int x, int y) {
 32        int px = find(x);
 33        int py = find(y);
 34        
 35        if (px == py) {
 36            return false;
 37        }
 38        
 39        if (rank[px] < rank[py]) {
 40            parent[px] = py;
 41        } else if (rank[px] > rank[py]) {
 42            parent[py] = px;
 43        } else {
 44            parent[py] = px;
 45            rank[px]++;
 46        }
 47        
 48        count--;
 49        return true;
 50    }
 51    
 52    // 判断是否连通
 53    bool connected(int x, int y) {
 54        return find(x) == find(y);
 55    }
 56    
 57    // 获取连通分量数
 58    int getCount() const {
 59        return count;
 60    }
 61};
 62
 63// 岛屿数量
 64int numIslands(vector<vector<char>>& grid) {
 65    if (grid.empty()) return 0;
 66    
 67    int m = grid.size();
 68    int n = grid[0].size();
 69    UnionFind uf(m * n);
 70    int zeros = 0;
 71    
 72    for (int i = 0; i < m; i++) {
 73        for (int j = 0; j < n; j++) {
 74            if (grid[i][j] == '0') {
 75                zeros++;
 76            } else {
 77                // 向右和向下连接
 78                if (j + 1 < n && grid[i][j+1] == '1') {
 79                    uf.unite(i * n + j, i * n + j + 1);
 80                }
 81                if (i + 1 < m && grid[i+1][j] == '1') {
 82                    uf.unite(i * n + j, (i + 1) * n + j);
 83                }
 84            }
 85        }
 86    }
 87    
 88    return uf.getCount() - zeros;
 89}
 90
 91// 朋友圈数量
 92int findCircleNum(const vector<vector<int>>& isConnected) {
 93    int n = isConnected.size();
 94    UnionFind uf(n);
 95    
 96    for (int i = 0; i < n; i++) {
 97        for (int j = i + 1; j < n; j++) {
 98            if (isConnected[i][j] == 1) {
 99                uf.unite(i, j);
100            }
101        }
102    }
103    
104    return uf.getCount();
105}
106
107// 冗余连接
108vector<int> findRedundantConnection(const vector<vector<int>>& edges) {
109    int n = edges.size();
110    UnionFind uf(n + 1);
111    
112    for (const auto& edge : edges) {
113        if (!uf.unite(edge[0], edge[1])) {
114            return edge;
115        }
116    }
117    
118    return {};
119}
120
121int main() {
122    cout << "=== 并查集演示 ===" << endl << endl;
123    
124    UnionFind uf(10);
125    
126    cout << "初始连通分量数: " << uf.getCount() << endl;
127    
128    uf.unite(1, 2);
129    uf.unite(2, 3);
130    uf.unite(4, 5);
131    
132    cout << "合并后连通分量数: " << uf.getCount() << endl;
133    cout << "1和3连通: " << (uf.connected(1, 3) ? "是" : "否") << endl;
134    cout << "1和4连通: " << (uf.connected(1, 4) ? "是" : "否") << endl;
135    
136    uf.unite(3, 4);
137    cout << "\n合并3和4后" << endl;
138    cout << "连通分量数: " << uf.getCount() << endl;
139    cout << "1和5连通: " << (uf.connected(1, 5) ? "是" : "否") << endl << endl;
140    
141    // 朋友圈
142    cout << "朋友圈问题:" << endl;
143    vector<vector<int>> isConnected = {
144        {1, 1, 0},
145        {1, 1, 0},
146        {0, 0, 1}
147    };
148    cout << "  朋友圈数量: " << findCircleNum(isConnected) << endl;
149    
150    return 0;
151}