mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
``` def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs): # apply your pass for joint graph here return joint_gm class M(torch.nn.Module): def forward(self, x): return x.sin() x = torch.randn(10, requires_grad=False) compiled_fn = torch.compile(M(), backend="aot_eager") with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass): out = compiled_fn(x) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164981 Approved by: https://github.com/ezyang, https://github.com/anijain2305
36 lines
1.2 KiB
Python
36 lines
1.2 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
"""
|
|
This was semi-automatically generated by running
|
|
|
|
stubgen torch.utils._config_module.py
|
|
|
|
And then manually extracting the methods of ConfigModule and converting them into top-level functions.
|
|
|
|
This file should be imported into any file that uses install_config_module like so:
|
|
|
|
if TYPE_CHECKING:
|
|
from torch.utils._config_typing import * # noqa: F401, F403
|
|
|
|
from torch.utils._config_module import install_config_module
|
|
|
|
# adds patch, save_config, etc
|
|
install_config_module(sys.modules[__name__])
|
|
|
|
Note that the import should happen before the call to install_config_module(), otherwise runtime errors may occur.
|
|
"""
|
|
|
|
assert TYPE_CHECKING, "Do not use at runtime"
|
|
|
|
def save_config() -> bytes: ...
|
|
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
|
|
def codegen_config() -> str: ...
|
|
def get_hash() -> bytes: ...
|
|
def to_dict() -> dict[str, Any]: ...
|
|
def shallow_copy_dict() -> dict[str, Any]: ...
|
|
def load_config(config: bytes | dict[str, Any]) -> None: ...
|
|
def get_config_copy() -> dict[str, Any]: ...
|
|
def get_serializable_config_copy() -> dict[str, Any]: ...
|
|
def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ...
|