diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 1b53f63a5ab4..ddba404ccdda 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -76,14 +76,14 @@ def aot_compile_warning(): def aot_compile( f: Callable, args: tuple[Any], - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, *, - dynamic_shapes: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[dict[str, Any]] = None, + options: Optional[dict[str, Any]] = None, remove_runtime_assertions: bool = False, disable_constraint_solver: bool = False, same_signature: bool = True, -) -> Union[List[str], str]: +) -> Union[list[str], str]: """ Note: this function is not stable yet diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 25ef114c27e2..74eaebcff127 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -4,8 +4,9 @@ import logging import operator import typing import warnings +from collections.abc import Sequence from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Sequence, Set, Union +from typing import Any, Optional, Union import torch import torch.export._trace @@ -72,7 +73,7 @@ def _trace_and_get_graph_from_model(model, args): def _create_jit_graph( model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any] -) -> tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]: +) -> tuple[torch.Graph, list["_C.IValue"], Any, Optional[torch.ScriptModule]]: if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) torch_out = None @@ -263,7 +264,7 @@ def construct_fqn(ir, ref_map, name_map): def get_block_to_lifted_attrs( graph: torch._C.Graph, -) -> tuple[Dict[torch._C.Block, Set[str]], Dict[str, str]]: +) -> tuple[dict[torch._C.Block, set[str]], dict[str, str]]: """ Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes. When a graph has control flow, the graph will be divided into multiple blocks. We want to convert @@ -280,19 +281,19 @@ def get_block_to_lifted_attrs( """ # A map from a block to its expected to be lifted arguments. - blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {} + blocks_to_lifted_attrs: dict[torch._C.Block, set[str]] = {} # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a # GetAttr node. By traversing this reference map, we can figure out the # full IR aliasing pass and figure out the FQN of an attribute. # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1" - node_to_parent_map: Dict[str, str] = {} + node_to_parent_map: dict[str, str] = {} # Used for reconstructing the FQN of an attribute based on the reference map. # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR # This name map stores which attribute name is called for a src IR --> dest IR action. # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear" - node_to_attr_name: Dict[str, str] = {} + node_to_attr_name: dict[str, str] = {} def _dfs_get_attr_dependency(entry): """ @@ -315,7 +316,7 @@ def get_block_to_lifted_attrs( Walk the graph in a bottom-up fashion to build the expected to be lifted arguments for each block. """ - arguments: Set[str] = set() + arguments: set[str] = set() for node in entry.nodes(): for block in node.blocks(): # Recursively build. @@ -342,7 +343,7 @@ def get_block_to_lifted_attrs( def get_attribute_fqn_from_ts_node( - name_to_attribute_fqn: Dict[str, str], node: torch._C.Node + name_to_attribute_fqn: dict[str, str], node: torch._C.Node ) -> str: def get_attr(name: str): if name in name_to_attribute_fqn: @@ -392,12 +393,12 @@ class TS2FXGraphConverter: def __init__( self, ts_graph: Union[torch._C.Graph, torch._C.Block], - name_to_param: Dict[str, torch.Tensor], - name_to_buffer: Dict[str, torch.Tensor], - blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], - name_to_non_tensor_attribute: Dict[str, Any], - name_to_constant: Dict[str, Any], - name_to_attribute_fqn: Dict[str, str], + name_to_param: dict[str, torch.Tensor], + name_to_buffer: dict[str, torch.Tensor], + blocks_to_lifted_attrs: dict[torch._C.Block, set[str]], + name_to_non_tensor_attribute: dict[str, Any], + name_to_constant: dict[str, Any], + name_to_attribute_fqn: dict[str, str], ): self.ts_graph = ts_graph # Mapping of parameter FQN to actual parameter value @@ -406,19 +407,19 @@ class TS2FXGraphConverter: self.name_to_buffer = name_to_buffer self.fx_graph: torch.fx.Graph = torch.fx.Graph() - self.input_specs: List[InputSpec] = [] - self.output_specs: List[OutputSpec] = [] + self.input_specs: list[InputSpec] = [] + self.output_specs: list[OutputSpec] = [] # Mapping of TS node name to converted FX node - self.name_to_node: Dict[ - str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]] + self.name_to_node: dict[ + str, Union[torch.fx.Node, list[torch.fx.Node], dict[Any, torch.fx.Node]] ] = {} # Mapping of TS node name to constant value (int, str, TorchBind obj, # tensor constants ...) - self.name_to_constant: Dict[str, Any] = name_to_constant + self.name_to_constant: dict[str, Any] = name_to_constant # Mapping from torchscript node output name to attribute fully qualified name - self.name_to_attribute_fqn: Dict[str, str] = name_to_attribute_fqn + self.name_to_attribute_fqn: dict[str, str] = name_to_attribute_fqn # Mapping from fully qualified name to real values or a fx graph node # During convert, this represents the current value of a non-tensor attribute @@ -428,14 +429,14 @@ class TS2FXGraphConverter: # self.count += 1 # c2 = self.count # return x + c1 + c2 - self.name_to_non_tensor_attribute_node: Dict[str, Any] = {} + self.name_to_non_tensor_attribute_node: dict[str, Any] = {} # Mapping from fully qualified name to initial real values inputs # We separate it from self.name_to_non_tensor_attribute_node since # we need initial real value input when we construct fx.GraphModule - self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute + self.name_to_non_tensor_attribute: dict[str, Any] = name_to_non_tensor_attribute - self.subgraphs: Dict[str, torch.fx.GraphModule] = {} + self.subgraphs: dict[str, torch.fx.GraphModule] = {} # Mapping of block to list of attributes that need to be lifted for each # block @@ -457,7 +458,7 @@ class TS2FXGraphConverter: # might have inplace updates to the variable defined in the parent fx graph. After # the execution of that sub-block, the variable defined in the parent fx graph also # needs to be updated. - self.name_update_from_subblock_to_parent: Set[str] = set() + self.name_update_from_subblock_to_parent: set[str] = set() def _is_get_attr_node(self, fqn): return ( @@ -469,7 +470,7 @@ class TS2FXGraphConverter: ) ) - def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]): + def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: list[str]): subgraph_nodes, subgraph_converters = [], [] for block in node.blocks(): subgraph_converter = TS2FXGraphConverter( @@ -506,7 +507,7 @@ class TS2FXGraphConverter: Block[x.1] %2 = x.1 ... """ - arguments: Set[str] = set() + arguments: set[str] = set() for block in entry.blocks(): for block_node in block.nodes(): for block_node_in in block_node.inputs(): @@ -1332,12 +1333,12 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter): def __init__( self, ts_graph: Union[torch._C.Graph, torch._C.Block], - name_to_param: Dict[str, torch.Tensor], - name_to_buffer: Dict[str, torch.Tensor], - blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], - name_to_non_tensor_attribute: Dict[str, Any], - name_to_constant: Dict[str, Any], - name_to_attribute_fqn: Dict[str, str], + name_to_param: dict[str, torch.Tensor], + name_to_buffer: dict[str, torch.Tensor], + blocks_to_lifted_attrs: dict[torch._C.Block, set[str]], + name_to_non_tensor_attribute: dict[str, Any], + name_to_constant: dict[str, Any], + name_to_attribute_fqn: dict[str, str], ): super().__init__( ts_graph, @@ -1350,7 +1351,7 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter): ) # Data to keep track of unsupported nodes. - self.unsupported_node_list: List[torch._C.Node] = [] + self.unsupported_node_list: list[torch._C.Node] = [] # Add mock to needed attributes. self.name_to_node = ExplainTS2FXGraphConverter._DictMock( @@ -1395,7 +1396,7 @@ class TS2EPConverter: self, ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction], sample_args: tuple[Any, ...], - sample_kwargs: Optional[Dict[str, Any]] = None, + sample_kwargs: Optional[dict[str, Any]] = None, ): self.ts_model = ts_model self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) @@ -1403,8 +1404,8 @@ class TS2EPConverter: self.sample_args = sample_args self.sample_kwargs = sample_kwargs - self.name_to_param: Dict[str, torch.Tensor] = {} - self.name_to_buffer: Dict[str, torch.Tensor] = {} + self.name_to_param: dict[str, torch.Tensor] = {} + self.name_to_buffer: dict[str, torch.Tensor] = {} param_list = ( list(self.ts_model.parameters()) if not isinstance(self.ts_model, torch._C.ScriptFunction) @@ -1422,8 +1423,8 @@ class TS2EPConverter: else: self.name_to_buffer[k] = tensor - self.name_to_non_tensor_attributes: Dict[str, Any] = {} - self.name_to_constant: Dict[str, Any] = {} + self.name_to_non_tensor_attributes: dict[str, Any] = {} + self.name_to_constant: dict[str, Any] = {} self.lift_get_attr() @@ -1509,7 +1510,7 @@ DEBUG: (TORCH_LOGS="+export" ), additionally def retrace_as_exported_program( self, gm: torch.fx.GraphModule, - name_to_constant: Dict[str, Any], + name_to_constant: dict[str, Any], ): dynamic_shapes = _tree_map_with_path( lambda path, x: ( @@ -1569,7 +1570,7 @@ DEBUG: (TORCH_LOGS="+export" ), additionally # TS2FXGraphConverter since it gets attributes from self.ts_model # which is not accessable in TS2FXGraphConverter. It is similar to where # we collect self.name_to_param and self.name_to_buffer. - name_to_attribute_fqn: Dict[str, str] = {} + name_to_attribute_fqn: dict[str, str] = {} def get_attr(fqn: str): name = fqn.split(".") diff --git a/torch/_export/db/case.py b/torch/_export/db/case.py index 2ce0b9a87a53..6d32eab79d3e 100644 --- a/torch/_export/db/case.py +++ b/torch/_export/db/case.py @@ -4,12 +4,12 @@ import re import string from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional, Set +from typing import Any, Optional from types import ModuleType import torch -_TAGS: Dict[str, Dict[str, Any]] = { +_TAGS: dict[str, dict[str, Any]] = { "torch": { "cond": {}, "dynamic-shape": {}, @@ -79,12 +79,12 @@ class ExportCase: description: str # A description of the use case. model: torch.nn.Module name: str - example_kwargs: Dict[str, Any] = field(default_factory=dict) + example_kwargs: dict[str, Any] = field(default_factory=dict) extra_args: Optional[ArgsType] = None # For testing graph generalization. # Tags associated with the use case. (e.g dynamic-shape, escape-hatch) - tags: Set[str] = field(default_factory=set) + tags: set[str] = field(default_factory=set) support_level: SupportLevel = SupportLevel.SUPPORTED - dynamic_shapes: Optional[Dict[str, Any]] = None + dynamic_shapes: Optional[dict[str, Any]] = None def __post_init__(self): check_inputs_type(self.example_args, self.example_kwargs) @@ -98,10 +98,10 @@ class ExportCase: raise ValueError(f'Invalid description: "{self.description}"') -_EXAMPLE_CASES: Dict[str, ExportCase] = {} -_MODULES: Set[ModuleType] = set() -_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {} -_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {} +_EXAMPLE_CASES: dict[str, ExportCase] = {} +_MODULES: set[ModuleType] = set() +_EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {} +_EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {} def register_db_case(case: ExportCase) -> None: diff --git a/torch/_export/db/examples/list_unpack.py b/torch/_export/db/examples/list_unpack.py index 3e2f8e2469a0..98533cfab549 100644 --- a/torch/_export/db/examples/list_unpack.py +++ b/torch/_export/db/examples/list_unpack.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import List import torch @@ -9,7 +8,7 @@ class ListUnpack(torch.nn.Module): erased after tracing. """ - def forward(self, args: List[torch.Tensor]): + def forward(self, args: list[torch.Tensor]): """ Lists are treated as static construct, therefore unpacking should be erased after tracing. diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 83f9255f9c37..91d8d083fa1b 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -3,18 +3,7 @@ import contextlib import inspect import logging from collections import defaultdict -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch import torch.utils._pytree as pytree @@ -97,8 +86,8 @@ def fakify( mode: FakeTensorMode, kp: KeyPath, t: Any, - t_constraints: Dict[int, Dict[int, Constraint]], - sources: Dict[tuple[int, int], List[Source]], + t_constraints: dict[int, dict[int, Constraint]], + sources: dict[tuple[int, int], list[Source]], ): source = key_path_to_source(kp) if _is_constant_argument(t) or isinstance(t, (torch.ScriptObject, torch.nn.Module)): @@ -165,7 +154,7 @@ def make_fake_inputs( combined_args = _combine_args(nn_module, args, kwargs) _check_dynamic_shapes(combined_args, dynamic_shapes) constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) - t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) + t_constraints: dict[int, dict[int, Constraint]] = defaultdict(dict) for constraint in constraints: t_constraints[constraint.t_id][constraint.dim] = constraint @@ -214,17 +203,17 @@ def make_fake_inputs( original_signature = inspect.signature(nn_module.forward) else: original_signature = None - sources: Dict[tuple[int, int], List[Source]] = defaultdict(list) + sources: dict[tuple[int, int], list[Source]] = defaultdict(list) fake_args, fake_kwargs = tree_map_with_path( lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), (args, kwargs), ) - names: Dict[str, tuple[int, int]] = {} - source_pairs: List[tuple[Source, Source]] = [] - derived_equalities: List[tuple[Source, Union[Source, Symbol], Callable]] = [] - phantom_symbols: Dict[str, Symbol] = {} - relaxed_sources: Set[Source] = set() + names: dict[str, tuple[int, int]] = {} + source_pairs: list[tuple[Source, Source]] = [] + derived_equalities: list[tuple[Source, Union[Source, Symbol], Callable]] = [] + phantom_symbols: dict[str, Symbol] = {} + relaxed_sources: set[Source] = set() for constraint in constraints: torch.export.dynamic_shapes._process_equalities( constraint, @@ -255,9 +244,9 @@ def make_fake_inputs( def _flatten_dynamic_shapes( - combined_args: Dict[str, Any], - dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any]], -) -> List[Any]: + combined_args: dict[str, Any], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]], +) -> list[Any]: flat_shapes = [] def _tree_map_helper(path, t, shape): @@ -283,7 +272,7 @@ def _clean_dynamic_markers(tensor: torch.Tensor) -> None: def produce_guards_and_solve_constraints( fake_mode: FakeTensorMode, gm: torch.fx.GraphModule, - dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], equalities_inputs: EqualityConstraint, original_signature: inspect.Signature, _is_torch_jit_trace=False, @@ -348,8 +337,8 @@ def produce_guards_and_solve_constraints( def make_constraints( fake_mode: FakeTensorMode, gm: torch.fx.GraphModule, - combined_args: Dict[str, Any], - dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None], + combined_args: dict[str, Any], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], num_lifted_inputs: int, ): """ @@ -435,7 +424,7 @@ def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: buffers_parameters = set(m.buffers()) buffers_parameters.update(m.parameters()) - def inner(m: torch.nn.Module, prefix_atoms: List[str], constants): + def inner(m: torch.nn.Module, prefix_atoms: list[str], constants): for k, v in m.__dict__.items(): if isinstance( v, @@ -459,8 +448,8 @@ def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: def _get_graph_inputs_of_type_nn_module( - args: Optional[Tuple[Tuple[Any], Dict[Any, Any]]], -) -> Set[Type[torch.nn.Module]]: + args: Optional[tuple[tuple[Any], dict[Any, Any]]], +) -> set[type[torch.nn.Module]]: if args is None: return set() module_types = set() @@ -471,14 +460,14 @@ def _get_graph_inputs_of_type_nn_module( def _enter_enable_graph_inputs_of_type_nn_module( - module_types: Set[Type[torch.nn.Module]], + module_types: set[type[torch.nn.Module]], ) -> None: for t in module_types: torch._export.utils.register_module_as_pytree_input_node(t) def _exit_enable_graph_inputs_of_type_nn_module( - module_types: Set[Type[torch.nn.Module]], + module_types: set[type[torch.nn.Module]], ) -> None: for t in module_types: torch._export.utils.deregister_module_as_pytree_input_node(t) @@ -486,7 +475,7 @@ def _exit_enable_graph_inputs_of_type_nn_module( @contextlib.contextmanager def _enable_graph_inputs_of_type_nn_module( - args: Optional[Tuple[Tuple[Any], Dict[Any, Any]]], + args: Optional[tuple[tuple[Any], dict[Any, Any]]], ): if args is None: yield @@ -502,8 +491,8 @@ def _enable_graph_inputs_of_type_nn_module( @contextlib.contextmanager def _fakify_module_inputs( - args: Tuple[Any], - kwargs: Dict[Any, Any], + args: tuple[Any], + kwargs: dict[Any, Any], fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, ): # This context manager is used to fakify module inputs. @@ -534,7 +523,7 @@ def _fakify_module_inputs( def _fakify_script_objects( mod: torch.nn.Module, args: tuple[Any], - kwargs: Dict[Any, Any], + kwargs: dict[Any, Any], fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, ): # This context manager is used to fakify script objects into FakeScriptObject. diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index 4e8cf901f2e2..9d63811f09ed 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -3,7 +3,7 @@ import operator import traceback import typing from contextlib import nullcontext -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Optional, Union import torch from functorch.experimental.control_flow import _unstack_pytree @@ -31,7 +31,7 @@ Fn = Callable[..., Any] PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] -_TORCH_SYM_OPS: Set[Callable] = { +_TORCH_SYM_OPS: set[Callable] = { torch.sym_int, torch.sym_float, torch.sym_ite, @@ -64,9 +64,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): self.root = torch.nn.Module() self.graph = torch.fx.Graph() self.graph.set_codegen(codegen) - self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment] + self.tensor_attrs: dict[str, torch.Tensor] = {} # type: ignore[assignment] self.fake_tensor_mode: Optional[FakeTensorMode] = None - self.submodules: Dict[torch.nn.Module, str] = {} + self.submodules: dict[torch.nn.Module, str] = {} def trace(self) -> None: # type: ignore[override] raise ExportPassBaseError("ExportTracer doesn't support trace().") @@ -162,7 +162,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): self, target: str, # type: ignore[override] args: tuple[Argument, ...], - kwargs: Dict[str, Argument], + kwargs: dict[str, Argument], ) -> ProxyValue: arg = super().placeholder(target, args, kwargs) return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) @@ -171,7 +171,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): self, target: torch.fx.node.Target, args: tuple[Argument, ...], - kwargs: Dict[str, Argument], + kwargs: dict[str, Argument], ) -> ProxyValue: return self.callback.output(args[0], NodeMetadata(self.node.meta)).data # type: ignore[return-value] @@ -179,7 +179,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): self, target: torch.fx.node.Target, args: tuple[Argument, ...], - kwargs: Dict[str, Argument], + kwargs: dict[str, Argument], ) -> ProxyValue: meta = NodeMetadata(self.node.meta) @@ -218,7 +218,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): raise ExportPassBaseError(f"Unsupported target type: {target}") def get_attr( - self, target: str, args: tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] + self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override] ) -> Argument: return super().get_attr(target, args, kwargs) @@ -226,12 +226,12 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): self, target: torch.fx.node.Target, args: tuple[Argument, ...], - kwargs: Dict[str, Argument], + kwargs: dict[str, Argument], ) -> None: raise ExportPassBaseError("call_module is not supported.") def call_method( - self, target: str, args: tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] + self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override] ) -> None: raise ExportPassBaseError("call_method is not supported.") @@ -254,7 +254,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): kind: str, target: torch.fx.node.Target, args: tuple[Argument, ...], - kwargs: Dict[str, Argument], + kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: args_data, kwargs_data = pytree.tree_map_only( @@ -277,7 +277,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): self.tracer.set_metadata(res_proxy.node, res_data) return ProxyValue(res_data, res_proxy) - def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: + def inputs(self, graph_module: torch.fx.GraphModule) -> list[Argument]: # TODO(angelayi): Update this with what we decide to do for metadata in # the exported graph module if (args := graph_module.meta.get("args", None)) is not None: @@ -327,7 +327,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): self, op, args: tuple[Argument, ...], - kwargs: Dict[str, Argument], + kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: return self._fx("call_function", op, args, kwargs, meta) @@ -345,7 +345,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): pred: ProxyValue, true_fn: torch.fx.GraphModule, false_fn: torch.fx.GraphModule, - inputs: List[Argument], + inputs: list[Argument], meta: NodeMetadata, ) -> ProxyValue: true_branch = self.call_submodule(true_fn, tuple(inputs)) @@ -363,8 +363,8 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): def call_map( self, f: torch.fx.GraphModule, - mapped_args: List[ProxyValue], - operands: List[ProxyValue], + mapped_args: list[ProxyValue], + operands: list[ProxyValue], meta: NodeMetadata, ) -> ProxyValue: xs = _unstack_pytree([arg.data for arg in mapped_args])[0] @@ -383,7 +383,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase): ) -> ProxyValue: return self._fx("call_function", operator.getitem, (value, key), {}, meta) - def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: + def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue: return self._fx("output", "output", (results,), {}, meta) def call_submodule( diff --git a/torch/_export/pass_infra/node_metadata.py b/torch/_export/pass_infra/node_metadata.py index 4aa9b8093c37..9874dc1520fd 100644 --- a/torch/_export/pass_infra/node_metadata.py +++ b/torch/_export/pass_infra/node_metadata.py @@ -1,10 +1,10 @@ -from typing import Any, Dict, Set +from typing import Any NodeMetadataValue = Any -PROTECTED_KEYS: Set[str] = { +PROTECTED_KEYS: set[str] = { "val", "stack_trace", "nn_module_stack", @@ -14,8 +14,8 @@ PROTECTED_KEYS: Set[str] = { class NodeMetadata: - def __init__(self, data: Dict[str, Any]) -> None: - self.data: Dict[str, Any] = data.copy() + def __init__(self, data: dict[str, Any]) -> None: + self.data: dict[str, Any] = data.copy() def __getitem__(self, key: str) -> NodeMetadataValue: return self.data[key] diff --git a/torch/_export/pass_infra/proxy_value.py b/torch/_export/pass_infra/proxy_value.py index 01c3b6612ca8..df62c9d0ffe5 100644 --- a/torch/_export/pass_infra/proxy_value.py +++ b/torch/_export/pass_infra/proxy_value.py @@ -1,5 +1,6 @@ # pyre-strict -from typing import Union, Iterator, Iterable, Generic +from typing import Union, Generic +from collections.abc import Iterator, Iterable import torch from typing import TypeVar diff --git a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py index e8ed5931a74f..99df6c7fb635 100644 --- a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +++ b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -3,7 +3,7 @@ import math import operator import traceback from functools import partial -from typing import Callable, Dict, List, NamedTuple, Set +from typing import Callable, NamedTuple import sympy @@ -45,11 +45,11 @@ def _convert_range_to_int(range: ValueRanges): class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase): def __init__( self, - range_constraints: Dict[sympy.Symbol, ValueRanges], + range_constraints: dict[sympy.Symbol, ValueRanges], ): super().__init__() - self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints - self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set() + self.range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints + self._asserts_generated_unbacked_symbols: set[sympy.Symbol] = set() self.counter = 0 def _assert_range_constraint(self, node, lower, upper, assert_msg): @@ -105,8 +105,8 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase): # need the proxy for shape, which further requires the proxy for ret[1], etc. def add_assertions(val): - call_backs: List[Callable] = [] - messages: List[str] = [] + call_backs: list[Callable] = [] + messages: list[str] = [] if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): symbol = val.node.expr if symbol in self.existing_inline_assertions: @@ -161,9 +161,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase): def _get_existing_inline_assertions( graph_module: torch.fx.GraphModule, - range_constraints: Dict[sympy.Symbol, ValueRanges], -) -> Dict[sympy.Symbol, ValueRanges]: - existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {} + range_constraints: dict[sympy.Symbol, ValueRanges], +) -> dict[sympy.Symbol, ValueRanges]: + existing_inline_assertions: dict[sympy.Symbol, ValueRanges] = {} for module in graph_module.modules(): if not isinstance(module, torch.fx.GraphModule): diff --git a/torch/_export/passes/collect_tracepoints_pass.py b/torch/_export/passes/collect_tracepoints_pass.py index 8a7da09e35c4..8162342e50c8 100644 --- a/torch/_export/passes/collect_tracepoints_pass.py +++ b/torch/_export/passes/collect_tracepoints_pass.py @@ -2,7 +2,7 @@ from __future__ import annotations import operator -from typing import Dict, Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING, Union import torch from torch.export.exported_program import ConstantArgument, TensorArgument @@ -23,7 +23,7 @@ class CollectTracepointsPass(PassBase): """ def __init__( - self, specs: Dict[str, ModuleCallSignature], sig: ExportGraphSignature + self, specs: dict[str, ModuleCallSignature], sig: ExportGraphSignature ) -> None: super().__init__() self.specs = specs diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py index 43971ce49380..977180ed9791 100644 --- a/torch/_export/passes/constant_folding.py +++ b/torch/_export/passes/constant_folding.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import collections from collections import defaultdict -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import torch import torch.utils._pytree as pytree @@ -53,8 +53,8 @@ class ConstantFolder(torch.fx.Interpreter): skip_constructors: bool = False, ): super().__init__(gm) - self.node_replacements: Dict[torch.fx.Node, Any] = {} - self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() + self.node_replacements: dict[torch.fx.Node, Any] = {} + self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter() self.unknown_value = object() self.skip_constructors: bool = skip_constructors @@ -281,7 +281,7 @@ def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule new_graph = torch.fx.Graph() - node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} output_nodes = [] for node in gm.graph.nodes: if node.meta[META_TAG] == MODULE_TAG: diff --git a/torch/_export/passes/functionalize_side_effectful_ops_pass.py b/torch/_export/passes/functionalize_side_effectful_ops_pass.py index 12867f121496..c14e859e4ef3 100644 --- a/torch/_export/passes/functionalize_side_effectful_ops_pass.py +++ b/torch/_export/passes/functionalize_side_effectful_ops_pass.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, Optional, List +from typing import Optional import torch from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument @@ -9,7 +9,7 @@ from torch._ops import OpOverload aten = torch.ops.aten -_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = { +_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = { aten.sym_constrain_range.default: aten._functional_sym_constrain_range, aten._assert_async.msg: aten._functional_assert_async.msg, } @@ -60,7 +60,7 @@ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse): self, op: OpOverload, args: tuple[Argument, ...], - kwargs: Dict[str, Argument], + kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: @@ -88,7 +88,7 @@ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse): return self._dep_token - def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: + def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue: assert self._dep_token is not None return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type] diff --git a/torch/_export/passes/insert_custom_op_guards.py b/torch/_export/passes/insert_custom_op_guards.py index 0550e2c34e2a..bd68f2488993 100644 --- a/torch/_export/passes/insert_custom_op_guards.py +++ b/torch/_export/passes/insert_custom_op_guards.py @@ -1,5 +1,4 @@ import functools -from typing import List import torch from torch._export.passes._node_metadata_hook import ( @@ -8,7 +7,7 @@ from torch._export.passes._node_metadata_hook import ( ) -def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: List[str]) -> None: +def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: list[str]) -> None: """ This is used by draft_export to insert guards in front of calls to custom operators which have a generated fake kernel. diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 8194e801d0bf..1fd03ce4ac04 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import collections import warnings -from typing import Any, Dict, List, Union +from typing import Any, Union import torch from torch._export.verifier import SpecViolationError @@ -29,13 +29,13 @@ class ConstantAttrMap(collections.abc.MutableMapping): def __init__(self) -> None: # Underlying dict that we use to implement this mapping. - self._constant_attrs: Dict[ - Union[int, torch.Tensor, FakeScriptObject], List[Any] + self._constant_attrs: dict[ + Union[int, torch.Tensor, FakeScriptObject], list[Any] ] = {} # Map from the hash(ScriptObject) to the ScriptObject itself. Used for # APIs like `__iter__` that should look like they're returning the # original ScriptObjects. - self._script_object_map: Dict[int, torch.ScriptObject] = {} + self._script_object_map: dict[int, torch.ScriptObject] = {} def __getitem__( self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject] @@ -113,7 +113,7 @@ def lift_constants_pass( gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature, constant_attrs: ConstantAttrMap, -) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]: +) -> dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]: """ Takes a graph module, graph signature, and modifies them implace to lift any constants (tensors or custom classes) as inputs to the graph. Returns a @@ -131,7 +131,7 @@ def lift_constants_pass( Returns: A dictionary of fqn => constant value. """ - all_constants: Dict[ + all_constants: dict[ str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject] ] = {} @@ -300,13 +300,13 @@ def lift_constants_pass( def rewrite_script_object_meta( gm: torch.fx.GraphModule, -) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]: +) -> dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]: """When tracing, we produce a graph with FakeScriptObject in the meta["val"]. For now, we rewrie meta["val"] to be a placeholder CustomObjArgument """ - constants: Dict[ + constants: dict[ str, Union[ torch.Tensor, diff --git a/torch/_export/passes/replace_autocast_with_hop_pass.py b/torch/_export/passes/replace_autocast_with_hop_pass.py index 224aa9c7a883..9d415c4a0891 100644 --- a/torch/_export/passes/replace_autocast_with_hop_pass.py +++ b/torch/_export/passes/replace_autocast_with_hop_pass.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from __future__ import annotations -from typing import List, Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING, Union import torch from torch._higher_order_ops.wrap import wrap_with_autocast @@ -116,7 +116,7 @@ def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: exit_autocast # 3 E # 4 """ - enter_autocast_node_stack: List[torch.fx.Node] = [] + enter_autocast_node_stack: list[torch.fx.Node] = [] first_node_after_outer_most_exit: bool = False def node_call_back(node: torch.fx.Node) -> bool: diff --git a/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py b/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py index 1d2b4ec2a8ec..afa40d200620 100644 --- a/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py +++ b/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import logging import operator -from typing import List, Optional, Union +from typing import Optional, Union import torch import torch.export._trace @@ -269,9 +269,9 @@ def _conv1d_op_with_squeeze( inp: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], - stride: List[int], - padding: List[int], - dilation: List[int], + stride: list[int], + padding: list[int], + dilation: list[int], groups: int, ) -> torch.Tensor: # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze diff --git a/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py b/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py index 6723ac5f86a6..2043212d0f66 100644 --- a/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py +++ b/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, Optional +from typing import Optional import torch from torch._ops import OpOverload, HigherOrderOperator from torch._export.error import InternalError @@ -9,7 +9,7 @@ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse __all__ = ["ReplaceViewOpsWithViewCopyOpsPass"] -_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = { +_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: dict[OpOverload, OpOverload] = { torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, } diff --git a/torch/_export/serde/aoti_schema.py b/torch/_export/serde/aoti_schema.py index 17d5ceda0ef0..d19add43705c 100644 --- a/torch/_export/serde/aoti_schema.py +++ b/torch/_export/serde/aoti_schema.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import List from torch._export.serde.schema import Node @@ -12,4 +11,4 @@ class ExternKernelNode: @dataclass class ExternKernelNodes: - nodes: List[ExternKernelNode] + nodes: list[ExternKernelNode] diff --git a/torch/_export/serde/dynamic_shapes.py b/torch/_export/serde/dynamic_shapes.py index f7aaff8a3336..241199b56b86 100644 --- a/torch/_export/serde/dynamic_shapes.py +++ b/torch/_export/serde/dynamic_shapes.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from torch._dynamo.exc import UserError, UserErrorType @@ -24,7 +24,7 @@ class RootDim: min: int max: Union[int, None] - derived: List[str] + derived: list[str] @dataclasses.dataclass @@ -33,15 +33,15 @@ class DynamicShapesSpec: This stores a dynamic_shapes spec for de/serialization. """ - dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None] - dims: Dict[str, RootDim] + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None] + dims: dict[str, RootDim] def _postprocess_serialized_shapes( - dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None], - dims: Dict[str, Dict[str, Union[int, List[str], None]]], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], + dims: dict[str, dict[str, Union[int, list[str], None]]], to_dict: Optional[bool] = False, -) -> Union[DynamicShapesSpec, Dict[str, Any]]: +) -> Union[DynamicShapesSpec, dict[str, Any]]: """ Sorts dims and dumps to dictionary format. """ @@ -63,11 +63,11 @@ def _postprocess_serialized_shapes( def _dump_dynamic_shapes( - dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], args: tuple[Any], - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, to_dict: Optional[bool] = False, -) -> Union[DynamicShapesSpec, Dict[str, Any]]: +) -> Union[DynamicShapesSpec, dict[str, Any]]: """ Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec. Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims". @@ -127,7 +127,7 @@ def _dump_dynamic_shapes( } ``` """ - dims: Dict[str, Dict[str, Any]] = {} + dims: dict[str, dict[str, Any]] = {} def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] """ @@ -198,9 +198,9 @@ def _dump_dynamic_shapes( def _load_dynamic_shapes( - spec: Union[DynamicShapesSpec, Dict[str, Any]], + spec: Union[DynamicShapesSpec, dict[str, Any]], from_dict: Optional[bool] = False, -) -> Union[Dict[str, Any], tuple[Any], List[Any], None]: +) -> Union[dict[str, Any], tuple[Any], list[Any], None]: """ Utility function for dynamic shapes serialization. Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export(). diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 7802a950a07b..ff2d90adabca 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from enum import IntEnum -from typing import Annotated, Dict, List, Optional +from typing import Annotated, Optional from torch._export.serde.union import _Union @@ -96,10 +96,10 @@ class SymBool(_Union): @dataclass class TensorMeta: dtype: Annotated[ScalarType, 10] - sizes: Annotated[List[SymInt], 20] + sizes: Annotated[list[SymInt], 20] requires_grad: Annotated[bool, 30] device: Annotated[Device, 40] - strides: Annotated[List[SymInt], 50] + strides: Annotated[list[SymInt], 50] storage_offset: Annotated[SymInt, 60] layout: Annotated[Layout, 70] @@ -175,29 +175,29 @@ class CustomObjArgument: class Argument(_Union): as_none: Annotated[bool, 10] as_tensor: Annotated[TensorArgument, 20] - as_tensors: Annotated[List[TensorArgument], 30] + as_tensors: Annotated[list[TensorArgument], 30] as_int: Annotated[int, 50] - as_ints: Annotated[List[int], 70] + as_ints: Annotated[list[int], 70] as_float: Annotated[float, 80] - as_floats: Annotated[List[float], 90] + as_floats: Annotated[list[float], 90] as_string: Annotated[str, 100] - as_strings: Annotated[List[str], 101] + as_strings: Annotated[list[str], 101] as_sym_int: Annotated[SymIntArgument, 110] - as_sym_ints: Annotated[List[SymIntArgument], 120] + as_sym_ints: Annotated[list[SymIntArgument], 120] as_scalar_type: Annotated[ScalarType, 130] as_memory_format: Annotated[MemoryFormat, 140] as_layout: Annotated[Layout, 150] as_device: Annotated[Device, 160] as_bool: Annotated[bool, 170] - as_bools: Annotated[List[bool], 180] + as_bools: Annotated[list[bool], 180] as_sym_bool: Annotated[SymBoolArgument, 182] - as_sym_bools: Annotated[List[SymBoolArgument], 184] + as_sym_bools: Annotated[list[SymBoolArgument], 184] as_graph: Annotated[GraphArgument, 200] - as_optional_tensors: Annotated[List[OptionalTensorArgument], 190] + as_optional_tensors: Annotated[list[OptionalTensorArgument], 190] as_custom_obj: Annotated[CustomObjArgument, 210] as_operator: Annotated[str, 220] as_sym_float: Annotated[SymFloatArgument, 230] - as_sym_floats: Annotated[List[SymFloatArgument], 240] + as_sym_floats: Annotated[list[SymFloatArgument], 240] class ArgumentKind(IntEnum): @@ -217,27 +217,27 @@ class NamedArgument: @dataclass class Node: target: Annotated[str, 10] - inputs: Annotated[List[NamedArgument], 20] - outputs: Annotated[List[Argument], 30] - metadata: Annotated[Dict[str, str], 40] + inputs: Annotated[list[NamedArgument], 20] + outputs: Annotated[list[Argument], 30] + metadata: Annotated[dict[str, str], 40] is_hop_single_tensor_return: Annotated[Optional[bool], 50] = None @dataclass class Graph: - inputs: Annotated[List[Argument], 10] - outputs: Annotated[List[Argument], 20] - nodes: Annotated[List[Node], 30] - tensor_values: Annotated[Dict[str, TensorMeta], 40] - sym_int_values: Annotated[Dict[str, SymInt], 50] - sym_bool_values: Annotated[Dict[str, SymBool], 60] + inputs: Annotated[list[Argument], 10] + outputs: Annotated[list[Argument], 20] + nodes: Annotated[list[Node], 30] + tensor_values: Annotated[dict[str, TensorMeta], 40] + sym_int_values: Annotated[dict[str, SymInt], 50] + sym_bool_values: Annotated[dict[str, SymBool], 60] # This is for deserializing the submodule graphs from higher order ops # (ex. cond, map) where single tensor returns will just return a single # tensor, rather than following export schema and returning a singleton # list. is_single_tensor_return: Annotated[bool, 70] = False - custom_obj_values: Annotated[Dict[str, CustomObjArgument], 80] = field(default_factory=dict) - sym_float_values: Annotated[Dict[str, SymFloat], 90] = field(default_factory=dict) + custom_obj_values: Annotated[dict[str, CustomObjArgument], 80] = field(default_factory=dict) + sym_float_values: Annotated[dict[str, SymFloat], 90] = field(default_factory=dict) @dataclass class UserInputSpec: @@ -354,8 +354,8 @@ class OutputSpec(_Union): @dataclass class GraphSignature: - input_specs: Annotated[List[InputSpec], 10] - output_specs: Annotated[List[OutputSpec], 20] + input_specs: Annotated[list[InputSpec], 10] + output_specs: Annotated[list[OutputSpec], 20] @dataclass @@ -366,8 +366,8 @@ class RangeConstraint: @dataclass class ModuleCallSignature: - inputs: Annotated[List[Argument], 10] - outputs: Annotated[List[Argument], 20] + inputs: Annotated[list[Argument], 10] + outputs: Annotated[list[Argument], 20] # These are serialized by calling pytree.treespec_loads # And deserialized by calling pytree.treespec_dumps @@ -376,7 +376,7 @@ class ModuleCallSignature: # This field is used to prettify the graph placeholders # after we ser/der and retrace - forward_arg_names: Annotated[Optional[List[str]], 50] = None + forward_arg_names: Annotated[Optional[list[str]], 50] = None @dataclass @@ -392,8 +392,8 @@ class GraphModule: # This is used for unflattening, by tracking the calling structure of all of # the modules in order to unflatten the modules back to the eager calling # conventions. - module_call_graph: Annotated[List[ModuleCallEntry], 60] - metadata: Annotated[Dict[str, str], 40] = field(default_factory=dict) + module_call_graph: Annotated[list[ModuleCallEntry], 60] + metadata: Annotated[dict[str, str], 40] = field(default_factory=dict) # Invariant: Every time a change is made to the schema, one of the versions @@ -408,8 +408,8 @@ class SchemaVersion: class ExportedProgram: graph_module: Annotated[GraphModule, 10] # Key is the opset namespace (ex. aten), and value is the version number - opset_version: Annotated[Dict[str, int], 20] - range_constraints: Annotated[Dict[str, RangeConstraint], 30] + opset_version: Annotated[dict[str, int], 20] + range_constraints: Annotated[dict[str, RangeConstraint], 30] schema_version: Annotated[SchemaVersion, 60] - verifiers: Annotated[List[str], 70] = field(default_factory=list) + verifiers: Annotated[list[str], 70] = field(default_factory=list) torch_version: Annotated[str, 80] = "<=2.4" diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index e0a45318f7a0..5c3f7ccb9e17 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -5,7 +5,7 @@ import inspect import re import typing from enum import IntEnum -from typing import Annotated, Any, Dict, ForwardRef, List, Optional, Union +from typing import Annotated, Any, ForwardRef, Optional, Union from torch._export.serde import schema from torch._export.serde.union import _Union @@ -36,16 +36,16 @@ _THRIFT_TYPE_MAP = { def _staged_schema(): - yaml_ret: Dict[str, Any] = {} + yaml_ret: dict[str, Any] = {} defs = {} - cpp_enum_defs: Dict[str, str] = {} - cpp_class_defs: Dict[str, str] = {} - cpp_type_decls: List[str] = [] - cpp_json_defs: List[str] = [] - thrift_enum_defs: List[str] = [] - thrift_type_defs: Dict[str, str] = {} + cpp_enum_defs: dict[str, str] = {} + cpp_class_defs: dict[str, str] = {} + cpp_type_decls: list[str] = [] + cpp_json_defs: list[str] = [] + thrift_enum_defs: list[str] = [] + thrift_type_defs: dict[str, str] = {} - def _handle_aggregate(ty) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + def _handle_aggregate(ty) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: def dump_type(t, level: int) -> tuple[str, str, str]: if getattr(t, "__name__", None) in cpp_enum_defs: return t.__name__, "int64_t", t.__name__ @@ -125,7 +125,7 @@ def _staged_schema(): f"Default value {v} is not supported yet in export schema." ) - def dump_field(f) -> tuple[Dict[str, Any], str, Optional[str], str, int]: + def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: t, cpp_type, thrift_type = dump_type(f.type, 0) ret = {"type": t} cpp_default: Optional[str] = None @@ -524,12 +524,12 @@ def _hash_content(s: str): @dataclasses.dataclass class _Commit: - result: Dict[str, Any] + result: dict[str, Any] checksum_next: str yaml_path: str - additions: Dict[str, Any] - subtractions: Dict[str, Any] - base: Dict[str, Any] + additions: dict[str, Any] + subtractions: dict[str, Any] + base: dict[str, Any] checksum_head: Optional[str] cpp_header: str cpp_header_path: str diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index c23797901cb6..bfea0f4630d8 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -24,15 +24,11 @@ from typing import ( Any, Callable, cast, - Dict, final, - Iterator, - List, Optional, - Set, - Type, Union, ) +from collections.abc import Iterator import sympy @@ -118,7 +114,7 @@ class SerializeError(RuntimeError): pass -def _reverse_map(d: Dict[Any, Enum]): +def _reverse_map(d: dict[Any, Enum]): return {v.value: k for k, v in d.items()} @@ -352,7 +348,7 @@ def serialize_torch_artifact(artifact: Optional[Any], pickle_protocol: int = DEF del copyreg.dispatch_table[FakeTensor] -def deserialize_torch_artifact(serialized: Union[Dict[str, Any], tuple[Any, ...], bytes]): +def deserialize_torch_artifact(serialized: Union[dict[str, Any], tuple[Any, ...], bytes]): if isinstance(serialized, (dict, tuple)): return serialized if len(serialized) == 0: @@ -401,8 +397,8 @@ def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr: def serialize_range_constraints( - range_constraints: Dict[sympy.Symbol, ValueRanges] -) -> Dict[str, RangeConstraint]: + range_constraints: dict[sympy.Symbol, ValueRanges] +) -> dict[str, RangeConstraint]: return { str(k): RangeConstraint( _sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type] @@ -426,15 +422,15 @@ def _get_schema_from_target(target): @dataclass class GraphState: - inputs: List[Argument] = field(default_factory=list) - outputs: List[Argument] = field(default_factory=list) - nodes: List[Node] = field(default_factory=list) - tensor_values: Dict[str, TensorMeta] = field(default_factory=dict) - sym_int_values: Dict[str, SymInt] = field(default_factory=dict) - sym_bool_values: Dict[str, SymBool] = field(default_factory=dict) - sym_float_values: Dict[str, SymFloat] = field(default_factory=dict) + inputs: list[Argument] = field(default_factory=list) + outputs: list[Argument] = field(default_factory=list) + nodes: list[Node] = field(default_factory=list) + tensor_values: dict[str, TensorMeta] = field(default_factory=dict) + sym_int_values: dict[str, SymInt] = field(default_factory=dict) + sym_bool_values: dict[str, SymBool] = field(default_factory=dict) + sym_float_values: dict[str, SymFloat] = field(default_factory=dict) is_single_tensor_return: bool = False - custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) + custom_obj_values: dict[str, CustomObjArgument] = field(default_factory=dict) class Final(type): @@ -450,13 +446,13 @@ class GraphModuleSerializer(metaclass=Final): def __init__( self, graph_signature: ep.ExportGraphSignature, - module_call_graph: List[ep.ModuleCallEntry], + module_call_graph: list[ep.ModuleCallEntry], ): self.graph_state = GraphState() self.graph_signature = graph_signature self.module_call_graph = module_call_graph - self.custom_objs: Dict[str, torch._C.ScriptObject] = {} - self.duplicate_getitem_nodes: Dict[str, str] = {} + self.custom_objs: dict[str, torch._C.ScriptObject] = {} + self.duplicate_getitem_nodes: dict[str, str] = {} @contextmanager def save_graph_state(self): @@ -597,7 +593,7 @@ class GraphModuleSerializer(metaclass=Final): else: return user_node.name - def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: + def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]: ret = {} if stack_trace := node.meta.get("stack_trace"): ret["stack_trace"] = stack_trace @@ -647,7 +643,7 @@ class GraphModuleSerializer(metaclass=Final): class_fqn=script_obj_meta.class_fqn, ) - def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]: + def serialize_sym_op_inputs(self, op, args) -> list[NamedArgument]: if isinstance(op, torch._ops.OpOverload): args_names = [arg.name for arg in op._schema.arguments] else: @@ -669,7 +665,7 @@ class GraphModuleSerializer(metaclass=Final): target: Any, # torch._ops.OpOverload and other custom operator types. args, kwargs=None - ) -> List[NamedArgument]: + ) -> list[NamedArgument]: assert isinstance(target, (torch._ops.OpOverload, *_registered_extension_types())) kwargs = kwargs or {} serialized_args = [] @@ -700,7 +696,7 @@ class GraphModuleSerializer(metaclass=Final): return serialized_args - def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]: + def serialize_hoo_inputs(self, args, kwargs) -> list[NamedArgument]: """ For serializing HOO inputs since HOOs do not have a schema. """ @@ -1180,8 +1176,8 @@ class GraphModuleSerializer(metaclass=Final): ) def serialize_module_call_graph( - self, module_call_graph: List[ep.ModuleCallEntry] - ) -> List[ModuleCallEntry]: + self, module_call_graph: list[ep.ModuleCallEntry] + ) -> list[ModuleCallEntry]: return [ ModuleCallEntry( fqn=entry.fqn, @@ -1194,7 +1190,7 @@ class GraphModuleSerializer(metaclass=Final): for entry in module_call_graph ] - def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: + def serialize_outputs(self, node: torch.fx.Node) -> list[Argument]: """For a given node, return the dataclass representing its output values. [NOTE: Multiple outputs] We handle aggregates differently than FX. For @@ -1294,7 +1290,7 @@ class GraphModuleSerializer(metaclass=Final): return output_arguments - def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: + def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]: """ For serializing HOO outputs since HOOs do not have a schema. """ @@ -1359,7 +1355,7 @@ class GraphModuleSerializer(metaclass=Final): # list outputs should've been handled earlier raise SerializeError(f"Unable to serialize output {meta_val}") - def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]: + def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]: meta_val = node.meta["val"] idx_to_name = {} @@ -1406,7 +1402,7 @@ class GraphModuleSerializer(metaclass=Final): is_single_tensor_return=self.graph_state.is_single_tensor_return, ) - def serialize_graph_module_metadata(self, meta: Dict[str, Any]): + def serialize_graph_module_metadata(self, meta: dict[str, Any]): ret = {} if custom := meta.get("custom"): try: @@ -1431,8 +1427,8 @@ class GraphModuleSerializer(metaclass=Final): @final class ExportedProgramSerializer(metaclass=Final): - def __init__(self, opset_version: Optional[Dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL): - self.opset_version: Dict[str, int] = {} + def __init__(self, opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL): + self.opset_version: dict[str, int] = {} if opset_version: self.opset_version.update(opset_version) if "aten" not in self.opset_version: @@ -1498,15 +1494,15 @@ class GraphModuleDeserializer(metaclass=Final): class Result: graph_module: torch.fx.GraphModule signature: ep.ExportGraphSignature - module_call_graph: List[ep.ModuleCallEntry] - names_to_symbols: Dict[str, sympy.Symbol] - state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] - constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]] - example_inputs: Optional[tuple[tuple[torch.Tensor, ...], Dict[str, Any]]] + module_call_graph: list[ep.ModuleCallEntry] + names_to_symbols: dict[str, sympy.Symbol] + state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]] + constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]] + example_inputs: Optional[tuple[tuple[torch.Tensor, ...], dict[str, Any]]] def __init__(self) -> None: - self.serialized_name_to_node: Dict[str, torch.fx.Node] = {} - self.serialized_name_to_meta: Dict[str, MetaType] = {} + self.serialized_name_to_node: dict[str, torch.fx.Node] = {} + self.serialized_name_to_meta: dict[str, MetaType] = {} self.graph = torch.fx.Graph() self.module = torch.nn.Module() @@ -1953,10 +1949,10 @@ class GraphModuleDeserializer(metaclass=Final): def deserialize( self, serialized_graph_module: GraphModule, - serialized_state_dict: Union[Dict[str, torch.Tensor], bytes], - constants: Union[Dict[str, Any], bytes], - example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None, - symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, + serialized_state_dict: Union[dict[str, torch.Tensor], bytes], + constants: Union[dict[str, Any], bytes], + example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]] = None, + symbol_name_to_range: Optional[dict[str, symbolic_shapes.ValueRanges]] = None, ) -> Result: global _CURRENT_DESERIALIZER assert _CURRENT_DESERIALIZER is None @@ -1996,7 +1992,7 @@ class GraphModuleDeserializer(metaclass=Final): "ToFloat": torch.utils._sympy.functions.ToFloat, "Identity": torch.utils._sympy.functions.Identity, } - self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} + self.symbol_name_to_symbol: dict[str, sympy.Symbol] = {} self.constants = deserialize_torch_artifact(constants) self.signature = self.deserialize_signature(serialized_graph_module.signature) @@ -2091,7 +2087,7 @@ class GraphModuleDeserializer(metaclass=Final): kwargs[schema_arg.name] = actual_args[schema_arg.name] return tuple(args), kwargs - def deserialize_hoo_inputs(self, inputs: List[NamedArgument]): + def deserialize_hoo_inputs(self, inputs: list[NamedArgument]): """ For deserializing HOO inputs since HOOs do not have a schema. """ @@ -2232,7 +2228,7 @@ class GraphModuleDeserializer(metaclass=Final): "torch.ops.higher_order" in serialized_node.target and not getattr(serialized_node, "is_hop_single_tensor_return", True) ): - meta_val: List[Any] = [] + meta_val: list[Any] = [] arg = serialized_node.outputs[0].as_tensor deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata) @@ -2261,7 +2257,7 @@ class GraphModuleDeserializer(metaclass=Final): fx_node: torch.fx.Node, arg: Union[TensorArgument, SymIntArgument, SymFloatArgument], idx: int, - deserialized_metadata: Dict[str, Any], + deserialized_metadata: dict[str, Any], ): if isinstance(arg, TensorArgument): name = arg.name @@ -2290,7 +2286,7 @@ class GraphModuleDeserializer(metaclass=Final): meta_val, fx_node: torch.fx.Node, args, - deserialized_metadata: Dict[str, Any], + deserialized_metadata: dict[str, Any], ): for idx, arg in enumerate(args): if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)): @@ -2343,7 +2339,7 @@ class GraphModuleDeserializer(metaclass=Final): # return value. # This performs the inverse mapping of the `serialize_outputs` call in # serialization, see [NOTE: Multiple outputs] - meta_val: List[Any] = [] + meta_val: list[Any] = [] if len(serialized_node.outputs) == 1: assert isinstance(serialized_node.outputs[0].value, list) assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) @@ -2355,8 +2351,8 @@ class GraphModuleDeserializer(metaclass=Final): fx_node.meta["val"] = tuple(meta_val) self.serialized_name_to_node[fx_node.name] = fx_node - def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: - ret: Dict[str, Any] = {} + def deserialize_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: + ret: dict[str, Any] = {} if stack_trace := metadata.get("stack_trace"): ret["stack_trace"] = stack_trace @@ -2446,8 +2442,8 @@ class GraphModuleDeserializer(metaclass=Final): ) def deserialize_module_call_graph( - self, module_call_graph: List[ModuleCallEntry] - ) -> List[ep.ModuleCallEntry]: + self, module_call_graph: list[ModuleCallEntry] + ) -> list[ep.ModuleCallEntry]: return [ ep.ModuleCallEntry( fqn=entry.fqn, @@ -2463,8 +2459,8 @@ class GraphModuleDeserializer(metaclass=Final): @final class ExportedProgramDeserializer(metaclass=Final): - def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None): - self.expected_opset_version: Dict[str, int] = {} + def __init__(self, expected_opset_version: Optional[dict[str, int]] = None): + self.expected_opset_version: dict[str, int] = {} if expected_opset_version: self.expected_opset_version.update(expected_opset_version) if "aten" not in self.expected_opset_version: @@ -2472,9 +2468,9 @@ class ExportedProgramDeserializer(metaclass=Final): def deserialize_range_constraints( self, - symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges], - symbol_name_to_symbol: Dict[str, sympy.Symbol], - ) -> Dict[sympy.Symbol, ValueRanges]: + symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges], + symbol_name_to_symbol: dict[str, sympy.Symbol], + ) -> dict[sympy.Symbol, ValueRanges]: range_constraints = {} for k, v in symbol_name_to_range.items(): if symbol := symbol_name_to_symbol.get(k): @@ -2486,9 +2482,9 @@ class ExportedProgramDeserializer(metaclass=Final): def deserialize( self, exported_program: ExportedProgram, - state_dict: Union[Dict[str, torch.Tensor], bytes], - constants: Union[Dict[str, torch.Tensor], bytes], - example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None, + state_dict: Union[dict[str, torch.Tensor], bytes], + constants: Union[dict[str, torch.Tensor], bytes], + example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]] = None, *, _unsafe_skip_version_check=False, ) -> ep.ExportedProgram: @@ -2566,7 +2562,7 @@ def _dataclass_to_dict(obj): def serialize( exported_program: ep.ExportedProgram, - opset_version: Optional[Dict[str, int]] = None, + opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, ) -> SerializedArtifact: with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs): @@ -2627,7 +2623,7 @@ def _dict_to_dataclass(cls, data): def deserialize( artifact: SerializedArtifact, - expected_opset_version: Optional[Dict[str, int]] = None, + expected_opset_version: Optional[dict[str, int]] = None, *, _unsafe_skip_version_check=False, ) -> ep.ExportedProgram: @@ -2649,7 +2645,7 @@ def deserialize( def _canonicalize_graph( sorted_inputs, sorted_outputs, graph -) -> tuple[Graph, Dict[str, str]]: +) -> tuple[Graph, dict[str, str]]: def _get_argument(a: Argument): if a.type == "as_none": return None @@ -2712,15 +2708,15 @@ def _canonicalize_graph( def sort_nodes(nodes): @dataclass class Edges: - outs: List[int] + outs: list[int] ins: int - graph_inputs: Set[str] = set() - def_table: Dict[str, int] = {} - edges: Dict[int, Edges] = {} - candidates: List[tuple[str, List[tuple[str, List[int]]], int]] = [] - rank: Dict[str, int] = {} - ret: List[Node] = [] + graph_inputs: set[str] = set() + def_table: dict[str, int] = {} + edges: dict[int, Edges] = {} + candidates: list[tuple[str, list[tuple[str, list[int]]], int]] = [] + rank: dict[str, int] = {} + ret: list[Node] = [] def get_name(a) -> Optional[str]: if a is None: @@ -2827,7 +2823,7 @@ def _canonicalize_graph( assert len(sorted_nodes) == len(graph.nodes) # Stage 2: Rename nodes. - name_table: Dict[str, str] = {} + name_table: dict[str, str] = {} def rename_def(a): def _rename(arg_name, values): @@ -3163,8 +3159,8 @@ class ExtensionHandler: def register_extension( - op_type: Type[Any], - extension_handler: Type[ExtensionHandler], + op_type: type[Any], + extension_handler: type[ExtensionHandler], ): """Register custom de/serialization method for a node with non-standard type.""" assert issubclass(extension_handler, ExtensionHandler), f"Expected ExtensionHandler, got {extension_handler}." @@ -3187,5 +3183,5 @@ def _registered_extension_types(): # namespace to avoid conflicts. # Serialization: Op type --> custom handler. # De-serialization: Namespace --> custom handler. -_serialization_registry: Dict[Type[Any], Type[ExtensionHandler]] = {} -_deserialization_registry: Dict[str, Type[ExtensionHandler]] = {} +_serialization_registry: dict[type[Any], type[ExtensionHandler]] = {} +_deserialization_registry: dict[str, type[ExtensionHandler]] = {} diff --git a/torch/_export/serde/union.py b/torch/_export/serde/union.py index b129e8dd9a89..006b809e1e56 100644 --- a/torch/_export/serde/union.py +++ b/torch/_export/serde/union.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools +from collections.abc import Hashable from dataclasses import fields -from typing import Hashable, Set class _UnionTag(str): @@ -26,8 +26,8 @@ class _UnionTag(str): return hash(str(self)) -@functools.lru_cache(maxsize=None) -def _get_field_names(cls) -> Set[str]: +@functools.cache +def _get_field_names(cls) -> set[str]: return {f.name for f in fields(cls)} diff --git a/torch/_export/tools.py b/torch/_export/tools.py index f054532e7779..0007de25d3e9 100644 --- a/torch/_export/tools.py +++ b/torch/_export/tools.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import logging import warnings -from typing import Any, Dict, Iterable, Optional +from collections.abc import Iterable +from typing import Any, Optional import torch import torch.export @@ -18,8 +19,8 @@ def _generate_inputs_for_submodules( model: torch.nn.Module, target_submodules: Iterable[str], args: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, -) -> Dict[str, tuple[Any, Any]]: + kwargs: Optional[dict[str, Any]] = None, +) -> dict[str, tuple[Any, Any]]: """ Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this function doesn't work. @@ -61,11 +62,11 @@ def _generate_inputs_for_submodules( def report_exportability( mod: torch.nn.Module, args: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, *, strict: bool = True, pre_dispatch: bool = False, -) -> Dict[str, Optional[Exception]]: +) -> dict[str, Optional[Exception]]: """ Report exportability issues for a module in one-shot. @@ -92,7 +93,7 @@ def report_exportability( submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) tried_module_types = set() - report: Dict[str, Optional[Exception]] = {} + report: dict[str, Optional[Exception]] = {} def try_export(module, module_name, args, kwargs): nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 004ffb114b2f..e3031ab04ef1 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -8,21 +8,10 @@ import json import math import operator import re +from collections.abc import Iterable from contextlib import contextmanager from inspect import Parameter -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torch._guards import detect_fake_mode @@ -116,7 +105,7 @@ def _overwrite_signature_for_non_persistent_buffers( return new_sig -def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]: +def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> dict[str, Any]: """ Param/buffer metadata needs to be saved before lowering to aten IR because aten IR lifts them, as a result, automatic preservation doesn't work. @@ -174,7 +163,7 @@ def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]: def _populate_param_buffer_metadata_to_new_gm( - params_buffers_to_node_meta: Dict[str, Any], + params_buffers_to_node_meta: dict[str, Any], gm: torch.fx.GraphModule, new_sig: "ExportGraphSignature", ) -> None: @@ -217,7 +206,7 @@ def _get_shape_env_from_gm(gm: torch.fx.GraphModule): def _rename_without_collisions( - name_map: Dict[str, str], + name_map: dict[str, str], orig_name: str, name: str, is_placeholder: bool = False, @@ -246,7 +235,7 @@ def _rename_without_collisions( def _check_input_constraints_for_graph( - input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints + input_placeholders: list[torch.fx.Node], flat_args_with_path, range_constraints ) -> None: def get_keystr(key_path: KeyPath) -> str: """For a given index into the flat_args, return a human readable string @@ -280,7 +269,7 @@ def _check_input_constraints_for_graph( # NOTE: export already guarantees that the same symbol is used in metadata # for all InputDims related by equality constraints, so we can just unify # symbols with given input dimension values to check equality constraints. - unification_map: Dict[sympy.Symbol, Any] = {} + unification_map: dict[sympy.Symbol, Any] = {} for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): node_val = node.meta.get("val") if isinstance(node_val, FakeTensor): @@ -372,7 +361,7 @@ def _check_input_constraints_for_graph( def register_dataclass_as_pytree_node( - cls: Type[Any], + cls: type[Any], flatten_fn: Optional[FlattenFunc] = None, unflatten_fn: Optional[UnflattenFunc] = None, *, @@ -385,7 +374,7 @@ def register_dataclass_as_pytree_node( cls ), f"Only dataclasses can be registered with this function: {cls}" - def default_flatten_fn(obj: Any) -> tuple[List[Any], Context]: + def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]: flattened = [] flat_names = [] none_names = [] @@ -402,7 +391,7 @@ def register_dataclass_as_pytree_node( flat_names, none_names = context return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) - def default_flatten_fn_with_keys(obj: Any) -> tuple[List[Any], Context]: + def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: flattened, (flat_names, _none_names) = flatten_fn(obj) # type: ignore[misc] return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names @@ -537,7 +526,7 @@ def sequential_split( return new_gm -def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: +def nodes_filter(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]: """Returns the nodes that match the node_call_back as a list.""" return [node for node in nodes if node_call_back(node)] @@ -572,7 +561,7 @@ def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature): def nodes_first( - nodes: List[torch.fx.Node], node_call_back=None + nodes: list[torch.fx.Node], node_call_back=None ) -> Optional[torch.fx.Node]: """ Returns the first node that matches the node_call_back. If no node matches, returns None. @@ -584,12 +573,12 @@ def nodes_first( return None -def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int: +def nodes_count(nodes: list[torch.fx.Node], node_call_back) -> int: """Returns the number of nodes that match the node_call_back.""" return len(nodes_filter(nodes, node_call_back)) -def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: +def nodes_map(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]: """ Sequentially visit the nodes list and invoke node_call_back on each element. Returns the nodes list after the node_call_back is invoked on each element. @@ -748,7 +737,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: and gather the top-level named placeholder nodes. """ # gather all HOO subgraphs and their top-level named placeholder nodes - subgraph_ph_tuples: List[tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] + subgraph_ph_tuples: list[tuple[torch.fx.GraphModule, list[torch.fx.Node]]] = [] for node in gm.graph.nodes: if node.op == "call_function" and isinstance( node.target, torch._ops.HigherOrderOperator @@ -769,7 +758,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: # propagate names for subgraph, hoo_phs in subgraph_ph_tuples: - name_map: Dict[str, str] = {} + name_map: dict[str, str] = {} for i, node in enumerate(subgraph.graph.nodes): if i < len(hoo_phs): # placeholder, retain name name_map[node.name] = hoo_phs[i].name @@ -789,7 +778,7 @@ def placeholder_naming_pass( fake_args, fake_kwargs, fake_params_buffers, - constants: Dict[str, Any], + constants: dict[str, Any], ) -> None: """ This pass is run at the end of _export_non_strict() to assign better placeholder node names: @@ -828,7 +817,7 @@ def placeholder_naming_pass( else: raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") - name_map: Dict[str, str] = {} + name_map: dict[str, str] = {} # map user input names with mod.forward() signature combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) @@ -927,7 +916,7 @@ def placeholder_naming_pass( del constants[name] -def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict: +def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict: """ If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`. `v` is the values in the dictionary. @@ -957,8 +946,8 @@ def _detect_fake_mode_from_gm( If no fake mode is found, we return None for fake_mode. """ - fake_inps: List[torch.Tensor] = [] - fake_vals: List[torch.Tensor] = [] + fake_inps: list[torch.Tensor] = [] + fake_vals: list[torch.Tensor] = [] for node in gm.graph.nodes: if node.op == "placeholder" and "val" in node.meta: fake_val = node.meta["val"] @@ -980,8 +969,8 @@ def _detect_fake_mode_from_gm( @contextmanager def _disable_load_state_dict_hooks(mod: torch.nn.Module): - state_dict_hooks: Dict[int, Callable] = dict(mod._state_dict_hooks) - state_dict_pre_hooks: Dict[int, Callable] = dict(mod._state_dict_pre_hooks) + state_dict_hooks: dict[int, Callable] = dict(mod._state_dict_hooks) + state_dict_pre_hooks: dict[int, Callable] = dict(mod._state_dict_pre_hooks) mod._state_dict_hooks.clear() mod._state_dict_pre_hooks.clear() try: @@ -1075,11 +1064,11 @@ def _check_valid_to_preserve(op_overload: "OperatorBase"): @functools.lru_cache(maxsize=1) -def _collect_all_valid_cia_ops_for_aten_namespace() -> Set["OperatorBase"]: +def _collect_all_valid_cia_ops_for_aten_namespace() -> set["OperatorBase"]: return _collect_all_valid_cia_ops_for_namespace("aten") -def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBase"]: +def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> set["OperatorBase"]: # Step 1: Materialize all ops from C++ dispatcher _materialize_cpp_cia_ops() @@ -1096,7 +1085,7 @@ def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBas return cia_ops -def _collect_all_valid_cia_ops() -> Set["OperatorBase"]: +def _collect_all_valid_cia_ops() -> set["OperatorBase"]: """ This is an util function that gets the all CIA functional ops. @@ -1166,14 +1155,14 @@ def _compiling_state_context(): def _fakify_params_buffers( fake_mode: FakeTensorMode, mod: torch.nn.Module, -) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]: +) -> dict[str, Union[torch.Tensor, torch.nn.Parameter]]: params_buffers = { **dict(mod.named_parameters(remove_duplicate=False)), **dict(mod.named_buffers(remove_duplicate=False)), } faked_params_buffers = {} - memo: Dict[int, FakeTensor] = {} + memo: dict[int, FakeTensor] = {} for key, value in params_buffers.items(): if id(value) in memo: fake_tensor = memo[id(value)] @@ -1184,7 +1173,7 @@ def _fakify_params_buffers( return faked_params_buffers # type: ignore[return-value] -def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None: +def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None: """ Registers a module as a valid input type for :func:`torch.export.export`. @@ -1233,7 +1222,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None: def __deepcopy__(self, memo): return PrototypeModule(self()) - def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: + def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]: named_parameters = dict(obj.named_parameters()) named_buffers = dict(obj.named_buffers()) params_buffers = {**named_parameters, **named_buffers} @@ -1270,7 +1259,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None: ret = obj return ret - def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]: + def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: flattened, [flat_names, *args] = flatten_fn(obj) # type: ignore[misc] return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], [ flat_names, @@ -1301,7 +1290,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None: from_dumpable_context=from_dumpable_context, ) - def default_flatten_fn_spec(obj, spec) -> List[Any]: + def default_flatten_fn_spec(obj, spec) -> list[Any]: flats, context = flatten_fn(obj) assert context == spec.context return flats @@ -1312,6 +1301,6 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None: ) -def deregister_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None: +def deregister_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None: _deregister_pytree_node(cls) _deregister_pytree_flatten_spec(cls) diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index d04c71213403..ad5380b04c9c 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -4,7 +4,7 @@ import inspect import math import operator from collections.abc import Iterable -from typing import Any, Dict, final, List, Type, TYPE_CHECKING +from typing import Any, final, TYPE_CHECKING import torch from torch._ops import HigherOrderOperator, OpOverload @@ -83,7 +83,7 @@ def _check_torch_fn(node: torch.fx.Node) -> None: raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}") class _VerifierMeta(type): - _registry: Dict[str, Type['Verifier']] = {} + _registry: dict[str, type['Verifier']] = {} def __new__(metacls, name, bases, attrs): if bases: @@ -113,7 +113,7 @@ def getattr_recursive(obj: Any, target: str) -> Any: class Verifier(metaclass=_VerifierMeta): dialect = "ATEN" - def allowed_builtin_ops(self) -> List: + def allowed_builtin_ops(self) -> list: return [ operator.getitem, operator.add, @@ -141,10 +141,10 @@ class Verifier(metaclass=_VerifierMeta): builtins.getattr, ] - def allowed_op_types(self) -> tuple[Type[Any], ...]: + def allowed_op_types(self) -> tuple[type[Any], ...]: return (OpOverload, HigherOrderOperator) - def allowed_getattr_types(self) -> tuple[Type[Any], ...]: + def allowed_getattr_types(self) -> tuple[type[Any], ...]: return (torch.fx.GraphModule,) def check_valid_op(self, op): @@ -163,18 +163,18 @@ class Verifier(metaclass=_VerifierMeta): @final def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: - def _allowed_getattr_types() -> tuple[Type[Any], ...]: + def _allowed_getattr_types() -> tuple[type[Any], ...]: ret = self.allowed_getattr_types() assert not any(t is object for t in ret) return ret def _check_valid_op(op) -> None: - def _allowed_builtin_ops() -> List: + def _allowed_builtin_ops() -> list: ret = self.allowed_builtin_ops() assert all(inspect.isbuiltin(op) for op in ret) return ret - def _allowed_op_types() -> tuple[Type[Any], ...]: + def _allowed_op_types() -> tuple[type[Any], ...]: ret = self.allowed_op_types() assert not any(t is object for t in ret) return ret @@ -426,7 +426,7 @@ def _verify_exported_program_signature(exported_program) -> None: num_tokens = len(gs.output_tokens) end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens - mutate_nodes: List[str] = output_nodes[num_tokens:end] + mutate_nodes: list[str] = output_nodes[num_tokens:end] user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] for mutation_node in mutate_nodes: @@ -458,7 +458,7 @@ def _verify_exported_program_signature(exported_program) -> None: ) -def load_verifier(dialect: str) -> Type[Verifier]: +def load_verifier(dialect: str) -> type[Verifier]: if dialect == "ATEN" or dialect == "": return _VerifierMeta._registry.get(dialect, Verifier) return _VerifierMeta._registry[dialect]