From 4d4abec80f03cd8fdefe1d9cb3a60d3690cd777e Mon Sep 17 00:00:00 2001 From: Xuan Zhang Date: Fri, 5 Sep 2025 12:06:16 -0700 Subject: [PATCH] allow user to pass in custom partitioner function (#157580) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157580 Approved by: https://github.com/bdhirsh --- test/functorch/test_aotdispatch.py | 44 +++++++++++++ test/inductor/test_codecache.py | 43 ++++++++++++ test/inductor/test_custom_partitioner_fn.py | 72 +++++++++++++++++++++ torch/_inductor/codecache.py | 15 +++++ torch/_inductor/compile_fx.py | 35 +++++++--- torch/_inductor/config.py | 3 + torch/_inductor/custom_graph_pass.py | 56 ++++++++++++++++ torch/_inductor/fuzzer.py | 19 +++++- 8 files changed, 276 insertions(+), 11 deletions(-) create mode 100644 test/inductor/test_custom_partitioner_fn.py diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 5d068310f69d..5e8902b0aa8f 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -58,6 +58,7 @@ from torch._functorch.aot_autograd import ( ) from torch._higher_order_ops.out_dtype import out_dtype from torch._inductor.codecache import compiled_fx_graph_hash +from torch._inductor.custom_graph_pass import CustomPartitionerFn from torch._inductor.output_code import MockFXGraphCacheOutput from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node @@ -5687,6 +5688,49 @@ def forward(self, primals_1, tangents_1): return (cat,)""", ) + @unittest.skipIf(not USE_NETWORKX, "networkx not available") + def test_custom_partitioner_fn(self): + class MyCustomPartitionerFn(CustomPartitionerFn): + def __init__(self): + super().__init__() + self.called = False + + def __call__(self, gm, joint_inputs, **kwargs): + self.called = True + return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs) + + def uuid(self): + return None + + def f(x): + return x.cos().cos() + + inp = [torch.randn((4, 4), requires_grad=True)] + custom_partitioner_fn = MyCustomPartitionerFn() + fw_graph, bw_graph = get_fw_bw_graph(f, inp, partitioner=custom_partitioner_fn) + self.assertTrue(custom_partitioner_fn.called) + self.assertExpectedInline( + fw_graph.code.strip(), + """\ +def forward(self, primals_1): + cos = torch.ops.aten.cos.default(primals_1) + cos_1 = torch.ops.aten.cos.default(cos); cos = None + return (cos_1, primals_1)""", + ) + self.assertExpectedInline( + bw_graph.code.strip(), + """\ +def forward(self, primals_1, tangents_1): + cos = torch.ops.aten.cos.default(primals_1) + sin = torch.ops.aten.sin.default(cos); cos = None + neg = torch.ops.aten.neg.default(sin); sin = None + mul = torch.ops.aten.mul.Tensor(tangents_1, neg); tangents_1 = neg = None + sin_1 = torch.ops.aten.sin.default(primals_1); primals_1 = None + neg_1 = torch.ops.aten.neg.default(sin_1); sin_1 = None + mul_1 = torch.ops.aten.mul.Tensor(mul, neg_1); mul = neg_1 = None + return (mul_1,)""", + ) + @unittest.skipIf(not USE_NETWORKX, "networkx not available") def test_min_cut_partitioner_save_shape(self): def f(x): diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 757ea061c26f..6da49ab39229 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -33,6 +33,7 @@ from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.custom_graph_pass import ( CustomGraphModulePass, CustomGraphPass, + CustomPartitionerFn, get_hash_for_files, ) from torch._inductor.graph import GraphLowering @@ -2115,6 +2116,19 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) +class TestCustomPartitionerFn(CustomPartitionerFn): + def __init__(self): + self._uuid = None + + def __call__( + self, gm, joint_inputs, **kwargs + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + return gm, gm # Dummy implementation + + def uuid(self) -> Optional[Union[bytes, str]]: + return self._uuid + + class TestFxGraphCacheHashing(TestCase): def test_parameter_constants(self): """ @@ -2520,6 +2534,35 @@ class TestFxGraphCacheHashing(TestCase): self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + def test_hash_custom_partitioner_fn(self): + """ + Test that the custom partitioner function's UUID is properly used in the FX graph cache hashing. + """ + custom_partitioner_fn = TestCustomPartitionerFn() + with config.patch({"custom_partitioner_fn": custom_partitioner_fn}): + custom_partitioner_fn._uuid = "1" + details1 = FxGraphHashDetails(None, [], {}, []) + details2 = FxGraphHashDetails(None, [], {}, []) + + custom_partitioner_fn._uuid = "2" + details3 = FxGraphHashDetails(None, [], {}, []) + + self.assertEqual(details1._custom_partitioner_fn, "1") + self.assertEqual(details2._custom_partitioner_fn, "1") + self.assertEqual(details3._custom_partitioner_fn, "2") + + gm = torch.fx.GraphModule({}, torch.fx.Graph()) + pickler = FxGraphCachePickler(gm) + + self.assertEqual( + pickler.dumps(details1), + pickler.dumps(details2), + ) + self.assertNotEqual( + pickler.dumps(details1), + pickler.dumps(details3), + ) + def test_bypass_unsupported(self): """ Test _reduce_unsupported diff --git a/test/inductor/test_custom_partitioner_fn.py b/test/inductor/test_custom_partitioner_fn.py new file mode 100644 index 000000000000..722a154b27ff --- /dev/null +++ b/test/inductor/test_custom_partitioner_fn.py @@ -0,0 +1,72 @@ +# Owner(s): ["module: pt2-dispatcher"] +import torch +from functorch.compile import min_cut_rematerialization_partition +from torch._C import FileCheck +from torch._inductor.custom_graph_pass import CustomPartitionerFn, get_hash_for_files +from torch._inductor.test_case import TestCase +from torch._inductor.utils import run_fw_bw_and_get_code +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU + + +class MyCustomPartitionerFn(CustomPartitionerFn): + """ + A custom partitioner function with static_lifetime_input_indices overwrites. + """ + + def __init__(self): + super().__init__() + self.called = False + + def __call__(self, gm, joint_inputs, **kwargs): + self.called = True + kwargs["static_lifetime_input_indices"] = [0, 1] + return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs) + + def uuid(self): + return get_hash_for_files((__file__,)) + + +class TestCustomPartitionerFn(TestCase): + def test_custom_partitioner_fn(self): + """ + For function f(a, b), with the partitioner in the compile_fx stack, + the addition `a+b` (equivalently `buf0`) is saved for backward. + With the custom partitioner function, we indicate that + `a` and `b` (equivalently `primals_1` and `primals_2`) do not take + additional memory and thus, they are saved for backward. + """ + + # initialization + @torch.compile + def f(a, b): + return (a + b).cos().cos() + + a = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE) + b = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE) + + # CASE 1 -- default + # addition `a + b` (i.e, `buf0`) is saved for backward. + code_og = run_fw_bw_and_get_code(lambda: f(a, b)) + fwd_code_og = code_og[1][0] + FileCheck().check("return (buf1, buf0, )").run(fwd_code_og) + + # CASE 2 -- custom partitioner function + # `a` and `b` (i.e., `primals_1` and `primals_2`) are saved for backward. + custom_partitioner_fn = MyCustomPartitionerFn() + self.assertFalse(custom_partitioner_fn.called) + self.assertIsNotNone(custom_partitioner_fn.uuid()) + + with torch._inductor.config.patch(custom_partitioner_fn=custom_partitioner_fn): + code_cp = run_fw_bw_and_get_code(lambda: f(a, b)) + fwd_code_cp = code_cp[1][0] + FileCheck().check("return (buf0, primals_1, primals_2, )").run(fwd_code_cp) + + # make sure the custom partitioner function is indeed invoked + self.assertTrue(custom_partitioner_fn.called) + + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + if HAS_GPU: + run_tests() diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 041abc9a473e..7b24208a2c51 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -83,6 +83,8 @@ from torch._inductor.custom_graph_pass import ( CustomGraphModulePass, CustomGraphPass, CustomGraphPassType, + CustomPartitionerFn, + CustomPartitionerFnType, ) from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param from torch._inductor.runtime.compile_tasks import _reload_python_module @@ -895,6 +897,11 @@ class FxGraphHashDetails: if custom_config is not None } + # Register the custom partitioner function + self._custom_partitioner_fn = self._get_custom_partitioner_fn_detail( + config.custom_partitioner_fn + ) + # This is mainly added to handle these two inductor configs, which are (unfortunately) # sometimes cache safe: # - _pre_fusion_custom_pass @@ -927,6 +934,14 @@ class FxGraphHashDetails: assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass)) return custom_pass.uuid() + def _get_custom_partitioner_fn_detail( + self, custom_partitioner_fn: CustomPartitionerFnType + ) -> Optional[Any]: + if not custom_partitioner_fn: + return None + assert isinstance(custom_partitioner_fn, CustomPartitionerFn) + return custom_partitioner_fn.uuid() + def compiled_fx_graph_hash( gm: torch.fx.GraphModule, diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 0489bc1ba866..9e4661330045 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -65,6 +65,7 @@ from torch._inductor.cudagraph_utils import ( log_cudagraph_skip_and_bump_counter, PlaceholderInfo, ) +from torch._inductor.custom_graph_pass import CustomPartitionerFn from torch._inductor.debug import ( create_mapping_pre_post_grad_nodes, save_args_for_compile_fx_inner, @@ -2110,16 +2111,30 @@ def partition_fn( "static_lifetime_input_indices", None ) - with dynamo_utils.dynamo_timed( - "min_cut_rematerialization_partition", log_pt2_compile_event=True - ): - return min_cut_rematerialization_partition( - gm, - joint_inputs, - compiler="inductor", - static_lifetime_input_indices=static_lifetime_input_indices, - **kwargs, - ) + if config.custom_partitioner_fn is None: + with dynamo_utils.dynamo_timed( + "min_cut_rematerialization_partition", log_pt2_compile_event=True + ): + return min_cut_rematerialization_partition( + gm, + joint_inputs, + compiler="inductor", + static_lifetime_input_indices=static_lifetime_input_indices, + **kwargs, + ) + else: + assert isinstance(config.custom_partitioner_fn, CustomPartitionerFn) + with dynamo_utils.dynamo_timed( + config.custom_partitioner_fn.__class__.__name__, + log_pt2_compile_event=True, + ): + return config.custom_partitioner_fn( + gm, + joint_inputs, + compiler="inductor", + static_lifetime_input_indices=static_lifetime_input_indices, + **kwargs, + ) def get_num_model_outputs(model: GraphModule) -> int: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 44cda0ad3c62..beb1641785de 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -266,6 +266,9 @@ b2b_gemm_pass = False post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None +# Allow users to pass in custom partition function +custom_partitioner_fn: torch._inductor.custom_graph_pass.CustomPartitionerFnType = None + # Registers a custom joint graph pass. joint_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None joint_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None diff --git a/torch/_inductor/custom_graph_pass.py b/torch/_inductor/custom_graph_pass.py index c9a8e33a1145..413a224724fd 100644 --- a/torch/_inductor/custom_graph_pass.py +++ b/torch/_inductor/custom_graph_pass.py @@ -1,5 +1,6 @@ import hashlib from abc import ABC, abstractmethod +from collections.abc import Sequence from functools import lru_cache from typing import Any, Callable, Optional, Union from typing_extensions import TypeAlias @@ -102,3 +103,58 @@ def get_hash_for_files(paths: tuple[str], extra: str = "") -> bytes: hasher.update(path.encode("utf-8")) hasher.update(f.read()) return hasher.digest() + + +class CustomPartitionerFn(ABC): + """ + Implement this interface for custom partitioner: + + 1) The __call__() method contains the implementation of the custom partitioner. + + 2) The uuid() method enables inductor to cache compiled graphs when your custom + partitioner are applied. This method can return any identifier as long as it uniquely + identifies your implementation (and can be pickled). The caching logic includes this + identifier in its key calculation, i.e., any new value will effectively invalidate + existing entries. We expect custom partitioner would typically depend purely on the + textual representation of the implementation. In that case, we recommend using the + 'get_hash_for_files' helper below to compute a unique hash from the contents of a + static list of source files, i.e., the source(s) containing the custom partitioner + implementation. That approach ensures that any change to the implementation will + mean a new uuid. + + EXAMPLE: + + from torch._inductor.custom_graph_pass import get_hash_for_files + + class MyCustomPartitionerFn(CustomPartitionerFn): + def __call__( + self, + gm: torch.fx.GraphModule, + joint_inputs: Sequence[object], + **kwargs: Any + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + # my custom partitioner implementation + # ... + + def uuid(self) -> Optional[Any]: + return get_hash_for_files((__file__,)) + + """ + + @abstractmethod + def __call__( + self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + """ + Implementation of the custom partitioner. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom partitioner implementation. + Return None to skip inductor code caching entirely. + """ + + +CustomPartitionerFnType: TypeAlias = Optional[CustomPartitionerFn] diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 82edd5d4d5b6..8149bc7e98e7 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -23,7 +23,8 @@ from typing import ( ) import torch -from torch._inductor.custom_graph_pass import CustomGraphPass +from functorch.compile import min_cut_rematerialization_partition +from torch._inductor.custom_graph_pass import CustomGraphPass, CustomPartitionerFn from torch._inductor.scheduler import BaseSchedulerNode from torch.utils._config_module import _ConfigEntry, ConfigModule from torch.utils._ordered_set import OrderedSet @@ -74,6 +75,20 @@ class DummyPass(CustomGraphPass): return None +class DummyPartitionerFn(CustomPartitionerFn): + """ + A Dummy partitioner function to be used by ConfigFuzzer + """ + + def __call__( + self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs) + + def uuid(self) -> Optional[Any]: + return None + + T = TypeVar("T") @@ -84,6 +99,7 @@ class TypeExemplars: TYPE_EXEMPLARS: dict[str, Any] = { CustomGraphPass.__name__: DummyPass(), + CustomPartitionerFn.__name__: DummyPartitionerFn(), torch.fx.graph.Graph.__name__: torch.fx.graph.Graph(), BaseSchedulerNode.__name__: BaseSchedulerNode(None), # type: ignore[arg-type] } @@ -499,6 +515,7 @@ MODULE_DEFAULTS: dict[str, ConfigType] = { "joint_custom_post_pass": DEFAULT, # Typing "joint_custom_pre_pass": DEFAULT, # Typing "pre_grad_custom_pass": DEFAULT, # Typing + "custom_partitioner_fn": DEFAULT, # Typing }, "torch._dynamo.config": { "traceable_tensor_subclasses": DEFAULT, # Typing