[Graph Partition] reorder for minimal number of partitions (#151968)

This pr adds an optimal reordering for minimizing #partitions.

## Optimal reordering for minimizing #partitions

A bfs could minimize #partitions (ignore peak memory for now):
1. For each node, compute node_to_indegree: dict[node, int].
2. Maintain 2 queues: cudagraphable_nodes, and non_cudagraphable_nodes. Iterate through all nodes and add nodes to one of these 2 queues if node_to_indegree[node] == 0.
3. While non_cudagraphable_nodes is not empty: Pop 1 node, schedule it, update the indegree of all its successors, and add its successor nodes to one of the queues if node_to_indegree[successor] == 0.
4. While cudagraphable_nodes is not empty: Pop 1 node, schedule it, update the indegree of all its successors, and add its successor nodes to one of the queues if node_to_indegree[successor] == 0.
5. Repeat step 3 & 4 until all nodes have been scheduled.

We call this strategy `reorder_for_minimizing_partition`.

**Q: Why is this optimal?**

Suppose this is not optimal, we have a counter example with 2 non_cudagraphable regions:

```
[non_cudagrable1, cudagraphable2, non_cudagraphable3]
```

where we can reorder to only 1 non_cudagraphable region:

```
[non_cudagrable1, non_cudagraphable3, cudagraphable2]
```

This reorder means non_cudagraphable3 does not depend on cudagraphable2. So after we scheduled non_cudagraphable1, both non_cudagraphable3 and cudagraphable2 have in_degree as 0. If this is true, Step 3 should have already scheduled non_cudagraphable3 before cudagraphable2 such that the counter example cannot exist.

This shows we cannot find such a counter example and the bfs is optimal on minimizing #partitions.

## Minimize peak memory

`reorder_for_peak_memory` currently uses topological_sort_dfs, topological_sort_lpmf, and topological_sort_bfs, where the later 2 are bfs. ILP brings small benefits and it can hardly scale to more than 100 nodes, according to @xuanzhang816. So ILP is not used for peak memory reorder in the inductor.

Heuristics strategy:
- Conduct reorder_for_peak_memory as the default order
- Conduct reorder_for_minimal_partitions and get results as list[tuple[partition, bool]], where partition: list[BaseSchedulerNode] and bool for cudagraphable.
- If the reorder increases peak memory too much, we use the default order.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151968
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng
2025-04-29 17:17:16 +00:00
committed by PyTorch MergeBot
parent a77a44761b
commit 797768cd90
3 changed files with 222 additions and 21 deletions

View File

@ -2686,8 +2686,8 @@ if HAS_CUDA:
loss.backward()
optimizer.step()
# 2 graph partitions lead to 2 fwd cudagraphs and 2 bwd cudagraphs
self.assertEqual(self.get_manager().new_graph_id().id, 4)
# 2 graph partitions lead to 2 fwd cudagraphs and 1 bwd cudagraphs
self.assertEqual(self.get_manager().new_graph_id().id, 3)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_cpu_only(self):
@ -3088,6 +3088,89 @@ if HAS_CUDA:
self.assertEqual(self.get_manager().new_graph_id().id, 3)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_reorder_cpu_and_gpu(self):
def f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu):
x_cuda0 = x_cuda + 1
x_cuda1 = x_cuda0 @ weight_cuda
x_cuda2 = 2 * (x_cuda1 + x_cuda)
y_cpu0 = y_cpu + 1
y_cpu1 = y_cpu0 @ weight_cpu
z_cuda0 = z_cuda + 1
z_cuda1 = z_cuda0 @ weight_cuda
z_cuda2 = 2 * (z_cuda1 + z_cuda)
return x_cuda2, y_cpu1, z_cuda2
x_cuda = torch.randn(3, 3, device="cuda")
y_cpu = torch.randn(3, 3, device="cpu")
z_cuda = torch.randn(3, 3, device="cuda")
weight_cuda = torch.randn(3, 3, device="cuda")
weight_cpu = torch.randn(3, 3, device="cpu")
eager_out = f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu)
compiled_f = torch.compile(f, mode="reduce-overhead")
for _ in range(3):
compiled_out = compiled_f(
x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu
)
self.assertEqual(eager_out, compiled_out)
# reorder merges ops on cuda into 1 graph partition
self.assertEqual(self.get_manager().new_graph_id().id, 1)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_reorder_cpu_and_gpu_interleave(self):
def f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu):
# partition 1 on cuda, no dependency
x_cuda0 = x_cuda + 1
x_cuda1 = x_cuda0 @ weight_cuda
x_cuda2 = 2 * (x_cuda1 + x_cuda)
# partition 2 on cpu w/ dependency on partition 1
y_cpu0 = y_cpu + 1
x_cuda2_cpu = x_cuda2.cpu() # adds dependency on gpu computations
y_cpu1 = y_cpu0 @ weight_cpu + x_cuda2_cpu
# partition 3 on cuda w/o dependency
z_cuda0 = z_cuda + 1
z_cuda1 = z_cuda0 @ weight_cuda
z_cuda2 = 2 * (z_cuda1 + z_cuda)
# partition 4 on cpu w/o dependency
y_cpu2 = y_cpu + 5
y_cpu3 = y_cpu2 @ weight_cpu
# partition 5 on cuda w/o dependency
u_cuda0 = z_cuda + 3
u_cuda1 = u_cuda0 @ weight_cuda
u_cuda2 = 2 * (u_cuda0 + u_cuda1)
return x_cuda2, y_cpu1, z_cuda2, y_cpu3, u_cuda2
x_cuda = torch.randn(3, 3, device="cuda")
y_cpu = torch.randn(3, 3, device="cpu")
z_cuda = torch.randn(3, 3, device="cuda")
weight_cuda = torch.randn(3, 3, device="cuda")
weight_cpu = torch.randn(3, 3, device="cpu")
eager_out = f(x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu)
compiled_f = torch.compile(f, mode="reduce-overhead")
for _ in range(3):
compiled_out = compiled_f(
x_cuda, y_cpu, z_cuda, weight_cuda, weight_cpu
)
self.assertEqual(eager_out, compiled_out)
# the optimal order is
# [[partition 4 on cpu], [partition 1,3,5 on cuda], [partition 2 on cpu]]
# since partition2 depends on partition1. So we have 1 cudagraph in total.
self.assertEqual(self.get_manager().new_graph_id().id, 1)
@config.patch(implicit_fallbacks=True)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_reorder_custom_op_with_no_dependency(self):

View File

@ -22,6 +22,13 @@ if TYPE_CHECKING:
torch_log = logging.getLogger(__name__)
@dataclasses.dataclass
class PeakMemoryResult:
order: list[BaseSchedulerNode]
peak_memory: int
method: str
@dataclasses.dataclass
class MemoryPlanningInfoForBuffer:
size_alloc: int = 0
@ -578,6 +585,35 @@ def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNo
return result
def prepare_planning_info(
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
name_to_fused_node: dict[str, BaseSchedulerNode],
graph_inputs: OrderedSet[str],
graph_outputs: OrderedSet[str],
) -> tuple[int, dict[str, FreeableInputBuffer]]:
"""
Prepare planning info. As nodes are scheduled one at a time, these help
keep track of when a buffer can be freed, and when a node can be scheduled
Returns:
int: peak memory estimation
dict[str, FreeableInputBuffer]: name to freeable input buffer
"""
name_to_freeable_input_buf = get_freeable_input_buf(nodes, graph_inputs)
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
assign_memory_planning_info_for_scheduler_nodes(
nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf
)
# the default
estimated_peak_memory, _ = estimate_peak_memory(
nodes, name_to_freeable_input_buf, graph_outputs
)
return estimated_peak_memory, name_to_freeable_input_buf
def reorder_for_peak_memory(
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
@ -597,29 +633,16 @@ def reorder_for_peak_memory(
torch_log.info("Reordering for peak memory -- %d nodes", len(nodes))
@dataclasses.dataclass
class PeakMemoryResult:
order: list[BaseSchedulerNode]
peak_memory: int
method: str
# preparation -- as nodes are scheduled one at a time, these help
# keep track of when a buffer can be freed, and when a node can be scheduled
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
nodes, graph_inputs
)
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
assign_memory_planning_info_for_scheduler_nodes(
nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf
estimated_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
)
# keep track of the peak memory estimates of different methods
peak_memory_diff_methods: list[PeakMemoryResult] = []
# the default
estimated_peak_memory, _ = estimate_peak_memory(
nodes, name_to_freeable_input_buf, graph_outputs
)
peak_memory_diff_methods.append(
PeakMemoryResult(nodes, estimated_peak_memory, "baseline")
)

View File

@ -2104,6 +2104,7 @@ class Scheduler:
self.process_grouped_nodes()
if torch._inductor.config.graph_partition:
self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
self.compute_last_usage()
@ -4282,6 +4283,100 @@ class Scheduler:
return signatures[::-1]
def reorder_for_minimizing_partition(
self,
nodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]:
"""
Reorder nodes to minimize the number of partitions via a bfs
topological sort. This is the optimal reodering such that the
number of partitions cannot be reduced further. This may be
sub-optimal for other metrics such as peak memory. This does not
change relative orders of two cudagraphable nodes, nor the
relative order of two non_cudagraphable nodes.
"""
node_to_indegree: dict[BaseSchedulerNode, int] = dict()
cudagraphable_nodes: collections.deque[BaseSchedulerNode] = collections.deque()
non_cudagraphable_nodes: collections.deque[BaseSchedulerNode] = (
collections.deque()
)
def insert_pending_nodes(node: BaseSchedulerNode) -> None:
if self.should_partition(node):
non_cudagraphable_nodes.append(node)
else:
cudagraphable_nodes.append(node)
def update_indegree(node: BaseSchedulerNode) -> None:
for succ_node in node.mpi_node.succ_nodes:
assert node_to_indegree[succ_node] > 0
node_to_indegree[succ_node] -= 1
if node_to_indegree[succ_node] == 0:
insert_pending_nodes(succ_node)
for node in nodes:
node_to_indegree[node] = len(node.mpi_node.pred_nodes)
if node_to_indegree[node] == 0:
insert_pending_nodes(node)
schedule: list[BaseSchedulerNode] = []
num_iters: int = 0
while num_iters < len(nodes) and (
non_cudagraphable_nodes or cudagraphable_nodes
):
while non_cudagraphable_nodes:
node = non_cudagraphable_nodes.popleft()
schedule.append(node)
update_indegree(node)
while cudagraphable_nodes:
node = cudagraphable_nodes.popleft()
schedule.append(node)
update_indegree(node)
num_iters += 1
if num_iters > len(nodes):
raise RuntimeError(
"""
Failed to schedule, while loop ran too long when
reordering for minimizing the num of partitions
"""
)
return schedule
def maybe_reorder_for_minimizing_partition(
self,
nodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]:
"""
Reorder nodes to minimize the number of partitions if this only slightly
increase peak memory.
"""
from .memory import estimate_peak_memory, prepare_planning_info
graph_outputs = OrderedSet(V.graph.get_output_names())
default_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
nodes,
self.name_to_buf,
self.name_to_fused_node,
OrderedSet(V.graph.graph_inputs.keys()),
graph_outputs,
)
reordered_nodes = self.reorder_for_minimizing_partition(nodes)
reorder_peak_memory, _ = estimate_peak_memory(
reordered_nodes, name_to_freeable_input_buf, graph_outputs
)
# 1.1 here means 10% extra peak memory budget which is quite arbitrary
if reorder_peak_memory < default_peak_memory * 1.1:
return reordered_nodes
return nodes
def reorder_for_partition_with_simple_dependency(
self, nodes: list[BaseSchedulerNode]
) -> list[BaseSchedulerNode]: