March 26, 20268 min read

Union-Find: The Data Structure That Makes Connectivity Problems Trivial

How the Union-Find (Disjoint Set Union) data structure works, why path compression and union by rank matter, and where you'll actually use it in interviews and real systems.

union-find data-structures graphs algorithms interviews
Ad 336x280

Union-Find (also called Disjoint Set Union, or DSU) is one of those data structures that solves a very specific class of problems, but solves them so well that nothing else comes close. The question it answers: given a collection of elements, which ones are connected?

Not "what's the shortest path between them" — just "are they in the same group?" That simpler question turns out to be incredibly useful, and Union-Find answers it in nearly O(1) time per query.

The Problem It Solves

Imagine you have n nodes and you're adding edges one at a time. After each edge, you want to answer: "Are node A and node B in the same connected component?"

You could run BFS/DFS after every new edge. That's O(n + m) per query. With Union-Find, it's effectively O(1) per query after amortization.

The two operations are:

  • Find(x) — which group does element x belong to? (Returns the group's "representative" or root.)
  • Union(x, y) — merge the groups containing x and y into one group.
That's it. Two operations. The entire data structure exists to make those two operations fast.

Naive Implementation

The simplest approach: each element has a parent. The root of each tree is the representative of that group.

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))  # each element is its own parent

def find(self, x):
while self.parent[x] != x:
x = self.parent[x]
return x

def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
self.parent[root_x] = root_y

This works, but find can be O(n) in the worst case. If you keep unioning elements in a chain (0->1->2->3->...), the tree degenerates into a linked list, and finding the root means walking the entire chain.

We need two optimizations to fix this.

Path Compression

The insight: when you call find(x), you walk up to the root. On the way back, why not point every node you visited directly at the root? Next time anyone calls find on those nodes, they get there in one hop.

def find(self, x):
    if self.parent[x] != x:
        self.parent[x] = self.find(self.parent[x])  # path compression
    return self.parent[x]

One line of change, massive impact. After the first find, the entire path from x to the root gets flattened. Subsequent finds on any of those nodes are O(1).

There's also a non-recursive version called "path splitting" or "path halving" if you're worried about stack depth:

def find(self, x):
    while self.parent[x] != x:
        self.parent[x] = self.parent[self.parent[x]]  # skip one level
        x = self.parent[x]
    return x

This doesn't fully flatten the path, but it halves the tree height on every find. In practice, it's fast enough and avoids recursion.

Union by Rank (or Size)

The second optimization: when merging two trees, attach the smaller one under the larger one. This keeps trees shallow.

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n

def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False # already connected

# attach smaller tree under larger tree if self.rank[root_x] < self.rank[root_y]: self.parent[root_x] = root_y elif self.rank[root_x] > self.rank[root_y]: self.parent[root_y] = root_x else: self.parent[root_y] = root_x self.rank[root_x] += 1

return True # merged two different components

You can use rank (upper bound on tree height) or size (number of elements). Both work. Rank is slightly more common in textbooks, size is sometimes more useful in practice because you can query "how big is this component?"

The Complexity

With both path compression and union by rank, the amortized time per operation is O(alpha(n)), where alpha is the inverse Ackermann function. For all practical purposes, alpha(n) <= 4 for any n that fits in the universe. It's effectively O(1).

This is one of those rare cases where the theoretical analysis says "basically constant time" and the practical performance backs it up completely.

Classic Problems

Detecting Cycles in an Undirected Graph

Process edges one by one. Before adding edge (u, v), check if u and v are already in the same component. If yes, adding this edge creates a cycle.

def has_cycle(n, edges):
    uf = UnionFind(n)
    for u, v in edges:
        if uf.find(u) == uf.find(v):
            return True  # u and v already connected -> cycle
        uf.union(u, v)
    return False

This is cleaner and often faster than DFS-based cycle detection when you're processing edges one at a time.

Number of Connected Components

Just count how many distinct roots exist after processing all edges.

def count_components(n, edges):
    uf = UnionFind(n)
    for u, v in edges:
        uf.union(u, v)
    # count unique roots
    return len(set(uf.find(i) for i in range(n)))

Kruskal's Minimum Spanning Tree

Sort edges by weight. Process them in order. For each edge, if it connects two different components, include it in the MST. If it would create a cycle (same component), skip it.

def kruskal(n, edges):
    edges.sort(key=lambda e: e[2])  # sort by weight
    uf = UnionFind(n)
    mst = []
    total_weight = 0

for u, v, w in edges:
if uf.union(u, v): # returns True if they were in different components
mst.append((u, v, w))
total_weight += w
if len(mst) == n - 1:
break

return total_weight, mst

Kruskal's is essentially a greedy algorithm that uses Union-Find to efficiently check "would this edge create a cycle?"

Accounts Merge

Given a list of accounts where each account is [name, email1, email2, ...], merge accounts that share any email. This is a connectivity problem in disguise — emails are nodes, and each account says "these emails are all connected."

def accounts_merge(accounts):
    uf = UnionFind(10001)  # use email index
    email_to_id = {}
    email_to_name = {}
    idx = 0

for account in accounts:
name = account[0]
for email in account[1:]:
if email not in email_to_id:
email_to_id[email] = idx
idx += 1
email_to_name[email] = name
uf.union(email_to_id[account[1]], email_to_id[email])

# group emails by root from collections import defaultdict groups = defaultdict(list) for email, eid in email_to_id.items(): groups[uf.find(eid)].append(email)

return [[email_to_name[emails[0]]] + sorted(emails) for emails in groups.values()]

Union-Find vs. BFS/DFS

When should you use Union-Find instead of graph traversal?

Union-Find wins when:
  • You're processing edges incrementally (adding connections over time)
  • You only care about connectivity, not paths
  • You need to answer many "are these connected?" queries efficiently
  • The graph is changing (edges being added) between queries
BFS/DFS wins when:
  • You need the actual path between nodes
  • You need shortest path information
  • You're doing a one-time traversal of a static graph
  • The graph has directed edges (Union-Find is for undirected connectivity)
The key distinction: Union-Find is optimized for the dynamic connectivity problem. If the graph is static and you just need to traverse it once, BFS/DFS is simpler and sufficient.

Common Mistakes

Forgetting to use find() before comparing roots. Don't compare parent[x] directly — that might not be the root after path compression changes things. Always compare find(x) with find(y). Not returning a useful value from union(). Making union() return True/False (whether a merge actually happened) is crucial for problems like cycle detection and Kruskal's. Don't just make it void. Using Union-Find for directed graphs. Union-Find tracks undirected connectivity. If you need to know "can I reach B from A in a directed graph," Union-Find isn't the right tool. Use DFS or topological sort. Forgetting to initialize parent array. Each element should start as its own parent: parent[i] = i. If you initialize everything to 0, every element starts in the same group.

The Weighted Union-Find Variant

Sometimes you need to track a relationship between elements, not just whether they're connected. For example: "A is 3 units heavier than B." You can extend Union-Find to track these relative weights along edges.

class WeightedUnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.weight = [0] * n  # weight[x] = relation from x to parent[x]

def find(self, x):
if self.parent[x] != x:
root = self.find(self.parent[x])
self.weight[x] += self.weight[self.parent[x]]
self.parent[x] = root
return self.parent[x]

def union(self, x, y, w):
# w means: x + w = y (or whatever relation you define)
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return # check consistency: self.weight[x] - self.weight[y] should equal w
self.parent[root_x] = root_y
self.weight[root_x] = w + self.weight[y] - self.weight[x]

This is how you solve problems like "Evaluate Division" on LeetCode, where you're given equations like a/b = 2.0 and need to answer queries about a/c.

Practical Takeaway

Union-Find is a specialized tool. You won't use it on every problem. But when a problem involves grouping elements, tracking connected components, or detecting cycles in undirected graphs — especially when connections are added incrementally — Union-Find is almost certainly the right approach. Learn the template with path compression and union by rank, and you'll handle these problems confidently.

Practice building the data structure from scratch a few times on CodeUp. Once the template is in muscle memory, the hard part of Union-Find problems becomes recognizing that the problem is a Union-Find problem — the implementation is mechanical after that.

Ad 728x90