mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/_export (#145138)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145138 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #145154
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd8d0fa20c
commit
97d4d3c40a
@ -76,14 +76,14 @@ def aot_compile_warning():
|
||||
def aot_compile(
|
||||
f: Callable,
|
||||
args: tuple[Any],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Dict[str, Any]] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[dict[str, Any]] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
remove_runtime_assertions: bool = False,
|
||||
disable_constraint_solver: bool = False,
|
||||
same_signature: bool = True,
|
||||
) -> Union[List[str], str]:
|
||||
) -> Union[list[str], str]:
|
||||
"""
|
||||
Note: this function is not stable yet
|
||||
|
||||
|
@ -4,8 +4,9 @@ import logging
|
||||
import operator
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.export._trace
|
||||
@ -72,7 +73,7 @@ def _trace_and_get_graph_from_model(model, args):
|
||||
|
||||
def _create_jit_graph(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
|
||||
) -> tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]:
|
||||
) -> tuple[torch.Graph, list["_C.IValue"], Any, Optional[torch.ScriptModule]]:
|
||||
if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
|
||||
flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
|
||||
torch_out = None
|
||||
@ -263,7 +264,7 @@ def construct_fqn(ir, ref_map, name_map):
|
||||
|
||||
def get_block_to_lifted_attrs(
|
||||
graph: torch._C.Graph,
|
||||
) -> tuple[Dict[torch._C.Block, Set[str]], Dict[str, str]]:
|
||||
) -> tuple[dict[torch._C.Block, set[str]], dict[str, str]]:
|
||||
"""
|
||||
Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes.
|
||||
When a graph has control flow, the graph will be divided into multiple blocks. We want to convert
|
||||
@ -280,19 +281,19 @@ def get_block_to_lifted_attrs(
|
||||
"""
|
||||
|
||||
# A map from a block to its expected to be lifted arguments.
|
||||
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {}
|
||||
blocks_to_lifted_attrs: dict[torch._C.Block, set[str]] = {}
|
||||
|
||||
# Reference map stores the input (i.e., src) and output (i.e., dest) IR of a
|
||||
# GetAttr node. By traversing this reference map, we can figure out the
|
||||
# full IR aliasing pass and figure out the FQN of an attribute.
|
||||
# E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1"
|
||||
node_to_parent_map: Dict[str, str] = {}
|
||||
node_to_parent_map: dict[str, str] = {}
|
||||
|
||||
# Used for reconstructing the FQN of an attribute based on the reference map.
|
||||
# In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR
|
||||
# This name map stores which attribute name is called for a src IR --> dest IR action.
|
||||
# E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear"
|
||||
node_to_attr_name: Dict[str, str] = {}
|
||||
node_to_attr_name: dict[str, str] = {}
|
||||
|
||||
def _dfs_get_attr_dependency(entry):
|
||||
"""
|
||||
@ -315,7 +316,7 @@ def get_block_to_lifted_attrs(
|
||||
Walk the graph in a bottom-up fashion to build the expected to be
|
||||
lifted arguments for each block.
|
||||
"""
|
||||
arguments: Set[str] = set()
|
||||
arguments: set[str] = set()
|
||||
for node in entry.nodes():
|
||||
for block in node.blocks():
|
||||
# Recursively build.
|
||||
@ -342,7 +343,7 @@ def get_block_to_lifted_attrs(
|
||||
|
||||
|
||||
def get_attribute_fqn_from_ts_node(
|
||||
name_to_attribute_fqn: Dict[str, str], node: torch._C.Node
|
||||
name_to_attribute_fqn: dict[str, str], node: torch._C.Node
|
||||
) -> str:
|
||||
def get_attr(name: str):
|
||||
if name in name_to_attribute_fqn:
|
||||
@ -392,12 +393,12 @@ class TS2FXGraphConverter:
|
||||
def __init__(
|
||||
self,
|
||||
ts_graph: Union[torch._C.Graph, torch._C.Block],
|
||||
name_to_param: Dict[str, torch.Tensor],
|
||||
name_to_buffer: Dict[str, torch.Tensor],
|
||||
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
|
||||
name_to_non_tensor_attribute: Dict[str, Any],
|
||||
name_to_constant: Dict[str, Any],
|
||||
name_to_attribute_fqn: Dict[str, str],
|
||||
name_to_param: dict[str, torch.Tensor],
|
||||
name_to_buffer: dict[str, torch.Tensor],
|
||||
blocks_to_lifted_attrs: dict[torch._C.Block, set[str]],
|
||||
name_to_non_tensor_attribute: dict[str, Any],
|
||||
name_to_constant: dict[str, Any],
|
||||
name_to_attribute_fqn: dict[str, str],
|
||||
):
|
||||
self.ts_graph = ts_graph
|
||||
# Mapping of parameter FQN to actual parameter value
|
||||
@ -406,19 +407,19 @@ class TS2FXGraphConverter:
|
||||
self.name_to_buffer = name_to_buffer
|
||||
|
||||
self.fx_graph: torch.fx.Graph = torch.fx.Graph()
|
||||
self.input_specs: List[InputSpec] = []
|
||||
self.output_specs: List[OutputSpec] = []
|
||||
self.input_specs: list[InputSpec] = []
|
||||
self.output_specs: list[OutputSpec] = []
|
||||
|
||||
# Mapping of TS node name to converted FX node
|
||||
self.name_to_node: Dict[
|
||||
str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
|
||||
self.name_to_node: dict[
|
||||
str, Union[torch.fx.Node, list[torch.fx.Node], dict[Any, torch.fx.Node]]
|
||||
] = {}
|
||||
# Mapping of TS node name to constant value (int, str, TorchBind obj,
|
||||
# tensor constants ...)
|
||||
self.name_to_constant: Dict[str, Any] = name_to_constant
|
||||
self.name_to_constant: dict[str, Any] = name_to_constant
|
||||
|
||||
# Mapping from torchscript node output name to attribute fully qualified name
|
||||
self.name_to_attribute_fqn: Dict[str, str] = name_to_attribute_fqn
|
||||
self.name_to_attribute_fqn: dict[str, str] = name_to_attribute_fqn
|
||||
|
||||
# Mapping from fully qualified name to real values or a fx graph node
|
||||
# During convert, this represents the current value of a non-tensor attribute
|
||||
@ -428,14 +429,14 @@ class TS2FXGraphConverter:
|
||||
# self.count += 1
|
||||
# c2 = self.count
|
||||
# return x + c1 + c2
|
||||
self.name_to_non_tensor_attribute_node: Dict[str, Any] = {}
|
||||
self.name_to_non_tensor_attribute_node: dict[str, Any] = {}
|
||||
|
||||
# Mapping from fully qualified name to initial real values inputs
|
||||
# We separate it from self.name_to_non_tensor_attribute_node since
|
||||
# we need initial real value input when we construct fx.GraphModule
|
||||
self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute
|
||||
self.name_to_non_tensor_attribute: dict[str, Any] = name_to_non_tensor_attribute
|
||||
|
||||
self.subgraphs: Dict[str, torch.fx.GraphModule] = {}
|
||||
self.subgraphs: dict[str, torch.fx.GraphModule] = {}
|
||||
|
||||
# Mapping of block to list of attributes that need to be lifted for each
|
||||
# block
|
||||
@ -457,7 +458,7 @@ class TS2FXGraphConverter:
|
||||
# might have inplace updates to the variable defined in the parent fx graph. After
|
||||
# the execution of that sub-block, the variable defined in the parent fx graph also
|
||||
# needs to be updated.
|
||||
self.name_update_from_subblock_to_parent: Set[str] = set()
|
||||
self.name_update_from_subblock_to_parent: set[str] = set()
|
||||
|
||||
def _is_get_attr_node(self, fqn):
|
||||
return (
|
||||
@ -469,7 +470,7 @@ class TS2FXGraphConverter:
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]):
|
||||
def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: list[str]):
|
||||
subgraph_nodes, subgraph_converters = [], []
|
||||
for block in node.blocks():
|
||||
subgraph_converter = TS2FXGraphConverter(
|
||||
@ -506,7 +507,7 @@ class TS2FXGraphConverter:
|
||||
Block[x.1]
|
||||
%2 = x.1 ...
|
||||
"""
|
||||
arguments: Set[str] = set()
|
||||
arguments: set[str] = set()
|
||||
for block in entry.blocks():
|
||||
for block_node in block.nodes():
|
||||
for block_node_in in block_node.inputs():
|
||||
@ -1332,12 +1333,12 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
|
||||
def __init__(
|
||||
self,
|
||||
ts_graph: Union[torch._C.Graph, torch._C.Block],
|
||||
name_to_param: Dict[str, torch.Tensor],
|
||||
name_to_buffer: Dict[str, torch.Tensor],
|
||||
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
|
||||
name_to_non_tensor_attribute: Dict[str, Any],
|
||||
name_to_constant: Dict[str, Any],
|
||||
name_to_attribute_fqn: Dict[str, str],
|
||||
name_to_param: dict[str, torch.Tensor],
|
||||
name_to_buffer: dict[str, torch.Tensor],
|
||||
blocks_to_lifted_attrs: dict[torch._C.Block, set[str]],
|
||||
name_to_non_tensor_attribute: dict[str, Any],
|
||||
name_to_constant: dict[str, Any],
|
||||
name_to_attribute_fqn: dict[str, str],
|
||||
):
|
||||
super().__init__(
|
||||
ts_graph,
|
||||
@ -1350,7 +1351,7 @@ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
|
||||
)
|
||||
|
||||
# Data to keep track of unsupported nodes.
|
||||
self.unsupported_node_list: List[torch._C.Node] = []
|
||||
self.unsupported_node_list: list[torch._C.Node] = []
|
||||
|
||||
# Add mock to needed attributes.
|
||||
self.name_to_node = ExplainTS2FXGraphConverter._DictMock(
|
||||
@ -1395,7 +1396,7 @@ class TS2EPConverter:
|
||||
self,
|
||||
ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction],
|
||||
sample_args: tuple[Any, ...],
|
||||
sample_kwargs: Optional[Dict[str, Any]] = None,
|
||||
sample_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
self.ts_model = ts_model
|
||||
self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
|
||||
@ -1403,8 +1404,8 @@ class TS2EPConverter:
|
||||
self.sample_args = sample_args
|
||||
self.sample_kwargs = sample_kwargs
|
||||
|
||||
self.name_to_param: Dict[str, torch.Tensor] = {}
|
||||
self.name_to_buffer: Dict[str, torch.Tensor] = {}
|
||||
self.name_to_param: dict[str, torch.Tensor] = {}
|
||||
self.name_to_buffer: dict[str, torch.Tensor] = {}
|
||||
param_list = (
|
||||
list(self.ts_model.parameters())
|
||||
if not isinstance(self.ts_model, torch._C.ScriptFunction)
|
||||
@ -1422,8 +1423,8 @@ class TS2EPConverter:
|
||||
else:
|
||||
self.name_to_buffer[k] = tensor
|
||||
|
||||
self.name_to_non_tensor_attributes: Dict[str, Any] = {}
|
||||
self.name_to_constant: Dict[str, Any] = {}
|
||||
self.name_to_non_tensor_attributes: dict[str, Any] = {}
|
||||
self.name_to_constant: dict[str, Any] = {}
|
||||
|
||||
self.lift_get_attr()
|
||||
|
||||
@ -1509,7 +1510,7 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
|
||||
def retrace_as_exported_program(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
name_to_constant: Dict[str, Any],
|
||||
name_to_constant: dict[str, Any],
|
||||
):
|
||||
dynamic_shapes = _tree_map_with_path(
|
||||
lambda path, x: (
|
||||
@ -1569,7 +1570,7 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
|
||||
# TS2FXGraphConverter since it gets attributes from self.ts_model
|
||||
# which is not accessable in TS2FXGraphConverter. It is similar to where
|
||||
# we collect self.name_to_param and self.name_to_buffer.
|
||||
name_to_attribute_fqn: Dict[str, str] = {}
|
||||
name_to_attribute_fqn: dict[str, str] = {}
|
||||
|
||||
def get_attr(fqn: str):
|
||||
name = fqn.split(".")
|
||||
|
@ -4,12 +4,12 @@ import re
|
||||
import string
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Optional
|
||||
from types import ModuleType
|
||||
|
||||
import torch
|
||||
|
||||
_TAGS: Dict[str, Dict[str, Any]] = {
|
||||
_TAGS: dict[str, dict[str, Any]] = {
|
||||
"torch": {
|
||||
"cond": {},
|
||||
"dynamic-shape": {},
|
||||
@ -79,12 +79,12 @@ class ExportCase:
|
||||
description: str # A description of the use case.
|
||||
model: torch.nn.Module
|
||||
name: str
|
||||
example_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
example_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
extra_args: Optional[ArgsType] = None # For testing graph generalization.
|
||||
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
|
||||
tags: Set[str] = field(default_factory=set)
|
||||
tags: set[str] = field(default_factory=set)
|
||||
support_level: SupportLevel = SupportLevel.SUPPORTED
|
||||
dynamic_shapes: Optional[Dict[str, Any]] = None
|
||||
dynamic_shapes: Optional[dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
check_inputs_type(self.example_args, self.example_kwargs)
|
||||
@ -98,10 +98,10 @@ class ExportCase:
|
||||
raise ValueError(f'Invalid description: "{self.description}"')
|
||||
|
||||
|
||||
_EXAMPLE_CASES: Dict[str, ExportCase] = {}
|
||||
_MODULES: Set[ModuleType] = set()
|
||||
_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
|
||||
_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
|
||||
_EXAMPLE_CASES: dict[str, ExportCase] = {}
|
||||
_MODULES: set[ModuleType] = set()
|
||||
_EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {}
|
||||
_EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {}
|
||||
|
||||
|
||||
def register_db_case(case: ExportCase) -> None:
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@ -9,7 +8,7 @@ class ListUnpack(torch.nn.Module):
|
||||
erased after tracing.
|
||||
"""
|
||||
|
||||
def forward(self, args: List[torch.Tensor]):
|
||||
def forward(self, args: list[torch.Tensor]):
|
||||
"""
|
||||
Lists are treated as static construct, therefore unpacking should be
|
||||
erased after tracing.
|
||||
|
@ -3,18 +3,7 @@ import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -97,8 +86,8 @@ def fakify(
|
||||
mode: FakeTensorMode,
|
||||
kp: KeyPath,
|
||||
t: Any,
|
||||
t_constraints: Dict[int, Dict[int, Constraint]],
|
||||
sources: Dict[tuple[int, int], List[Source]],
|
||||
t_constraints: dict[int, dict[int, Constraint]],
|
||||
sources: dict[tuple[int, int], list[Source]],
|
||||
):
|
||||
source = key_path_to_source(kp)
|
||||
if _is_constant_argument(t) or isinstance(t, (torch.ScriptObject, torch.nn.Module)):
|
||||
@ -165,7 +154,7 @@ def make_fake_inputs(
|
||||
combined_args = _combine_args(nn_module, args, kwargs)
|
||||
_check_dynamic_shapes(combined_args, dynamic_shapes)
|
||||
constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
|
||||
t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict)
|
||||
t_constraints: dict[int, dict[int, Constraint]] = defaultdict(dict)
|
||||
for constraint in constraints:
|
||||
t_constraints[constraint.t_id][constraint.dim] = constraint
|
||||
|
||||
@ -214,17 +203,17 @@ def make_fake_inputs(
|
||||
original_signature = inspect.signature(nn_module.forward)
|
||||
else:
|
||||
original_signature = None
|
||||
sources: Dict[tuple[int, int], List[Source]] = defaultdict(list)
|
||||
sources: dict[tuple[int, int], list[Source]] = defaultdict(list)
|
||||
fake_args, fake_kwargs = tree_map_with_path(
|
||||
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
|
||||
(args, kwargs),
|
||||
)
|
||||
|
||||
names: Dict[str, tuple[int, int]] = {}
|
||||
source_pairs: List[tuple[Source, Source]] = []
|
||||
derived_equalities: List[tuple[Source, Union[Source, Symbol], Callable]] = []
|
||||
phantom_symbols: Dict[str, Symbol] = {}
|
||||
relaxed_sources: Set[Source] = set()
|
||||
names: dict[str, tuple[int, int]] = {}
|
||||
source_pairs: list[tuple[Source, Source]] = []
|
||||
derived_equalities: list[tuple[Source, Union[Source, Symbol], Callable]] = []
|
||||
phantom_symbols: dict[str, Symbol] = {}
|
||||
relaxed_sources: set[Source] = set()
|
||||
for constraint in constraints:
|
||||
torch.export.dynamic_shapes._process_equalities(
|
||||
constraint,
|
||||
@ -255,9 +244,9 @@ def make_fake_inputs(
|
||||
|
||||
|
||||
def _flatten_dynamic_shapes(
|
||||
combined_args: Dict[str, Any],
|
||||
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any]],
|
||||
) -> List[Any]:
|
||||
combined_args: dict[str, Any],
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
|
||||
) -> list[Any]:
|
||||
flat_shapes = []
|
||||
|
||||
def _tree_map_helper(path, t, shape):
|
||||
@ -283,7 +272,7 @@ def _clean_dynamic_markers(tensor: torch.Tensor) -> None:
|
||||
def produce_guards_and_solve_constraints(
|
||||
fake_mode: FakeTensorMode,
|
||||
gm: torch.fx.GraphModule,
|
||||
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None],
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
||||
equalities_inputs: EqualityConstraint,
|
||||
original_signature: inspect.Signature,
|
||||
_is_torch_jit_trace=False,
|
||||
@ -348,8 +337,8 @@ def produce_guards_and_solve_constraints(
|
||||
def make_constraints(
|
||||
fake_mode: FakeTensorMode,
|
||||
gm: torch.fx.GraphModule,
|
||||
combined_args: Dict[str, Any],
|
||||
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None],
|
||||
combined_args: dict[str, Any],
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
||||
num_lifted_inputs: int,
|
||||
):
|
||||
"""
|
||||
@ -435,7 +424,7 @@ def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
|
||||
buffers_parameters = set(m.buffers())
|
||||
buffers_parameters.update(m.parameters())
|
||||
|
||||
def inner(m: torch.nn.Module, prefix_atoms: List[str], constants):
|
||||
def inner(m: torch.nn.Module, prefix_atoms: list[str], constants):
|
||||
for k, v in m.__dict__.items():
|
||||
if isinstance(
|
||||
v,
|
||||
@ -459,8 +448,8 @@ def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
|
||||
|
||||
|
||||
def _get_graph_inputs_of_type_nn_module(
|
||||
args: Optional[Tuple[Tuple[Any], Dict[Any, Any]]],
|
||||
) -> Set[Type[torch.nn.Module]]:
|
||||
args: Optional[tuple[tuple[Any], dict[Any, Any]]],
|
||||
) -> set[type[torch.nn.Module]]:
|
||||
if args is None:
|
||||
return set()
|
||||
module_types = set()
|
||||
@ -471,14 +460,14 @@ def _get_graph_inputs_of_type_nn_module(
|
||||
|
||||
|
||||
def _enter_enable_graph_inputs_of_type_nn_module(
|
||||
module_types: Set[Type[torch.nn.Module]],
|
||||
module_types: set[type[torch.nn.Module]],
|
||||
) -> None:
|
||||
for t in module_types:
|
||||
torch._export.utils.register_module_as_pytree_input_node(t)
|
||||
|
||||
|
||||
def _exit_enable_graph_inputs_of_type_nn_module(
|
||||
module_types: Set[Type[torch.nn.Module]],
|
||||
module_types: set[type[torch.nn.Module]],
|
||||
) -> None:
|
||||
for t in module_types:
|
||||
torch._export.utils.deregister_module_as_pytree_input_node(t)
|
||||
@ -486,7 +475,7 @@ def _exit_enable_graph_inputs_of_type_nn_module(
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _enable_graph_inputs_of_type_nn_module(
|
||||
args: Optional[Tuple[Tuple[Any], Dict[Any, Any]]],
|
||||
args: Optional[tuple[tuple[Any], dict[Any, Any]]],
|
||||
):
|
||||
if args is None:
|
||||
yield
|
||||
@ -502,8 +491,8 @@ def _enable_graph_inputs_of_type_nn_module(
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _fakify_module_inputs(
|
||||
args: Tuple[Any],
|
||||
kwargs: Dict[Any, Any],
|
||||
args: tuple[Any],
|
||||
kwargs: dict[Any, Any],
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
||||
):
|
||||
# This context manager is used to fakify module inputs.
|
||||
@ -534,7 +523,7 @@ def _fakify_module_inputs(
|
||||
def _fakify_script_objects(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any],
|
||||
kwargs: Dict[Any, Any],
|
||||
kwargs: dict[Any, Any],
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
||||
):
|
||||
# This context manager is used to fakify script objects into FakeScriptObject.
|
||||
|
@ -3,7 +3,7 @@ import operator
|
||||
import traceback
|
||||
import typing
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from functorch.experimental.control_flow import _unstack_pytree
|
||||
@ -31,7 +31,7 @@ Fn = Callable[..., Any]
|
||||
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
||||
|
||||
|
||||
_TORCH_SYM_OPS: Set[Callable] = {
|
||||
_TORCH_SYM_OPS: set[Callable] = {
|
||||
torch.sym_int,
|
||||
torch.sym_float,
|
||||
torch.sym_ite,
|
||||
@ -64,9 +64,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
self.root = torch.nn.Module()
|
||||
self.graph = torch.fx.Graph()
|
||||
self.graph.set_codegen(codegen)
|
||||
self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
|
||||
self.tensor_attrs: dict[str, torch.Tensor] = {} # type: ignore[assignment]
|
||||
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
||||
self.submodules: Dict[torch.nn.Module, str] = {}
|
||||
self.submodules: dict[torch.nn.Module, str] = {}
|
||||
|
||||
def trace(self) -> None: # type: ignore[override]
|
||||
raise ExportPassBaseError("ExportTracer doesn't support trace().")
|
||||
@ -162,7 +162,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
kwargs: dict[str, Argument],
|
||||
) -> ProxyValue:
|
||||
arg = super().placeholder(target, args, kwargs)
|
||||
return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
|
||||
@ -171,7 +171,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
self,
|
||||
target: torch.fx.node.Target,
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
kwargs: dict[str, Argument],
|
||||
) -> ProxyValue:
|
||||
return self.callback.output(args[0], NodeMetadata(self.node.meta)).data # type: ignore[return-value]
|
||||
|
||||
@ -179,7 +179,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
self,
|
||||
target: torch.fx.node.Target,
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
kwargs: dict[str, Argument],
|
||||
) -> ProxyValue:
|
||||
meta = NodeMetadata(self.node.meta)
|
||||
|
||||
@ -218,7 +218,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
raise ExportPassBaseError(f"Unsupported target type: {target}")
|
||||
|
||||
def get_attr(
|
||||
self, target: str, args: tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
|
||||
self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
|
||||
) -> Argument:
|
||||
return super().get_attr(target, args, kwargs)
|
||||
|
||||
@ -226,12 +226,12 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
self,
|
||||
target: torch.fx.node.Target,
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
kwargs: dict[str, Argument],
|
||||
) -> None:
|
||||
raise ExportPassBaseError("call_module is not supported.")
|
||||
|
||||
def call_method(
|
||||
self, target: str, args: tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
|
||||
self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override]
|
||||
) -> None:
|
||||
raise ExportPassBaseError("call_method is not supported.")
|
||||
|
||||
@ -254,7 +254,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
kind: str,
|
||||
target: torch.fx.node.Target,
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
kwargs: dict[str, Argument],
|
||||
meta: NodeMetadata,
|
||||
) -> ProxyValue:
|
||||
args_data, kwargs_data = pytree.tree_map_only(
|
||||
@ -277,7 +277,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
self.tracer.set_metadata(res_proxy.node, res_data)
|
||||
return ProxyValue(res_data, res_proxy)
|
||||
|
||||
def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
|
||||
def inputs(self, graph_module: torch.fx.GraphModule) -> list[Argument]:
|
||||
# TODO(angelayi): Update this with what we decide to do for metadata in
|
||||
# the exported graph module
|
||||
if (args := graph_module.meta.get("args", None)) is not None:
|
||||
@ -327,7 +327,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
self,
|
||||
op,
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
kwargs: dict[str, Argument],
|
||||
meta: NodeMetadata,
|
||||
) -> ProxyValue:
|
||||
return self._fx("call_function", op, args, kwargs, meta)
|
||||
@ -345,7 +345,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
pred: ProxyValue,
|
||||
true_fn: torch.fx.GraphModule,
|
||||
false_fn: torch.fx.GraphModule,
|
||||
inputs: List[Argument],
|
||||
inputs: list[Argument],
|
||||
meta: NodeMetadata,
|
||||
) -> ProxyValue:
|
||||
true_branch = self.call_submodule(true_fn, tuple(inputs))
|
||||
@ -363,8 +363,8 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
def call_map(
|
||||
self,
|
||||
f: torch.fx.GraphModule,
|
||||
mapped_args: List[ProxyValue],
|
||||
operands: List[ProxyValue],
|
||||
mapped_args: list[ProxyValue],
|
||||
operands: list[ProxyValue],
|
||||
meta: NodeMetadata,
|
||||
) -> ProxyValue:
|
||||
xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
|
||||
@ -383,7 +383,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
) -> ProxyValue:
|
||||
return self._fx("call_function", operator.getitem, (value, key), {}, meta)
|
||||
|
||||
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
|
||||
def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue:
|
||||
return self._fx("output", "output", (results,), {}, meta)
|
||||
|
||||
def call_submodule(
|
||||
|
@ -1,10 +1,10 @@
|
||||
from typing import Any, Dict, Set
|
||||
from typing import Any
|
||||
|
||||
|
||||
NodeMetadataValue = Any
|
||||
|
||||
|
||||
PROTECTED_KEYS: Set[str] = {
|
||||
PROTECTED_KEYS: set[str] = {
|
||||
"val",
|
||||
"stack_trace",
|
||||
"nn_module_stack",
|
||||
@ -14,8 +14,8 @@ PROTECTED_KEYS: Set[str] = {
|
||||
|
||||
|
||||
class NodeMetadata:
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
self.data: Dict[str, Any] = data.copy()
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.data: dict[str, Any] = data.copy()
|
||||
|
||||
def __getitem__(self, key: str) -> NodeMetadataValue:
|
||||
return self.data[key]
|
||||
|
@ -1,5 +1,6 @@
|
||||
# pyre-strict
|
||||
from typing import Union, Iterator, Iterable, Generic
|
||||
from typing import Union, Generic
|
||||
from collections.abc import Iterator, Iterable
|
||||
import torch
|
||||
|
||||
from typing import TypeVar
|
||||
|
@ -3,7 +3,7 @@ import math
|
||||
import operator
|
||||
import traceback
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, NamedTuple, Set
|
||||
from typing import Callable, NamedTuple
|
||||
|
||||
import sympy
|
||||
|
||||
@ -45,11 +45,11 @@ def _convert_range_to_int(range: ValueRanges):
|
||||
class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
def __init__(
|
||||
self,
|
||||
range_constraints: Dict[sympy.Symbol, ValueRanges],
|
||||
range_constraints: dict[sympy.Symbol, ValueRanges],
|
||||
):
|
||||
super().__init__()
|
||||
self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
|
||||
self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set()
|
||||
self.range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints
|
||||
self._asserts_generated_unbacked_symbols: set[sympy.Symbol] = set()
|
||||
self.counter = 0
|
||||
|
||||
def _assert_range_constraint(self, node, lower, upper, assert_msg):
|
||||
@ -105,8 +105,8 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
# need the proxy for shape, which further requires the proxy for ret[1], etc.
|
||||
|
||||
def add_assertions(val):
|
||||
call_backs: List[Callable] = []
|
||||
messages: List[str] = []
|
||||
call_backs: list[Callable] = []
|
||||
messages: list[str] = []
|
||||
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
||||
symbol = val.node.expr
|
||||
if symbol in self.existing_inline_assertions:
|
||||
@ -161,9 +161,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
|
||||
def _get_existing_inline_assertions(
|
||||
graph_module: torch.fx.GraphModule,
|
||||
range_constraints: Dict[sympy.Symbol, ValueRanges],
|
||||
) -> Dict[sympy.Symbol, ValueRanges]:
|
||||
existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {}
|
||||
range_constraints: dict[sympy.Symbol, ValueRanges],
|
||||
) -> dict[sympy.Symbol, ValueRanges]:
|
||||
existing_inline_assertions: dict[sympy.Symbol, ValueRanges] = {}
|
||||
|
||||
for module in graph_module.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import Dict, Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.export.exported_program import ConstantArgument, TensorArgument
|
||||
@ -23,7 +23,7 @@ class CollectTracepointsPass(PassBase):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, specs: Dict[str, ModuleCallSignature], sig: ExportGraphSignature
|
||||
self, specs: dict[str, ModuleCallSignature], sig: ExportGraphSignature
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.specs = specs
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -53,8 +53,8 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
skip_constructors: bool = False,
|
||||
):
|
||||
super().__init__(gm)
|
||||
self.node_replacements: Dict[torch.fx.Node, Any] = {}
|
||||
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
|
||||
self.node_replacements: dict[torch.fx.Node, Any] = {}
|
||||
self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter()
|
||||
self.unknown_value = object()
|
||||
self.skip_constructors: bool = skip_constructors
|
||||
|
||||
@ -281,7 +281,7 @@ def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule
|
||||
|
||||
new_graph = torch.fx.Graph()
|
||||
|
||||
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
output_nodes = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.meta[META_TAG] == MODULE_TAG:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import Dict, Optional, List
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
|
||||
@ -9,7 +9,7 @@ from torch._ops import OpOverload
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = {
|
||||
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
|
||||
aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
|
||||
aten._assert_async.msg: aten._functional_assert_async.msg,
|
||||
}
|
||||
@ -60,7 +60,7 @@ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
|
||||
self,
|
||||
op: OpOverload,
|
||||
args: tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
kwargs: dict[str, Argument],
|
||||
meta: NodeMetadata,
|
||||
) -> ProxyValue:
|
||||
if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
|
||||
@ -88,7 +88,7 @@ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
|
||||
|
||||
return self._dep_token
|
||||
|
||||
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
|
||||
def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue:
|
||||
assert self._dep_token is not None
|
||||
|
||||
return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type]
|
||||
|
@ -1,5 +1,4 @@
|
||||
import functools
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch._export.passes._node_metadata_hook import (
|
||||
@ -8,7 +7,7 @@ from torch._export.passes._node_metadata_hook import (
|
||||
)
|
||||
|
||||
|
||||
def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: List[str]) -> None:
|
||||
def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: list[str]) -> None:
|
||||
"""
|
||||
This is used by draft_export to insert guards in front of calls to custom
|
||||
operators which have a generated fake kernel.
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch._export.verifier import SpecViolationError
|
||||
@ -29,13 +29,13 @@ class ConstantAttrMap(collections.abc.MutableMapping):
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Underlying dict that we use to implement this mapping.
|
||||
self._constant_attrs: Dict[
|
||||
Union[int, torch.Tensor, FakeScriptObject], List[Any]
|
||||
self._constant_attrs: dict[
|
||||
Union[int, torch.Tensor, FakeScriptObject], list[Any]
|
||||
] = {}
|
||||
# Map from the hash(ScriptObject) to the ScriptObject itself. Used for
|
||||
# APIs like `__iter__` that should look like they're returning the
|
||||
# original ScriptObjects.
|
||||
self._script_object_map: Dict[int, torch.ScriptObject] = {}
|
||||
self._script_object_map: dict[int, torch.ScriptObject] = {}
|
||||
|
||||
def __getitem__(
|
||||
self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
|
||||
@ -113,7 +113,7 @@ def lift_constants_pass(
|
||||
gm: torch.fx.GraphModule,
|
||||
graph_signature: ExportGraphSignature,
|
||||
constant_attrs: ConstantAttrMap,
|
||||
) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]:
|
||||
) -> dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]:
|
||||
"""
|
||||
Takes a graph module, graph signature, and modifies them implace to lift any
|
||||
constants (tensors or custom classes) as inputs to the graph. Returns a
|
||||
@ -131,7 +131,7 @@ def lift_constants_pass(
|
||||
Returns:
|
||||
A dictionary of fqn => constant value.
|
||||
"""
|
||||
all_constants: Dict[
|
||||
all_constants: dict[
|
||||
str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
|
||||
] = {}
|
||||
|
||||
@ -300,13 +300,13 @@ def lift_constants_pass(
|
||||
|
||||
def rewrite_script_object_meta(
|
||||
gm: torch.fx.GraphModule,
|
||||
) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]:
|
||||
) -> dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]:
|
||||
"""When tracing, we produce a graph with FakeScriptObject in the
|
||||
meta["val"].
|
||||
|
||||
For now, we rewrie meta["val"] to be a placeholder CustomObjArgument
|
||||
"""
|
||||
constants: Dict[
|
||||
constants: dict[
|
||||
str,
|
||||
Union[
|
||||
torch.Tensor,
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.wrap import wrap_with_autocast
|
||||
@ -116,7 +116,7 @@ def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
exit_autocast # 3
|
||||
E # 4
|
||||
"""
|
||||
enter_autocast_node_stack: List[torch.fx.Node] = []
|
||||
enter_autocast_node_stack: list[torch.fx.Node] = []
|
||||
first_node_after_outer_most_exit: bool = False
|
||||
|
||||
def node_call_back(node: torch.fx.Node) -> bool:
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import operator
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.export._trace
|
||||
@ -269,9 +269,9 @@ def _conv1d_op_with_squeeze(
|
||||
inp: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
groups: int,
|
||||
) -> torch.Tensor:
|
||||
# In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
import torch
|
||||
from torch._ops import OpOverload, HigherOrderOperator
|
||||
from torch._export.error import InternalError
|
||||
@ -9,7 +9,7 @@ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||
__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
|
||||
|
||||
|
||||
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
|
||||
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: dict[OpOverload, OpOverload] = {
|
||||
torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from torch._export.serde.schema import Node
|
||||
|
||||
@ -12,4 +11,4 @@ class ExternKernelNode:
|
||||
|
||||
@dataclass
|
||||
class ExternKernelNodes:
|
||||
nodes: List[ExternKernelNode]
|
||||
nodes: list[ExternKernelNode]
|
||||
|
@ -1,5 +1,5 @@
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.exc import UserError, UserErrorType
|
||||
@ -24,7 +24,7 @@ class RootDim:
|
||||
|
||||
min: int
|
||||
max: Union[int, None]
|
||||
derived: List[str]
|
||||
derived: list[str]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -33,15 +33,15 @@ class DynamicShapesSpec:
|
||||
This stores a dynamic_shapes spec for de/serialization.
|
||||
"""
|
||||
|
||||
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None]
|
||||
dims: Dict[str, RootDim]
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
|
||||
dims: dict[str, RootDim]
|
||||
|
||||
|
||||
def _postprocess_serialized_shapes(
|
||||
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None],
|
||||
dims: Dict[str, Dict[str, Union[int, List[str], None]]],
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
||||
dims: dict[str, dict[str, Union[int, list[str], None]]],
|
||||
to_dict: Optional[bool] = False,
|
||||
) -> Union[DynamicShapesSpec, Dict[str, Any]]:
|
||||
) -> Union[DynamicShapesSpec, dict[str, Any]]:
|
||||
"""
|
||||
Sorts dims and dumps to dictionary format.
|
||||
"""
|
||||
@ -63,11 +63,11 @@ def _postprocess_serialized_shapes(
|
||||
|
||||
|
||||
def _dump_dynamic_shapes(
|
||||
dynamic_shapes: Union[Dict[str, Any], tuple[Any], List[Any], None],
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
||||
args: tuple[Any],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
to_dict: Optional[bool] = False,
|
||||
) -> Union[DynamicShapesSpec, Dict[str, Any]]:
|
||||
) -> Union[DynamicShapesSpec, dict[str, Any]]:
|
||||
"""
|
||||
Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec.
|
||||
Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims".
|
||||
@ -127,7 +127,7 @@ def _dump_dynamic_shapes(
|
||||
}
|
||||
```
|
||||
"""
|
||||
dims: Dict[str, Dict[str, Any]] = {}
|
||||
dims: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
|
||||
"""
|
||||
@ -198,9 +198,9 @@ def _dump_dynamic_shapes(
|
||||
|
||||
|
||||
def _load_dynamic_shapes(
|
||||
spec: Union[DynamicShapesSpec, Dict[str, Any]],
|
||||
spec: Union[DynamicShapesSpec, dict[str, Any]],
|
||||
from_dict: Optional[bool] = False,
|
||||
) -> Union[Dict[str, Any], tuple[Any], List[Any], None]:
|
||||
) -> Union[dict[str, Any], tuple[Any], list[Any], None]:
|
||||
"""
|
||||
Utility function for dynamic shapes serialization.
|
||||
Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export().
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Annotated, Dict, List, Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from torch._export.serde.union import _Union
|
||||
|
||||
@ -96,10 +96,10 @@ class SymBool(_Union):
|
||||
@dataclass
|
||||
class TensorMeta:
|
||||
dtype: Annotated[ScalarType, 10]
|
||||
sizes: Annotated[List[SymInt], 20]
|
||||
sizes: Annotated[list[SymInt], 20]
|
||||
requires_grad: Annotated[bool, 30]
|
||||
device: Annotated[Device, 40]
|
||||
strides: Annotated[List[SymInt], 50]
|
||||
strides: Annotated[list[SymInt], 50]
|
||||
storage_offset: Annotated[SymInt, 60]
|
||||
layout: Annotated[Layout, 70]
|
||||
|
||||
@ -175,29 +175,29 @@ class CustomObjArgument:
|
||||
class Argument(_Union):
|
||||
as_none: Annotated[bool, 10]
|
||||
as_tensor: Annotated[TensorArgument, 20]
|
||||
as_tensors: Annotated[List[TensorArgument], 30]
|
||||
as_tensors: Annotated[list[TensorArgument], 30]
|
||||
as_int: Annotated[int, 50]
|
||||
as_ints: Annotated[List[int], 70]
|
||||
as_ints: Annotated[list[int], 70]
|
||||
as_float: Annotated[float, 80]
|
||||
as_floats: Annotated[List[float], 90]
|
||||
as_floats: Annotated[list[float], 90]
|
||||
as_string: Annotated[str, 100]
|
||||
as_strings: Annotated[List[str], 101]
|
||||
as_strings: Annotated[list[str], 101]
|
||||
as_sym_int: Annotated[SymIntArgument, 110]
|
||||
as_sym_ints: Annotated[List[SymIntArgument], 120]
|
||||
as_sym_ints: Annotated[list[SymIntArgument], 120]
|
||||
as_scalar_type: Annotated[ScalarType, 130]
|
||||
as_memory_format: Annotated[MemoryFormat, 140]
|
||||
as_layout: Annotated[Layout, 150]
|
||||
as_device: Annotated[Device, 160]
|
||||
as_bool: Annotated[bool, 170]
|
||||
as_bools: Annotated[List[bool], 180]
|
||||
as_bools: Annotated[list[bool], 180]
|
||||
as_sym_bool: Annotated[SymBoolArgument, 182]
|
||||
as_sym_bools: Annotated[List[SymBoolArgument], 184]
|
||||
as_sym_bools: Annotated[list[SymBoolArgument], 184]
|
||||
as_graph: Annotated[GraphArgument, 200]
|
||||
as_optional_tensors: Annotated[List[OptionalTensorArgument], 190]
|
||||
as_optional_tensors: Annotated[list[OptionalTensorArgument], 190]
|
||||
as_custom_obj: Annotated[CustomObjArgument, 210]
|
||||
as_operator: Annotated[str, 220]
|
||||
as_sym_float: Annotated[SymFloatArgument, 230]
|
||||
as_sym_floats: Annotated[List[SymFloatArgument], 240]
|
||||
as_sym_floats: Annotated[list[SymFloatArgument], 240]
|
||||
|
||||
|
||||
class ArgumentKind(IntEnum):
|
||||
@ -217,27 +217,27 @@ class NamedArgument:
|
||||
@dataclass
|
||||
class Node:
|
||||
target: Annotated[str, 10]
|
||||
inputs: Annotated[List[NamedArgument], 20]
|
||||
outputs: Annotated[List[Argument], 30]
|
||||
metadata: Annotated[Dict[str, str], 40]
|
||||
inputs: Annotated[list[NamedArgument], 20]
|
||||
outputs: Annotated[list[Argument], 30]
|
||||
metadata: Annotated[dict[str, str], 40]
|
||||
is_hop_single_tensor_return: Annotated[Optional[bool], 50] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Graph:
|
||||
inputs: Annotated[List[Argument], 10]
|
||||
outputs: Annotated[List[Argument], 20]
|
||||
nodes: Annotated[List[Node], 30]
|
||||
tensor_values: Annotated[Dict[str, TensorMeta], 40]
|
||||
sym_int_values: Annotated[Dict[str, SymInt], 50]
|
||||
sym_bool_values: Annotated[Dict[str, SymBool], 60]
|
||||
inputs: Annotated[list[Argument], 10]
|
||||
outputs: Annotated[list[Argument], 20]
|
||||
nodes: Annotated[list[Node], 30]
|
||||
tensor_values: Annotated[dict[str, TensorMeta], 40]
|
||||
sym_int_values: Annotated[dict[str, SymInt], 50]
|
||||
sym_bool_values: Annotated[dict[str, SymBool], 60]
|
||||
# This is for deserializing the submodule graphs from higher order ops
|
||||
# (ex. cond, map) where single tensor returns will just return a single
|
||||
# tensor, rather than following export schema and returning a singleton
|
||||
# list.
|
||||
is_single_tensor_return: Annotated[bool, 70] = False
|
||||
custom_obj_values: Annotated[Dict[str, CustomObjArgument], 80] = field(default_factory=dict)
|
||||
sym_float_values: Annotated[Dict[str, SymFloat], 90] = field(default_factory=dict)
|
||||
custom_obj_values: Annotated[dict[str, CustomObjArgument], 80] = field(default_factory=dict)
|
||||
sym_float_values: Annotated[dict[str, SymFloat], 90] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class UserInputSpec:
|
||||
@ -354,8 +354,8 @@ class OutputSpec(_Union):
|
||||
|
||||
@dataclass
|
||||
class GraphSignature:
|
||||
input_specs: Annotated[List[InputSpec], 10]
|
||||
output_specs: Annotated[List[OutputSpec], 20]
|
||||
input_specs: Annotated[list[InputSpec], 10]
|
||||
output_specs: Annotated[list[OutputSpec], 20]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -366,8 +366,8 @@ class RangeConstraint:
|
||||
|
||||
@dataclass
|
||||
class ModuleCallSignature:
|
||||
inputs: Annotated[List[Argument], 10]
|
||||
outputs: Annotated[List[Argument], 20]
|
||||
inputs: Annotated[list[Argument], 10]
|
||||
outputs: Annotated[list[Argument], 20]
|
||||
|
||||
# These are serialized by calling pytree.treespec_loads
|
||||
# And deserialized by calling pytree.treespec_dumps
|
||||
@ -376,7 +376,7 @@ class ModuleCallSignature:
|
||||
|
||||
# This field is used to prettify the graph placeholders
|
||||
# after we ser/der and retrace
|
||||
forward_arg_names: Annotated[Optional[List[str]], 50] = None
|
||||
forward_arg_names: Annotated[Optional[list[str]], 50] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -392,8 +392,8 @@ class GraphModule:
|
||||
# This is used for unflattening, by tracking the calling structure of all of
|
||||
# the modules in order to unflatten the modules back to the eager calling
|
||||
# conventions.
|
||||
module_call_graph: Annotated[List[ModuleCallEntry], 60]
|
||||
metadata: Annotated[Dict[str, str], 40] = field(default_factory=dict)
|
||||
module_call_graph: Annotated[list[ModuleCallEntry], 60]
|
||||
metadata: Annotated[dict[str, str], 40] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Invariant: Every time a change is made to the schema, one of the versions
|
||||
@ -408,8 +408,8 @@ class SchemaVersion:
|
||||
class ExportedProgram:
|
||||
graph_module: Annotated[GraphModule, 10]
|
||||
# Key is the opset namespace (ex. aten), and value is the version number
|
||||
opset_version: Annotated[Dict[str, int], 20]
|
||||
range_constraints: Annotated[Dict[str, RangeConstraint], 30]
|
||||
opset_version: Annotated[dict[str, int], 20]
|
||||
range_constraints: Annotated[dict[str, RangeConstraint], 30]
|
||||
schema_version: Annotated[SchemaVersion, 60]
|
||||
verifiers: Annotated[List[str], 70] = field(default_factory=list)
|
||||
verifiers: Annotated[list[str], 70] = field(default_factory=list)
|
||||
torch_version: Annotated[str, 80] = "<=2.4"
|
||||
|
@ -5,7 +5,7 @@ import inspect
|
||||
import re
|
||||
import typing
|
||||
from enum import IntEnum
|
||||
from typing import Annotated, Any, Dict, ForwardRef, List, Optional, Union
|
||||
from typing import Annotated, Any, ForwardRef, Optional, Union
|
||||
|
||||
from torch._export.serde import schema
|
||||
from torch._export.serde.union import _Union
|
||||
@ -36,16 +36,16 @@ _THRIFT_TYPE_MAP = {
|
||||
|
||||
|
||||
def _staged_schema():
|
||||
yaml_ret: Dict[str, Any] = {}
|
||||
yaml_ret: dict[str, Any] = {}
|
||||
defs = {}
|
||||
cpp_enum_defs: Dict[str, str] = {}
|
||||
cpp_class_defs: Dict[str, str] = {}
|
||||
cpp_type_decls: List[str] = []
|
||||
cpp_json_defs: List[str] = []
|
||||
thrift_enum_defs: List[str] = []
|
||||
thrift_type_defs: Dict[str, str] = {}
|
||||
cpp_enum_defs: dict[str, str] = {}
|
||||
cpp_class_defs: dict[str, str] = {}
|
||||
cpp_type_decls: list[str] = []
|
||||
cpp_json_defs: list[str] = []
|
||||
thrift_enum_defs: list[str] = []
|
||||
thrift_type_defs: dict[str, str] = {}
|
||||
|
||||
def _handle_aggregate(ty) -> tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
def _handle_aggregate(ty) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
||||
def dump_type(t, level: int) -> tuple[str, str, str]:
|
||||
if getattr(t, "__name__", None) in cpp_enum_defs:
|
||||
return t.__name__, "int64_t", t.__name__
|
||||
@ -125,7 +125,7 @@ def _staged_schema():
|
||||
f"Default value {v} is not supported yet in export schema."
|
||||
)
|
||||
|
||||
def dump_field(f) -> tuple[Dict[str, Any], str, Optional[str], str, int]:
|
||||
def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]:
|
||||
t, cpp_type, thrift_type = dump_type(f.type, 0)
|
||||
ret = {"type": t}
|
||||
cpp_default: Optional[str] = None
|
||||
@ -524,12 +524,12 @@ def _hash_content(s: str):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Commit:
|
||||
result: Dict[str, Any]
|
||||
result: dict[str, Any]
|
||||
checksum_next: str
|
||||
yaml_path: str
|
||||
additions: Dict[str, Any]
|
||||
subtractions: Dict[str, Any]
|
||||
base: Dict[str, Any]
|
||||
additions: dict[str, Any]
|
||||
subtractions: dict[str, Any]
|
||||
base: dict[str, Any]
|
||||
checksum_head: Optional[str]
|
||||
cpp_header: str
|
||||
cpp_header_path: str
|
||||
|
@ -24,15 +24,11 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
final,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterator
|
||||
|
||||
import sympy
|
||||
|
||||
@ -118,7 +114,7 @@ class SerializeError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _reverse_map(d: Dict[Any, Enum]):
|
||||
def _reverse_map(d: dict[Any, Enum]):
|
||||
return {v.value: k for k, v in d.items()}
|
||||
|
||||
|
||||
@ -352,7 +348,7 @@ def serialize_torch_artifact(artifact: Optional[Any], pickle_protocol: int = DEF
|
||||
del copyreg.dispatch_table[FakeTensor]
|
||||
|
||||
|
||||
def deserialize_torch_artifact(serialized: Union[Dict[str, Any], tuple[Any, ...], bytes]):
|
||||
def deserialize_torch_artifact(serialized: Union[dict[str, Any], tuple[Any, ...], bytes]):
|
||||
if isinstance(serialized, (dict, tuple)):
|
||||
return serialized
|
||||
if len(serialized) == 0:
|
||||
@ -401,8 +397,8 @@ def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr:
|
||||
|
||||
|
||||
def serialize_range_constraints(
|
||||
range_constraints: Dict[sympy.Symbol, ValueRanges]
|
||||
) -> Dict[str, RangeConstraint]:
|
||||
range_constraints: dict[sympy.Symbol, ValueRanges]
|
||||
) -> dict[str, RangeConstraint]:
|
||||
return {
|
||||
str(k): RangeConstraint(
|
||||
_sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type]
|
||||
@ -426,15 +422,15 @@ def _get_schema_from_target(target):
|
||||
|
||||
@dataclass
|
||||
class GraphState:
|
||||
inputs: List[Argument] = field(default_factory=list)
|
||||
outputs: List[Argument] = field(default_factory=list)
|
||||
nodes: List[Node] = field(default_factory=list)
|
||||
tensor_values: Dict[str, TensorMeta] = field(default_factory=dict)
|
||||
sym_int_values: Dict[str, SymInt] = field(default_factory=dict)
|
||||
sym_bool_values: Dict[str, SymBool] = field(default_factory=dict)
|
||||
sym_float_values: Dict[str, SymFloat] = field(default_factory=dict)
|
||||
inputs: list[Argument] = field(default_factory=list)
|
||||
outputs: list[Argument] = field(default_factory=list)
|
||||
nodes: list[Node] = field(default_factory=list)
|
||||
tensor_values: dict[str, TensorMeta] = field(default_factory=dict)
|
||||
sym_int_values: dict[str, SymInt] = field(default_factory=dict)
|
||||
sym_bool_values: dict[str, SymBool] = field(default_factory=dict)
|
||||
sym_float_values: dict[str, SymFloat] = field(default_factory=dict)
|
||||
is_single_tensor_return: bool = False
|
||||
custom_obj_values: Dict[str, CustomObjArgument] = field(default_factory=dict)
|
||||
custom_obj_values: dict[str, CustomObjArgument] = field(default_factory=dict)
|
||||
|
||||
|
||||
class Final(type):
|
||||
@ -450,13 +446,13 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
def __init__(
|
||||
self,
|
||||
graph_signature: ep.ExportGraphSignature,
|
||||
module_call_graph: List[ep.ModuleCallEntry],
|
||||
module_call_graph: list[ep.ModuleCallEntry],
|
||||
):
|
||||
self.graph_state = GraphState()
|
||||
self.graph_signature = graph_signature
|
||||
self.module_call_graph = module_call_graph
|
||||
self.custom_objs: Dict[str, torch._C.ScriptObject] = {}
|
||||
self.duplicate_getitem_nodes: Dict[str, str] = {}
|
||||
self.custom_objs: dict[str, torch._C.ScriptObject] = {}
|
||||
self.duplicate_getitem_nodes: dict[str, str] = {}
|
||||
|
||||
@contextmanager
|
||||
def save_graph_state(self):
|
||||
@ -597,7 +593,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
else:
|
||||
return user_node.name
|
||||
|
||||
def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
|
||||
def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]:
|
||||
ret = {}
|
||||
if stack_trace := node.meta.get("stack_trace"):
|
||||
ret["stack_trace"] = stack_trace
|
||||
@ -647,7 +643,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
class_fqn=script_obj_meta.class_fqn,
|
||||
)
|
||||
|
||||
def serialize_sym_op_inputs(self, op, args) -> List[NamedArgument]:
|
||||
def serialize_sym_op_inputs(self, op, args) -> list[NamedArgument]:
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
args_names = [arg.name for arg in op._schema.arguments]
|
||||
else:
|
||||
@ -669,7 +665,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
target: Any, # torch._ops.OpOverload and other custom operator types.
|
||||
args,
|
||||
kwargs=None
|
||||
) -> List[NamedArgument]:
|
||||
) -> list[NamedArgument]:
|
||||
assert isinstance(target, (torch._ops.OpOverload, *_registered_extension_types()))
|
||||
kwargs = kwargs or {}
|
||||
serialized_args = []
|
||||
@ -700,7 +696,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
|
||||
return serialized_args
|
||||
|
||||
def serialize_hoo_inputs(self, args, kwargs) -> List[NamedArgument]:
|
||||
def serialize_hoo_inputs(self, args, kwargs) -> list[NamedArgument]:
|
||||
"""
|
||||
For serializing HOO inputs since HOOs do not have a schema.
|
||||
"""
|
||||
@ -1180,8 +1176,8 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
)
|
||||
|
||||
def serialize_module_call_graph(
|
||||
self, module_call_graph: List[ep.ModuleCallEntry]
|
||||
) -> List[ModuleCallEntry]:
|
||||
self, module_call_graph: list[ep.ModuleCallEntry]
|
||||
) -> list[ModuleCallEntry]:
|
||||
return [
|
||||
ModuleCallEntry(
|
||||
fqn=entry.fqn,
|
||||
@ -1194,7 +1190,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
for entry in module_call_graph
|
||||
]
|
||||
|
||||
def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]:
|
||||
def serialize_outputs(self, node: torch.fx.Node) -> list[Argument]:
|
||||
"""For a given node, return the dataclass representing its output values.
|
||||
|
||||
[NOTE: Multiple outputs] We handle aggregates differently than FX. For
|
||||
@ -1294,7 +1290,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
|
||||
return output_arguments
|
||||
|
||||
def serialize_hoo_outputs(self, node: torch.fx.Node) -> List[Argument]:
|
||||
def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]:
|
||||
"""
|
||||
For serializing HOO outputs since HOOs do not have a schema.
|
||||
"""
|
||||
@ -1359,7 +1355,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
# list outputs should've been handled earlier
|
||||
raise SerializeError(f"Unable to serialize output {meta_val}")
|
||||
|
||||
def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]:
|
||||
def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]:
|
||||
meta_val = node.meta["val"]
|
||||
|
||||
idx_to_name = {}
|
||||
@ -1406,7 +1402,7 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
is_single_tensor_return=self.graph_state.is_single_tensor_return,
|
||||
)
|
||||
|
||||
def serialize_graph_module_metadata(self, meta: Dict[str, Any]):
|
||||
def serialize_graph_module_metadata(self, meta: dict[str, Any]):
|
||||
ret = {}
|
||||
if custom := meta.get("custom"):
|
||||
try:
|
||||
@ -1431,8 +1427,8 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
|
||||
@final
|
||||
class ExportedProgramSerializer(metaclass=Final):
|
||||
def __init__(self, opset_version: Optional[Dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL):
|
||||
self.opset_version: Dict[str, int] = {}
|
||||
def __init__(self, opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL):
|
||||
self.opset_version: dict[str, int] = {}
|
||||
if opset_version:
|
||||
self.opset_version.update(opset_version)
|
||||
if "aten" not in self.opset_version:
|
||||
@ -1498,15 +1494,15 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
class Result:
|
||||
graph_module: torch.fx.GraphModule
|
||||
signature: ep.ExportGraphSignature
|
||||
module_call_graph: List[ep.ModuleCallEntry]
|
||||
names_to_symbols: Dict[str, sympy.Symbol]
|
||||
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]]
|
||||
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
|
||||
example_inputs: Optional[tuple[tuple[torch.Tensor, ...], Dict[str, Any]]]
|
||||
module_call_graph: list[ep.ModuleCallEntry]
|
||||
names_to_symbols: dict[str, sympy.Symbol]
|
||||
state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]]
|
||||
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
|
||||
example_inputs: Optional[tuple[tuple[torch.Tensor, ...], dict[str, Any]]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
|
||||
self.serialized_name_to_meta: Dict[str, MetaType] = {}
|
||||
self.serialized_name_to_node: dict[str, torch.fx.Node] = {}
|
||||
self.serialized_name_to_meta: dict[str, MetaType] = {}
|
||||
self.graph = torch.fx.Graph()
|
||||
self.module = torch.nn.Module()
|
||||
|
||||
@ -1953,10 +1949,10 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
def deserialize(
|
||||
self,
|
||||
serialized_graph_module: GraphModule,
|
||||
serialized_state_dict: Union[Dict[str, torch.Tensor], bytes],
|
||||
constants: Union[Dict[str, Any], bytes],
|
||||
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None,
|
||||
symbol_name_to_range: Optional[Dict[str, symbolic_shapes.ValueRanges]] = None,
|
||||
serialized_state_dict: Union[dict[str, torch.Tensor], bytes],
|
||||
constants: Union[dict[str, Any], bytes],
|
||||
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]] = None,
|
||||
symbol_name_to_range: Optional[dict[str, symbolic_shapes.ValueRanges]] = None,
|
||||
) -> Result:
|
||||
global _CURRENT_DESERIALIZER
|
||||
assert _CURRENT_DESERIALIZER is None
|
||||
@ -1996,7 +1992,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
"ToFloat": torch.utils._sympy.functions.ToFloat,
|
||||
"Identity": torch.utils._sympy.functions.Identity,
|
||||
}
|
||||
self.symbol_name_to_symbol: Dict[str, sympy.Symbol] = {}
|
||||
self.symbol_name_to_symbol: dict[str, sympy.Symbol] = {}
|
||||
self.constants = deserialize_torch_artifact(constants)
|
||||
self.signature = self.deserialize_signature(serialized_graph_module.signature)
|
||||
|
||||
@ -2091,7 +2087,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
kwargs[schema_arg.name] = actual_args[schema_arg.name]
|
||||
return tuple(args), kwargs
|
||||
|
||||
def deserialize_hoo_inputs(self, inputs: List[NamedArgument]):
|
||||
def deserialize_hoo_inputs(self, inputs: list[NamedArgument]):
|
||||
"""
|
||||
For deserializing HOO inputs since HOOs do not have a schema.
|
||||
"""
|
||||
@ -2232,7 +2228,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
"torch.ops.higher_order" in serialized_node.target
|
||||
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
|
||||
):
|
||||
meta_val: List[Any] = []
|
||||
meta_val: list[Any] = []
|
||||
arg = serialized_node.outputs[0].as_tensor
|
||||
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
|
||||
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
|
||||
@ -2261,7 +2257,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
fx_node: torch.fx.Node,
|
||||
arg: Union[TensorArgument, SymIntArgument, SymFloatArgument],
|
||||
idx: int,
|
||||
deserialized_metadata: Dict[str, Any],
|
||||
deserialized_metadata: dict[str, Any],
|
||||
):
|
||||
if isinstance(arg, TensorArgument):
|
||||
name = arg.name
|
||||
@ -2290,7 +2286,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
meta_val,
|
||||
fx_node: torch.fx.Node,
|
||||
args,
|
||||
deserialized_metadata: Dict[str, Any],
|
||||
deserialized_metadata: dict[str, Any],
|
||||
):
|
||||
for idx, arg in enumerate(args):
|
||||
if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)):
|
||||
@ -2343,7 +2339,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
# return value.
|
||||
# This performs the inverse mapping of the `serialize_outputs` call in
|
||||
# serialization, see [NOTE: Multiple outputs]
|
||||
meta_val: List[Any] = []
|
||||
meta_val: list[Any] = []
|
||||
if len(serialized_node.outputs) == 1:
|
||||
assert isinstance(serialized_node.outputs[0].value, list)
|
||||
assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
|
||||
@ -2355,8 +2351,8 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
fx_node.meta["val"] = tuple(meta_val)
|
||||
self.serialized_name_to_node[fx_node.name] = fx_node
|
||||
|
||||
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
|
||||
ret: Dict[str, Any] = {}
|
||||
def deserialize_metadata(self, metadata: dict[str, str]) -> dict[str, Any]:
|
||||
ret: dict[str, Any] = {}
|
||||
if stack_trace := metadata.get("stack_trace"):
|
||||
ret["stack_trace"] = stack_trace
|
||||
|
||||
@ -2446,8 +2442,8 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
)
|
||||
|
||||
def deserialize_module_call_graph(
|
||||
self, module_call_graph: List[ModuleCallEntry]
|
||||
) -> List[ep.ModuleCallEntry]:
|
||||
self, module_call_graph: list[ModuleCallEntry]
|
||||
) -> list[ep.ModuleCallEntry]:
|
||||
return [
|
||||
ep.ModuleCallEntry(
|
||||
fqn=entry.fqn,
|
||||
@ -2463,8 +2459,8 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
|
||||
@final
|
||||
class ExportedProgramDeserializer(metaclass=Final):
|
||||
def __init__(self, expected_opset_version: Optional[Dict[str, int]] = None):
|
||||
self.expected_opset_version: Dict[str, int] = {}
|
||||
def __init__(self, expected_opset_version: Optional[dict[str, int]] = None):
|
||||
self.expected_opset_version: dict[str, int] = {}
|
||||
if expected_opset_version:
|
||||
self.expected_opset_version.update(expected_opset_version)
|
||||
if "aten" not in self.expected_opset_version:
|
||||
@ -2472,9 +2468,9 @@ class ExportedProgramDeserializer(metaclass=Final):
|
||||
|
||||
def deserialize_range_constraints(
|
||||
self,
|
||||
symbol_name_to_range: Dict[str, symbolic_shapes.ValueRanges],
|
||||
symbol_name_to_symbol: Dict[str, sympy.Symbol],
|
||||
) -> Dict[sympy.Symbol, ValueRanges]:
|
||||
symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges],
|
||||
symbol_name_to_symbol: dict[str, sympy.Symbol],
|
||||
) -> dict[sympy.Symbol, ValueRanges]:
|
||||
range_constraints = {}
|
||||
for k, v in symbol_name_to_range.items():
|
||||
if symbol := symbol_name_to_symbol.get(k):
|
||||
@ -2486,9 +2482,9 @@ class ExportedProgramDeserializer(metaclass=Final):
|
||||
def deserialize(
|
||||
self,
|
||||
exported_program: ExportedProgram,
|
||||
state_dict: Union[Dict[str, torch.Tensor], bytes],
|
||||
constants: Union[Dict[str, torch.Tensor], bytes],
|
||||
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], Dict[str, Any]], bytes]] = None,
|
||||
state_dict: Union[dict[str, torch.Tensor], bytes],
|
||||
constants: Union[dict[str, torch.Tensor], bytes],
|
||||
example_inputs: Optional[Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes]] = None,
|
||||
*,
|
||||
_unsafe_skip_version_check=False,
|
||||
) -> ep.ExportedProgram:
|
||||
@ -2566,7 +2562,7 @@ def _dataclass_to_dict(obj):
|
||||
|
||||
def serialize(
|
||||
exported_program: ep.ExportedProgram,
|
||||
opset_version: Optional[Dict[str, int]] = None,
|
||||
opset_version: Optional[dict[str, int]] = None,
|
||||
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
|
||||
) -> SerializedArtifact:
|
||||
with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs):
|
||||
@ -2627,7 +2623,7 @@ def _dict_to_dataclass(cls, data):
|
||||
|
||||
def deserialize(
|
||||
artifact: SerializedArtifact,
|
||||
expected_opset_version: Optional[Dict[str, int]] = None,
|
||||
expected_opset_version: Optional[dict[str, int]] = None,
|
||||
*,
|
||||
_unsafe_skip_version_check=False,
|
||||
) -> ep.ExportedProgram:
|
||||
@ -2649,7 +2645,7 @@ def deserialize(
|
||||
|
||||
def _canonicalize_graph(
|
||||
sorted_inputs, sorted_outputs, graph
|
||||
) -> tuple[Graph, Dict[str, str]]:
|
||||
) -> tuple[Graph, dict[str, str]]:
|
||||
def _get_argument(a: Argument):
|
||||
if a.type == "as_none":
|
||||
return None
|
||||
@ -2712,15 +2708,15 @@ def _canonicalize_graph(
|
||||
def sort_nodes(nodes):
|
||||
@dataclass
|
||||
class Edges:
|
||||
outs: List[int]
|
||||
outs: list[int]
|
||||
ins: int
|
||||
|
||||
graph_inputs: Set[str] = set()
|
||||
def_table: Dict[str, int] = {}
|
||||
edges: Dict[int, Edges] = {}
|
||||
candidates: List[tuple[str, List[tuple[str, List[int]]], int]] = []
|
||||
rank: Dict[str, int] = {}
|
||||
ret: List[Node] = []
|
||||
graph_inputs: set[str] = set()
|
||||
def_table: dict[str, int] = {}
|
||||
edges: dict[int, Edges] = {}
|
||||
candidates: list[tuple[str, list[tuple[str, list[int]]], int]] = []
|
||||
rank: dict[str, int] = {}
|
||||
ret: list[Node] = []
|
||||
|
||||
def get_name(a) -> Optional[str]:
|
||||
if a is None:
|
||||
@ -2827,7 +2823,7 @@ def _canonicalize_graph(
|
||||
assert len(sorted_nodes) == len(graph.nodes)
|
||||
|
||||
# Stage 2: Rename nodes.
|
||||
name_table: Dict[str, str] = {}
|
||||
name_table: dict[str, str] = {}
|
||||
|
||||
def rename_def(a):
|
||||
def _rename(arg_name, values):
|
||||
@ -3163,8 +3159,8 @@ class ExtensionHandler:
|
||||
|
||||
|
||||
def register_extension(
|
||||
op_type: Type[Any],
|
||||
extension_handler: Type[ExtensionHandler],
|
||||
op_type: type[Any],
|
||||
extension_handler: type[ExtensionHandler],
|
||||
):
|
||||
"""Register custom de/serialization method for a node with non-standard type."""
|
||||
assert issubclass(extension_handler, ExtensionHandler), f"Expected ExtensionHandler, got {extension_handler}."
|
||||
@ -3187,5 +3183,5 @@ def _registered_extension_types():
|
||||
# namespace to avoid conflicts.
|
||||
# Serialization: Op type --> custom handler.
|
||||
# De-serialization: Namespace --> custom handler.
|
||||
_serialization_registry: Dict[Type[Any], Type[ExtensionHandler]] = {}
|
||||
_deserialization_registry: Dict[str, Type[ExtensionHandler]] = {}
|
||||
_serialization_registry: dict[type[Any], type[ExtensionHandler]] = {}
|
||||
_deserialization_registry: dict[str, type[ExtensionHandler]] = {}
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import fields
|
||||
from typing import Hashable, Set
|
||||
|
||||
|
||||
class _UnionTag(str):
|
||||
@ -26,8 +26,8 @@ class _UnionTag(str):
|
||||
return hash(str(self))
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _get_field_names(cls) -> Set[str]:
|
||||
@functools.cache
|
||||
def _get_field_names(cls) -> set[str]:
|
||||
return {f.name for f in fields(cls)}
|
||||
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
@ -18,8 +19,8 @@ def _generate_inputs_for_submodules(
|
||||
model: torch.nn.Module,
|
||||
target_submodules: Iterable[str],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, tuple[Any, Any]]:
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, tuple[Any, Any]]:
|
||||
"""
|
||||
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
|
||||
function doesn't work.
|
||||
@ -61,11 +62,11 @@ def _generate_inputs_for_submodules(
|
||||
def report_exportability(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
pre_dispatch: bool = False,
|
||||
) -> Dict[str, Optional[Exception]]:
|
||||
) -> dict[str, Optional[Exception]]:
|
||||
"""
|
||||
Report exportability issues for a module in one-shot.
|
||||
|
||||
@ -92,7 +93,7 @@ def report_exportability(
|
||||
submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
|
||||
|
||||
tried_module_types = set()
|
||||
report: Dict[str, Optional[Exception]] = {}
|
||||
report: dict[str, Optional[Exception]] = {}
|
||||
|
||||
def try_export(module, module_name, args, kwargs):
|
||||
nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types
|
||||
|
@ -8,21 +8,10 @@ import json
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from inspect import Parameter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._guards import detect_fake_mode
|
||||
@ -116,7 +105,7 @@ def _overwrite_signature_for_non_persistent_buffers(
|
||||
return new_sig
|
||||
|
||||
|
||||
def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]:
|
||||
def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> dict[str, Any]:
|
||||
"""
|
||||
Param/buffer metadata needs to be saved before lowering to aten IR
|
||||
because aten IR lifts them, as a result, automatic preservation doesn't work.
|
||||
@ -174,7 +163,7 @@ def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def _populate_param_buffer_metadata_to_new_gm(
|
||||
params_buffers_to_node_meta: Dict[str, Any],
|
||||
params_buffers_to_node_meta: dict[str, Any],
|
||||
gm: torch.fx.GraphModule,
|
||||
new_sig: "ExportGraphSignature",
|
||||
) -> None:
|
||||
@ -217,7 +206,7 @@ def _get_shape_env_from_gm(gm: torch.fx.GraphModule):
|
||||
|
||||
|
||||
def _rename_without_collisions(
|
||||
name_map: Dict[str, str],
|
||||
name_map: dict[str, str],
|
||||
orig_name: str,
|
||||
name: str,
|
||||
is_placeholder: bool = False,
|
||||
@ -246,7 +235,7 @@ def _rename_without_collisions(
|
||||
|
||||
|
||||
def _check_input_constraints_for_graph(
|
||||
input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints
|
||||
input_placeholders: list[torch.fx.Node], flat_args_with_path, range_constraints
|
||||
) -> None:
|
||||
def get_keystr(key_path: KeyPath) -> str:
|
||||
"""For a given index into the flat_args, return a human readable string
|
||||
@ -280,7 +269,7 @@ def _check_input_constraints_for_graph(
|
||||
# NOTE: export already guarantees that the same symbol is used in metadata
|
||||
# for all InputDims related by equality constraints, so we can just unify
|
||||
# symbols with given input dimension values to check equality constraints.
|
||||
unification_map: Dict[sympy.Symbol, Any] = {}
|
||||
unification_map: dict[sympy.Symbol, Any] = {}
|
||||
for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
|
||||
node_val = node.meta.get("val")
|
||||
if isinstance(node_val, FakeTensor):
|
||||
@ -372,7 +361,7 @@ def _check_input_constraints_for_graph(
|
||||
|
||||
|
||||
def register_dataclass_as_pytree_node(
|
||||
cls: Type[Any],
|
||||
cls: type[Any],
|
||||
flatten_fn: Optional[FlattenFunc] = None,
|
||||
unflatten_fn: Optional[UnflattenFunc] = None,
|
||||
*,
|
||||
@ -385,7 +374,7 @@ def register_dataclass_as_pytree_node(
|
||||
cls
|
||||
), f"Only dataclasses can be registered with this function: {cls}"
|
||||
|
||||
def default_flatten_fn(obj: Any) -> tuple[List[Any], Context]:
|
||||
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
|
||||
flattened = []
|
||||
flat_names = []
|
||||
none_names = []
|
||||
@ -402,7 +391,7 @@ def register_dataclass_as_pytree_node(
|
||||
flat_names, none_names = context
|
||||
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
|
||||
|
||||
def default_flatten_fn_with_keys(obj: Any) -> tuple[List[Any], Context]:
|
||||
def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
|
||||
flattened, (flat_names, _none_names) = flatten_fn(obj) # type: ignore[misc]
|
||||
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
|
||||
|
||||
@ -537,7 +526,7 @@ def sequential_split(
|
||||
return new_gm
|
||||
|
||||
|
||||
def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
|
||||
def nodes_filter(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]:
|
||||
"""Returns the nodes that match the node_call_back as a list."""
|
||||
return [node for node in nodes if node_call_back(node)]
|
||||
|
||||
@ -572,7 +561,7 @@ def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature):
|
||||
|
||||
|
||||
def nodes_first(
|
||||
nodes: List[torch.fx.Node], node_call_back=None
|
||||
nodes: list[torch.fx.Node], node_call_back=None
|
||||
) -> Optional[torch.fx.Node]:
|
||||
"""
|
||||
Returns the first node that matches the node_call_back. If no node matches, returns None.
|
||||
@ -584,12 +573,12 @@ def nodes_first(
|
||||
return None
|
||||
|
||||
|
||||
def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int:
|
||||
def nodes_count(nodes: list[torch.fx.Node], node_call_back) -> int:
|
||||
"""Returns the number of nodes that match the node_call_back."""
|
||||
return len(nodes_filter(nodes, node_call_back))
|
||||
|
||||
|
||||
def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
|
||||
def nodes_map(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]:
|
||||
"""
|
||||
Sequentially visit the nodes list and invoke node_call_back on each element.
|
||||
Returns the nodes list after the node_call_back is invoked on each element.
|
||||
@ -748,7 +737,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
|
||||
and gather the top-level named placeholder nodes.
|
||||
"""
|
||||
# gather all HOO subgraphs and their top-level named placeholder nodes
|
||||
subgraph_ph_tuples: List[tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = []
|
||||
subgraph_ph_tuples: list[tuple[torch.fx.GraphModule, list[torch.fx.Node]]] = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and isinstance(
|
||||
node.target, torch._ops.HigherOrderOperator
|
||||
@ -769,7 +758,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
|
||||
|
||||
# propagate names
|
||||
for subgraph, hoo_phs in subgraph_ph_tuples:
|
||||
name_map: Dict[str, str] = {}
|
||||
name_map: dict[str, str] = {}
|
||||
for i, node in enumerate(subgraph.graph.nodes):
|
||||
if i < len(hoo_phs): # placeholder, retain name
|
||||
name_map[node.name] = hoo_phs[i].name
|
||||
@ -789,7 +778,7 @@ def placeholder_naming_pass(
|
||||
fake_args,
|
||||
fake_kwargs,
|
||||
fake_params_buffers,
|
||||
constants: Dict[str, Any],
|
||||
constants: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
This pass is run at the end of _export_non_strict() to assign better placeholder node names:
|
||||
@ -828,7 +817,7 @@ def placeholder_naming_pass(
|
||||
else:
|
||||
raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}")
|
||||
|
||||
name_map: Dict[str, str] = {}
|
||||
name_map: dict[str, str] = {}
|
||||
|
||||
# map user input names with mod.forward() signature
|
||||
combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs)
|
||||
@ -927,7 +916,7 @@ def placeholder_naming_pass(
|
||||
del constants[name]
|
||||
|
||||
|
||||
def remove_proxy_from_state_dict(state_dict: Dict, in_place: bool) -> Dict:
|
||||
def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict:
|
||||
"""
|
||||
If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`.
|
||||
`v` is the values in the dictionary.
|
||||
@ -957,8 +946,8 @@ def _detect_fake_mode_from_gm(
|
||||
If no fake mode is found, we return None for fake_mode.
|
||||
"""
|
||||
|
||||
fake_inps: List[torch.Tensor] = []
|
||||
fake_vals: List[torch.Tensor] = []
|
||||
fake_inps: list[torch.Tensor] = []
|
||||
fake_vals: list[torch.Tensor] = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder" and "val" in node.meta:
|
||||
fake_val = node.meta["val"]
|
||||
@ -980,8 +969,8 @@ def _detect_fake_mode_from_gm(
|
||||
|
||||
@contextmanager
|
||||
def _disable_load_state_dict_hooks(mod: torch.nn.Module):
|
||||
state_dict_hooks: Dict[int, Callable] = dict(mod._state_dict_hooks)
|
||||
state_dict_pre_hooks: Dict[int, Callable] = dict(mod._state_dict_pre_hooks)
|
||||
state_dict_hooks: dict[int, Callable] = dict(mod._state_dict_hooks)
|
||||
state_dict_pre_hooks: dict[int, Callable] = dict(mod._state_dict_pre_hooks)
|
||||
mod._state_dict_hooks.clear()
|
||||
mod._state_dict_pre_hooks.clear()
|
||||
try:
|
||||
@ -1075,11 +1064,11 @@ def _check_valid_to_preserve(op_overload: "OperatorBase"):
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _collect_all_valid_cia_ops_for_aten_namespace() -> Set["OperatorBase"]:
|
||||
def _collect_all_valid_cia_ops_for_aten_namespace() -> set["OperatorBase"]:
|
||||
return _collect_all_valid_cia_ops_for_namespace("aten")
|
||||
|
||||
|
||||
def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBase"]:
|
||||
def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> set["OperatorBase"]:
|
||||
# Step 1: Materialize all ops from C++ dispatcher
|
||||
_materialize_cpp_cia_ops()
|
||||
|
||||
@ -1096,7 +1085,7 @@ def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBas
|
||||
return cia_ops
|
||||
|
||||
|
||||
def _collect_all_valid_cia_ops() -> Set["OperatorBase"]:
|
||||
def _collect_all_valid_cia_ops() -> set["OperatorBase"]:
|
||||
"""
|
||||
This is an util function that gets the all CIA functional ops.
|
||||
|
||||
@ -1166,14 +1155,14 @@ def _compiling_state_context():
|
||||
def _fakify_params_buffers(
|
||||
fake_mode: FakeTensorMode,
|
||||
mod: torch.nn.Module,
|
||||
) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]:
|
||||
) -> dict[str, Union[torch.Tensor, torch.nn.Parameter]]:
|
||||
params_buffers = {
|
||||
**dict(mod.named_parameters(remove_duplicate=False)),
|
||||
**dict(mod.named_buffers(remove_duplicate=False)),
|
||||
}
|
||||
|
||||
faked_params_buffers = {}
|
||||
memo: Dict[int, FakeTensor] = {}
|
||||
memo: dict[int, FakeTensor] = {}
|
||||
for key, value in params_buffers.items():
|
||||
if id(value) in memo:
|
||||
fake_tensor = memo[id(value)]
|
||||
@ -1184,7 +1173,7 @@ def _fakify_params_buffers(
|
||||
return faked_params_buffers # type: ignore[return-value]
|
||||
|
||||
|
||||
def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
|
||||
def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
|
||||
"""
|
||||
Registers a module as a valid input type for :func:`torch.export.export`.
|
||||
|
||||
@ -1233,7 +1222,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
|
||||
def __deepcopy__(self, memo):
|
||||
return PrototypeModule(self())
|
||||
|
||||
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
||||
def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]:
|
||||
named_parameters = dict(obj.named_parameters())
|
||||
named_buffers = dict(obj.named_buffers())
|
||||
params_buffers = {**named_parameters, **named_buffers}
|
||||
@ -1270,7 +1259,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
|
||||
ret = obj
|
||||
return ret
|
||||
|
||||
def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]:
|
||||
def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
|
||||
flattened, [flat_names, *args] = flatten_fn(obj) # type: ignore[misc]
|
||||
return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], [
|
||||
flat_names,
|
||||
@ -1301,7 +1290,7 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
|
||||
from_dumpable_context=from_dumpable_context,
|
||||
)
|
||||
|
||||
def default_flatten_fn_spec(obj, spec) -> List[Any]:
|
||||
def default_flatten_fn_spec(obj, spec) -> list[Any]:
|
||||
flats, context = flatten_fn(obj)
|
||||
assert context == spec.context
|
||||
return flats
|
||||
@ -1312,6 +1301,6 @@ def register_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def deregister_module_as_pytree_input_node(cls: Type[torch.nn.Module]) -> None:
|
||||
def deregister_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None:
|
||||
_deregister_pytree_node(cls)
|
||||
_deregister_pytree_flatten_spec(cls)
|
||||
|
@ -4,7 +4,7 @@ import inspect
|
||||
import math
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Dict, final, List, Type, TYPE_CHECKING
|
||||
from typing import Any, final, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._ops import HigherOrderOperator, OpOverload
|
||||
@ -83,7 +83,7 @@ def _check_torch_fn(node: torch.fx.Node) -> None:
|
||||
raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}")
|
||||
|
||||
class _VerifierMeta(type):
|
||||
_registry: Dict[str, Type['Verifier']] = {}
|
||||
_registry: dict[str, type['Verifier']] = {}
|
||||
|
||||
def __new__(metacls, name, bases, attrs):
|
||||
if bases:
|
||||
@ -113,7 +113,7 @@ def getattr_recursive(obj: Any, target: str) -> Any:
|
||||
class Verifier(metaclass=_VerifierMeta):
|
||||
dialect = "ATEN"
|
||||
|
||||
def allowed_builtin_ops(self) -> List:
|
||||
def allowed_builtin_ops(self) -> list:
|
||||
return [
|
||||
operator.getitem,
|
||||
operator.add,
|
||||
@ -141,10 +141,10 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
builtins.getattr,
|
||||
]
|
||||
|
||||
def allowed_op_types(self) -> tuple[Type[Any], ...]:
|
||||
def allowed_op_types(self) -> tuple[type[Any], ...]:
|
||||
return (OpOverload, HigherOrderOperator)
|
||||
|
||||
def allowed_getattr_types(self) -> tuple[Type[Any], ...]:
|
||||
def allowed_getattr_types(self) -> tuple[type[Any], ...]:
|
||||
return (torch.fx.GraphModule,)
|
||||
|
||||
def check_valid_op(self, op):
|
||||
@ -163,18 +163,18 @@ class Verifier(metaclass=_VerifierMeta):
|
||||
|
||||
@final
|
||||
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
|
||||
def _allowed_getattr_types() -> tuple[Type[Any], ...]:
|
||||
def _allowed_getattr_types() -> tuple[type[Any], ...]:
|
||||
ret = self.allowed_getattr_types()
|
||||
assert not any(t is object for t in ret)
|
||||
return ret
|
||||
|
||||
def _check_valid_op(op) -> None:
|
||||
def _allowed_builtin_ops() -> List:
|
||||
def _allowed_builtin_ops() -> list:
|
||||
ret = self.allowed_builtin_ops()
|
||||
assert all(inspect.isbuiltin(op) for op in ret)
|
||||
return ret
|
||||
|
||||
def _allowed_op_types() -> tuple[Type[Any], ...]:
|
||||
def _allowed_op_types() -> tuple[type[Any], ...]:
|
||||
ret = self.allowed_op_types()
|
||||
assert not any(t is object for t in ret)
|
||||
return ret
|
||||
@ -426,7 +426,7 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
|
||||
num_tokens = len(gs.output_tokens)
|
||||
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
|
||||
mutate_nodes: List[str] = output_nodes[num_tokens:end]
|
||||
mutate_nodes: list[str] = output_nodes[num_tokens:end]
|
||||
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
|
||||
|
||||
for mutation_node in mutate_nodes:
|
||||
@ -458,7 +458,7 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
)
|
||||
|
||||
|
||||
def load_verifier(dialect: str) -> Type[Verifier]:
|
||||
def load_verifier(dialect: str) -> type[Verifier]:
|
||||
if dialect == "ATEN" or dialect == "":
|
||||
return _VerifierMeta._registry.get(dialect, Verifier)
|
||||
return _VerifierMeta._registry[dialect]
|
||||
|
Reference in New Issue
Block a user