Introduce joint_custom_pass callback (#164981)

```
        def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs):
           # apply your pass for joint graph here

            return joint_gm

        class M(torch.nn.Module):
            def forward(self, x):
                return x.sin()

        x = torch.randn(10, requires_grad=False)
        compiled_fn = torch.compile(M(), backend="aot_eager")

        with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
            out = compiled_fn(x)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164981
Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
Sherlock Huang
2025-10-08 13:06:32 -07:00
committed by PyTorch MergeBot
parent 1f73b96668
commit e532f62e0d
5 changed files with 60 additions and 2 deletions

View File

@ -715,6 +715,42 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
out = compiled_fn(x, y)
out.sum().backward()
def test_joint_custom_pass(self):
is_called = False
def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs):
nonlocal is_called
is_called = True
self.assertTrue(isinstance(joint_gm, torch.fx.GraphModule))
self.assertTrue(isinstance(joint_inputs, tuple))
# first input is list of primals
self.assertTrue(isinstance(joint_inputs[0], list))
# second input is list of tangents
self.assertTrue(isinstance(joint_inputs[1], list))
return joint_gm
class M(torch.nn.Module):
def forward(self, x):
return x.sin()
x = torch.randn(10, requires_grad=False)
compiled_fn = torch.compile(M(), backend="aot_eager")
with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
_ = compiled_fn(x)
# x doesn't require grad, shouldn't trigger joint graph compiler
self.assertFalse(is_called)
y = torch.randn(10, requires_grad=True)
with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
out = compiled_fn(y)
# y requires grad, should trigger joint graph compiler
self.assertTrue(is_called)
out.sum().backward()
@expectedFailureDynamic # https://github.com/pytorch/pytorch/issues/103539
@torch._dynamo.config.patch(automatic_dynamic_shapes=False)
@patch("torch._functorch.config.debug_assert", True)

View File

@ -305,7 +305,7 @@ def aot_stage2_inference(
"name": "torch._functorch.config",
"encoding": "string",
},
payload_fn=lambda: torch._functorch.config.get_config_copy(),
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
)
disable_amp = torch._C._is_any_autocast_enabled()
@ -1410,6 +1410,10 @@ def aot_stage2_autograd(
if fake_mode is not None and fake_mode.shape_env is not None:
tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode)
# apply joint_gm callback here
if callable(torch._functorch.config.joint_custom_pass):
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
static_lifetime_input_indices = fw_metadata.static_input_indices
fw_module, bw_module = aot_config.partition_fn(
fx_g,
@ -1491,7 +1495,7 @@ def aot_stage2_autograd(
"name": "torch._functorch.config",
"encoding": "string",
},
payload_fn=lambda: torch._functorch.config.get_config_copy(),
payload_fn=lambda: torch._functorch.config.get_serializable_config_copy(),
)
aot_graphs_log.info(
"aot_config id: %s, fw_metadata=%s, inner_meta=%s",

View File

@ -4,6 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable
"""
Global flags for aot autograd
"""
@ -15,6 +18,13 @@ from typing import Literal, Optional, TYPE_CHECKING
from torch.utils._config_module import Config, install_config_module
# [@compile_ignored: debug]
_save_config_ignore = [
# callable not serializeable
"joint_custom_pass",
]
# Converts torch rng ops to their functional philox rng equivalents. Note that
# we functionalize only CUDA rng ops today.
functionalize_rng_ops = False
@ -358,6 +368,10 @@ _sync_decision_cross_ranks = False
saved_tensors_hooks_filtering_mode = "donated"
# This callback is invoked on the joint graph before partitioning
joint_custom_pass: Callable = None # type: ignore[assignment]
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -614,6 +614,9 @@ class ConfigModule(ModuleType):
def get_config_copy(self) -> dict[str, Any]:
return self._get_dict()
def get_serializable_config_copy(self) -> dict[str, Any]:
return self._get_dict(ignored_keys=getattr(self, "_save_config_ignore", []))
def patch(
self,
arg1: Optional[Union[str, dict[str, Any]]] = None,

View File

@ -31,4 +31,5 @@ def to_dict() -> dict[str, Any]: ...
def shallow_copy_dict() -> dict[str, Any]: ...
def load_config(config: bytes | dict[str, Any]) -> None: ...
def get_config_copy() -> dict[str, Any]: ...
def get_serializable_config_copy() -> dict[str, Any]: ...
def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ...