Introduce joint_custom_pass callback (#164981)

```
        def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs):
           # apply your pass for joint graph here

            return joint_gm

        class M(torch.nn.Module):
            def forward(self, x):
                return x.sin()

        x = torch.randn(10, requires_grad=False)
        compiled_fn = torch.compile(M(), backend="aot_eager")

        with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
            out = compiled_fn(x)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164981
Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
Sherlock Huang
2025-10-08 13:06:32 -07:00
committed by PyTorch MergeBot
parent 1f73b96668
commit e532f62e0d
5 changed files with 60 additions and 2 deletions

View File

@ -614,6 +614,9 @@ class ConfigModule(ModuleType):
def get_config_copy(self) -> dict[str, Any]:
return self._get_dict()
def get_serializable_config_copy(self) -> dict[str, Any]:
return self._get_dict(ignored_keys=getattr(self, "_save_config_ignore", []))
def patch(
self,
arg1: Optional[Union[str, dict[str, Any]]] = None,

View File

@ -31,4 +31,5 @@ def to_dict() -> dict[str, Any]: ...
def shallow_copy_dict() -> dict[str, Any]: ...
def load_config(config: bytes | dict[str, Any]) -> None: ...
def get_config_copy() -> dict[str, Any]: ...
def get_serializable_config_copy() -> dict[str, Any]: ...
def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ...