mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/export (#145165)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145165 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
316808e4e9
commit
b6c5562c1f
@ -8,12 +8,12 @@ import sys
|
||||
import typing
|
||||
import warnings
|
||||
import zipfile
|
||||
from collections.abc import Iterator
|
||||
from enum import auto, Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
@ -82,12 +82,12 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
||||
|
||||
def export_for_training(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -177,13 +177,13 @@ def export_for_training(
|
||||
|
||||
def export_for_inference(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
decomp_table: Optional[Dict["OpOverload", Optional[Callable]]] = None,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
decomp_table: Optional[dict["OpOverload", Optional[Callable]]] = None,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -262,12 +262,12 @@ def export_for_inference(
|
||||
|
||||
def export(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -383,8 +383,8 @@ def save(
|
||||
ep: ExportedProgram,
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
*,
|
||||
extra_files: Optional[Dict[str, Any]] = None,
|
||||
opset_version: Optional[Dict[str, int]] = None,
|
||||
extra_files: Optional[dict[str, Any]] = None,
|
||||
opset_version: Optional[dict[str, int]] = None,
|
||||
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
|
||||
) -> None:
|
||||
"""
|
||||
@ -466,8 +466,8 @@ def save(
|
||||
def load(
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
*,
|
||||
extra_files: Optional[Dict[str, Any]] = None,
|
||||
expected_opset_version: Optional[Dict[str, int]] = None,
|
||||
extra_files: Optional[dict[str, Any]] = None,
|
||||
expected_opset_version: Optional[dict[str, int]] = None,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
|
||||
@ -577,7 +577,7 @@ def load(
|
||||
|
||||
|
||||
def register_dataclass(
|
||||
cls: Type[Any],
|
||||
cls: type[Any],
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
) -> None:
|
||||
|
@ -2,7 +2,7 @@ import inspect
|
||||
import logging
|
||||
import os
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._logging._internal
|
||||
@ -26,7 +26,7 @@ class FailureType(IntEnum):
|
||||
return self.name
|
||||
|
||||
|
||||
def prettify_stack(stack: List[Dict[str, str]], str_to_filename: Dict[str, str]) -> str:
|
||||
def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[str, str]) -> str:
|
||||
res = ""
|
||||
for frame in stack:
|
||||
if frame["filename"] not in str_to_filename:
|
||||
@ -38,8 +38,8 @@ def prettify_stack(stack: List[Dict[str, str]], str_to_filename: Dict[str, str])
|
||||
|
||||
|
||||
def filter_stack(
|
||||
stack: List[Dict[str, str]], str_to_filename: Dict[str, str]
|
||||
) -> List[Dict[str, str]]:
|
||||
stack: list[dict[str, str]], str_to_filename: dict[str, str]
|
||||
) -> list[dict[str, str]]:
|
||||
for i, s in enumerate(reversed(stack)):
|
||||
s["filename"] = str(s["filename"])
|
||||
if s["filename"] not in str_to_filename:
|
||||
@ -50,22 +50,22 @@ def filter_stack(
|
||||
return stack[-3:]
|
||||
|
||||
|
||||
def hash_stack(stack: List[Dict[str, str]]) -> str:
|
||||
def hash_stack(stack: list[dict[str, str]]) -> str:
|
||||
return ";".join(f'line: {s["line"]} filename: {s["filename"]}' for s in stack)
|
||||
|
||||
|
||||
class FailureReport:
|
||||
def __init__(
|
||||
self, failure_type: FailureType, data: Dict[str, Any], xfail: bool = False
|
||||
self, failure_type: FailureType, data: dict[str, Any], xfail: bool = False
|
||||
) -> None:
|
||||
self.failure_type: FailureType = failure_type
|
||||
self.data: Dict[str, Any] = data
|
||||
self.data: dict[str, Any] = data
|
||||
self.xfail: bool = xfail
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FailureReport(failure_type={self.failure_type}, xfail={self.xfail}, data={self.data})"
|
||||
|
||||
def print(self, str_to_filename: Dict[str, str]) -> str:
|
||||
def print(self, str_to_filename: dict[str, str]) -> str:
|
||||
if self.failure_type == FailureType.MISSING_FAKE_KERNEL:
|
||||
op = self.data["op"]
|
||||
|
||||
@ -113,8 +113,8 @@ class FailureReport:
|
||||
|
||||
|
||||
class DraftExportReport:
|
||||
def __init__(self, failures: List[FailureReport], str_to_filename: Dict[str, str]):
|
||||
self.failures: List[FailureReport] = failures
|
||||
def __init__(self, failures: list[FailureReport], str_to_filename: dict[str, str]):
|
||||
self.failures: list[FailureReport] = failures
|
||||
self.str_to_filename = str_to_filename
|
||||
|
||||
def successful(self) -> bool:
|
||||
@ -156,10 +156,10 @@ Please follow the instructions to fix the errors.
|
||||
|
||||
|
||||
class CaptureStructuredTrace(logging.Handler):
|
||||
def __init__(self, specific_log_keys: List[str]):
|
||||
def __init__(self, specific_log_keys: list[str]):
|
||||
super().__init__()
|
||||
self.specific_log_keys = specific_log_keys
|
||||
self.logs: List[Tuple[str, Dict[str, Any]]] = []
|
||||
self.logs: list[tuple[str, dict[str, Any]]] = []
|
||||
self.logger = logging.getLogger("torch.__trace")
|
||||
self.prev_get_dtrace = False
|
||||
|
||||
@ -185,14 +185,14 @@ class CaptureStructuredTrace(logging.Handler):
|
||||
|
||||
def draft_export(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
strict: bool = False,
|
||||
pre_dispatch: bool = False,
|
||||
) -> Tuple[ExportedProgram, DraftExportReport]:
|
||||
) -> tuple[ExportedProgram, DraftExportReport]:
|
||||
kwargs = kwargs or {}
|
||||
dynamic_shapes = dynamic_shapes or {}
|
||||
|
||||
@ -234,15 +234,15 @@ def draft_export(
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
)
|
||||
|
||||
str_to_filename: Dict[str, str] = {
|
||||
str_to_filename: dict[str, str] = {
|
||||
str(v): k for (k, v) in torch._logging.structured.INTERN_TABLE.items()
|
||||
}
|
||||
failures: List[FailureReport] = []
|
||||
custom_ops_logs: Dict[
|
||||
Any, Tuple[Dict[str, Any], FailureType]
|
||||
failures: list[FailureReport] = []
|
||||
custom_ops_logs: dict[
|
||||
Any, tuple[dict[str, Any], FailureType]
|
||||
] = {} # Dedup custom ops
|
||||
data_dependent_logs: Dict[
|
||||
str, Dict[str, Any]
|
||||
data_dependent_logs: dict[
|
||||
str, dict[str, Any]
|
||||
] = {} # Dedup data dependent errors based on stacktrace
|
||||
|
||||
for log_name, log_contents in capture_structured_log.logs:
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.effects import _get_schema, with_effects
|
||||
@ -22,7 +21,7 @@ def _remove_effect_tokens_from_graph_helper(
|
||||
inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs
|
||||
|
||||
output_node = None
|
||||
with_effect_nodes: List[torch.fx.Node] = []
|
||||
with_effect_nodes: list[torch.fx.Node] = []
|
||||
|
||||
# Output node need to check its args agianst output_token_names (collected from output_spec)
|
||||
# Therefore, we only need to find the top-levele output node
|
||||
@ -127,8 +126,8 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
||||
This function does an inplace modification on the given ExportedProgram.
|
||||
"""
|
||||
num_tokens: int = 0
|
||||
input_token_names: List[str] = []
|
||||
new_input_specs: List[InputSpec] = []
|
||||
input_token_names: list[str] = []
|
||||
new_input_specs: list[InputSpec] = []
|
||||
for inp in ep.graph_signature.input_specs:
|
||||
if inp.kind == InputKind.TOKEN:
|
||||
num_tokens += 1
|
||||
@ -138,8 +137,8 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
||||
new_input_specs.append(inp)
|
||||
|
||||
num_out_tokens: int = 0
|
||||
new_output_specs: List[OutputSpec] = []
|
||||
output_token_names: List[OutputSpec] = []
|
||||
new_output_specs: list[OutputSpec] = []
|
||||
output_token_names: list[OutputSpec] = []
|
||||
for out in ep.graph_signature.output_specs:
|
||||
if out.kind == OutputKind.TOKEN:
|
||||
num_out_tokens += 1
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
import operator
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.fx._pytree as fx_pytree
|
||||
@ -19,7 +19,7 @@ from torch.fx.passes.utils.fuser_utils import erase_nodes, fuse_as_graphmodule
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_getitem_users(node: torch.fx.Node) -> Set[torch.fx.Node]:
|
||||
def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]:
|
||||
node_users = list(node.users.keys())
|
||||
getitem_users = set()
|
||||
for user in node_users:
|
||||
@ -172,9 +172,9 @@ def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None:
|
||||
def _construct_inputs(
|
||||
gm: torch.fx.GraphModule,
|
||||
signature: ModuleCallSignature,
|
||||
node_name_map: Dict[str, torch.fx.Node],
|
||||
) -> Tuple[List[torch.fx.Node], Dict[str, torch.fx.Node]]:
|
||||
tree_unflatten_args: List[Optional[torch.fx.Node]] = []
|
||||
node_name_map: dict[str, torch.fx.Node],
|
||||
) -> tuple[list[torch.fx.Node], dict[str, torch.fx.Node]]:
|
||||
tree_unflatten_args: list[Optional[torch.fx.Node]] = []
|
||||
for input_ in signature.inputs:
|
||||
if isinstance(input_, ConstantArgument) and input_.value is None:
|
||||
# Constants should be directly embedded into the graph and not used
|
||||
@ -213,8 +213,8 @@ def _construct_inputs(
|
||||
|
||||
def _insert_call_module(
|
||||
gm: torch.fx.GraphModule,
|
||||
args_nodes: List[torch.fx.Node],
|
||||
kwargs_nodes: Dict[str, torch.fx.Node],
|
||||
args_nodes: list[torch.fx.Node],
|
||||
kwargs_nodes: dict[str, torch.fx.Node],
|
||||
module_to_swap: torch.nn.Module,
|
||||
name: str,
|
||||
) -> torch.fx.Node:
|
||||
@ -229,8 +229,8 @@ def _deconstruct_outputs(
|
||||
gm: torch.fx.GraphModule,
|
||||
signature: ModuleCallSignature,
|
||||
module_node: torch.fx.Node,
|
||||
node_name_map: Dict[str, torch.fx.Node],
|
||||
orig_outputs: Tuple[torch.fx.Node, ...],
|
||||
node_name_map: dict[str, torch.fx.Node],
|
||||
orig_outputs: tuple[torch.fx.Node, ...],
|
||||
) -> None:
|
||||
from .unflatten import _generate_flatten_spec
|
||||
|
||||
@ -246,17 +246,17 @@ def _deconstruct_outputs(
|
||||
|
||||
def _swap_module_helper(
|
||||
gm: torch.fx.GraphModule,
|
||||
modules_to_swap: Dict[str, torch.nn.Module],
|
||||
module_call_graph: Dict[str, ModuleCallSignature],
|
||||
modules_to_swap: dict[str, torch.nn.Module],
|
||||
module_call_graph: dict[str, ModuleCallSignature],
|
||||
) -> torch.fx.GraphModule:
|
||||
log.debug("Starting graph:")
|
||||
log.debug(gm.graph)
|
||||
|
||||
legalize_graph(gm)
|
||||
|
||||
partitions: Dict[str, NodeList] = defaultdict(list)
|
||||
partitions: dict[str, NodeList] = defaultdict(list)
|
||||
|
||||
node_name_map: Dict[str, torch.fx.Node] = {
|
||||
node_name_map: dict[str, torch.fx.Node] = {
|
||||
node.name: node for node in gm.graph.nodes
|
||||
}
|
||||
|
||||
@ -399,7 +399,7 @@ def _fix_input_output_signature(
|
||||
|
||||
|
||||
def _swap_modules(
|
||||
ep: ExportedProgram, modules_to_swap: Dict[str, torch.nn.Module]
|
||||
ep: ExportedProgram, modules_to_swap: dict[str, torch.nn.Module]
|
||||
) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps
|
||||
|
@ -10,7 +10,7 @@ import time
|
||||
import types
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -109,7 +109,7 @@ class ExportDynamoConfig:
|
||||
"""
|
||||
|
||||
allow_rnn: bool = True
|
||||
reorderable_logging_functions: Set[Callable] = dataclasses.field(
|
||||
reorderable_logging_functions: set[Callable] = dataclasses.field(
|
||||
default_factory=set
|
||||
)
|
||||
# Emit runtime asserts after AOTAutograd instead.
|
||||
@ -123,7 +123,7 @@ class ExportDynamoConfig:
|
||||
class ATenExportArtifact:
|
||||
gm: torch.fx.GraphModule
|
||||
sig: ExportGraphSignature
|
||||
constants: Dict[
|
||||
constants: dict[
|
||||
str,
|
||||
Union[
|
||||
torch.Tensor,
|
||||
@ -139,7 +139,7 @@ class ExportArtifact:
|
||||
in_spec: TreeSpec
|
||||
out_spec: TreeSpec
|
||||
fake_mode: FakeTensorMode
|
||||
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]]
|
||||
module_call_specs: dict[str, dict[str, pytree.TreeSpec]]
|
||||
|
||||
|
||||
DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
|
||||
@ -223,8 +223,8 @@ def _extract_fake_inputs(gm, args, kwargs):
|
||||
Also return the fake mode used to fakify those inputs.
|
||||
"""
|
||||
|
||||
fake_inps: List[torch.Tensor] = []
|
||||
fake_vals: List[torch.Tensor] = []
|
||||
fake_inps: list[torch.Tensor] = []
|
||||
fake_vals: list[torch.Tensor] = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder" and "val" in node.meta:
|
||||
fake_val = node.meta["val"]
|
||||
@ -336,21 +336,21 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
|
||||
def _get_param_buffer_mapping(
|
||||
original_module: torch.nn.Module,
|
||||
traced_module: torch.nn.Module,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Returns a mapping of parameter/buffer names from the new module to the
|
||||
original model. This is to help with restoring the FQN for parameter/buffers
|
||||
of a traced module to what the original module contains.
|
||||
"""
|
||||
|
||||
param_lookup: Dict[int, str] = {}
|
||||
buffer_lookup: Dict[int, str] = {}
|
||||
param_lookup: dict[int, str] = {}
|
||||
buffer_lookup: dict[int, str] = {}
|
||||
for name, param in original_module.named_parameters(remove_duplicate=False):
|
||||
param_lookup[id(param)] = name
|
||||
for name, buffer in original_module.named_buffers(remove_duplicate=False):
|
||||
buffer_lookup[id(buffer)] = name
|
||||
|
||||
param_buffer_table: Dict[str, str] = {}
|
||||
param_buffer_table: dict[str, str] = {}
|
||||
for dynamo_name, dynamo_param in traced_module.named_parameters(
|
||||
remove_duplicate=False
|
||||
):
|
||||
@ -371,9 +371,9 @@ def _get_param_buffer_mapping(
|
||||
def _preserve_requires_grad_pass(
|
||||
gm: torch.fx.GraphModule,
|
||||
sig: ExportGraphSignature,
|
||||
fake_params_buffers: Dict[str, torch.Tensor],
|
||||
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
|
||||
flat_fake_args: List[Any],
|
||||
fake_params_buffers: dict[str, torch.Tensor],
|
||||
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
|
||||
flat_fake_args: list[Any],
|
||||
):
|
||||
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
||||
assert len(sig.input_specs) == len(placeholders)
|
||||
@ -410,10 +410,10 @@ def _preserve_requires_grad_pass(
|
||||
def _remap_constants(
|
||||
orig_constant_attrs: ConstantAttrMap,
|
||||
graph_signature: ExportGraphSignature,
|
||||
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
|
||||
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
|
||||
) -> None:
|
||||
"""Rewrite the graph signature and constants table to use the FQN from the original module."""
|
||||
remap_table: Dict[str, List[str]] = {}
|
||||
remap_table: dict[str, list[str]] = {}
|
||||
for name, value in constants.items():
|
||||
if value in orig_constant_attrs:
|
||||
remap_table[name] = orig_constant_attrs[value]
|
||||
@ -617,7 +617,7 @@ def _restore_state_dict(
|
||||
traced_module.recompile()
|
||||
|
||||
|
||||
def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]:
|
||||
def _get_module_hierarchy(mod: torch.nn.Module) -> dict[str, str]:
|
||||
return {
|
||||
name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False)
|
||||
}
|
||||
@ -626,9 +626,9 @@ def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]:
|
||||
def _make_module_call_graph(
|
||||
in_spec: TreeSpec,
|
||||
out_spec: TreeSpec,
|
||||
module_call_signatures: Dict[str, ModuleCallSignature],
|
||||
forward_arg_names: Optional[List[str]] = None,
|
||||
) -> List[ModuleCallEntry]:
|
||||
module_call_signatures: dict[str, ModuleCallSignature],
|
||||
forward_arg_names: Optional[list[str]] = None,
|
||||
) -> list[ModuleCallEntry]:
|
||||
original = [
|
||||
ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
|
||||
for fqn in _EXPORT_MODULE_HIERARCHY # type: ignore[union-attr]
|
||||
@ -651,11 +651,11 @@ def _make_module_call_graph(
|
||||
|
||||
def _export_to_torch_ir(
|
||||
f: Callable,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
*,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
disable_constraint_solver: bool = False,
|
||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
||||
restore_fqn: bool = True,
|
||||
@ -681,7 +681,7 @@ def _export_to_torch_ir(
|
||||
_check_dynamic_shapes(combined_args, dynamic_shapes)
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
|
||||
try:
|
||||
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
|
||||
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {}
|
||||
ctx = nullcontext()
|
||||
if not isinstance(f, torch.fx.GraphModule):
|
||||
ctx = _wrap_submodules( # type: ignore[assignment]
|
||||
@ -814,9 +814,9 @@ def _export_to_aten_ir(
|
||||
|
||||
def _get_forward_arg_names(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> List[str]:
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Gets the argument names to forward that are used, for restoring the
|
||||
original signature when unlifting the exported program module.
|
||||
@ -829,7 +829,7 @@ def _get_forward_arg_names(
|
||||
sig = inspect.signature(mod.forward)
|
||||
_args = sig.bind_partial(*args).arguments
|
||||
|
||||
names: List[str] = []
|
||||
names: list[str] = []
|
||||
for name, value in _args.items():
|
||||
# handle variable number of positional args
|
||||
if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL:
|
||||
@ -843,7 +843,7 @@ def _get_forward_arg_names(
|
||||
return names
|
||||
|
||||
|
||||
def _get_non_persistent_buffers(mod: torch.nn.Module) -> Set[str]:
|
||||
def _get_non_persistent_buffers(mod: torch.nn.Module) -> set[str]:
|
||||
"""
|
||||
Returns set of non-persistent buffers in a module and its submodules.
|
||||
"""
|
||||
@ -855,10 +855,10 @@ def _get_non_persistent_buffers(mod: torch.nn.Module) -> Set[str]:
|
||||
|
||||
|
||||
def _rewrite_dynamo_tensor_constants(
|
||||
orig_mod_buffers: Set[torch.Tensor],
|
||||
traced_mod_buffers: Dict[str, torch.Tensor],
|
||||
orig_mod_buffers: set[torch.Tensor],
|
||||
traced_mod_buffers: dict[str, torch.Tensor],
|
||||
graph_signature: ExportGraphSignature,
|
||||
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
|
||||
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
|
||||
) -> None:
|
||||
"""
|
||||
Dynamo erroneously marks tensor attributes on modules as buffers.
|
||||
@ -879,7 +879,7 @@ def _rewrite_dynamo_tensor_constants(
|
||||
def _move_non_persistent_buffers_to_tensor_constants(
|
||||
orig_mod: torch.nn.Module,
|
||||
graph_signature: ExportGraphSignature,
|
||||
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
|
||||
constants: dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
|
||||
) -> None:
|
||||
"""
|
||||
Moves non-persistent buffers to tensor constants.
|
||||
@ -984,7 +984,7 @@ def _verify_placeholder_names(
|
||||
)
|
||||
|
||||
|
||||
def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
|
||||
def get_ep_stats(ep: ExportedProgram) -> dict[str, Any]:
|
||||
op_count = 0
|
||||
op_set = set()
|
||||
for m in ep.graph_module.modules():
|
||||
@ -1000,8 +1000,8 @@ def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]:
|
||||
return {"op_count": op_count, "op_set": op_set}
|
||||
|
||||
|
||||
_EXPORT_FLAGS: Optional[Set[str]] = None
|
||||
_EXPORT_MODULE_HIERARCHY: Optional[Dict[str, str]] = None
|
||||
_EXPORT_FLAGS: Optional[set[str]] = None
|
||||
_EXPORT_MODULE_HIERARCHY: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
def _log_export_wrapper(fn):
|
||||
@ -1065,7 +1065,7 @@ def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs):
|
||||
return example_inputs, example_kwarg_inputs
|
||||
|
||||
|
||||
def _get_original_state_dict(mod: torch.nn.Module) -> Dict[str, Any]:
|
||||
def _get_original_state_dict(mod: torch.nn.Module) -> dict[str, Any]:
|
||||
# Explicitly not calling mode.state_dict() as we do not want the module state for serialization
|
||||
# but the running module state so we can always match by id() the entries here with the graph inputs
|
||||
named_parameters = dict(mod.named_parameters(remove_duplicate=False))
|
||||
@ -1096,24 +1096,24 @@ def _process_export_inputs(mod, args, kwargs, dynamic_shapes):
|
||||
|
||||
def _get_module_call_graph(
|
||||
export_artifact: ExportArtifact,
|
||||
preserve_module_call_signature: Tuple[str, ...],
|
||||
preserve_module_call_signature: tuple[str, ...],
|
||||
strict_mode_export: bool,
|
||||
forward_arg_names: Optional[List[str]] = None,
|
||||
) -> Tuple[torch.fx.GraphModule, List[ModuleCallEntry]]:
|
||||
forward_arg_names: Optional[list[str]] = None,
|
||||
) -> tuple[torch.fx.GraphModule, list[ModuleCallEntry]]:
|
||||
"""
|
||||
In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and
|
||||
return module_call_graph.
|
||||
"""
|
||||
gm: torch.fx.GraphModule = export_artifact.aten.gm
|
||||
export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
|
||||
module_call_specs: Dict[
|
||||
str, Dict[str, TreeSpec]
|
||||
module_call_specs: dict[
|
||||
str, dict[str, TreeSpec]
|
||||
] = export_artifact.module_call_specs
|
||||
in_spec: TreeSpec = export_artifact.in_spec
|
||||
out_spec: TreeSpec = export_artifact.out_spec
|
||||
|
||||
# Make module signatures.
|
||||
module_call_signatures: Dict[str, ModuleCallSignature] = {}
|
||||
module_call_signatures: dict[str, ModuleCallSignature] = {}
|
||||
for fqn, specs in module_call_specs.items():
|
||||
mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn
|
||||
module_call_signatures[mod_fqn] = ModuleCallSignature(
|
||||
@ -1142,7 +1142,7 @@ def _get_module_call_graph(
|
||||
|
||||
|
||||
def _get_range_constraints(
|
||||
export_artifact: ExportArtifact, combined_args: Dict[str, Any], dynamic_shapes
|
||||
export_artifact: ExportArtifact, combined_args: dict[str, Any], dynamic_shapes
|
||||
):
|
||||
gm: torch.fx.GraphModule = export_artifact.aten.gm
|
||||
export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
|
||||
@ -1250,12 +1250,12 @@ def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
|
||||
|
||||
def _strict_export(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
|
||||
preserve_module_call_signature: Tuple[str, ...],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
||||
preserve_module_call_signature: tuple[str, ...],
|
||||
pre_dispatch: bool,
|
||||
original_state_dict: Dict[str, Any],
|
||||
original_state_dict: dict[str, Any],
|
||||
orig_in_spec: TreeSpec,
|
||||
allow_complex_guards_as_runtime_asserts: bool,
|
||||
_is_torch_jit_trace: bool,
|
||||
@ -1278,12 +1278,12 @@ def _strict_export(
|
||||
|
||||
def _strict_export_lower_to_aten_ir(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
|
||||
preserve_module_call_signature: Tuple[str, ...],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
||||
preserve_module_call_signature: tuple[str, ...],
|
||||
pre_dispatch: bool,
|
||||
original_state_dict: Dict[str, Any],
|
||||
original_state_dict: dict[str, Any],
|
||||
orig_in_spec: TreeSpec,
|
||||
allow_complex_guards_as_runtime_asserts: bool,
|
||||
_is_torch_jit_trace: bool,
|
||||
@ -1364,7 +1364,7 @@ def _strict_export_lower_to_aten_ir(
|
||||
# params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level)
|
||||
|
||||
constant_attrs = _gather_constant_attrs(mod)
|
||||
param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
|
||||
param_buffer_table: dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level)
|
||||
|
||||
# Dynamo does not track which buffers were registered as non-persistent. This info
|
||||
# is available in the original module, so we transfer it to the traced module. Also,
|
||||
@ -1459,7 +1459,7 @@ def _export_to_aten_ir_make_fx(
|
||||
mod, params_spec, params_len, store_orig_mod=True
|
||||
)
|
||||
|
||||
params_buffers_args: List[Any] = []
|
||||
params_buffers_args: list[Any] = []
|
||||
params_buffers_args.extend(params_and_buffers_flat)
|
||||
params_buffers_args.extend(args)
|
||||
|
||||
@ -1480,7 +1480,7 @@ def _export_to_aten_ir_make_fx(
|
||||
|
||||
# For any buffer that is assigned, we want to associate it to the final proxy node
|
||||
# that it is assigned to. This node can then be copied into the buffer.
|
||||
assigned_buffers: Dict[str, str] = {}
|
||||
assigned_buffers: dict[str, str] = {}
|
||||
hook = register_buffer_assignment_hook(
|
||||
non_strict_root, assigned_buffers
|
||||
)
|
||||
@ -1532,12 +1532,12 @@ def _export_to_aten_ir_make_fx(
|
||||
|
||||
# Dictionary that tracks subclass type to original getattr function
|
||||
# and the attributes we can proxy.
|
||||
tensor_type_to_old_getattribute: Dict[
|
||||
Type[torch.Tensor], Tuple[Callable, Set[str]]
|
||||
tensor_type_to_old_getattribute: dict[
|
||||
type[torch.Tensor], tuple[Callable, set[str]]
|
||||
] = {}
|
||||
for arg in args:
|
||||
subclass_types_to_instances: Dict[
|
||||
Type[torch.Tensor], List[Type[torch.Tensor]]
|
||||
subclass_types_to_instances: dict[
|
||||
type[torch.Tensor], list[type[torch.Tensor]]
|
||||
] = get_subclass_typing_container(arg)
|
||||
for subclass_type in subclass_types_to_instances:
|
||||
if subclass_type not in tensor_type_to_old_getattribute:
|
||||
@ -1725,12 +1725,12 @@ def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
|
||||
|
||||
def _non_strict_export(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]],
|
||||
preserve_module_call_signature: Tuple[str, ...],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
||||
preserve_module_call_signature: tuple[str, ...],
|
||||
pre_dispatch: bool,
|
||||
original_state_dict: Dict[str, Any],
|
||||
original_state_dict: dict[str, Any],
|
||||
orig_in_spec: TreeSpec,
|
||||
allow_complex_guards_as_runtime_asserts: bool,
|
||||
_is_torch_jit_trace: bool,
|
||||
@ -1744,7 +1744,7 @@ def _non_strict_export(
|
||||
out_spec: Optional[TreeSpec] = None
|
||||
in_spec: Optional[TreeSpec] = None
|
||||
|
||||
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
|
||||
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {}
|
||||
|
||||
def _tuplify_outputs(aot_export):
|
||||
def _aot_export_non_strict(mod, args, kwargs=None, **flags):
|
||||
@ -1896,12 +1896,12 @@ def _non_strict_export(
|
||||
@_disable_prexisiting_fake_mode
|
||||
def _export_for_training(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
global _EXPORT_MODULE_HIERARCHY
|
||||
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
||||
@ -1987,12 +1987,12 @@ def _export_for_training(
|
||||
@_disable_prexisiting_fake_mode
|
||||
def _export(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
*,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
pre_dispatch: bool = False,
|
||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
||||
_is_torch_jit_trace: bool = False,
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from torch.utils._pytree import Context, TreeSpec
|
||||
|
||||
|
||||
def reorder_kwargs(user_kwargs: Dict[str, Any], spec: TreeSpec) -> Dict[str, Any]:
|
||||
def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any]:
|
||||
"""Reorder user-provided kwargs to match the order in `spec`. `spec` is
|
||||
expected to be the in_spec of an exported program, i.e. the spec that
|
||||
results from flattening `(args, kwargs)`.
|
||||
|
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -25,7 +26,7 @@ from .exported_program import (
|
||||
)
|
||||
|
||||
|
||||
def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> List:
|
||||
def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list:
|
||||
reordered_kwargs = reorder_kwargs(kwargs, in_spec)
|
||||
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(
|
||||
(args, reordered_kwargs)
|
||||
@ -61,7 +62,7 @@ def _check_input_constraints_pre_hook(self, args, kwargs):
|
||||
def _unlift_inputs_as_getattr(
|
||||
gm: torch.fx.GraphModule,
|
||||
lifted_inputs: Sequence[Optional[str]],
|
||||
) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]:
|
||||
) -> tuple[dict[str, torch.fx.Node], dict[str, torch.fx.Node]]:
|
||||
"""
|
||||
Unlift inputs referring to params/buffers/constants as getattr nodes in the
|
||||
graph
|
||||
@ -90,8 +91,8 @@ def _unlift_inputs_as_getattr(
|
||||
def _insert_copy_for_mutations(
|
||||
gm: torch.fx.GraphModule,
|
||||
mutated_outputs: Sequence[Optional[str]],
|
||||
unlifted_name_to_node: Dict[str, torch.fx.Node],
|
||||
input_name_to_node: Dict[str, torch.fx.Node],
|
||||
unlifted_name_to_node: dict[str, torch.fx.Node],
|
||||
input_name_to_node: dict[str, torch.fx.Node],
|
||||
) -> None:
|
||||
"""
|
||||
Find the all the buffers and inputs that were mutated and insert copy_
|
||||
@ -144,7 +145,7 @@ def _insert_copy_for_mutations(
|
||||
def _get_codegen(
|
||||
in_spec: pytree.TreeSpec,
|
||||
out_spec: Optional[pytree.TreeSpec],
|
||||
forward_arg_names: Optional[List[str]] = None,
|
||||
forward_arg_names: Optional[list[str]] = None,
|
||||
) -> _PyTreeCodeGen:
|
||||
"""
|
||||
Create the codegen for the graph module based on the in/out specs
|
||||
@ -180,9 +181,9 @@ def _unlift(
|
||||
mutated_outputs: Sequence[Optional[str]],
|
||||
in_spec: pytree.TreeSpec,
|
||||
out_spec: Optional[pytree.TreeSpec],
|
||||
state_dict: Dict[str, Any],
|
||||
constants: Dict[str, Any],
|
||||
forward_arg_names: Optional[List[str]] = None,
|
||||
state_dict: dict[str, Any],
|
||||
constants: dict[str, Any],
|
||||
forward_arg_names: Optional[list[str]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -214,8 +215,8 @@ def _unlift(
|
||||
def _register_attrs_to_new_gm(
|
||||
new_gm: torch.fx.GraphModule,
|
||||
graph_signature: ExportGraphSignature,
|
||||
state_dict: Dict[str, Any],
|
||||
constants: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
constants: dict[str, Any],
|
||||
) -> None:
|
||||
non_persistent_buffers = set(graph_signature.non_persistent_buffers)
|
||||
for name in graph_signature.buffers:
|
||||
@ -388,7 +389,7 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
|
||||
forward_arg_names = (
|
||||
sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None
|
||||
)
|
||||
lifted_inputs: List[Optional[str]] = [
|
||||
lifted_inputs: list[Optional[str]] = [
|
||||
(
|
||||
in_spec.target
|
||||
if in_spec.kind
|
||||
@ -403,7 +404,7 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
|
||||
for in_spec in ep.graph_signature.input_specs
|
||||
]
|
||||
|
||||
mutated_outputs: List[Optional[str]] = [
|
||||
mutated_outputs: list[Optional[str]] = [
|
||||
(
|
||||
out_spec.target
|
||||
if out_spec.kind
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._export.utils import (
|
||||
@ -13,7 +13,7 @@ from torch._export.utils import (
|
||||
__all__ = ["CustomDecompTable"]
|
||||
|
||||
|
||||
class CustomDecompTable(Dict[torch._ops.OperatorBase, Callable]):
|
||||
class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]):
|
||||
"""
|
||||
This is a custom dictionary that is specifically used for handling decomp_table in export.
|
||||
The reason we need this is because in the new world, you can only *delete* an op from decomp
|
||||
@ -126,7 +126,7 @@ class CustomDecompTable(Dict[torch._ops.OperatorBase, Callable]):
|
||||
self._materialize_if_needed()
|
||||
return self.decomp_table.items()
|
||||
|
||||
def materialize(self) -> Dict[torch._ops.OperatorBase, Callable]:
|
||||
def materialize(self) -> dict[torch._ops.OperatorBase, Callable]:
|
||||
for op in _collect_all_valid_cia_ops():
|
||||
if _is_aten_op(op):
|
||||
continue
|
||||
|
@ -5,7 +5,7 @@ import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import (
|
||||
@ -242,7 +242,7 @@ Dim.DYNAMIC = _DimHint.DYNAMIC # type: ignore[attr-defined]
|
||||
|
||||
def dims(
|
||||
*names: str, min: Optional[int] = None, max: Optional[int] = None
|
||||
) -> Tuple[_Dim, ...]:
|
||||
) -> tuple[_Dim, ...]:
|
||||
"""
|
||||
Util to create multiple :func:`Dim` types.
|
||||
|
||||
@ -401,13 +401,13 @@ Constraint = Union[_Constraint, _DerivedConstraint, _RelaxedConstraint]
|
||||
|
||||
def _process_equalities(
|
||||
constraint: Constraint,
|
||||
get_sources: Callable[[int, int], List["Source"]],
|
||||
get_sources: Callable[[int, int], list["Source"]],
|
||||
shape_env: "ShapeEnv",
|
||||
names: Dict[str, Tuple[int, int]],
|
||||
source_pairs: List[Tuple["Source", "Source"]],
|
||||
derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]],
|
||||
phantom_symbols: Dict[str, "Symbol"],
|
||||
relaxed_sources: Set["Source"],
|
||||
names: dict[str, tuple[int, int]],
|
||||
source_pairs: list[tuple["Source", "Source"]],
|
||||
derived_equalities: list[tuple["Source", Union["Source", "Symbol"], Callable]],
|
||||
phantom_symbols: dict[str, "Symbol"],
|
||||
relaxed_sources: set["Source"],
|
||||
):
|
||||
"""
|
||||
Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become
|
||||
@ -582,7 +582,7 @@ def _tree_map_with_path(
|
||||
raise
|
||||
|
||||
|
||||
def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> Dict[str, Any]:
|
||||
def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> dict[str, Any]:
|
||||
# combine args and kwargs following the signature of f, as it happens
|
||||
# in the body of f when called with *args, **kwargs
|
||||
if isinstance(f, ExportedProgram):
|
||||
@ -684,8 +684,8 @@ def _warn_on_None_dynamic_shape_dimension():
|
||||
|
||||
|
||||
def _check_dynamic_shapes(
|
||||
combined_args: Dict[str, Any],
|
||||
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
||||
combined_args: dict[str, Any],
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
||||
):
|
||||
"""
|
||||
Checks the dynamic_shapes specification for correctness,
|
||||
@ -698,7 +698,7 @@ def _check_dynamic_shapes(
|
||||
if isinstance(dynamic_shapes, (tuple, list)):
|
||||
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
|
||||
|
||||
bounds: Dict[str, Tuple[int, int]] = {}
|
||||
bounds: dict[str, tuple[int, int]] = {}
|
||||
|
||||
def check_same_bounds(dim):
|
||||
if dim.__name__ in bounds:
|
||||
@ -799,9 +799,9 @@ def _check_dynamic_shapes(
|
||||
|
||||
|
||||
def _process_dynamic_shapes(
|
||||
combined_args: Dict[str, Any],
|
||||
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
||||
) -> List[Constraint]:
|
||||
combined_args: dict[str, Any],
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
||||
) -> list[Constraint]:
|
||||
"""
|
||||
Reads the dynamic_shapes specification and produces a list of constraints.
|
||||
"""
|
||||
@ -814,12 +814,12 @@ def _process_dynamic_shapes(
|
||||
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
|
||||
|
||||
# map of Dim names representing input shape dimensions to constraints on them
|
||||
symbols: Dict[str, List[Constraint]] = defaultdict(list)
|
||||
symbols: dict[str, list[Constraint]] = defaultdict(list)
|
||||
# track roots that do not directly represent input shape dimensions
|
||||
phantom_roots: Dict[str, _PhantomRoot] = {}
|
||||
derived_constraints_with_phantom_root: List[_DerivedConstraint] = []
|
||||
phantom_roots: dict[str, _PhantomRoot] = {}
|
||||
derived_constraints_with_phantom_root: list[_DerivedConstraint] = []
|
||||
# list of constraints to return
|
||||
constraints: List[Constraint] = []
|
||||
constraints: list[Constraint] = []
|
||||
|
||||
def to_constraint(dim, tensor, i):
|
||||
import sympy
|
||||
@ -979,7 +979,7 @@ def _process_dynamic_shapes(
|
||||
|
||||
|
||||
def _get_dim_name_mapping(
|
||||
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None]
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
|
||||
):
|
||||
name_to_dim = {}
|
||||
for dim in tree_flatten(
|
||||
@ -1002,8 +1002,8 @@ def _get_dim_name_mapping(
|
||||
|
||||
def refine_dynamic_shapes_from_suggested_fixes(
|
||||
msg: str,
|
||||
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
|
||||
) -> Union[Dict[str, Any], Tuple[Any], List[Any]]:
|
||||
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
|
||||
) -> Union[dict[str, Any], tuple[Any], list[Any]]:
|
||||
"""
|
||||
When exporting with :func:`dynamic_shapes`, export may fail with a ConstraintViolation error if the specification
|
||||
doesn't match the constraints inferred from tracing the model. The error message may provide suggested fixes -
|
||||
@ -1072,7 +1072,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
|
||||
name_to_dim = _get_dim_name_mapping(dynamic_shapes)
|
||||
|
||||
# track derived dim roots
|
||||
roots: Set[str] = set()
|
||||
roots: set[str] = set()
|
||||
for k, c in shape_fixes.items():
|
||||
assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr))
|
||||
if isinstance(c, sympy.Expr): # check dim/derived dim expression
|
||||
@ -1087,7 +1087,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
|
||||
assert k in name_to_dim or k in roots
|
||||
|
||||
# cache so we don't produce multiple derived dim objects
|
||||
derived_dim_cache: Dict[str, _DerivedDim] = {}
|
||||
derived_dim_cache: dict[str, _DerivedDim] = {}
|
||||
|
||||
def apply_fixes(path, dim, dummy):
|
||||
if dim is None or isinstance(dim, int): # not dynamic
|
||||
|
@ -7,9 +7,7 @@ from torch.export.exported_program import _decompose_exported_program
|
||||
|
||||
def _copy_graph_module_and_signature(
|
||||
ep: torch.fx.GraphModule,
|
||||
) -> typing.Tuple[
|
||||
torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature
|
||||
]:
|
||||
) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]:
|
||||
# copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(),
|
||||
# and this can break placeholder names in some particular cases.
|
||||
# For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'.
|
||||
|
@ -8,20 +8,9 @@ import operator
|
||||
import types
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
final,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
@ -100,11 +89,11 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleCallSignature:
|
||||
inputs: List[ArgumentSpec]
|
||||
outputs: List[ArgumentSpec]
|
||||
inputs: list[ArgumentSpec]
|
||||
outputs: list[ArgumentSpec]
|
||||
in_spec: pytree.TreeSpec
|
||||
out_spec: pytree.TreeSpec
|
||||
forward_arg_names: Optional[List[str]] = None
|
||||
forward_arg_names: Optional[list[str]] = None
|
||||
|
||||
def replace_all_uses_with(self, original_node, new_node):
|
||||
for i in self.inputs:
|
||||
@ -300,8 +289,8 @@ def _override_decomp_aten_to_variants():
|
||||
|
||||
|
||||
def _split_decomp_table_to_cia_and_python_decomp(
|
||||
decomp_table: Dict[torch._ops.OperatorBase, Callable]
|
||||
) -> Tuple[Dict[torch._ops.OperatorBase, Callable], ...]:
|
||||
decomp_table: dict[torch._ops.OperatorBase, Callable]
|
||||
) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]:
|
||||
all_preservable_cia_ops = set(_collect_all_valid_cia_ops())
|
||||
cia_ops_to_callable = {}
|
||||
|
||||
@ -355,8 +344,8 @@ def default_decompositions() -> "CustomDecompTable":
|
||||
def _decompose_and_get_gm_with_new_signature_constants(
|
||||
ep,
|
||||
*,
|
||||
cia_to_decomp: Dict[torch._ops.OperatorBase, Callable],
|
||||
python_decomp_table: Dict[torch._ops.OperatorBase, Callable],
|
||||
cia_to_decomp: dict[torch._ops.OperatorBase, Callable],
|
||||
python_decomp_table: dict[torch._ops.OperatorBase, Callable],
|
||||
joint_loss_index: Optional[int],
|
||||
decompose_custom_triton_ops,
|
||||
):
|
||||
@ -743,7 +732,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
|
||||
def _remove_unneccessary_copy_op_pass(
|
||||
gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature
|
||||
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]:
|
||||
) -> tuple[torch.fx.GraphModule, ExportGraphSignature]:
|
||||
"""
|
||||
Removes redundant copy_ node that was introduced due to mutated buffer.
|
||||
"""
|
||||
@ -774,8 +763,8 @@ def _common_getitem_elimination_pass(
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
|
||||
node_id: Dict[torch.fx.Node, str] = {}
|
||||
getitems: Dict[str, torch.fx.Node] = {}
|
||||
node_id: dict[torch.fx.Node, str] = {}
|
||||
getitems: dict[str, torch.fx.Node] = {}
|
||||
for node in list(module.graph.nodes):
|
||||
if node.op == "call_function" and node.target == operator.getitem:
|
||||
source, idx = node.args
|
||||
@ -797,13 +786,13 @@ def _common_getitem_elimination_pass(
|
||||
|
||||
def _get_updated_module_call_graph(
|
||||
gm: torch.fx.GraphModule,
|
||||
old_module_call_graph: List[ModuleCallEntry],
|
||||
old_module_call_graph: list[ModuleCallEntry],
|
||||
):
|
||||
new_module_call_graph = copy.deepcopy(old_module_call_graph)
|
||||
|
||||
# use node-level provenance metadata to create a map
|
||||
# from old node names to new node names
|
||||
provenance: Dict[str, str] = {}
|
||||
provenance: dict[str, str] = {}
|
||||
for node in gm.graph.nodes:
|
||||
if history := node.meta.get("from_node", []):
|
||||
provenance[history[-1].name] = node.name
|
||||
@ -822,8 +811,8 @@ def _get_updated_module_call_graph(
|
||||
def _decompose_exported_program(
|
||||
ep,
|
||||
*,
|
||||
cia_to_decomp: Dict[torch._ops.OperatorBase, Callable],
|
||||
python_decomp_table: Dict[torch._ops.OperatorBase, Callable],
|
||||
cia_to_decomp: dict[torch._ops.OperatorBase, Callable],
|
||||
python_decomp_table: dict[torch._ops.OperatorBase, Callable],
|
||||
joint_loss_index: Optional[int],
|
||||
decompose_custom_triton_ops: bool,
|
||||
):
|
||||
@ -889,18 +878,18 @@ class ExportedProgram:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
root: Union[torch.nn.Module, dict[str, Any]],
|
||||
graph: torch.fx.Graph,
|
||||
graph_signature: ExportGraphSignature,
|
||||
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
|
||||
range_constraints: "Dict[sympy.Symbol, Any]",
|
||||
module_call_graph: List[ModuleCallEntry],
|
||||
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
|
||||
state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]],
|
||||
range_constraints: "dict[sympy.Symbol, Any]",
|
||||
module_call_graph: list[ModuleCallEntry],
|
||||
example_inputs: Optional[tuple[tuple[Any, ...], dict[str, Any]]] = None,
|
||||
constants: Optional[
|
||||
Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]]
|
||||
dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]]
|
||||
] = None,
|
||||
*,
|
||||
verifiers: Optional[List[Type[Verifier]]] = None,
|
||||
verifiers: Optional[list[type[Verifier]]] = None,
|
||||
):
|
||||
# Remove codegen related things from the graph. It should just be a flat graph.
|
||||
graph._codegen = torch.fx.graph.CodeGen()
|
||||
@ -912,10 +901,10 @@ class ExportedProgram:
|
||||
self._graph_module, graph_signature, module_call_graph
|
||||
)
|
||||
self._graph_signature: ExportGraphSignature = graph_signature
|
||||
self._state_dict: Dict[str, Any] = state_dict
|
||||
self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
|
||||
self._state_dict: dict[str, Any] = state_dict
|
||||
self._range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints
|
||||
assert module_call_graph is not None
|
||||
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
|
||||
self._module_call_graph: list[ModuleCallEntry] = module_call_graph
|
||||
self._example_inputs = example_inputs
|
||||
|
||||
self._constants = constants or {}
|
||||
@ -975,7 +964,7 @@ class ExportedProgram:
|
||||
yield param
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]:
|
||||
def named_parameters(self) -> Iterator[tuple[str, torch.nn.Parameter]]:
|
||||
"""
|
||||
Returns an iterator over original module parameters, yielding
|
||||
both the name of the parameter as well as the parameter itself.
|
||||
@ -992,7 +981,7 @@ class ExportedProgram:
|
||||
yield buf
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
def named_buffers(self) -> Iterator[tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Returns an iterator over original module buffers, yielding
|
||||
both the name of the buffer as well as the buffer itself.
|
||||
@ -1299,7 +1288,7 @@ class ExportedProgram:
|
||||
@_disable_prexisiting_fake_mode
|
||||
def run_decompositions(
|
||||
self,
|
||||
decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
||||
decomp_table: Optional[dict[torch._ops.OperatorBase, Callable]] = None,
|
||||
decompose_custom_triton_ops: bool = False,
|
||||
) -> "ExportedProgram":
|
||||
"""
|
||||
@ -1520,8 +1509,8 @@ def _get_shape_env(gm):
|
||||
|
||||
def _get_updated_range_constraints(
|
||||
gm: torch.fx.GraphModule,
|
||||
old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None,
|
||||
) -> "Dict[sympy.Symbol, Any]":
|
||||
old_range_constraints: "Optional[dict[sympy.Symbol, Any]]" = None,
|
||||
) -> "dict[sympy.Symbol, Any]":
|
||||
assert old_range_constraints is not None
|
||||
|
||||
shape_env = _get_shape_env(gm)
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import dataclasses
|
||||
from collections.abc import Collection, Mapping
|
||||
from enum import auto, Enum
|
||||
from typing import Collection, Dict, List, Mapping, Optional, Set, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
@ -144,8 +145,8 @@ class OutputSpec:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExportBackwardSignature:
|
||||
gradients_to_parameters: Dict[str, str]
|
||||
gradients_to_user_inputs: Dict[str, str]
|
||||
gradients_to_parameters: dict[str, str]
|
||||
gradients_to_user_inputs: dict[str, str]
|
||||
loss_output: str
|
||||
|
||||
|
||||
@ -221,8 +222,8 @@ class ExportGraphSignature:
|
||||
)
|
||||
"""
|
||||
|
||||
input_specs: List[InputSpec]
|
||||
output_specs: List[OutputSpec]
|
||||
input_specs: list[InputSpec]
|
||||
output_specs: list[OutputSpec]
|
||||
|
||||
# A list of parameters uniquely identified by mangled fully qualified name
|
||||
@property
|
||||
@ -276,7 +277,7 @@ class ExportGraphSignature:
|
||||
# Graph node names of pytree-flattened inputs of original program
|
||||
@property
|
||||
def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]:
|
||||
user_inputs: List[Union[int, float, bool, None, str]] = []
|
||||
user_inputs: list[Union[int, float, bool, None, str]] = []
|
||||
for s in self.input_specs:
|
||||
if s.kind != InputKind.USER_INPUT:
|
||||
continue
|
||||
@ -302,7 +303,7 @@ class ExportGraphSignature:
|
||||
# For joint-graph purposes, will include the loss output.
|
||||
@property
|
||||
def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]:
|
||||
user_outputs: List[Union[int, float, bool, None, str]] = []
|
||||
user_outputs: list[Union[int, float, bool, None, str]] = []
|
||||
for s in self.output_specs:
|
||||
if s.kind not in [
|
||||
OutputKind.USER_OUTPUT,
|
||||
@ -393,8 +394,8 @@ class ExportGraphSignature:
|
||||
@property
|
||||
def backward_signature(self) -> Optional[ExportBackwardSignature]:
|
||||
loss_output = None
|
||||
gradients_to_parameters: Dict[str, str] = {}
|
||||
gradients_to_user_inputs: Dict[str, str] = {}
|
||||
gradients_to_parameters: dict[str, str] = {}
|
||||
gradients_to_user_inputs: dict[str, str] = {}
|
||||
for spec in self.output_specs:
|
||||
if spec.kind == OutputKind.LOSS_OUTPUT:
|
||||
assert loss_output is None
|
||||
@ -537,7 +538,7 @@ def _make_argument_spec(node, token_names) -> ArgumentSpec:
|
||||
def _convert_to_export_graph_signature(
|
||||
graph_signature: "GraphSignature",
|
||||
gm: "torch.fx.GraphModule",
|
||||
non_persistent_buffers: Set[str],
|
||||
non_persistent_buffers: set[str],
|
||||
) -> "ExportGraphSignature":
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
@ -9,7 +9,7 @@ __all__ = ["move_to_device_pass"]
|
||||
|
||||
|
||||
def move_to_device_pass(
|
||||
ep: ExportedProgram, location: Union[torch.device, str, Dict[str, str]]
|
||||
ep: ExportedProgram, location: Union[torch.device, str, dict[str, str]]
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
Move the exported program to the given device.
|
||||
@ -27,7 +27,7 @@ def move_to_device_pass(
|
||||
|
||||
def _get_new_device(
|
||||
curr_device: torch.device,
|
||||
location: Union[torch.device, str, Dict[str, str]],
|
||||
location: Union[torch.device, str, dict[str, str]],
|
||||
) -> str:
|
||||
if isinstance(location, dict):
|
||||
if str(curr_device) in location.keys():
|
||||
|
@ -9,7 +9,7 @@ from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx._pytree as fx_pytree
|
||||
@ -85,7 +85,7 @@ def _assign_attr(
|
||||
# foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2.
|
||||
to_modules = {to_module}
|
||||
for item in prefix:
|
||||
ts: Set[torch.nn.Module] = set()
|
||||
ts: set[torch.nn.Module] = set()
|
||||
for to_module in to_modules:
|
||||
if not hasattr(to_module, item):
|
||||
setattr(to_module, item, torch.nn.Module())
|
||||
@ -222,7 +222,7 @@ class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module):
|
||||
to the next InterpreterModule, and wraps back around after the last.
|
||||
"""
|
||||
|
||||
def __init__(self, attrs: Set[str], call_modules: List[InterpreterModule]):
|
||||
def __init__(self, attrs: set[str], call_modules: list[InterpreterModule]):
|
||||
super().__init__()
|
||||
assert call_modules
|
||||
self._modules = call_modules[0]._modules
|
||||
@ -273,8 +273,8 @@ class FlatArgsAdapter(abc.ABC):
|
||||
self,
|
||||
target_spec: pytree.TreeSpec,
|
||||
input_spec: pytree.TreeSpec,
|
||||
input_args: List[Any],
|
||||
) -> List[Any]:
|
||||
input_args: list[Any],
|
||||
) -> list[Any]:
|
||||
"""NOTE: This adapter may mutate given ``input_args_with_path``."""
|
||||
...
|
||||
|
||||
@ -316,7 +316,7 @@ class UnflattenedModule(torch.nn.Module):
|
||||
_copy_graph_attrs(export_module._graph_module, self, seen_attrs)
|
||||
|
||||
self.range_constraints = export_module.range_constraints
|
||||
self.equality_constraints: List = []
|
||||
self.equality_constraints: list = []
|
||||
|
||||
# aliasing/unused param or buffer issues:
|
||||
# in strict-mode export, dynamo export will deduplicate aliased tensors,
|
||||
@ -329,8 +329,8 @@ class UnflattenedModule(torch.nn.Module):
|
||||
# the state_dict as module attributes, but only keep the used tensors in the
|
||||
# graph's forward pass (_sink_params).
|
||||
state_dict = export_module.state_dict
|
||||
assigned_params: Set[str] = set() # tracking unused params
|
||||
id_to_param: Dict[int, torch.nn.Parameter] = {} # handling weight-sharing
|
||||
assigned_params: set[str] = set() # tracking unused params
|
||||
id_to_param: dict[int, torch.nn.Parameter] = {} # handling weight-sharing
|
||||
for name in self.graph_signature.parameters: # this loop adds used params
|
||||
param = state_dict[name]
|
||||
if id(param) not in id_to_param:
|
||||
@ -347,8 +347,8 @@ class UnflattenedModule(torch.nn.Module):
|
||||
assigned_params.add(name)
|
||||
|
||||
non_persistent_buffers = set(self.graph_signature.non_persistent_buffers)
|
||||
assigned_buffers: Set[str] = set() # tracking unused buffers
|
||||
id_to_buffer: Dict[int, Tuple[torch.nn.Parameter, bool]] = {}
|
||||
assigned_buffers: set[str] = set() # tracking unused buffers
|
||||
id_to_buffer: dict[int, tuple[torch.nn.Parameter, bool]] = {}
|
||||
for name in self.graph_signature.buffers: # this loop adds used buffers
|
||||
if name in non_persistent_buffers:
|
||||
persistent = False
|
||||
@ -407,7 +407,7 @@ class UnflattenedModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
# use id map so we don't double-clone aliased constants
|
||||
id_to_const: Dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {}
|
||||
id_to_const: dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {}
|
||||
for fqn, constant in export_module.constants.items():
|
||||
if id(constant) not in id_to_const:
|
||||
if isinstance(constant, torch.Tensor):
|
||||
@ -423,14 +423,14 @@ class UnflattenedModule(torch.nn.Module):
|
||||
|
||||
# This is to handle parameters/buffers that point to the same tensor
|
||||
# object id -> list of (node_name, target_name)
|
||||
consts_map: Dict[int, List[Tuple[str, str]]] = defaultdict(list)
|
||||
consts_targets: Set[str] = set()
|
||||
consts_map: dict[int, list[tuple[str, str]]] = defaultdict(list)
|
||||
consts_targets: set[str] = set()
|
||||
|
||||
def add_to_consts_map(obj_id, node_name, target_name):
|
||||
name_list = consts_map[obj_id]
|
||||
name_list.append((node_name, target_name))
|
||||
|
||||
added_params_buffers: Set[str] = set() # track aliased/unused params, buffers
|
||||
added_params_buffers: set[str] = set() # track aliased/unused params, buffers
|
||||
for s in self.graph_signature.input_specs:
|
||||
if s.kind == InputKind.PARAMETER or (
|
||||
s.kind == InputKind.BUFFER and s.persistent
|
||||
@ -476,7 +476,7 @@ class UnflattenedModule(torch.nn.Module):
|
||||
add_to_consts_map(id(tensor), ph_name, fqn)
|
||||
|
||||
# node name -> list of possible targets
|
||||
inputs_to_state: Dict[str, List[str]] = {}
|
||||
inputs_to_state: dict[str, list[str]] = {}
|
||||
for node_target in consts_map.values():
|
||||
targets = [t[1] for t in node_target]
|
||||
for n, _ in node_target:
|
||||
@ -790,7 +790,7 @@ def _compute_accessor(parent_fqn: str, child_fqn: str) -> str:
|
||||
def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
|
||||
def graph_dump(graph: torch.fx.Graph) -> str:
|
||||
ret = []
|
||||
nodes_idx: Dict[int, int] = {}
|
||||
nodes_idx: dict[int, int] = {}
|
||||
|
||||
def arg_dump(arg) -> str:
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
@ -902,15 +902,15 @@ class _ModuleFrame:
|
||||
def __init__(
|
||||
self,
|
||||
flat_graph: torch.fx.Graph,
|
||||
nodes: Tuple[torch.fx.Node, ...],
|
||||
nodes: tuple[torch.fx.Node, ...],
|
||||
seen_nodes,
|
||||
seen_modules,
|
||||
seen_attrs,
|
||||
created_modules,
|
||||
parent,
|
||||
module_stack: List[Tuple[str, Optional[str], int]],
|
||||
module_stack: list[tuple[str, Optional[str], int]],
|
||||
module_id,
|
||||
module_call_graph: Dict[str, ModuleCallSignature],
|
||||
module_call_graph: dict[str, ModuleCallSignature],
|
||||
module: Optional[Union[torch.fx.GraphModule, UnflattenedModule]] = None,
|
||||
):
|
||||
self.flat_graph = flat_graph
|
||||
@ -944,7 +944,7 @@ class _ModuleFrame:
|
||||
self.graph = self.module.graph
|
||||
|
||||
# Mapping of nodes in the flat graph to nodes in this graph.
|
||||
self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
self.node_map: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
self.node_to_placeholder = {}
|
||||
|
||||
self.parent_call_module: Optional[torch.fx.Node] = None
|
||||
@ -1017,7 +1017,7 @@ class _ModuleFrame:
|
||||
] = flat_arg_node
|
||||
|
||||
with self.parent.graph.inserting_before(self.parent_call_module):
|
||||
input_nodes: List[Optional[torch.fx.Node]] = []
|
||||
input_nodes: list[Optional[torch.fx.Node]] = []
|
||||
for input in signature.inputs:
|
||||
if isinstance(input, ConstantArgument):
|
||||
input_nodes.append(input.value) # type: ignore[arg-type]
|
||||
@ -1175,7 +1175,7 @@ class _ModuleFrame:
|
||||
parent_out: Optional[torch.fx.Node] = _generate_flatten_spec(
|
||||
self.parent.module, self.parent_call_module, signature.out_spec
|
||||
)
|
||||
graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node
|
||||
graph_outputs: Union[torch.fx.Node, list[torch.fx.Node]] = tree_out_node
|
||||
else:
|
||||
graph_outputs = []
|
||||
# Iterate through nodes we have copied into self.graph.
|
||||
@ -1351,10 +1351,10 @@ class _SubmoduleEntry:
|
||||
|
||||
|
||||
def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule):
|
||||
seen_nodes: Dict[str, torch.fx.Node] = {}
|
||||
seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list)
|
||||
seen_attrs: Dict[str, Set[str]] = defaultdict(set)
|
||||
created_modules: Dict[str, torch.nn.Module] = {}
|
||||
seen_nodes: dict[str, torch.fx.Node] = {}
|
||||
seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
|
||||
seen_attrs: dict[str, set[str]] = defaultdict(set)
|
||||
created_modules: dict[str, torch.nn.Module] = {}
|
||||
_ModuleFrame(
|
||||
orig_graph,
|
||||
tuple(orig_graph.nodes),
|
||||
@ -1376,7 +1376,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu
|
||||
|
||||
|
||||
def _reorder_submodules(
|
||||
parent: torch.nn.Module, fqn_order: Dict[str, int], prefix: str = ""
|
||||
parent: torch.nn.Module, fqn_order: dict[str, int], prefix: str = ""
|
||||
):
|
||||
# TODO Can be optimized by adding submodules ahead of time.
|
||||
if prefix == "":
|
||||
@ -1496,7 +1496,7 @@ class _IVals:
|
||||
def _copy_graph_attrs(
|
||||
gm: torch.fx.GraphModule,
|
||||
root_module: UnflattenedModule,
|
||||
seen_attrs: Dict[str, Set[str]],
|
||||
seen_attrs: dict[str, set[str]],
|
||||
):
|
||||
for child_fqn, names in seen_attrs.items():
|
||||
module = _get_attr(root_module, child_fqn) if child_fqn else root_module
|
||||
@ -1550,9 +1550,9 @@ def _deduplicate_modules(partitions):
|
||||
|
||||
def _sink_params(
|
||||
module: torch.nn.Module,
|
||||
inputs_to_state: Dict[str, List[str]],
|
||||
scope: List[str],
|
||||
module_id_to_inputs_removed: Optional[Dict[int, Set[str]]] = None,
|
||||
inputs_to_state: dict[str, list[str]],
|
||||
scope: list[str],
|
||||
module_id_to_inputs_removed: Optional[dict[int, set[str]]] = None,
|
||||
):
|
||||
"""Sink params, buffers, and constants from graph inputs into get_attr nodes.
|
||||
|
||||
@ -1613,7 +1613,7 @@ def _sink_params(
|
||||
)
|
||||
|
||||
# Filter out inputs_to_state corresponding to current scope.
|
||||
inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {}
|
||||
inputs_to_state_of_scope: dict[torch.fx.Node, list[str]] = {}
|
||||
for node in inputs:
|
||||
if node.name not in inputs_to_state:
|
||||
continue
|
||||
@ -1640,7 +1640,7 @@ def _sink_params(
|
||||
inputs_to_state_of_scope[node] = state_name
|
||||
|
||||
# Record name of remove inputs for return purpose.
|
||||
inputs_removed: Set[str] = set()
|
||||
inputs_removed: set[str] = set()
|
||||
|
||||
for node, state_name in inputs_to_state_of_scope.items():
|
||||
if len(node.users) > 0:
|
||||
|
Reference in New Issue
Block a user