mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
PEP585 update - torch/export (#145165)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145165 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
316808e4e9
commit
b6c5562c1f
@ -8,12 +8,12 @@ import sys
|
||||
import typing
|
||||
import warnings
|
||||
import zipfile
|
||||
from collections.abc import Iterator
|
||||
from enum import auto, Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
@ -82,12 +82,12 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
||||
|
||||
def export_for_training(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -177,13 +177,13 @@ def export_for_training(
|
||||
|
||||
def export_for_inference(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
decomp_table: Optional[Dict["OpOverload", Optional[Callable]]] = None,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
decomp_table: Optional[dict["OpOverload", Optional[Callable]]] = None,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -262,12 +262,12 @@ def export_for_inference(
|
||||
|
||||
def export(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -383,8 +383,8 @@ def save(
|
||||
ep: ExportedProgram,
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
*,
|
||||
extra_files: Optional[Dict[str, Any]] = None,
|
||||
opset_version: Optional[Dict[str, int]] = None,
|
||||
extra_files: Optional[dict[str, Any]] = None,
|
||||
opset_version: Optional[dict[str, int]] = None,
|
||||
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
|
||||
) -> None:
|
||||
"""
|
||||
@ -466,8 +466,8 @@ def save(
|
||||
def load(
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
*,
|
||||
extra_files: Optional[Dict[str, Any]] = None,
|
||||
expected_opset_version: Optional[Dict[str, int]] = None,
|
||||
extra_files: Optional[dict[str, Any]] = None,
|
||||
expected_opset_version: Optional[dict[str, int]] = None,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
|
||||
@ -577,7 +577,7 @@ def load(
|
||||
|
||||
|
||||
def register_dataclass(
|
||||
cls: Type[Any],
|
||||
cls: type[Any],
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user