PEP585 update - torch/export (#145165)

See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145165
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-18 15:03:05 -08:00
committed by PyTorch MergeBot
parent 316808e4e9
commit b6c5562c1f
14 changed files with 257 additions and 269 deletions

View File

@ -8,12 +8,12 @@ import sys
import typing
import warnings
import zipfile
from collections.abc import Iterator
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
@ -82,12 +82,12 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
def export_for_training(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
"""
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -177,13 +177,13 @@ def export_for_training(
def export_for_inference(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
decomp_table: Optional[Dict["OpOverload", Optional[Callable]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
decomp_table: Optional[dict["OpOverload", Optional[Callable]]] = None,
) -> ExportedProgram:
"""
:func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -262,12 +262,12 @@ def export_for_inference(
def export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
"""
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -383,8 +383,8 @@ def save(
ep: ExportedProgram,
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
opset_version: Optional[Dict[str, int]] = None,
extra_files: Optional[dict[str, Any]] = None,
opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
) -> None:
"""
@ -466,8 +466,8 @@ def save(
def load(
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
expected_opset_version: Optional[Dict[str, int]] = None,
extra_files: Optional[dict[str, Any]] = None,
expected_opset_version: Optional[dict[str, int]] = None,
) -> ExportedProgram:
"""
@ -577,7 +577,7 @@ def load(
def register_dataclass(
cls: Type[Any],
cls: type[Any],
*,
serialized_type_name: Optional[str] = None,
) -> None: