allow user to pass in custom partitioner function (#157580)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157580
Approved by: https://github.com/bdhirsh
This commit is contained in:
Xuan Zhang
2025-09-05 12:06:16 -07:00
committed by PyTorch MergeBot
parent 9c03d6be87
commit 4d4abec80f
8 changed files with 276 additions and 11 deletions

View File

@ -58,6 +58,7 @@ from torch._functorch.aot_autograd import (
)
from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._inductor.custom_graph_pass import CustomPartitionerFn
from torch._inductor.output_code import MockFXGraphCacheOutput
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
from torch.fx.experimental.proxy_tensor import is_sym_node
@ -5687,6 +5688,49 @@ def forward(self, primals_1, tangents_1):
return (cat,)""",
)
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
def test_custom_partitioner_fn(self):
class MyCustomPartitionerFn(CustomPartitionerFn):
def __init__(self):
super().__init__()
self.called = False
def __call__(self, gm, joint_inputs, **kwargs):
self.called = True
return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs)
def uuid(self):
return None
def f(x):
return x.cos().cos()
inp = [torch.randn((4, 4), requires_grad=True)]
custom_partitioner_fn = MyCustomPartitionerFn()
fw_graph, bw_graph = get_fw_bw_graph(f, inp, partitioner=custom_partitioner_fn)
self.assertTrue(custom_partitioner_fn.called)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1):
cos = torch.ops.aten.cos.default(primals_1)
cos_1 = torch.ops.aten.cos.default(cos); cos = None
return (cos_1, primals_1)""",
)
self.assertExpectedInline(
bw_graph.code.strip(),
"""\
def forward(self, primals_1, tangents_1):
cos = torch.ops.aten.cos.default(primals_1)
sin = torch.ops.aten.sin.default(cos); cos = None
neg = torch.ops.aten.neg.default(sin); sin = None
mul = torch.ops.aten.mul.Tensor(tangents_1, neg); tangents_1 = neg = None
sin_1 = torch.ops.aten.sin.default(primals_1); primals_1 = None
neg_1 = torch.ops.aten.neg.default(sin_1); sin_1 = None
mul_1 = torch.ops.aten.mul.Tensor(mul, neg_1); mul = neg_1 = None
return (mul_1,)""",
)
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
def test_min_cut_partitioner_save_shape(self):
def f(x):

View File

@ -33,6 +33,7 @@ from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.custom_graph_pass import (
CustomGraphModulePass,
CustomGraphPass,
CustomPartitionerFn,
get_hash_for_files,
)
from torch._inductor.graph import GraphLowering
@ -2115,6 +2116,19 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
class TestCustomPartitionerFn(CustomPartitionerFn):
def __init__(self):
self._uuid = None
def __call__(
self, gm, joint_inputs, **kwargs
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
return gm, gm # Dummy implementation
def uuid(self) -> Optional[Union[bytes, str]]:
return self._uuid
class TestFxGraphCacheHashing(TestCase):
def test_parameter_constants(self):
"""
@ -2520,6 +2534,35 @@ class TestFxGraphCacheHashing(TestCase):
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
def test_hash_custom_partitioner_fn(self):
"""
Test that the custom partitioner function's UUID is properly used in the FX graph cache hashing.
"""
custom_partitioner_fn = TestCustomPartitionerFn()
with config.patch({"custom_partitioner_fn": custom_partitioner_fn}):
custom_partitioner_fn._uuid = "1"
details1 = FxGraphHashDetails(None, [], {}, [])
details2 = FxGraphHashDetails(None, [], {}, [])
custom_partitioner_fn._uuid = "2"
details3 = FxGraphHashDetails(None, [], {}, [])
self.assertEqual(details1._custom_partitioner_fn, "1")
self.assertEqual(details2._custom_partitioner_fn, "1")
self.assertEqual(details3._custom_partitioner_fn, "2")
gm = torch.fx.GraphModule({}, torch.fx.Graph())
pickler = FxGraphCachePickler(gm)
self.assertEqual(
pickler.dumps(details1),
pickler.dumps(details2),
)
self.assertNotEqual(
pickler.dumps(details1),
pickler.dumps(details3),
)
def test_bypass_unsupported(self):
"""
Test _reduce_unsupported

View File

@ -0,0 +1,72 @@
# Owner(s): ["module: pt2-dispatcher"]
import torch
from functorch.compile import min_cut_rematerialization_partition
from torch._C import FileCheck
from torch._inductor.custom_graph_pass import CustomPartitionerFn, get_hash_for_files
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_fw_bw_and_get_code
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
class MyCustomPartitionerFn(CustomPartitionerFn):
"""
A custom partitioner function with static_lifetime_input_indices overwrites.
"""
def __init__(self):
super().__init__()
self.called = False
def __call__(self, gm, joint_inputs, **kwargs):
self.called = True
kwargs["static_lifetime_input_indices"] = [0, 1]
return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs)
def uuid(self):
return get_hash_for_files((__file__,))
class TestCustomPartitionerFn(TestCase):
def test_custom_partitioner_fn(self):
"""
For function f(a, b), with the partitioner in the compile_fx stack,
the addition `a+b` (equivalently `buf0`) is saved for backward.
With the custom partitioner function, we indicate that
`a` and `b` (equivalently `primals_1` and `primals_2`) do not take
additional memory and thus, they are saved for backward.
"""
# initialization
@torch.compile
def f(a, b):
return (a + b).cos().cos()
a = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE)
b = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE)
# CASE 1 -- default
# addition `a + b` (i.e, `buf0`) is saved for backward.
code_og = run_fw_bw_and_get_code(lambda: f(a, b))
fwd_code_og = code_og[1][0]
FileCheck().check("return (buf1, buf0, )").run(fwd_code_og)
# CASE 2 -- custom partitioner function
# `a` and `b` (i.e., `primals_1` and `primals_2`) are saved for backward.
custom_partitioner_fn = MyCustomPartitionerFn()
self.assertFalse(custom_partitioner_fn.called)
self.assertIsNotNone(custom_partitioner_fn.uuid())
with torch._inductor.config.patch(custom_partitioner_fn=custom_partitioner_fn):
code_cp = run_fw_bw_and_get_code(lambda: f(a, b))
fwd_code_cp = code_cp[1][0]
FileCheck().check("return (buf0, primals_1, primals_2, )").run(fwd_code_cp)
# make sure the custom partitioner function is indeed invoked
self.assertTrue(custom_partitioner_fn.called)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests()

View File

@ -83,6 +83,8 @@ from torch._inductor.custom_graph_pass import (
CustomGraphModulePass,
CustomGraphPass,
CustomGraphPassType,
CustomPartitionerFn,
CustomPartitionerFnType,
)
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
from torch._inductor.runtime.compile_tasks import _reload_python_module
@ -895,6 +897,11 @@ class FxGraphHashDetails:
if custom_config is not None
}
# Register the custom partitioner function
self._custom_partitioner_fn = self._get_custom_partitioner_fn_detail(
config.custom_partitioner_fn
)
# This is mainly added to handle these two inductor configs, which are (unfortunately)
# sometimes cache safe:
# - _pre_fusion_custom_pass
@ -927,6 +934,14 @@ class FxGraphHashDetails:
assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass))
return custom_pass.uuid()
def _get_custom_partitioner_fn_detail(
self, custom_partitioner_fn: CustomPartitionerFnType
) -> Optional[Any]:
if not custom_partitioner_fn:
return None
assert isinstance(custom_partitioner_fn, CustomPartitionerFn)
return custom_partitioner_fn.uuid()
def compiled_fx_graph_hash(
gm: torch.fx.GraphModule,

View File

@ -65,6 +65,7 @@ from torch._inductor.cudagraph_utils import (
log_cudagraph_skip_and_bump_counter,
PlaceholderInfo,
)
from torch._inductor.custom_graph_pass import CustomPartitionerFn
from torch._inductor.debug import (
create_mapping_pre_post_grad_nodes,
save_args_for_compile_fx_inner,
@ -2110,16 +2111,30 @@ def partition_fn(
"static_lifetime_input_indices", None
)
with dynamo_utils.dynamo_timed(
"min_cut_rematerialization_partition", log_pt2_compile_event=True
):
return min_cut_rematerialization_partition(
gm,
joint_inputs,
compiler="inductor",
static_lifetime_input_indices=static_lifetime_input_indices,
**kwargs,
)
if config.custom_partitioner_fn is None:
with dynamo_utils.dynamo_timed(
"min_cut_rematerialization_partition", log_pt2_compile_event=True
):
return min_cut_rematerialization_partition(
gm,
joint_inputs,
compiler="inductor",
static_lifetime_input_indices=static_lifetime_input_indices,
**kwargs,
)
else:
assert isinstance(config.custom_partitioner_fn, CustomPartitionerFn)
with dynamo_utils.dynamo_timed(
config.custom_partitioner_fn.__class__.__name__,
log_pt2_compile_event=True,
):
return config.custom_partitioner_fn(
gm,
joint_inputs,
compiler="inductor",
static_lifetime_input_indices=static_lifetime_input_indices,
**kwargs,
)
def get_num_model_outputs(model: GraphModule) -> int:

View File

@ -266,6 +266,9 @@ b2b_gemm_pass = False
post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
# Allow users to pass in custom partition function
custom_partitioner_fn: torch._inductor.custom_graph_pass.CustomPartitionerFnType = None
# Registers a custom joint graph pass.
joint_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
joint_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None

View File

@ -1,5 +1,6 @@
import hashlib
from abc import ABC, abstractmethod
from collections.abc import Sequence
from functools import lru_cache
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeAlias
@ -102,3 +103,58 @@ def get_hash_for_files(paths: tuple[str], extra: str = "") -> bytes:
hasher.update(path.encode("utf-8"))
hasher.update(f.read())
return hasher.digest()
class CustomPartitionerFn(ABC):
"""
Implement this interface for custom partitioner:
1) The __call__() method contains the implementation of the custom partitioner.
2) The uuid() method enables inductor to cache compiled graphs when your custom
partitioner are applied. This method can return any identifier as long as it uniquely
identifies your implementation (and can be pickled). The caching logic includes this
identifier in its key calculation, i.e., any new value will effectively invalidate
existing entries. We expect custom partitioner would typically depend purely on the
textual representation of the implementation. In that case, we recommend using the
'get_hash_for_files' helper below to compute a unique hash from the contents of a
static list of source files, i.e., the source(s) containing the custom partitioner
implementation. That approach ensures that any change to the implementation will
mean a new uuid.
EXAMPLE:
from torch._inductor.custom_graph_pass import get_hash_for_files
class MyCustomPartitionerFn(CustomPartitionerFn):
def __call__(
self,
gm: torch.fx.GraphModule,
joint_inputs: Sequence[object],
**kwargs: Any
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
# my custom partitioner implementation
# ...
def uuid(self) -> Optional[Any]:
return get_hash_for_files((__file__,))
"""
@abstractmethod
def __call__(
self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
"""
Implementation of the custom partitioner.
"""
@abstractmethod
def uuid(self) -> Optional[Any]:
"""
Return an ID to uniquely identify your custom partitioner implementation.
Return None to skip inductor code caching entirely.
"""
CustomPartitionerFnType: TypeAlias = Optional[CustomPartitionerFn]

View File

@ -23,7 +23,8 @@ from typing import (
)
import torch
from torch._inductor.custom_graph_pass import CustomGraphPass
from functorch.compile import min_cut_rematerialization_partition
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomPartitionerFn
from torch._inductor.scheduler import BaseSchedulerNode
from torch.utils._config_module import _ConfigEntry, ConfigModule
from torch.utils._ordered_set import OrderedSet
@ -74,6 +75,20 @@ class DummyPass(CustomGraphPass):
return None
class DummyPartitionerFn(CustomPartitionerFn):
"""
A Dummy partitioner function to be used by ConfigFuzzer
"""
def __call__(
self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs)
def uuid(self) -> Optional[Any]:
return None
T = TypeVar("T")
@ -84,6 +99,7 @@ class TypeExemplars:
TYPE_EXEMPLARS: dict[str, Any] = {
CustomGraphPass.__name__: DummyPass(),
CustomPartitionerFn.__name__: DummyPartitionerFn(),
torch.fx.graph.Graph.__name__: torch.fx.graph.Graph(),
BaseSchedulerNode.__name__: BaseSchedulerNode(None), # type: ignore[arg-type]
}
@ -499,6 +515,7 @@ MODULE_DEFAULTS: dict[str, ConfigType] = {
"joint_custom_post_pass": DEFAULT, # Typing
"joint_custom_pre_pass": DEFAULT, # Typing
"pre_grad_custom_pass": DEFAULT, # Typing
"custom_partitioner_fn": DEFAULT, # Typing
},
"torch._dynamo.config": {
"traceable_tensor_subclasses": DEFAULT, # Typing