PEP585 update - torch/_export (#145138)

See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145138
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #145154
This commit is contained in:
Aaron Orenstein
2025-01-18 08:47:47 -08:00
committed by PyTorch MergeBot
parent cd8d0fa20c
commit 97d4d3c40a
26 changed files with 325 additions and 351 deletions

View File

@ -76,14 +76,14 @@ def aot_compile_warning():
def aot_compile( def aot_compile(
f: Callable, f: Callable,
args: tuple[Any], args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
*, *,
dynamic_shapes: Optional[Dict[str, Any]] = None, dynamic_shapes: Optional[dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None, options: Optional[dict[str, Any]] = None,
remove_runtime_assertions: bool = False, remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False, disable_constraint_solver: bool = False,
same_signature: bool = True, same_signature: bool = True,
) -> Union[List[str], str]: ) -> Union[list[str], str]:
""" """
Note: this function is not stable yet Note: this function is not stable yet

View File

@ -4,8 +4,9 @@ import logging
import operator import operator
import typing import typing
import warnings import warnings
from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Sequence, Set, Union from typing import Any, Optional, Union
import torch import torch
import torch.export._trace import torch.export._trace
@ -72,7 +73,7 @@ def _trace_and_get_graph_from_model(model, args):
def _create_jit_graph( def _create_jit_graph(
model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any] model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
) -> tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]: ) -> tuple[torch.Graph, list["_C.IValue"], Any, Optional[torch.ScriptModule]]:
if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
torch_out = None torch_out = None
@ -263,7 +264,7 @@ def construct_fqn(ir, ref_map, name_map):
def get_block_to_lifted_attrs( def get_block_to_lifted_attrs(
graph: torch._C.Graph, graph: torch._C.Graph,
) -> tuple[Dict[torch._C.Block, Set[str]], Dict[str, str]]: ) -> tuple[dict[torch._C.Block, set[str]], dict[str, str]]:
""" """
Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes. Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes.
When a graph has control flow, the graph will be divided into multiple blocks. We want to convert When a graph has control flow, the graph will be divided into multiple blocks. We want to convert
@ -280,19 +281,19 @@ def get_block_to_lifted_attrs(
""" """
# A map from a block to its expected to be lifted arguments. # A map from a block to its expected to be lifted arguments.
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {} blocks_to_lifted_attrs: dict[torch._C.Block, set[str]] = {}
# Reference map stores the input (i.e., src) and output (i.e., dest) IR of a # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a
# GetAttr node. By traversing this reference map, we can figure out the # GetAttr node. By traversing this reference map, we can figure out the
# full IR aliasing pass and figure out the FQN of an attribute. # full IR aliasing pass and figure out the FQN of an attribute.
# E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1" # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1"
node_to_parent_map: Dict[str, str] = {} node_to_parent_map: dict[str, str] = {}
# Used for reconstructing the FQN of an attribute based on the reference map. # Used for reconstructing the FQN of an attribute based on the reference map.
# In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR
# This name map stores which attribute name is called for a src IR --> dest IR action. # This name map stores which attribute name is called for a src IR --> dest IR action.
# E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear" # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear"
node_to_attr_name: Dict[str, str] = {} node_to_attr_name: dict[str, str] = {}
def _dfs_get_attr_dependency(entry): def _dfs_get_attr_dependency(entry):
""" """
@ -315,7 +316,7 @@ def get_block_to_lifted_attrs(
Walk the graph in a bottom-up fashion to build the expected to be Walk the graph in a bottom-up fashion to build the expected to be
lifted arguments for each block. lifted arguments for each block.
""" """
arguments: Set[str] = set() arguments: set[str] = set()
for node in entry.nodes(): for node in entry.nodes():
for block in node.blocks(): for block in node.blocks():
# Recursively build. # Recursively build.
@ -342,7 +343,7 @@ def get_block_to_lifted_attrs(
def get_attribute_fqn_from_ts_node( def get_attribute_fqn_from_ts_node(
name_to_attribute_fqn: Dict[str, str], node: torch._C.Node name_to_attribute_fqn: dict[str, str], node: torch._C.Node
) -> str: ) -> str:
def get_attr(name: str): def get_attr(name: str):
if name in name_to_attribute_fqn: if name in name_to_attribute_fqn:
@ -392,12 +393,12 @@ class TS2FXGraphConverter:
def __init__( def __init__(
self, self,
ts_graph: Union[torch._C.Graph, torch._C.Block], ts_graph: Union[torch._C.Graph, torch._C.Block],
name_to_param: Dict[str, torch.Tensor], name_to_param: dict[str, torch.Tensor],
name_to_buffer: Dict[str, torch.Tensor], name_to_buffer: dict[str, torch.Tensor],
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], blocks_to_lifted_attrs: dict[torch._C.Block, set[str]],
name_to_non_tensor_attribute: Dict[str, Any], name_to_non_tensor_attribute: dict[str, Any],
name_to_constant: Dict[str, Any], name_to_constant: dict[str, Any],
name_to_attribute_fqn: Dict[str, str], name_to_attribute_fqn: dict[str, str],
): ):
self.ts_graph = ts_graph self.ts_graph = ts_graph
# Mapping of parameter FQN to actual parameter value # Mapping of parameter FQN to actual parameter value
@ -406,19 +407,19 @@ class TS2FXGraphConverter:
self.name_to_buffer = name_to_buffer self.name_to_buffer = name_to_buffer
self.fx_graph: torch.fx.Graph = torch.fx.Graph() self.fx_graph: torch.fx.Graph = torch.fx.Graph()
self.input_specs: List[InputSpec] = [] self.input_specs: list[InputSpec] = []
self.output_specs: List[OutputSpec] = [] self.output_specs: list[OutputSpec] = []
# Mapping of TS node name to converted FX node # Mapping of TS node name to converted FX node
self.name_to_node: Dict[ self.name_to_node: dict[
str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]] str, Union[torch.fx.Node, list[torch.fx.Node], dict[Any, torch.fx.Node]]
] = {} ] = {}
# Mapping of TS node name to constant value (int, str, TorchBind obj, # Mapping of TS node name to constant value (int, str, TorchBind obj,
# tensor constants ...) # tensor constants ...)
self.name_to_constant: Dict[str, Any] = name_to_constant self.name_to_constant: dict[str, Any] = name_to_constant
# Mapping from torchscript node output name to attribute fully qualified name # Mapping from torchscript node output name to attribute fully qualified name
self.name_to_attribute_fqn: Dict[str, str] = name_to_attribute_fqn self.name_to_attribute_fqn: dict[str, str] = name_to_attribute_fqn
# Mapping from fully qualified name to real values or a fx graph node # Mapping from fully qualified name to real values or a fx graph node
# During convert, this represents the current value of a non-tensor attribute # During convert, this represents the current value of a non-tensor attribute
@ -428,14 +429,14 @@ class TS2FXGraphConverter:
# self.count += 1 # self.count += 1
# c2 = self.count # c2 = self.count
# return x + c1 + c2 # return x + c1 + c2
self.name_to_non_tensor_attribute_node: Dict[str, Any] = {} self.name_to_non_tensor_attribute_node: dict[str, Any] = {}
# Mapping from fully qualified name to initial real values inputs # Mapping from fully qualified name to initial real values inputs
# We separate it from self.name_to_non_tensor_attribute_node since # We separate it from self.name_to_non_tensor_attribute_node since
# we need initial real value input when we construct fx.GraphModule # we need initial real value input when we construct fx.GraphModule
self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute self.name_to_non_tensor_attribute: dict[str, Any] = name_to_non_tensor_attribute
self.subgraphs: Dict[str, torch.fx.GraphModule] = {} self.subgraphs: dict[str, torch.fx.GraphModule] = {}
# Mapping of block to list of attributes that need to be lifted for each # Mapping of block to list of attributes that need to be lifted for each
# block # block
@ -457,7 +458,7 @@ class TS2FXGraphConverter:
# might have inplace updates to the variable defined in the parent fx graph. After # might have inplace updates to the variable defined in the parent fx graph. After
# the execution of that sub-block, the variable defined in the parent fx graph also # the execution of that sub-block, the variable defined in the parent fx graph also
# needs to be updated. # needs to be updated.
self.name_update_from_subblock_to_parent: Set[str] = set() self.name_update_from_subblock_to_parent: set[str] = set()
def _is_get_attr_node(self, fqn): def _is_get_attr_node(self, fqn):
return ( return (
@ -469,7 +470,7 @@ class TS2FXGraphConverter:
) )
) )
def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]): def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: list[str]):
subgraph_nodes, subgraph_converters = [], [] subgraph_nodes, subgraph_converters = [], []
for block in node.blocks(): for block in node.blocks():
subgraph_converter = TS2FXGraphConverter( subgraph_converter = TS2FXGraphConverter(
@ -506,7 +507,7 @@ class TS2FXGraphConverter:
Block[x.1] Block[x.1]
%2 = x.1 ... %2 = x.1 ...
""" """
arguments: Set[str] = set() arguments: set[str] = set()
for block in entry.blocks(): for block in entry.blocks():
for block_node in block.nodes(): for block_node in block.nodes():
for block_node_in in block_node.inputs(): for block_node_in in block_node.inputs():
@ -1332,12 +1333,12 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
def __init__( def __init__(
self, self,
ts_graph: Union[torch._C.Graph, torch._C.Block], ts_graph: Union[torch._C.Graph, torch._C.Block],
name_to_param: Dict[str, torch.Tensor], name_to_param: dict[str, torch.Tensor],
name_to_buffer: Dict[str, torch.Tensor], name_to_buffer: dict[str, torch.Tensor],
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], blocks_to_lifted_attrs: dict[torch._C.Block, set[str]],
name_to_non_tensor_attribute: Dict[str, Any], name_to_non_tensor_attribute: dict[str, Any],
name_to_constant: Dict[str, Any], name_to_constant: dict[str, Any],
name_to_attribute_fqn: Dict[str, str], name_to_attribute_fqn: dict[str, str],
): ):
super().__init__( super().__init__(
ts_graph, ts_graph,
@ -1350,7 +1351,7 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
) )
# Data to keep track of unsupported nodes. # Data to keep track of unsupported nodes.
self.unsupported_node_list: List[torch._C.Node] = [] self.unsupported_node_list: list[torch._C.Node] = []
# Add mock to needed attributes. # Add mock to needed attributes.
self.name_to_node = ExplainTS2FXGraphConverter._DictMock( self.name_to_node = ExplainTS2FXGraphConverter._DictMock(
@ -1395,7 +1396,7 @@ class TS2EPConverter:
self, self,
ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction], ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction],
sample_args: tuple[Any, ...], sample_args: tuple[Any, ...],
sample_kwargs: Optional[Dict[str, Any]] = None, sample_kwargs: Optional[dict[str, Any]] = None,
): ):
self.ts_model = ts_model self.ts_model = ts_model
self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
@ -1403,8 +1404,8 @@ class TS2EPConverter:
self.sample_args = sample_args self.sample_args = sample_args
self.sample_kwargs = sample_kwargs self.sample_kwargs = sample_kwargs
self.name_to_param: Dict[str, torch.Tensor] = {} self.name_to_param: dict[str, torch.Tensor] = {}
self.name_to_buffer: Dict[str, torch.Tensor] = {} self.name_to_buffer: dict[str, torch.Tensor] = {}
param_list = ( param_list = (
list(self.ts_model.parameters()) list(self.ts_model.parameters())
if not isinstance(self.ts_model, torch._C.ScriptFunction) if not isinstance(self.ts_model, torch._C.ScriptFunction)
@ -1422,8 +1423,8 @@ class TS2EPConverter:
else: else:
self.name_to_buffer[k] = tensor self.name_to_buffer[k] = tensor
self.name_to_non_tensor_attributes: Dict[str, Any] = {} self.name_to_non_tensor_attributes: dict[str, Any] = {}
self.name_to_constant: Dict[str, Any] = {} self.name_to_constant: dict[str, Any] = {}
self.lift_get_attr() self.lift_get_attr()
@ -1509,7 +1510,7 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
def retrace_as_exported_program( def retrace_as_exported_program(
self, self,
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
name_to_constant: Dict[str, Any], name_to_constant: dict[str, Any],
): ):
dynamic_shapes = _tree_map_with_path( dynamic_shapes = _tree_map_with_path(
lambda path, x: ( lambda path, x: (
@ -1569,7 +1570,7 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
# TS2FXGraphConverter since it gets attributes from self.ts_model # TS2FXGraphConverter since it gets attributes from self.ts_model
# which is not accessable in TS2FXGraphConverter. It is similar to where # which is not accessable in TS2FXGraphConverter. It is similar to where
# we collect self.name_to_param and self.name_to_buffer. # we collect self.name_to_param and self.name_to_buffer.
name_to_attribute_fqn: Dict[str, str] = {} name_to_attribute_fqn: dict[str, str] = {}
def get_attr(fqn: str): def get_attr(fqn: str):
name = fqn.split(".") name = fqn.split(".")

View File

@ -4,12 +4,12 @@ import re
import string import string
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Set from typing import Any, Optional
from types import ModuleType from types import ModuleType
import torch import torch
_TAGS: Dict[str, Dict[str, Any]] = { _TAGS: dict[str, dict[str, Any]] = {
"torch": { "torch": {
"cond": {}, "cond": {},
"dynamic-shape": {}, "dynamic-shape": {},
@ -79,12 +79,12 @@ class ExportCase:
description: str # A description of the use case. description: str # A description of the use case.
model: torch.nn.Module model: torch.nn.Module
name: str name: str
example_kwargs: Dict[str, Any] = field(default_factory=dict) example_kwargs: dict[str, Any] = field(default_factory=dict)
extra_args: Optional[ArgsType] = None # For testing graph generalization. extra_args: Optional[ArgsType] = None # For testing graph generalization.
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch) # Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
tags: Set[str] = field(default_factory=set) tags: set[str] = field(default_factory=set)
support_level: SupportLevel = SupportLevel.SUPPORTED support_level: SupportLevel = SupportLevel.SUPPORTED
dynamic_shapes: Optional[Dict[str, Any]] = None dynamic_shapes: Optional[dict[str, Any]] = None
def __post_init__(self): def __post_init__(self):
check_inputs_type(self.example_args, self.example_kwargs) check_inputs_type(self.example_args, self.example_kwargs)
@ -98,10 +98,10 @@ class ExportCase:
raise ValueError(f'Invalid description: "{self.description}"') raise ValueError(f'Invalid description: "{self.description}"')
_EXAMPLE_CASES: Dict[str, ExportCase] = {} _EXAMPLE_CASES: dict[str, ExportCase] = {}
_MODULES: Set[ModuleType] = set() _MODULES: set[ModuleType] = set()
_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {} _EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {}
_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {} _EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {}
def register_db_case(case: ExportCase) -> None: def register_db_case(case: ExportCase) -> None:

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import List
import torch import torch
@ -9,7 +8,7 @@ class ListUnpack(torch.nn.Module):
erased after tracing. erased after tracing.
""" """
def forward(self, args: List[torch.Tensor]): def forward(self, args: list[torch.Tensor]):
""" """
Lists are treated as static construct, therefore unpacking should be Lists are treated as static construct, therefore unpacking should be
erased after tracing. erased after tracing.

View File

@ -3,18 +3,7 @@ import contextlib
import inspect import inspect
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import ( from typing import Any, Callable, Optional, TYPE_CHECKING, Union
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -97,8 +86,8 @@ def fakify(
mode: FakeTensorMode, mode: FakeTensorMode,
kp: KeyPath, kp: KeyPath,
t: Any, t: Any,
t_constraints: Dict[int, Dict[int, Constraint]], t_constraints: dict[int, dict[int, Constraint]],
sources: Dict[tuple[int, int], List[Source]], sources: dict[tuple[int, int], list[Source]],
): ):
source = key_path_to_source(kp) source = key_path_to_source(kp)
if _is_constant_argument(t) or isinstance(t, (torch.ScriptObject, torch.nn.Module)): if _is_constant_argument(t) or isinstance(t, (torch.ScriptObject, torch.nn.Module)):
@ -165,7 +154,7 @@ def make_fake_inputs(
combined_args = _combine_args(nn_module, args, kwargs) combined_args = _combine_args(nn_module, args, kwargs)
_check_dynamic_shapes(combined_args, dynamic_shapes) _check_dynamic_shapes(combined_args, dynamic_shapes)
constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) t_constraints: dict[int, dict[int, Constraint]] = defaultdict(dict)
for constraint in constraints: for constraint in constraints:
t_constraints[constraint.t_id][constraint.dim] = constraint t_constraints[constraint.t_id][constraint.dim] = constraint
@ -214,17 +203,17 @@ def make_fake_inputs(
original_signature = inspect.signature(nn_module.forward) original_signature = inspect.signature(nn_module.forward)
else: else:
original_signature = None original_signature = None
sources: Dict[tuple[int, int], List[Source]] = defaultdict(list) sources: dict[tuple[int, int], list[Source]] = defaultdict(list)
fake_args, fake_kwargs = tree_map_with_path( fake_args, fake_kwargs = tree_map_with_path(
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
(args, kwargs), (args, kwargs),
) )
names: Dict[str, tuple[int, int]] = {} names: dict[str, tuple[int, int]] = {}
source_pairs: List[tuple[Source, Source]] = [] source_pairs: list[tuple[Source, Source]] = []
derived_equalities: List[tuple[Source, Union[Source, Symbol], Callable]] = [] derived_equalities: list[tuple[Source, Union[Source, Symbol], Callable]] = []
phantom_symbols: Dict[str, Symbol] = {} phantom_symbols: dict[str, Symbol] = {}
relaxed_sources: Set[Source] = set() relaxed_sources: set[Source] = set()
for constraint in constraints: for constraint in constraints:
torch.export.dynamic_shapes._process_equalities( torch.export.dynamic_shapes._process_equalities(
constraint, constraint,
@ -255,9 +244,9 @@ def make_fake_inputs(
def _flatten_dynamic_shapes( def _flatten_dynamic_shapes(
combined_args: Dict[str, Any], combined_args: dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any]], dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
) -> List[Any]: ) -> list[Any]:
flat_shapes = [] flat_shapes = []
def _tree_map_helper(path, t, shape): def _tree_map_helper(path, t, shape):
@ -283,7 +272,7 @@ def _clean_dynamic_markers(tensor: torch.Tensor) -> None:
def produce_guards_and_solve_constraints( def produce_guards_and_solve_constraints(
fake_mode: FakeTensorMode, fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None], dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
equalities_inputs: EqualityConstraint, equalities_inputs: EqualityConstraint,
original_signature: inspect.Signature, original_signature: inspect.Signature,
_is_torch_jit_trace=False, _is_torch_jit_trace=False,
@ -348,8 +337,8 @@ def produce_guards_and_solve_constraints(
def make_constraints( def make_constraints(
fake_mode: FakeTensorMode, fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
combined_args: Dict[str, Any], combined_args: dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None], dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
num_lifted_inputs: int, num_lifted_inputs: int,
): ):
""" """
@ -435,7 +424,7 @@ def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
buffers_parameters = set(m.buffers()) buffers_parameters = set(m.buffers())
buffers_parameters.update(m.parameters()) buffers_parameters.update(m.parameters())
def inner(m: torch.nn.Module, prefix_atoms: List[str], constants): def inner(m: torch.nn.Module, prefix_atoms: list[str], constants):
for k, v in m.__dict__.items(): for k, v in m.__dict__.items():
if isinstance( if isinstance(
v, v,
@ -459,8 +448,8 @@ def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
def _get_graph_inputs_of_type_nn_module( def _get_graph_inputs_of_type_nn_module(
args: Optional[Tuple[Tuple[Any], Dict[Any, Any]]], args: Optional[tuple[tuple[Any], dict[Any, Any]]],
) -> Set[Type[torch.nn.Module]]: ) -> set[type[torch.nn.Module]]:
if args is None: if args is None:
return set() return set()
module_types = set() module_types = set()
@ -471,14 +460,14 @@ def _get_graph_inputs_of_type_nn_module(
def _enter_enable_graph_inputs_of_type_nn_module( def _enter_enable_graph_inputs_of_type_nn_module(
module_types: Set[Type[torch.nn.Module]], module_types: set[type[torch.nn.Module]],
) -> None: ) -> None:
for t in module_types: for t in module_types:
torch._export.utils.register_module_as_pytree_input_node(t) torch._export.utils.register_module_as_pytree_input_node(t)
def _exit_enable_graph_inputs_of_type_nn_module( def _exit_enable_graph_inputs_of_type_nn_module(
module_types: Set[Type[torch.nn.Module]], module_types: set[type[torch.nn.Module]],
) -> None: ) -> None:
for t in module_types: for t in module_types:
torch._export.utils.deregister_module_as_pytree_input_node(t) torch._export.utils.deregister_module_as_pytree_input_node(t)
@ -486,7 +475,7 @@ def _exit_enable_graph_inputs_of_type_nn_module(
@contextlib.contextmanager @contextlib.contextmanager
def _enable_graph_inputs_of_type_nn_module( def _enable_graph_inputs_of_type_nn_module(
args: Optional[Tuple[Tuple[Any], Dict[Any, Any]]], args: Optional[tuple[tuple[Any], dict[Any, Any]]],
): ):
if args is None: if args is None:
yield yield
@ -502,8 +491,8 @@ def _enable_graph_inputs_of_type_nn_module(
@contextlib.contextmanager @contextlib.contextmanager
def _fakify_module_inputs( def _fakify_module_inputs(
args: Tuple[Any], args: tuple[Any],
kwargs: Dict[Any, Any], kwargs: dict[Any, Any],
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
): ):
# This context manager is used to fakify module inputs. # This context manager is used to fakify module inputs.
@ -534,7 +523,7 @@ def _fakify_module_inputs(
def _fakify_script_objects( def _fakify_script_objects(
mod: torch.nn.Module, mod: torch.nn.Module,
args: tuple[Any], args: tuple[Any],
kwargs: Dict[Any, Any], kwargs: dict[Any, Any],
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
): ):
# This context manager is used to fakify script objects into FakeScriptObject. # This context manager is used to fakify script objects into FakeScriptObject.

View File

@ -3,7 +3,7 @@ import operator
import traceback import traceback
import typing import typing
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional, Set, Union from typing import Any, Callable, Optional, Union
import torch import torch
from functorch.experimental.control_flow import _unstack_pytree from functorch.experimental.control_flow import _unstack_pytree
@ -31,7 +31,7 @@ Fn = Callable[..., Any]
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
_TORCH_SYM_OPS: Set[Callable] = { _TORCH_SYM_OPS: set[Callable] = {
torch.sym_int, torch.sym_int,
torch.sym_float, torch.sym_float,
torch.sym_ite, torch.sym_ite,
@ -64,9 +64,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self.root = torch.nn.Module() self.root = torch.nn.Module()
self.graph = torch.fx.Graph() self.graph = torch.fx.Graph()
self.graph.set_codegen(codegen) self.graph.set_codegen(codegen)
self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment] self.tensor_attrs: dict[str, torch.Tensor] = {} # type: ignore[assignment]
self.fake_tensor_mode: Optional[FakeTensorMode] = None self.fake_tensor_mode: Optional[FakeTensorMode] = None
self.submodules: Dict[torch.nn.Module, str] = {} self.submodules: dict[torch.nn.Module, str] = {}
def trace(self) -> None: # type: ignore[override] def trace(self) -> None: # type: ignore[override]
raise ExportPassBaseError("ExportTracer doesn't support trace().") raise ExportPassBaseError("ExportTracer doesn't support trace().")
@ -162,7 +162,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self, self,
target: str, # type: ignore[override] target: str, # type: ignore[override]
args: tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Argument], kwargs: dict[str, Argument],
) -> ProxyValue: ) -> ProxyValue:
arg = super().placeholder(target, args, kwargs) arg = super().placeholder(target, args, kwargs)
return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
@ -171,7 +171,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self, self,
target: torch.fx.node.Target, target: torch.fx.node.Target,
args: tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Argument], kwargs: dict[str, Argument],
) -> ProxyValue: ) -> ProxyValue:
return self.callback.output(args[0], NodeMetadata(self.node.meta)).data # type: ignore[return-value] return self.callback.output(args[0], NodeMetadata(self.node.meta)).data # type: ignore[return-value]
@ -179,7 +179,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self, self,
target: torch.fx.node.Target, target: torch.fx.node.Target,
args: tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Argument], kwargs: dict[str, Argument],
) -> ProxyValue: ) -> ProxyValue:
meta = NodeMetadata(self.node.meta) meta = NodeMetadata(self.node.meta)
@ -218,7 +218,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
raise ExportPassBaseError(f"Unsupported target type: {target}") raise ExportPassBaseError(f"Unsupported target type: {target}")
def get_attr( def get_attr(
self, target: str, args: tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
) -> Argument: ) -> Argument:
return super().get_attr(target, args, kwargs) return super().get_attr(target, args, kwargs)
@ -226,12 +226,12 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self, self,
target: torch.fx.node.Target, target: torch.fx.node.Target,
args: tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Argument], kwargs: dict[str, Argument],
) -> None: ) -> None:
raise ExportPassBaseError("call_module is not supported.") raise ExportPassBaseError("call_module is not supported.")
def call_method( def call_method(
self, target: str, args: tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override] self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
) -> None: ) -> None:
raise ExportPassBaseError("call_method is not supported.") raise ExportPassBaseError("call_method is not supported.")
@ -254,7 +254,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
kind: str, kind: str,
target: torch.fx.node.Target, target: torch.fx.node.Target,
args: tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Argument], kwargs: dict[str, Argument],
meta: NodeMetadata, meta: NodeMetadata,
) -> ProxyValue: ) -> ProxyValue:
args_data, kwargs_data = pytree.tree_map_only( args_data, kwargs_data = pytree.tree_map_only(
@ -277,7 +277,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self.tracer.set_metadata(res_proxy.node, res_data) self.tracer.set_metadata(res_proxy.node, res_data)
return ProxyValue(res_data, res_proxy) return ProxyValue(res_data, res_proxy)
def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: def inputs(self, graph_module: torch.fx.GraphModule) -> list[Argument]:
# TODO(angelayi): Update this with what we decide to do for metadata in # TODO(angelayi): Update this with what we decide to do for metadata in
# the exported graph module # the exported graph module
if (args := graph_module.meta.get("args", None)) is not None: if (args := graph_module.meta.get("args", None)) is not None:
@ -327,7 +327,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self, self,
op, op,
args: tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Argument], kwargs: dict[str, Argument],
meta: NodeMetadata, meta: NodeMetadata,
) -> ProxyValue: ) -> ProxyValue:
return self._fx("call_function", op, args, kwargs, meta) return self._fx("call_function", op, args, kwargs, meta)
@ -345,7 +345,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
pred: ProxyValue, pred: ProxyValue,
true_fn: torch.fx.GraphModule, true_fn: torch.fx.GraphModule,
false_fn: torch.fx.GraphModule, false_fn: torch.fx.GraphModule,
inputs: List[Argument], inputs: list[Argument],
meta: NodeMetadata, meta: NodeMetadata,
) -> ProxyValue: ) -> ProxyValue:
true_branch = self.call_submodule(true_fn, tuple(inputs)) true_branch = self.call_submodule(true_fn, tuple(inputs))
@ -363,8 +363,8 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
def call_map( def call_map(
self, self,
f: torch.fx.GraphModule, f: torch.fx.GraphModule,
mapped_args: List[ProxyValue], mapped_args: list[ProxyValue],
operands: List[ProxyValue], operands: list[ProxyValue],
meta: NodeMetadata, meta: NodeMetadata,
) -> ProxyValue: ) -> ProxyValue:
xs = _unstack_pytree([arg.data for arg in mapped_args])[0] xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
@ -383,7 +383,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
) -> ProxyValue: ) -> ProxyValue:
return self._fx("call_function", operator.getitem, (value, key), {}, meta) return self._fx("call_function", operator.getitem, (value, key), {}, meta)
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue:
return self._fx("output", "output", (results,), {}, meta) return self._fx("output", "output", (results,), {}, meta)
def call_submodule( def call_submodule(

View File

@ -1,10 +1,10 @@
from typing import Any, Dict, Set from typing import Any
NodeMetadataValue = Any NodeMetadataValue = Any
PROTECTED_KEYS: Set[str] = { PROTECTED_KEYS: set[str] = {
"val", "val",
"stack_trace", "stack_trace",
"nn_module_stack", "nn_module_stack",
@ -14,8 +14,8 @@ PROTECTED_KEYS: Set[str] = {
class NodeMetadata: class NodeMetadata:
def __init__(self, data: Dict[str, Any]) -> None: def __init__(self, data: dict[str, Any]) -> None:
self.data: Dict[str, Any] = data.copy() self.data: dict[str, Any] = data.copy()
def __getitem__(self, key: str) -> NodeMetadataValue: def __getitem__(self, key: str) -> NodeMetadataValue:
return self.data[key] return self.data[key]

View File

@ -1,5 +1,6 @@
# pyre-strict # pyre-strict
from typing import Union, Iterator, Iterable, Generic from typing import Union, Generic
from collections.abc import Iterator, Iterable
import torch import torch
from typing import TypeVar from typing import TypeVar

View File

@ -3,7 +3,7 @@ import math
import operator import operator
import traceback import traceback
from functools import partial from functools import partial
from typing import Callable, Dict, List, NamedTuple, Set from typing import Callable, NamedTuple
import sympy import sympy
@ -45,11 +45,11 @@ def _convert_range_to_int(range: ValueRanges):
class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase): class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
def __init__( def __init__(
self, self,
range_constraints: Dict[sympy.Symbol, ValueRanges], range_constraints: dict[sympy.Symbol, ValueRanges],
): ):
super().__init__() super().__init__()
self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints self.range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints
self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set() self._asserts_generated_unbacked_symbols: set[sympy.Symbol] = set()
self.counter = 0 self.counter = 0
def _assert_range_constraint(self, node, lower, upper, assert_msg): def _assert_range_constraint(self, node, lower, upper, assert_msg):
@ -105,8 +105,8 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
# need the proxy for shape, which further requires the proxy for ret[1], etc. # need the proxy for shape, which further requires the proxy for ret[1], etc.
def add_assertions(val): def add_assertions(val):
call_backs: List[Callable] = [] call_backs: list[Callable] = []
messages: List[str] = [] messages: list[str] = []
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
symbol = val.node.expr symbol = val.node.expr
if symbol in self.existing_inline_assertions: if symbol in self.existing_inline_assertions:
@ -161,9 +161,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
def _get_existing_inline_assertions( def _get_existing_inline_assertions(
graph_module: torch.fx.GraphModule, graph_module: torch.fx.GraphModule,
range_constraints: Dict[sympy.Symbol, ValueRanges], range_constraints: dict[sympy.Symbol, ValueRanges],
) -> Dict[sympy.Symbol, ValueRanges]: ) -> dict[sympy.Symbol, ValueRanges]:
existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {} existing_inline_assertions: dict[sympy.Symbol, ValueRanges] = {}
for module in graph_module.modules(): for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule): if not isinstance(module, torch.fx.GraphModule):

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import operator import operator
from typing import Dict, Optional, TYPE_CHECKING, Union from typing import Optional, TYPE_CHECKING, Union
import torch import torch
from torch.export.exported_program import ConstantArgument, TensorArgument from torch.export.exported_program import ConstantArgument, TensorArgument
@ -23,7 +23,7 @@ class CollectTracepointsPass(PassBase):
""" """
def __init__( def __init__(
self, specs: Dict[str, ModuleCallSignature], sig: ExportGraphSignature self, specs: dict[str, ModuleCallSignature], sig: ExportGraphSignature
) -> None: ) -> None:
super().__init__() super().__init__()
self.specs = specs self.specs = specs

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import collections import collections
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Optional
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -53,8 +53,8 @@ class ConstantFolder(torch.fx.Interpreter):
skip_constructors: bool = False, skip_constructors: bool = False,
): ):
super().__init__(gm) super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {} self.node_replacements: dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object() self.unknown_value = object()
self.skip_constructors: bool = skip_constructors self.skip_constructors: bool = skip_constructors
@ -281,7 +281,7 @@ def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule
new_graph = torch.fx.Graph() new_graph = torch.fx.Graph()
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
output_nodes = [] output_nodes = []
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.meta[META_TAG] == MODULE_TAG: if node.meta[META_TAG] == MODULE_TAG:

View File

@ -1,5 +1,5 @@
import copy import copy
from typing import Dict, Optional, List from typing import Optional
import torch import torch
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
@ -9,7 +9,7 @@ from torch._ops import OpOverload
aten = torch.ops.aten aten = torch.ops.aten
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = { _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
aten.sym_constrain_range.default: aten._functional_sym_constrain_range, aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
aten._assert_async.msg: aten._functional_assert_async.msg, aten._assert_async.msg: aten._functional_assert_async.msg,
} }
@ -60,7 +60,7 @@ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
self, self,
op: OpOverload, op: OpOverload,
args: tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Argument], kwargs: dict[str, Argument],
meta: NodeMetadata, meta: NodeMetadata,
) -> ProxyValue: ) -> ProxyValue:
if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
@ -88,7 +88,7 @@ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
return self._dep_token return self._dep_token
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue:
assert self._dep_token is not None assert self._dep_token is not None
return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type] return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type]

View File

@ -1,5 +1,4 @@
import functools import functools
from typing import List
import torch import torch
from torch._export.passes._node_metadata_hook import ( from torch._export.passes._node_metadata_hook import (
@ -8,7 +7,7 @@ from torch._export.passes._node_metadata_hook import (
) )
def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: List[str]) -> None: def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: list[str]) -> None:
""" """
This is used by draft_export to insert guards in front of calls to custom This is used by draft_export to insert guards in front of calls to custom
operators which have a generated fake kernel. operators which have a generated fake kernel.

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import collections import collections
import warnings import warnings
from typing import Any, Dict, List, Union from typing import Any, Union
import torch import torch
from torch._export.verifier import SpecViolationError from torch._export.verifier import SpecViolationError
@ -29,13 +29,13 @@ class ConstantAttrMap(collections.abc.MutableMapping):
def __init__(self) -> None: def __init__(self) -> None:
# Underlying dict that we use to implement this mapping. # Underlying dict that we use to implement this mapping.
self._constant_attrs: Dict[ self._constant_attrs: dict[
Union[int, torch.Tensor, FakeScriptObject], List[Any] Union[int, torch.Tensor, FakeScriptObject], list[Any]
] = {} ] = {}
# Map from the hash(ScriptObject) to the ScriptObject itself. Used for # Map from the hash(ScriptObject) to the ScriptObject itself. Used for
# APIs like `__iter__` that should look like they're returning the # APIs like `__iter__` that should look like they're returning the
# original ScriptObjects. # original ScriptObjects.
self._script_object_map: Dict[int, torch.ScriptObject] = {} self._script_object_map: dict[int, torch.ScriptObject] = {}
def __getitem__( def __getitem__(
self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject] self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
@ -113,7 +113,7 @@ def lift_constants_pass(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
graph_signature: ExportGraphSignature, graph_signature: ExportGraphSignature,
constant_attrs: ConstantAttrMap, constant_attrs: ConstantAttrMap,
) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]: ) -> dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]:
""" """
Takes a graph module, graph signature, and modifies them implace to lift any Takes a graph module, graph signature, and modifies them implace to lift any
constants (tensors or custom classes) as inputs to the graph. Returns a constants (tensors or custom classes) as inputs to the graph. Returns a
@ -131,7 +131,7 @@ def lift_constants_pass(
Returns: Returns:
A dictionary of fqn => constant value. A dictionary of fqn => constant value.
""" """
all_constants: Dict[ all_constants: dict[
str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject] str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
] = {} ] = {}
@ -300,13 +300,13 @@ def lift_constants_pass(
def rewrite_script_object_meta( def rewrite_script_object_meta(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]: ) -> dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]:
"""When tracing, we produce a graph with FakeScriptObject in the """When tracing, we produce a graph with FakeScriptObject in the
meta["val"]. meta["val"].
For now, we rewrie meta["val"] to be a placeholder CustomObjArgument For now, we rewrie meta["val"] to be a placeholder CustomObjArgument
""" """
constants: Dict[ constants: dict[
str, str,
Union[ Union[
torch.Tensor, torch.Tensor,

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations from __future__ import annotations
from typing import List, Optional, TYPE_CHECKING, Union from typing import Optional, TYPE_CHECKING, Union
import torch import torch
from torch._higher_order_ops.wrap import wrap_with_autocast from torch._higher_order_ops.wrap import wrap_with_autocast
@ -116,7 +116,7 @@ def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
exit_autocast # 3 exit_autocast # 3
E # 4 E # 4
""" """
enter_autocast_node_stack: List[torch.fx.Node] = [] enter_autocast_node_stack: list[torch.fx.Node] = []
first_node_after_outer_most_exit: bool = False first_node_after_outer_most_exit: bool = False
def node_call_back(node: torch.fx.Node) -> bool: def node_call_back(node: torch.fx.Node) -> bool:

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import logging import logging
import operator import operator
from typing import List, Optional, Union from typing import Optional, Union
import torch import torch
import torch.export._trace import torch.export._trace
@ -269,9 +269,9 @@ def _conv1d_op_with_squeeze(
inp: torch.Tensor, inp: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
stride: List[int], stride: list[int],
padding: List[int], padding: list[int],
dilation: List[int], dilation: list[int],
groups: int, groups: int,
) -> torch.Tensor: ) -> torch.Tensor:
# In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import Dict, Optional from typing import Optional
import torch import torch
from torch._ops import OpOverload, HigherOrderOperator from torch._ops import OpOverload, HigherOrderOperator
from torch._export.error import InternalError from torch._export.error import InternalError
@ -9,7 +9,7 @@ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"] __all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = { _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: dict[OpOverload, OpOverload] = {
torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
} }

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List
from torch._export.serde.schema import Node from torch._export.serde.schema import Node
@ -12,4 +11,4 @@ class ExternKernelNode:
@dataclass @dataclass
class ExternKernelNodes: class ExternKernelNodes:
nodes: List[ExternKernelNode] nodes: list[ExternKernelNode]

View File

@ -1,5 +1,5 @@
import dataclasses import dataclasses
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
import torch import torch
from torch._dynamo.exc import UserError, UserErrorType from torch._dynamo.exc import UserError, UserErrorType
@ -24,7 +24,7 @@ class RootDim:
min: int min: int
max: Union[int, None] max: Union[int, None]
derived: List[str] derived: list[str]
@dataclasses.dataclass @dataclasses.dataclass
@ -33,15 +33,15 @@ class DynamicShapesSpec:
This stores a dynamic_shapes spec for de/serialization. This stores a dynamic_shapes spec for de/serialization.
""" """
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None] dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
dims: Dict[str, RootDim] dims: dict[str, RootDim]
def _postprocess_serialized_shapes( def _postprocess_serialized_shapes(
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None], dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
dims: Dict[str, Dict[str, Union[int, List[str], None]]], dims: dict[str, dict[str, Union[int, list[str], None]]],
to_dict: Optional[bool] = False, to_dict: Optional[bool] = False,
) -> Union[DynamicShapesSpec, Dict[str, Any]]: ) -> Union[DynamicShapesSpec, dict[str, Any]]:
""" """
Sorts dims and dumps to dictionary format. Sorts dims and dumps to dictionary format.
""" """
@ -63,11 +63,11 @@ def _postprocess_serialized_shapes(
def _dump_dynamic_shapes( def _dump_dynamic_shapes(
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None], dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
args: tuple[Any], args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
to_dict: Optional[bool] = False, to_dict: Optional[bool] = False,
) -> Union[DynamicShapesSpec, Dict[str, Any]]: ) -> Union[DynamicShapesSpec, dict[str, Any]]:
""" """
Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec. Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec.
Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims". Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims".
@ -127,7 +127,7 @@ def _dump_dynamic_shapes(
} }
``` ```
""" """
dims: Dict[str, Dict[str, Any]] = {} dims: dict[str, dict[str, Any]] = {}
def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
""" """
@ -198,9 +198,9 @@ def _dump_dynamic_shapes(
def _load_dynamic_shapes( def _load_dynamic_shapes(
spec: Union[DynamicShapesSpec, Dict[str, Any]], spec: Union[DynamicShapesSpec, dict[str, Any]],
from_dict: Optional[bool] = False, from_dict: Optional[bool] = False,
) -> Union[Dict[str, Any], tuple[Any], List[Any], None]: ) -> Union[dict[str, Any], tuple[Any], list[Any], None]:
""" """
Utility function for dynamic shapes serialization. Utility function for dynamic shapes serialization.
Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export(). Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export().

View File

@ -3,7 +3,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import IntEnum from enum import IntEnum
from typing import Annotated, Dict, List, Optional from typing import Annotated, Optional
from torch._export.serde.union import _Union from torch._export.serde.union import _Union
@ -96,10 +96,10 @@ class SymBool(_Union):
@dataclass @dataclass
class TensorMeta: class TensorMeta:
dtype: Annotated[ScalarType, 10] dtype: Annotated[ScalarType, 10]
sizes: Annotated[List[SymInt], 20] sizes: Annotated[list[SymInt], 20]
requires_grad: Annotated[bool, 30] requires_grad: Annotated[bool, 30]
device: Annotated[Device, 40] device: Annotated[Device, 40]
strides: Annotated[List[SymInt], 50] strides: Annotated[list[SymInt], 50]
storage_offset: Annotated[SymInt, 60] storage_offset: Annotated[SymInt, 60]
layout: Annotated[Layout, 70] layout: Annotated[Layout, 70]
@ -175,29 +175,29 @@ class CustomObjArgument:
class Argument(_Union): class Argument(_Union):
as_none: Annotated[bool, 10] as_none: Annotated[bool, 10]
as_tensor: Annotated[TensorArgument, 20] as_tensor: Annotated[TensorArgument, 20]
as_tensors: Annotated[List[TensorArgument], 30] as_tensors: Annotated[list[TensorArgument], 30]
as_int: Annotated[int, 50] as_int: Annotated[int, 50]
as_ints: Annotated[List[int], 70] as_ints: Annotated[list[int], 70]
as_float: Annotated[float, 80] as_float: Annotated[float, 80]
as_floats: Annotated[List[float], 90] as_floats: Annotated[list[float], 90]
as_string: Annotated[str, 100] as_string: Annotated[str, 100]
as_strings: Annotated[List[str], 101] as_strings: Annotated[list[str], 101]
as_sym_int: Annotated[SymIntArgument, 110] as_sym_int: Annotated[SymIntArgument, 110]
as_sym_ints: Annotated[List[SymIntArgument], 120] as_sym_ints: Annotated[list[SymIntArgument], 120]
as_scalar_type: Annotated[ScalarType, 130] as_scalar_type: Annotated[ScalarType, 130]
as_memory_format: Annotated[MemoryFormat, 140] as_memory_format: Annotated[MemoryFormat, 140]
as_layout: Annotated[Layout, 150] as_layout: Annotated[Layout, 150]
as_device: Annotated[Device, 160] as_device: Annotated[Device, 160]
as_bool: Annotated[bool, 170] as_bool: Annotated[bool, 170]
as_bools: Annotated[List[bool], 180] as_bools: Annotated[list[bool], 180]
as_sym_bool: Annotated[SymBoolArgument, 182] as_sym_bool: Annotated[SymBoolArgument, 182]
as_sym_bools: Annotated[List[SymBoolArgument], 184] as_sym_bools: Annotated[list[SymBoolArgument], 184]
as_graph: Annotated[GraphArgument, 200] as_graph: Annotated[GraphArgument, 200]
as_optional_tensors: Annotated[List[OptionalTensorArgument], 190] as_optional_tensors: Annotated[list[OptionalTensorArgument], 190]
as_custom_obj: Annotated[CustomObjArgument, 210] as_custom_obj: Annotated[CustomObjArgument, 210]
as_operator: Annotated[str, 220] as_operator: Annotated[str, 220]
as_sym_float: Annotated[SymFloatArgument, 230] as_sym_float: Annotated[SymFloatArgument, 230]
as_sym_floats: Annotated[List[SymFloatArgument], 240] as_sym_floats: Annotated[list[SymFloatArgument], 240]
class ArgumentKind(IntEnum): class ArgumentKind(IntEnum):
@ -217,27 +217,27 @@ class NamedArgument:
@dataclass @dataclass
class Node: class Node:
target: Annotated[str, 10] target: Annotated[str, 10]
inputs: Annotated[List[NamedArgument], 20] inputs: Annotated[list[NamedArgument], 20]
outputs: Annotated[List[Argument], 30] outputs: Annotated[list[Argument], 30]
metadata: Annotated[Dict[str, str], 40] metadata: Annotated[dict[str, str], 40]
is_hop_single_tensor_return: Annotated[Optional[bool], 50] = None is_hop_single_tensor_return: Annotated[Optional[bool], 50] = None
@dataclass @dataclass
class Graph: class Graph:
inputs: Annotated[List[Argument], 10] inputs: Annotated[list[Argument], 10]
outputs: Annotated[List[Argument], 20] outputs: Annotated[list[Argument], 20]
nodes: Annotated[List[Node], 30] nodes: Annotated[list[Node], 30]
tensor_values: Annotated[Dict[str, TensorMeta], 40] tensor_values: Annotated[dict[str, TensorMeta], 40]
sym_int_values: Annotated[Dict[str, SymInt], 50] sym_int_values: Annotated[dict[str, SymInt], 50]
sym_bool_values: Annotated[Dict[str, SymBool], 60] sym_bool_values: Annotated[dict[str, SymBool], 60]
# This is for deserializing the submodule graphs from higher order ops # This is for deserializing the submodule graphs from higher order ops
# (ex. cond, map) where single tensor returns will just return a single # (ex. cond, map) where single tensor returns will just return a single
# tensor, rather than following export schema and returning a singleton # tensor, rather than following export schema and returning a singleton
# list. # list.
is_single_tensor_return: Annotated[bool, 70] = False is_single_tensor_return: Annotated[bool, 70] = False
custom_obj_values: Annotated[Dict[str, CustomObjArgument], 80] = field(default_factory=dict) custom_obj_values: Annotated[dict[str, CustomObjArgument], 80] = field(default_factory=dict)
sym_float_values: Annotated[Dict[str, SymFloat], 90] = field(default_factory=dict) sym_float_values: Annotated[dict[str, SymFloat], 90] = field(default_factory=dict)
@dataclass @dataclass
class UserInputSpec: class UserInputSpec:
@ -354,8 +354,8 @@ class OutputSpec(_Union):
@dataclass @dataclass
class GraphSignature: class GraphSignature:
input_specs: Annotated[List[InputSpec], 10] input_specs: Annotated[list[InputSpec], 10]
output_specs: Annotated[List[OutputSpec], 20] output_specs: Annotated[list[OutputSpec], 20]
@dataclass @dataclass
@ -366,8 +366,8 @@ class RangeConstraint:
@dataclass @dataclass
class ModuleCallSignature: class ModuleCallSignature:
inputs: Annotated[List[Argument], 10] inputs: Annotated[list[Argument], 10]
outputs: Annotated[List[Argument], 20] outputs: Annotated[list[Argument], 20]
# These are serialized by calling pytree.treespec_loads # These are serialized by calling pytree.treespec_loads
# And deserialized by calling pytree.treespec_dumps # And deserialized by calling pytree.treespec_dumps
@ -376,7 +376,7 @@ class ModuleCallSignature:
# This field is used to prettify the graph placeholders # This field is used to prettify the graph placeholders
# after we ser/der and retrace # after we ser/der and retrace
forward_arg_names: Annotated[Optional[List[str]], 50] = None forward_arg_names: Annotated[Optional[list[str]], 50] = None
@dataclass @dataclass
@ -392,8 +392,8 @@ class GraphModule:
# This is used for unflattening, by tracking the calling structure of all of # This is used for unflattening, by tracking the calling structure of all of
# the modules in order to unflatten the modules back to the eager calling # the modules in order to unflatten the modules back to the eager calling
# conventions. # conventions.
module_call_graph: Annotated[List[ModuleCallEntry], 60] module_call_graph: Annotated[list[ModuleCallEntry], 60]
metadata: Annotated[Dict[str, str], 40] = field(default_factory=dict) metadata: Annotated[dict[str, str], 40] = field(default_factory=dict)
# Invariant: Every time a change is made to the schema, one of the versions # Invariant: Every time a change is made to the schema, one of the versions
@ -408,8 +408,8 @@ class SchemaVersion:
class ExportedProgram: class ExportedProgram:
graph_module: Annotated[GraphModule, 10] graph_module: Annotated[GraphModule, 10]
# Key is the opset namespace (ex. aten), and value is the version number # Key is the opset namespace (ex. aten), and value is the version number
opset_version: Annotated[Dict[str, int], 20] opset_version: Annotated[dict[str, int], 20]
range_constraints: Annotated[Dict[str, RangeConstraint], 30] range_constraints: Annotated[dict[str, RangeConstraint], 30]
schema_version: Annotated[SchemaVersion, 60] schema_version: Annotated[SchemaVersion, 60]
verifiers: Annotated[List[str], 70] = field(default_factory=list) verifiers: Annotated[list[str], 70] = field(default_factory=list)
torch_version: Annotated[str, 80] = "<=2.4" torch_version: Annotated[str, 80] = "<=2.4"

View File

@ -5,7 +5,7 @@ import inspect
import re import re
import typing import typing
from enum import IntEnum from enum import IntEnum
from typing import Annotated, Any, Dict, ForwardRef, List, Optional, Union from typing import Annotated, Any, ForwardRef, Optional, Union
from torch._export.serde import schema from torch._export.serde import schema
from torch._export.serde.union import _Union from torch._export.serde.union import _Union
@ -36,16 +36,16 @@ _THRIFT_TYPE_MAP = {
def _staged_schema(): def _staged_schema():
yaml_ret: Dict[str, Any] = {} yaml_ret: dict[str, Any] = {}
defs = {} defs = {}
cpp_enum_defs: Dict[str, str] = {} cpp_enum_defs: dict[str, str] = {}
cpp_class_defs: Dict[str, str] = {} cpp_class_defs: dict[str, str] = {}
cpp_type_decls: List[str] = [] cpp_type_decls: list[str] = []
cpp_json_defs: List[str] = [] cpp_json_defs: list[str] = []
thrift_enum_defs: List[str] = [] thrift_enum_defs: list[str] = []
thrift_type_defs: Dict[str, str] = {} thrift_type_defs: dict[str, str] = {}
def _handle_aggregate(ty) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: def _handle_aggregate(ty) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
def dump_type(t, level: int) -> tuple[str, str, str]: def dump_type(t, level: int) -> tuple[str, str, str]:
if getattr(t, "__name__", None) in cpp_enum_defs: if getattr(t, "__name__", None) in cpp_enum_defs:
return t.__name__, "int64_t", t.__name__ return t.__name__, "int64_t", t.__name__
@ -125,7 +125,7 @@ def _staged_schema():
f"Default value {v} is not supported yet in export schema." f"Default value {v} is not supported yet in export schema."
) )
def dump_field(f) -> tuple[Dict[str, Any], str, Optional[str], str, int]: def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]:
t, cpp_type, thrift_type = dump_type(f.type, 0) t, cpp_type, thrift_type = dump_type(f.type, 0)
ret = {"type": t} ret = {"type": t}
cpp_default: Optional[str] = None cpp_default: Optional[str] = None
@ -524,12 +524,12 @@ def _hash_content(s: str):
@dataclasses.dataclass @dataclasses.dataclass
class _Commit: class _Commit:
result: Dict[str, Any] result: dict[str, Any]
checksum_next: str checksum_next: str
yaml_path: str yaml_path: str
additions: Dict[str, Any] additions: dict[str, Any]
subtractions: Dict[str, Any] subtractions: dict[str, Any]
base: Dict[str, Any] base: dict[str, Any]
checksum_head: Optional[str] checksum_head: Optional[str]
cpp_header: str cpp_header: str
cpp_header_path: str cpp_header_path: str

View File

@ -24,15 +24,11 @@ from typing import (
Any, Any,
Callable, Callable,
cast, cast,
Dict,
final, final,
Iterator,
List,
Optional, Optional,
Set,
Type,
Union, Union,
) )
from collections.abc import Iterator
import sympy import sympy
@ -118,7 +114,7 @@ class SerializeError(RuntimeError):
pass pass
def _reverse_map(d: Dict[Any, Enum]): def _reverse_map(d: dict[Any, Enum]):
return {v.value: k for k, v in d.items()} return {v.value: k for k, v in d.items()}
@ -352,7 +348,7 @@ def serialize_torch_artifact(artifact: Optional[Any], pickle_protocol: int = DEF
del copyreg.dispatch_table[FakeTensor] del copyreg.dispatch_table[FakeTensor]
def deserialize_torch_artifact(serialized: Union[Dict[str, Any], tuple[Any, ...], bytes]): def deserialize_torch_artifact(serialized: Union[dict[str, Any], tuple[Any, ...], bytes]):
if isinstance(serialized, (dict, tuple)): if isinstance(serialized, (dict, tuple)):
return serialized return serialized
if len(serialized) == 0: if len(serialized) == 0:
@ -401,8 +397,8 @@ def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr:
def serialize_range_constraints( def serialize_range_constraints(
range_constraints: Dict[sympy.Symbol, ValueRanges] range_constraints: dict[sympy.Symbol, ValueRanges]
) -> Dict[str, RangeConstraint]: ) -> dict[str, RangeConstraint]:
return { return {
str(k): RangeConstraint( str(k): RangeConstraint(
_sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type] _sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type]
@ -426,15 +422,15 @@ def _get_schema_from_target(target):
@dataclass @dataclass
class GraphState: class GraphState:
inputs: List[Argument] = field(default_factory=list) inputs: list[Argument] = field(default_factory=list)
outputs: List[Argument] = field(default_factory=list) outputs: list[Argument] = field(default_factory=list)
nodes: List[Node] = field(default_factory=list) nodes: list[Node] = field(default_factory=list)
tensor_values: Dict[str, TensorMeta] = field(default_factory=dict) tensor_values: dict[str, TensorMeta] = field(default_factory=dict)
sym_int_values: Dict[str, SymInt] = field(default_factory=dict) sym_int_values: dict[str, SymInt] = field(default_factory=dict)
sym_bool_values: Dict[str, SymBool] = field(default_factory=dict) sym_bool_values: dict[str, SymBool] = field(default_factory=dict)
sym_float_values: Dict[str, SymFloat] = field(default_factory=dict) sym_float_values: dict[str, SymFloat] = field(default_factory=dict)
is_single_tensor_return: bool = False is_single_tensor_return: bool = False
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict) custom_obj_values: dict[str, CustomObjArgument] = field(default_factory=dict)
class Final(type): class Final(type):
@ -450,13 +446,13 @@ class GraphModuleSerializer(metaclass=Final):
def __init__( def __init__(
self, self,
graph_signature: ep.ExportGraphSignature, graph_signature: ep.ExportGraphSignature,
module_call_graph: List[ep.ModuleCallEntry], module_call_graph: list[ep.ModuleCallEntry],
): ):
self.graph_state = GraphState() self.graph_state = GraphState()
self.graph_signature = graph_signature self.graph_signature = graph_signature
self.module_call_graph = module_call_graph self.module_call_graph = module_call_graph
self.custom_objs: Dict[str, torch._C.ScriptObject] = {} self.custom_objs: dict[str, torch._C.ScriptObject] = {}
self.duplicate_getitem_nodes: Dict[str, str] = {} self.duplicate_getitem_nodes: dict[str, str] = {}
@contextmanager @contextmanager
def save_graph_state(self): def save_graph_state(self):
@ -597,7 +593,7 @@ class GraphModuleSerializer(metaclass=Final):
else: else:
return user_node.name return user_node.name
def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]:
ret = {} ret = {}
if stack_trace := node.meta.get("stack_trace"): if stack_trace := node.meta.get("stack_trace"):
ret["stack_trace"] = stack_trace ret["stack_trace"] = stack_trace
@ -647,7 +643,7 @@ class GraphModuleSerializer(metaclass=Final):
class_fqn=script_obj_meta.class_fqn, class_fqn=script_obj_meta.class_fqn,
) )
def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]: def serialize_sym_op_inputs(self, op, args) -> list[NamedArgument]:
if isinstance(op, torch._ops.OpOverload): if isinstance(op, torch._ops.OpOverload):
args_names = [arg.name for arg in op._schema.arguments] args_names = [arg.name for arg in op._schema.arguments]
else: else:
@ -669,7 +665,7 @@ class GraphModuleSerializer(metaclass=Final):
target: Any, # torch._ops.OpOverload and other custom operator types. target: Any, # torch._ops.OpOverload and other custom operator types.
args, args,
kwargs=None kwargs=None
) -> List[NamedArgument]: ) -> list[NamedArgument]:
assert isinstance(target, (torch._ops.OpOverload, *_registered_extension_types())) assert isinstance(target, (torch._ops.OpOverload, *_registered_extension_types()))
kwargs = kwargs or {} kwargs = kwargs or {}
serialized_args = [] serialized_args = []
@ -700,7 +696,7 @@ class GraphModuleSerializer(metaclass=Final):
return serialized_args return serialized_args
def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]: def serialize_hoo_inputs(self, args, kwargs) -> list[NamedArgument]:
""" """
For serializing HOO inputs since HOOs do not have a schema. For serializing HOO inputs since HOOs do not have a schema.
""" """
@ -1180,8 +1176,8 @@ class GraphModuleSerializer(metaclass=Final):
) )
def serialize_module_call_graph( def serialize_module_call_graph(
self, module_call_graph: List[ep.ModuleCallEntry] self, module_call_graph: list[ep.ModuleCallEntry]
) -> List[ModuleCallEntry]: ) -> list[ModuleCallEntry]:
return [ return [
ModuleCallEntry( ModuleCallEntry(
fqn=entry.fqn, fqn=entry.fqn,
@ -1194,7 +1190,7 @@ class GraphModuleSerializer(metaclass=Final):
for entry in module_call_graph for entry in module_call_graph
] ]
def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]: def serialize_outputs(self, node: torch.fx.Node) -> list[Argument]:
"""For a given node, return the dataclass representing its output values. """For a given node, return the dataclass representing its output values.
[NOTE: Multiple outputs] We handle aggregates differently than FX. For [NOTE: Multiple outputs] We handle aggregates differently than FX. For
@ -1294,7 +1290,7 @@ class GraphModuleSerializer(metaclass=Final):
return output_arguments return output_arguments
def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]: def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]:
""" """
For serializing HOO outputs since HOOs do not have a schema. For serializing HOO outputs since HOOs do not have a schema.
""" """
@ -1359,7 +1355,7 @@ class GraphModuleSerializer(metaclass=Final):
# list outputs should've been handled earlier # list outputs should've been handled earlier
raise SerializeError(f"Unable to serialize output {meta_val}") raise SerializeError(f"Unable to serialize output {meta_val}")
def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]: def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]:
meta_val = node.meta["val"] meta_val = node.meta["val"]
idx_to_name = {} idx_to_name = {}
@ -1406,7 +1402,7 @@ class GraphModuleSerializer(metaclass=Final):
is_single_tensor_return=self.graph_state.is_single_tensor_return, is_single_tensor_return=self.graph_state.is_single_tensor_return,
) )
def serialize_graph_module_metadata(self, meta: Dict[str, Any]): def serialize_graph_module_metadata(self, meta: dict[str, Any]):
ret = {} ret = {}
if custom := meta.get("custom"): if custom := meta.get("custom"):
try: try:
@ -1431,8 +1427,8 @@ class GraphModuleSerializer(metaclass=Final):
@final @final
class ExportedProgramSerializer(metaclass=Final): class ExportedProgramSerializer(metaclass=Final):
def __init__(self, opset_version: Optional[Dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL): def __init__(self, opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL):
self.opset_version: Dict[str, int] = {} self.opset_version: dict[str, int] = {}
if opset_version: if opset_version:
self.opset_version.update(opset_version) self.opset_version.update(opset_version)
if "aten" not in self.opset_version: if "aten" not in self.opset_version:
@ -1498,15 +1494,15 @@ class GraphModuleDeserializer(metaclass=Final):
class Result: class Result:
graph_module: torch.fx.GraphModule graph_module: torch.fx.GraphModule
signature: ep.ExportGraphSignature signature: ep.ExportGraphSignature
module_call_graph: List[ep.ModuleCallEntry] module_call_graph: list[ep.ModuleCallEntry]
names_to_symbols: Dict[str, sympy.Symbol] names_to_symbols: dict[str, sympy.Symbol]
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]]
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]] constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
example_inputs: Optional[tuple[tuple[torch.Tensor, ...], Dict[str, Any]]] example_inputs: Optional[tuple[tuple[torch.Tensor, ...], dict[str, Any]]]
def __init__(self) -> None: def __init__(self) -> None:
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {} self.serialized_name_to_node: dict[str, torch.fx.Node] = {}
self.serialized_name_to_meta: Dict[str, MetaType] = {} self.serialized_name_to_meta: dict[str, MetaType] = {}
self.graph = torch.fx.Graph() self.graph = torch.fx.Graph()
self.module = torch.nn.Module() self.module = torch.nn.Module()
@ -1953,10 +1949,10 @@ class GraphModuleDeserializer(metaclass=Final):
def deserialize( def deserialize(
self, self,
serialized_graph_module: GraphModule, serialized_graph_module: GraphModule,
serialized_state_dict: Union[Dict[str, torch.Tensor], bytes], serialized_state_dict: Union[dict[str, torch.Tensor], bytes],
constants: Union[Dict[str, Any], bytes], constants: Union[dict[str, Any], bytes],
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None, example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]] = None,
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None, symbol_name_to_range: Optional[dict[str, symbolic_shapes.ValueRanges]] = None,
) -> Result: ) -> Result:
global _CURRENT_DESERIALIZER global _CURRENT_DESERIALIZER
assert _CURRENT_DESERIALIZER is None assert _CURRENT_DESERIALIZER is None
@ -1996,7 +1992,7 @@ class GraphModuleDeserializer(metaclass=Final):
"ToFloat": torch.utils._sympy.functions.ToFloat, "ToFloat": torch.utils._sympy.functions.ToFloat,
"Identity": torch.utils._sympy.functions.Identity, "Identity": torch.utils._sympy.functions.Identity,
} }
self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {} self.symbol_name_to_symbol: dict[str, sympy.Symbol] = {}
self.constants = deserialize_torch_artifact(constants) self.constants = deserialize_torch_artifact(constants)
self.signature = self.deserialize_signature(serialized_graph_module.signature) self.signature = self.deserialize_signature(serialized_graph_module.signature)
@ -2091,7 +2087,7 @@ class GraphModuleDeserializer(metaclass=Final):
kwargs[schema_arg.name] = actual_args[schema_arg.name] kwargs[schema_arg.name] = actual_args[schema_arg.name]
return tuple(args), kwargs return tuple(args), kwargs
def deserialize_hoo_inputs(self, inputs: List[NamedArgument]): def deserialize_hoo_inputs(self, inputs: list[NamedArgument]):
""" """
For deserializing HOO inputs since HOOs do not have a schema. For deserializing HOO inputs since HOOs do not have a schema.
""" """
@ -2232,7 +2228,7 @@ class GraphModuleDeserializer(metaclass=Final):
"torch.ops.higher_order" in serialized_node.target "torch.ops.higher_order" in serialized_node.target
and not getattr(serialized_node, "is_hop_single_tensor_return", True) and not getattr(serialized_node, "is_hop_single_tensor_return", True)
): ):
meta_val: List[Any] = [] meta_val: list[Any] = []
arg = serialized_node.outputs[0].as_tensor arg = serialized_node.outputs[0].as_tensor
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata) self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
@ -2261,7 +2257,7 @@ class GraphModuleDeserializer(metaclass=Final):
fx_node: torch.fx.Node, fx_node: torch.fx.Node,
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument], arg: Union[TensorArgument, SymIntArgument, SymFloatArgument],
idx: int, idx: int,
deserialized_metadata: Dict[str, Any], deserialized_metadata: dict[str, Any],
): ):
if isinstance(arg, TensorArgument): if isinstance(arg, TensorArgument):
name = arg.name name = arg.name
@ -2290,7 +2286,7 @@ class GraphModuleDeserializer(metaclass=Final):
meta_val, meta_val,
fx_node: torch.fx.Node, fx_node: torch.fx.Node,
args, args,
deserialized_metadata: Dict[str, Any], deserialized_metadata: dict[str, Any],
): ):
for idx, arg in enumerate(args): for idx, arg in enumerate(args):
if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)): if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)):
@ -2343,7 +2339,7 @@ class GraphModuleDeserializer(metaclass=Final):
# return value. # return value.
# This performs the inverse mapping of the `serialize_outputs` call in # This performs the inverse mapping of the `serialize_outputs` call in
# serialization, see [NOTE: Multiple outputs] # serialization, see [NOTE: Multiple outputs]
meta_val: List[Any] = [] meta_val: list[Any] = []
if len(serialized_node.outputs) == 1: if len(serialized_node.outputs) == 1:
assert isinstance(serialized_node.outputs[0].value, list) assert isinstance(serialized_node.outputs[0].value, list)
assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
@ -2355,8 +2351,8 @@ class GraphModuleDeserializer(metaclass=Final):
fx_node.meta["val"] = tuple(meta_val) fx_node.meta["val"] = tuple(meta_val)
self.serialized_name_to_node[fx_node.name] = fx_node self.serialized_name_to_node[fx_node.name] = fx_node
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: def deserialize_metadata(self, metadata: dict[str, str]) -> dict[str, Any]:
ret: Dict[str, Any] = {} ret: dict[str, Any] = {}
if stack_trace := metadata.get("stack_trace"): if stack_trace := metadata.get("stack_trace"):
ret["stack_trace"] = stack_trace ret["stack_trace"] = stack_trace
@ -2446,8 +2442,8 @@ class GraphModuleDeserializer(metaclass=Final):
) )
def deserialize_module_call_graph( def deserialize_module_call_graph(
self, module_call_graph: List[ModuleCallEntry] self, module_call_graph: list[ModuleCallEntry]
) -> List[ep.ModuleCallEntry]: ) -> list[ep.ModuleCallEntry]:
return [ return [
ep.ModuleCallEntry( ep.ModuleCallEntry(
fqn=entry.fqn, fqn=entry.fqn,
@ -2463,8 +2459,8 @@ class GraphModuleDeserializer(metaclass=Final):
@final @final
class ExportedProgramDeserializer(metaclass=Final): class ExportedProgramDeserializer(metaclass=Final):
def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None): def __init__(self, expected_opset_version: Optional[dict[str, int]] = None):
self.expected_opset_version: Dict[str, int] = {} self.expected_opset_version: dict[str, int] = {}
if expected_opset_version: if expected_opset_version:
self.expected_opset_version.update(expected_opset_version) self.expected_opset_version.update(expected_opset_version)
if "aten" not in self.expected_opset_version: if "aten" not in self.expected_opset_version:
@ -2472,9 +2468,9 @@ class ExportedProgramDeserializer(metaclass=Final):
def deserialize_range_constraints( def deserialize_range_constraints(
self, self,
symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges], symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges],
symbol_name_to_symbol: Dict[str, sympy.Symbol], symbol_name_to_symbol: dict[str, sympy.Symbol],
) -> Dict[sympy.Symbol, ValueRanges]: ) -> dict[sympy.Symbol, ValueRanges]:
range_constraints = {} range_constraints = {}
for k, v in symbol_name_to_range.items(): for k, v in symbol_name_to_range.items():
if symbol := symbol_name_to_symbol.get(k): if symbol := symbol_name_to_symbol.get(k):
@ -2486,9 +2482,9 @@ class ExportedProgramDeserializer(metaclass=Final):
def deserialize( def deserialize(
self, self,
exported_program: ExportedProgram, exported_program: ExportedProgram,
state_dict: Union[Dict[str, torch.Tensor], bytes], state_dict: Union[dict[str, torch.Tensor], bytes],
constants: Union[Dict[str, torch.Tensor], bytes], constants: Union[dict[str, torch.Tensor], bytes],
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None, example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]] = None,
*, *,
_unsafe_skip_version_check=False, _unsafe_skip_version_check=False,
) -> ep.ExportedProgram: ) -> ep.ExportedProgram:
@ -2566,7 +2562,7 @@ def _dataclass_to_dict(obj):
def serialize( def serialize(
exported_program: ep.ExportedProgram, exported_program: ep.ExportedProgram,
opset_version: Optional[Dict[str, int]] = None, opset_version: Optional[dict[str, int]] = None,
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
) -> SerializedArtifact: ) -> SerializedArtifact:
with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs): with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs):
@ -2627,7 +2623,7 @@ def _dict_to_dataclass(cls, data):
def deserialize( def deserialize(
artifact: SerializedArtifact, artifact: SerializedArtifact,
expected_opset_version: Optional[Dict[str, int]] = None, expected_opset_version: Optional[dict[str, int]] = None,
*, *,
_unsafe_skip_version_check=False, _unsafe_skip_version_check=False,
) -> ep.ExportedProgram: ) -> ep.ExportedProgram:
@ -2649,7 +2645,7 @@ def deserialize(
def _canonicalize_graph( def _canonicalize_graph(
sorted_inputs, sorted_outputs, graph sorted_inputs, sorted_outputs, graph
) -> tuple[Graph, Dict[str, str]]: ) -> tuple[Graph, dict[str, str]]:
def _get_argument(a: Argument): def _get_argument(a: Argument):
if a.type == "as_none": if a.type == "as_none":
return None return None
@ -2712,15 +2708,15 @@ def _canonicalize_graph(
def sort_nodes(nodes): def sort_nodes(nodes):
@dataclass @dataclass
class Edges: class Edges:
outs: List[int] outs: list[int]
ins: int ins: int
graph_inputs: Set[str] = set() graph_inputs: set[str] = set()
def_table: Dict[str, int] = {} def_table: dict[str, int] = {}
edges: Dict[int, Edges] = {} edges: dict[int, Edges] = {}
candidates: List[tuple[str, List[tuple[str, List[int]]], int]] = [] candidates: list[tuple[str, list[tuple[str, list[int]]], int]] = []
rank: Dict[str, int] = {} rank: dict[str, int] = {}
ret: List[Node] = [] ret: list[Node] = []
def get_name(a) -> Optional[str]: def get_name(a) -> Optional[str]:
if a is None: if a is None:
@ -2827,7 +2823,7 @@ def _canonicalize_graph(
assert len(sorted_nodes) == len(graph.nodes) assert len(sorted_nodes) == len(graph.nodes)
# Stage 2: Rename nodes. # Stage 2: Rename nodes.
name_table: Dict[str, str] = {} name_table: dict[str, str] = {}
def rename_def(a): def rename_def(a):
def _rename(arg_name, values): def _rename(arg_name, values):
@ -3163,8 +3159,8 @@ class ExtensionHandler:
def register_extension( def register_extension(
op_type: Type[Any], op_type: type[Any],
extension_handler: Type[ExtensionHandler], extension_handler: type[ExtensionHandler],
): ):
"""Register custom de/serialization method for a node with non-standard type.""" """Register custom de/serialization method for a node with non-standard type."""
assert issubclass(extension_handler, ExtensionHandler), f"Expected ExtensionHandler, got {extension_handler}." assert issubclass(extension_handler, ExtensionHandler), f"Expected ExtensionHandler, got {extension_handler}."
@ -3187,5 +3183,5 @@ def _registered_extension_types():
# namespace to avoid conflicts. # namespace to avoid conflicts.
# Serialization: Op type --> custom handler. # Serialization: Op type --> custom handler.
# De-serialization: Namespace --> custom handler. # De-serialization: Namespace --> custom handler.
_serialization_registry: Dict[Type[Any], Type[ExtensionHandler]] = {} _serialization_registry: dict[type[Any], type[ExtensionHandler]] = {}
_deserialization_registry: Dict[str, Type[ExtensionHandler]] = {} _deserialization_registry: dict[str, type[ExtensionHandler]] = {}

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import functools import functools
from collections.abc import Hashable
from dataclasses import fields from dataclasses import fields
from typing import Hashable, Set
class _UnionTag(str): class _UnionTag(str):
@ -26,8 +26,8 @@ class _UnionTag(str):
return hash(str(self)) return hash(str(self))
@functools.lru_cache(maxsize=None) @functools.cache
def _get_field_names(cls) -> Set[str]: def _get_field_names(cls) -> set[str]:
return {f.name for f in fields(cls)} return {f.name for f in fields(cls)}

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import logging import logging
import warnings import warnings
from typing import Any, Dict, Iterable, Optional from collections.abc import Iterable
from typing import Any, Optional
import torch import torch
import torch.export import torch.export
@ -18,8 +19,8 @@ def _generate_inputs_for_submodules(
model: torch.nn.Module, model: torch.nn.Module,
target_submodules: Iterable[str], target_submodules: Iterable[str],
args: tuple[Any, ...], args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
) -> Dict[str, tuple[Any, Any]]: ) -> dict[str, tuple[Any, Any]]:
""" """
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
function doesn't work. function doesn't work.
@ -61,11 +62,11 @@ def _generate_inputs_for_submodules(
def report_exportability( def report_exportability(
mod: torch.nn.Module, mod: torch.nn.Module,
args: tuple[Any, ...], args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
*, *,
strict: bool = True, strict: bool = True,
pre_dispatch: bool = False, pre_dispatch: bool = False,
) -> Dict[str, Optional[Exception]]: ) -> dict[str, Optional[Exception]]:
""" """
Report exportability issues for a module in one-shot. Report exportability issues for a module in one-shot.
@ -92,7 +93,7 @@ def report_exportability(
submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
tried_module_types = set() tried_module_types = set()
report: Dict[str, Optional[Exception]] = {} report: dict[str, Optional[Exception]] = {}
def try_export(module, module_name, args, kwargs): def try_export(module, module_name, args, kwargs):
nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types

View File

@ -8,21 +8,10 @@ import json
import math import math
import operator import operator
import re import re
from collections.abc import Iterable
from contextlib import contextmanager from contextlib import contextmanager
from inspect import Parameter from inspect import Parameter
from typing import ( from typing import Any, Callable, Optional, TYPE_CHECKING, Union
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
import torch import torch
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
@ -116,7 +105,7 @@ def _overwrite_signature_for_non_persistent_buffers(
return new_sig return new_sig
def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]: def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> dict[str, Any]:
""" """
Param/buffer metadata needs to be saved before lowering to aten IR Param/buffer metadata needs to be saved before lowering to aten IR
because aten IR lifts them, as a result, automatic preservation doesn't work. because aten IR lifts them, as a result, automatic preservation doesn't work.
@ -174,7 +163,7 @@ def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]:
def _populate_param_buffer_metadata_to_new_gm( def _populate_param_buffer_metadata_to_new_gm(
params_buffers_to_node_meta: Dict[str, Any], params_buffers_to_node_meta: dict[str, Any],
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
new_sig: "ExportGraphSignature", new_sig: "ExportGraphSignature",
) -> None: ) -> None:
@ -217,7 +206,7 @@ def _get_shape_env_from_gm(gm: torch.fx.GraphModule):
def _rename_without_collisions( def _rename_without_collisions(
name_map: Dict[str, str], name_map: dict[str, str],
orig_name: str, orig_name: str,
name: str, name: str,
is_placeholder: bool = False, is_placeholder: bool = False,
@ -246,7 +235,7 @@ def _rename_without_collisions(
def _check_input_constraints_for_graph( def _check_input_constraints_for_graph(
input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints input_placeholders: list[torch.fx.Node], flat_args_with_path, range_constraints
) -> None: ) -> None:
def get_keystr(key_path: KeyPath) -> str: def get_keystr(key_path: KeyPath) -> str:
"""For a given index into the flat_args, return a human readable string """For a given index into the flat_args, return a human readable string
@ -280,7 +269,7 @@ def _check_input_constraints_for_graph(
# NOTE: export already guarantees that the same symbol is used in metadata # NOTE: export already guarantees that the same symbol is used in metadata
# for all InputDims related by equality constraints, so we can just unify # for all InputDims related by equality constraints, so we can just unify
# symbols with given input dimension values to check equality constraints. # symbols with given input dimension values to check equality constraints.
unification_map: Dict[sympy.Symbol, Any] = {} unification_map: dict[sympy.Symbol, Any] = {}
for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
node_val = node.meta.get("val") node_val = node.meta.get("val")
if isinstance(node_val, FakeTensor): if isinstance(node_val, FakeTensor):
@ -372,7 +361,7 @@ def _check_input_constraints_for_graph(
def register_dataclass_as_pytree_node( def register_dataclass_as_pytree_node(
cls: Type[Any], cls: type[Any],
flatten_fn: Optional[FlattenFunc] = None, flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None, unflatten_fn: Optional[UnflattenFunc] = None,
*, *,
@ -385,7 +374,7 @@ def register_dataclass_as_pytree_node(
cls cls
), f"Only dataclasses can be registered with this function: {cls}" ), f"Only dataclasses can be registered with this function: {cls}"
def default_flatten_fn(obj: Any) -> tuple[List[Any], Context]: def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
flattened = [] flattened = []
flat_names = [] flat_names = []
none_names = [] none_names = []
@ -402,7 +391,7 @@ def register_dataclass_as_pytree_node(
flat_names, none_names = context flat_names, none_names = context
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
def default_flatten_fn_with_keys(obj: Any) -> tuple[List[Any], Context]: def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, (flat_names, _none_names) = flatten_fn(obj) # type: ignore[misc] flattened, (flat_names, _none_names) = flatten_fn(obj) # type: ignore[misc]
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
@ -537,7 +526,7 @@ def sequential_split(
return new_gm return new_gm
def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: def nodes_filter(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]:
"""Returns the nodes that match the node_call_back as a list.""" """Returns the nodes that match the node_call_back as a list."""
return [node for node in nodes if node_call_back(node)] return [node for node in nodes if node_call_back(node)]
@ -572,7 +561,7 @@ def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature):
def nodes_first( def nodes_first(
nodes: List[torch.fx.Node], node_call_back=None nodes: list[torch.fx.Node], node_call_back=None
) -> Optional[torch.fx.Node]: ) -> Optional[torch.fx.Node]:
""" """
Returns the first node that matches the node_call_back. If no node matches, returns None. Returns the first node that matches the node_call_back. If no node matches, returns None.
@ -584,12 +573,12 @@ def nodes_first(
return None return None
def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int: def nodes_count(nodes: list[torch.fx.Node], node_call_back) -> int:
"""Returns the number of nodes that match the node_call_back.""" """Returns the number of nodes that match the node_call_back."""
return len(nodes_filter(nodes, node_call_back)) return len(nodes_filter(nodes, node_call_back))
def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: def nodes_map(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]:
""" """
Sequentially visit the nodes list and invoke node_call_back on each element. Sequentially visit the nodes list and invoke node_call_back on each element.
Returns the nodes list after the node_call_back is invoked on each element. Returns the nodes list after the node_call_back is invoked on each element.
@ -748,7 +737,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
and gather the top-level named placeholder nodes. and gather the top-level named placeholder nodes.
""" """
# gather all HOO subgraphs and their top-level named placeholder nodes # gather all HOO subgraphs and their top-level named placeholder nodes
subgraph_ph_tuples: List[tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] subgraph_ph_tuples: list[tuple[torch.fx.GraphModule, list[torch.fx.Node]]] = []
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.op == "call_function" and isinstance( if node.op == "call_function" and isinstance(
node.target, torch._ops.HigherOrderOperator node.target, torch._ops.HigherOrderOperator
@ -769,7 +758,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
# propagate names # propagate names
for subgraph, hoo_phs in subgraph_ph_tuples: for subgraph, hoo_phs in subgraph_ph_tuples:
name_map: Dict[str, str] = {} name_map: dict[str, str] = {}
for i, node in enumerate(subgraph.graph.nodes): for i, node in enumerate(subgraph.graph.nodes):
if i < len(hoo_phs): # placeholder, retain name if i < len(hoo_phs): # placeholder, retain name
name_map[node.name] = hoo_phs[i].name name_map[node.name] = hoo_phs[i].name
@ -789,7 +778,7 @@ def placeholder_naming_pass(
fake_args, fake_args,
fake_kwargs, fake_kwargs,
fake_params_buffers, fake_params_buffers,
constants: Dict[str, Any], constants: dict[str, Any],
) -> None: ) -> None:
""" """
This pass is run at the end of _export_non_strict() to assign better placeholder node names: This pass is run at the end of _export_non_strict() to assign better placeholder node names:
@ -828,7 +817,7 @@ def placeholder_naming_pass(
else: else:
raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}")
name_map: Dict[str, str] = {} name_map: dict[str, str] = {}
# map user input names with mod.forward() signature # map user input names with mod.forward() signature
combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs)
@ -927,7 +916,7 @@ def placeholder_naming_pass(
del constants[name] del constants[name]
def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict: def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict:
""" """
If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`. If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`.
`v` is the values in the dictionary. `v` is the values in the dictionary.
@ -957,8 +946,8 @@ def _detect_fake_mode_from_gm(
If no fake mode is found, we return None for fake_mode. If no fake mode is found, we return None for fake_mode.
""" """
fake_inps: List[torch.Tensor] = [] fake_inps: list[torch.Tensor] = []
fake_vals: List[torch.Tensor] = [] fake_vals: list[torch.Tensor] = []
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.op == "placeholder" and "val" in node.meta: if node.op == "placeholder" and "val" in node.meta:
fake_val = node.meta["val"] fake_val = node.meta["val"]
@ -980,8 +969,8 @@ def _detect_fake_mode_from_gm(
@contextmanager @contextmanager
def _disable_load_state_dict_hooks(mod: torch.nn.Module): def _disable_load_state_dict_hooks(mod: torch.nn.Module):
state_dict_hooks: Dict[int, Callable] = dict(mod._state_dict_hooks) state_dict_hooks: dict[int, Callable] = dict(mod._state_dict_hooks)
state_dict_pre_hooks: Dict[int, Callable] = dict(mod._state_dict_pre_hooks) state_dict_pre_hooks: dict[int, Callable] = dict(mod._state_dict_pre_hooks)
mod._state_dict_hooks.clear() mod._state_dict_hooks.clear()
mod._state_dict_pre_hooks.clear() mod._state_dict_pre_hooks.clear()
try: try:
@ -1075,11 +1064,11 @@ def _check_valid_to_preserve(op_overload: "OperatorBase"):
@functools.lru_cache(maxsize=1) @functools.lru_cache(maxsize=1)
def _collect_all_valid_cia_ops_for_aten_namespace() -> Set["OperatorBase"]: def _collect_all_valid_cia_ops_for_aten_namespace() -> set["OperatorBase"]:
return _collect_all_valid_cia_ops_for_namespace("aten") return _collect_all_valid_cia_ops_for_namespace("aten")
def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBase"]: def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> set["OperatorBase"]:
# Step 1: Materialize all ops from C++ dispatcher # Step 1: Materialize all ops from C++ dispatcher
_materialize_cpp_cia_ops() _materialize_cpp_cia_ops()
@ -1096,7 +1085,7 @@ def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBas
return cia_ops return cia_ops
def _collect_all_valid_cia_ops() -> Set["OperatorBase"]: def _collect_all_valid_cia_ops() -> set["OperatorBase"]:
""" """
This is an util function that gets the all CIA functional ops. This is an util function that gets the all CIA functional ops.
@ -1166,14 +1155,14 @@ def _compiling_state_context():
def _fakify_params_buffers( def _fakify_params_buffers(
fake_mode: FakeTensorMode, fake_mode: FakeTensorMode,
mod: torch.nn.Module, mod: torch.nn.Module,
) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]: ) -> dict[str, Union[torch.Tensor, torch.nn.Parameter]]:
params_buffers = { params_buffers = {
**dict(mod.named_parameters(remove_duplicate=False)), **dict(mod.named_parameters(remove_duplicate=False)),
**dict(mod.named_buffers(remove_duplicate=False)), **dict(mod.named_buffers(remove_duplicate=False)),
} }
faked_params_buffers = {} faked_params_buffers = {}
memo: Dict[int, FakeTensor] = {} memo: dict[int, FakeTensor] = {}
for key, value in params_buffers.items(): for key, value in params_buffers.items():
if id(value) in memo: if id(value) in memo:
fake_tensor = memo[id(value)] fake_tensor = memo[id(value)]
@ -1184,7 +1173,7 @@ def _fakify_params_buffers(
return faked_params_buffers # type: ignore[return-value] return faked_params_buffers # type: ignore[return-value]
def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None: def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
""" """
Registers a module as a valid input type for :func:`torch.export.export`. Registers a module as a valid input type for :func:`torch.export.export`.
@ -1233,7 +1222,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return PrototypeModule(self()) return PrototypeModule(self())
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
named_parameters = dict(obj.named_parameters()) named_parameters = dict(obj.named_parameters())
named_buffers = dict(obj.named_buffers()) named_buffers = dict(obj.named_buffers())
params_buffers = {**named_parameters, **named_buffers} params_buffers = {**named_parameters, **named_buffers}
@ -1270,7 +1259,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
ret = obj ret = obj
return ret return ret
def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]: def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, [flat_names, *args] = flatten_fn(obj) # type: ignore[misc] flattened, [flat_names, *args] = flatten_fn(obj) # type: ignore[misc]
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], [ return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], [
flat_names, flat_names,
@ -1301,7 +1290,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
from_dumpable_context=from_dumpable_context, from_dumpable_context=from_dumpable_context,
) )
def default_flatten_fn_spec(obj, spec) -> List[Any]: def default_flatten_fn_spec(obj, spec) -> list[Any]:
flats, context = flatten_fn(obj) flats, context = flatten_fn(obj)
assert context == spec.context assert context == spec.context
return flats return flats
@ -1312,6 +1301,6 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
) )
def deregister_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None: def deregister_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
_deregister_pytree_node(cls) _deregister_pytree_node(cls)
_deregister_pytree_flatten_spec(cls) _deregister_pytree_flatten_spec(cls)

View File

@ -4,7 +4,7 @@ import inspect
import math import math
import operator import operator
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Dict, final, List, Type, TYPE_CHECKING from typing import Any, final, TYPE_CHECKING
import torch import torch
from torch._ops import HigherOrderOperator, OpOverload from torch._ops import HigherOrderOperator, OpOverload
@ -83,7 +83,7 @@ def _check_torch_fn(node: torch.fx.Node) -> None:
raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}") raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}")
class _VerifierMeta(type): class _VerifierMeta(type):
_registry: Dict[str, Type['Verifier']] = {} _registry: dict[str, type['Verifier']] = {}
def __new__(metacls, name, bases, attrs): def __new__(metacls, name, bases, attrs):
if bases: if bases:
@ -113,7 +113,7 @@ def getattr_recursive(obj: Any, target: str) -> Any:
class Verifier(metaclass=_VerifierMeta): class Verifier(metaclass=_VerifierMeta):
dialect = "ATEN" dialect = "ATEN"
def allowed_builtin_ops(self) -> List: def allowed_builtin_ops(self) -> list:
return [ return [
operator.getitem, operator.getitem,
operator.add, operator.add,
@ -141,10 +141,10 @@ class Verifier(metaclass=_VerifierMeta):
builtins.getattr, builtins.getattr,
] ]
def allowed_op_types(self) -> tuple[Type[Any], ...]: def allowed_op_types(self) -> tuple[type[Any], ...]:
return (OpOverload, HigherOrderOperator) return (OpOverload, HigherOrderOperator)
def allowed_getattr_types(self) -> tuple[Type[Any], ...]: def allowed_getattr_types(self) -> tuple[type[Any], ...]:
return (torch.fx.GraphModule,) return (torch.fx.GraphModule,)
def check_valid_op(self, op): def check_valid_op(self, op):
@ -163,18 +163,18 @@ class Verifier(metaclass=_VerifierMeta):
@final @final
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
def _allowed_getattr_types() -> tuple[Type[Any], ...]: def _allowed_getattr_types() -> tuple[type[Any], ...]:
ret = self.allowed_getattr_types() ret = self.allowed_getattr_types()
assert not any(t is object for t in ret) assert not any(t is object for t in ret)
return ret return ret
def _check_valid_op(op) -> None: def _check_valid_op(op) -> None:
def _allowed_builtin_ops() -> List: def _allowed_builtin_ops() -> list:
ret = self.allowed_builtin_ops() ret = self.allowed_builtin_ops()
assert all(inspect.isbuiltin(op) for op in ret) assert all(inspect.isbuiltin(op) for op in ret)
return ret return ret
def _allowed_op_types() -> tuple[Type[Any], ...]: def _allowed_op_types() -> tuple[type[Any], ...]:
ret = self.allowed_op_types() ret = self.allowed_op_types()
assert not any(t is object for t in ret) assert not any(t is object for t in ret)
return ret return ret
@ -426,7 +426,7 @@ def _verify_exported_program_signature(exported_program) -> None:
num_tokens = len(gs.output_tokens) num_tokens = len(gs.output_tokens)
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
mutate_nodes: List[str] = output_nodes[num_tokens:end] mutate_nodes: list[str] = output_nodes[num_tokens:end]
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
for mutation_node in mutate_nodes: for mutation_node in mutate_nodes:
@ -458,7 +458,7 @@ def _verify_exported_program_signature(exported_program) -> None:
) )
def load_verifier(dialect: str) -> Type[Verifier]: def load_verifier(dialect: str) -> type[Verifier]:
if dialect == "ATEN" or dialect == "": if dialect == "ATEN" or dialect == "":
return _VerifierMeta._registry.get(dialect, Verifier) return _VerifierMeta._registry.get(dialect, Verifier)
return _VerifierMeta._registry[dialect] return _VerifierMeta._registry[dialect]