mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Introduce joint_custom_pass callback (#164981)
``` def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs): # apply your pass for joint graph here return joint_gm class M(torch.nn.Module): def forward(self, x): return x.sin() x = torch.randn(10, requires_grad=False) compiled_fn = torch.compile(M(), backend="aot_eager") with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass): out = compiled_fn(x) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164981 Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
1f73b96668
commit
e532f62e0d
@ -715,6 +715,42 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
|
||||
out = compiled_fn(x, y)
|
||||
out.sum().backward()
|
||||
|
||||
def test_joint_custom_pass(self):
|
||||
is_called = False
|
||||
|
||||
def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs):
|
||||
nonlocal is_called
|
||||
is_called = True
|
||||
|
||||
self.assertTrue(isinstance(joint_gm, torch.fx.GraphModule))
|
||||
|
||||
self.assertTrue(isinstance(joint_inputs, tuple))
|
||||
# first input is list of primals
|
||||
self.assertTrue(isinstance(joint_inputs[0], list))
|
||||
# second input is list of tangents
|
||||
self.assertTrue(isinstance(joint_inputs[1], list))
|
||||
|
||||
return joint_gm
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.sin()
|
||||
|
||||
x = torch.randn(10, requires_grad=False)
|
||||
compiled_fn = torch.compile(M(), backend="aot_eager")
|
||||
|
||||
with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
|
||||
_ = compiled_fn(x)
|
||||
# x doesn't require grad, shouldn't trigger joint graph compiler
|
||||
self.assertFalse(is_called)
|
||||
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
|
||||
out = compiled_fn(y)
|
||||
# y requires grad, should trigger joint graph compiler
|
||||
self.assertTrue(is_called)
|
||||
out.sum().backward()
|
||||
|
||||
@expectedFailureDynamic # https://github.com/pytorch/pytorch/issues/103539
|
||||
@torch._dynamo.config.patch(automatic_dynamic_shapes=False)
|
||||
@patch("torch._functorch.config.debug_assert", True)
|
||||
|
@ -305,7 +305,7 @@ def aot_stage2_inference(
|
||||
"name": "torch._functorch.config",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: torch._functorch.config.get_config_copy(),
|
||||
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
|
||||
)
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
@ -1410,6 +1410,10 @@ def aot_stage2_autograd(
|
||||
if fake_mode is not None and fake_mode.shape_env is not None:
|
||||
tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode)
|
||||
|
||||
# apply joint_gm callback here
|
||||
if callable(torch._functorch.config.joint_custom_pass):
|
||||
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
|
||||
|
||||
static_lifetime_input_indices = fw_metadata.static_input_indices
|
||||
fw_module, bw_module = aot_config.partition_fn(
|
||||
fx_g,
|
||||
@ -1491,7 +1495,7 @@ def aot_stage2_autograd(
|
||||
"name": "torch._functorch.config",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: torch._functorch.config.get_config_copy(),
|
||||
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
|
||||
)
|
||||
aot_graphs_log.info(
|
||||
"aot_config id: %s, fw_metadata=%s, inner_meta=%s",
|
||||
|
@ -4,6 +4,9 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
"""
|
||||
Global flags for aot autograd
|
||||
"""
|
||||
@ -15,6 +18,13 @@ from typing import Literal, Optional, TYPE_CHECKING
|
||||
from torch.utils._config_module import Config, install_config_module
|
||||
|
||||
|
||||
# [@compile_ignored: debug]
|
||||
_save_config_ignore = [
|
||||
# callable not serializeable
|
||||
"joint_custom_pass",
|
||||
]
|
||||
|
||||
|
||||
# Converts torch rng ops to their functional philox rng equivalents. Note that
|
||||
# we functionalize only CUDA rng ops today.
|
||||
functionalize_rng_ops = False
|
||||
@ -358,6 +368,10 @@ _sync_decision_cross_ranks = False
|
||||
saved_tensors_hooks_filtering_mode = "donated"
|
||||
|
||||
|
||||
# This callback is invoked on the joint graph before partitioning
|
||||
joint_custom_pass: Callable = None # type: ignore[assignment]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
@ -614,6 +614,9 @@ class ConfigModule(ModuleType):
|
||||
def get_config_copy(self) -> dict[str, Any]:
|
||||
return self._get_dict()
|
||||
|
||||
def get_serializable_config_copy(self) -> dict[str, Any]:
|
||||
return self._get_dict(ignored_keys=getattr(self, "_save_config_ignore", []))
|
||||
|
||||
def patch(
|
||||
self,
|
||||
arg1: Optional[Union[str, dict[str, Any]]] = None,
|
||||
|
@ -31,4 +31,5 @@ def to_dict() -> dict[str, Any]: ...
|
||||
def shallow_copy_dict() -> dict[str, Any]: ...
|
||||
def load_config(config: bytes | dict[str, Any]) -> None: ...
|
||||
def get_config_copy() -> dict[str, Any]: ...
|
||||
def get_serializable_config_copy() -> dict[str, Any]: ...
|
||||
def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ...
|
||||
|
Reference in New Issue
Block a user