mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[graph partition] Add way to register custom rule (#163310)
This PR adds an experimental way to register a custom rule for if inductor should partition the graph around an operator. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/163310 Approved by: https://github.com/ProExpertProg, https://github.com/BoyuanFeng, https://github.com/eellison ghstack dependencies: #162117, #162307, #162651
This commit is contained in:
@ -23,6 +23,8 @@ if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
from types import ModuleType
|
||||
|
||||
import weakref
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
@ -92,6 +94,28 @@ _T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
_custom_should_partition_fns: weakref.WeakKeyDictionary[
|
||||
torch._ops.OpOverload, Callable[..., bool]
|
||||
] = weakref.WeakKeyDictionary()
|
||||
|
||||
|
||||
def register_should_partition_rule(
|
||||
op: torch._ops.OpOverload,
|
||||
func: Callable[..., bool],
|
||||
) -> None:
|
||||
"""Register a function that says if Inductor should partition the graph on this op.
|
||||
|
||||
The function should be have the same signature as the operator.
|
||||
Inductor will invoke the function with FakeTensors when it needs to decide
|
||||
if the graph should be partitioned.
|
||||
|
||||
`register_should_partition_rule` is currently private and experimental.
|
||||
Use at your own risk.
|
||||
"""
|
||||
assert isinstance(op, torch._ops.OpOverload)
|
||||
_custom_should_partition_fns[op] = func
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SchedulerBuffer:
|
||||
scheduler: Scheduler
|
||||
@ -4632,6 +4656,25 @@ class Scheduler:
|
||||
) -> bool:
|
||||
"""Return True if we should partition the inductor graph on this node"""
|
||||
|
||||
# Allow users to manually specify if a node should be partitioned
|
||||
# Can only do this for FallbackKernels
|
||||
ir_node = node.node
|
||||
if isinstance(ir_node, torch._inductor.ir.FallbackKernel):
|
||||
operator = ir_node.op_overload
|
||||
if operator is not None and operator in _custom_should_partition_fns:
|
||||
assert isinstance(operator, torch._ops.OpOverload)
|
||||
should_partition_fn = _custom_should_partition_fns[operator]
|
||||
fx_node = ir_node.get_origin_node()
|
||||
assert fx_node is not None
|
||||
success, fake_args, fake_kwargs = (
|
||||
torch._inductor.fx_utils.get_fake_args_kwargs(fx_node)
|
||||
)
|
||||
assert success, (
|
||||
"If this op came from a custom inductor pass, make sure to run FakeTensorUpdator"
|
||||
)
|
||||
should_partition = should_partition_fn(*fake_args, **fake_kwargs)
|
||||
return should_partition
|
||||
|
||||
# When not using cudagraphs, keep all kernels in the `call` function
|
||||
# instead of graph partition functions, since graph partition only brings
|
||||
# benefit to cudagraph
|
||||
|
Reference in New Issue
Block a user