Union-Find: The Data Structure That Makes Connectivity Problems Trivial
How the Union-Find (Disjoint Set Union) data structure works, why path compression and union by rank matter, and where you'll actually use it in interviews and real systems.
Union-Find (also called Disjoint Set Union, or DSU) is one of those data structures that solves a very specific class of problems, but solves them so well that nothing else comes close. The question it answers: given a collection of elements, which ones are connected?
Not "what's the shortest path between them" — just "are they in the same group?" That simpler question turns out to be incredibly useful, and Union-Find answers it in nearly O(1) time per query.
The Problem It Solves
Imagine you have n nodes and you're adding edges one at a time. After each edge, you want to answer: "Are node A and node B in the same connected component?"
You could run BFS/DFS after every new edge. That's O(n + m) per query. With Union-Find, it's effectively O(1) per query after amortization.
The two operations are:
- Find(x) — which group does element x belong to? (Returns the group's "representative" or root.)
- Union(x, y) — merge the groups containing x and y into one group.
Naive Implementation
The simplest approach: each element has a parent. The root of each tree is the representative of that group.
class UnionFind:
def __init__(self, n):
self.parent = list(range(n)) # each element is its own parent
def find(self, x):
while self.parent[x] != x:
x = self.parent[x]
return x
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
self.parent[root_x] = root_y
This works, but find can be O(n) in the worst case. If you keep unioning elements in a chain (0->1->2->3->...), the tree degenerates into a linked list, and finding the root means walking the entire chain.
We need two optimizations to fix this.
Path Compression
The insight: when you call find(x), you walk up to the root. On the way back, why not point every node you visited directly at the root? Next time anyone calls find on those nodes, they get there in one hop.
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # path compression
return self.parent[x]
One line of change, massive impact. After the first find, the entire path from x to the root gets flattened. Subsequent finds on any of those nodes are O(1).
There's also a non-recursive version called "path splitting" or "path halving" if you're worried about stack depth:
def find(self, x):
while self.parent[x] != x:
self.parent[x] = self.parent[self.parent[x]] # skip one level
x = self.parent[x]
return x
This doesn't fully flatten the path, but it halves the tree height on every find. In practice, it's fast enough and avoids recursion.
Union by Rank (or Size)
The second optimization: when merging two trees, attach the smaller one under the larger one. This keeps trees shallow.
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False # already connected
# attach smaller tree under larger tree
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
return True # merged two different components
You can use rank (upper bound on tree height) or size (number of elements). Both work. Rank is slightly more common in textbooks, size is sometimes more useful in practice because you can query "how big is this component?"
The Complexity
With both path compression and union by rank, the amortized time per operation is O(alpha(n)), where alpha is the inverse Ackermann function. For all practical purposes, alpha(n) <= 4 for any n that fits in the universe. It's effectively O(1).
This is one of those rare cases where the theoretical analysis says "basically constant time" and the practical performance backs it up completely.
Classic Problems
Detecting Cycles in an Undirected Graph
Process edges one by one. Before adding edge (u, v), check if u and v are already in the same component. If yes, adding this edge creates a cycle.
def has_cycle(n, edges):
uf = UnionFind(n)
for u, v in edges:
if uf.find(u) == uf.find(v):
return True # u and v already connected -> cycle
uf.union(u, v)
return False
This is cleaner and often faster than DFS-based cycle detection when you're processing edges one at a time.
Number of Connected Components
Just count how many distinct roots exist after processing all edges.
def count_components(n, edges):
uf = UnionFind(n)
for u, v in edges:
uf.union(u, v)
# count unique roots
return len(set(uf.find(i) for i in range(n)))
Kruskal's Minimum Spanning Tree
Sort edges by weight. Process them in order. For each edge, if it connects two different components, include it in the MST. If it would create a cycle (same component), skip it.
def kruskal(n, edges):
edges.sort(key=lambda e: e[2]) # sort by weight
uf = UnionFind(n)
mst = []
total_weight = 0
for u, v, w in edges:
if uf.union(u, v): # returns True if they were in different components
mst.append((u, v, w))
total_weight += w
if len(mst) == n - 1:
break
return total_weight, mst
Kruskal's is essentially a greedy algorithm that uses Union-Find to efficiently check "would this edge create a cycle?"
Accounts Merge
Given a list of accounts where each account is [name, email1, email2, ...], merge accounts that share any email. This is a connectivity problem in disguise — emails are nodes, and each account says "these emails are all connected."
def accounts_merge(accounts):
uf = UnionFind(10001) # use email index
email_to_id = {}
email_to_name = {}
idx = 0
for account in accounts:
name = account[0]
for email in account[1:]:
if email not in email_to_id:
email_to_id[email] = idx
idx += 1
email_to_name[email] = name
uf.union(email_to_id[account[1]], email_to_id[email])
# group emails by root
from collections import defaultdict
groups = defaultdict(list)
for email, eid in email_to_id.items():
groups[uf.find(eid)].append(email)
return [[email_to_name[emails[0]]] + sorted(emails) for emails in groups.values()]
Union-Find vs. BFS/DFS
When should you use Union-Find instead of graph traversal?
Union-Find wins when:- You're processing edges incrementally (adding connections over time)
- You only care about connectivity, not paths
- You need to answer many "are these connected?" queries efficiently
- The graph is changing (edges being added) between queries
- You need the actual path between nodes
- You need shortest path information
- You're doing a one-time traversal of a static graph
- The graph has directed edges (Union-Find is for undirected connectivity)
Common Mistakes
Forgetting to use find() before comparing roots. Don't compareparent[x] directly — that might not be the root after path compression changes things. Always compare find(x) with find(y).
Not returning a useful value from union(). Making union() return True/False (whether a merge actually happened) is crucial for problems like cycle detection and Kruskal's. Don't just make it void.
Using Union-Find for directed graphs. Union-Find tracks undirected connectivity. If you need to know "can I reach B from A in a directed graph," Union-Find isn't the right tool. Use DFS or topological sort.
Forgetting to initialize parent array. Each element should start as its own parent: parent[i] = i. If you initialize everything to 0, every element starts in the same group.
The Weighted Union-Find Variant
Sometimes you need to track a relationship between elements, not just whether they're connected. For example: "A is 3 units heavier than B." You can extend Union-Find to track these relative weights along edges.
class WeightedUnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.weight = [0] * n # weight[x] = relation from x to parent[x]
def find(self, x):
if self.parent[x] != x:
root = self.find(self.parent[x])
self.weight[x] += self.weight[self.parent[x]]
self.parent[x] = root
return self.parent[x]
def union(self, x, y, w):
# w means: x + w = y (or whatever relation you define)
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return # check consistency: self.weight[x] - self.weight[y] should equal w
self.parent[root_x] = root_y
self.weight[root_x] = w + self.weight[y] - self.weight[x]
This is how you solve problems like "Evaluate Division" on LeetCode, where you're given equations like a/b = 2.0 and need to answer queries about a/c.
Practical Takeaway
Union-Find is a specialized tool. You won't use it on every problem. But when a problem involves grouping elements, tracking connected components, or detecting cycles in undirected graphs — especially when connections are added incrementally — Union-Find is almost certainly the right approach. Learn the template with path compression and union by rank, and you'll handle these problems confidently.
Practice building the data structure from scratch a few times on CodeUp. Once the template is in muscle memory, the hard part of Union-Find problems becomes recognizing that the problem is a Union-Find problem — the implementation is mechanical after that.