Dynamic Programming on Trees — Patterns and Solutions
Tree DP combines recursion with memoization in ways that feel different from grid DP. Rerooting, subtree aggregation, and the patterns that make tree DP systematic.
Tree DP doesn't get as much attention as grid DP or string DP, but it shows up in interviews consistently — especially at companies that like graph problems. The difficulty comes from the tree structure itself: you can't just fill a table left-to-right. You need to think in terms of subtrees, and the "table" is the tree itself.
The good news: there are only a handful of patterns. Once you recognize which one applies, the implementation follows a consistent template.
The Core Idea
In tree DP, each node stores a value computed from its children's values. You process the tree bottom-up via postorder DFS: compute children first, then combine at the parent. The tree's structure determines the subproblem dependencies — a node's answer depends on its subtree.
1
/ \
2 3
/ \
4 5
Process: 4, 5, 2, 3, 1. By the time you handle node 2, you already have answers for 4 and 5.
Pattern 1: Subtree DP (Basic)
The simplest form. Each node computes a value based solely on its children's results.
Example: Maximum Path Sum in a Tree
Not the binary tree version (where paths go up and down). Here, the tree is rooted, and a path goes from some node down to a descendant.
def max_path_sum(root):
"""Max sum from any node down to any descendant."""
result = [float('-inf')]
def dfs(node):
if not node:
return 0
left = max(0, dfs(node.left))
right = max(0, dfs(node.right))
# path through this node
result[0] = max(result[0], node.val + left + right)
# return best single-direction path
return node.val + max(left, right)
dfs(root)
return result[0]
function maxPathSum(root) {
let result = -Infinity;
function dfs(node) {
if (!node) return 0;
const left = Math.max(0, dfs(node.left));
const right = Math.max(0, dfs(node.right));
result = Math.max(result, node.val + left + right);
return node.val + Math.max(left, right);
}
dfs(root);
return result;
}
The "compute one thing, track another" pattern again: dfs returns the best single-direction path (for the parent to use), but we track the best two-direction path through each node in result.
Example: House Robber III
Binary tree where each node has a value. You can't rob two directly connected nodes. Maximize total.
Each node has two states: robbed or not robbed.
def rob(root):
def dfs(node):
if not node:
return (0, 0) # (rob_this, skip_this)
left = dfs(node.left)
right = dfs(node.right)
# rob this node: can't rob children
rob_this = node.val + left[1] + right[1]
# skip this node: take best of each child
skip_this = max(left) + max(right)
return (rob_this, skip_this)
return max(dfs(root))
function rob(root) {
function dfs(node) {
if (!node) return [0, 0]; // [rob, skip]
const left = dfs(node.left);
const right = dfs(node.right);
const robThis = node.val + left[1] + right[1];
const skipThis = Math.max(...left) + Math.max(...right);
return [robThis, skipThis];
}
return Math.max(...dfs(root));
}
The pattern: return a tuple of states from each recursive call. The parent combines them based on the problem's constraints.
Pattern 2: Rerooting DP
This is the hard one. You need to compute something for every node as if that node were the root. Naive approach: reroot and recompute, O(n^2). Rerooting DP does it in O(n) with two passes.
Example: Sum of Distances in Tree
Given an unrooted tree of n nodes, for each node compute the sum of distances to all other nodes.
Pass 1 (bottom-up): Root the tree arbitrarily at node 0. Computecount[v] (subtree size) and dist[0] (sum of distances from root).
Pass 2 (top-down): For each edge (parent → child), rerooting from parent to child moves count[child] nodes closer by 1 and n - count[child] nodes farther by 1.
def sum_of_distances(n, edges):
adj = [[] for _ in range(n)]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
count = [1] * n
dist = [0] * n
# pass 1: compute count and dist[0]
def dfs1(node, parent):
for child in adj[node]:
if child != parent:
dfs1(child, node)
count[node] += count[child]
dist[node] += dist[child] + count[child]
# pass 2: reroot
def dfs2(node, parent):
for child in adj[node]:
if child != parent:
dist[child] = dist[node] - count[child] + (n - count[child])
dfs2(child, node)
dfs1(0, -1)
dfs2(0, -1)
return dist
function sumOfDistances(n, edges) {
const adj = Array.from({ length: n }, () => []);
for (const [u, v] of edges) {
adj[u].push(v);
adj[v].push(u);
}
const count = new Array(n).fill(1);
const dist = new Array(n).fill(0);
function dfs1(node, parent) {
for (const child of adj[node]) {
if (child !== parent) {
dfs1(child, node);
count[node] += count[child];
dist[node] += dist[child] + count[child];
}
}
}
function dfs2(node, parent) {
for (const child of adj[node]) {
if (child !== parent) {
dist[child] = dist[node] - count[child] + (n - count[child]);
dfs2(child, node);
}
}
}
dfs1(0, -1);
dfs2(0, -1);
return dist;
}
The rerooting formula: when moving the root from node to child, count[child] nodes get 1 step closer (they're now in the root's subtree) and n - count[child] nodes get 1 step farther. So:
dist[child] = dist[node] - count[child] + (n - count[child])
This formula changes depending on the problem. The structure (two-pass DFS) stays the same.
Pattern 3: Tree DP on General Trees (Adjacency List)
Interview trees are often given as adjacency lists, not binary tree nodes. The DFS template changes slightly — you need to avoid revisiting the parent.
Example: Maximum Independent Set
Select nodes with no two adjacent. Maximize count (or sum of values).
def max_independent_set(n, edges):
adj = [[] for _ in range(n)]
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
def dfs(node, parent):
include = 1 # include this node
exclude = 0 # exclude this node
for child in adj[node]:
if child != parent:
child_inc, child_exc = dfs(child, node)
include += child_exc # if we include node, exclude children
exclude += max(child_inc, child_exc) # if we exclude, children can be either
return include, exclude
inc, exc = dfs(0, -1)
return max(inc, exc)
function maxIndependentSet(n, edges) {
const adj = Array.from({ length: n }, () => []);
for (const [u, v] of edges) {
adj[u].push(v);
adj[v].push(u);
}
function dfs(node, parent) {
let include = 1;
let exclude = 0;
for (const child of adj[node]) {
if (child !== parent) {
const [childInc, childExc] = dfs(child, node);
include += childExc;
exclude += Math.max(childInc, childExc);
}
}
return [include, exclude];
}
const [inc, exc] = dfs(0, -1);
return Math.max(inc, exc);
}
Same include/exclude pattern as House Robber III, but on a general tree.
Pattern 4: Diameter and Longest Path
The tree diameter problem is tree DP in disguise. At each node, the longest path through it equals the sum of the two longest paths to leaves in different subtrees.
def tree_diameter(n, edges):
adj = [[] for _ in range(n)]
for u, v, w in edges: # weighted edges
adj[u].append((v, w))
adj[v].append((u, w))
diameter = [0]
def dfs(node, parent):
max1 = max2 = 0 # two longest paths from this node
for child, weight in adj[node]:
if child != parent:
child_depth = dfs(child, node) + weight
if child_depth > max1:
max2 = max1
max1 = child_depth
elif child_depth > max2:
max2 = child_depth
diameter[0] = max(diameter[0], max1 + max2)
return max1
dfs(0, -1)
return diameter[0]
function treeDiameter(n, edges) {
const adj = Array.from({ length: n }, () => []);
for (const [u, v, w] of edges) {
adj[u].push([v, w]);
adj[v].push([u, w]);
}
let diameter = 0;
function dfs(node, parent) {
let max1 = 0, max2 = 0;
for (const [child, weight] of adj[node]) {
if (child !== parent) {
const depth = dfs(child, node) + weight;
if (depth > max1) { max2 = max1; max1 = depth; }
else if (depth > max2) { max2 = depth; }
}
}
diameter = Math.max(diameter, max1 + max2);
return max1;
}
dfs(0, -1);
return diameter;
}
Tracking the top two values is a useful trick that shows up in tree problems repeatedly.
Dealing with Stack Overflow
Deep trees (chains of 100k+ nodes) blow the call stack with recursive DFS. Convert to iterative:
def iterative_tree_dp(n, adj):
# BFS to get processing order, then reverse
from collections import deque
order = []
visited = [False] * n
queue = deque([0])
visited[0] = True
parent = [-1] * n
while queue:
node = queue.popleft()
order.append(node)
for child in adj[node]:
if not visited[child]:
visited[child] = True
parent[child] = node
queue.append(child)
# process in reverse BFS order (leaves first)
dp = [0] * n
for node in reversed(order):
for child in adj[node]:
if child != parent[node]:
dp[node] += dp[child] + 1 # example computation
return dp
Reverse BFS order guarantees children are processed before parents. This replaces recursion entirely and handles trees of any depth.
Complexity
| Pattern | Time | Space |
|---|---|---|
| Subtree DP | O(n) | O(n) |
| Rerooting (two-pass) | O(n) | O(n) |
| Include/Exclude | O(n) | O(n) |
| Diameter | O(n) | O(n) |
Common Mistakes
Not tracking the parent. In general trees (adjacency lists), if you don't pass the parent, you'll revisit it as a "child" and infinite-loop. Returning the wrong thing from DFS. Tree DP functions often need to return a value for the parent's use that's different from the thing they track globally. Confusing these is the most common bug. Stack overflow on deep trees. Python's default recursion limit is 1000. Usesys.setrecursionlimit(200000) or convert to iterative. JavaScript's limit varies by engine but is similarly an issue for 50k+ nodes.
Forgetting rerooting is two passes. A single DFS gives you the answer for one root only. The second pass propagates the answer to all nodes in O(n). Trying to do it in one pass is a trap.
Practice these patterns on CodeUp — tree DP questions are less common than array DP but they're high-signal in interviews because fewer candidates prepare for them.