Compare commits

...

4 Commits

3 changed files with 43 additions and 51 deletions

View File

@ -945,35 +945,46 @@ if HAS_CUDA_AND_TRITON:
self.assertEqual(num_partitions, 1)
@torch.library.custom_op("mylib::baz", mutates_args=())
def baz(x: torch.Tensor, flag: int) -> torch.Tensor:
def baz(x: torch.Tensor) -> torch.Tensor:
return x.clone()
@baz.register_fake
def _(x, flag):
def _(x):
return x.clone()
def should_partition(x, flag):
return flag
# custom_should_partition_ops takes effect which lead to 2 partitions
torch._inductor.config.custom_should_partition_ops = ["mylib::baz"]
torch._inductor.scheduler.register_should_partition_rule(
torch.ops.mylib.baz.default, should_partition
)
def f(x, flag):
def f(x):
x = x + 1
x = baz(x, flag)
x = baz(x)
x = x + 1
return x
f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True)
_, code = run_and_get_code(f_compiled, x, True)
_, code = run_and_get_code(f_compiled, x)
num_partitions = get_num_partitions(code)
self.assertEqual(num_partitions, 2)
_, code = run_and_get_code(f_compiled, x, False)
# update the config should NOT force recompile
torch._inductor.config.custom_should_partition_ops = []
with torch.compiler.set_stance("fail_on_recompile"):
f_compiled(x)
# run_and_get_code forces recompile. Now we should cache miss, recompile, and
# only have 1 partition.
_, code = run_and_get_code(f_compiled, x)
num_partitions = get_num_partitions(code)
self.assertEqual(num_partitions, 1)
# test that op_overload name takes effect which lead to 2 partitions
torch._inductor.config.custom_should_partition_ops = ["mylib::baz.default"]
f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True)
_, code = run_and_get_code(f_compiled, x)
num_partitions = get_num_partitions(code)
self.assertEqual(num_partitions, 2)
@torch._inductor.config.patch("graph_partition", True)
@torch._inductor.config.patch("implicit_fallbacks", True)
def test_graph_partition_with_memory_plan_reuse(self):

View File

@ -483,6 +483,11 @@ graph_partition: bool = (
== "1"
)
# register ops upon which inductor should partition the graph. name format should be
# "namespace::kernel_name" (e.g., aten::mm) for op overload packet, or
# "namespace::kernel_name.overload" (e.g., aten::mm.default).
custom_should_partition_ops: list[str] = []
# whether template autotuning should allow flexible layouts if possible (e.g. only extern choices)
max_autotune_allow_flexible_layouts: bool = False

View File

@ -25,8 +25,6 @@ if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from types import ModuleType
import weakref
import sympy
import torch
@ -97,28 +95,6 @@ _T = TypeVar("_T")
_P = ParamSpec("_P")
_custom_should_partition_fns: weakref.WeakKeyDictionary[
torch._ops.OpOverload, Callable[..., bool]
] = weakref.WeakKeyDictionary()
def register_should_partition_rule(
op: torch._ops.OpOverload,
func: Callable[..., bool],
) -> None:
"""Register a function that says if Inductor should partition the graph on this op.
The function should be have the same signature as the operator.
Inductor will invoke the function with FakeTensors when it needs to decide
if the graph should be partitioned.
`register_should_partition_rule` is currently private and experimental.
Use at your own risk.
"""
assert isinstance(op, torch._ops.OpOverload)
_custom_should_partition_fns[op] = func
class MixOrderReduction:
"""
This class contains utility functions to decide if we should fuse reductions
@ -4946,21 +4922,21 @@ class Scheduler:
# Allow users to manually specify if a node should be partitioned
# Can only do this for FallbackKernels
ir_node = node.node
if isinstance(ir_node, torch._inductor.ir.FallbackKernel):
operator = ir_node.op_overload
if operator is not None and operator in _custom_should_partition_fns:
assert isinstance(operator, torch._ops.OpOverload)
should_partition_fn = _custom_should_partition_fns[operator]
fx_node = ir_node.get_origin_node()
assert fx_node is not None
success, fake_args, fake_kwargs = (
torch._inductor.fx_utils.get_fake_args_kwargs(fx_node)
)
assert success, (
"If this op came from a custom inductor pass, make sure to run FakeTensorUpdator"
)
should_partition = should_partition_fn(*fake_args, **fake_kwargs)
return should_partition
if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
op := ir_node.op_overload
):
op_overload_packet_name = op.name()
op_overload_name = (
f"{op_overload_packet_name}.{op._overloadname}"
if isinstance(op, torch._ops.OpOverload)
else op_overload_packet_name
)
if (
op_overload_packet_name in config.custom_should_partition_ops
or op_overload_name in config.custom_should_partition_ops
):
assert isinstance(op, torch._ops.OpOverload)
return True
# When not using cudagraphs, keep all kernels in the `call` function
# instead of graph partition functions, since graph partition only brings