10-并查集(Union-Find)
💡 核心结论
并查集本质
定义:维护元素分组,支持快速合并和查询
核心操作:union(合并)、find(查找)
时间复杂度:接近O(1)(α(n),阿克曼函数的反函数)
空间复杂度:O(n)
应用:连通性问题、最小生成树、动态连通性
两大优化
路径压缩:查找时将路径上所有节点直接连到根
按秩合并:小树合并到大树,保持树的平衡
优化效果
优化 |
时间复杂度 |
|---|---|
无优化 |
O(n) |
只路径压缩 |
O(log n) |
只按秩合并 |
O(log n) |
两者结合 |
O(α(n)) ≈ O(1) |
应用场景(重要)
连通性问题:判断两点是否连通
最小生成树:Kruskal算法
动态连通性:动态添加边
朋友圈问题:社交网络分组
岛屿数量:DFS的替代方案
🎯 基本实现
简单版本
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
def find(self, x):
"""查找根节点"""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # 路径压缩
return self.parent[x]
def union(self, x, y):
"""合并两个集合"""
px, py = self.find(x), self.find(y)
if px != py:
self.parent[px] = py
return True
return False
def connected(self, x, y):
"""判断是否连通"""
return self.find(x) == self.find(y)
完整版本(按秩合并)
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.count = n # 连通分量数
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False
# 按秩合并
if self.rank[px] < self.rank[py]:
self.parent[px] = py
elif self.rank[px] > self.rank[py]:
self.parent[py] = px
else:
self.parent[py] = px
self.rank[px] += 1
self.count -= 1
return True
def get_count(self):
"""获取连通分量数"""
return self.count
📚 经典问题
1. 岛屿数量
def num_islands(grid):
if not grid:
return 0
m, n = len(grid), len(grid[0])
uf = UnionFind(m * n)
for i in range(m):
for j in range(n):
if grid[i][j] == '1':
for di, dj in [(0,1), (1,0)]:
ni, nj = i + di, j + dj
if 0 <= ni < m and 0 <= nj < n and grid[ni][nj] == '1':
uf.union(i * n + j, ni * n + nj)
# 统计连通分量
return sum(1 for i in range(m) for j in range(n)
if grid[i][j] == '1' and uf.find(i * n + j) == i * n + j)
2. 朋友圈数量
def find_circle_num(is_connected):
n = len(is_connected)
uf = UnionFind(n)
for i in range(n):
for j in range(i + 1, n):
if is_connected[i][j] == 1:
uf.union(i, j)
return uf.get_count()
3. 冗余连接
def find_redundant_connection(edges):
"""找到使图成环的最后一条边"""
n = len(edges)
uf = UnionFind(n + 1)
for u, v in edges:
if not uf.union(u, v):
return [u, v]
return []
📚 LeetCode练习
💻 完整代码实现
Python 实现
1"""
2并查集实现
3"""
4
5class UnionFind:
6 """并查集(路径压缩 + 按秩合并)"""
7
8 def __init__(self, n):
9 self.parent = list(range(n))
10 self.rank = [0] * n
11 self.count = n # 连通分量数
12
13 def find(self, x):
14 """查找根节点(路径压缩)"""
15 if self.parent[x] != x:
16 self.parent[x] = self.find(self.parent[x])
17 return self.parent[x]
18
19 def union(self, x, y):
20 """合并两个集合(按秩合并)"""
21 px, py = self.find(x), self.find(y)
22
23 if px == py:
24 return False
25
26 if self.rank[px] < self.rank[py]:
27 self.parent[px] = py
28 elif self.rank[px] > self.rank[py]:
29 self.parent[py] = px
30 else:
31 self.parent[py] = px
32 self.rank[px] += 1
33
34 self.count -= 1
35 return True
36
37 def connected(self, x, y):
38 """判断是否连通"""
39 return self.find(x) == self.find(y)
40
41 def get_count(self):
42 """获取连通分量数"""
43 return self.count
44
45
46# ========== 应用示例 ==========
47
48def num_islands(grid):
49 """岛屿数量"""
50 if not grid:
51 return 0
52
53 m, n = len(grid), len(grid[0])
54 uf = UnionFind(m * n)
55 zeros = 0
56
57 for i in range(m):
58 for j in range(n):
59 if grid[i][j] == '0':
60 zeros += 1
61 else:
62 for di, dj in [(0, 1), (1, 0)]:
63 ni, nj = i + di, j + dj
64 if 0 <= ni < m and 0 <= nj < n and grid[ni][nj] == '1':
65 uf.union(i * n + j, ni * n + nj)
66
67 return uf.get_count() - zeros
68
69
70def find_circle_num(is_connected):
71 """朋友圈数量"""
72 n = len(is_connected)
73 uf = UnionFind(n)
74
75 for i in range(n):
76 for j in range(i + 1, n):
77 if is_connected[i][j] == 1:
78 uf.union(i, j)
79
80 return uf.get_count()
81
82
83def find_redundant_connection(edges):
84 """找冗余连接"""
85 n = len(edges)
86 uf = UnionFind(n + 1)
87
88 for u, v in edges:
89 if not uf.union(u, v):
90 return [u, v]
91
92 return []
93
94
95def demo():
96 """演示并查集"""
97 print("=== 并查集演示 ===\n")
98
99 uf = UnionFind(10)
100
101 print(f"初始连通分量数: {uf.get_count()}")
102
103 # 合并
104 uf.union(1, 2)
105 uf.union(2, 3)
106 uf.union(4, 5)
107
108 print(f"合并后连通分量数: {uf.get_count()}")
109 print(f"1和3连通: {uf.connected(1, 3)}")
110 print(f"1和4连通: {uf.connected(1, 4)}")
111
112 uf.union(3, 4)
113 print(f"\n合并3和4后")
114 print(f"连通分量数: {uf.get_count()}")
115 print(f"1和5连通: {uf.connected(1, 5)}")
116
117 # 朋友圈问题
118 print("\n=== 朋友圈问题 ===\n")
119 is_connected = [
120 [1, 1, 0],
121 [1, 1, 0],
122 [0, 0, 1]
123 ]
124 print(f"朋友圈矩阵: {is_connected}")
125 print(f"朋友圈数量: {find_circle_num(is_connected)}")
126
127
128if __name__ == '__main__':
129 demo()
130
C++ 实现
1/**
2 * 并查集实现
3 */
4
5#include <iostream>
6#include <vector>
7#include <numeric>
8
9using namespace std;
10
11class UnionFind {
12private:
13 vector<int> parent;
14 vector<int> rank;
15 int count; // 连通分量数
16
17public:
18 UnionFind(int n) : parent(n), rank(n, 0), count(n) {
19 iota(parent.begin(), parent.end(), 0); // 初始化为0,1,2,...
20 }
21
22 // 查找根节点(路径压缩)
23 int find(int x) {
24 if (parent[x] != x) {
25 parent[x] = find(parent[x]);
26 }
27 return parent[x];
28 }
29
30 // 合并两个集合(按秩合并)
31 bool unite(int x, int y) {
32 int px = find(x);
33 int py = find(y);
34
35 if (px == py) {
36 return false;
37 }
38
39 if (rank[px] < rank[py]) {
40 parent[px] = py;
41 } else if (rank[px] > rank[py]) {
42 parent[py] = px;
43 } else {
44 parent[py] = px;
45 rank[px]++;
46 }
47
48 count--;
49 return true;
50 }
51
52 // 判断是否连通
53 bool connected(int x, int y) {
54 return find(x) == find(y);
55 }
56
57 // 获取连通分量数
58 int getCount() const {
59 return count;
60 }
61};
62
63// 岛屿数量
64int numIslands(vector<vector<char>>& grid) {
65 if (grid.empty()) return 0;
66
67 int m = grid.size();
68 int n = grid[0].size();
69 UnionFind uf(m * n);
70 int zeros = 0;
71
72 for (int i = 0; i < m; i++) {
73 for (int j = 0; j < n; j++) {
74 if (grid[i][j] == '0') {
75 zeros++;
76 } else {
77 // 向右和向下连接
78 if (j + 1 < n && grid[i][j+1] == '1') {
79 uf.unite(i * n + j, i * n + j + 1);
80 }
81 if (i + 1 < m && grid[i+1][j] == '1') {
82 uf.unite(i * n + j, (i + 1) * n + j);
83 }
84 }
85 }
86 }
87
88 return uf.getCount() - zeros;
89}
90
91// 朋友圈数量
92int findCircleNum(const vector<vector<int>>& isConnected) {
93 int n = isConnected.size();
94 UnionFind uf(n);
95
96 for (int i = 0; i < n; i++) {
97 for (int j = i + 1; j < n; j++) {
98 if (isConnected[i][j] == 1) {
99 uf.unite(i, j);
100 }
101 }
102 }
103
104 return uf.getCount();
105}
106
107// 冗余连接
108vector<int> findRedundantConnection(const vector<vector<int>>& edges) {
109 int n = edges.size();
110 UnionFind uf(n + 1);
111
112 for (const auto& edge : edges) {
113 if (!uf.unite(edge[0], edge[1])) {
114 return edge;
115 }
116 }
117
118 return {};
119}
120
121int main() {
122 cout << "=== 并查集演示 ===" << endl << endl;
123
124 UnionFind uf(10);
125
126 cout << "初始连通分量数: " << uf.getCount() << endl;
127
128 uf.unite(1, 2);
129 uf.unite(2, 3);
130 uf.unite(4, 5);
131
132 cout << "合并后连通分量数: " << uf.getCount() << endl;
133 cout << "1和3连通: " << (uf.connected(1, 3) ? "是" : "否") << endl;
134 cout << "1和4连通: " << (uf.connected(1, 4) ? "是" : "否") << endl;
135
136 uf.unite(3, 4);
137 cout << "\n合并3和4后" << endl;
138 cout << "连通分量数: " << uf.getCount() << endl;
139 cout << "1和5连通: " << (uf.connected(1, 5) ? "是" : "否") << endl << endl;
140
141 // 朋友圈
142 cout << "朋友圈问题:" << endl;
143 vector<vector<int>> isConnected = {
144 {1, 1, 0},
145 {1, 1, 0},
146 {0, 0, 1}
147 };
148 cout << " 朋友圈数量: " << findCircleNum(isConnected) << endl;
149
150 return 0;
151}