Files
pytorch/torch/utils/_config_typing.pyi
Sherlock Huang e532f62e0d Introduce joint_custom_pass callback (#164981)
```
        def joint_custom_pass(joint_gm: torch.fx.GraphModule, joint_inputs):
           # apply your pass for joint graph here

            return joint_gm

        class M(torch.nn.Module):
            def forward(self, x):
                return x.sin()

        x = torch.randn(10, requires_grad=False)
        compiled_fn = torch.compile(M(), backend="aot_eager")

        with torch._functorch.config.patch("joint_custom_pass", joint_custom_pass):
            out = compiled_fn(x)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164981
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2025-10-09 04:40:54 +00:00

36 lines
1.2 KiB
Python

# mypy: allow-untyped-defs
from typing import Any, TYPE_CHECKING
"""
This was semi-automatically generated by running
stubgen torch.utils._config_module.py
And then manually extracting the methods of ConfigModule and converting them into top-level functions.
This file should be imported into any file that uses install_config_module like so:
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
from torch.utils._config_module import install_config_module
# adds patch, save_config, etc
install_config_module(sys.modules[__name__])
Note that the import should happen before the call to install_config_module(), otherwise runtime errors may occur.
"""
assert TYPE_CHECKING, "Do not use at runtime"
def save_config() -> bytes: ...
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...
def codegen_config() -> str: ...
def get_hash() -> bytes: ...
def to_dict() -> dict[str, Any]: ...
def shallow_copy_dict() -> dict[str, Any]: ...
def load_config(config: bytes | dict[str, Any]) -> None: ...
def get_config_copy() -> dict[str, Any]: ...
def get_serializable_config_copy() -> dict[str, Any]: ...
def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ...