mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Graph Partition] add log for graph partition reasons and #partitions (#159425)
Previously, we log `skipping cudagraphs due to [xxx reasons]` when there are cudagraph-unsafe ops. With graph partition, we will split off these ops and cudagraph remaining parts. But the log message is also skipped. In this PR, we add logs for graph partition reasons and the number of partitions to better understand the workload. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159425 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
7a4167a164
commit
6b9473469f
@ -68,6 +68,7 @@ from .utils import (
|
||||
is_multi_outputs_template,
|
||||
is_output_of_multi_outputs_template,
|
||||
is_wait,
|
||||
maybe_log_cudagraph_partition,
|
||||
sympy_product,
|
||||
)
|
||||
from .virtualized import V
|
||||
@ -4225,27 +4226,42 @@ class Scheduler:
|
||||
and name not in self.mutation_real_name
|
||||
)
|
||||
|
||||
def should_partition(self, node: BaseSchedulerNode) -> bool:
|
||||
def should_partition(
|
||||
self, node: BaseSchedulerNode, should_log: bool = False
|
||||
) -> bool:
|
||||
"""Return True if we should partition the inductor graph on this node"""
|
||||
|
||||
# avoid duplicating logs when should_partition is called multiple times
|
||||
# on the same node
|
||||
def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None:
|
||||
return
|
||||
|
||||
log_partition_reason = maybe_log_cudagraph_partition if should_log else noop_log
|
||||
|
||||
if isinstance(node, FusedSchedulerNode):
|
||||
return any(self.should_partition(snode) for snode in node.snodes)
|
||||
|
||||
if not node.is_gpu():
|
||||
return True
|
||||
assert node.node is not None
|
||||
|
||||
if not node.is_gpu():
|
||||
log_partition_reason("non gpu ops", node=node)
|
||||
|
||||
if node.node is None:
|
||||
return True
|
||||
|
||||
if isinstance(node.node, ir.DeviceCopy):
|
||||
log_partition_reason("DeviceCopy ops", node=node)
|
||||
return True
|
||||
|
||||
if isinstance(node.node, ir.Conditional):
|
||||
log_partition_reason("Conditional ops", node=node)
|
||||
return True
|
||||
|
||||
if getattr(node.node, "unbacked_bindings", None):
|
||||
log_partition_reason("unbacked binding ops", node=node)
|
||||
return True
|
||||
|
||||
if is_cudagraph_unsafe_op(node.node):
|
||||
log_partition_reason("CUDAGraph-unsafe custom ops", node=node)
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -4715,7 +4731,7 @@ class Scheduler:
|
||||
cur_partition: PartitionType = []
|
||||
skip_cudagraphs = []
|
||||
for node in self.nodes:
|
||||
should_partition = self.should_partition(node)
|
||||
should_partition = self.should_partition(node, should_log=True)
|
||||
if cur_partition and skip_cudagraph != should_partition:
|
||||
partitions.append(cur_partition)
|
||||
skip_cudagraphs.append(skip_cudagraph)
|
||||
@ -4793,6 +4809,10 @@ class Scheduler:
|
||||
"""
|
||||
partitions, signatures = self.graph_partition()
|
||||
|
||||
if len(partitions) > 1:
|
||||
msg = f"cudagraph partition into {len(partitions)} partitions"
|
||||
maybe_log_cudagraph_partition(msg=msg, prefix="")
|
||||
|
||||
for partition, signature in zip(partitions, signatures):
|
||||
assert len(partition) >= 1, (
|
||||
f"Each partition must have at least one node but found {len(partition)}"
|
||||
|
Reference in New Issue
Block a user