Grab bag of (mostly) typing improvements (#158075)

Collects some scattershot improvements made while attempting to enable training for AOTInductor. Non-typing changes are:

1. Swapping a few custom searches for the output node in an FX graph for calling `graph.output_node()`.
2. Removing two unused parameters from `torch.export._unlift._unlift`.
3. Switching handles to constants in `cpp_wrapper_cpu` to use C++ references for memory efficiency.
4. Cleaning out unused, unexported imports from `torch/export/__init__.py`, and adding one missing export to `__all__`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158075
Approved by: https://github.com/Skylion007
This commit is contained in:
Benjamin Glass
2025-07-21 15:42:02 +00:00
committed by PyTorch MergeBot
parent ad2dec1997
commit 22920c9138
13 changed files with 126 additions and 114 deletions

View File

@ -1,59 +1,38 @@
import builtins
import copy
import dataclasses
import inspect
import os
import sys
import typing
import warnings
import zipfile
from collections.abc import Iterator
from enum import auto, Enum
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from collections.abc import Mapping
from typing import Any, Callable, Optional, Union
from typing_extensions import deprecated
import torch
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.types import FileLike
from torch.utils._pytree import (
FlattenFunc,
FromDumpableContextFn,
ToDumpableContextFn,
UnflattenFunc,
)
if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch._ops import OpOverload
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
__all__ = [
"AdditionalInputs",
"Constraint",
"Dim",
"ExportBackwardSignature",
"ExportGraphSignature",
"ExportedProgram",
"CustomDecompTable",
"default_decompositions",
"Dim",
"dims",
"draft_export",
"export_for_training",
"export",
"ExportBackwardSignature",
"ExportedProgram",
"ExportGraphSignature",
"FlatArgsAdapter",
"load",
"ModuleCallEntry",
"ModuleCallSignature",
"default_decompositions",
"dims",
"export",
"export_for_training",
"load",
"register_dataclass",
"save",
"ShapesCollection",
"unflatten",
"FlatArgsAdapter",
"UnflattenedModule",
"AdditionalInputs",
"draft_export",
]
# To make sure export specific custom ops are loaded
@ -82,9 +61,9 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
def export_for_training(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: Optional[Mapping[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
@ -181,9 +160,9 @@ def export_for_training(
def export(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: Optional[Mapping[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
@ -540,9 +519,9 @@ def load(
def draft_export(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: Optional[Mapping[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
strict: bool = False,
) -> ExportedProgram: