[Graph Partition] add log for graph partition reasons and #partitions (#159425)

Previously, we log `skipping cudagraphs due to [xxx reasons]` when there are cudagraph-unsafe ops. With graph partition, we will split off these ops and cudagraph remaining parts. But the log message is also skipped.

In this PR, we add logs for graph partition reasons and the number of partitions to better understand the workload.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159425
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng
2025-07-31 04:21:03 +00:00
committed by PyTorch MergeBot
parent 7a4167a164
commit 6b9473469f
3 changed files with 68 additions and 5 deletions

View File

@ -68,6 +68,7 @@ from .utils import (
is_multi_outputs_template,
is_output_of_multi_outputs_template,
is_wait,
maybe_log_cudagraph_partition,
sympy_product,
)
from .virtualized import V
@ -4225,27 +4226,42 @@ class Scheduler:
and name not in self.mutation_real_name
)
def should_partition(self, node: BaseSchedulerNode) -> bool:
def should_partition(
self, node: BaseSchedulerNode, should_log: bool = False
) -> bool:
"""Return True if we should partition the inductor graph on this node"""
# avoid duplicating logs when should_partition is called multiple times
# on the same node
def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None:
return
log_partition_reason = maybe_log_cudagraph_partition if should_log else noop_log
if isinstance(node, FusedSchedulerNode):
return any(self.should_partition(snode) for snode in node.snodes)
if not node.is_gpu():
return True
assert node.node is not None
if not node.is_gpu():
log_partition_reason("non gpu ops", node=node)
if node.node is None:
return True
if isinstance(node.node, ir.DeviceCopy):
log_partition_reason("DeviceCopy ops", node=node)
return True
if isinstance(node.node, ir.Conditional):
log_partition_reason("Conditional ops", node=node)
return True
if getattr(node.node, "unbacked_bindings", None):
log_partition_reason("unbacked binding ops", node=node)
return True
if is_cudagraph_unsafe_op(node.node):
log_partition_reason("CUDAGraph-unsafe custom ops", node=node)
return True
return False
@ -4715,7 +4731,7 @@ class Scheduler:
cur_partition: PartitionType = []
skip_cudagraphs = []
for node in self.nodes:
should_partition = self.should_partition(node)
should_partition = self.should_partition(node, should_log=True)
if cur_partition and skip_cudagraph != should_partition:
partitions.append(cur_partition)
skip_cudagraphs.append(skip_cudagraph)
@ -4793,6 +4809,10 @@ class Scheduler:
"""
partitions, signatures = self.graph_partition()
if len(partitions) > 1:
msg = f"cudagraph partition into {len(partitions)} partitions"
maybe_log_cudagraph_partition(msg=msg, prefix="")
for partition, signature in zip(partitions, signatures):
assert len(partition) >= 1, (
f"Each partition must have at least one node but found {len(partition)}"