[graph partition] Add way to register custom rule (#163310)

This PR adds an experimental way to register a custom rule for if
inductor should partition the graph around an operator.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163310
Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison
ghstack dependencies: #162117, #162307, #162651
This commit is contained in:
rzou
2025-09-19 11:55:11 -07:00
committed by PyTorch MergeBot
parent 0098e5636d
commit ee7bdd8f2f
2 changed files with 105 additions and 0 deletions

View File

@ -23,6 +23,8 @@ if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from types import ModuleType
import weakref
import sympy
import torch
@ -92,6 +94,28 @@ _T = TypeVar("_T")
_P = ParamSpec("_P")
_custom_should_partition_fns: weakref.WeakKeyDictionary[
torch._ops.OpOverload, Callable[..., bool]
] = weakref.WeakKeyDictionary()
def register_should_partition_rule(
op: torch._ops.OpOverload,
func: Callable[..., bool],
) -> None:
"""Register a function that says if Inductor should partition the graph on this op.
The function should be have the same signature as the operator.
Inductor will invoke the function with FakeTensors when it needs to decide
if the graph should be partitioned.
`register_should_partition_rule` is currently private and experimental.
Use at your own risk.
"""
assert isinstance(op, torch._ops.OpOverload)
_custom_should_partition_fns[op] = func
@dataclasses.dataclass
class SchedulerBuffer:
scheduler: Scheduler
@ -4632,6 +4656,25 @@ class Scheduler:
) -> bool:
"""Return True if we should partition the inductor graph on this node"""
# Allow users to manually specify if a node should be partitioned
# Can only do this for FallbackKernels
ir_node = node.node
if isinstance(ir_node, torch._inductor.ir.FallbackKernel):
operator = ir_node.op_overload
if operator is not None and operator in _custom_should_partition_fns:
assert isinstance(operator, torch._ops.OpOverload)
should_partition_fn = _custom_should_partition_fns[operator]
fx_node = ir_node.get_origin_node()
assert fx_node is not None
success, fake_args, fake_kwargs = (
torch._inductor.fx_utils.get_fake_args_kwargs(fx_node)
)
assert success, (
"If this op came from a custom inductor pass, make sure to run FakeTensorUpdator"
)
should_partition = should_partition_fn(*fake_args, **fake_kwargs)
return should_partition
# When not using cudagraphs, keep all kernels in the `call` function
# instead of graph partition functions, since graph partition only brings
# benefit to cudagraph