March 26, 20269 min read

Dynamic Programming on Trees — Patterns and Solutions

Tree DP combines recursion with memoization in ways that feel different from grid DP. Rerooting, subtree aggregation, and the patterns that make tree DP systematic.

dynamic programming trees dfs recursion interview
Ad 336x280

Tree DP doesn't get as much attention as grid DP or string DP, but it shows up in interviews consistently — especially at companies that like graph problems. The difficulty comes from the tree structure itself: you can't just fill a table left-to-right. You need to think in terms of subtrees, and the "table" is the tree itself.

The good news: there are only a handful of patterns. Once you recognize which one applies, the implementation follows a consistent template.

The Core Idea

In tree DP, each node stores a value computed from its children's values. You process the tree bottom-up via postorder DFS: compute children first, then combine at the parent. The tree's structure determines the subproblem dependencies — a node's answer depends on its subtree.

       1
      / \
     2   3
    / \
   4   5

Process: 4, 5, 2, 3, 1. By the time you handle node 2, you already have answers for 4 and 5.

Pattern 1: Subtree DP (Basic)

The simplest form. Each node computes a value based solely on its children's results.

Example: Maximum Path Sum in a Tree

Not the binary tree version (where paths go up and down). Here, the tree is rooted, and a path goes from some node down to a descendant.

def max_path_sum(root):
    """Max sum from any node down to any descendant."""
    result = [float('-inf')]

def dfs(node):
if not node:
return 0
left = max(0, dfs(node.left))
right = max(0, dfs(node.right))
# path through this node
result[0] = max(result[0], node.val + left + right)
# return best single-direction path
return node.val + max(left, right)

dfs(root)
return result[0]

function maxPathSum(root) {
  let result = -Infinity;

function dfs(node) {
if (!node) return 0;
const left = Math.max(0, dfs(node.left));
const right = Math.max(0, dfs(node.right));
result = Math.max(result, node.val + left + right);
return node.val + Math.max(left, right);
}

dfs(root);
return result;
}

The "compute one thing, track another" pattern again: dfs returns the best single-direction path (for the parent to use), but we track the best two-direction path through each node in result.

Example: House Robber III

Binary tree where each node has a value. You can't rob two directly connected nodes. Maximize total.

Each node has two states: robbed or not robbed.

def rob(root):
    def dfs(node):
        if not node:
            return (0, 0)  # (rob_this, skip_this)

left = dfs(node.left)
right = dfs(node.right)

# rob this node: can't rob children rob_this = node.val + left[1] + right[1] # skip this node: take best of each child skip_this = max(left) + max(right)

return (rob_this, skip_this)

return max(dfs(root))

function rob(root) {
  function dfs(node) {
    if (!node) return [0, 0]; // [rob, skip]

const left = dfs(node.left);
const right = dfs(node.right);

const robThis = node.val + left[1] + right[1];
const skipThis = Math.max(...left) + Math.max(...right);

return [robThis, skipThis];
}

return Math.max(...dfs(root));
}

The pattern: return a tuple of states from each recursive call. The parent combines them based on the problem's constraints.

Pattern 2: Rerooting DP

This is the hard one. You need to compute something for every node as if that node were the root. Naive approach: reroot and recompute, O(n^2). Rerooting DP does it in O(n) with two passes.

Example: Sum of Distances in Tree

Given an unrooted tree of n nodes, for each node compute the sum of distances to all other nodes.

Pass 1 (bottom-up): Root the tree arbitrarily at node 0. Compute count[v] (subtree size) and dist[0] (sum of distances from root). Pass 2 (top-down): For each edge (parent → child), rerooting from parent to child moves count[child] nodes closer by 1 and n - count[child] nodes farther by 1.
def sum_of_distances(n, edges):
    adj = [[] for _ in range(n)]
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)

count = [1] * n
dist = [0] * n

# pass 1: compute count and dist[0] def dfs1(node, parent): for child in adj[node]: if child != parent: dfs1(child, node) count[node] += count[child] dist[node] += dist[child] + count[child] # pass 2: reroot def dfs2(node, parent): for child in adj[node]: if child != parent: dist[child] = dist[node] - count[child] + (n - count[child]) dfs2(child, node)

dfs1(0, -1)
dfs2(0, -1)
return dist

function sumOfDistances(n, edges) {
  const adj = Array.from({ length: n }, () => []);
  for (const [u, v] of edges) {
    adj[u].push(v);
    adj[v].push(u);
  }

const count = new Array(n).fill(1);
const dist = new Array(n).fill(0);

function dfs1(node, parent) {
for (const child of adj[node]) {
if (child !== parent) {
dfs1(child, node);
count[node] += count[child];
dist[node] += dist[child] + count[child];
}
}
}

function dfs2(node, parent) {
for (const child of adj[node]) {
if (child !== parent) {
dist[child] = dist[node] - count[child] + (n - count[child]);
dfs2(child, node);
}
}
}

dfs1(0, -1);
dfs2(0, -1);
return dist;
}

The rerooting formula: when moving the root from node to child, count[child] nodes get 1 step closer (they're now in the root's subtree) and n - count[child] nodes get 1 step farther. So:

dist[child] = dist[node] - count[child] + (n - count[child])

This formula changes depending on the problem. The structure (two-pass DFS) stays the same.

Pattern 3: Tree DP on General Trees (Adjacency List)

Interview trees are often given as adjacency lists, not binary tree nodes. The DFS template changes slightly — you need to avoid revisiting the parent.

Example: Maximum Independent Set

Select nodes with no two adjacent. Maximize count (or sum of values).

def max_independent_set(n, edges):
    adj = [[] for _ in range(n)]
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)

def dfs(node, parent):
include = 1 # include this node
exclude = 0 # exclude this node

for child in adj[node]:
if child != parent:
child_inc, child_exc = dfs(child, node)
include += child_exc # if we include node, exclude children
exclude += max(child_inc, child_exc) # if we exclude, children can be either

return include, exclude

inc, exc = dfs(0, -1)
return max(inc, exc)

function maxIndependentSet(n, edges) {
  const adj = Array.from({ length: n }, () => []);
  for (const [u, v] of edges) {
    adj[u].push(v);
    adj[v].push(u);
  }

function dfs(node, parent) {
let include = 1;
let exclude = 0;

for (const child of adj[node]) {
if (child !== parent) {
const [childInc, childExc] = dfs(child, node);
include += childExc;
exclude += Math.max(childInc, childExc);
}
}

return [include, exclude];
}

const [inc, exc] = dfs(0, -1);
return Math.max(inc, exc);
}

Same include/exclude pattern as House Robber III, but on a general tree.

Pattern 4: Diameter and Longest Path

The tree diameter problem is tree DP in disguise. At each node, the longest path through it equals the sum of the two longest paths to leaves in different subtrees.

def tree_diameter(n, edges):
    adj = [[] for _ in range(n)]
    for u, v, w in edges:  # weighted edges
        adj[u].append((v, w))
        adj[v].append((u, w))

diameter = [0]

def dfs(node, parent):
max1 = max2 = 0 # two longest paths from this node

for child, weight in adj[node]:
if child != parent:
child_depth = dfs(child, node) + weight
if child_depth > max1:
max2 = max1
max1 = child_depth
elif child_depth > max2:
max2 = child_depth

diameter[0] = max(diameter[0], max1 + max2)
return max1

dfs(0, -1)
return diameter[0]

function treeDiameter(n, edges) {
  const adj = Array.from({ length: n }, () => []);
  for (const [u, v, w] of edges) {
    adj[u].push([v, w]);
    adj[v].push([u, w]);
  }

let diameter = 0;

function dfs(node, parent) {
let max1 = 0, max2 = 0;

for (const [child, weight] of adj[node]) {
if (child !== parent) {
const depth = dfs(child, node) + weight;
if (depth > max1) { max2 = max1; max1 = depth; }
else if (depth > max2) { max2 = depth; }
}
}

diameter = Math.max(diameter, max1 + max2);
return max1;
}

dfs(0, -1);
return diameter;
}

Tracking the top two values is a useful trick that shows up in tree problems repeatedly.

Dealing with Stack Overflow

Deep trees (chains of 100k+ nodes) blow the call stack with recursive DFS. Convert to iterative:

def iterative_tree_dp(n, adj):
    # BFS to get processing order, then reverse
    from collections import deque
    order = []
    visited = [False] * n
    queue = deque([0])
    visited[0] = True
    parent = [-1] * n

while queue:
node = queue.popleft()
order.append(node)
for child in adj[node]:
if not visited[child]:
visited[child] = True
parent[child] = node
queue.append(child)

# process in reverse BFS order (leaves first) dp = [0] * n for node in reversed(order): for child in adj[node]: if child != parent[node]: dp[node] += dp[child] + 1 # example computation

return dp

Reverse BFS order guarantees children are processed before parents. This replaces recursion entirely and handles trees of any depth.

Complexity

PatternTimeSpace
Subtree DPO(n)O(n)
Rerooting (two-pass)O(n)O(n)
Include/ExcludeO(n)O(n)
DiameterO(n)O(n)
All tree DP runs in O(n) time because you visit each node exactly once (or twice for rerooting). Space is O(n) for the recursion stack and any per-node storage.

Common Mistakes

Not tracking the parent. In general trees (adjacency lists), if you don't pass the parent, you'll revisit it as a "child" and infinite-loop. Returning the wrong thing from DFS. Tree DP functions often need to return a value for the parent's use that's different from the thing they track globally. Confusing these is the most common bug. Stack overflow on deep trees. Python's default recursion limit is 1000. Use sys.setrecursionlimit(200000) or convert to iterative. JavaScript's limit varies by engine but is similarly an issue for 50k+ nodes. Forgetting rerooting is two passes. A single DFS gives you the answer for one root only. The second pass propagates the answer to all nodes in O(n). Trying to do it in one pass is a trap.

Practice these patterns on CodeUp — tree DP questions are less common than array DP but they're high-signal in interviews because fewer candidates prepare for them.

Ad 728x90