mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
allow user to pass in custom partitioner function (#157580)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157580 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
9c03d6be87
commit
4d4abec80f
@ -58,6 +58,7 @@ from torch._functorch.aot_autograd import (
|
||||
)
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
||||
from torch._inductor.custom_graph_pass import CustomPartitionerFn
|
||||
from torch._inductor.output_code import MockFXGraphCacheOutput
|
||||
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import is_sym_node
|
||||
@ -5687,6 +5688,49 @@ def forward(self, primals_1, tangents_1):
|
||||
return (cat,)""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
|
||||
def test_custom_partitioner_fn(self):
|
||||
class MyCustomPartitionerFn(CustomPartitionerFn):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.called = False
|
||||
|
||||
def __call__(self, gm, joint_inputs, **kwargs):
|
||||
self.called = True
|
||||
return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs)
|
||||
|
||||
def uuid(self):
|
||||
return None
|
||||
|
||||
def f(x):
|
||||
return x.cos().cos()
|
||||
|
||||
inp = [torch.randn((4, 4), requires_grad=True)]
|
||||
custom_partitioner_fn = MyCustomPartitionerFn()
|
||||
fw_graph, bw_graph = get_fw_bw_graph(f, inp, partitioner=custom_partitioner_fn)
|
||||
self.assertTrue(custom_partitioner_fn.called)
|
||||
self.assertExpectedInline(
|
||||
fw_graph.code.strip(),
|
||||
"""\
|
||||
def forward(self, primals_1):
|
||||
cos = torch.ops.aten.cos.default(primals_1)
|
||||
cos_1 = torch.ops.aten.cos.default(cos); cos = None
|
||||
return (cos_1, primals_1)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
bw_graph.code.strip(),
|
||||
"""\
|
||||
def forward(self, primals_1, tangents_1):
|
||||
cos = torch.ops.aten.cos.default(primals_1)
|
||||
sin = torch.ops.aten.sin.default(cos); cos = None
|
||||
neg = torch.ops.aten.neg.default(sin); sin = None
|
||||
mul = torch.ops.aten.mul.Tensor(tangents_1, neg); tangents_1 = neg = None
|
||||
sin_1 = torch.ops.aten.sin.default(primals_1); primals_1 = None
|
||||
neg_1 = torch.ops.aten.neg.default(sin_1); sin_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(mul, neg_1); mul = neg_1 = None
|
||||
return (mul_1,)""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
|
||||
def test_min_cut_partitioner_save_shape(self):
|
||||
def f(x):
|
||||
|
@ -33,6 +33,7 @@ from torch._inductor.cpp_builder import normalize_path_separator
|
||||
from torch._inductor.custom_graph_pass import (
|
||||
CustomGraphModulePass,
|
||||
CustomGraphPass,
|
||||
CustomPartitionerFn,
|
||||
get_hash_for_files,
|
||||
)
|
||||
from torch._inductor.graph import GraphLowering
|
||||
@ -2115,6 +2116,19 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
||||
|
||||
class TestCustomPartitionerFn(CustomPartitionerFn):
|
||||
def __init__(self):
|
||||
self._uuid = None
|
||||
|
||||
def __call__(
|
||||
self, gm, joint_inputs, **kwargs
|
||||
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
|
||||
return gm, gm # Dummy implementation
|
||||
|
||||
def uuid(self) -> Optional[Union[bytes, str]]:
|
||||
return self._uuid
|
||||
|
||||
|
||||
class TestFxGraphCacheHashing(TestCase):
|
||||
def test_parameter_constants(self):
|
||||
"""
|
||||
@ -2520,6 +2534,35 @@ class TestFxGraphCacheHashing(TestCase):
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
def test_hash_custom_partitioner_fn(self):
|
||||
"""
|
||||
Test that the custom partitioner function's UUID is properly used in the FX graph cache hashing.
|
||||
"""
|
||||
custom_partitioner_fn = TestCustomPartitionerFn()
|
||||
with config.patch({"custom_partitioner_fn": custom_partitioner_fn}):
|
||||
custom_partitioner_fn._uuid = "1"
|
||||
details1 = FxGraphHashDetails(None, [], {}, [])
|
||||
details2 = FxGraphHashDetails(None, [], {}, [])
|
||||
|
||||
custom_partitioner_fn._uuid = "2"
|
||||
details3 = FxGraphHashDetails(None, [], {}, [])
|
||||
|
||||
self.assertEqual(details1._custom_partitioner_fn, "1")
|
||||
self.assertEqual(details2._custom_partitioner_fn, "1")
|
||||
self.assertEqual(details3._custom_partitioner_fn, "2")
|
||||
|
||||
gm = torch.fx.GraphModule({}, torch.fx.Graph())
|
||||
pickler = FxGraphCachePickler(gm)
|
||||
|
||||
self.assertEqual(
|
||||
pickler.dumps(details1),
|
||||
pickler.dumps(details2),
|
||||
)
|
||||
self.assertNotEqual(
|
||||
pickler.dumps(details1),
|
||||
pickler.dumps(details3),
|
||||
)
|
||||
|
||||
def test_bypass_unsupported(self):
|
||||
"""
|
||||
Test _reduce_unsupported
|
||||
|
72
test/inductor/test_custom_partitioner_fn.py
Normal file
72
test/inductor/test_custom_partitioner_fn.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Owner(s): ["module: pt2-dispatcher"]
|
||||
import torch
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._C import FileCheck
|
||||
from torch._inductor.custom_graph_pass import CustomPartitionerFn, get_hash_for_files
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import run_fw_bw_and_get_code
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
class MyCustomPartitionerFn(CustomPartitionerFn):
|
||||
"""
|
||||
A custom partitioner function with static_lifetime_input_indices overwrites.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.called = False
|
||||
|
||||
def __call__(self, gm, joint_inputs, **kwargs):
|
||||
self.called = True
|
||||
kwargs["static_lifetime_input_indices"] = [0, 1]
|
||||
return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs)
|
||||
|
||||
def uuid(self):
|
||||
return get_hash_for_files((__file__,))
|
||||
|
||||
|
||||
class TestCustomPartitionerFn(TestCase):
|
||||
def test_custom_partitioner_fn(self):
|
||||
"""
|
||||
For function f(a, b), with the partitioner in the compile_fx stack,
|
||||
the addition `a+b` (equivalently `buf0`) is saved for backward.
|
||||
With the custom partitioner function, we indicate that
|
||||
`a` and `b` (equivalently `primals_1` and `primals_2`) do not take
|
||||
additional memory and thus, they are saved for backward.
|
||||
"""
|
||||
|
||||
# initialization
|
||||
@torch.compile
|
||||
def f(a, b):
|
||||
return (a + b).cos().cos()
|
||||
|
||||
a = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE)
|
||||
b = torch.randn((2, 2), requires_grad=True, device=GPU_TYPE)
|
||||
|
||||
# CASE 1 -- default
|
||||
# addition `a + b` (i.e, `buf0`) is saved for backward.
|
||||
code_og = run_fw_bw_and_get_code(lambda: f(a, b))
|
||||
fwd_code_og = code_og[1][0]
|
||||
FileCheck().check("return (buf1, buf0, )").run(fwd_code_og)
|
||||
|
||||
# CASE 2 -- custom partitioner function
|
||||
# `a` and `b` (i.e., `primals_1` and `primals_2`) are saved for backward.
|
||||
custom_partitioner_fn = MyCustomPartitionerFn()
|
||||
self.assertFalse(custom_partitioner_fn.called)
|
||||
self.assertIsNotNone(custom_partitioner_fn.uuid())
|
||||
|
||||
with torch._inductor.config.patch(custom_partitioner_fn=custom_partitioner_fn):
|
||||
code_cp = run_fw_bw_and_get_code(lambda: f(a, b))
|
||||
fwd_code_cp = code_cp[1][0]
|
||||
FileCheck().check("return (buf0, primals_1, primals_2, )").run(fwd_code_cp)
|
||||
|
||||
# make sure the custom partitioner function is indeed invoked
|
||||
self.assertTrue(custom_partitioner_fn.called)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if HAS_GPU:
|
||||
run_tests()
|
@ -83,6 +83,8 @@ from torch._inductor.custom_graph_pass import (
|
||||
CustomGraphModulePass,
|
||||
CustomGraphPass,
|
||||
CustomGraphPassType,
|
||||
CustomPartitionerFn,
|
||||
CustomPartitionerFnType,
|
||||
)
|
||||
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
|
||||
from torch._inductor.runtime.compile_tasks import _reload_python_module
|
||||
@ -895,6 +897,11 @@ class FxGraphHashDetails:
|
||||
if custom_config is not None
|
||||
}
|
||||
|
||||
# Register the custom partitioner function
|
||||
self._custom_partitioner_fn = self._get_custom_partitioner_fn_detail(
|
||||
config.custom_partitioner_fn
|
||||
)
|
||||
|
||||
# This is mainly added to handle these two inductor configs, which are (unfortunately)
|
||||
# sometimes cache safe:
|
||||
# - _pre_fusion_custom_pass
|
||||
@ -927,6 +934,14 @@ class FxGraphHashDetails:
|
||||
assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass))
|
||||
return custom_pass.uuid()
|
||||
|
||||
def _get_custom_partitioner_fn_detail(
|
||||
self, custom_partitioner_fn: CustomPartitionerFnType
|
||||
) -> Optional[Any]:
|
||||
if not custom_partitioner_fn:
|
||||
return None
|
||||
assert isinstance(custom_partitioner_fn, CustomPartitionerFn)
|
||||
return custom_partitioner_fn.uuid()
|
||||
|
||||
|
||||
def compiled_fx_graph_hash(
|
||||
gm: torch.fx.GraphModule,
|
||||
|
@ -65,6 +65,7 @@ from torch._inductor.cudagraph_utils import (
|
||||
log_cudagraph_skip_and_bump_counter,
|
||||
PlaceholderInfo,
|
||||
)
|
||||
from torch._inductor.custom_graph_pass import CustomPartitionerFn
|
||||
from torch._inductor.debug import (
|
||||
create_mapping_pre_post_grad_nodes,
|
||||
save_args_for_compile_fx_inner,
|
||||
@ -2110,16 +2111,30 @@ def partition_fn(
|
||||
"static_lifetime_input_indices", None
|
||||
)
|
||||
|
||||
with dynamo_utils.dynamo_timed(
|
||||
"min_cut_rematerialization_partition", log_pt2_compile_event=True
|
||||
):
|
||||
return min_cut_rematerialization_partition(
|
||||
gm,
|
||||
joint_inputs,
|
||||
compiler="inductor",
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
**kwargs,
|
||||
)
|
||||
if config.custom_partitioner_fn is None:
|
||||
with dynamo_utils.dynamo_timed(
|
||||
"min_cut_rematerialization_partition", log_pt2_compile_event=True
|
||||
):
|
||||
return min_cut_rematerialization_partition(
|
||||
gm,
|
||||
joint_inputs,
|
||||
compiler="inductor",
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
assert isinstance(config.custom_partitioner_fn, CustomPartitionerFn)
|
||||
with dynamo_utils.dynamo_timed(
|
||||
config.custom_partitioner_fn.__class__.__name__,
|
||||
log_pt2_compile_event=True,
|
||||
):
|
||||
return config.custom_partitioner_fn(
|
||||
gm,
|
||||
joint_inputs,
|
||||
compiler="inductor",
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_num_model_outputs(model: GraphModule) -> int:
|
||||
|
@ -266,6 +266,9 @@ b2b_gemm_pass = False
|
||||
post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
|
||||
post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
|
||||
|
||||
# Allow users to pass in custom partition function
|
||||
custom_partitioner_fn: torch._inductor.custom_graph_pass.CustomPartitionerFnType = None
|
||||
|
||||
# Registers a custom joint graph pass.
|
||||
joint_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
|
||||
joint_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
|
||||
|
@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing_extensions import TypeAlias
|
||||
@ -102,3 +103,58 @@ def get_hash_for_files(paths: tuple[str], extra: str = "") -> bytes:
|
||||
hasher.update(path.encode("utf-8"))
|
||||
hasher.update(f.read())
|
||||
return hasher.digest()
|
||||
|
||||
|
||||
class CustomPartitionerFn(ABC):
|
||||
"""
|
||||
Implement this interface for custom partitioner:
|
||||
|
||||
1) The __call__() method contains the implementation of the custom partitioner.
|
||||
|
||||
2) The uuid() method enables inductor to cache compiled graphs when your custom
|
||||
partitioner are applied. This method can return any identifier as long as it uniquely
|
||||
identifies your implementation (and can be pickled). The caching logic includes this
|
||||
identifier in its key calculation, i.e., any new value will effectively invalidate
|
||||
existing entries. We expect custom partitioner would typically depend purely on the
|
||||
textual representation of the implementation. In that case, we recommend using the
|
||||
'get_hash_for_files' helper below to compute a unique hash from the contents of a
|
||||
static list of source files, i.e., the source(s) containing the custom partitioner
|
||||
implementation. That approach ensures that any change to the implementation will
|
||||
mean a new uuid.
|
||||
|
||||
EXAMPLE:
|
||||
|
||||
from torch._inductor.custom_graph_pass import get_hash_for_files
|
||||
|
||||
class MyCustomPartitionerFn(CustomPartitionerFn):
|
||||
def __call__(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
joint_inputs: Sequence[object],
|
||||
**kwargs: Any
|
||||
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
|
||||
# my custom partitioner implementation
|
||||
# ...
|
||||
|
||||
def uuid(self) -> Optional[Any]:
|
||||
return get_hash_for_files((__file__,))
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any
|
||||
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
|
||||
"""
|
||||
Implementation of the custom partitioner.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def uuid(self) -> Optional[Any]:
|
||||
"""
|
||||
Return an ID to uniquely identify your custom partitioner implementation.
|
||||
Return None to skip inductor code caching entirely.
|
||||
"""
|
||||
|
||||
|
||||
CustomPartitionerFnType: TypeAlias = Optional[CustomPartitionerFn]
|
||||
|
@ -23,7 +23,8 @@ from typing import (
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomPartitionerFn
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
from torch.utils._config_module import _ConfigEntry, ConfigModule
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -74,6 +75,20 @@ class DummyPass(CustomGraphPass):
|
||||
return None
|
||||
|
||||
|
||||
class DummyPartitionerFn(CustomPartitionerFn):
|
||||
"""
|
||||
A Dummy partitioner function to be used by ConfigFuzzer
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any
|
||||
) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]:
|
||||
return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs)
|
||||
|
||||
def uuid(self) -> Optional[Any]:
|
||||
return None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@ -84,6 +99,7 @@ class TypeExemplars:
|
||||
|
||||
TYPE_EXEMPLARS: dict[str, Any] = {
|
||||
CustomGraphPass.__name__: DummyPass(),
|
||||
CustomPartitionerFn.__name__: DummyPartitionerFn(),
|
||||
torch.fx.graph.Graph.__name__: torch.fx.graph.Graph(),
|
||||
BaseSchedulerNode.__name__: BaseSchedulerNode(None), # type: ignore[arg-type]
|
||||
}
|
||||
@ -499,6 +515,7 @@ MODULE_DEFAULTS: dict[str, ConfigType] = {
|
||||
"joint_custom_post_pass": DEFAULT, # Typing
|
||||
"joint_custom_pre_pass": DEFAULT, # Typing
|
||||
"pre_grad_custom_pass": DEFAULT, # Typing
|
||||
"custom_partitioner_fn": DEFAULT, # Typing
|
||||
},
|
||||
"torch._dynamo.config": {
|
||||
"traceable_tensor_subclasses": DEFAULT, # Typing
|
||||
|
Reference in New Issue
Block a user