mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Introduce joint_custom_pass callback (#164981)
```
def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs):
# apply your pass for joint graph here
return joint_gm
class M(torch.nn.Module):
def forward(self, x):
return x.sin()
x = torch.randn(10, requires_grad=False)
compiled_fn = torch.compile(M(), backend="aot_eager")
with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
out = compiled_fn(x)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164981
Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
1f73b96668
commit
e532f62e0d
@ -614,6 +614,9 @@ class ConfigModule(ModuleType):
|
||||
def get_config_copy(self) -> dict[str, Any]:
|
||||
return self._get_dict()
|
||||
|
||||
def get_serializable_config_copy(self) -> dict[str, Any]:
|
||||
return self._get_dict(ignored_keys=getattr(self, "_save_config_ignore", []))
|
||||
|
||||
def patch(
|
||||
self,
|
||||
arg1: Optional[Union[str, dict[str, Any]]] = None,
|
||||
|
||||
@ -31,4 +31,5 @@ def to_dict() -> dict[str, Any]: ...
|
||||
def shallow_copy_dict() -> dict[str, Any]: ...
|
||||
def load_config(config: bytes | dict[str, Any]) -> None: ...
|
||||
def get_config_copy() -> dict[str, Any]: ...
|
||||
def get_serializable_config_copy() -> dict[str, Any]: ...
|
||||
def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ...
|
||||
|
||||
Reference in New Issue
Block a user