PEP585 update - torch/export (#145165)

See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145165
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-18 15:03:05 -08:00
committed by PyTorch MergeBot
parent 316808e4e9
commit b6c5562c1f
14 changed files with 257 additions and 269 deletions

View File

@ -8,12 +8,12 @@ import sys
import typing
import warnings
import zipfile
from collections.abc import Iterator
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
@ -82,12 +82,12 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
def export_for_training(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
"""
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -177,13 +177,13 @@ def export_for_training(
def export_for_inference(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
decomp_table: Optional[Dict["OpOverload", Optional[Callable]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
decomp_table: Optional[dict["OpOverload", Optional[Callable]]] = None,
) -> ExportedProgram:
"""
:func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -262,12 +262,12 @@ def export_for_inference(
def export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
"""
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -383,8 +383,8 @@ def save(
ep: ExportedProgram,
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
opset_version: Optional[Dict[str, int]] = None,
extra_files: Optional[dict[str, Any]] = None,
opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
) -> None:
"""
@ -466,8 +466,8 @@ def save(
def load(
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
expected_opset_version: Optional[Dict[str, int]] = None,
extra_files: Optional[dict[str, Any]] = None,
expected_opset_version: Optional[dict[str, int]] = None,
) -> ExportedProgram:
"""
@ -577,7 +577,7 @@ def load(
def register_dataclass(
cls: Type[Any],
cls: type[Any],
*,
serialized_type_name: Optional[str] = None,
) -> None:

View File

@ -2,7 +2,7 @@ import inspect
import logging
import os
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch._logging._internal
@ -26,7 +26,7 @@ class FailureType(IntEnum):
return self.name
def prettify_stack(stack: List[Dict[str, str]], str_to_filename: Dict[str, str]) -> str:
def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[str, str]) -> str:
res = ""
for frame in stack:
if frame["filename"] not in str_to_filename:
@ -38,8 +38,8 @@ def prettify_stack(stack: List[Dict[str, str]], str_to_filename: Dict[str, str])
def filter_stack(
stack: List[Dict[str, str]], str_to_filename: Dict[str, str]
) -> List[Dict[str, str]]:
stack: list[dict[str, str]], str_to_filename: dict[str, str]
) -> list[dict[str, str]]:
for i, s in enumerate(reversed(stack)):
s["filename"] = str(s["filename"])
if s["filename"] not in str_to_filename:
@ -50,22 +50,22 @@ def filter_stack(
return stack[-3:]
def hash_stack(stack: List[Dict[str, str]]) -> str:
def hash_stack(stack: list[dict[str, str]]) -> str:
return ";".join(f'line: {s["line"]} filename: {s["filename"]}' for s in stack)
class FailureReport:
def __init__(
self, failure_type: FailureType, data: Dict[str, Any], xfail: bool = False
self, failure_type: FailureType, data: dict[str, Any], xfail: bool = False
) -> None:
self.failure_type: FailureType = failure_type
self.data: Dict[str, Any] = data
self.data: dict[str, Any] = data
self.xfail: bool = xfail
def __repr__(self) -> str:
return f"FailureReport(failure_type={self.failure_type}, xfail={self.xfail}, data={self.data})"
def print(self, str_to_filename: Dict[str, str]) -> str:
def print(self, str_to_filename: dict[str, str]) -> str:
if self.failure_type == FailureType.MISSING_FAKE_KERNEL:
op = self.data["op"]
@ -113,8 +113,8 @@ class FailureReport:
class DraftExportReport:
def __init__(self, failures: List[FailureReport], str_to_filename: Dict[str, str]):
self.failures: List[FailureReport] = failures
def __init__(self, failures: list[FailureReport], str_to_filename: dict[str, str]):
self.failures: list[FailureReport] = failures
self.str_to_filename = str_to_filename
def successful(self) -> bool:
@ -156,10 +156,10 @@ Please follow the instructions to fix the errors.
class CaptureStructuredTrace(logging.Handler):
def __init__(self, specific_log_keys: List[str]):
def __init__(self, specific_log_keys: list[str]):
super().__init__()
self.specific_log_keys = specific_log_keys
self.logs: List[Tuple[str, Dict[str, Any]]] = []
self.logs: list[tuple[str, dict[str, Any]]] = []
self.logger = logging.getLogger("torch.__trace")
self.prev_get_dtrace = False
@ -185,14 +185,14 @@ class CaptureStructuredTrace(logging.Handler):
def draft_export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
preserve_module_call_signature: Tuple[str, ...] = (),
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
strict: bool = False,
pre_dispatch: bool = False,
) -> Tuple[ExportedProgram, DraftExportReport]:
) -> tuple[ExportedProgram, DraftExportReport]:
kwargs = kwargs or {}
dynamic_shapes = dynamic_shapes or {}
@ -234,15 +234,15 @@ def draft_export(
preserve_module_call_signature=preserve_module_call_signature,
)
str_to_filename: Dict[str, str] = {
str_to_filename: dict[str, str] = {
str(v): k for (k, v) in torch._logging.structured.INTERN_TABLE.items()
}
failures: List[FailureReport] = []
custom_ops_logs: Dict[
Any, Tuple[Dict[str, Any], FailureType]
failures: list[FailureReport] = []
custom_ops_logs: dict[
Any, tuple[dict[str, Any], FailureType]
] = {} # Dedup custom ops
data_dependent_logs: Dict[
str, Dict[str, Any]
data_dependent_logs: dict[
str, dict[str, Any]
] = {} # Dedup data dependent errors based on stacktrace
for log_name, log_contents in capture_structured_log.logs:

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import operator
from typing import List
import torch
from torch._higher_order_ops.effects import _get_schema, with_effects
@ -22,7 +21,7 @@ def _remove_effect_tokens_from_graph_helper(
inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs
output_node = None
with_effect_nodes: List[torch.fx.Node] = []
with_effect_nodes: list[torch.fx.Node] = []
# Output node need to check its args agianst output_token_names (collected from output_spec)
# Therefore, we only need to find the top-levele output node
@ -127,8 +126,8 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
This function does an inplace modification on the given ExportedProgram.
"""
num_tokens: int = 0
input_token_names: List[str] = []
new_input_specs: List[InputSpec] = []
input_token_names: list[str] = []
new_input_specs: list[InputSpec] = []
for inp in ep.graph_signature.input_specs:
if inp.kind == InputKind.TOKEN:
num_tokens += 1
@ -138,8 +137,8 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
new_input_specs.append(inp)
num_out_tokens: int = 0
new_output_specs: List[OutputSpec] = []
output_token_names: List[OutputSpec] = []
new_output_specs: list[OutputSpec] = []
output_token_names: list[OutputSpec] = []
for out in ep.graph_signature.output_specs:
if out.kind == OutputKind.TOKEN:
num_out_tokens += 1

View File

@ -2,7 +2,7 @@ import logging
import operator
import types
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple
from typing import Optional
import torch
import torch.fx._pytree as fx_pytree
@ -19,7 +19,7 @@ from torch.fx.passes.utils.fuser_utils import erase_nodes, fuse_as_graphmodule
log = logging.getLogger(__name__)
def _get_getitem_users(node: torch.fx.Node) -> Set[torch.fx.Node]:
def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]:
node_users = list(node.users.keys())
getitem_users = set()
for user in node_users:
@ -172,9 +172,9 @@ def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None:
def _construct_inputs(
gm: torch.fx.GraphModule,
signature: ModuleCallSignature,
node_name_map: Dict[str, torch.fx.Node],
) -> Tuple[List[torch.fx.Node], Dict[str, torch.fx.Node]]:
tree_unflatten_args: List[Optional[torch.fx.Node]] = []
node_name_map: dict[str, torch.fx.Node],
) -> tuple[list[torch.fx.Node], dict[str, torch.fx.Node]]:
tree_unflatten_args: list[Optional[torch.fx.Node]] = []
for input_ in signature.inputs:
if isinstance(input_, ConstantArgument) and input_.value is None:
# Constants should be directly embedded into the graph and not used
@ -213,8 +213,8 @@ def _construct_inputs(
def _insert_call_module(
gm: torch.fx.GraphModule,
args_nodes: List[torch.fx.Node],
kwargs_nodes: Dict[str, torch.fx.Node],
args_nodes: list[torch.fx.Node],
kwargs_nodes: dict[str, torch.fx.Node],
module_to_swap: torch.nn.Module,
name: str,
) -> torch.fx.Node:
@ -229,8 +229,8 @@ def _deconstruct_outputs(
gm: torch.fx.GraphModule,
signature: ModuleCallSignature,
module_node: torch.fx.Node,
node_name_map: Dict[str, torch.fx.Node],
orig_outputs: Tuple[torch.fx.Node, ...],
node_name_map: dict[str, torch.fx.Node],
orig_outputs: tuple[torch.fx.Node, ...],
) -> None:
from .unflatten import _generate_flatten_spec
@ -246,17 +246,17 @@ def _deconstruct_outputs(
def _swap_module_helper(
gm: torch.fx.GraphModule,
modules_to_swap: Dict[str, torch.nn.Module],
module_call_graph: Dict[str, ModuleCallSignature],
modules_to_swap: dict[str, torch.nn.Module],
module_call_graph: dict[str, ModuleCallSignature],
) -> torch.fx.GraphModule:
log.debug("Starting graph:")
log.debug(gm.graph)
legalize_graph(gm)
partitions: Dict[str, NodeList] = defaultdict(list)
partitions: dict[str, NodeList] = defaultdict(list)
node_name_map: Dict[str, torch.fx.Node] = {
node_name_map: dict[str, torch.fx.Node] = {
node.name: node for node in gm.graph.nodes
}
@ -399,7 +399,7 @@ def _fix_input_output_signature(
def _swap_modules(
ep: ExportedProgram, modules_to_swap: Dict[str, torch.nn.Module]
ep: ExportedProgram, modules_to_swap: dict[str, torch.nn.Module]
) -> torch.fx.GraphModule:
"""
Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps

View File

@ -10,7 +10,7 @@ import time
import types
import warnings
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Optional, Union
import torch
import torch._dynamo
@ -109,7 +109,7 @@ class ExportDynamoConfig:
"""
allow_rnn: bool = True
reorderable_logging_functions: Set[Callable] = dataclasses.field(
reorderable_logging_functions: set[Callable] = dataclasses.field(
default_factory=set
)
# Emit runtime asserts after AOTAutograd instead.
@ -123,7 +123,7 @@ class ExportDynamoConfig:
class ATenExportArtifact:
gm: torch.fx.GraphModule
sig: ExportGraphSignature
constants: Dict[
constants: dict[
str,
Union[
torch.Tensor,
@ -139,7 +139,7 @@ class ExportArtifact:
in_spec: TreeSpec
out_spec: TreeSpec
fake_mode: FakeTensorMode
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]]
module_call_specs: dict[str, dict[str, pytree.TreeSpec]]
DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
@ -223,8 +223,8 @@ def _extract_fake_inputs(gm, args, kwargs):
Also return the fake mode used to fakify those inputs.
"""
fake_inps: List[torch.Tensor] = []
fake_vals: List[torch.Tensor] = []
fake_inps: list[torch.Tensor] = []
fake_vals: list[torch.Tensor] = []
for node in gm.graph.nodes:
if node.op == "placeholder" and "val" in node.meta:
fake_val = node.meta["val"]
@ -336,21 +336,21 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
def _get_param_buffer_mapping(
original_module: torch.nn.Module,
traced_module: torch.nn.Module,
) -> Dict[str, str]:
) -> dict[str, str]:
"""
Returns a mapping of parameter/buffer names from the new module to the
original model. This is to help with restoring the FQN for parameter/buffers
of a traced module to what the original module contains.
"""
param_lookup: Dict[int, str] = {}
buffer_lookup: Dict[int, str] = {}
param_lookup: dict[int, str] = {}
buffer_lookup: dict[int, str] = {}
for name, param in original_module.named_parameters(remove_duplicate=False):
param_lookup[id(param)] = name
for name, buffer in original_module.named_buffers(remove_duplicate=False):
buffer_lookup[id(buffer)] = name
param_buffer_table: Dict[str, str] = {}
param_buffer_table: dict[str, str] = {}
for dynamo_name, dynamo_param in traced_module.named_parameters(
remove_duplicate=False
):
@ -371,9 +371,9 @@ def _get_param_buffer_mapping(
def _preserve_requires_grad_pass(
gm: torch.fx.GraphModule,
sig: ExportGraphSignature,
fake_params_buffers: Dict[str, torch.Tensor],
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
flat_fake_args: List[Any],
fake_params_buffers: dict[str, torch.Tensor],
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
flat_fake_args: list[Any],
):
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
assert len(sig.input_specs) == len(placeholders)
@ -410,10 +410,10 @@ def _preserve_requires_grad_pass(
def _remap_constants(
orig_constant_attrs: ConstantAttrMap,
graph_signature: ExportGraphSignature,
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
) -> None:
"""Rewrite the graph signature and constants table to use the FQN from the original module."""
remap_table: Dict[str, List[str]] = {}
remap_table: dict[str, list[str]] = {}
for name, value in constants.items():
if value in orig_constant_attrs:
remap_table[name] = orig_constant_attrs[value]
@ -617,7 +617,7 @@ def _restore_state_dict(
traced_module.recompile()
def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]:
def _get_module_hierarchy(mod: torch.nn.Module) -> dict[str, str]:
return {
name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False)
}
@ -626,9 +626,9 @@ def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]:
def _make_module_call_graph(
in_spec: TreeSpec,
out_spec: TreeSpec,
module_call_signatures: Dict[str, ModuleCallSignature],
forward_arg_names: Optional[List[str]] = None,
) -> List[ModuleCallEntry]:
module_call_signatures: dict[str, ModuleCallSignature],
forward_arg_names: Optional[list[str]] = None,
) -> list[ModuleCallEntry]:
original = [
ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
for fqn in _EXPORT_MODULE_HIERARCHY # type: ignore[union-attr]
@ -651,11 +651,11 @@ def _make_module_call_graph(
def _export_to_torch_ir(
f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
*,
preserve_module_call_signature: Tuple[str, ...] = (),
preserve_module_call_signature: tuple[str, ...] = (),
disable_constraint_solver: bool = False,
allow_complex_guards_as_runtime_asserts: bool = False,
restore_fqn: bool = True,
@ -681,7 +681,7 @@ def _export_to_torch_ir(
_check_dynamic_shapes(combined_args, dynamic_shapes)
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
try:
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {}
ctx = nullcontext()
if not isinstance(f, torch.fx.GraphModule):
ctx = _wrap_submodules( # type: ignore[assignment]
@ -814,9 +814,9 @@ def _export_to_aten_ir(
def _get_forward_arg_names(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
) -> List[str]:
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
) -> list[str]:
"""
Gets the argument names to forward that are used, for restoring the
original signature when unlifting the exported program module.
@ -829,7 +829,7 @@ def _get_forward_arg_names(
sig = inspect.signature(mod.forward)
_args = sig.bind_partial(*args).arguments
names: List[str] = []
names: list[str] = []
for name, value in _args.items():
# handle variable number of positional args
if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL:
@ -843,7 +843,7 @@ def _get_forward_arg_names(
return names
def _get_non_persistent_buffers(mod: torch.nn.Module) -> Set[str]:
def _get_non_persistent_buffers(mod: torch.nn.Module) -> set[str]:
"""
Returns set of non-persistent buffers in a module and its submodules.
"""
@ -855,10 +855,10 @@ def _get_non_persistent_buffers(mod: torch.nn.Module) -> Set[str]:
def _rewrite_dynamo_tensor_constants(
orig_mod_buffers: Set[torch.Tensor],
traced_mod_buffers: Dict[str, torch.Tensor],
orig_mod_buffers: set[torch.Tensor],
traced_mod_buffers: dict[str, torch.Tensor],
graph_signature: ExportGraphSignature,
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
) -> None:
"""
Dynamo erroneously marks tensor attributes on modules as buffers.
@ -879,7 +879,7 @@ def _rewrite_dynamo_tensor_constants(
def _move_non_persistent_buffers_to_tensor_constants(
orig_mod: torch.nn.Module,
graph_signature: ExportGraphSignature,
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
) -> None:
"""
Moves non-persistent buffers to tensor constants.
@ -984,7 +984,7 @@ def _verify_placeholder_names(
)
def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
def get_ep_stats(ep: ExportedProgram) -> dict[str, Any]:
op_count = 0
op_set = set()
for m in ep.graph_module.modules():
@ -1000,8 +1000,8 @@ def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
return {"op_count": op_count, "op_set": op_set}
_EXPORT_FLAGS: Optional[Set[str]] = None
_EXPORT_MODULE_HIERARCHY: Optional[Dict[str, str]] = None
_EXPORT_FLAGS: Optional[set[str]] = None
_EXPORT_MODULE_HIERARCHY: Optional[dict[str, str]] = None
def _log_export_wrapper(fn):
@ -1065,7 +1065,7 @@ def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs):
return example_inputs, example_kwarg_inputs
def _get_original_state_dict(mod: torch.nn.Module) -> Dict[str, Any]:
def _get_original_state_dict(mod: torch.nn.Module) -> dict[str, Any]:
# Explicitly not calling mode.state_dict() as we do not want the module state for serialization
# but the running module state so we can always match by id() the entries here with the graph inputs
named_parameters = dict(mod.named_parameters(remove_duplicate=False))
@ -1096,24 +1096,24 @@ def _process_export_inputs(mod, args, kwargs, dynamic_shapes):
def _get_module_call_graph(
export_artifact: ExportArtifact,
preserve_module_call_signature: Tuple[str, ...],
preserve_module_call_signature: tuple[str, ...],
strict_mode_export: bool,
forward_arg_names: Optional[List[str]] = None,
) -> Tuple[torch.fx.GraphModule, List[ModuleCallEntry]]:
forward_arg_names: Optional[list[str]] = None,
) -> tuple[torch.fx.GraphModule, list[ModuleCallEntry]]:
"""
In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and
return module_call_graph.
"""
gm: torch.fx.GraphModule = export_artifact.aten.gm
export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
module_call_specs: Dict[
str, Dict[str, TreeSpec]
module_call_specs: dict[
str, dict[str, TreeSpec]
] = export_artifact.module_call_specs
in_spec: TreeSpec = export_artifact.in_spec
out_spec: TreeSpec = export_artifact.out_spec
# Make module signatures.
module_call_signatures: Dict[str, ModuleCallSignature] = {}
module_call_signatures: dict[str, ModuleCallSignature] = {}
for fqn, specs in module_call_specs.items():
mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn
module_call_signatures[mod_fqn] = ModuleCallSignature(
@ -1142,7 +1142,7 @@ def _get_module_call_graph(
def _get_range_constraints(
export_artifact: ExportArtifact, combined_args: Dict[str, Any], dynamic_shapes
export_artifact: ExportArtifact, combined_args: dict[str, Any], dynamic_shapes
):
gm: torch.fx.GraphModule = export_artifact.aten.gm
export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
@ -1250,12 +1250,12 @@ def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
def _strict_export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
preserve_module_call_signature: Tuple[str, ...],
args: tuple[Any, ...],
kwargs: dict[str, Any],
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
preserve_module_call_signature: tuple[str, ...],
pre_dispatch: bool,
original_state_dict: Dict[str, Any],
original_state_dict: dict[str, Any],
orig_in_spec: TreeSpec,
allow_complex_guards_as_runtime_asserts: bool,
_is_torch_jit_trace: bool,
@ -1278,12 +1278,12 @@ def _strict_export(
def _strict_export_lower_to_aten_ir(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
preserve_module_call_signature: Tuple[str, ...],
args: tuple[Any, ...],
kwargs: dict[str, Any],
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
preserve_module_call_signature: tuple[str, ...],
pre_dispatch: bool,
original_state_dict: Dict[str, Any],
original_state_dict: dict[str, Any],
orig_in_spec: TreeSpec,
allow_complex_guards_as_runtime_asserts: bool,
_is_torch_jit_trace: bool,
@ -1364,7 +1364,7 @@ def _strict_export_lower_to_aten_ir(
# params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level)
constant_attrs = _gather_constant_attrs(mod)
param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
param_buffer_table: dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
# Dynamo does not track which buffers were registered as non-persistent. This info
# is available in the original module, so we transfer it to the traced module. Also,
@ -1459,7 +1459,7 @@ def _export_to_aten_ir_make_fx(
mod, params_spec, params_len, store_orig_mod=True
)
params_buffers_args: List[Any] = []
params_buffers_args: list[Any] = []
params_buffers_args.extend(params_and_buffers_flat)
params_buffers_args.extend(args)
@ -1480,7 +1480,7 @@ def _export_to_aten_ir_make_fx(
# For any buffer that is assigned, we want to associate it to the final proxy node
# that it is assigned to. This node can then be copied into the buffer.
assigned_buffers: Dict[str, str] = {}
assigned_buffers: dict[str, str] = {}
hook = register_buffer_assignment_hook(
non_strict_root, assigned_buffers
)
@ -1532,12 +1532,12 @@ def _export_to_aten_ir_make_fx(
# Dictionary that tracks subclass type to original getattr function
# and the attributes we can proxy.
tensor_type_to_old_getattribute: Dict[
Type[torch.Tensor], Tuple[Callable, Set[str]]
tensor_type_to_old_getattribute: dict[
type[torch.Tensor], tuple[Callable, set[str]]
] = {}
for arg in args:
subclass_types_to_instances: Dict[
Type[torch.Tensor], List[Type[torch.Tensor]]
subclass_types_to_instances: dict[
type[torch.Tensor], list[type[torch.Tensor]]
] = get_subclass_typing_container(arg)
for subclass_type in subclass_types_to_instances:
if subclass_type not in tensor_type_to_old_getattribute:
@ -1725,12 +1725,12 @@ def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
def _non_strict_export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
preserve_module_call_signature: Tuple[str, ...],
args: tuple[Any, ...],
kwargs: dict[str, Any],
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
preserve_module_call_signature: tuple[str, ...],
pre_dispatch: bool,
original_state_dict: Dict[str, Any],
original_state_dict: dict[str, Any],
orig_in_spec: TreeSpec,
allow_complex_guards_as_runtime_asserts: bool,
_is_torch_jit_trace: bool,
@ -1744,7 +1744,7 @@ def _non_strict_export(
out_spec: Optional[TreeSpec] = None
in_spec: Optional[TreeSpec] = None
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {}
def _tuplify_outputs(aot_export):
def _aot_export_non_strict(mod, args, kwargs=None, **flags):
@ -1896,12 +1896,12 @@ def _non_strict_export(
@_disable_prexisiting_fake_mode
def _export_for_training(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
*,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
global _EXPORT_MODULE_HIERARCHY
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
@ -1987,12 +1987,12 @@ def _export_for_training(
@_disable_prexisiting_fake_mode
def _export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
*,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
preserve_module_call_signature: tuple[str, ...] = (),
pre_dispatch: bool = False,
allow_complex_guards_as_runtime_asserts: bool = False,
_is_torch_jit_trace: bool = False,

View File

@ -1,9 +1,9 @@
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
from torch.utils._pytree import Context, TreeSpec
def reorder_kwargs(user_kwargs: Dict[str, Any], spec: TreeSpec) -> Dict[str, Any]:
def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any]:
"""Reorder user-provided kwargs to match the order in `spec`. `spec` is
expected to be the in_spec of an exported program, i.e. the spec that
results from flattening `(args, kwargs)`.

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import copy
import warnings
from collections.abc import Sequence
from itertools import chain
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
@ -25,7 +26,7 @@ from .exported_program import (
)
def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> List:
def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list:
reordered_kwargs = reorder_kwargs(kwargs, in_spec)
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
(args, reordered_kwargs)
@ -61,7 +62,7 @@ def _check_input_constraints_pre_hook(self, args, kwargs):
def _unlift_inputs_as_getattr(
gm: torch.fx.GraphModule,
lifted_inputs: Sequence[Optional[str]],
) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]:
) -> tuple[dict[str, torch.fx.Node], dict[str, torch.fx.Node]]:
"""
Unlift inputs referring to params/buffers/constants as getattr nodes in the
graph
@ -90,8 +91,8 @@ def _unlift_inputs_as_getattr(
def _insert_copy_for_mutations(
gm: torch.fx.GraphModule,
mutated_outputs: Sequence[Optional[str]],
unlifted_name_to_node: Dict[str, torch.fx.Node],
input_name_to_node: Dict[str, torch.fx.Node],
unlifted_name_to_node: dict[str, torch.fx.Node],
input_name_to_node: dict[str, torch.fx.Node],
) -> None:
"""
Find the all the buffers and inputs that were mutated and insert copy_
@ -144,7 +145,7 @@ def _insert_copy_for_mutations(
def _get_codegen(
in_spec: pytree.TreeSpec,
out_spec: Optional[pytree.TreeSpec],
forward_arg_names: Optional[List[str]] = None,
forward_arg_names: Optional[list[str]] = None,
) -> _PyTreeCodeGen:
"""
Create the codegen for the graph module based on the in/out specs
@ -180,9 +181,9 @@ def _unlift(
mutated_outputs: Sequence[Optional[str]],
in_spec: pytree.TreeSpec,
out_spec: Optional[pytree.TreeSpec],
state_dict: Dict[str, Any],
constants: Dict[str, Any],
forward_arg_names: Optional[List[str]] = None,
state_dict: dict[str, Any],
constants: dict[str, Any],
forward_arg_names: Optional[list[str]] = None,
):
"""
Args:
@ -214,8 +215,8 @@ def _unlift(
def _register_attrs_to_new_gm(
new_gm: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
state_dict: Dict[str, Any],
constants: Dict[str, Any],
state_dict: dict[str, Any],
constants: dict[str, Any],
) -> None:
non_persistent_buffers = set(graph_signature.non_persistent_buffers)
for name in graph_signature.buffers:
@ -388,7 +389,7 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
forward_arg_names = (
sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
)
lifted_inputs: List[Optional[str]] = [
lifted_inputs: list[Optional[str]] = [
(
in_spec.target
if in_spec.kind
@ -403,7 +404,7 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
for in_spec in ep.graph_signature.input_specs
]
mutated_outputs: List[Optional[str]] = [
mutated_outputs: list[Optional[str]] = [
(
out_spec.target
if out_spec.kind

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Callable, Dict
from typing import Callable
import torch
from torch._export.utils import (
@ -13,7 +13,7 @@ from torch._export.utils import (
__all__ = ["CustomDecompTable"]
class CustomDecompTable(Dict[torch._ops.OperatorBase, Callable]):
class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
"""
This is a custom dictionary that is specifically used for handling decomp_table in export.
The reason we need this is because in the new world, you can only *delete* an op from decomp
@ -126,7 +126,7 @@ class CustomDecompTable(Dict[torch._ops.OperatorBase, Callable]):
self._materialize_if_needed()
return self.decomp_table.items()
def materialize(self) -> Dict[torch._ops.OperatorBase, Callable]:
def materialize(self) -> dict[torch._ops.OperatorBase, Callable]:
for op in _collect_all_valid_cia_ops():
if _is_aten_op(op):
continue

View File

@ -5,7 +5,7 @@ import logging
import sys
from collections import defaultdict
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from torch.utils._pytree import (
@ -242,7 +242,7 @@ Dim.DYNAMIC = _DimHint.DYNAMIC # type: ignore[attr-defined]
def dims(
*names: str, min: Optional[int] = None, max: Optional[int] = None
) -> Tuple[_Dim, ...]:
) -> tuple[_Dim, ...]:
"""
Util to create multiple :func:`Dim` types.
@ -401,13 +401,13 @@ Constraint = Union[_Constraint, _DerivedConstraint, _RelaxedConstraint]
def _process_equalities(
constraint: Constraint,
get_sources: Callable[[int, int], List["Source"]],
get_sources: Callable[[int, int], list["Source"]],
shape_env: "ShapeEnv",
names: Dict[str, Tuple[int, int]],
source_pairs: List[Tuple["Source", "Source"]],
derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]],
phantom_symbols: Dict[str, "Symbol"],
relaxed_sources: Set["Source"],
names: dict[str, tuple[int, int]],
source_pairs: list[tuple["Source", "Source"]],
derived_equalities: list[tuple["Source", Union["Source", "Symbol"], Callable]],
phantom_symbols: dict[str, "Symbol"],
relaxed_sources: set["Source"],
):
"""
Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become
@ -582,7 +582,7 @@ def _tree_map_with_path(
raise
def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> Dict[str, Any]:
def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> dict[str, Any]:
# combine args and kwargs following the signature of f, as it happens
# in the body of f when called with *args, **kwargs
if isinstance(f, ExportedProgram):
@ -684,8 +684,8 @@ def _warn_on_None_dynamic_shape_dimension():
def _check_dynamic_shapes(
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
combined_args: dict[str, Any],
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
):
"""
Checks the dynamic_shapes specification for correctness,
@ -698,7 +698,7 @@ def _check_dynamic_shapes(
if isinstance(dynamic_shapes, (tuple, list)):
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
bounds: Dict[str, Tuple[int, int]] = {}
bounds: dict[str, tuple[int, int]] = {}
def check_same_bounds(dim):
if dim.__name__ in bounds:
@ -799,9 +799,9 @@ def _check_dynamic_shapes(
def _process_dynamic_shapes(
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
) -> List[Constraint]:
combined_args: dict[str, Any],
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
) -> list[Constraint]:
"""
Reads the dynamic_shapes specification and produces a list of constraints.
"""
@ -814,12 +814,12 @@ def _process_dynamic_shapes(
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
# map of Dim names representing input shape dimensions to constraints on them
symbols: Dict[str, List[Constraint]] = defaultdict(list)
symbols: dict[str, list[Constraint]] = defaultdict(list)
# track roots that do not directly represent input shape dimensions
phantom_roots: Dict[str, _PhantomRoot] = {}
derived_constraints_with_phantom_root: List[_DerivedConstraint] = []
phantom_roots: dict[str, _PhantomRoot] = {}
derived_constraints_with_phantom_root: list[_DerivedConstraint] = []
# list of constraints to return
constraints: List[Constraint] = []
constraints: list[Constraint] = []
def to_constraint(dim, tensor, i):
import sympy
@ -979,7 +979,7 @@ def _process_dynamic_shapes(
def _get_dim_name_mapping(
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None]
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
):
name_to_dim = {}
for dim in tree_flatten(
@ -1002,8 +1002,8 @@ def _get_dim_name_mapping(
def refine_dynamic_shapes_from_suggested_fixes(
msg: str,
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
) -> Union[Dict[str, Any], Tuple[Any], List[Any]]:
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
) -> Union[dict[str, Any], tuple[Any], list[Any]]:
"""
When exporting with :func:`dynamic_shapes`, export may fail with a ConstraintViolation error if the specification
doesn't match the constraints inferred from tracing the model. The error message may provide suggested fixes -
@ -1072,7 +1072,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
name_to_dim = _get_dim_name_mapping(dynamic_shapes)
# track derived dim roots
roots: Set[str] = set()
roots: set[str] = set()
for k, c in shape_fixes.items():
assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr))
if isinstance(c, sympy.Expr): # check dim/derived dim expression
@ -1087,7 +1087,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
assert k in name_to_dim or k in roots
# cache so we don't produce multiple derived dim objects
derived_dim_cache: Dict[str, _DerivedDim] = {}
derived_dim_cache: dict[str, _DerivedDim] = {}
def apply_fixes(path, dim, dummy):
if dim is None or isinstance(dim, int): # not dynamic

View File

@ -7,9 +7,7 @@ from torch.export.exported_program import _decompose_exported_program
def _copy_graph_module_and_signature(
ep: torch.fx.GraphModule,
) -> typing.Tuple[
torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature
]:
) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]:
# copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(),
# and this can break placeholder names in some particular cases.
# For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'.

View File

@ -8,20 +8,9 @@ import operator
import types
import warnings
from collections import namedtuple
from collections.abc import Iterator
from contextlib import contextmanager
from typing import (
Any,
Callable,
Dict,
final,
Iterator,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._library.fake_class_registry import FakeScriptObject
@ -100,11 +89,11 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
@dataclasses.dataclass
class ModuleCallSignature:
inputs: List[ArgumentSpec]
outputs: List[ArgumentSpec]
inputs: list[ArgumentSpec]
outputs: list[ArgumentSpec]
in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec
forward_arg_names: Optional[List[str]] = None
forward_arg_names: Optional[list[str]] = None
def replace_all_uses_with(self, original_node, new_node):
for i in self.inputs:
@ -300,8 +289,8 @@ def _override_decomp_aten_to_variants():
def _split_decomp_table_to_cia_and_python_decomp(
decomp_table: Dict[torch._ops.OperatorBase, Callable]
) -> Tuple[Dict[torch._ops.OperatorBase, Callable], ...]:
decomp_table: dict[torch._ops.OperatorBase, Callable]
) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]:
all_preservable_cia_ops = set(_collect_all_valid_cia_ops())
cia_ops_to_callable = {}
@ -355,8 +344,8 @@ def default_decompositions() -> "CustomDecompTable":
def _decompose_and_get_gm_with_new_signature_constants(
ep,
*,
cia_to_decomp: Dict[torch._ops.OperatorBase, Callable],
python_decomp_table: Dict[torch._ops.OperatorBase, Callable],
cia_to_decomp: dict[torch._ops.OperatorBase, Callable],
python_decomp_table: dict[torch._ops.OperatorBase, Callable],
joint_loss_index: Optional[int],
decompose_custom_triton_ops,
):
@ -743,7 +732,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
def _remove_unneccessary_copy_op_pass(
gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
) -> tuple[torch.fx.GraphModule, ExportGraphSignature]:
"""
Removes redundant copy_ node that was introduced due to mutated buffer.
"""
@ -774,8 +763,8 @@ def _common_getitem_elimination_pass(
if not isinstance(module, torch.fx.GraphModule):
continue
node_id: Dict[torch.fx.Node, str] = {}
getitems: Dict[str, torch.fx.Node] = {}
node_id: dict[torch.fx.Node, str] = {}
getitems: dict[str, torch.fx.Node] = {}
for node in list(module.graph.nodes):
if node.op == "call_function" and node.target == operator.getitem:
source, idx = node.args
@ -797,13 +786,13 @@ def _common_getitem_elimination_pass(
def _get_updated_module_call_graph(
gm: torch.fx.GraphModule,
old_module_call_graph: List[ModuleCallEntry],
old_module_call_graph: list[ModuleCallEntry],
):
new_module_call_graph = copy.deepcopy(old_module_call_graph)
# use node-level provenance metadata to create a map
# from old node names to new node names
provenance: Dict[str, str] = {}
provenance: dict[str, str] = {}
for node in gm.graph.nodes:
if history := node.meta.get("from_node", []):
provenance[history[-1].name] = node.name
@ -822,8 +811,8 @@ def _get_updated_module_call_graph(
def _decompose_exported_program(
ep,
*,
cia_to_decomp: Dict[torch._ops.OperatorBase, Callable],
python_decomp_table: Dict[torch._ops.OperatorBase, Callable],
cia_to_decomp: dict[torch._ops.OperatorBase, Callable],
python_decomp_table: dict[torch._ops.OperatorBase, Callable],
joint_loss_index: Optional[int],
decompose_custom_triton_ops: bool,
):
@ -889,18 +878,18 @@ class ExportedProgram:
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
root: Union[torch.nn.Module, dict[str, Any]],
graph: torch.fx.Graph,
graph_signature: ExportGraphSignature,
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
range_constraints: "Dict[sympy.Symbol, Any]",
module_call_graph: List[ModuleCallEntry],
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]],
range_constraints: "dict[sympy.Symbol, Any]",
module_call_graph: list[ModuleCallEntry],
example_inputs: Optional[tuple[tuple[Any, ...], dict[str, Any]]] = None,
constants: Optional[
Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]]
dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]]
] = None,
*,
verifiers: Optional[List[Type[Verifier]]] = None,
verifiers: Optional[list[type[Verifier]]] = None,
):
# Remove codegen related things from the graph. It should just be a flat graph.
graph._codegen = torch.fx.graph.CodeGen()
@ -912,10 +901,10 @@ class ExportedProgram:
self._graph_module, graph_signature, module_call_graph
)
self._graph_signature: ExportGraphSignature = graph_signature
self._state_dict: Dict[str, Any] = state_dict
self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
self._state_dict: dict[str, Any] = state_dict
self._range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints
assert module_call_graph is not None
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
self._module_call_graph: list[ModuleCallEntry] = module_call_graph
self._example_inputs = example_inputs
self._constants = constants or {}
@ -975,7 +964,7 @@ class ExportedProgram:
yield param
@compatibility(is_backward_compatible=False)
def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]:
def named_parameters(self) -> Iterator[tuple[str, torch.nn.Parameter]]:
"""
Returns an iterator over original module parameters, yielding
both the name of the parameter as well as the parameter itself.
@ -992,7 +981,7 @@ class ExportedProgram:
yield buf
@compatibility(is_backward_compatible=False)
def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]:
def named_buffers(self) -> Iterator[tuple[str, torch.Tensor]]:
"""
Returns an iterator over original module buffers, yielding
both the name of the buffer as well as the buffer itself.
@ -1299,7 +1288,7 @@ class ExportedProgram:
@_disable_prexisiting_fake_mode
def run_decompositions(
self,
decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
decomp_table: Optional[dict[torch._ops.OperatorBase, Callable]] = None,
decompose_custom_triton_ops: bool = False,
) -> "ExportedProgram":
"""
@ -1520,8 +1509,8 @@ def _get_shape_env(gm):
def _get_updated_range_constraints(
gm: torch.fx.GraphModule,
old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None,
) -> "Dict[sympy.Symbol, Any]":
old_range_constraints: "Optional[dict[sympy.Symbol, Any]]" = None,
) -> "dict[sympy.Symbol, Any]":
assert old_range_constraints is not None
shape_env = _get_shape_env(gm)

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import dataclasses
from collections.abc import Collection, Mapping
from enum import auto, Enum
from typing import Collection, Dict, List, Mapping, Optional, Set, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import is_fake
@ -144,8 +145,8 @@ class OutputSpec:
@dataclasses.dataclass
class ExportBackwardSignature:
gradients_to_parameters: Dict[str, str]
gradients_to_user_inputs: Dict[str, str]
gradients_to_parameters: dict[str, str]
gradients_to_user_inputs: dict[str, str]
loss_output: str
@ -221,8 +222,8 @@ class ExportGraphSignature:
)
"""
input_specs: List[InputSpec]
output_specs: List[OutputSpec]
input_specs: list[InputSpec]
output_specs: list[OutputSpec]
# A list of parameters uniquely identified by mangled fully qualified name
@property
@ -276,7 +277,7 @@ class ExportGraphSignature:
# Graph node names of pytree-flattened inputs of original program
@property
def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]:
user_inputs: List[Union[int, float, bool, None, str]] = []
user_inputs: list[Union[int, float, bool, None, str]] = []
for s in self.input_specs:
if s.kind != InputKind.USER_INPUT:
continue
@ -302,7 +303,7 @@ class ExportGraphSignature:
# For joint-graph purposes, will include the loss output.
@property
def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]:
user_outputs: List[Union[int, float, bool, None, str]] = []
user_outputs: list[Union[int, float, bool, None, str]] = []
for s in self.output_specs:
if s.kind not in [
OutputKind.USER_OUTPUT,
@ -393,8 +394,8 @@ class ExportGraphSignature:
@property
def backward_signature(self) -> Optional[ExportBackwardSignature]:
loss_output = None
gradients_to_parameters: Dict[str, str] = {}
gradients_to_user_inputs: Dict[str, str] = {}
gradients_to_parameters: dict[str, str] = {}
gradients_to_user_inputs: dict[str, str] = {}
for spec in self.output_specs:
if spec.kind == OutputKind.LOSS_OUTPUT:
assert loss_output is None
@ -537,7 +538,7 @@ def _make_argument_spec(node, token_names) -> ArgumentSpec:
def _convert_to_export_graph_signature(
graph_signature: "GraphSignature",
gm: "torch.fx.GraphModule",
non_persistent_buffers: Set[str],
non_persistent_buffers: set[str],
) -> "ExportGraphSignature":
from torch.utils import _pytree as pytree

View File

@ -9,7 +9,7 @@ __all__ = ["move_to_device_pass"]
def move_to_device_pass(
ep: ExportedProgram, location: Union[torch.device, str, Dict[str, str]]
ep: ExportedProgram, location: Union[torch.device, str, dict[str, str]]
) -> ExportedProgram:
"""
Move the exported program to the given device.
@ -27,7 +27,7 @@ def move_to_device_pass(
def _get_new_device(
curr_device: torch.device,
location: Union[torch.device, str, Dict[str, str]],
location: Union[torch.device, str, dict[str, str]],
) -> str:
if isinstance(location, dict):
if str(curr_device) in location.keys():

View File

@ -9,7 +9,7 @@ from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, cast, Optional, Union
import torch
import torch.fx._pytree as fx_pytree
@ -85,7 +85,7 @@ def _assign_attr(
# foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2.
to_modules = {to_module}
for item in prefix:
ts: Set[torch.nn.Module] = set()
ts: set[torch.nn.Module] = set()
for to_module in to_modules:
if not hasattr(to_module, item):
setattr(to_module, item, torch.nn.Module())
@ -222,7 +222,7 @@ class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module):
to the next InterpreterModule, and wraps back around after the last.
"""
def __init__(self, attrs: Set[str], call_modules: List[InterpreterModule]):
def __init__(self, attrs: set[str], call_modules: list[InterpreterModule]):
super().__init__()
assert call_modules
self._modules = call_modules[0]._modules
@ -273,8 +273,8 @@ class FlatArgsAdapter(abc.ABC):
self,
target_spec: pytree.TreeSpec,
input_spec: pytree.TreeSpec,
input_args: List[Any],
) -> List[Any]:
input_args: list[Any],
) -> list[Any]:
"""NOTE: This adapter may mutate given ``input_args_with_path``."""
...
@ -316,7 +316,7 @@ class UnflattenedModule(torch.nn.Module):
_copy_graph_attrs(export_module._graph_module, self, seen_attrs)
self.range_constraints = export_module.range_constraints
self.equality_constraints: List = []
self.equality_constraints: list = []
# aliasing/unused param or buffer issues:
# in strict-mode export, dynamo export will deduplicate aliased tensors,
@ -329,8 +329,8 @@ class UnflattenedModule(torch.nn.Module):
# the state_dict as module attributes, but only keep the used tensors in the
# graph's forward pass (_sink_params).
state_dict = export_module.state_dict
assigned_params: Set[str] = set() # tracking unused params
id_to_param: Dict[int, torch.nn.Parameter] = {} # handling weight-sharing
assigned_params: set[str] = set() # tracking unused params
id_to_param: dict[int, torch.nn.Parameter] = {} # handling weight-sharing
for name in self.graph_signature.parameters: # this loop adds used params
param = state_dict[name]
if id(param) not in id_to_param:
@ -347,8 +347,8 @@ class UnflattenedModule(torch.nn.Module):
assigned_params.add(name)
non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
assigned_buffers: Set[str] = set() # tracking unused buffers
id_to_buffer: Dict[int, Tuple[torch.nn.Parameter, bool]] = {}
assigned_buffers: set[str] = set() # tracking unused buffers
id_to_buffer: dict[int, tuple[torch.nn.Parameter, bool]] = {}
for name in self.graph_signature.buffers: # this loop adds used buffers
if name in non_persistent_buffers:
persistent = False
@ -407,7 +407,7 @@ class UnflattenedModule(torch.nn.Module):
)
# use id map so we don't double-clone aliased constants
id_to_const: Dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {}
id_to_const: dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {}
for fqn, constant in export_module.constants.items():
if id(constant) not in id_to_const:
if isinstance(constant, torch.Tensor):
@ -423,14 +423,14 @@ class UnflattenedModule(torch.nn.Module):
# This is to handle parameters/buffers that point to the same tensor
# object id -> list of (node_name, target_name)
consts_map: Dict[int, List[Tuple[str, str]]] = defaultdict(list)
consts_targets: Set[str] = set()
consts_map: dict[int, list[tuple[str, str]]] = defaultdict(list)
consts_targets: set[str] = set()
def add_to_consts_map(obj_id, node_name, target_name):
name_list = consts_map[obj_id]
name_list.append((node_name, target_name))
added_params_buffers: Set[str] = set() # track aliased/unused params, buffers
added_params_buffers: set[str] = set() # track aliased/unused params, buffers
for s in self.graph_signature.input_specs:
if s.kind == InputKind.PARAMETER or (
s.kind == InputKind.BUFFER and s.persistent
@ -476,7 +476,7 @@ class UnflattenedModule(torch.nn.Module):
add_to_consts_map(id(tensor), ph_name, fqn)
# node name -> list of possible targets
inputs_to_state: Dict[str, List[str]] = {}
inputs_to_state: dict[str, list[str]] = {}
for node_target in consts_map.values():
targets = [t[1] for t in node_target]
for n, _ in node_target:
@ -790,7 +790,7 @@ def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
def graph_dump(graph: torch.fx.Graph) -> str:
ret = []
nodes_idx: Dict[int, int] = {}
nodes_idx: dict[int, int] = {}
def arg_dump(arg) -> str:
if isinstance(arg, torch.fx.Node):
@ -902,15 +902,15 @@ class _ModuleFrame:
def __init__(
self,
flat_graph: torch.fx.Graph,
nodes: Tuple[torch.fx.Node, ...],
nodes: tuple[torch.fx.Node, ...],
seen_nodes,
seen_modules,
seen_attrs,
created_modules,
parent,
module_stack: List[Tuple[str, Optional[str], int]],
module_stack: list[tuple[str, Optional[str], int]],
module_id,
module_call_graph: Dict[str, ModuleCallSignature],
module_call_graph: dict[str, ModuleCallSignature],
module: Optional[Union[torch.fx.GraphModule, UnflattenedModule]] = None,
):
self.flat_graph = flat_graph
@ -944,7 +944,7 @@ class _ModuleFrame:
self.graph = self.module.graph
# Mapping of nodes in the flat graph to nodes in this graph.
self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {}
self.node_map: dict[torch.fx.Node, torch.fx.Node] = {}
self.node_to_placeholder = {}
self.parent_call_module: Optional[torch.fx.Node] = None
@ -1017,7 +1017,7 @@ class _ModuleFrame:
] = flat_arg_node
with self.parent.graph.inserting_before(self.parent_call_module):
input_nodes: List[Optional[torch.fx.Node]] = []
input_nodes: list[Optional[torch.fx.Node]] = []
for input in signature.inputs:
if isinstance(input, ConstantArgument):
input_nodes.append(input.value) # type: ignore[arg-type]
@ -1175,7 +1175,7 @@ class _ModuleFrame:
parent_out: Optional[torch.fx.Node] = _generate_flatten_spec(
self.parent.module, self.parent_call_module, signature.out_spec
)
graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node
graph_outputs: Union[torch.fx.Node, list[torch.fx.Node]] = tree_out_node
else:
graph_outputs = []
# Iterate through nodes we have copied into self.graph.
@ -1351,10 +1351,10 @@ class _SubmoduleEntry:
def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
seen_nodes: Dict[str, torch.fx.Node] = {}
seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list)
seen_attrs: Dict[str, Set[str]] = defaultdict(set)
created_modules: Dict[str, torch.nn.Module] = {}
seen_nodes: dict[str, torch.fx.Node] = {}
seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
seen_attrs: dict[str, set[str]] = defaultdict(set)
created_modules: dict[str, torch.nn.Module] = {}
_ModuleFrame(
orig_graph,
tuple(orig_graph.nodes),
@ -1376,7 +1376,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu
def _reorder_submodules(
parent: torch.nn.Module, fqn_order: Dict[str, int], prefix: str = ""
parent: torch.nn.Module, fqn_order: dict[str, int], prefix: str = ""
):
# TODO Can be optimized by adding submodules ahead of time.
if prefix == "":
@ -1496,7 +1496,7 @@ class _IVals:
def _copy_graph_attrs(
gm: torch.fx.GraphModule,
root_module: UnflattenedModule,
seen_attrs: Dict[str, Set[str]],
seen_attrs: dict[str, set[str]],
):
for child_fqn, names in seen_attrs.items():
module = _get_attr(root_module, child_fqn) if child_fqn else root_module
@ -1550,9 +1550,9 @@ def _deduplicate_modules(partitions):
def _sink_params(
module: torch.nn.Module,
inputs_to_state: Dict[str, List[str]],
scope: List[str],
module_id_to_inputs_removed: Optional[Dict[int, Set[str]]] = None,
inputs_to_state: dict[str, list[str]],
scope: list[str],
module_id_to_inputs_removed: Optional[dict[int, set[str]]] = None,
):
"""Sink params, buffers, and constants from graph inputs into get_attr nodes.
@ -1613,7 +1613,7 @@ def _sink_params(
)
# Filter out inputs_to_state corresponding to current scope.
inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {}
inputs_to_state_of_scope: dict[torch.fx.Node, list[str]] = {}
for node in inputs:
if node.name not in inputs_to_state:
continue
@ -1640,7 +1640,7 @@ def _sink_params(
inputs_to_state_of_scope[node] = state_name
# Record name of remove inputs for return purpose.
inputs_removed: Set[str] = set()
inputs_removed: set[str] = set()
for node, state_name in inputs_to_state_of_scope.items():
if len(node.users) > 0: