mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Grab bag of (mostly) typing improvements (#158075)
Collects some scattershot improvements made while attempting to enable training for AOTInductor. Non-typing changes are: 1. Swapping a few custom searches for the output node in an FX graph for calling `graph.output_node()`. 2. Removing two unused parameters from `torch.export._unlift._unlift`. 3. Switching handles to constants in `cpp_wrapper_cpu` to use C++ references for memory efficiency. 4. Cleaning out unused, unexported imports from `torch/export/__init__.py`, and adding one missing export to `__all__`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158075 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad2dec1997
commit
22920c9138
@ -1,59 +1,38 @@
|
||||
import builtins
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
import zipfile
|
||||
from collections.abc import Iterator
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
from torch.types import FileLike
|
||||
from torch.utils._pytree import (
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
ToDumpableContextFn,
|
||||
UnflattenFunc,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# Do not import unconditionally, as they import sympy and importing sympy is very slow
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdditionalInputs",
|
||||
"Constraint",
|
||||
"Dim",
|
||||
"ExportBackwardSignature",
|
||||
"ExportGraphSignature",
|
||||
"ExportedProgram",
|
||||
"CustomDecompTable",
|
||||
"default_decompositions",
|
||||
"Dim",
|
||||
"dims",
|
||||
"draft_export",
|
||||
"export_for_training",
|
||||
"export",
|
||||
"ExportBackwardSignature",
|
||||
"ExportedProgram",
|
||||
"ExportGraphSignature",
|
||||
"FlatArgsAdapter",
|
||||
"load",
|
||||
"ModuleCallEntry",
|
||||
"ModuleCallSignature",
|
||||
"default_decompositions",
|
||||
"dims",
|
||||
"export",
|
||||
"export_for_training",
|
||||
"load",
|
||||
"register_dataclass",
|
||||
"save",
|
||||
"ShapesCollection",
|
||||
"unflatten",
|
||||
"FlatArgsAdapter",
|
||||
"UnflattenedModule",
|
||||
"AdditionalInputs",
|
||||
"draft_export",
|
||||
]
|
||||
|
||||
# To make sure export specific custom ops are loaded
|
||||
@ -82,9 +61,9 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
||||
def export_for_training(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
@ -181,9 +160,9 @@ def export_for_training(
|
||||
def export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
@ -540,9 +519,9 @@ def load(
|
||||
def draft_export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
strict: bool = False,
|
||||
) -> ExportedProgram:
|
||||
|
Reference in New Issue
Block a user