Hide sidebar

All Nodes Distance K in Binary Tree

TreesBFS / Graph

Problem Statement

Given the root of a binary tree, the value of a target node, and an integer k, return an array of the values of all nodes that have a distance k from the target node.

Example

target = 5, k = 2

674250813

Output: [7, 4, 1]

Algorithm Explanation

This problem can be solved by treating the tree as an undirected graph. Once we have a graph representation, we can perform a Breadth-First Search (BFS) starting from the target node to find all nodes at distance K.

Algorithm Steps

  • Graph Conversion: Traverse the binary tree (using DFS or BFS) to build an adjacency list that represents the connections between nodes. Each node will have pointers to its parent, left child, and right child.
  • BFS from Target: Start a BFS from the `target` node. Use a queue to keep track of nodes to visit and their distance from the target.
  • Track Visited Nodes: Use a set to keep track of visited nodes to avoid cycles and redundant processing.
  • Find Nodes at Distance K: When the distance of a node in the BFS equals `k`, add its value to the result list.
  • Stop Early: If the distance exceeds `k`, we can stop searching down that path.
674250813
Graph Construction
Start building graph. Queue: [3]

Adjacency List

All Nodes Distance K Solution

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution:
    def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> List[int]:
        graph = collections.defaultdict(list)
        
        # Convert tree to graph
        q = collections.deque([root])
        while q:
            node = q.popleft()
            if node.left:
                graph[node].append(node.left)
                graph[node.left].append(node)
                q.append(node.left)
            if node.right:
                graph[node].append(node.right)
                graph[node.right].append(node)
                q.append(node.right)

        # BFS from target
        ans = []
        visited = {target}
        q = collections.deque([(target, 0)])
        
        while q:
            node, dist = q.popleft()
            
            if dist == k:
                ans.append(node.val)
            
            if dist > k:
                continue
                
            for neighbor in graph[node]:
                if neighbor not in visited:
                    visited.add(neighbor)
                    q.append((neighbor, dist + 1))
                    
        return ans