PEP585 update - torch/_export (#145138)

See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145138
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #145154
This commit is contained in:
Aaron Orenstein
2025-01-18 08:47:47 -08:00
committed by PyTorch MergeBot
parent cd8d0fa20c
commit 97d4d3c40a
26 changed files with 325 additions and 351 deletions

View File

@ -76,14 +76,14 @@ def aot_compile_warning():
def aot_compile(
f: Callable,
args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[dict[str, Any]] = None,
options: Optional[dict[str, Any]] = None,
remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False,
same_signature: bool = True,
) -> Union[List[str], str]:
) -> Union[list[str], str]:
"""
Note: this function is not stable yet

View File

@ -4,8 +4,9 @@ import logging
import operator
import typing
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Optional, Union
import torch
import torch.export._trace
@ -72,7 +73,7 @@ def _trace_and_get_graph_from_model(model, args):
def _create_jit_graph(
model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
) -> tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]:
) -> tuple[torch.Graph, list["_C.IValue"], Any, Optional[torch.ScriptModule]]:
if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
torch_out = None
@ -263,7 +264,7 @@ def construct_fqn(ir, ref_map, name_map):
def get_block_to_lifted_attrs(
graph: torch._C.Graph,
) -> tuple[Dict[torch._C.Block, Set[str]], Dict[str, str]]:
) -> tuple[dict[torch._C.Block, set[str]], dict[str, str]]:
"""
Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes.
When a graph has control flow, the graph will be divided into multiple blocks. We want to convert
@ -280,19 +281,19 @@ def get_block_to_lifted_attrs(
"""
# A map from a block to its expected to be lifted arguments.
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {}
blocks_to_lifted_attrs: dict[torch._C.Block, set[str]] = {}
# Reference map stores the input (i.e., src) and output (i.e., dest) IR of a
# GetAttr node. By traversing this reference map, we can figure out the
# full IR aliasing pass and figure out the FQN of an attribute.
# E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1"
node_to_parent_map: Dict[str, str] = {}
node_to_parent_map: dict[str, str] = {}
# Used for reconstructing the FQN of an attribute based on the reference map.
# In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR
# This name map stores which attribute name is called for a src IR --> dest IR action.
# E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear"
node_to_attr_name: Dict[str, str] = {}
node_to_attr_name: dict[str, str] = {}
def _dfs_get_attr_dependency(entry):
"""
@ -315,7 +316,7 @@ def get_block_to_lifted_attrs(
Walk the graph in a bottom-up fashion to build the expected to be
lifted arguments for each block.
"""
arguments: Set[str] = set()
arguments: set[str] = set()
for node in entry.nodes():
for block in node.blocks():
# Recursively build.
@ -342,7 +343,7 @@ def get_block_to_lifted_attrs(
def get_attribute_fqn_from_ts_node(
name_to_attribute_fqn: Dict[str, str], node: torch._C.Node
name_to_attribute_fqn: dict[str, str], node: torch._C.Node
) -> str:
def get_attr(name: str):
if name in name_to_attribute_fqn:
@ -392,12 +393,12 @@ class TS2FXGraphConverter:
def __init__(
self,
ts_graph: Union[torch._C.Graph, torch._C.Block],
name_to_param: Dict[str, torch.Tensor],
name_to_buffer: Dict[str, torch.Tensor],
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
name_to_non_tensor_attribute: Dict[str, Any],
name_to_constant: Dict[str, Any],
name_to_attribute_fqn: Dict[str, str],
name_to_param: dict[str, torch.Tensor],
name_to_buffer: dict[str, torch.Tensor],
blocks_to_lifted_attrs: dict[torch._C.Block, set[str]],
name_to_non_tensor_attribute: dict[str, Any],
name_to_constant: dict[str, Any],
name_to_attribute_fqn: dict[str, str],
):
self.ts_graph = ts_graph
# Mapping of parameter FQN to actual parameter value
@ -406,19 +407,19 @@ class TS2FXGraphConverter:
self.name_to_buffer = name_to_buffer
self.fx_graph: torch.fx.Graph = torch.fx.Graph()
self.input_specs: List[InputSpec] = []
self.output_specs: List[OutputSpec] = []
self.input_specs: list[InputSpec] = []
self.output_specs: list[OutputSpec] = []
# Mapping of TS node name to converted FX node
self.name_to_node: Dict[
str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
self.name_to_node: dict[
str, Union[torch.fx.Node, list[torch.fx.Node], dict[Any, torch.fx.Node]]
] = {}
# Mapping of TS node name to constant value (int, str, TorchBind obj,
# tensor constants ...)
self.name_to_constant: Dict[str, Any] = name_to_constant
self.name_to_constant: dict[str, Any] = name_to_constant
# Mapping from torchscript node output name to attribute fully qualified name
self.name_to_attribute_fqn: Dict[str, str] = name_to_attribute_fqn
self.name_to_attribute_fqn: dict[str, str] = name_to_attribute_fqn
# Mapping from fully qualified name to real values or a fx graph node
# During convert, this represents the current value of a non-tensor attribute
@ -428,14 +429,14 @@ class TS2FXGraphConverter:
# self.count += 1
# c2 = self.count
# return x + c1 + c2
self.name_to_non_tensor_attribute_node: Dict[str, Any] = {}
self.name_to_non_tensor_attribute_node: dict[str, Any] = {}
# Mapping from fully qualified name to initial real values inputs
# We separate it from self.name_to_non_tensor_attribute_node since
# we need initial real value input when we construct fx.GraphModule
self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute
self.name_to_non_tensor_attribute: dict[str, Any] = name_to_non_tensor_attribute
self.subgraphs: Dict[str, torch.fx.GraphModule] = {}
self.subgraphs: dict[str, torch.fx.GraphModule] = {}
# Mapping of block to list of attributes that need to be lifted for each
# block
@ -457,7 +458,7 @@ class TS2FXGraphConverter:
# might have inplace updates to the variable defined in the parent fx graph. After
# the execution of that sub-block, the variable defined in the parent fx graph also
# needs to be updated.
self.name_update_from_subblock_to_parent: Set[str] = set()
self.name_update_from_subblock_to_parent: set[str] = set()
def _is_get_attr_node(self, fqn):
return (
@ -469,7 +470,7 @@ class TS2FXGraphConverter:
)
)
def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]):
def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: list[str]):
subgraph_nodes, subgraph_converters = [], []
for block in node.blocks():
subgraph_converter = TS2FXGraphConverter(
@ -506,7 +507,7 @@ class TS2FXGraphConverter:
Block[x.1]
%2 = x.1 ...
"""
arguments: Set[str] = set()
arguments: set[str] = set()
for block in entry.blocks():
for block_node in block.nodes():
for block_node_in in block_node.inputs():
@ -1332,12 +1333,12 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
def __init__(
self,
ts_graph: Union[torch._C.Graph, torch._C.Block],
name_to_param: Dict[str, torch.Tensor],
name_to_buffer: Dict[str, torch.Tensor],
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
name_to_non_tensor_attribute: Dict[str, Any],
name_to_constant: Dict[str, Any],
name_to_attribute_fqn: Dict[str, str],
name_to_param: dict[str, torch.Tensor],
name_to_buffer: dict[str, torch.Tensor],
blocks_to_lifted_attrs: dict[torch._C.Block, set[str]],
name_to_non_tensor_attribute: dict[str, Any],
name_to_constant: dict[str, Any],
name_to_attribute_fqn: dict[str, str],
):
super().__init__(
ts_graph,
@ -1350,7 +1351,7 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
)
# Data to keep track of unsupported nodes.
self.unsupported_node_list: List[torch._C.Node] = []
self.unsupported_node_list: list[torch._C.Node] = []
# Add mock to needed attributes.
self.name_to_node = ExplainTS2FXGraphConverter._DictMock(
@ -1395,7 +1396,7 @@ class TS2EPConverter:
self,
ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction],
sample_args: tuple[Any, ...],
sample_kwargs: Optional[Dict[str, Any]] = None,
sample_kwargs: Optional[dict[str, Any]] = None,
):
self.ts_model = ts_model
self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
@ -1403,8 +1404,8 @@ class TS2EPConverter:
self.sample_args = sample_args
self.sample_kwargs = sample_kwargs
self.name_to_param: Dict[str, torch.Tensor] = {}
self.name_to_buffer: Dict[str, torch.Tensor] = {}
self.name_to_param: dict[str, torch.Tensor] = {}
self.name_to_buffer: dict[str, torch.Tensor] = {}
param_list = (
list(self.ts_model.parameters())
if not isinstance(self.ts_model, torch._C.ScriptFunction)
@ -1422,8 +1423,8 @@ class TS2EPConverter:
else:
self.name_to_buffer[k] = tensor
self.name_to_non_tensor_attributes: Dict[str, Any] = {}
self.name_to_constant: Dict[str, Any] = {}
self.name_to_non_tensor_attributes: dict[str, Any] = {}
self.name_to_constant: dict[str, Any] = {}
self.lift_get_attr()
@ -1509,7 +1510,7 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
def retrace_as_exported_program(
self,
gm: torch.fx.GraphModule,
name_to_constant: Dict[str, Any],
name_to_constant: dict[str, Any],
):
dynamic_shapes = _tree_map_with_path(
lambda path, x: (
@ -1569,7 +1570,7 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
# TS2FXGraphConverter since it gets attributes from self.ts_model
# which is not accessable in TS2FXGraphConverter. It is similar to where
# we collect self.name_to_param and self.name_to_buffer.
name_to_attribute_fqn: Dict[str, str] = {}
name_to_attribute_fqn: dict[str, str] = {}
def get_attr(fqn: str):
name = fqn.split(".")

View File

@ -4,12 +4,12 @@ import re
import string
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Set
from typing import Any, Optional
from types import ModuleType
import torch
_TAGS: Dict[str, Dict[str, Any]] = {
_TAGS: dict[str, dict[str, Any]] = {
"torch": {
"cond": {},
"dynamic-shape": {},
@ -79,12 +79,12 @@ class ExportCase:
description: str # A description of the use case.
model: torch.nn.Module
name: str
example_kwargs: Dict[str, Any] = field(default_factory=dict)
example_kwargs: dict[str, Any] = field(default_factory=dict)
extra_args: Optional[ArgsType] = None # For testing graph generalization.
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
tags: Set[str] = field(default_factory=set)
tags: set[str] = field(default_factory=set)
support_level: SupportLevel = SupportLevel.SUPPORTED
dynamic_shapes: Optional[Dict[str, Any]] = None
dynamic_shapes: Optional[dict[str, Any]] = None
def __post_init__(self):
check_inputs_type(self.example_args, self.example_kwargs)
@ -98,10 +98,10 @@ class ExportCase:
raise ValueError(f'Invalid description: "{self.description}"')
_EXAMPLE_CASES: Dict[str, ExportCase] = {}
_MODULES: Set[ModuleType] = set()
_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
_EXAMPLE_CASES: dict[str, ExportCase] = {}
_MODULES: set[ModuleType] = set()
_EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {}
_EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {}
def register_db_case(case: ExportCase) -> None:

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import List
import torch
@ -9,7 +8,7 @@ class ListUnpack(torch.nn.Module):
erased after tracing.
"""
def forward(self, args: List[torch.Tensor]):
def forward(self, args: list[torch.Tensor]):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.

View File

@ -3,18 +3,7 @@ import contextlib
import inspect
import logging
from collections import defaultdict
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
import torch.utils._pytree as pytree
@ -97,8 +86,8 @@ def fakify(
mode: FakeTensorMode,
kp: KeyPath,
t: Any,
t_constraints: Dict[int, Dict[int, Constraint]],
sources: Dict[tuple[int, int], List[Source]],
t_constraints: dict[int, dict[int, Constraint]],
sources: dict[tuple[int, int], list[Source]],
):
source = key_path_to_source(kp)
if _is_constant_argument(t) or isinstance(t, (torch.ScriptObject, torch.nn.Module)):
@ -165,7 +154,7 @@ def make_fake_inputs(
combined_args = _combine_args(nn_module, args, kwargs)
_check_dynamic_shapes(combined_args, dynamic_shapes)
constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict)
t_constraints: dict[int, dict[int, Constraint]] = defaultdict(dict)
for constraint in constraints:
t_constraints[constraint.t_id][constraint.dim] = constraint
@ -214,17 +203,17 @@ def make_fake_inputs(
original_signature = inspect.signature(nn_module.forward)
else:
original_signature = None
sources: Dict[tuple[int, int], List[Source]] = defaultdict(list)
sources: dict[tuple[int, int], list[Source]] = defaultdict(list)
fake_args, fake_kwargs = tree_map_with_path(
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
(args, kwargs),
)
names: Dict[str, tuple[int, int]] = {}
source_pairs: List[tuple[Source, Source]] = []
derived_equalities: List[tuple[Source, Union[Source, Symbol], Callable]] = []
phantom_symbols: Dict[str, Symbol] = {}
relaxed_sources: Set[Source] = set()
names: dict[str, tuple[int, int]] = {}
source_pairs: list[tuple[Source, Source]] = []
derived_equalities: list[tuple[Source, Union[Source, Symbol], Callable]] = []
phantom_symbols: dict[str, Symbol] = {}
relaxed_sources: set[Source] = set()
for constraint in constraints:
torch.export.dynamic_shapes._process_equalities(
constraint,
@ -255,9 +244,9 @@ def make_fake_inputs(
def _flatten_dynamic_shapes(
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any]],
) -> List[Any]:
combined_args: dict[str, Any],
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
) -> list[Any]:
flat_shapes = []
def _tree_map_helper(path, t, shape):
@ -283,7 +272,7 @@ def _clean_dynamic_markers(tensor: torch.Tensor) -> None:
def produce_guards_and_solve_constraints(
fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule,
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None],
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
equalities_inputs: EqualityConstraint,
original_signature: inspect.Signature,
_is_torch_jit_trace=False,
@ -348,8 +337,8 @@ def produce_guards_and_solve_constraints(
def make_constraints(
fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule,
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None],
combined_args: dict[str, Any],
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
num_lifted_inputs: int,
):
"""
@ -435,7 +424,7 @@ def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
buffers_parameters = set(m.buffers())
buffers_parameters.update(m.parameters())
def inner(m: torch.nn.Module, prefix_atoms: List[str], constants):
def inner(m: torch.nn.Module, prefix_atoms: list[str], constants):
for k, v in m.__dict__.items():
if isinstance(
v,
@ -459,8 +448,8 @@ def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
def _get_graph_inputs_of_type_nn_module(
args: Optional[Tuple[Tuple[Any], Dict[Any, Any]]],
) -> Set[Type[torch.nn.Module]]:
args: Optional[tuple[tuple[Any], dict[Any, Any]]],
) -> set[type[torch.nn.Module]]:
if args is None:
return set()
module_types = set()
@ -471,14 +460,14 @@ def _get_graph_inputs_of_type_nn_module(
def _enter_enable_graph_inputs_of_type_nn_module(
module_types: Set[Type[torch.nn.Module]],
module_types: set[type[torch.nn.Module]],
) -> None:
for t in module_types:
torch._export.utils.register_module_as_pytree_input_node(t)
def _exit_enable_graph_inputs_of_type_nn_module(
module_types: Set[Type[torch.nn.Module]],
module_types: set[type[torch.nn.Module]],
) -> None:
for t in module_types:
torch._export.utils.deregister_module_as_pytree_input_node(t)
@ -486,7 +475,7 @@ def _exit_enable_graph_inputs_of_type_nn_module(
@contextlib.contextmanager
def _enable_graph_inputs_of_type_nn_module(
args: Optional[Tuple[Tuple[Any], Dict[Any, Any]]],
args: Optional[tuple[tuple[Any], dict[Any, Any]]],
):
if args is None:
yield
@ -502,8 +491,8 @@ def _enable_graph_inputs_of_type_nn_module(
@contextlib.contextmanager
def _fakify_module_inputs(
args: Tuple[Any],
kwargs: Dict[Any, Any],
args: tuple[Any],
kwargs: dict[Any, Any],
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
):
# This context manager is used to fakify module inputs.
@ -534,7 +523,7 @@ def _fakify_module_inputs(
def _fakify_script_objects(
mod: torch.nn.Module,
args: tuple[Any],
kwargs: Dict[Any, Any],
kwargs: dict[Any, Any],
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
):
# This context manager is used to fakify script objects into FakeScriptObject.

View File

@ -3,7 +3,7 @@ import operator
import traceback
import typing
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Optional, Union
import torch
from functorch.experimental.control_flow import _unstack_pytree
@ -31,7 +31,7 @@ Fn = Callable[..., Any]
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
_TORCH_SYM_OPS: Set[Callable] = {
_TORCH_SYM_OPS: set[Callable] = {
torch.sym_int,
torch.sym_float,
torch.sym_ite,
@ -64,9 +64,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self.root = torch.nn.Module()
self.graph = torch.fx.Graph()
self.graph.set_codegen(codegen)
self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
self.tensor_attrs: dict[str, torch.Tensor] = {} # type: ignore[assignment]
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self.submodules: Dict[torch.nn.Module, str] = {}
self.submodules: dict[torch.nn.Module, str] = {}
def trace(self) -> None: # type: ignore[override]
raise ExportPassBaseError("ExportTracer doesn't support trace().")
@ -162,7 +162,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self,
target: str, # type: ignore[override]
args: tuple[Argument, ...],
kwargs: Dict[str, Argument],
kwargs: dict[str, Argument],
) -> ProxyValue:
arg = super().placeholder(target, args, kwargs)
return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
@ -171,7 +171,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self,
target: torch.fx.node.Target,
args: tuple[Argument, ...],
kwargs: Dict[str, Argument],
kwargs: dict[str, Argument],
) -> ProxyValue:
return self.callback.output(args[0], NodeMetadata(self.node.meta)).data # type: ignore[return-value]
@ -179,7 +179,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self,
target: torch.fx.node.Target,
args: tuple[Argument, ...],
kwargs: Dict[str, Argument],
kwargs: dict[str, Argument],
) -> ProxyValue:
meta = NodeMetadata(self.node.meta)
@ -218,7 +218,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
raise ExportPassBaseError(f"Unsupported target type: {target}")
def get_attr(
self, target: str, args: tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
) -> Argument:
return super().get_attr(target, args, kwargs)
@ -226,12 +226,12 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self,
target: torch.fx.node.Target,
args: tuple[Argument, ...],
kwargs: Dict[str, Argument],
kwargs: dict[str, Argument],
) -> None:
raise ExportPassBaseError("call_module is not supported.")
def call_method(
self, target: str, args: tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
) -> None:
raise ExportPassBaseError("call_method is not supported.")
@ -254,7 +254,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
kind: str,
target: torch.fx.node.Target,
args: tuple[Argument, ...],
kwargs: Dict[str, Argument],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
args_data, kwargs_data = pytree.tree_map_only(
@ -277,7 +277,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self.tracer.set_metadata(res_proxy.node, res_data)
return ProxyValue(res_data, res_proxy)
def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
def inputs(self, graph_module: torch.fx.GraphModule) -> list[Argument]:
# TODO(angelayi): Update this with what we decide to do for metadata in
# the exported graph module
if (args := graph_module.meta.get("args", None)) is not None:
@ -327,7 +327,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self,
op,
args: tuple[Argument, ...],
kwargs: Dict[str, Argument],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
return self._fx("call_function", op, args, kwargs, meta)
@ -345,7 +345,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
pred: ProxyValue,
true_fn: torch.fx.GraphModule,
false_fn: torch.fx.GraphModule,
inputs: List[Argument],
inputs: list[Argument],
meta: NodeMetadata,
) -> ProxyValue:
true_branch = self.call_submodule(true_fn, tuple(inputs))
@ -363,8 +363,8 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
def call_map(
self,
f: torch.fx.GraphModule,
mapped_args: List[ProxyValue],
operands: List[ProxyValue],
mapped_args: list[ProxyValue],
operands: list[ProxyValue],
meta: NodeMetadata,
) -> ProxyValue:
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
@ -383,7 +383,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
) -> ProxyValue:
return self._fx("call_function", operator.getitem, (value, key), {}, meta)
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue:
return self._fx("output", "output", (results,), {}, meta)
def call_submodule(

View File

@ -1,10 +1,10 @@
from typing import Any, Dict, Set
from typing import Any
NodeMetadataValue = Any
PROTECTED_KEYS: Set[str] = {
PROTECTED_KEYS: set[str] = {
"val",
"stack_trace",
"nn_module_stack",
@ -14,8 +14,8 @@ PROTECTED_KEYS: Set[str] = {
class NodeMetadata:
def __init__(self, data: Dict[str, Any]) -> None:
self.data: Dict[str, Any] = data.copy()
def __init__(self, data: dict[str, Any]) -> None:
self.data: dict[str, Any] = data.copy()
def __getitem__(self, key: str) -> NodeMetadataValue:
return self.data[key]

View File

@ -1,5 +1,6 @@
# pyre-strict
from typing import Union, Iterator, Iterable, Generic
from typing import Union, Generic
from collections.abc import Iterator, Iterable
import torch
from typing import TypeVar

View File

@ -3,7 +3,7 @@ import math
import operator
import traceback
from functools import partial
from typing import Callable, Dict, List, NamedTuple, Set
from typing import Callable, NamedTuple
import sympy
@ -45,11 +45,11 @@ def _convert_range_to_int(range: ValueRanges):
class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
def __init__(
self,
range_constraints: Dict[sympy.Symbol, ValueRanges],
range_constraints: dict[sympy.Symbol, ValueRanges],
):
super().__init__()
self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set()
self.range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints
self._asserts_generated_unbacked_symbols: set[sympy.Symbol] = set()
self.counter = 0
def _assert_range_constraint(self, node, lower, upper, assert_msg):
@ -105,8 +105,8 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
# need the proxy for shape, which further requires the proxy for ret[1], etc.
def add_assertions(val):
call_backs: List[Callable] = []
messages: List[str] = []
call_backs: list[Callable] = []
messages: list[str] = []
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
symbol = val.node.expr
if symbol in self.existing_inline_assertions:
@ -161,9 +161,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
def _get_existing_inline_assertions(
graph_module: torch.fx.GraphModule,
range_constraints: Dict[sympy.Symbol, ValueRanges],
) -> Dict[sympy.Symbol, ValueRanges]:
existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {}
range_constraints: dict[sympy.Symbol, ValueRanges],
) -> dict[sympy.Symbol, ValueRanges]:
existing_inline_assertions: dict[sympy.Symbol, ValueRanges] = {}
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import operator
from typing import Dict, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
import torch
from torch.export.exported_program import ConstantArgument, TensorArgument
@ -23,7 +23,7 @@ class CollectTracepointsPass(PassBase):
"""
def __init__(
self, specs: Dict[str, ModuleCallSignature], sig: ExportGraphSignature
self, specs: dict[str, ModuleCallSignature], sig: ExportGraphSignature
) -> None:
super().__init__()
self.specs = specs

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import collections
from collections import defaultdict
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
import torch
import torch.utils._pytree as pytree
@ -53,8 +53,8 @@ class ConstantFolder(torch.fx.Interpreter):
skip_constructors: bool = False,
):
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
self.node_replacements: dict[torch.fx.Node, Any] = {}
self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors
@ -281,7 +281,7 @@ def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule
new_graph = torch.fx.Graph()
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
output_nodes = []
for node in gm.graph.nodes:
if node.meta[META_TAG] == MODULE_TAG:

View File

@ -1,5 +1,5 @@
import copy
from typing import Dict, Optional, List
from typing import Optional
import torch
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
@ -9,7 +9,7 @@ from torch._ops import OpOverload
aten = torch.ops.aten
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = {
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
aten._assert_async.msg: aten._functional_assert_async.msg,
}
@ -60,7 +60,7 @@ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
self,
op: OpOverload,
args: tuple[Argument, ...],
kwargs: Dict[str, Argument],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
@ -88,7 +88,7 @@ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
return self._dep_token
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue:
assert self._dep_token is not None
return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type]

View File

@ -1,5 +1,4 @@
import functools
from typing import List
import torch
from torch._export.passes._node_metadata_hook import (
@ -8,7 +7,7 @@ from torch._export.passes._node_metadata_hook import (
)
def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: List[str]) -> None:
def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: list[str]) -> None:
"""
This is used by draft_export to insert guards in front of calls to custom
operators which have a generated fake kernel.

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import collections
import warnings
from typing import Any, Dict, List, Union
from typing import Any, Union
import torch
from torch._export.verifier import SpecViolationError
@ -29,13 +29,13 @@ class ConstantAttrMap(collections.abc.MutableMapping):
def __init__(self) -> None:
# Underlying dict that we use to implement this mapping.
self._constant_attrs: Dict[
Union[int, torch.Tensor, FakeScriptObject], List[Any]
self._constant_attrs: dict[
Union[int, torch.Tensor, FakeScriptObject], list[Any]
] = {}
# Map from the hash(ScriptObject) to the ScriptObject itself. Used for
# APIs like `__iter__` that should look like they're returning the
# original ScriptObjects.
self._script_object_map: Dict[int, torch.ScriptObject] = {}
self._script_object_map: dict[int, torch.ScriptObject] = {}
def __getitem__(
self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
@ -113,7 +113,7 @@ def lift_constants_pass(
gm: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
constant_attrs: ConstantAttrMap,
) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]:
) -> dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]:
"""
Takes a graph module, graph signature, and modifies them implace to lift any
constants (tensors or custom classes) as inputs to the graph. Returns a
@ -131,7 +131,7 @@ def lift_constants_pass(
Returns:
A dictionary of fqn => constant value.
"""
all_constants: Dict[
all_constants: dict[
str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
] = {}
@ -300,13 +300,13 @@ def lift_constants_pass(
def rewrite_script_object_meta(
gm: torch.fx.GraphModule,
) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]:
) -> dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]:
"""When tracing, we produce a graph with FakeScriptObject in the
meta["val"].
For now, we rewrie meta["val"] to be a placeholder CustomObjArgument
"""
constants: Dict[
constants: dict[
str,
Union[
torch.Tensor,

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from __future__ import annotations
from typing import List, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
import torch
from torch._higher_order_ops.wrap import wrap_with_autocast
@ -116,7 +116,7 @@ def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
exit_autocast # 3
E # 4
"""
enter_autocast_node_stack: List[torch.fx.Node] = []
enter_autocast_node_stack: list[torch.fx.Node] = []
first_node_after_outer_most_exit: bool = False
def node_call_back(node: torch.fx.Node) -> bool:

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import logging
import operator
from typing import List, Optional, Union
from typing import Optional, Union
import torch
import torch.export._trace
@ -269,9 +269,9 @@ def _conv1d_op_with_squeeze(
inp: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
stride: List[int],
padding: List[int],
dilation: List[int],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
) -> torch.Tensor:
# In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, Optional
from typing import Optional
import torch
from torch._ops import OpOverload, HigherOrderOperator
from torch._export.error import InternalError
@ -9,7 +9,7 @@ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: dict[OpOverload, OpOverload] = {
torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
}

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import List
from torch._export.serde.schema import Node
@ -12,4 +11,4 @@ class ExternKernelNode:
@dataclass
class ExternKernelNodes:
nodes: List[ExternKernelNode]
nodes: list[ExternKernelNode]

View File

@ -1,5 +1,5 @@
import dataclasses
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
import torch
from torch._dynamo.exc import UserError, UserErrorType
@ -24,7 +24,7 @@ class RootDim:
min: int
max: Union[int, None]
derived: List[str]
derived: list[str]
@dataclasses.dataclass
@ -33,15 +33,15 @@ class DynamicShapesSpec:
This stores a dynamic_shapes spec for de/serialization.
"""
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None]
dims: Dict[str, RootDim]
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
dims: dict[str, RootDim]
def _postprocess_serialized_shapes(
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None],
dims: Dict[str, Dict[str, Union[int, List[str], None]]],
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
dims: dict[str, dict[str, Union[int, list[str], None]]],
to_dict: Optional[bool] = False,
) -> Union[DynamicShapesSpec, Dict[str, Any]]:
) -> Union[DynamicShapesSpec, dict[str, Any]]:
"""
Sorts dims and dumps to dictionary format.
"""
@ -63,11 +63,11 @@ def _postprocess_serialized_shapes(
def _dump_dynamic_shapes(
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None],
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
to_dict: Optional[bool] = False,
) -> Union[DynamicShapesSpec, Dict[str, Any]]:
) -> Union[DynamicShapesSpec, dict[str, Any]]:
"""
Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec.
Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims".
@ -127,7 +127,7 @@ def _dump_dynamic_shapes(
}
```
"""
dims: Dict[str, Dict[str, Any]] = {}
dims: dict[str, dict[str, Any]] = {}
def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
"""
@ -198,9 +198,9 @@ def _dump_dynamic_shapes(
def _load_dynamic_shapes(
spec: Union[DynamicShapesSpec, Dict[str, Any]],
spec: Union[DynamicShapesSpec, dict[str, Any]],
from_dict: Optional[bool] = False,
) -> Union[Dict[str, Any], tuple[Any], List[Any], None]:
) -> Union[dict[str, Any], tuple[Any], list[Any], None]:
"""
Utility function for dynamic shapes serialization.
Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export().

View File

@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Annotated, Dict, List, Optional
from typing import Annotated, Optional
from torch._export.serde.union import _Union
@ -96,10 +96,10 @@ class SymBool(_Union):
@dataclass
class TensorMeta:
dtype: Annotated[ScalarType, 10]
sizes: Annotated[List[SymInt], 20]
sizes: Annotated[list[SymInt], 20]
requires_grad: Annotated[bool, 30]
device: Annotated[Device, 40]
strides: Annotated[List[SymInt], 50]
strides: Annotated[list[SymInt], 50]
storage_offset: Annotated[SymInt, 60]
layout: Annotated[Layout, 70]
@ -175,29 +175,29 @@ class CustomObjArgument:
class Argument(_Union):
as_none: Annotated[bool, 10]
as_tensor: Annotated[TensorArgument, 20]
as_tensors: Annotated[List[TensorArgument], 30]
as_tensors: Annotated[list[TensorArgument], 30]
as_int: Annotated[int, 50]
as_ints: Annotated[List[int], 70]
as_ints: Annotated[list[int], 70]
as_float: Annotated[float, 80]
as_floats: Annotated[List[float], 90]
as_floats: Annotated[list[float], 90]
as_string: Annotated[str, 100]
as_strings: Annotated[List[str], 101]
as_strings: Annotated[list[str], 101]
as_sym_int: Annotated[SymIntArgument, 110]
as_sym_ints: Annotated[List[SymIntArgument], 120]
as_sym_ints: Annotated[list[SymIntArgument], 120]
as_scalar_type: Annotated[ScalarType, 130]
as_memory_format: Annotated[MemoryFormat, 140]
as_layout: Annotated[Layout, 150]
as_device: Annotated[Device, 160]
as_bool: Annotated[bool, 170]
as_bools: Annotated[List[bool], 180]
as_bools: Annotated[list[bool], 180]
as_sym_bool: Annotated[SymBoolArgument, 182]
as_sym_bools: Annotated[List[SymBoolArgument], 184]
as_sym_bools: Annotated[list[SymBoolArgument], 184]
as_graph: Annotated[GraphArgument, 200]
as_optional_tensors: Annotated[List[OptionalTensorArgument], 190]
as_optional_tensors: Annotated[list[OptionalTensorArgument], 190]
as_custom_obj: Annotated[CustomObjArgument, 210]
as_operator: Annotated[str, 220]
as_sym_float: Annotated[SymFloatArgument, 230]
as_sym_floats: Annotated[List[SymFloatArgument], 240]
as_sym_floats: Annotated[list[SymFloatArgument], 240]
class ArgumentKind(IntEnum):
@ -217,27 +217,27 @@ class NamedArgument:
@dataclass
class Node:
target: Annotated[str, 10]
inputs: Annotated[List[NamedArgument], 20]
outputs: Annotated[List[Argument], 30]
metadata: Annotated[Dict[str, str], 40]
inputs: Annotated[list[NamedArgument], 20]
outputs: Annotated[list[Argument], 30]
metadata: Annotated[dict[str, str], 40]
is_hop_single_tensor_return: Annotated[Optional[bool], 50] = None
@dataclass
class Graph:
inputs: Annotated[List[Argument], 10]
outputs: Annotated[List[Argument], 20]
nodes: Annotated[List[Node], 30]
tensor_values: Annotated[Dict[str, TensorMeta], 40]
sym_int_values: Annotated[Dict[str, SymInt], 50]
sym_bool_values: Annotated[Dict[str, SymBool], 60]
inputs: Annotated[list[Argument], 10]
outputs: Annotated[list[Argument], 20]
nodes: Annotated[list[Node], 30]
tensor_values: Annotated[dict[str, TensorMeta], 40]
sym_int_values: Annotated[dict[str, SymInt], 50]
sym_bool_values: Annotated[dict[str, SymBool], 60]
# This is for deserializing the submodule graphs from higher order ops
# (ex. cond, map) where single tensor returns will just return a single
# tensor, rather than following export schema and returning a singleton
# list.
is_single_tensor_return: Annotated[bool, 70] = False
custom_obj_values: Annotated[Dict[str, CustomObjArgument], 80] = field(default_factory=dict)
sym_float_values: Annotated[Dict[str, SymFloat], 90] = field(default_factory=dict)
custom_obj_values: Annotated[dict[str, CustomObjArgument], 80] = field(default_factory=dict)
sym_float_values: Annotated[dict[str, SymFloat], 90] = field(default_factory=dict)
@dataclass
class UserInputSpec:
@ -354,8 +354,8 @@ class OutputSpec(_Union):
@dataclass
class GraphSignature:
input_specs: Annotated[List[InputSpec], 10]
output_specs: Annotated[List[OutputSpec], 20]
input_specs: Annotated[list[InputSpec], 10]
output_specs: Annotated[list[OutputSpec], 20]
@dataclass
@ -366,8 +366,8 @@ class RangeConstraint:
@dataclass
class ModuleCallSignature:
inputs: Annotated[List[Argument], 10]
outputs: Annotated[List[Argument], 20]
inputs: Annotated[list[Argument], 10]
outputs: Annotated[list[Argument], 20]
# These are serialized by calling pytree.treespec_loads
# And deserialized by calling pytree.treespec_dumps
@ -376,7 +376,7 @@ class ModuleCallSignature:
# This field is used to prettify the graph placeholders
# after we ser/der and retrace
forward_arg_names: Annotated[Optional[List[str]], 50] = None
forward_arg_names: Annotated[Optional[list[str]], 50] = None
@dataclass
@ -392,8 +392,8 @@ class GraphModule:
# This is used for unflattening, by tracking the calling structure of all of
# the modules in order to unflatten the modules back to the eager calling
# conventions.
module_call_graph: Annotated[List[ModuleCallEntry], 60]
metadata: Annotated[Dict[str, str], 40] = field(default_factory=dict)
module_call_graph: Annotated[list[ModuleCallEntry], 60]
metadata: Annotated[dict[str, str], 40] = field(default_factory=dict)
# Invariant: Every time a change is made to the schema, one of the versions
@ -408,8 +408,8 @@ class SchemaVersion:
class ExportedProgram:
graph_module: Annotated[GraphModule, 10]
# Key is the opset namespace (ex. aten), and value is the version number
opset_version: Annotated[Dict[str, int], 20]
range_constraints: Annotated[Dict[str, RangeConstraint], 30]
opset_version: Annotated[dict[str, int], 20]
range_constraints: Annotated[dict[str, RangeConstraint], 30]
schema_version: Annotated[SchemaVersion, 60]
verifiers: Annotated[List[str], 70] = field(default_factory=list)
verifiers: Annotated[list[str], 70] = field(default_factory=list)
torch_version: Annotated[str, 80] = "<=2.4"

View File

@ -5,7 +5,7 @@ import inspect
import re
import typing
from enum import IntEnum
from typing import Annotated, Any, Dict, ForwardRef, List, Optional, Union
from typing import Annotated, Any, ForwardRef, Optional, Union
from torch._export.serde import schema
from torch._export.serde.union import _Union
@ -36,16 +36,16 @@ _THRIFT_TYPE_MAP = {
def _staged_schema():
yaml_ret: Dict[str, Any] = {}
yaml_ret: dict[str, Any] = {}
defs = {}
cpp_enum_defs: Dict[str, str] = {}
cpp_class_defs: Dict[str, str] = {}
cpp_type_decls: List[str] = []
cpp_json_defs: List[str] = []
thrift_enum_defs: List[str] = []
thrift_type_defs: Dict[str, str] = {}
cpp_enum_defs: dict[str, str] = {}
cpp_class_defs: dict[str, str] = {}
cpp_type_decls: list[str] = []
cpp_json_defs: list[str] = []
thrift_enum_defs: list[str] = []
thrift_type_defs: dict[str, str] = {}
def _handle_aggregate(ty) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
def _handle_aggregate(ty) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
def dump_type(t, level: int) -> tuple[str, str, str]:
if getattr(t, "__name__", None) in cpp_enum_defs:
return t.__name__, "int64_t", t.__name__
@ -125,7 +125,7 @@ def _staged_schema():
f"Default value {v} is not supported yet in export schema."
)
def dump_field(f) -> tuple[Dict[str, Any], str, Optional[str], str, int]:
def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]:
t, cpp_type, thrift_type = dump_type(f.type, 0)
ret = {"type": t}
cpp_default: Optional[str] = None
@ -524,12 +524,12 @@ def _hash_content(s: str):
@dataclasses.dataclass
class _Commit:
result: Dict[str, Any]
result: dict[str, Any]
checksum_next: str
yaml_path: str
additions: Dict[str, Any]
subtractions: Dict[str, Any]
base: Dict[str, Any]
additions: dict[str, Any]
subtractions: dict[str, Any]
base: dict[str, Any]
checksum_head: Optional[str]
cpp_header: str
cpp_header_path: str

View File

@ -24,15 +24,11 @@ from typing import (
Any,
Callable,
cast,
Dict,
final,
Iterator,
List,
Optional,
Set,
Type,
Union,
)
from collections.abc import Iterator
import sympy
@ -118,7 +114,7 @@ class SerializeError(RuntimeError):
pass
def _reverse_map(d: Dict[Any, Enum]):
def _reverse_map(d: dict[Any, Enum]):
return {v.value: k for k, v in d.items()}
@ -352,7 +348,7 @@ def serialize_torch_artifact(artifact: Optional[Any], pickle_protocol: int = DEF
del copyreg.dispatch_table[FakeTensor]
def deserialize_torch_artifact(serialized: Union[Dict[str, Any], tuple[Any, ...], bytes]):
def deserialize_torch_artifact(serialized: Union[dict[str, Any], tuple[Any, ...], bytes]):
if isinstance(serialized, (dict, tuple)):
return serialized
if len(serialized) == 0:
@ -401,8 +397,8 @@ def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr:
def serialize_range_constraints(
range_constraints: Dict[sympy.Symbol, ValueRanges]
) -> Dict[str, RangeConstraint]:
range_constraints: dict[sympy.Symbol, ValueRanges]
) -> dict[str, RangeConstraint]:
return {
str(k): RangeConstraint(
_sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type]
@ -426,15 +422,15 @@ def _get_schema_from_target(target):
@dataclass
class GraphState:
inputs: List[Argument] = field(default_factory=list)
outputs: List[Argument] = field(default_factory=list)
nodes: List[Node] = field(default_factory=list)
tensor_values: Dict[str, TensorMeta] = field(default_factory=dict)
sym_int_values: Dict[str, SymInt] = field(default_factory=dict)
sym_bool_values: Dict[str, SymBool] = field(default_factory=dict)
sym_float_values: Dict[str, SymFloat] = field(default_factory=dict)
inputs: list[Argument] = field(default_factory=list)
outputs: list[Argument] = field(default_factory=list)
nodes: list[Node] = field(default_factory=list)
tensor_values: dict[str, TensorMeta] = field(default_factory=dict)
sym_int_values: dict[str, SymInt] = field(default_factory=dict)
sym_bool_values: dict[str, SymBool] = field(default_factory=dict)
sym_float_values: dict[str, SymFloat] = field(default_factory=dict)
is_single_tensor_return: bool = False
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
custom_obj_values: dict[str, CustomObjArgument] = field(default_factory=dict)
class Final(type):
@ -450,13 +446,13 @@ class GraphModuleSerializer(metaclass=Final):
def __init__(
self,
graph_signature: ep.ExportGraphSignature,
module_call_graph: List[ep.ModuleCallEntry],
module_call_graph: list[ep.ModuleCallEntry],
):
self.graph_state = GraphState()
self.graph_signature = graph_signature
self.module_call_graph = module_call_graph
self.custom_objs: Dict[str, torch._C.ScriptObject] = {}
self.duplicate_getitem_nodes: Dict[str, str] = {}
self.custom_objs: dict[str, torch._C.ScriptObject] = {}
self.duplicate_getitem_nodes: dict[str, str] = {}
@contextmanager
def save_graph_state(self):
@ -597,7 +593,7 @@ class GraphModuleSerializer(metaclass=Final):
else:
return user_node.name
def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]:
ret = {}
if stack_trace := node.meta.get("stack_trace"):
ret["stack_trace"] = stack_trace
@ -647,7 +643,7 @@ class GraphModuleSerializer(metaclass=Final):
class_fqn=script_obj_meta.class_fqn,
)
def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]:
def serialize_sym_op_inputs(self, op, args) -> list[NamedArgument]:
if isinstance(op, torch._ops.OpOverload):
args_names = [arg.name for arg in op._schema.arguments]
else:
@ -669,7 +665,7 @@ class GraphModuleSerializer(metaclass=Final):
target: Any, # torch._ops.OpOverload and other custom operator types.
args,
kwargs=None
) -> List[NamedArgument]:
) -> list[NamedArgument]:
assert isinstance(target, (torch._ops.OpOverload, *_registered_extension_types()))
kwargs = kwargs or {}
serialized_args = []
@ -700,7 +696,7 @@ class GraphModuleSerializer(metaclass=Final):
return serialized_args
def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]:
def serialize_hoo_inputs(self, args, kwargs) -> list[NamedArgument]:
"""
For serializing HOO inputs since HOOs do not have a schema.
"""
@ -1180,8 +1176,8 @@ class GraphModuleSerializer(metaclass=Final):
)
def serialize_module_call_graph(
self, module_call_graph: List[ep.ModuleCallEntry]
) -> List[ModuleCallEntry]:
self, module_call_graph: list[ep.ModuleCallEntry]
) -> list[ModuleCallEntry]:
return [
ModuleCallEntry(
fqn=entry.fqn,
@ -1194,7 +1190,7 @@ class GraphModuleSerializer(metaclass=Final):
for entry in module_call_graph
]
def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]:
def serialize_outputs(self, node: torch.fx.Node) -> list[Argument]:
"""For a given node, return the dataclass representing its output values.
[NOTE: Multiple outputs] We handle aggregates differently than FX. For
@ -1294,7 +1290,7 @@ class GraphModuleSerializer(metaclass=Final):
return output_arguments
def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]:
def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]:
"""
For serializing HOO outputs since HOOs do not have a schema.
"""
@ -1359,7 +1355,7 @@ class GraphModuleSerializer(metaclass=Final):
# list outputs should've been handled earlier
raise SerializeError(f"Unable to serialize output {meta_val}")
def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]:
def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]:
meta_val = node.meta["val"]
idx_to_name = {}
@ -1406,7 +1402,7 @@ class GraphModuleSerializer(metaclass=Final):
is_single_tensor_return=self.graph_state.is_single_tensor_return,
)
def serialize_graph_module_metadata(self, meta: Dict[str, Any]):
def serialize_graph_module_metadata(self, meta: dict[str, Any]):
ret = {}
if custom := meta.get("custom"):
try:
@ -1431,8 +1427,8 @@ class GraphModuleSerializer(metaclass=Final):
@final
class ExportedProgramSerializer(metaclass=Final):
def __init__(self, opset_version: Optional[Dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL):
self.opset_version: Dict[str, int] = {}
def __init__(self, opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL):
self.opset_version: dict[str, int] = {}
if opset_version:
self.opset_version.update(opset_version)
if "aten" not in self.opset_version:
@ -1498,15 +1494,15 @@ class GraphModuleDeserializer(metaclass=Final):
class Result:
graph_module: torch.fx.GraphModule
signature: ep.ExportGraphSignature
module_call_graph: List[ep.ModuleCallEntry]
names_to_symbols: Dict[str, sympy.Symbol]
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]]
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
example_inputs: Optional[tuple[tuple[torch.Tensor, ...], Dict[str, Any]]]
module_call_graph: list[ep.ModuleCallEntry]
names_to_symbols: dict[str, sympy.Symbol]
state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]]
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
example_inputs: Optional[tuple[tuple[torch.Tensor, ...], dict[str, Any]]]
def __init__(self) -> None:
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
self.serialized_name_to_meta: Dict[str, MetaType] = {}
self.serialized_name_to_node: dict[str, torch.fx.Node] = {}
self.serialized_name_to_meta: dict[str, MetaType] = {}
self.graph = torch.fx.Graph()
self.module = torch.nn.Module()
@ -1953,10 +1949,10 @@ class GraphModuleDeserializer(metaclass=Final):
def deserialize(
self,
serialized_graph_module: GraphModule,
serialized_state_dict: Union[Dict[str, torch.Tensor], bytes],
constants: Union[Dict[str, Any], bytes],
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None,
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
serialized_state_dict: Union[dict[str, torch.Tensor], bytes],
constants: Union[dict[str, Any], bytes],
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]] = None,
symbol_name_to_range: Optional[dict[str, symbolic_shapes.ValueRanges]] = None,
) -> Result:
global _CURRENT_DESERIALIZER
assert _CURRENT_DESERIALIZER is None
@ -1996,7 +1992,7 @@ class GraphModuleDeserializer(metaclass=Final):
"ToFloat": torch.utils._sympy.functions.ToFloat,
"Identity": torch.utils._sympy.functions.Identity,
}
self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {}
self.symbol_name_to_symbol: dict[str, sympy.Symbol] = {}
self.constants = deserialize_torch_artifact(constants)
self.signature = self.deserialize_signature(serialized_graph_module.signature)
@ -2091,7 +2087,7 @@ class GraphModuleDeserializer(metaclass=Final):
kwargs[schema_arg.name] = actual_args[schema_arg.name]
return tuple(args), kwargs
def deserialize_hoo_inputs(self, inputs: List[NamedArgument]):
def deserialize_hoo_inputs(self, inputs: list[NamedArgument]):
"""
For deserializing HOO inputs since HOOs do not have a schema.
"""
@ -2232,7 +2228,7 @@ class GraphModuleDeserializer(metaclass=Final):
"torch.ops.higher_order" in serialized_node.target
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
):
meta_val: List[Any] = []
meta_val: list[Any] = []
arg = serialized_node.outputs[0].as_tensor
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
@ -2261,7 +2257,7 @@ class GraphModuleDeserializer(metaclass=Final):
fx_node: torch.fx.Node,
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument],
idx: int,
deserialized_metadata: Dict[str, Any],
deserialized_metadata: dict[str, Any],
):
if isinstance(arg, TensorArgument):
name = arg.name
@ -2290,7 +2286,7 @@ class GraphModuleDeserializer(metaclass=Final):
meta_val,
fx_node: torch.fx.Node,
args,
deserialized_metadata: Dict[str, Any],
deserialized_metadata: dict[str, Any],
):
for idx, arg in enumerate(args):
if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)):
@ -2343,7 +2339,7 @@ class GraphModuleDeserializer(metaclass=Final):
# return value.
# This performs the inverse mapping of the `serialize_outputs` call in
# serialization, see [NOTE: Multiple outputs]
meta_val: List[Any] = []
meta_val: list[Any] = []
if len(serialized_node.outputs) == 1:
assert isinstance(serialized_node.outputs[0].value, list)
assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
@ -2355,8 +2351,8 @@ class GraphModuleDeserializer(metaclass=Final):
fx_node.meta["val"] = tuple(meta_val)
self.serialized_name_to_node[fx_node.name] = fx_node
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
ret: Dict[str, Any] = {}
def deserialize_metadata(self, metadata: dict[str, str]) -> dict[str, Any]:
ret: dict[str, Any] = {}
if stack_trace := metadata.get("stack_trace"):
ret["stack_trace"] = stack_trace
@ -2446,8 +2442,8 @@ class GraphModuleDeserializer(metaclass=Final):
)
def deserialize_module_call_graph(
self, module_call_graph: List[ModuleCallEntry]
) -> List[ep.ModuleCallEntry]:
self, module_call_graph: list[ModuleCallEntry]
) -> list[ep.ModuleCallEntry]:
return [
ep.ModuleCallEntry(
fqn=entry.fqn,
@ -2463,8 +2459,8 @@ class GraphModuleDeserializer(metaclass=Final):
@final
class ExportedProgramDeserializer(metaclass=Final):
def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None):
self.expected_opset_version: Dict[str, int] = {}
def __init__(self, expected_opset_version: Optional[dict[str, int]] = None):
self.expected_opset_version: dict[str, int] = {}
if expected_opset_version:
self.expected_opset_version.update(expected_opset_version)
if "aten" not in self.expected_opset_version:
@ -2472,9 +2468,9 @@ class ExportedProgramDeserializer(metaclass=Final):
def deserialize_range_constraints(
self,
symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges],
symbol_name_to_symbol: Dict[str, sympy.Symbol],
) -> Dict[sympy.Symbol, ValueRanges]:
symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges],
symbol_name_to_symbol: dict[str, sympy.Symbol],
) -> dict[sympy.Symbol, ValueRanges]:
range_constraints = {}
for k, v in symbol_name_to_range.items():
if symbol := symbol_name_to_symbol.get(k):
@ -2486,9 +2482,9 @@ class ExportedProgramDeserializer(metaclass=Final):
def deserialize(
self,
exported_program: ExportedProgram,
state_dict: Union[Dict[str, torch.Tensor], bytes],
constants: Union[Dict[str, torch.Tensor], bytes],
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None,
state_dict: Union[dict[str, torch.Tensor], bytes],
constants: Union[dict[str, torch.Tensor], bytes],
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]] = None,
*,
_unsafe_skip_version_check=False,
) -> ep.ExportedProgram:
@ -2566,7 +2562,7 @@ def _dataclass_to_dict(obj):
def serialize(
exported_program: ep.ExportedProgram,
opset_version: Optional[Dict[str, int]] = None,
opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
) -> SerializedArtifact:
with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs):
@ -2627,7 +2623,7 @@ def _dict_to_dataclass(cls, data):
def deserialize(
artifact: SerializedArtifact,
expected_opset_version: Optional[Dict[str, int]] = None,
expected_opset_version: Optional[dict[str, int]] = None,
*,
_unsafe_skip_version_check=False,
) -> ep.ExportedProgram:
@ -2649,7 +2645,7 @@ def deserialize(
def _canonicalize_graph(
sorted_inputs, sorted_outputs, graph
) -> tuple[Graph, Dict[str, str]]:
) -> tuple[Graph, dict[str, str]]:
def _get_argument(a: Argument):
if a.type == "as_none":
return None
@ -2712,15 +2708,15 @@ def _canonicalize_graph(
def sort_nodes(nodes):
@dataclass
class Edges:
outs: List[int]
outs: list[int]
ins: int
graph_inputs: Set[str] = set()
def_table: Dict[str, int] = {}
edges: Dict[int, Edges] = {}
candidates: List[tuple[str, List[tuple[str, List[int]]], int]] = []
rank: Dict[str, int] = {}
ret: List[Node] = []
graph_inputs: set[str] = set()
def_table: dict[str, int] = {}
edges: dict[int, Edges] = {}
candidates: list[tuple[str, list[tuple[str, list[int]]], int]] = []
rank: dict[str, int] = {}
ret: list[Node] = []
def get_name(a) -> Optional[str]:
if a is None:
@ -2827,7 +2823,7 @@ def _canonicalize_graph(
assert len(sorted_nodes) == len(graph.nodes)
# Stage 2: Rename nodes.
name_table: Dict[str, str] = {}
name_table: dict[str, str] = {}
def rename_def(a):
def _rename(arg_name, values):
@ -3163,8 +3159,8 @@ class ExtensionHandler:
def register_extension(
op_type: Type[Any],
extension_handler: Type[ExtensionHandler],
op_type: type[Any],
extension_handler: type[ExtensionHandler],
):
"""Register custom de/serialization method for a node with non-standard type."""
assert issubclass(extension_handler, ExtensionHandler), f"Expected ExtensionHandler, got {extension_handler}."
@ -3187,5 +3183,5 @@ def _registered_extension_types():
# namespace to avoid conflicts.
# Serialization: Op type --> custom handler.
# De-serialization: Namespace --> custom handler.
_serialization_registry: Dict[Type[Any], Type[ExtensionHandler]] = {}
_deserialization_registry: Dict[str, Type[ExtensionHandler]] = {}
_serialization_registry: dict[type[Any], type[ExtensionHandler]] = {}
_deserialization_registry: dict[str, type[ExtensionHandler]] = {}

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import functools
from collections.abc import Hashable
from dataclasses import fields
from typing import Hashable, Set
class _UnionTag(str):
@ -26,8 +26,8 @@ class _UnionTag(str):
return hash(str(self))
@functools.lru_cache(maxsize=None)
def _get_field_names(cls) -> Set[str]:
@functools.cache
def _get_field_names(cls) -> set[str]:
return {f.name for f in fields(cls)}

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import logging
import warnings
from typing import Any, Dict, Iterable, Optional
from collections.abc import Iterable
from typing import Any, Optional
import torch
import torch.export
@ -18,8 +19,8 @@ def _generate_inputs_for_submodules(
model: torch.nn.Module,
target_submodules: Iterable[str],
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, tuple[Any, Any]]:
kwargs: Optional[dict[str, Any]] = None,
) -> dict[str, tuple[Any, Any]]:
"""
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
function doesn't work.
@ -61,11 +62,11 @@ def _generate_inputs_for_submodules(
def report_exportability(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
*,
strict: bool = True,
pre_dispatch: bool = False,
) -> Dict[str, Optional[Exception]]:
) -> dict[str, Optional[Exception]]:
"""
Report exportability issues for a module in one-shot.
@ -92,7 +93,7 @@ def report_exportability(
submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
tried_module_types = set()
report: Dict[str, Optional[Exception]] = {}
report: dict[str, Optional[Exception]] = {}
def try_export(module, module_name, args, kwargs):
nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types

View File

@ -8,21 +8,10 @@ import json
import math
import operator
import re
from collections.abc import Iterable
from contextlib import contextmanager
from inspect import Parameter
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from torch._guards import detect_fake_mode
@ -116,7 +105,7 @@ def _overwrite_signature_for_non_persistent_buffers(
return new_sig
def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]:
def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> dict[str, Any]:
"""
Param/buffer metadata needs to be saved before lowering to aten IR
because aten IR lifts them, as a result, automatic preservation doesn't work.
@ -174,7 +163,7 @@ def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]:
def _populate_param_buffer_metadata_to_new_gm(
params_buffers_to_node_meta: Dict[str, Any],
params_buffers_to_node_meta: dict[str, Any],
gm: torch.fx.GraphModule,
new_sig: "ExportGraphSignature",
) -> None:
@ -217,7 +206,7 @@ def _get_shape_env_from_gm(gm: torch.fx.GraphModule):
def _rename_without_collisions(
name_map: Dict[str, str],
name_map: dict[str, str],
orig_name: str,
name: str,
is_placeholder: bool = False,
@ -246,7 +235,7 @@ def _rename_without_collisions(
def _check_input_constraints_for_graph(
input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints
input_placeholders: list[torch.fx.Node], flat_args_with_path, range_constraints
) -> None:
def get_keystr(key_path: KeyPath) -> str:
"""For a given index into the flat_args, return a human readable string
@ -280,7 +269,7 @@ def _check_input_constraints_for_graph(
# NOTE: export already guarantees that the same symbol is used in metadata
# for all InputDims related by equality constraints, so we can just unify
# symbols with given input dimension values to check equality constraints.
unification_map: Dict[sympy.Symbol, Any] = {}
unification_map: dict[sympy.Symbol, Any] = {}
for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
node_val = node.meta.get("val")
if isinstance(node_val, FakeTensor):
@ -372,7 +361,7 @@ def _check_input_constraints_for_graph(
def register_dataclass_as_pytree_node(
cls: Type[Any],
cls: type[Any],
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
*,
@ -385,7 +374,7 @@ def register_dataclass_as_pytree_node(
cls
), f"Only dataclasses can be registered with this function: {cls}"
def default_flatten_fn(obj: Any) -> tuple[List[Any], Context]:
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
flattened = []
flat_names = []
none_names = []
@ -402,7 +391,7 @@ def register_dataclass_as_pytree_node(
flat_names, none_names = context
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
def default_flatten_fn_with_keys(obj: Any) -> tuple[List[Any], Context]:
def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, (flat_names, _none_names) = flatten_fn(obj) # type: ignore[misc]
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
@ -537,7 +526,7 @@ def sequential_split(
return new_gm
def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
def nodes_filter(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]:
"""Returns the nodes that match the node_call_back as a list."""
return [node for node in nodes if node_call_back(node)]
@ -572,7 +561,7 @@ def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature):
def nodes_first(
nodes: List[torch.fx.Node], node_call_back=None
nodes: list[torch.fx.Node], node_call_back=None
) -> Optional[torch.fx.Node]:
"""
Returns the first node that matches the node_call_back. If no node matches, returns None.
@ -584,12 +573,12 @@ def nodes_first(
return None
def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int:
def nodes_count(nodes: list[torch.fx.Node], node_call_back) -> int:
"""Returns the number of nodes that match the node_call_back."""
return len(nodes_filter(nodes, node_call_back))
def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
def nodes_map(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]:
"""
Sequentially visit the nodes list and invoke node_call_back on each element.
Returns the nodes list after the node_call_back is invoked on each element.
@ -748,7 +737,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
and gather the top-level named placeholder nodes.
"""
# gather all HOO subgraphs and their top-level named placeholder nodes
subgraph_ph_tuples: List[tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = []
subgraph_ph_tuples: list[tuple[torch.fx.GraphModule, list[torch.fx.Node]]] = []
for node in gm.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.HigherOrderOperator
@ -769,7 +758,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
# propagate names
for subgraph, hoo_phs in subgraph_ph_tuples:
name_map: Dict[str, str] = {}
name_map: dict[str, str] = {}
for i, node in enumerate(subgraph.graph.nodes):
if i < len(hoo_phs): # placeholder, retain name
name_map[node.name] = hoo_phs[i].name
@ -789,7 +778,7 @@ def placeholder_naming_pass(
fake_args,
fake_kwargs,
fake_params_buffers,
constants: Dict[str, Any],
constants: dict[str, Any],
) -> None:
"""
This pass is run at the end of _export_non_strict() to assign better placeholder node names:
@ -828,7 +817,7 @@ def placeholder_naming_pass(
else:
raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}")
name_map: Dict[str, str] = {}
name_map: dict[str, str] = {}
# map user input names with mod.forward() signature
combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs)
@ -927,7 +916,7 @@ def placeholder_naming_pass(
del constants[name]
def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict:
def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict:
"""
If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`.
`v` is the values in the dictionary.
@ -957,8 +946,8 @@ def _detect_fake_mode_from_gm(
If no fake mode is found, we return None for fake_mode.
"""
fake_inps: List[torch.Tensor] = []
fake_vals: List[torch.Tensor] = []
fake_inps: list[torch.Tensor] = []
fake_vals: list[torch.Tensor] = []
for node in gm.graph.nodes:
if node.op == "placeholder" and "val" in node.meta:
fake_val = node.meta["val"]
@ -980,8 +969,8 @@ def _detect_fake_mode_from_gm(
@contextmanager
def _disable_load_state_dict_hooks(mod: torch.nn.Module):
state_dict_hooks: Dict[int, Callable] = dict(mod._state_dict_hooks)
state_dict_pre_hooks: Dict[int, Callable] = dict(mod._state_dict_pre_hooks)
state_dict_hooks: dict[int, Callable] = dict(mod._state_dict_hooks)
state_dict_pre_hooks: dict[int, Callable] = dict(mod._state_dict_pre_hooks)
mod._state_dict_hooks.clear()
mod._state_dict_pre_hooks.clear()
try:
@ -1075,11 +1064,11 @@ def _check_valid_to_preserve(op_overload: "OperatorBase"):
@functools.lru_cache(maxsize=1)
def _collect_all_valid_cia_ops_for_aten_namespace() -> Set["OperatorBase"]:
def _collect_all_valid_cia_ops_for_aten_namespace() -> set["OperatorBase"]:
return _collect_all_valid_cia_ops_for_namespace("aten")
def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBase"]:
def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> set["OperatorBase"]:
# Step 1: Materialize all ops from C++ dispatcher
_materialize_cpp_cia_ops()
@ -1096,7 +1085,7 @@ def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBas
return cia_ops
def _collect_all_valid_cia_ops() -> Set["OperatorBase"]:
def _collect_all_valid_cia_ops() -> set["OperatorBase"]:
"""
This is an util function that gets the all CIA functional ops.
@ -1166,14 +1155,14 @@ def _compiling_state_context():
def _fakify_params_buffers(
fake_mode: FakeTensorMode,
mod: torch.nn.Module,
) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]:
) -> dict[str, Union[torch.Tensor, torch.nn.Parameter]]:
params_buffers = {
**dict(mod.named_parameters(remove_duplicate=False)),
**dict(mod.named_buffers(remove_duplicate=False)),
}
faked_params_buffers = {}
memo: Dict[int, FakeTensor] = {}
memo: dict[int, FakeTensor] = {}
for key, value in params_buffers.items():
if id(value) in memo:
fake_tensor = memo[id(value)]
@ -1184,7 +1173,7 @@ def _fakify_params_buffers(
return faked_params_buffers # type: ignore[return-value]
def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
"""
Registers a module as a valid input type for :func:`torch.export.export`.
@ -1233,7 +1222,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
def __deepcopy__(self, memo):
return PrototypeModule(self())
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
named_parameters = dict(obj.named_parameters())
named_buffers = dict(obj.named_buffers())
params_buffers = {**named_parameters, **named_buffers}
@ -1270,7 +1259,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
ret = obj
return ret
def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]:
def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, [flat_names, *args] = flatten_fn(obj) # type: ignore[misc]
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], [
flat_names,
@ -1301,7 +1290,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
from_dumpable_context=from_dumpable_context,
)
def default_flatten_fn_spec(obj, spec) -> List[Any]:
def default_flatten_fn_spec(obj, spec) -> list[Any]:
flats, context = flatten_fn(obj)
assert context == spec.context
return flats
@ -1312,6 +1301,6 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
)
def deregister_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
def deregister_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
_deregister_pytree_node(cls)
_deregister_pytree_flatten_spec(cls)

View File

@ -4,7 +4,7 @@ import inspect
import math
import operator
from collections.abc import Iterable
from typing import Any, Dict, final, List, Type, TYPE_CHECKING
from typing import Any, final, TYPE_CHECKING
import torch
from torch._ops import HigherOrderOperator, OpOverload
@ -83,7 +83,7 @@ def _check_torch_fn(node: torch.fx.Node) -> None:
raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}")
class _VerifierMeta(type):
_registry: Dict[str, Type['Verifier']] = {}
_registry: dict[str, type['Verifier']] = {}
def __new__(metacls, name, bases, attrs):
if bases:
@ -113,7 +113,7 @@ def getattr_recursive(obj: Any, target: str) -> Any:
class Verifier(metaclass=_VerifierMeta):
dialect = "ATEN"
def allowed_builtin_ops(self) -> List:
def allowed_builtin_ops(self) -> list:
return [
operator.getitem,
operator.add,
@ -141,10 +141,10 @@ class Verifier(metaclass=_VerifierMeta):
builtins.getattr,
]
def allowed_op_types(self) -> tuple[Type[Any], ...]:
def allowed_op_types(self) -> tuple[type[Any], ...]:
return (OpOverload, HigherOrderOperator)
def allowed_getattr_types(self) -> tuple[Type[Any], ...]:
def allowed_getattr_types(self) -> tuple[type[Any], ...]:
return (torch.fx.GraphModule,)
def check_valid_op(self, op):
@ -163,18 +163,18 @@ class Verifier(metaclass=_VerifierMeta):
@final
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
def _allowed_getattr_types() -> tuple[Type[Any], ...]:
def _allowed_getattr_types() -> tuple[type[Any], ...]:
ret = self.allowed_getattr_types()
assert not any(t is object for t in ret)
return ret
def _check_valid_op(op) -> None:
def _allowed_builtin_ops() -> List:
def _allowed_builtin_ops() -> list:
ret = self.allowed_builtin_ops()
assert all(inspect.isbuiltin(op) for op in ret)
return ret
def _allowed_op_types() -> tuple[Type[Any], ...]:
def _allowed_op_types() -> tuple[type[Any], ...]:
ret = self.allowed_op_types()
assert not any(t is object for t in ret)
return ret
@ -426,7 +426,7 @@ def _verify_exported_program_signature(exported_program) -> None:
num_tokens = len(gs.output_tokens)
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
mutate_nodes: List[str] = output_nodes[num_tokens:end]
mutate_nodes: list[str] = output_nodes[num_tokens:end]
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
for mutation_node in mutate_nodes:
@ -458,7 +458,7 @@ def _verify_exported_program_signature(exported_program) -> None:
)
def load_verifier(dialect: str) -> Type[Verifier]:
def load_verifier(dialect: str) -> type[Verifier]:
if dialect == "ATEN" or dialect == "":
return _VerifierMeta._registry.get(dialect, Verifier)
return _VerifierMeta._registry[dialect]