PEP585 update - torch/fx (#145166)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-19 19:32:07 -08:00
committed by PyTorch MergeBot
parent 6374332d33
commit 0b2a3687b9
57 changed files with 904 additions and 917 deletions

View File

@ -3,6 +3,7 @@
import builtins
import contextlib
import collections
import copy
import functools
import inspect
@ -2767,7 +2768,7 @@ class TestFX(JitTestCase):
return self.other(x)
traced = symbolic_trace(ReturnTypeModule())
self.assertIn("-> typing_List[str]", traced._code)
self.assertIn("-> list[str]", traced._code)
scripted = torch.jit.script(traced)
self.assertIn("-> List[str]", scripted.code)
@ -3566,8 +3567,8 @@ class TestFX(JitTestCase):
traced(x, y)
FileCheck().check("_Tuple[()]") \
.check("typing_Tuple[str,typing_Tuple[()]]") \
FileCheck().check("tuple[()]") \
.check("tuple[str,tuple[()]]") \
.run(traced.code)
scripted = torch.jit.script(traced)
@ -4063,45 +4064,62 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
def _annotation_type_to_stable_str(self, t, sig_str):
_trivial_mappings = {
str : 'str',
int : 'int',
float: 'float',
bool: 'bool',
torch.dtype: 'torch.dtype',
torch.Tensor: 'torch.Tensor',
torch.device: 'torch.device',
torch.memory_format: 'torch.memory_format',
slice: 'slice',
torch.nn.Module: 'torch.nn.modules.module.Module',
torch.fx.Graph : 'torch.fx.graph.Graph',
torch.fx.Node : 'torch.fx.node.Node',
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
torch.fx.node.Target : 'torch.fx.node.Target',
torch.fx.node.Argument : 'torch.fx.node.Argument',
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
Ellipsis : '...',
typing.Any: 'Any',
type(None): 'NoneType',
None: 'None',
typing.Iterator: 'Iterator',
collections.abc.Iterator: 'Iterator',
}
_UNBOUND_TYPES = {
dict,
list,
tuple,
type,
typing.Callable,
typing.Dict,
typing.List,
typing.Tuple,
typing.Type,
typing.Union,
}
def _annotation_type_to_stable_str(self, t, sig_str, recursive: bool = False):
if t is inspect.Signature.empty:
return ''
# Forward ref
if isinstance(t, str):
return f"'{t}'"
if recursive:
return t
else:
return f"'{t}'"
if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
return t.__forward_arg__
if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
return t.__forward_arg__
trivial_mappings = {
str : 'str',
int : 'int',
float: 'float',
bool: 'bool',
torch.dtype: 'torch.dtype',
torch.Tensor: 'torch.Tensor',
torch.device: 'torch.device',
torch.memory_format: 'torch.memory_format',
slice: 'slice',
torch.nn.Module: 'torch.nn.modules.module.Module',
torch.fx.Graph : 'torch.fx.graph.Graph',
torch.fx.Node : 'torch.fx.node.Node',
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
torch.fx.node.Target : 'torch.fx.node.Target',
torch.fx.node.Argument : 'torch.fx.node.Argument',
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
Ellipsis : '...',
typing.Any: 'Any',
type(None): 'NoneType',
None: 'None',
typing.Iterator: 'Iterator',
}
mapping = trivial_mappings.get(t, None)
mapping = self._trivial_mappings.get(t, None)
if mapping:
return mapping
@ -4115,14 +4133,14 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
if all(isinstance(ct, typing.TypeVar) for ct in contained):
contained = []
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str, True) for ct in contained]
contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
origin = getattr(t, '__origin__', None)
if origin is None:
# Unbound types don't have `__origin__` in some Python versions, so fix that up here.
origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
origin = t if t in self._UNBOUND_TYPES else origin
if origin in {tuple, typing.Tuple}:
return f'Tuple{contained_type_str}'
@ -4130,7 +4148,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
# Annoying hack to detect Optional
if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str, True)}]'
return f'Union{contained_type_str}'
if origin in {dict, typing.Dict}:
return f'Dict{contained_type_str}'

View File

@ -1524,6 +1524,29 @@ class {test_classname}(torch.nn.Module):
(int, type(torch.float)),
(Union[int, float], int),
(Union[int, float], float),
(list[int], int),
(list[int], create_type_hint([int, int])),
(list[int], create_type_hint((int, int))),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
(
list[torch.Tensor],
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
),
(torch.Tensor, torch.nn.Parameter),
(list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
(
list[torch.Tensor],
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
),
(torch.Tensor, torch.nn.Parameter),
(list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[list[torch.Tensor]], list[torch.Tensor]),
(Optional[list[int]], list[int]),
] + [
# pre-PEP585 signatures
(List[int], int),
(List[int], create_type_hint([int, int])),
(List[int], create_type_hint((int, int))),
@ -1532,7 +1555,6 @@ class {test_classname}(torch.nn.Module):
List[torch.Tensor],
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
),
(torch.Tensor, torch.nn.Parameter),
(List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
@ -1540,18 +1562,21 @@ class {test_classname}(torch.nn.Module):
List[torch.Tensor],
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
),
(torch.Tensor, torch.nn.Parameter),
(List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[List[torch.Tensor]], List[torch.Tensor]),
(Optional[List[int]], List[int]),
]
for sig_type, arg_type in should_be_equal:
self.assertTrue(type_matches(sig_type, arg_type))
should_fail = [
(int, float),
(Union[int, float], str),
(list[torch.Tensor], List[int]),
] + [
# pre-PEP585 signatures
(List[torch.Tensor], List[int]),
]

View File

@ -1,9 +1,9 @@
import textwrap
from typing import Any, Callable, Dict, TypeVar
from typing import Any, Callable, TypeVar
_BACK_COMPAT_OBJECTS: Dict[Any, None] = {}
_MARKED_WITH_COMPATIBILITY: Dict[Any, None] = {}
_BACK_COMPAT_OBJECTS: dict[Any, None] = {}
_MARKED_WITH_COMPATIBILITY: dict[Any, None] = {}
_T = TypeVar("_T")

View File

@ -1,20 +1,20 @@
# mypy: allow-untyped-defs
from collections import namedtuple
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type
from typing import Any, Callable, NamedTuple, Optional
import torch.return_types
from torch.utils._pytree import PyTree, TreeSpec
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
FlattenFuncSpec = Callable[[PyTree, TreeSpec], list]
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {}
SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {}
SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
def register_pytree_flatten_spec(
cls: Type[Any],
cls: type[Any],
flatten_fn_spec: FlattenFuncSpec,
flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
) -> None:
@ -23,7 +23,7 @@ def register_pytree_flatten_spec(
def _deregister_pytree_flatten_spec(
cls: Type[Any],
cls: type[Any],
) -> None:
del SUPPORTED_NODES[cls]
del SUPPORTED_NODES_EXACT_MATCH[cls]
@ -33,7 +33,7 @@ def tree_flatten_spec(
pytree: PyTree,
spec: TreeSpec,
exact_structural_match=False,
) -> List[Any]:
) -> list[Any]:
if spec.is_leaf():
return [pytree]
if spec.type not in SUPPORTED_NODES:
@ -58,31 +58,31 @@ def tree_flatten_spec(
return result
def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]:
def _dict_flatten_spec(d: dict[Any, Any], spec: TreeSpec) -> list[Any]:
return [d[k] for k in spec.context]
def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
def _list_flatten_spec(d: list[Any], spec: TreeSpec) -> list[Any]:
return [d[i] for i in range(spec.num_children)]
def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]:
def _tuple_flatten_spec(d: tuple[Any], spec: TreeSpec) -> list[Any]:
return [d[i] for i in range(spec.num_children)]
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]:
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]:
return [d[i] for i in range(spec.num_children)]
def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool:
def _dict_flatten_spec_exact_match(d: dict[Any, Any], spec: TreeSpec) -> bool:
return len(d) == spec.num_children
def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool:
def _list_flatten_spec_exact_match(d: list[Any], spec: TreeSpec) -> bool:
return len(d) == spec.num_children
def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool:
def _tuple_flatten_spec_exact_match(d: tuple[Any], spec: TreeSpec) -> bool:
return len(d) == spec.num_children

View File

@ -10,18 +10,7 @@ import os
import warnings
from itertools import chain
from types import CodeType, FunctionType, ModuleType
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
Type,
Union,
)
from typing import Any, Callable, NamedTuple, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -42,7 +31,7 @@ HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
_orig_module_call: Callable = torch.nn.Module.__call__
_orig_module_getattr: Callable = torch.nn.Module.__getattr__
_proxyable_classes: Dict[Type, None] = {}
_proxyable_classes: dict[type, None] = {}
_is_fx_tracing_flag = False
@ -262,8 +251,8 @@ class Tracer(TracerBase):
@compatibility(is_backward_compatible=True)
def __init__(
self,
autowrap_modules: Tuple[ModuleType] = (math,),
autowrap_functions: Tuple[Callable, ...] = (),
autowrap_modules: tuple[ModuleType] = (math,),
autowrap_functions: tuple[Callable, ...] = (),
param_shapes_constant: bool = False,
) -> None:
# This method's signature is overridden by the first line of this class'
@ -296,7 +285,7 @@ class Tracer(TracerBase):
# Functions we will eagerly wrap when we see them while tracing
# this captures both `math.sqrt()` and `from math import sqrt` automatically
self._autowrap_function_ids: Set[int] = {
self._autowrap_function_ids: set[int] = {
id(value)
for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
if not name.startswith("_") and callable(value)
@ -305,20 +294,20 @@ class Tracer(TracerBase):
# Python modules to apply autowrap to at the start, in addition to
# modules we see while tracing
self._autowrap_search: List[ModuleType] = list(autowrap_modules)
self._autowrap_search: list[ModuleType] = list(autowrap_modules)
self.param_shapes_constant = param_shapes_constant
self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
self.submodule_paths: Optional[dict[torch.nn.Module, str]] = None
self.root_module_name: str = ""
# Maps the containing module's name to the operator name
self.scope = Scope("", None)
# Records the module call stack
self.module_stack = collections.OrderedDict()
self.num_calls: Dict[str, int] = {}
self.num_calls: dict[str, int] = {}
# Mapping of node name to module scope
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
self.node_name_to_scope: dict[str, tuple[str, type]] = {}
_qualname_counter: Dict[str, int] = collections.defaultdict(int)
_qualname_counter: dict[str, int] = collections.defaultdict(int)
@compatibility(is_backward_compatible=True)
def get_fresh_qualname(self, prefix: str) -> str:
@ -492,8 +481,8 @@ class Tracer(TracerBase):
self,
m: torch.nn.Module,
forward: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
"""
Method that specifies the behavior of this ``Tracer`` when it encounters
@ -547,7 +536,7 @@ class Tracer(TracerBase):
return ret_val
@compatibility(is_backward_compatible=False)
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
"""
Method that specifies the behavior of this ``Tracer`` when we call getattr
on a call to an ``nn.Module`` instance.
@ -626,7 +615,7 @@ class Tracer(TracerBase):
total_args = co.co_argcount + co.co_kwonlyargcount
orig_args = list(co.co_varnames)
names_iter = iter(co.co_varnames)
args: List[Any] = []
args: list[Any] = []
skip_arg_idx = 0
if is_module:
if total_args == 0:
@ -712,7 +701,7 @@ class Tracer(TracerBase):
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
concrete_args: Optional[dict[str, Any]] = None,
) -> Graph:
"""
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
@ -763,7 +752,7 @@ class Tracer(TracerBase):
self.root = torch.nn.Module()
fn = root
tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None)
tracer_cls: Optional[type[Tracer]] = getattr(self, "__class__", None)
self.graph = Graph(tracer_cls=tracer_cls)
if hasattr(fn, "__code__"):
code = fn.__code__
@ -777,11 +766,11 @@ class Tracer(TracerBase):
# is some other attribute on the model. Construct a dict mapping Tensor
# values to the qualified name here for efficiency. This is used downstream
# in create_arg
self.tensor_attrs: Dict[
self.tensor_attrs: dict[
Union[torch.Tensor, ScriptObject, FakeScriptObject], str
] = {}
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]):
for k, v in m.__dict__.items():
if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)):
self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
@ -797,7 +786,7 @@ class Tracer(TracerBase):
fn, isinstance(root, torch.nn.Module), concrete_args
)
parameter_proxy_cache: Dict[
parameter_proxy_cache: dict[
str, Proxy
] = {} # Reduce number of get_attr calls
@ -872,7 +861,7 @@ class Tracer(TracerBase):
nonlocal cnt
cnt += 1
param = sig.parameters[name]
default: Tuple[Any, ...] = (
default: tuple[Any, ...] = (
() if param.default is inspect.Parameter.empty else (param.default,)
)
out = self.create_proxy(
@ -913,7 +902,7 @@ class Tracer(TracerBase):
return pytree.tree_map(replace_ph, concrete_args[name])
if name[0] == "*":
default: Tuple[Any, ...] = ()
default: tuple[Any, ...] = ()
else:
param = sig.parameters[name]
default = ( # type: ignore[assignment]
@ -932,11 +921,11 @@ class Tracer(TracerBase):
# the purposes of the wrap() API.
# We key by the globals dict id and function name to ensure we're wrapping a given
# function only once.
_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {}
_wrapped_fns_to_patch: dict[tuple[int, str], dict] = {}
# List of methods on classes to wrap (class type, function name)
# this currently only works for Tensor.* methods that aren't traced properly
_wrapped_methods_to_patch: List[Tuple[type, str]] = []
_wrapped_methods_to_patch: list[tuple[type, str]] = []
if os.environ.get("FX_PATCH_GETITEM") == "1":
# This change is needed to trace models like PositionalEmbedding from BERT:
@ -1043,12 +1032,12 @@ class _PatchedFnSetAttr(_PatchedFn):
class _Patcher:
def __init__(self) -> None:
super().__init__()
self.patches_made: List[_PatchedFn] = []
self.visited: Set[int] = set()
self.patches_made: list[_PatchedFn] = []
self.visited: set[int] = set()
def patch(
self,
frame_dict: Dict[str, Any],
frame_dict: dict[str, Any],
name: str,
new_fn: Callable,
deduplicate: bool = True,
@ -1169,7 +1158,7 @@ def _patch_wrapped_functions(patcher: _Patcher):
def _autowrap_check(
patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int]
patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int]
):
"""
Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
@ -1252,7 +1241,7 @@ def wrap(fn_or_name: Union[str, Callable]):
@compatibility(is_backward_compatible=True)
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
concrete_args: Optional[dict[str, Any]] = None,
) -> GraphModule:
"""
Symbolic tracing API

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import sys
from typing import Dict, Optional
from typing import Optional
import torch
from torch._logging import LazyString
@ -43,7 +43,7 @@ def _format_graph_code(name, filename, graph_str):
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[dict]:
"""
Returns the nn_module_stack of the first call_function node.
"""

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import operator
from collections import deque
from typing import Deque, Dict, List, NamedTuple, Set, Tuple
from typing import NamedTuple
import torch
from torch.fx.experimental.partitioner_utils import (
@ -28,15 +28,15 @@ class DAGNode:
def __init__(
self,
submodule_node: Node,
input_nodes: List[Node],
output_nodes: List[Node],
logical_device_ids: List[int],
input_nodes: list[Node],
output_nodes: list[Node],
logical_device_ids: list[int],
size_bytes: int,
) -> None:
self.submodule_node: Node = submodule_node
self.input_nodes: List[Node] = input_nodes
self.output_nodes: List[Node] = output_nodes
self.logical_device_ids: List[int] = logical_device_ids
self.input_nodes: list[Node] = input_nodes
self.output_nodes: list[Node] = output_nodes
self.logical_device_ids: list[int] = logical_device_ids
self.size_bytes = size_bytes
def __str__(self) -> str:
@ -47,14 +47,14 @@ class DAG:
"""DAG class contains all the DAG nodes"""
def __init__(self) -> None:
self.nodes: List[DAGNode] = []
self.nodes: list[DAGNode] = []
def create_node(
self,
submodule_node: Node,
input_nodes: List[Node],
output_nodes: List[Node],
logical_devices: List[int],
input_nodes: list[Node],
output_nodes: list[Node],
logical_devices: list[int],
size_bytes: int,
) -> None:
node = DAGNode(
@ -79,7 +79,7 @@ def reset_partition_device(partitions):
def combine_two_partitions(
partition_0: Partition, partition_1: Partition, partitions: List[Partition]
partition_0: Partition, partition_1: Partition, partitions: list[Partition]
) -> None:
"""Given a list of partitions and its two partitions,
combine these two partitions into a new one appending to the partitions
@ -95,7 +95,7 @@ def combine_two_partitions(
return
def set_parents_and_children(partitions: List[Partition]) -> None:
def set_parents_and_children(partitions: list[Partition]) -> None:
"""Given a list of partitions, mark parents and children for each partition"""
# Go through all nodes in a partition.
# If a node's user is in other partition,
@ -119,7 +119,7 @@ def set_parents_and_children(partitions: List[Partition]) -> None:
return
def reorganize_partitions(partitions: List[Partition]) -> None:
def reorganize_partitions(partitions: list[Partition]) -> None:
"""Given a list of partitions, reorganize partition id,
its parents and its children for each partition
"""
@ -130,17 +130,17 @@ def reorganize_partitions(partitions: List[Partition]) -> None:
return
def get_bfs_level_partition(partitions: List[Partition]) -> None:
def get_bfs_level_partition(partitions: list[Partition]) -> None:
"""Given a list of partitions,
mark the bfs level for each partition
"""
current_level: Set[Partition] = set()
visited: Set[Partition] = set()
current_level: set[Partition] = set()
visited: set[Partition] = set()
for partition in partitions:
# If a partition has no parent, it should be in root level
if len(partition.parents) == 0:
current_level.add(partition)
next_level: Set[Partition] = set()
next_level: set[Partition] = set()
level = 0
# bfs
while current_level:
@ -158,26 +158,26 @@ def get_bfs_level_partition(partitions: List[Partition]) -> None:
return
def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]:
def get_node_to_partition_mapping(partitions: list[Partition]) -> dict[Node, int]:
"""Given a list of partitions,return node to partition mapping"""
node_to_partition: Dict[Node, int] = {}
node_to_partition: dict[Node, int] = {}
for partition in partitions:
for node in partition.nodes:
node_to_partition[node] = partition.partition_id
return node_to_partition
def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]:
def get_logical_id_to_device(devices: list[Device]) -> dict[int, Device]:
"""Get a mapping from device logical ID to Device object."""
logical_id_to_device: Dict[int, Device] = {}
logical_id_to_device: dict[int, Device] = {}
for d in devices:
logical_id_to_device[d.logical_id] = d
return logical_id_to_device
def get_device_partition_stats(
partitions: List[Partition], devices: List[Device]
) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]:
partitions: list[Partition], devices: list[Device]
) -> tuple[dict[Device, list[Partition]], dict[Device, int], list[Partition]]:
"""Given a list of partitions and a list of devices, returns:
1. A mapping from device to partitions on it;
2. A mapping from device to its remaining memory size;
@ -186,9 +186,9 @@ def get_device_partition_stats(
# logical id to device
logical_id_to_device = get_logical_id_to_device(devices)
# Track partitions on device
device_to_partitions: Dict[Device, List[Partition]] = {}
device_to_partitions: dict[Device, list[Partition]] = {}
# Track device's left mem size
device_to_left_mem_bytes: Dict[Device, int] = {}
device_to_left_mem_bytes: dict[Device, int] = {}
for d in devices:
device_to_partitions[d] = []
device_to_left_mem_bytes[d] = d.available_mem_bytes
@ -213,16 +213,16 @@ def get_device_partition_stats(
def get_device_to_partitions_mapping(
partitions: List[Partition], devices: List[Device]
partitions: list[Partition], devices: list[Device]
):
"""Given a list of partitions and a list of devices,
map each partition into a device.
"""
def calculate_extra_mem_bytes_needed_for(
partition: Partition, partitions: List[Partition]
partition: Partition, partitions: list[Partition]
):
all_nodes: Set[Node] = set()
all_nodes: set[Node] = set()
for p in partitions:
all_nodes = all_nodes.union(p.nodes)
if len(all_nodes) == 0:
@ -273,8 +273,8 @@ def check_dependency(partition):
"""Given a partition,check if there is a circular dependency on
this partition using bfs
"""
visited: Set[Partition] = {partition}
queue: Deque[Partition] = deque([partition])
visited: set[Partition] = {partition}
queue: deque[Partition] = deque([partition])
while queue:
p = queue.popleft()
for child in p.children:
@ -298,9 +298,9 @@ class Partitioner:
"""
def __init__(self) -> None:
self.partitions: List[Partition] = []
self.node_to_partition: Dict[Node, int] = {}
self.devices: List[Device] = []
self.partitions: list[Partition] = []
self.node_to_partition: dict[Node, int] = {}
self.devices: list[Device] = []
def partition_graph(
self,
@ -435,9 +435,9 @@ class Partitioner:
return device
# Track partition and its left mem size
partition_to_left_mem_bytes: Dict[Partition, int] = {}
partition_to_left_mem_bytes: dict[Partition, int] = {}
# Track all the devices that have been used
occupied_devices: List[Device] = []
occupied_devices: list[Device] = []
partition = self.create_partition()
for node in self.graph_module.graph.nodes:
if node.op in {"call_module", "call_method", "call_function"}:
@ -516,7 +516,7 @@ class Partitioner:
# Devices that hold partitions
used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
# Track replicates of the assigned devices
replicated_device_to_used_device: Dict[Device, Device] = {}
replicated_device_to_used_device: dict[Device, Device] = {}
while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len(
self.devices
@ -583,7 +583,7 @@ class Partitioner:
continue
if node.target == operator.__getitem__:
continue
input_nodes: Dict[Node, None] = {}
input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# When a node has two or more output nodes,
@ -634,7 +634,7 @@ class Partitioner:
"""
def combine_partitions_based_on_size(
partitions: List[Partition], available_mem_bytes: int
partitions: list[Partition], available_mem_bytes: int
) -> None:
"""Combining small partitions together to keep as less partitions as possible.
Here is an example of the algorithm to do this:
@ -672,10 +672,10 @@ class Partitioner:
return mem_bytes_needed
def find_partition_to_combine_based_on_size(
sorted_partitions: List[Partition],
sorted_partitions: list[Partition],
available_mem_bytes: int,
partitions: List[Partition],
) -> Tuple[bool, List[Partition]]:
partitions: list[Partition],
) -> tuple[bool, list[Partition]]:
"""step 1 in combine_partition_based_on_size()"""
find_combination = False
smallest_partition = sorted_partitions.pop(0)
@ -721,8 +721,8 @@ class Partitioner:
return False
# Track embedding partitions and non-embedding partitions separately
embedding_partitions: List[Partition] = []
non_embedding_partitions: List[Partition] = []
embedding_partitions: list[Partition] = []
non_embedding_partitions: list[Partition] = []
# A Flag to check the boundary
in_embedding_region: bool = False
partition = self.create_partition()
@ -794,7 +794,7 @@ class Partitioner:
def cost_aware_partition(
self,
transfer_rate_bytes_per_sec: float,
node_to_latency_mapping: Dict[Node, NodeLatency],
node_to_latency_mapping: dict[Node, NodeLatency],
) -> None:
"""This method is to partition the fx module based on the cost.
The cost is the total latency of running the whole fx module.
@ -872,7 +872,7 @@ class Partitioner:
)
if len(self.partitions) == 1:
return False
partition_pair: List[int] = []
partition_pair: list[int] = []
for i in range(len(self.partitions) - 1):
for j in range(i + 1, len(self.partitions)):
# Try to combine the partition pair
@ -915,7 +915,7 @@ class Partitioner:
def kl_based_partition(
self,
transfer_rate_bytes_per_sec: float,
node_to_latency_mapping: Dict[Node, NodeLatency],
node_to_latency_mapping: dict[Node, NodeLatency],
) -> None:
"""This function is a cost aware partition based
on Kernighan-Lin algorithm.
@ -987,7 +987,7 @@ class Partitioner:
"""
p1_nodes = list(p1.nodes) + [None]
min_cost = float("inf")
node_pair: List[Node] = []
node_pair: list[Node] = []
for n1 in p1_nodes:
# Ignore the node if it is not a op node
if n1 is not None and n1.op in {"placeholder", "get_attr"}:
@ -1011,9 +1011,9 @@ class Partitioner:
self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
)
# Keep tracking the node pair that shows the better cost
node_pair: List[Node] = []
node_pair: list[Node] = []
# Keep tracking the partition pair of node pair
partition_pair: List[Partition] = []
partition_pair: list[Partition] = []
# Collect all the op nodes from the graph
op_nodes = [
n
@ -1060,7 +1060,7 @@ class Partitioner:
"""This function helps to rebuild the partitions given the nodes and its
corresponding partition id
"""
partition_id_to_partition_mapping: Dict[int, Partition] = {}
partition_id_to_partition_mapping: dict[int, Partition] = {}
self.node_to_partition = node_to_partition_mapping
for node in self.node_to_partition:
partition_id = self.node_to_partition[node]

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import re
from typing import Callable, Dict, Optional, Set, Union
from typing import Callable, Optional, Union
import torch.fx
from torch.fx.node import map_arg
@ -100,7 +100,7 @@ def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
call_mod_args = call_mod_node_to_replace.args
call_mod_kwargs = call_mod_node_to_replace.kwargs
replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {}
replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {}
ph_count = 0
def replacement_fn(node):
@ -171,7 +171,7 @@ def split_const_subgraphs(
# Build up a list of const_nodes, defined as nodes that are themselves
# get_attrs, or have all get_attr or other constant node inputs.
const_nodes: Set[torch.fx.Node] = set()
const_nodes: set[torch.fx.Node] = set()
found_const_folding = False
for node in mod_traced.graph.nodes:
# Skip over placeholders/outputs because they can't be const folded and

View File

@ -1,4 +1,4 @@
from typing import List, Sequence
from collections.abc import Sequence
import torch.fx as fx
@ -19,7 +19,7 @@ def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
the `gm` with breakpoint inserted.
"""
def insert_pdb(body: Sequence[str]) -> List[str]:
def insert_pdb(body: Sequence[str]) -> list[str]:
return ["import pdb; pdb.set_trace()\n", *body]
with gm.graph.on_generate_code(

View File

@ -2,7 +2,7 @@
import itertools
import operator
from functools import reduce
from typing import Callable, Dict, TypeVar
from typing import Callable, TypeVar
from typing_extensions import ParamSpec
import sympy
@ -19,9 +19,9 @@ from torch.nn.modules.conv import Conv2d
_T = TypeVar("_T")
_P = ParamSpec("_P")
_INFERENCE_RULES: Dict[Target, Callable] = {}
_REFINEMENT_RULES: Dict[Target, Callable] = {}
_RULES: Dict[Target, Callable] = {}
_INFERENCE_RULES: dict[Target, Callable] = {}
_REFINEMENT_RULES: dict[Target, Callable] = {}
_RULES: dict[Target, Callable] = {}
__all__ = [
"GraphTypeChecker",

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import itertools
import operator
from typing import Dict, List, Tuple
import torch
from torch.fx._symbolic_trace import symbolic_trace
@ -10,8 +9,8 @@ from torch.fx.passes.tools_common import legalize_graph
def split_result_tensors(
result: torch.Tensor, inputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
result: torch.Tensor, inputs: list[torch.Tensor]
) -> tuple[torch.Tensor, ...]:
"""
A free function for use in the merge_matmul graph transformation below that
splits the output from a merged matmul into the individual results for each
@ -71,7 +70,7 @@ def may_depend_on(a: Node, b: Node, search_depth: int = 6):
return False
def are_nodes_independent(nodes: List[Node]):
def are_nodes_independent(nodes: list[Node]):
"""
Check if all of the given nodes are pairwise-data independent.
@ -102,8 +101,8 @@ def merge_matmul(in_mod: torch.nn.Module):
"""
gm = symbolic_trace(in_mod)
rhs_users: Dict[Node, List[Node]] = {}
lhs_users: Dict[Node, List[Node]] = {}
rhs_users: dict[Node, list[Node]] = {}
lhs_users: dict[Node, list[Node]] = {}
# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
# the matmul of which they are the LHS/RHS.

View File

@ -2,7 +2,7 @@
import builtins
import functools
import warnings
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.fx
@ -40,7 +40,7 @@ def torch_abs_override(input, *, out=None):
return input
manual_meta_overrides: Dict[Callable, Callable] = {
manual_meta_overrides: dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override,
torch.nn.LayerNorm: nn_layernorm_override,
torch.relu: torch_relu_override,
@ -274,7 +274,7 @@ class MetaTracer(torch.fx.Tracer):
def proxy(self, node):
return MetaProxy(node, self)
def trace(self, root, meta_args: Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
assert isinstance(meta_args, dict)
self.meta_args = meta_args
@ -299,8 +299,8 @@ class MetaTracer(torch.fx.Tracer):
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
meta_args: Optional[Dict[str, torch.Tensor]] = None,
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[dict[str, torch.Tensor]] = None,
concrete_args: Optional[dict[str, Any]] = None,
) -> torch.fx.GraphModule:
tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import operator
import warnings
from typing import Callable, Dict, Iterable, TypeVar
from collections.abc import Iterable
from typing import Callable, TypeVar
from typing_extensions import ParamSpec
import torch
@ -57,7 +58,7 @@ from torch.nn.modules.conv import Conv2d
_T = TypeVar("_T")
_P = ParamSpec("_P")
_INFERENCE_RULES: Dict[Target, Callable] = {}
_INFERENCE_RULES: dict[Target, Callable] = {}
MAX_TENSOR_RANK = 4

View File

@ -1,7 +1,7 @@
# mypy: ignore-errors
import copy
import itertools
from typing import Callable, Dict, List
from typing import Callable
from torch.fx.experimental.migrate_gradual_types.constraint import (
ApplyBroadcasting,
@ -50,7 +50,7 @@ from torch.fx.experimental.migrate_gradual_types.util import (
from torch.fx.tensor_type import Dyn, TensorType
_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {}
_TRANSFORMATION_RULES: dict[Constraint, Callable] = {}
def register_transformation_rule(call_target):
@ -797,7 +797,7 @@ def transform_constraint(constraint: Constraint, counter: int):
return constraint, counter
def calc_last_two_dims(constraint, d: List[DVar]):
def calc_last_two_dims(constraint, d: list[DVar]):
"""
Generates constraints for the last two dimensions of a convolution or a maxpool output
Args:
@ -866,7 +866,7 @@ def calc_last_two_dims(constraint, d: List[DVar]):
return c4, c5
def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]):
"""
Generate all possibilities of being equal or not equal to dyn for my_list
Args:
@ -888,7 +888,7 @@ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
return all_possibilities
def is_target_div_by_dim(target: List[int], dim: List[DVar]):
def is_target_div_by_dim(target: list[int], dim: list[DVar]):
"""
Generate constraints to check if the target dimensions are divisible by the input dimensions
Args:
@ -901,7 +901,7 @@ def is_target_div_by_dim(target: List[int], dim: List[DVar]):
return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
def is_dim_div_by_target(target: List[int], dim: List[DVar]):
def is_dim_div_by_target(target: list[int], dim: list[DVar]):
"""
Generate constraints to check if the input dimensions is divisible by the target dimensions
Args:
@ -1000,9 +1000,9 @@ def apply_padding(
e11: BinConstraintT,
e2: BinConstraintT,
e12: BinConstraintT,
d2: List[DVar],
d11: List[DVar],
d12: List[DVar],
d2: list[DVar],
d11: list[DVar],
d12: list[DVar],
counter: int,
):
"""
@ -1068,7 +1068,7 @@ def apply_padding(
def no_broadcast_dim_with_index(
d1: List[DVar], d2: List[DVar], d3: List[DVar], d4: List[DVar], i: int
d1: list[DVar], d2: list[DVar], d3: list[DVar], d4: list[DVar], i: int
):
"""
Args:
@ -1129,10 +1129,10 @@ def create_equality_constraints_for_broadcasting(
e2: TVar,
e11: TVar,
e12: TVar,
d1: List[DVar],
d2: List[DVar],
d11: List[DVar],
d12: List[DVar],
d1: list[DVar],
d2: list[DVar],
d11: list[DVar],
d12: list[DVar],
):
"""
Create equality constraints for when no broadcasting occurs
@ -1236,7 +1236,7 @@ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
def generate_all_broadcasting_possibilities_no_padding(
d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]
d1: list[DVar], d2: list[DVar], d11: list[DVar], d12: list[DVar]
):
"""
Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import operator
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Optional
import torch
import torch.fx
@ -38,7 +38,7 @@ class NormalizeArgs(Transformer):
self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
):
super().__init__(module)
self.node_map: Dict[Proxy, Node] = {}
self.node_map: dict[Proxy, Node] = {}
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
def run_node(self, n: Node) -> Any:
@ -66,10 +66,10 @@ class NormalizeArgs(Transformer):
def call_function(
self,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
arg_types: Optional[Tuple[Any, ...]] = None,
kwarg_types: Optional[Dict[str, Any]] = None,
args: tuple[Argument, ...],
kwargs: dict[str, Any],
arg_types: Optional[tuple[Any, ...]] = None,
kwarg_types: Optional[dict[str, Any]] = None,
):
assert callable(target)
new_args_and_kwargs = normalize_function(
@ -89,7 +89,7 @@ class NormalizeArgs(Transformer):
return super().call_function(target, args, kwargs)
def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
):
assert isinstance(target, str)
new_args_and_kwargs = normalize_module(
@ -124,7 +124,7 @@ class NormalizeOperators(AnnotateTypesWithSchema):
traced = NormalizeOperators(traced).transform()
"""
binary_magic_method_remap: Dict[
binary_magic_method_remap: dict[
Callable[[Any, Any], Any], Callable[[Any, Any], Any]
] = {
torch.add: operator.add,
@ -142,7 +142,7 @@ class NormalizeOperators(AnnotateTypesWithSchema):
}
def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
):
# Normalize operators according to the magic methods implemented on tensors here:
# https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950

View File

@ -4,8 +4,9 @@ import logging
import operator
import time
from collections import defaultdict
from collections.abc import Iterable
from enum import Enum
from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type
from typing import Any, cast, Optional
import torch
import torch.fx as fx
@ -33,7 +34,7 @@ __all__ = [
]
def _parent_name(target: str) -> Tuple[str, str]:
def _parent_name(target: str) -> tuple[str, str]:
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
@ -44,11 +45,11 @@ def _parent_name(target: str) -> Tuple[str, str]:
# Works for length 2 patterns with 2 modules
def matches_module_pattern(
pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]
pattern: Iterable[type], node: fx.Node, modules: dict[str, Any]
):
if len(node.args) == 0:
return False
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
nodes: tuple[Any, fx.Node] = (node.args[0], node)
for expected_type, current_node in zip(pattern, nodes):
if not isinstance(current_node, fx.Node):
return False
@ -64,7 +65,7 @@ def matches_module_pattern(
def replace_node_module(
node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module
node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module
):
assert isinstance(node.target, str)
parent_name, name = _parent_name(node.target)
@ -120,7 +121,7 @@ def remove_dropout(model: nn.Module) -> nn.Module:
class DropoutRemover(torch.fx.Transformer):
def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
if isinstance(self.submodules[target], nn.Dropout):
assert len(args) == 1
@ -133,15 +134,15 @@ def remove_dropout(model: nn.Module) -> nn.Module:
def extract_subgraph(
orig_module: nn.Module,
nodes: List[fx.Node],
inputs: List[fx.Node],
outputs: List[fx.Node],
nodes: list[fx.Node],
inputs: list[fx.Node],
outputs: list[fx.Node],
):
"""
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
"""
new_graph = fx.Graph()
env: Dict[fx.Node, fx.Node] = {}
env: dict[fx.Node, fx.Node] = {}
for input in inputs:
new_node = new_graph.placeholder(input.name)
env[input] = new_node
@ -180,13 +181,13 @@ mkldnn_map = {
}
def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
"""
For each node, if it's a module that can be preconverted into MKLDNN,
then we do so and create a mapping to allow us to convert from the MKLDNN
version of the module to the original.
"""
old_modules: Dict[nn.Module, nn.Module] = {}
old_modules: dict[nn.Module, nn.Module] = {}
for node in nodes:
if node.op == "call_module":
assert isinstance(node.target, str)
@ -200,9 +201,9 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
def reset_modules(
nodes: List[fx.Node],
modules: Dict[str, nn.Module],
old_modules: Dict[nn.Module, nn.Module],
nodes: list[fx.Node],
modules: dict[str, nn.Module],
old_modules: dict[nn.Module, nn.Module],
):
"""
Maps each module that's been changed with `modules_to_mkldnn` back to its
@ -219,9 +220,9 @@ def reset_modules(
class MklSubgraph:
def __init__(self, fx_graph: fx.Graph):
self.fx_graph = fx_graph
self.nodes: List[fx.Node] = []
self.start_nodes: List[fx.Node] = []
self.end_nodes: List[fx.Node] = []
self.nodes: list[fx.Node] = []
self.start_nodes: list[fx.Node] = []
self.end_nodes: list[fx.Node] = []
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
@ -244,7 +245,7 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
ShapeProp(fx_model).propagate(example_inputs)
sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes])
output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes])
submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
def benchmark(f):
@ -281,8 +282,8 @@ def use_mkl_length(graph: MklSubgraph) -> bool:
class UnionFind:
def __init__(self, n):
self.parent: List[Optional[int]] = [None] * n
self.size: List[int] = [0] * n
self.parent: list[Optional[int]] = [None] * n
self.size: list[int] = [0] * n
def make_set(self, v: int):
self.parent[v] = v
@ -308,8 +309,8 @@ class UnionFind:
def optimize_for_inference(
model: torch.nn.Module,
pass_config: Optional[Dict[str, Any]] = None,
tracer: Type[fx.Tracer] = fx.Tracer,
pass_config: Optional[dict[str, Any]] = None,
tracer: type[fx.Tracer] = fx.Tracer,
) -> torch.nn.Module:
"""
Performs a set of optimization passes to optimize a model for the
@ -348,7 +349,7 @@ def optimize_for_inference(
cur_tracer = tracer()
fx_graph = cur_tracer.trace(copy.deepcopy(model))
fx.GraphModule(cur_tracer.root, fx_graph)
modules: Dict[str, nn.Module] = dict(model.named_modules())
modules: dict[str, nn.Module] = dict(model.named_modules())
class MklSupport(Enum):
NO = 1
@ -388,7 +389,7 @@ def optimize_for_inference(
node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
)
node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
node.args = cast(tuple[fx.node.Argument], mkldnn_args)
with fx_graph.inserting_after(node):
dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
@ -455,7 +456,7 @@ def optimize_for_inference(
for other_color in cur_colors[1:]:
uf.join(cur_colors[0], other_color)
mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
for node in fx_graph.nodes:
if hasattr(node, "color"):
mkldnn_graphs[uf.find(node.color)].nodes.append(node)

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import Dict, List, NamedTuple, Set
from typing import NamedTuple
from torch.fx.node import map_arg, Node
@ -11,13 +11,13 @@ class Partition:
"""
def __init__(self, partition_id: int) -> None:
self.nodes: Set[Node] = set()
self.nodes: set[Node] = set()
self.partition_id = partition_id
self.parents: Set[Partition] = set()
self.children: Set[Partition] = set()
self.parents: set[Partition] = set()
self.children: set[Partition] = set()
self.bfs_level: int = -1
self.used_mem_bytes: int = 0
self.logical_device_ids: List[int] = []
self.logical_device_ids: list[int] = []
def __str__(self):
return str(self.partition_id)
@ -28,7 +28,7 @@ class Partition:
self.used_mem_bytes += get_extra_size_of(node, self.nodes)
def add_node(self, node):
input_nodes: Dict[Node, None] = {}
input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Add current node's input nodes if they are placeholder or constants
@ -43,7 +43,7 @@ class Partition:
if node in self.nodes:
self.nodes.remove(node)
# Collect the node's input nodes
input_nodes: Dict[Node, None] = {}
input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Check if an input node is a placeholder or get_attr,
@ -88,23 +88,23 @@ class PartitionMode(Enum):
class PartitionerConfig(NamedTuple):
devices: List[Device]
devices: list[Device]
mode: PartitionMode = PartitionMode.size_based
transfer_rate_bytes_per_sec: float = 0.0
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
node_to_partition_mapping: Dict[Node, int] = {}
partition_to_logical_device_mapping: Dict[int, List[int]] = {}
node_to_latency_mapping: dict[Node, NodeLatency] = {}
node_to_partition_mapping: dict[Node, int] = {}
partition_to_logical_device_mapping: dict[int, list[int]] = {}
# Saturate host by replicating partitions to the remaining idle devices.
saturate_host: bool = False
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
def get_extra_size_of(node: Node, nodes: set[Node]) -> int:
"""Given a node and a set of nodes,
this function return the extra size that needed
if this node is included in this set.
"""
# Find all its input nodes
input_nodes: Dict[Node, None] = {}
input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# Calculate total size of related nodes
@ -127,18 +127,18 @@ def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
def get_latency_of_one_partition(
partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
partition: Partition, node_to_latency_mapping: dict[Node, NodeLatency]
) -> PartitionLatency:
"""Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
def get_top_nodes(partition: Partition) -> List[Node]:
def get_top_nodes(partition: Partition) -> list[Node]:
"""Given a partition, return a list of nodes on the top bfs level"""
top_nodes: List[Node] = []
top_nodes: list[Node] = []
for node in partition.nodes:
# Skip placeholder and get_attr nodes
if node.op in {"placeholder", "get_attr"}:
continue
input_nodes: Dict[Node, None] = {}
input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
# If a node has no input nodes in this partition,
@ -216,12 +216,12 @@ def get_latency_of_one_partition(
def get_partition_to_latency_mapping(
partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
) -> Dict[Partition, PartitionLatency]:
partitions: list[Partition], node_to_latency_mapping: dict[Node, NodeLatency]
) -> dict[Partition, PartitionLatency]:
"""Given all the partitions and node_to_latency_mapping dictionary,
return a mapping dictionary of each partition to its overall latency
"""
partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
partition_to_latency_mapping: dict[Partition, PartitionLatency] = {}
# Go through each partition and get its latency
for partition in partitions:
partition_latency = get_latency_of_one_partition(
@ -255,7 +255,7 @@ def get_comm_latency_between(
# the output size of those input nodes will be counted
# and added to comm_size
for node in child_partition.nodes:
input_nodes: Dict[Node, None] = {}
input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault)
for n in input_nodes:
@ -268,8 +268,8 @@ def get_comm_latency_between(
def get_latency_of_partitioned_graph(
partitions: List[Partition],
partition_to_latency_mapping: Dict[Partition, PartitionLatency],
partitions: list[Partition],
partition_to_latency_mapping: dict[Partition, PartitionLatency],
transfer_rate_bytes_per_sec: float,
):
"""Given all partitions in a graph, find the critical path among all partitions
@ -298,7 +298,7 @@ def get_latency_of_partitioned_graph(
return max_latency_sec
return latency_so_far_sec
def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
def get_top_partitions(partitions: list[Partition]) -> list[Partition]:
"""This function is to return all the partitions without parents
as the starting points of all the paths
"""

View File

@ -17,21 +17,15 @@ import typing_extensions
import warnings
import weakref
from collections import defaultdict
from collections.abc import Generator, Mapping, Sequence
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Mapping,
Optional,
overload,
Protocol,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
@ -168,7 +162,7 @@ from torch.types import py_sym_types, PySymType
class _HasMeta(Protocol):
meta: Dict[str, PySymType]
meta: dict[str, PySymType]
def is_sym_node(node: _HasMeta) -> bool:
@ -377,9 +371,9 @@ _ExtractValType = Optional[
PySymType,
_AnyScriptObjectType,
BackwardState,
List["_ExtractValType"],
Tuple["_ExtractValType", ...],
Dict[str, "_ExtractValType"],
list["_ExtractValType"],
tuple["_ExtractValType", ...],
dict[str, "_ExtractValType"],
Tensor,
int,
float,
@ -767,10 +761,10 @@ def proxy_call(
proxy_mode: ProxyTorchDispatchMode,
func: OpOverload,
pre_dispatch: bool,
args: Tuple[object, ...],
kwargs: Dict[str, object],
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
unrecognized_types: List[Type] = []
unrecognized_types: list[type] = []
flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs))
def can_handle_tensor(x: Tensor) -> bool:
@ -987,7 +981,7 @@ class _SymNodeDict:
"""
def __init__(self) -> None:
self.sym_node_dict: Dict[PySymType, _PySymProxyType] = {}
self.sym_node_dict: dict[PySymType, _PySymProxyType] = {}
def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None:
self.sym_node_dict[key.node] = value
@ -1015,9 +1009,9 @@ class _SymNodeDict:
class PythonKeyTracer(Tracer):
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
symnode_tracker: _SymNodeDict
sympy_expr_tracker: Dict[sympy.Symbol, object]
sympy_expr_tracker: dict[sympy.Symbol, object]
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
torch_fn_counts: Dict[OpOverload, int]
torch_fn_counts: dict[OpOverload, int]
enable_thunkify: bool = False
def __init__(self) -> None:
@ -1043,14 +1037,14 @@ class PythonKeyTracer(Tracer):
self,
m: Module,
forward: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
return forward(*args, **kwargs)
# We don't want to turn getattr calls into proxies. So we just return the actual value.
def getattr(
self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy]
self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
) -> object:
return attr_val
@ -1095,7 +1089,7 @@ class PythonKeyTracer(Tracer):
def _make_temp_remove_mode_context_manager(
mode_ty: Type[TorchFunctionMode],
mode_ty: type[TorchFunctionMode],
) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]:
@contextmanager
def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]:
@ -1137,7 +1131,7 @@ def _make_temp_remove_mode_context_manager(
def dispatch_trace(
root: Union[Module, Callable],
tracer: Tracer,
concrete_args: Optional[Tuple[Any, ...]] = None,
concrete_args: Optional[tuple[Any, ...]] = None,
) -> GraphModule:
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
@ -1235,9 +1229,9 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
def __torch_function__(
self,
func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None,
types: tuple[torch._C._TensorMeta, ...],
args: tuple[object, ...] = (),
kwargs: Optional[dict[str, object]] = None,
) -> object:
kwargs = kwargs or {}
self.tracer.torch_fn_metadata = func
@ -1259,14 +1253,14 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
# The input to torch.amp.autocast_mode._exit_autocast graph node should be the
# enter_autocast node. So we have to save the enter autocast node here, and assign it
# to the exit_autocast call_function node.
self.enter_autocast_nodes: List[torch.fx.Node] = []
self.enter_autocast_nodes: list[torch.fx.Node] = []
def __torch_function__(
self,
func: Union[OpOverload, Callable],
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None,
types: tuple[torch._C._TensorMeta, ...],
args: tuple[object, ...] = (),
kwargs: Optional[dict[str, object]] = None,
) -> object:
kwargs = kwargs or {}
if func in _side_effectful_need_to_be_preserved_pre_dispatch:
@ -1324,7 +1318,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
# Every time we enter a mode, we maintain a stack telling us what the previous
# ProxyTorchDispatchMode state was (if there was any).
# This lets us properly reset the state on exit.
self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = []
self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = []
self.decomp_layers = 0
from torch._inductor import config
@ -1334,9 +1328,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(
self,
func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None,
types: tuple[torch._C._TensorMeta, ...],
args: tuple[object, ...] = (),
kwargs: Optional[dict[str, object]] = None,
) -> object:
with set_original_aten_op(func):
kwargs = kwargs or {}
@ -1354,7 +1348,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> Optional[bool]:
@ -1372,10 +1366,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
return True
def _compute_proxy(
self, func: OpOverload, args: Tuple[object, ...], out: PySymType
self, func: OpOverload, args: tuple[object, ...], out: PySymType
) -> Proxy:
# Handle torch.sym_sum
n_args: Tuple[object, ...]
n_args: tuple[object, ...]
if len(args) == 1 and isinstance(args[0], (list, tuple)):
n_args = (
tuple(
@ -1403,9 +1397,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def __sym_dispatch__(
self,
func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...],
kwargs: Dict[str, object],
types: tuple[torch._C._TensorMeta, ...],
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
# Peephole optimize multiply by one
# NB: be careful not to trigger guards here!
@ -1438,9 +1432,9 @@ class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
symnode_tracker: MutableMapping[PySymType, _PySymProxyType]
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
sympy_expr_tracker: Dict[sympy.Symbol, object]
sympy_expr_tracker: dict[sympy.Symbol, object]
torch_fn_metadata: Optional[OpOverload]
torch_fn_counts: Dict[OpOverload, int]
torch_fn_counts: dict[OpOverload, int]
enable_thunkify: bool = False
def __init__(self, graph: fx.graph.Graph) -> None:
@ -1476,7 +1470,7 @@ class DecompositionInterpreter(fx.Interpreter):
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
def placeholder(
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
) -> object:
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
@ -1485,7 +1479,7 @@ class DecompositionInterpreter(fx.Interpreter):
return out
def get_attr(
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
) -> object:
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
@ -1495,7 +1489,7 @@ class DecompositionInterpreter(fx.Interpreter):
# call_function, call_method, call_module get traced automatically by the outer mode.
def output(
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
) -> object:
out = super().output(target, args, kwargs) # type: ignore[arg-type]
@ -1516,13 +1510,13 @@ class DecompositionInterpreter(fx.Interpreter):
def wrapper_and_args_for_make_fx(
func: Callable[..., R], args: Tuple[object, ...], kwargs: Dict[str, object]
) -> Tuple[Callable[[List[object]], R], List[object]]:
func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object]
) -> tuple[Callable[[list[object]], R], list[object]]:
# make_fx doesn't support kwargs, so we need to do this flattening
# and then unflatten the args before calling func
flat_args, spec = pytree.tree_flatten((args, kwargs))
def wrapped(flat_args: List[object]) -> R:
def wrapped(flat_args: list[object]) -> R:
fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec)
return func(*fn_args, **fn_kwargs)
@ -1642,7 +1636,7 @@ class _ModuleStackTracer(PythonKeyTracer):
return tracer.proxy_modules[self]
@property
def _modules(self) -> Dict[str, AttrProxy]:
def _modules(self) -> dict[str, AttrProxy]:
assert "_modules" in self.__dict__
submodules = self.__dict__["_modules"]
assert isinstance(submodules, dict)
@ -1674,7 +1668,7 @@ class _ModuleStackTracer(PythonKeyTracer):
raise _ModuleNotInstalledAsSubmoduleError from e
def getattr(
self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy]
self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
) -> object:
if (
not isinstance(attr_val, Module)
@ -1693,7 +1687,7 @@ class _ModuleStackTracer(PythonKeyTracer):
return self.attr_proxy_map[attr_val]
def trace( # type: ignore[override]
self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]]
self, root: Union[Module, Callable], concrete_args: Optional[dict[str, object]]
) -> fx.Graph:
res = super().trace(root, concrete_args)
@ -1702,7 +1696,7 @@ class _ModuleStackTracer(PythonKeyTracer):
# to the tracer while tracing, the proxy object gets registered
# first. So we need to replace the proxy modules with the real ones
# This can happen during HOO tracing
proxy_module_names_to_be_replaced: List[Tuple[str, _AttrProxy]] = []
proxy_module_names_to_be_replaced: list[tuple[str, _AttrProxy]] = []
for name, module in self.root.named_modules():
if module in self.proxy_modules:
proxy_module_names_to_be_replaced.append((name, module))
@ -1746,8 +1740,8 @@ class _ModuleStackTracer(PythonKeyTracer):
self,
m: Module,
forward: Callable,
args: Tuple[object, ...],
kwargs: Dict[str, object],
args: tuple[object, ...],
kwargs: dict[str, object],
) -> None:
"""PythonKeyTracer overrides call_module to avoid the scope handling,
but we actually want it.
@ -1857,7 +1851,7 @@ class _MakefxTracer:
) -> None:
# Configurations that are used to initialize the context managers and their states.
# Should not modify them during tracing.
self.decomposition_table: Dict[OpOverload, Callable] = dict(
self.decomposition_table: dict[OpOverload, Callable] = dict(
decomposition_table or {}
)
self.decomposition_table.setdefault(
@ -1885,7 +1879,7 @@ class _MakefxTracer:
nullcontext, TorchFunctionMetadataMode
] = nullcontext()
def _checkpoint_modes(self) -> List[Any]:
def _checkpoint_modes(self) -> list[Any]:
return [
self.fake_tensor_mode,
self.proxy_mode,
@ -1913,7 +1907,7 @@ class _MakefxTracer:
@contextmanager
def _init_modes_from_inputs(
self, f: Callable, args: Tuple[object, ...]
self, f: Callable, args: tuple[object, ...]
) -> Generator[None, None, None]:
prev_modes = self._checkpoint_modes()
try:
@ -2202,7 +2196,7 @@ def make_fx(
return wrapped
def get_torch_dispatch_modes() -> List[TorchDispatchMode]:
def get_torch_dispatch_modes() -> list[TorchDispatchMode]:
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
@ -2240,7 +2234,7 @@ def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs)
# dispatch machinery which disables it for us
with disable_proxy_modes_tracing():
# TODO: properly compute types
types: List[Type] = []
types: list[type] = []
return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value]
@ -2252,8 +2246,8 @@ def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, Non
def maybe_handle_decomp(
proxy_mode: ProxyTorchDispatchMode,
op: OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
from torch._inductor.compiler_bisector import CompilerBisector
@ -2274,8 +2268,8 @@ def maybe_handle_decomp(
def get_isolated_graphmodule(
func: Callable,
args: Tuple[object, ...],
kwargs: Dict[str, object],
args: tuple[object, ...],
kwargs: dict[str, object],
tracing_mode: str = "real",
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
) -> GraphModule:

View File

@ -4,7 +4,7 @@ import inspect
import itertools
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -83,11 +83,11 @@ class ShapeEnvEvent:
f: Callable
# Arguments and keyword arguments called with.
args: Optional[List[Any]] = None
kwargs: Optional[Dict[str, Any]] = None
args: Optional[list[Any]] = None
kwargs: Optional[dict[str, Any]] = None
# List of tracked_fakes at the time the method was called.
tracked_fakes: Optional[List[Any]] = None
tracked_fakes: Optional[list[Any]] = None
# Name of the captured event.
# Used for special handling of particular methods.
@ -344,15 +344,15 @@ def replay_shape_env_events(events):
# ShapeEnv.produce_guards.
@dataclass
class FakeTensorMeta:
tensor_size: Tuple[Union[int, torch.SymInt], ...]
tensor_stride: Tuple[Union[int, torch.SymInt], ...]
tensor_size: tuple[Union[int, torch.SymInt], ...]
tensor_stride: tuple[Union[int, torch.SymInt], ...]
tensor_storage_offset: Union[int, torch.SymInt]
is_nested: bool
def size(self) -> Tuple[Union[int, torch.SymInt], ...]:
def size(self) -> tuple[Union[int, torch.SymInt], ...]:
return self.tensor_size
def stride(self) -> Tuple[Union[int, torch.SymInt], ...]:
def stride(self) -> tuple[Union[int, torch.SymInt], ...]:
return self.tensor_stride
def storage_offset(self) -> Union[int, torch.SymInt]:
@ -445,7 +445,7 @@ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value)
# compare the two values.
def compare_vars(
map_value: Callable[[str, Any], Any]
) -> List[Tuple[str, str, str]]:
) -> list[tuple[str, str, str]]:
env1_set, env2_set = set(env1_vars), set(env2_vars)
# First, compare the set of keys in each vars dictionary.
@ -489,7 +489,7 @@ class NotEqualError(Exception):
def __init__(
self,
msg: str,
mismatched: List[Tuple[str, str, str]],
mismatched: list[tuple[str, str, str]],
) -> None:
details = "\n".join(
[

View File

@ -6,7 +6,7 @@ import functools
import inspect
import textwrap
from types import FunctionType
from typing import Any, Callable, cast, Dict, Optional, Union
from typing import Any, Callable, cast, Optional, Union
import torch
from torch._sources import normalize_source_lines
@ -112,7 +112,7 @@ class RewritingTracer(Tracer):
def trace(
self,
root: Union[torch.nn.Module, Callable],
concrete_args: Optional[Dict[str, Any]] = None,
concrete_args: Optional[dict[str, Any]] = None,
) -> Graph:
return super().trace(_rewrite(root), concrete_args)

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import inspect
from typing import Any, Dict, Optional, Tuple
from typing import Any, Optional
import torch
import torch.fx
@ -42,7 +42,7 @@ class AnnotateTypesWithSchema(Transformer):
self.annotate_get_attrs = annotate_get_attrs
def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
):
python_ret_type = None
if self.annotate_functionals and target.__module__ == "torch.nn.functional":
@ -73,7 +73,7 @@ class AnnotateTypesWithSchema(Transformer):
return return_proxy
def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
):
python_ret_type = None
assert isinstance(target, str)
@ -91,8 +91,8 @@ class AnnotateTypesWithSchema(Transformer):
def get_attr(
self,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
args: tuple[Argument, ...],
kwargs: dict[str, Any],
):
attr_proxy = super().get_attr(target, args, kwargs)

View File

@ -1,5 +1,6 @@
import re
from typing import Any, DefaultDict, Dict, List, Tuple, Union
from collections import defaultdict
from typing import Any, Union
import numpy as np
import sympy as sp
@ -13,10 +14,10 @@ s_pattern = r"s\d+"
def infer_symbol_values(
symints: List[Union[torch.SymInt, int]],
init_symints: List[Union[torch.SymInt, int]],
symbol_idx_dict: Dict[str, int],
padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]],
symints: list[Union[torch.SymInt, int]],
init_symints: list[Union[torch.SymInt, int]],
symbol_idx_dict: dict[str, int],
padding_constraints: defaultdict[torch.SymInt, list[Union[sp.Expr, int]]],
constraint: str,
) -> None:
if constraint.find("non-singleton") != -1:
@ -83,8 +84,8 @@ def infer_symbol_values(
def calculate_value(
left_expression: Union[str, Any, None],
right_expression: Union[str, Any, None],
symints: List[Union[torch.SymInt, int]],
symbol_idx_dict: Dict[str, int],
symints: list[Union[torch.SymInt, int]],
symbol_idx_dict: dict[str, int],
) -> None:
var, val = solve_equation(left_expression, right_expression)
idx = symbol_idx_dict[var]
@ -95,7 +96,7 @@ def calculate_value(
def solve_equation(
left_expression: Union[str, Any, None],
right_expression: Union[str, Any, None],
) -> Tuple[str, int]:
) -> tuple[str, int]:
expression = f"{left_expression} - {right_expression}"
var = re.findall(s_pattern, expression)[0]
if re.findall(parentheses_pattern, expression):
@ -116,9 +117,9 @@ def solve_equation(
def update_equation(
symints: List[Union[torch.SymInt, int]],
init_symints: List[Union[torch.SymInt, int]],
padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]],
symints: list[Union[torch.SymInt, int]],
init_symints: list[Union[torch.SymInt, int]],
padding_constraints: defaultdict[torch.SymInt, list[Union[sp.Expr, int]]],
init_eq: sp.Expr,
new_mod_num: int,
var: torch.SymInt,

View File

@ -20,7 +20,7 @@ import math
import operator
import sys
from functools import lru_cache, update_wrapper
from typing import Optional, Type, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
import torch
@ -1272,7 +1272,7 @@ def _make_node_magic(method, func):
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
raise
sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
pytype: Type
pytype: type
# This is not strictly correct. In Python, a**b may return complex when
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
# returns a float while both arguments are ints: 2**(-1). Also, max and
@ -1335,7 +1335,7 @@ def _make_node_magic(method, func):
out_hint = None
if self.hint is not None:
out_hint = op(self.hint)
pytype: Type
pytype: type
if method in always_int_magic_methods:
pytype = int
elif method in always_bool_magic_methods:
@ -1485,7 +1485,7 @@ def _make_node_sizes_strides(method, func):
out_hint = op(size_hints, stride_hints)
# NB: This is the indicator function, not the actual bool!
pytype: Type
pytype: type
if method.endswith("_indicator"):
pytype = int
else:

View File

@ -23,7 +23,8 @@ import re
import sys
import threading
import traceback
from collections import defaultdict
from collections import Counter, defaultdict
from collections.abc import Iterator, Mapping, Sequence
from contextlib import _GeneratorContextManager, contextmanager
from dataclasses import dataclass, field
from enum import Enum
@ -31,19 +32,9 @@ from typing import (
Any,
Callable,
cast,
Counter,
DefaultDict,
Dict,
Iterator,
List,
Mapping,
NamedTuple,
NoReturn,
Optional,
Sequence,
Set,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
@ -104,8 +95,8 @@ if TYPE_CHECKING:
from torch.types import BoolLikeType
InputList = List
DimList = List
InputList = list
DimList = list
log = logging.getLogger(__name__)
@ -236,8 +227,8 @@ class SymIntEqByExpr:
def _nested_int_aware_sort(
tup: Tuple[Union[SymInt, int], int]
) -> Tuple[int, Union[SymInt, int], int]:
tup: tuple[Union[SymInt, int], int]
) -> tuple[int, Union[SymInt, int], int]:
return (
# Order nested ints by their coefficients.
# 1 here to order nested ints after non-nested-ints.
@ -289,7 +280,7 @@ def lru_cache(
# These are modules that contain generic code for interacting with ShapeEnv
# which are unlikely to identify a particular interesting guard statement
@lru_cache(None)
def uninteresting_files() -> Set[str]:
def uninteresting_files() -> set[str]:
import torch._compile
import torch._dynamo.eval_frame
import torch._inductor.sizevars
@ -332,8 +323,8 @@ def has_symbolic_sizes_strides(elem: torch.Tensor) -> bool:
Int: TypeAlias = Union[torch.SymInt, int]
def create_contiguous(shape: Sequence[Int]) -> List[Int]:
strides: List[Int] = [1]
def create_contiguous(shape: Sequence[Int]) -> list[Int]:
strides: list[Int] = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1]) # type: ignore[operator]
return list(reversed(strides))
@ -461,15 +452,15 @@ def check_consistent(new: _T, old: _T) -> None:
def resolve_unbacked_bindings(
shape_env: Optional[ShapeEnv],
bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]],
) -> Optional[Dict[sympy.Symbol, pytree.KeyPath]]:
bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]:
if bindings is None:
return None
assert shape_env is not None
return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()}
Result: TypeAlias = Union[torch.Tensor, Tuple[torch.Tensor, ...]]
Result: TypeAlias = Union[torch.Tensor, tuple[torch.Tensor, ...]]
def rebind_unbacked(
@ -557,7 +548,7 @@ def rebind_unbacked(
and len(raw_u1.args) == 2
and (
raw_u1_args0 := cast(
Tuple[sympy.Basic, sympy.Basic], raw_u1.args[0]
tuple[sympy.Basic, sympy.Basic], raw_u1.args[0]
)
)
and raw_u1_args0[0] == 1
@ -565,7 +556,7 @@ def rebind_unbacked(
and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol)
and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1))
and eq.rhs == 1
and cast(Tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True)
and cast(tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True)
):
# This is what the pattern match above is testing
repacked = _sympy_cast_symbool_to_symint_guardless(
@ -645,8 +636,8 @@ def canonicalize_bool_expr(expr: _T) -> _T:
def _sympy_from_args(
cls: Union[Type[sympy.Add], Type[sympy.Mul]],
args: List[sympy.Expr],
cls: type[Union[sympy.Add, sympy.Mul]],
args: list[sympy.Expr],
sort: bool = True,
is_commutative: Optional[bool] = None,
) -> sympy.Expr:
@ -686,7 +677,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
return type(expr)(*map(canonicalize_bool_expr, expr.args))
opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
t: Union[Type[Any]]
t: Union[type[Any]]
if isinstance(expr, tuple(opposite.keys())):
rhs = expr.lhs - expr.rhs # type: ignore[attr-defined]
t = opposite[type(expr)] # type: ignore[index]
@ -888,7 +879,7 @@ def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]:
def find_symbol_binding_fx_nodes(
graph: torch.fx.Graph,
) -> Dict[sympy.Symbol, torch.fx.Node]:
) -> dict[sympy.Symbol, torch.fx.Node]:
r = {}
# NB: Prefer first occurrence of symbol
for node in graph.nodes:
@ -949,7 +940,7 @@ def compute_unbacked_bindings(
example_value: object,
old_example_value: Optional[object] = None,
peek: bool = False,
) -> Optional[Dict[sympy.Symbol, pytree.KeyPath]]:
) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]:
"""
After having run fake tensor propagation and producing example_value
result, traverse example_value looking for freshly bound unbacked
@ -977,7 +968,7 @@ def compute_unbacked_bindings(
def free_unbacked_symbols_with_path(
a: object, path: pytree.KeyPath, real: Optional[object] = None
) -> Dict[sympy.Symbol, pytree.KeyPath]:
) -> dict[sympy.Symbol, pytree.KeyPath]:
assert shape_env is not None
r = {}
if isinstance(a, (tuple, list)):
@ -1456,11 +1447,11 @@ def guard_float(a: Union[SymFloat, float]) -> float:
# Given a GraphModule, return all the FakeTensors for all the placeholders
def fx_placeholder_vals(gm: torch.fx.GraphModule) -> List[object]:
def fx_placeholder_vals(gm: torch.fx.GraphModule) -> list[object]:
return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"]
def fx_placeholder_targets(gm: torch.fx.GraphModule) -> List[str]:
def fx_placeholder_targets(gm: torch.fx.GraphModule) -> list[str]:
return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
@ -1475,7 +1466,7 @@ def eval_guards(
)
def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> Dict[sympy.Symbol, int]:
def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> dict[sympy.Symbol, int]:
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr]
@ -1617,15 +1608,15 @@ class EqualityConstraint(Constraint):
form and so the problem reduces to symbolic expression equality.)
"""
source_pairs: List[Tuple[Source, Source]]
derived_equalities: List[
Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]
source_pairs: list[tuple[Source, Source]]
derived_equalities: list[
tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]
]
phantom_symbols: List[sympy.Symbol]
relaxed_sources: Set[Source]
phantom_symbols: list[sympy.Symbol]
relaxed_sources: set[Source]
_parents: Dict[Source, Source] = field(init=False)
_defs: Dict[Source, sympy.Expr] = field(init=False)
_parents: dict[Source, Source] = field(init=False)
_defs: dict[Source, sympy.Expr] = field(init=False)
def __post_init__(self) -> None:
"""
@ -1643,12 +1634,12 @@ class EqualityConstraint(Constraint):
# self._parents is a map from input sources to input sources where, conceptually,
# these are directed edges in a union-find forest
_parents: Dict[Source, Source] = {}
_parents: dict[Source, Source] = {}
object.__setattr__(self, "_parents", _parents)
# self._defs is a map from input sources to "canonical" symbolic expressions,
# i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,
# not derived Dims)
_defs: Dict[Source, sympy.Expr] = {}
_defs: dict[Source, sympy.Expr] = {}
object.__setattr__(self, "_defs", _defs)
for source1, source2 in self.source_pairs:
@ -1838,7 +1829,7 @@ class StatefulSymbolicContext(StatelessSymbolicContext):
# cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never
# get recorded in var_to_val, etc.
# TODO(voz): consider a weakref to the shape_env here
shape_env_to_source_to_symbol_cache: Dict[int, Dict[str, sympy.Expr]] = None # type: ignore[assignment]
shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = None # type: ignore[assignment]
def __post_init__(self) -> None:
super().__post_init__()
@ -1856,7 +1847,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext):
flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
"""
inner_contexts: Dict[str, SymbolicContext] = None # type: ignore[assignment]
inner_contexts: dict[str, SymbolicContext] = None # type: ignore[assignment]
def __post_init__(self) -> None:
super().__post_init__()
@ -1875,7 +1866,7 @@ def is_symbolic(
IndicatorTypes = (IsNonOverlappingAndDenseIndicator,)
def _expandsums(args: List[sympy.Expr]) -> Tuple[sympy.Expr, bool]:
def _expandsums(args: list[sympy.Expr]) -> tuple[sympy.Expr, bool]:
adds, other = [], []
for arg in args:
if arg.is_Add:
@ -1912,8 +1903,8 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
elif exp < 0:
return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
elif expr.is_Mul:
num: List[sympy.Expr] = []
den: List[sympy.Expr] = []
num: list[sympy.Expr] = []
den: list[sympy.Expr] = []
for arg in expr.args:
if arg.is_Pow and arg.args[1] == -1:
den.append(S.One / arg) # type: ignore[operator, arg-type]
@ -1961,7 +1952,7 @@ class _SymbolInfo(NamedTuple):
def _maybe_evaluate_static_worker(
expr: _SympyT,
# NB: this is a tuple to ensure it can be LRU cached
symbol_info: Tuple[_SymbolInfo, ...],
symbol_info: tuple[_SymbolInfo, ...],
unbacked_only: bool,
size_oblivious: bool,
) -> Optional[_SympyT]:
@ -2193,9 +2184,9 @@ class SymExprPrinter(PythonPrinter):
class _ShapeGuardPrinter(abc.ABC):
def __init__(
self,
symbol_to_source: Mapping[sympy.Symbol, List[Source]],
symbol_to_source: Mapping[sympy.Symbol, list[Source]],
source_ref: Callable[[Source], str],
var_to_sources: Mapping[sympy.Symbol, List[Source]],
var_to_sources: Mapping[sympy.Symbol, list[Source]],
) -> None:
self.symbol_to_source = symbol_to_source
self.source_ref = source_ref
@ -2246,7 +2237,7 @@ class ShapeGuardPrinter(ShapeGuardPythonPrinter):
class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter):
def __init__(self, var_to_sources: Mapping[sympy.Symbol, List[Source]]):
def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]):
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
@ -2261,7 +2252,7 @@ class DynamicDimConstraintPrinter(PythonPrinter):
def __init__(
self,
symbol_to_source: Dict[sympy.Symbol, List[Source]],
symbol_to_source: dict[sympy.Symbol, list[Source]],
source_name_to_debug_name: Mapping[str, str],
):
super().__init__()
@ -2284,23 +2275,23 @@ class DimConstraints:
def __init__(
self,
symbol_to_source: Dict[sympy.Symbol, List[Source]],
symbol_to_source: dict[sympy.Symbol, list[Source]],
var_to_val: Mapping[sympy.Symbol, sympy.Integer],
marked_dynamic: Set[sympy.Symbol],
marked_dynamic: set[sympy.Symbol],
source_name_to_debug_name: Mapping[str, str],
) -> None:
# We try to solve systems of inequalities with 1 free variable.
self._univariate_inequalities: Dict[
sympy.Symbol, Set[SympyBoolean]
self._univariate_inequalities: dict[
sympy.Symbol, set[SympyBoolean]
] = defaultdict(set)
# Among them, we prioritize solving for a free variable that has equalities.
# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
# and removing a symbol from the former => removing it from the latter.
self._symbols_with_equalities: Set[sympy.Symbol] = set()
self._symbols_with_equalities: set[sympy.Symbol] = set()
# A solution of a free variable with equalities becomes a substitution.
# We use these substitutions to simplify other constraints.
# NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.
self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {}
self._substitutions: dict[sympy.Symbol, sympy.Integer] = {}
# In general, constraints may have // and % operations.
# Of course, // can be expressed in terms of / and %.
@ -2308,20 +2299,20 @@ class DimConstraints:
# We do so by using the values of variables as hints to evaluate %.
# For soundness we record additional congruence guards and solve them separately.
self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val
self._congruences: DefaultDict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set)
self._congruences: defaultdict[sympy.Symbol, set[sympy.Expr]] = defaultdict(set)
# We do not try to (directly) solve inequalities with > 1 free variables.
# NOTE: free variables in these inequalities cannot also be in _substitutions.
self._multivariate_inequalities: Set[SympyBoolean] = set()
self._multivariate_inequalities: set[SympyBoolean] = set()
# We park external equalities between free variables here.
self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = []
self._symbolic_equivalences: list[tuple[Source, sympy.Expr]] = []
# Solutions come in two forms:
# - (static) specializations
# - (dynamic) inequalities / congruences
self._static_results: Set[str] = set()
self._dynamic_results: Set[str] = set()
self._static_results: set[str] = set()
self._dynamic_results: set[str] = set()
# printer for solutions
self._dcp = DynamicDimConstraintPrinter(
@ -2329,13 +2320,13 @@ class DimConstraints:
)
# inconsistencies found on substituting with concrete values / static solutions
self._inconsistencies: List[str] = []
self._inconsistencies: list[str] = []
# symbols that are marked dynamic
self._marked_dynamic = marked_dynamic
# track supported sympy functions and subtract from list of all sympy functions
self._supported_sympy_functions: Set[sympy.Function] = {
self._supported_sympy_functions: set[sympy.Function] = {
Application,
Mod,
PythonMod,
@ -2488,8 +2479,8 @@ class DimConstraints:
# these will resolve to either specializations or dynamic equality constraints
self._symbolic_equivalences.append((source, expr))
def _reduce_congruences(self) -> Dict[sympy.Symbol, Set[sympy.Expr]]:
reduced_congruences: Dict[sympy.Symbol, Set[sympy.Expr]] = {}
def _reduce_congruences(self) -> dict[sympy.Symbol, set[sympy.Expr]]:
reduced_congruences: dict[sympy.Symbol, set[sympy.Expr]] = {}
for s, congruences in self._congruences.items():
remainder_modulus_pairs = []
congruences_to_check = set()
@ -2650,7 +2641,7 @@ class DimConstraints:
cond = cond and isinstance(divisor, sympy.Integer)
return cond
def forced_specializations(self) -> Dict[str, sympy.Expr]:
def forced_specializations(self) -> dict[str, sympy.Expr]:
"""Returns a dictionary of the names of symbols to their specialized value"""
def debug_name(src: Source) -> str:
@ -2678,8 +2669,8 @@ class DimConstraints:
def _process_derived_dim_roots(
self,
results: Dict[str, Dict[str, Any]],
name_to_dim: Dict[str, Any],
results: dict[str, dict[str, Any]],
name_to_dim: dict[str, Any],
) -> None:
"""
Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots,
@ -2745,7 +2736,7 @@ class DimConstraints:
# {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2}
# we want instead:
# {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3}
introduced_roots: Dict[str, str] = {} # map new root -> old root
introduced_roots: dict[str, str] = {} # map new root -> old root
for k, c in list(results.items()):
if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim
root = next(iter(c["eq"].free_symbols))
@ -2782,7 +2773,7 @@ class DimConstraints:
# this consists of:
# 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim
# 2) {"dy": "dx + 1"} -> dx: root for suggested fix
modified_roots: Set[str] = set()
modified_roots: set[str] = set()
for k, c in results.items():
if k not in name_to_dim: # _dynamo.export() may handle source directly
continue
@ -2799,7 +2790,7 @@ class DimConstraints:
# evaluate the new value for each root
# this is now either 1) unchanged, 2) refined with a new range,
# or 3) specialized to a concrete value
modified_root_values: Dict[str, Dict[str, Any]] = {}
modified_root_values: dict[str, dict[str, Any]] = {}
for mroot in modified_roots:
swapped_root = True
if mroot in results:
@ -2860,9 +2851,9 @@ class DimConstraints:
def prettify_results(
self,
original_signature: inspect.Signature,
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
constraint_violation_error: object,
forced_specializations: Dict[str, str],
forced_specializations: dict[str, str],
) -> str:
"""Format a message for constraint violation erros"""
from torch.export.dynamic_shapes import _get_dim_name_mapping
@ -2876,7 +2867,7 @@ class DimConstraints:
s = s.replace(k, v) if not inverse else s.replace(v, k)
return s
results: DefaultDict[str, Dict[str, Any]] = defaultdict(dict)
results: defaultdict[str, dict[str, Any]] = defaultdict(dict)
if dynamic_shapes is None:
dynamic_shapes = {}
@ -3050,7 +3041,7 @@ class ShapeEnv:
self,
*,
should_record_events: Optional[bool] = None,
tracked_fakes: Optional[List[Any]] = None,
tracked_fakes: Optional[list[Any]] = None,
**kwargs: Any,
) -> None:
self._init(**kwargs)
@ -3086,7 +3077,7 @@ class ShapeEnv:
# Keep track of the list of tracked fakes.
self.tracked_fakes = tracked_fakes
# List of events for reconstructing ShapeEnv at arbitrary points in time.
self.events: List[ShapeEnvEvent] = (
self.events: list[ShapeEnvEvent] = (
[ShapeEnvEvent(ShapeEnv, kwargs=kwargs)]
if self.should_record_events
else []
@ -3099,7 +3090,7 @@ class ShapeEnv:
# NOTE: It's important that SymNodes in this cache have their ShapeEnv
# stripped otherwise you end up with cycles which can only be cleaned
# with the GC.
self.fake_tensor_cache: Dict[
self.fake_tensor_cache: dict[
torch._subclasses.fake_tensor._DispatchCacheKey,
torch._subclasses.fake_tensor._DispatchCacheEntry,
] = {}
@ -3134,7 +3125,7 @@ class ShapeEnv:
# symbolically equal.
duck_shape: Optional[bool] = None,
# For debugging
co_fields: Optional[Dict[str, str]] = None,
co_fields: Optional[dict[str, str]] = None,
# When True, whenever safe, we will generate a deferred runtime assert
# instead of a guard whenever we know that an expression must be True,
# otherwise it would be an error, even for backed SymInts (where we
@ -3165,50 +3156,50 @@ class ShapeEnv:
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
)
self.guards: List[ShapeGuard] = []
self.axioms: Dict[sympy.Expr, sympy.Expr] = {}
self.guards: list[ShapeGuard] = []
self.axioms: dict[sympy.Expr, sympy.Expr] = {}
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
# Like var_to_val, but only set when propagate_real_tensors is on.
# Used as last resort to avoid GuardOnDataDependent error
self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
self.unbacked_var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
# Like above, but used exclusively for OBLIVIOUS_SIZE. These
# potentially could be put together but I am not sure, writing out
# the logic individually before abstracting.
self.oblivious_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
self.oblivious_var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
# Maps symbolic ints to their min/max range. These ranges
# are conservative: the int MUST fall in the range, but the
# range may contain ints which may not actually appear in
# practice
self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
self.var_to_range_sloc: Dict[sympy.Symbol, ValueRangesSLoc] = {}
self.var_to_range: dict[sympy.Symbol, ValueRanges] = {}
self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
# When doing a size-oblivious test, exclude this integer and
# everything higher than it from the acceptable range. This solves
# https://github.com/pytorch/pytorch/issues/120288 for constant range
# case
# TODO: generalize this to work with expressions (in that case, we
# need to maintain a SET and we need extra symbolic reasoning on top)
self.oblivious_upper_bound_exclusive: Dict[sympy.Symbol, sympy.Integer] = {}
self.source_name_to_debug_name: Dict[str, str] = {}
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
self.oblivious_upper_bound_exclusive: dict[sympy.Symbol, sympy.Integer] = {}
self.source_name_to_debug_name: dict[str, str] = {}
self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
# Maps a source to the *original* symbol that was assigned to it
self.source_to_var: Dict[str, sympy.Symbol] = {}
self.source_to_var: dict[str, sympy.Symbol] = {}
# Maps from sympy ints to expressions representing them
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
self.replacements: Dict[sympy.Symbol, sympy.Expr] = {}
self.replacements: dict[sympy.Symbol, sympy.Expr] = {}
# The sloc of the guard that triggered this replacement to be added
self.replacements_slocs: Dict[sympy.Symbol, SLoc] = {}
self.unbacked_renamings: Dict[sympy.Symbol, sympy.Symbol] = {}
self.replacements_slocs: dict[sympy.Symbol, SLoc] = {}
self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {}
# Set holds a % b expressions that evaluate to 0.
self.divisible: Set[sympy.Expr] = set()
self.divisible: set[sympy.Expr] = set()
# Set that holds "size-like" symbols. When we perform
# "size-oblivious" tests, these can be assumed to be >= 2.
self.size_like: Set[sympy.Symbol] = set()
self.size_like: set[sympy.Symbol] = set()
# Duck-shaping says that if two input tensors have the same size,
# they get assigned the same symbolic variable
self.val_to_var: Dict[int, sympy.Symbol] = {}
self.val_to_var: dict[int, sympy.Symbol] = {}
if specialize_zero_one:
self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One}
self.unbacked_symfloat_counter = itertools.count()
@ -3241,8 +3232,8 @@ class ShapeEnv:
# to the next unbacked symbol to wait on, but if we choose the
# latest key, an assert will only show up at the moment when
# we can actually codegen it.
self.deferred_runtime_asserts: Dict[
Optional[sympy.Symbol], List[RuntimeAssert]
self.deferred_runtime_asserts: dict[
Optional[sympy.Symbol], list[RuntimeAssert]
] = {}
# This exists so we can efficiently invalidate the cache (it's used as
# part of the cache key); otherwise we'd have to iterate through
@ -3279,7 +3270,7 @@ class ShapeEnv:
#
# NB: fresh unbacked symbols NEVER get substitutions applied to them,
# they are binding sites!
self.pending_fresh_unbacked_symbols: List[sympy.Symbol] = []
self.pending_fresh_unbacked_symbols: list[sympy.Symbol] = []
# Version counter used to invalidate cached values
self._prev_cache_key = self._get_key()
@ -3294,8 +3285,8 @@ class ShapeEnv:
# 2. list of arguments
# This drastically reduces the size of the FX graph, avoiding
# duplicated nodes.
self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {}
self.source_to_symbol: Dict[str, sympy.Symbol] = {}
self.fx_node_cache: dict[tuple[Callable, tuple[Any, ...]], torch.fx.Node] = {}
self.source_to_symbol: dict[str, sympy.Symbol] = {}
# Suppose you want to replace an unbacked symbol with another
# unbacked symbol. This is error prone because you can cause
@ -3322,7 +3313,7 @@ class ShapeEnv:
# bindings. At the moment, this is not tracked, but we potentially
# could track this at the IR level using a higher order operator
# with something like effect token tracking.
self.unbacked_alloc_order: Dict[sympy.Symbol, int] = {}
self.unbacked_alloc_order: dict[sympy.Symbol, int] = {}
from torch.fx.experimental.validator import translation_validation_enabled
@ -3345,7 +3336,7 @@ class ShapeEnv:
# Whenever you add a node to self.graph, you must add a mapping to this
# variable. Otherwise, the built FX graph on the replayed ShapeEnv will
# not be valid.
self.name_to_node: Dict[str, torch.fx.Node] = {}
self.name_to_node: dict[str, torch.fx.Node] = {}
@property
def allow_scalar_outputs(self) -> bool:
@ -3439,7 +3430,7 @@ class ShapeEnv:
shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
def _snapshot_tracked_fakes(self) -> Optional[List[Any]]:
def _snapshot_tracked_fakes(self) -> Optional[list[Any]]:
if self.tracked_fakes is None:
return None
@ -3631,7 +3622,7 @@ class ShapeEnv:
self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
return self.source_to_symbol[srcname]
def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None:
def _add_z3var(self, symbol: sympy.Symbol, type: type) -> None:
if self._translation_validation_enabled:
self.validator.add_var(symbol, type)
@ -3651,8 +3642,8 @@ class ShapeEnv:
def _create_fx_call_function(
self,
op: Callable,
args: Tuple,
) -> Tuple[Optional[torch.fx.Node], bool]:
args: tuple,
) -> tuple[Optional[torch.fx.Node], bool]:
# Cache this tuple in order to avoid duplicated nodes.
node_key = (op, args)
# Flags whether the returned node was cached or not.
@ -3681,7 +3672,7 @@ class ShapeEnv:
def _create_fx_placeholder_and_z3var(
self,
symbol: sympy.Symbol,
type: Type,
type: type,
) -> Optional[torch.fx.Node]:
if not self._translation_validation_enabled:
return None
@ -3742,7 +3733,7 @@ class ShapeEnv:
"""Context manager to ignore all guards generated inside"""
return _suppress_guards(self)
def _get_key(self) -> Tuple[int, int, int, int]:
def _get_key(self) -> tuple[int, int, int, int]:
"""
Defines the current "state" of the guards we've accumulated in this ShapeEnv.
Determines when we need to invalidate our cache
@ -3778,7 +3769,7 @@ class ShapeEnv:
ex_size: Sequence[Union[int, SymInt]],
source: Source,
symbolic_context: SymbolicContext,
) -> List[sympy.Expr]:
) -> list[sympy.Expr]:
return self._produce_dyn_sizes_from_int_tuple(
tuple(ex_size), source, symbolic_context
)
@ -3788,7 +3779,7 @@ class ShapeEnv:
tensor_size: Sequence[Union[int, SymInt]],
source: Source,
symbolic_context: SymbolicContext,
) -> List[sympy.Expr]:
) -> list[sympy.Expr]:
assert all(
not is_symbolic(val) for val in tensor_size
), f"Expect size to be a plain tuple of ints but got {tensor_size}"
@ -3816,9 +3807,9 @@ class ShapeEnv:
source: Source,
*,
symbolic_context: Optional[SymbolicContext] = None,
) -> Tuple[
Tuple[Union[int, SymInt], ...],
Tuple[Union[int, SymInt], ...],
) -> tuple[
tuple[Union[int, SymInt], ...],
tuple[Union[int, SymInt], ...],
Union[int, SymInt],
]:
"""
@ -3903,17 +3894,17 @@ class ShapeEnv:
source: Source,
*,
symbolic_context: Optional[SymbolicContext] = None,
) -> Tuple[
Tuple[Union[int, SymInt], ...],
Tuple[Union[int, SymInt], ...],
) -> tuple[
tuple[Union[int, SymInt], ...],
tuple[Union[int, SymInt], ...],
Union[int, SymInt],
]:
dim = len(ex_size)
# Reimplement the legacy behavior
if symbolic_context is None:
constraint_sizes: List[DimConstraint] = [None] * dim
constraint_strides: List[DimConstraint] = [None] * dim
constraint_sizes: list[DimConstraint] = [None] * dim
constraint_strides: list[DimConstraint] = [None] * dim
dynamic_dims = []
dynamic_strides = []
for i in range(dim):
@ -3963,7 +3954,7 @@ class ShapeEnv:
from torch._dynamo.source import TensorProperty, TensorPropertySource
size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(
size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(
ex_size, source, symbolic_context
)
stride = self._compute_symbolic_stride(
@ -4022,11 +4013,11 @@ class ShapeEnv:
],
are_sizes_static: bool,
symbolic_context: SymbolicContext,
) -> List[sympy.Expr]:
) -> list[sympy.Expr]:
from torch._dynamo.source import TensorProperty, TensorPropertySource
stride: List[Optional[sympy.Expr]] = [None] * len(size)
candidates: Dict[Union[int, SymInt], sympy.Expr] = {}
stride: list[Optional[sympy.Expr]] = [None] * len(size)
candidates: dict[Union[int, SymInt], sympy.Expr] = {}
# iterate over unbound strides in val ascending order with
# index descending as a tie breaker since for cases like
@ -4590,7 +4581,7 @@ class ShapeEnv:
return c_render
return c.render(source)
def produce_guards(self, *args: Any, **kwargs: Any) -> List[str]:
def produce_guards(self, *args: Any, **kwargs: Any) -> list[str]:
"""
Like produce_guards_verbose, but only returns the non-verbose guard expressions
(no verbose guards produced.)
@ -4603,7 +4594,7 @@ class ShapeEnv:
sources: Sequence[Source],
source_ref: Callable[[Source], str] = lambda n: n.name(),
*,
guards: Optional[List[ShapeGuard]] = None,
guards: Optional[list[ShapeGuard]] = None,
input_contexts: Optional[DimList[SymbolicContext]] = None,
# Encodes user-specified input shape equations of the form s = s' and s = fn(s').
# (See docs on EqualityConstraint for details of the encoding.)
@ -4611,7 +4602,7 @@ class ShapeEnv:
_simplified: bool = False,
# Indicates if we should produce guards for known static values.
ignore_static: bool = True,
) -> Tuple[List[str], List[str]]: # python, verbose
) -> tuple[list[str], list[str]]: # python, verbose
"""
Generates a list of guards strings which, when evaluated in a context that
defines tensors for all the sources, returns True or False depending
@ -4740,13 +4731,13 @@ class ShapeEnv:
# the symbol mapping is
input_guards = []
symbol_to_source: Dict[sympy.Symbol, List[Source]] = collections.defaultdict(
symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict(
list
)
symbol_to_constraints: DefaultDict[
sympy.Symbol, Set[Constraint]
symbol_to_constraints: defaultdict[
sympy.Symbol, set[Constraint]
] = collections.defaultdict(set)
constraint_violations: List[Tuple[bool, str, Callable[[], str]]] = []
constraint_violations: list[tuple[bool, str, Callable[[], str]]] = []
py_printer = ShapeGuardPythonPrinter(
symbol_to_source, source_ref, self.var_to_sources
@ -4956,7 +4947,7 @@ class ShapeEnv:
# For subclasses, we need to track symints on BOTH the outer
# and inner tensors.
# TODO: type this better
sources_tensors_constraints: List[Tuple[Source, Any, Any, Any]] = [
sources_tensors_constraints: list[tuple[Source, Any, Any, Any]] = [
(source, t, context.constraint_sizes, context.constraint_strides)
]
attrs, _ = t.__tensor_flatten__()
@ -5256,8 +5247,8 @@ class ShapeEnv:
)
if constraint_violations:
warn_msgs: List[str] = []
error_msgs: List[str] = []
warn_msgs: list[str] = []
error_msgs: list[str] = []
debug_names = set()
for warn_only, debug_name, msg_cb in constraint_violations:
if warn_only:
@ -5327,7 +5318,7 @@ class ShapeEnv:
self,
placeholders: Sequence[Union[SymInt, FakeTensor]],
*,
guards: Optional[List[ShapeGuard]] = None,
guards: Optional[list[ShapeGuard]] = None,
ignore_static: bool = True,
) -> Optional[str]:
"""
@ -5386,7 +5377,7 @@ class ShapeEnv:
return self.evaluate_guards_expression(code, args)
return True
def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> List[ShapeGuard]:
def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> list[ShapeGuard]:
"""
Get a list of guards, but pruned so it only provides guards that
reference symints from the passed in input
@ -5401,7 +5392,7 @@ class ShapeEnv:
def bind_symbols(
self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor]
) -> Dict[sympy.Symbol, int]:
) -> dict[sympy.Symbol, int]:
"""
Given a paired list of placeholders (fake tensors with
symbolic sizes) and concrete arguments (regular tensors
@ -5418,7 +5409,7 @@ class ShapeEnv:
another copy. This assumes the guards are already checked,
though if it's cheap we'll check for shenanigans
"""
bindings: Dict[sympy.Symbol, int] = {}
bindings: dict[sympy.Symbol, int] = {}
def bind_symint(arg: object, val: object) -> None:
if isinstance(val, SymInt):
@ -5451,7 +5442,7 @@ class ShapeEnv:
return bindings
def get_nontrivial_guards(self) -> List[SympyBoolean]:
def get_nontrivial_guards(self) -> list[SympyBoolean]:
"""Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""
return [
self.simplify(guard.expr)
@ -5488,9 +5479,9 @@ class ShapeEnv:
@_lru_cache
def get_axioms(
self,
symbols: Optional[Tuple[sympy.Symbol]] = None,
symbols: Optional[tuple[sympy.Symbol]] = None,
compute_hint: bool = False,
) -> Tuple[SympyBoolean, ...]:
) -> tuple[SympyBoolean, ...]:
"""
Given the symbols in an expression, it returns all the runtime asserts that have those symbols
concatenated with all the guards.
@ -5518,9 +5509,9 @@ class ShapeEnv:
@lru_cache(None)
def get_implications(
self, e: SympyBoolean
) -> Tuple[Tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]:
) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]:
"""Given a expression, it returns a list of predicates that follow from it"""
equiv: Dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {}
equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {}
def add_expr(expr: SympyBoolean) -> None:
expr = canonicalize_bool_expr(expr)
@ -5564,8 +5555,8 @@ class ShapeEnv:
unbacked_only: bool = False,
compute_hint: bool = False,
size_oblivious: bool = False,
axioms: Optional[Tuple[SympyBoolean]] = None,
var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None,
axioms: Optional[tuple[SympyBoolean]] = None,
var_to_range: Optional[tuple[tuple[sympy.Symbol, ValueRanges]]] = None,
) -> Optional[sympy.Basic]:
"""
Tries to evaluate expr without introducing guards
@ -5589,7 +5580,7 @@ class ShapeEnv:
expr = canonicalize_bool_expr(expr)
def resimplify_floor_div(axioms: Dict[sympy.Expr, sympy.Expr]) -> None:
def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None:
if not self._resimplify_floor_div_axioms:
return
self._resimplify_floor_div_axioms = False
@ -6114,7 +6105,7 @@ class ShapeEnv:
# Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
# (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
# Prefer to simplify out symbols with ephemeral sources.
def _smart_symbol_sort(x: sympy.Symbol) -> Tuple[int, int, str]:
def _smart_symbol_sort(x: sympy.Symbol) -> tuple[int, int, str]:
has_only_ephemeral_sources = x in self.var_to_sources and all(
s.is_ephemeral() for s in self.var_to_sources[x]
)
@ -6282,7 +6273,7 @@ class ShapeEnv:
def _get_stack_summary(
self, is_debug: bool = False, framework_loc: Optional[str] = None
) -> Tuple[SLoc, str]:
) -> tuple[SLoc, str]:
floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc
if floc is None:
frame = inspect.currentframe()
@ -6903,7 +6894,7 @@ class _PythonMsgPrinter(PythonPrinter):
(i.e., as ==, !=, >, <).
"""
def __init__(self, src_map: Dict[str, List[str]]) -> None:
def __init__(self, src_map: dict[str, list[str]]) -> None:
super().__init__()
self.src_map = src_map
@ -6912,7 +6903,7 @@ class _PythonMsgPrinter(PythonPrinter):
def _suggest_torch_checks(
e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]]
e: GuardOnDataDependentSymNode, src_map: defaultdict[str, list[str]]
) -> None:
# extract the unresolved condition on unbacked symints in the error
cond = e.cond

View File

@ -5,7 +5,7 @@ import logging
import math
import operator
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Optional, Union
import sympy
@ -60,7 +60,7 @@ try:
def z3str(e: z3.ExprRef) -> str:
assert z3.is_expr(e), f"unsupported expression type: {e}"
def get_args_str(e: z3.ExprRef) -> List[str]:
def get_args_str(e: z3.ExprRef) -> list[str]:
return [z3str(e.arg(i)) for i in range(e.num_args())]
# First, we simplify the given expression.
@ -350,13 +350,13 @@ try:
super().__init__(module, garbage_collect_values=True)
def placeholder(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
symbol = fx_traceback.get_current_meta()["symbol"]
return self.validator.z3var(symbol)
def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
if target != torch._assert:
# Lift and runs the node target function
@ -481,21 +481,21 @@ try:
log.debug("new instance")
# Mapping of SymPy symbols to Z3 variables.
self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {}
self.symbols: dict[sympy.Symbol, z3.ExprRef] = {}
# Set of source Z3 expressions.
# They represent the generated guards without any kind of
# simplification or transformation.
self._source_exprs: Set[z3.BoolRef] = set()
self._source_exprs: set[z3.BoolRef] = set()
# Set of target Z3 expressions.
# They represent the actual checked guards at runtime. They might
# be simplified or transformed versions of the source guards.
self._target_exprs: Set[z3.BoolRef] = set()
self._target_exprs: set[z3.BoolRef] = set()
# Set of Z3 expressions representing assertions over both the
# source and target expressions.
self._assertions: Set[z3.BoolRef] = set()
self._assertions: set[z3.BoolRef] = set()
# Retrieves the corresponding Z3 variable.
def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
@ -503,7 +503,7 @@ try:
return self.symbols[symbol]
# Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef:
def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef:
if symbol in self.symbols:
return self.symbols[symbol]
@ -769,7 +769,7 @@ def bisect(shape_env):
# Checks whether the given shape_env fails when produce_guards is called.
def check_shapeenv_fails(
shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]
shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]]
) -> Optional[ValidationException]:
assert tracked_fakes is not None
try:

View File

@ -11,23 +11,10 @@ import os
import re
import warnings
from collections import defaultdict
from collections.abc import Iterable
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
FrozenSet,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
)
from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING
import torch
import torch.utils._pytree as pytree
@ -47,11 +34,11 @@ if TYPE_CHECKING:
# Mapping of builtins to their `typing` equivalent.
_origin_type_map = {
list: List,
dict: Dict,
set: Set,
frozenset: FrozenSet,
tuple: Tuple,
list: list,
dict: dict,
set: set,
frozenset: frozenset,
tuple: tuple,
}
_legal_ops = dict.fromkeys(
@ -61,7 +48,7 @@ _legal_ops = dict.fromkeys(
# Signature for functions thattransforms the body (`list[str]`) of the
# generated code
TransformCodeFunc = Callable[[List[str]], List[str]]
TransformCodeFunc = Callable[[list[str]], list[str]]
class _CustomBuiltin(NamedTuple):
@ -78,7 +65,7 @@ class _CustomBuiltin(NamedTuple):
obj: Any
_custom_builtins: Dict[str, _CustomBuiltin] = {}
_custom_builtins: dict[str, _CustomBuiltin] = {}
def _register_custom_builtin(name: str, import_str: str, obj: Any):
@ -144,10 +131,10 @@ class _Namespace:
"""
def __init__(self):
self._obj_to_name: Dict[Any, str] = {}
self._obj_to_name: dict[Any, str] = {}
self._unassociated_names = set()
self._used_names: Set[str] = set()
self._base_count: Dict[str, int] = defaultdict(int)
self._used_names: set[str] = set()
self._base_count: dict[str, int] = defaultdict(int)
self._illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
@ -261,10 +248,10 @@ class PythonCode:
# Python source code for the forward function definition.
src: str
# Values in global scope during execution of `src_def`.
globals: Dict[str, Any]
globals: dict[str, Any]
# Optional mapping from the forward function's line number to
# node index.
_lineno_map: Optional[Dict[int, Optional[int]]]
_lineno_map: Optional[dict[int, Optional[int]]]
def _format_target(base: str, target: str) -> str:
@ -311,7 +298,7 @@ class _PyTreeInfo(NamedTuple):
Contains extra info stored when we're using Pytrees
"""
orig_args: List[str]
orig_args: list[str]
in_spec: pytree.TreeSpec
out_spec: Optional[pytree.TreeSpec]
@ -359,7 +346,7 @@ class CodeGen:
self._body_transformer: Optional[TransformCodeFunc] = None
self._func_name: str = "forward"
def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
def gen_fn_def(self, free_vars: list[str], maybe_return_annotation: str) -> str:
"""
Given the free variables and a return annotation, generates the beginning of the FX function.
By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
@ -398,7 +385,7 @@ class CodeGen:
"""
return outputs
def additional_globals(self) -> List[Tuple[str, Any]]:
def additional_globals(self) -> list[tuple[str, Any]]:
"""
If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
For example, return ['List', typing.List] if you need ``List`` in the global context.
@ -416,13 +403,13 @@ class CodeGen:
include_device: bool = False,
colored: bool = False,
) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {}
free_vars: list[str] = []
body: list[str] = []
globals_: dict[str, Any] = {}
wrapped_fns: dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = [""]
maybe_return_annotation: list[str] = [""]
include_stride = include_stride or (
os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1"
)
@ -553,7 +540,7 @@ class CodeGen:
return blue(repr(arg))
def _format_args(
args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
args: tuple[Argument, ...], kwargs: dict[str, Argument]
) -> str:
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
@ -565,8 +552,8 @@ class CodeGen:
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
node_to_last_use: dict[Node, Node] = {}
user_to_last_uses: dict[Node, list[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
@ -782,9 +769,9 @@ class CodeGen:
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
# remove counter and generate lineno to node index mapping
lineno_map: Dict[int, Optional[int]] = {}
lineno_map: dict[int, Optional[int]] = {}
prologue_len = prologue.count("\n") + 1
new_lines: List[str] = []
new_lines: list[str] = []
cur_idx = None
for line in "".join(body).split("\n"):
counter = re.search(r"# COUNTER: (\d+)", line)
@ -904,11 +891,11 @@ class _FindNodesLookupTable:
"""
def __init__(self):
self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(
self.table: dict[tuple[str, Optional[Target]], dict[Node, None]] = defaultdict(
dict
)
def _key(self, node) -> Tuple[str, Optional[Target]]:
def _key(self, node) -> tuple[str, Optional[Target]]:
return (node.op, node.target if node.op == "call_function" else None)
def __contains__(self, node) -> bool:
@ -985,14 +972,14 @@ class Graph:
def __init__(
self,
owning_module: Optional["GraphModule"] = None,
tracer_cls: Optional[Type["Tracer"]] = None,
tracer_extras: Optional[Dict[str, Any]] = None,
tracer_cls: Optional[type["Tracer"]] = None,
tracer_extras: Optional[dict[str, Any]] = None,
):
"""
Construct an empty Graph.
"""
self._root: Node = Node(self, "", "root", "", (), {})
self._used_names: Dict[str, int] = {} # base name -> number
self._used_names: dict[str, int] = {} # base name -> number
self._insert = self._root.prepend
self._len = 0
self._graph_namespace = _Namespace()
@ -1000,7 +987,7 @@ class Graph:
self._tracer_cls = tracer_cls
self._tracer_extras = tracer_extras
self._codegen = CodeGen()
self._co_fields: Dict[str, Any] = {}
self._co_fields: dict[str, Any] = {}
self._find_nodes_lookup_table = _FindNodesLookupTable()
@property
@ -1060,7 +1047,7 @@ class Graph:
@compatibility(is_backward_compatible=True)
def graph_copy(
self, g: "Graph", val_map: Dict[Node, Node], return_output_node=False
self, g: "Graph", val_map: dict[Node, Node], return_output_node=False
) -> "Optional[Argument]":
"""
Copy all nodes from a given graph into ``self``.
@ -1113,8 +1100,8 @@ class Graph:
self,
op: str,
target: "Target",
args: Optional[Tuple["Argument", ...]] = None,
kwargs: Optional[Dict[str, "Argument"]] = None,
args: Optional[tuple["Argument", ...]] = None,
kwargs: Optional[dict[str, "Argument"]] = None,
name: Optional[str] = None,
type_expr: Optional[Any] = None,
) -> Node:
@ -1373,8 +1360,8 @@ class Graph:
def call_module(
self,
module_name: str,
args: Optional[Tuple["Argument", ...]] = None,
kwargs: Optional[Dict[str, "Argument"]] = None,
args: Optional[tuple["Argument", ...]] = None,
kwargs: Optional[dict[str, "Argument"]] = None,
type_expr: Optional[Any] = None,
) -> Node:
"""
@ -1423,8 +1410,8 @@ class Graph:
def call_method(
self,
method_name: str,
args: Optional[Tuple["Argument", ...]] = None,
kwargs: Optional[Dict[str, "Argument"]] = None,
args: Optional[tuple["Argument", ...]] = None,
kwargs: Optional[dict[str, "Argument"]] = None,
type_expr: Optional[Any] = None,
) -> Node:
"""
@ -1462,8 +1449,8 @@ class Graph:
def call_function(
self,
the_function: Callable[..., Any],
args: Optional[Tuple["Argument", ...]] = None,
kwargs: Optional[Dict[str, "Argument"]] = None,
args: Optional[tuple["Argument", ...]] = None,
kwargs: Optional[dict[str, "Argument"]] = None,
type_expr: Optional[Any] = None,
) -> Node:
"""
@ -1668,10 +1655,10 @@ class Graph:
Return a human-readable (not machine-readable) string representation
of this Graph
"""
placeholder_names: List[str] = []
placeholder_names: list[str] = []
# This is a one-element array just so ``format_node`` can modify the closed
# over value
maybe_return_typename: List[str] = [""]
maybe_return_typename: list[str] = [""]
node_strs = [node.format_node(placeholder_names) for node in self.nodes]
param_str = ", ".join(placeholder_names)
@ -1729,8 +1716,8 @@ class Graph:
f"defined! Please check that Nodes in the graph are topologically ordered\n{self}"
)
seen_names: Set[str] = set()
seen_values: Set[Node] = set()
seen_names: set[str] = set()
seen_values: set[Node] = set()
for node in self.nodes:
if node.op not in [
"placeholder",

View File

@ -8,7 +8,7 @@ import sys
import traceback
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
@ -39,7 +39,7 @@ class _EvalCacheLoader:
self.eval_cache = {}
self.next_id = 0
def cache(self, src: str, globals: Dict[str, Any], co_fields=None):
def cache(self, src: str, globals: dict[str, Any], co_fields=None):
"""Store the source in a private cache, and add a lazy entry in linecache
that allows the source to be retrieved by 'filename'.
@ -83,19 +83,19 @@ class _EvalCacheLoader:
_loader = _EvalCacheLoader()
def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None):
def _exec_with_source(src: str, globals: dict[str, Any], co_fields=None):
key = _loader.cache(src, globals, co_fields)
exec(compile(src, key, "exec"), globals)
def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None):
def _forward_from_src(src: str, globals: dict[str, Any], co_fields=None):
return _method_from_src(
method_name="forward", src=src, globals=globals, co_fields=co_fields
)
def _method_from_src(
method_name: str, src: str, globals: Dict[str, Any], co_fields=None
method_name: str, src: str, globals: dict[str, Any], co_fields=None
) -> Callable:
# avoid mutating the passed in dict
globals_copy = globals.copy()
@ -114,8 +114,8 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
return f"from {module_name} import {attr_name} as {name}"
def _format_import_block(globals: Dict[str, Any], importer: Importer):
import_strs: Set[str] = {
def _format_import_block(globals: dict[str, Any], importer: Importer):
import_strs: set[str] = {
_format_import_statement(name, obj, importer) for name, obj in globals.items()
}
# Sort the imports so we have a stable import block that allows us to
@ -124,7 +124,7 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer):
@compatibility(is_backward_compatible=True)
def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Module:
# BC: attribute name was changed from `code` to `_code` to facilitate
# making `code` into a property and adding a docstring to it
fn_src = body.get("_code") or body["code"]
@ -134,7 +134,7 @@ def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mod
@compatibility(is_backward_compatible=True)
def reduce_package_graph_module(
importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
importer: PackageImporter, body: dict[Any, Any], generated_module_name: str
) -> torch.nn.Module:
forward = importer.import_module(generated_module_name).forward
return _deserialize_graph_module(forward, body)
@ -142,7 +142,7 @@ def reduce_package_graph_module(
@compatibility(is_backward_compatible=True)
def reduce_deploy_graph_module(
importer: PackageImporter, body: Dict[Any, Any], import_block: str
importer: PackageImporter, body: dict[Any, Any], import_block: str
) -> torch.nn.Module:
ns = {}
ns["__builtins__"] = importer.patched_builtins
@ -162,7 +162,7 @@ class _CodeOnlyModule(torch.nn.Module):
def _deserialize_graph_module(
forward, body: Dict[Any, Any], graph_module_cls=None
forward, body: dict[Any, Any], graph_module_cls=None
) -> torch.nn.Module:
"""
Deserialize a GraphModule given the dictionary of the original module,
@ -271,7 +271,7 @@ def _get_attr(model: torch.nn.Module, attr_name: str):
return _get_attr_via_attr_list(model, attr_name.split("."))
def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: List[str]):
def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]):
if len(attr_list) == 0:
return model
*prefix, field = attr_list
@ -415,7 +415,7 @@ class GraphModule(torch.nn.Module):
code.
"""
def __new__(cls: "Type[GraphModule]", *args, **kwargs):
def __new__(cls: "type[GraphModule]", *args, **kwargs):
# each instance of a graph module needs its own forward method
# so create a new singleton class for each instance.
# it is a subclass of the user-defined class, the only difference
@ -437,7 +437,7 @@ class GraphModule(torch.nn.Module):
@compatibility(is_backward_compatible=True)
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
root: Union[torch.nn.Module, dict[str, Any]],
graph: Graph,
class_name: str = "GraphModule",
):
@ -527,12 +527,12 @@ class GraphModule(torch.nn.Module):
self._tracer_extras = self.graph._tracer_extras
# Dictionary to store metadata
self.meta: Dict[str, Any] = {}
self._replace_hooks: List[Callable] = []
self._create_node_hooks: List[Callable] = []
self._erase_node_hooks: List[Callable] = []
self.meta: dict[str, Any] = {}
self._replace_hooks: list[Callable] = []
self._create_node_hooks: list[Callable] = []
self._erase_node_hooks: list[Callable] = []
# Used to remove hooks from deepcopied graph modules within a context manager.
self._deepcopy_hooks: List[Callable] = []
self._deepcopy_hooks: list[Callable] = []
# TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
@ -739,7 +739,7 @@ class {module_name}(torch.nn.Module):
This method can be called to clean up an ``nn.Module`` without
manually calling ``delete_submodule`` on each unused submodule.
"""
used: List[str] = []
used: list[str] = []
for node in self.graph.nodes:
if node.op == "call_module" or node.op == "get_attr":

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, Iterable, List, Tuple
from collections.abc import Iterable
from typing import Any
from torch.utils._pytree import (
_dict_flatten,
@ -79,25 +80,25 @@ compatibility(is_backward_compatible=True)(immutable_dict)
# Register immutable collections for PyTree operations
def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
def _immutable_dict_flatten(d: dict[Any, Any]) -> tuple[list[Any], Context]:
return _dict_flatten(d)
def _immutable_dict_unflatten(
values: Iterable[Any],
context: Context,
) -> Dict[Any, Any]:
) -> dict[Any, Any]:
return immutable_dict(_dict_unflatten(values, context))
def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
def _immutable_list_flatten(d: list[Any]) -> tuple[list[Any], Context]:
return _list_flatten(d)
def _immutable_list_unflatten(
values: Iterable[Any],
context: Context,
) -> List[Any]:
) -> list[Any]:
return immutable_list(_list_unflatten(values, context))

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import inspect
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
import torch.fx.traceback as fx_traceback
@ -17,6 +17,10 @@ from .node import Argument, map_aggregate, map_arg, Node, Target
from .proxy import Proxy
if TYPE_CHECKING:
from collections.abc import Iterator
__all__ = ["Interpreter", "Transformer"]
@ -92,7 +96,7 @@ class Interpreter:
self.graph = graph
else:
self.graph = self.module.graph # type: ignore[assignment]
self.env: Dict[Node, Any] = {}
self.env: dict[Node, Any] = {}
self.name = "Interpreter"
self.garbage_collect_values = garbage_collect_values
self.extra_traceback = True
@ -102,8 +106,8 @@ class Interpreter:
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use: Dict[Node, Node] = {}
self.user_to_last_uses: Dict[Node, List[Node]] = {}
node_to_last_use: dict[Node, Node] = {}
self.user_to_last_uses: dict[Node, list[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
@ -118,7 +122,7 @@ class Interpreter:
def run(
self,
*args,
initial_env: Optional[Dict[Node, Any]] = None,
initial_env: Optional[dict[Node, Any]] = None,
enable_io_processing: bool = True,
) -> Any:
"""
@ -232,7 +236,7 @@ class Interpreter:
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
@ -268,7 +272,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True)
def get_attr(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
@ -289,7 +293,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True)
def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
"""
Execute a ``call_function`` node and return the result.
@ -311,7 +315,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True)
def call_method(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
"""
Execute a ``call_method`` node and return the result.
@ -335,7 +339,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True)
def call_module(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
"""
Execute a ``call_module`` node and return the result.
@ -360,7 +364,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True)
def output(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
"""
Execute an ``output`` node. This really just retrieves
@ -401,7 +405,7 @@ class Interpreter:
return attr_itr
@compatibility(is_backward_compatible=True)
def fetch_args_kwargs_from_env(self, n: Node) -> Tuple[Tuple, Dict]:
def fetch_args_kwargs_from_env(self, n: Node) -> tuple[tuple, dict]:
"""
Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
from the current execution environment.
@ -497,7 +501,7 @@ class Transformer(Interpreter):
def __init__(self, graph: Graph):
super().__init__()
self.graph = graph
self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment]
self.tensor_attrs: dict[torch.Tensor, str] = {} # type: ignore[assignment]
def is_leaf_module(self, _, __) -> bool:
return True
@ -507,7 +511,7 @@ class Transformer(Interpreter):
@compatibility(is_backward_compatible=True)
def placeholder(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Proxy:
"""
Execute a ``placeholder`` node. In ``Transformer``, this is
@ -529,7 +533,7 @@ class Transformer(Interpreter):
@compatibility(is_backward_compatible=True)
def get_attr(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Proxy:
"""
Execute a ``get_attr`` node. In ``Transformer``, this is
@ -548,7 +552,7 @@ class Transformer(Interpreter):
@compatibility(is_backward_compatible=True)
def call_module(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
# Override so that the leaf module policy from `self.tracer` is respected.
assert isinstance(target, str)
@ -557,7 +561,7 @@ class Transformer(Interpreter):
@compatibility(is_backward_compatible=True)
def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any:
# Override so that functions that were wrapped are still wrapped.
return self.tracer.create_proxy("call_function", target, args, kwargs)

View File

@ -3,19 +3,8 @@ import builtins
import inspect
import types
import warnings
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from collections.abc import Mapping, Sequence
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from torch._C import _NodeBase
@ -57,7 +46,7 @@ Target = Union[Callable[..., Any], str]
Argument = Optional[
Union[
Tuple["Argument", ...],
tuple["Argument", ...],
Sequence["Argument"],
Mapping[str, "Argument"],
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
@ -79,7 +68,7 @@ _legal_ops = dict.fromkeys(
]
)
_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
_side_effectful_need_to_be_preserved_pre_dispatch: set[Callable] = {
torch._C._set_grad_enabled,
torch.amp._enter_autocast,
torch.amp._exit_autocast,
@ -87,7 +76,7 @@ _side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
# or add logic to correctly mark all inplace ops as side effectful.
_side_effectful_functions: Set[Callable] = {
_side_effectful_functions: set[Callable] = {
torch._assert,
torch._assert_async,
_ops.aten._assert_async.msg,
@ -227,18 +216,18 @@ class Node(_NodeBase):
in the Graph printout.
"""
_args: Tuple["Argument", ...]
_kwargs: Dict[str, "Argument"]
_args: tuple["Argument", ...]
_kwargs: dict[str, "Argument"]
graph: "Graph"
name: str
op: str
target: "Target"
_input_nodes: Dict["Node", None]
users: Dict["Node", None]
_input_nodes: dict["Node", None]
users: dict["Node", None]
type: Optional[Any]
_sort_key: Any
_repr_fn: Optional[Callable[["Node"], str]]
meta: Dict[str, Any]
meta: dict[str, Any]
@compatibility(is_backward_compatible=True)
def __init__(
@ -247,8 +236,8 @@ class Node(_NodeBase):
name: str,
op: str,
target: "Target",
args: Tuple["Argument", ...],
kwargs: Dict[str, "Argument"],
args: tuple["Argument", ...],
kwargs: dict[str, "Argument"],
return_type: Optional[Any] = None,
) -> None:
"""
@ -339,14 +328,14 @@ class Node(_NodeBase):
# transformations. This metadata is preserved across node copies
assign(self, "meta", {})
def __getstate__(self) -> Dict[str, Any]:
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
state["_erased"] = self._erased
state["_prev"] = self._prev
state["_next"] = self._next
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
def __setstate__(self, state: dict[str, Any]) -> None:
_erased = state.pop("_erased")
_prev = state.pop("_prev")
_next = state.pop("_next")
@ -442,7 +431,7 @@ class Node(_NodeBase):
p._next, n._prev = n, p
@property
def args(self) -> Tuple[Argument, ...]:
def args(self) -> tuple[Argument, ...]:
"""
The tuple of arguments to this ``Node``. The interpretation of arguments
depends on the node's opcode. See the :class:`Node` docstring for more
@ -454,7 +443,7 @@ class Node(_NodeBase):
return self._args
@args.setter
def args(self, a: Tuple[Argument, ...]) -> None:
def args(self, a: tuple[Argument, ...]) -> None:
"""
Set the tuple of arguments to this Node. The interpretation of arguments
depends on the node's opcode. See the ``fx.Graph`` docstring for more
@ -465,7 +454,7 @@ class Node(_NodeBase):
self.__update_args_kwargs(a, self._kwargs)
@property
def kwargs(self) -> Dict[str, Argument]:
def kwargs(self) -> dict[str, Argument]:
"""
The dict of keyword arguments to this ``Node``. The interpretation of arguments
depends on the node's opcode. See the :class:`Node` docstring for more
@ -477,7 +466,7 @@ class Node(_NodeBase):
return self._kwargs
@kwargs.setter
def kwargs(self, k: Dict[str, Argument]) -> None:
def kwargs(self, k: dict[str, Argument]) -> None:
"""
Set the dict of kwargs to this Node. The interpretation of arguments
depends on the node's opcode. See the ``fx.Graph`` docstring for more
@ -488,7 +477,7 @@ class Node(_NodeBase):
self.__update_args_kwargs(self._args, k)
@property
def all_input_nodes(self) -> List["Node"]:
def all_input_nodes(self) -> list["Node"]:
"""
Return all Nodes that are inputs to this Node. This is equivalent to
iterating over ``args`` and ``kwargs`` and only collecting the values that
@ -534,7 +523,7 @@ class Node(_NodeBase):
self._args = args_left + (arg,) + args_right
_new_input_nodes: Dict[Node, None] = {}
_new_input_nodes: dict[Node, None] = {}
map_arg(arg, _new_input_nodes.setdefault)
for new_use in _new_input_nodes.keys():
@ -574,7 +563,7 @@ class Node(_NodeBase):
self.meta["stack_trace"] = trace
def __update_args_kwargs(
self, new_args: Tuple["Argument", ...], new_kwargs: Dict[str, "Argument"]
self, new_args: tuple["Argument", ...], new_kwargs: dict[str, "Argument"]
) -> None:
"""
This API is internal. Do *not* call it directly.
@ -634,8 +623,8 @@ class Node(_NodeBase):
@compatibility(is_backward_compatible=True)
def format_node(
self,
placeholder_names: Optional[List[str]] = None,
maybe_return_typename: Optional[List[str]] = None,
placeholder_names: Optional[list[str]] = None,
maybe_return_typename: Optional[list[str]] = None,
) -> Optional[str]:
"""
Return a descriptive string representation of ``self``.
@ -704,7 +693,7 @@ class Node(_NodeBase):
delete_user_cb: Callable[["Node"], bool] = lambda user: True,
*,
propagate_meta: bool = False,
) -> List["Node"]:
) -> list["Node"]:
"""
Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
@ -775,7 +764,7 @@ class Node(_NodeBase):
# impure since it mutates inputs
return True
tags: Optional[List[torch.Tag]] = getattr(self.target, "_tags", None)
tags: Optional[list[torch.Tag]] = getattr(self.target, "_tags", None)
if tags is not None and torch.Tag.nondeterministic_seeded in tags:
# impure since it mutates RNG state
return True
@ -799,8 +788,8 @@ class Node(_NodeBase):
def normalized_arguments(
self,
root: torch.nn.Module,
arg_types: Optional[Tuple[Any]] = None,
kwarg_types: Optional[Dict[str, Any]] = None,
arg_types: Optional[tuple[Any]] = None,
kwarg_types: Optional[dict[str, Any]] = None,
normalize_to_only_use_kwargs: bool = False,
) -> Optional[ArgsKwargsPair]:
"""

View File

@ -5,17 +5,7 @@ import numbers
import types
import typing
import warnings
from typing import (
Any,
Callable,
cast,
Dict,
List,
NamedTuple,
Optional,
Tuple,
TYPE_CHECKING,
)
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING
import torch
from torch._jit_internal import boolean_dispatched
@ -44,11 +34,11 @@ class ArgsKwargsPair(NamedTuple):
Simple named tuple for wrapping args/kwargs pairs.
"""
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
args: tuple[Any, ...]
kwargs: dict[str, Any]
_manual_overrides: Dict[Callable, List[inspect.Signature]] = {}
_manual_overrides: dict[Callable, list[inspect.Signature]] = {}
def _nonzero_schemas():
@ -108,7 +98,7 @@ def _torchscript_schema_to_signature_impl(
) -> inspect.Signature:
from inspect import Parameter
parameters: List[Parameter] = []
parameters: list[Parameter] = []
for arg in ts_schema.arguments:
arg_type = _torchscript_type_to_python_type(arg.type)
default = arg.default_value if arg.has_default_value() else Parameter.empty
@ -154,7 +144,7 @@ def _torchscript_schema_to_signature_impl(
return inspect.Signature(parameters, return_annotation=return_type)
_SCHEMA_TO_SIGNATURE_CACHE: Dict[Tuple[str, str], inspect.Signature] = {}
_SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {}
def _torchscript_schema_to_signature(
@ -173,7 +163,7 @@ def _torchscript_schema_to_signature(
@compatibility(is_backward_compatible=False)
def check_for_mutable_operation(
target: Callable, args: Tuple["Argument", ...], kwargs: Dict[str, "Argument"]
target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"]
):
signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
@ -265,12 +255,12 @@ def create_type_hint(x):
if isinstance(x, list):
def ret_type(x):
return List[x] # type: ignore[valid-type]
return list[x] # type: ignore[valid-type]
else:
def ret_type(x):
return Tuple[x, ...]
return tuple[x, ...] # type: ignore[valid-type]
if len(x) == 0:
return ret_type(Any)
@ -291,6 +281,10 @@ def create_type_hint(x):
return x
_LIST_TYPES = (list, typing.List) # noqa: UP006
_TUPLE_TYPES = (tuple, typing.Tuple) # noqa: UP006
@compatibility(is_backward_compatible=False)
def type_matches(signature_type: Any, argument_type: Any):
sig_origin_type = getattr(signature_type, "__origin__", signature_type)
@ -304,22 +298,24 @@ def type_matches(signature_type: Any, argument_type: Any):
sig_contained = signature_type.__args__
return any(type_matches(c, argument_type) for c in sig_contained)
if signature_type is List[int] and argument_type is int:
if signature_type is typing.List[int] and argument_type is int: # noqa: UP006
# int can be promoted to List[int]
return True
if getattr(signature_type, "__origin__", None) in {list, List}:
if getattr(signature_type, "__origin__", None) in _LIST_TYPES:
sig_el_type = signature_type.__args__[0]
if sig_el_type is argument_type:
return True
if not inspect.isclass(sig_el_type):
warnings.warn(
f"Does not support nested parametric types, got {signature_type}. Please file a bug."
)
return False
if getattr(argument_type, "__origin__", None) in {list, List}:
if getattr(argument_type, "__origin__", None) in _LIST_TYPES:
return issubclass(argument_type.__args__[0], sig_el_type)
def is_homogeneous_tuple(t):
if getattr(t, "__origin__", None) not in {tuple, Tuple}:
if getattr(t, "__origin__", None) not in _TUPLE_TYPES:
return False
contained = t.__args__
if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
@ -344,10 +340,10 @@ def type_matches(signature_type: Any, argument_type: Any):
@compatibility(is_backward_compatible=False)
def normalize_function(
target: Callable,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
arg_types: Optional[Tuple[Any]] = None,
kwarg_types: Optional[Dict[str, Any]] = None,
args: tuple[Any],
kwargs: Optional[dict[str, Any]] = None,
arg_types: Optional[tuple[Any]] = None,
kwarg_types: Optional[dict[str, Any]] = None,
normalize_to_only_use_kwargs: bool = False,
) -> Optional[ArgsKwargsPair]:
"""
@ -424,7 +420,7 @@ def normalize_function(
)
else:
if arg_types is not None or kwarg_types is not None:
arg_types = arg_types if arg_types else cast(Tuple[Any], ())
arg_types = arg_types if arg_types else cast(tuple[Any], ())
kwarg_types = kwarg_types if kwarg_types else {}
for candidate_signature in torch_op_schemas:
sig_matches = True
@ -468,8 +464,8 @@ def normalize_function(
def normalize_module(
root: torch.nn.Module,
target: str,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
args: tuple[Any],
kwargs: Optional[dict[str, Any]] = None,
normalize_to_only_use_kwargs: bool = False,
) -> Optional[ArgsKwargsPair]:
"""
@ -513,8 +509,8 @@ def normalize_module(
def _args_kwargs_to_normalized_args_kwargs(
sig: inspect.Signature,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
normalize_to_only_use_kwargs: bool,
) -> Optional[ArgsKwargsPair]:
"""
@ -552,8 +548,8 @@ def _args_kwargs_to_normalized_args_kwargs(
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
new_kwargs: Dict[str, Any] = {}
new_args: List[Any] = []
new_kwargs: dict[str, Any] = {}
new_args: list[Any] = []
for i, param in enumerate(sig.parameters):
if not normalize_to_only_use_kwargs and i < len(args):
new_args.append(bound_args.arguments[param])

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import logging
import os
from typing import Any, List, Set, Union
from typing import Any, Union
from sympy import Integer, Number, Symbol
from sympy.logic.boolalg import BooleanAtom
@ -28,7 +28,7 @@ from torch.utils._sympy.reference import TensorReferenceAnalysis
from torch.utils._sympy.symbol import symbol_is_type, SymT
__all__: List[str] = []
__all__: list[str] = []
log = logging.getLogger(__name__)
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
@ -242,7 +242,7 @@ def tensorify_python_scalars(
if node.op == "call_function" and (
replacement_op := SUPPORTED_OPS.get(node.target)
):
args: List[Any] = []
args: list[Any] = []
transform = False
compute_dtype = get_computation_dtype(node.meta["val"].dtype)
@ -299,7 +299,7 @@ def tensorify_python_scalars(
"tensorify_float_success", True, overwrite=True
)
failed_tensorify_ops: Set[str] = set()
failed_tensorify_ops: set[str] = set()
# Now do one more pass that specializes all symfloats we didn't manage
# to tensorify away.

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, Tuple
from typing import Any
import torch
from torch.fx import Graph, GraphModule, Node
@ -90,14 +90,14 @@ class CSEPass(PassBase):
modified = False
new_graph = Graph()
env: Dict[
env: dict[
Node, Node
] = {} # map from node in the old graph to node in the new graph
hash_env: Dict[
Tuple[torch._ops.OpOverload, int], Node
hash_env: dict[
tuple[torch._ops.OpOverload, int], Node
] = {} # map from hash to a node in the new graph
token_map: Dict[
Tuple[torch._ops.OpOverload, int], Dict[str, Any]
token_map: dict[
tuple[torch._ops.OpOverload, int], dict[str, Any]
] = {} # map from hash to token
for n in graph_module.graph.nodes:
# The placeholder, output, and get_attr nodes are copied to the new graph without change

View File

@ -2,7 +2,7 @@
import hashlib
from itertools import chain
from typing import Any, Dict, Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import torch
import torch.fx
@ -150,10 +150,10 @@ if HAS_PYDOT:
def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
return self._dot_graphs[f"{self._name}_{submod_name}"]
def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]:
def get_all_dot_graphs(self) -> dict[str, pydot.Dot]:
return self._dot_graphs
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]:
template = {
"shape": self.dot_graph_shape,
"fillcolor": "#CAFFE3",

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, NamedTuple, Optional
import torch
from torch.fx._compatibility import compatibility
@ -29,7 +29,7 @@ def replace_target_nodes_with(
"""Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
and updates them to match the new op code and target"""
new_graph = Graph()
val_map: Dict[Node, Node] = {}
val_map: dict[Node, Node] = {}
for node in fx_module.graph.nodes:
if node.op == old_op and node.target == old_target:
args = map_arg(node.args, lambda n: val_map[n])
@ -52,7 +52,7 @@ class size_bytes(NamedTuple):
@compatibility(is_backward_compatible=False)
def get_size_of_all_nodes(
fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
fx_module: GraphModule, args: Optional[list[torch.Tensor]] = None
) -> None:
"""Given a fx graph module, update each node with its total size (weights + bias + output)
and its output_size(output). For a non-module node, the total size is the output size.

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import os
from typing import Callable, Dict, List, Optional, Set, TypeVar
from typing import Callable, Optional, TypeVar
from torch.fx import Graph, Node
from torch.fx._compatibility import compatibility
@ -45,11 +45,11 @@ class GraphTransformObserver:
self.active = trace.enabled or self.log_url is not None
if self.active:
self.erased_nodes: Set[str] = set()
self.created_nodes: Set[str] = set()
self.name_to_node: Dict[str, Node] = {}
self.erased_nodes: set[str] = set()
self.created_nodes: set[str] = set()
self.name_to_node: dict[str, Node] = {}
# record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context
self.copied_gms: List[GraphModule] = []
self.copied_gms: list[GraphModule] = []
self._node_creation_hook = self.get_node_creation_hook()
self._node_erase_hook = self.get_node_erase_hook()

View File

@ -2,8 +2,9 @@
import collections
import itertools
import logging
from collections.abc import Iterable, Sequence
from copy import copy
from typing import Dict, Iterable, List, Optional, Sequence, Set
from typing import Optional
from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name, Node
@ -52,10 +53,10 @@ class _DependencyViewer:
self.downstreams[node].add(output_node)
self.downstreams[node].update(self.downstreams[output_node])
def downstreams_of(self, node: Node) -> Set[Node]:
def downstreams_of(self, node: Node) -> set[Node]:
return self.downstreams[node]
def upstreams_of(self, node: Node) -> Set[Node]:
def upstreams_of(self, node: Node) -> set[Node]:
return self.upstreams[node]
@ -84,21 +85,21 @@ class CapabilityBasedPartitioner:
dict(self.graph_module.named_modules()), node
)
def propose_partitions(self) -> List[Partition]:
def propose_partitions(self) -> list[Partition]:
# partition_map is a mapping from partition id to a set of partition id's.
# The value set contains all the partition ids that can be reached by doing a
# DFS starting from the partition id in the key.
partition_map: Dict[int, Set] = collections.defaultdict(set)
partition_map: dict[int, set] = collections.defaultdict(set)
# assumptions: nodes in candidate list is sorted in topological order
assignment: Dict[Node, int] = {} # mapping from node to partition_id
partitions_by_id: Dict[
assignment: dict[Node, int] = {} # mapping from node to partition_id
partitions_by_id: dict[
int, Partition
] = {} # mapping from partition_id to partition
nodes_order: Dict[
nodes_order: dict[
Node, int
] = {} # mapping from nodes to reversed topological order
partitions_order: Dict[
partitions_order: dict[
int, int
] = {} # mapping from partition_id to minimum topo order of nodes in partition
new_partition_id = itertools.count()
@ -111,7 +112,7 @@ class CapabilityBasedPartitioner:
merged_nodes = copy(partitions_by_id[self_id].nodes)
merged_nodes.update(partitions_by_id[other_id].nodes)
def dfs_iter_find_cycle(all_user_nodes: Set[Node]):
def dfs_iter_find_cycle(all_user_nodes: set[Node]):
for user_node in all_user_nodes:
visited_partition_ids = set()
@ -210,7 +211,7 @@ class CapabilityBasedPartitioner:
for node in reversed(self.graph_module.graph.nodes):
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
merge_candidates: Dict[int, None] = {}
merge_candidates: dict[int, None] = {}
# Note a limited horizontal fusion is enabled:
# when `node` is not supported, the code below attempts to fuse consumer of `node`.
@ -241,7 +242,7 @@ class CapabilityBasedPartitioner:
# post processing to re-assign "getitem" nodes into upstream partition
logger.debug("Reassigning getitem nodes to its producer node's partition...")
nodes_reassignment: Dict[Node, int] = {}
nodes_reassignment: dict[Node, int] = {}
for node in self.graph_module.graph.nodes:
is_tuple_output = True
for user in node.users:
@ -266,7 +267,7 @@ class CapabilityBasedPartitioner:
logger.debug("Filtering out single node partitions...")
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
partitions_to_remove: List[int] = []
partitions_to_remove: list[int] = []
for id, partition in partitions_by_id.items():
compute_node_count = 0
for node in partition.nodes:
@ -295,7 +296,7 @@ class CapabilityBasedPartitioner:
]
def fuse_partitions(
self, partitions: List[Partition], prefix: str = "fused_"
self, partitions: list[Partition], prefix: str = "fused_"
) -> GraphModule:
logger.debug("Fusing partitions...")
# fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ]
@ -306,7 +307,7 @@ class CapabilityBasedPartitioner:
)
# remove non-compute-ops that sits at the boundary of a partition.
def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
def remove_bookend_non_compute_ops(self, partitions: list[Partition]):
non_compute_ops = set(self.non_compute_ops)
def is_non_compute_node(node: Node):
@ -316,11 +317,11 @@ class CapabilityBasedPartitioner:
)
# cache transparent nodes
transparent_input_nodes: Dict[Node, bool] = {}
transparent_output_nodes: Dict[Node, bool] = {}
transparent_input_nodes: dict[Node, bool] = {}
transparent_output_nodes: dict[Node, bool] = {}
def is_transparent_input_node(
node: Node, partition: Set[Node], removed_nodes: Set[Node]
node: Node, partition: set[Node], removed_nodes: set[Node]
):
if (
node.op == "placeholder"
@ -341,7 +342,7 @@ class CapabilityBasedPartitioner:
return False
def is_transparent_output_node(
node: Node, partition: Set[Node], removed_nodes: Set[Node]
node: Node, partition: set[Node], removed_nodes: set[Node]
):
if (
node.op == "placeholder"
@ -367,7 +368,7 @@ class CapabilityBasedPartitioner:
# Note it's ok to use `set` here, since we are only query if a node
# has been removed. We are NEVER going to iterate on nodes inside
# the set.
remove_node: Set[Node] = set()
remove_node: set[Node] = set()
for node in partition.nodes:
if is_non_compute_node(node) and (
is_transparent_input_node(node, set(partition.nodes), remove_node)

View File

@ -3,7 +3,7 @@ import inspect
import logging
from functools import wraps
from queue import Queue
from typing import Callable, Dict, List
from typing import Callable
import torch.nn as nn
from torch.fx._compatibility import compatibility
@ -50,7 +50,7 @@ def pass_result_wrapper(fn: Callable) -> Callable:
def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
) -> None:
for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]):
@ -64,8 +64,8 @@ def _validate_pass_schedule_constraint(
def _topological_sort_passes(
passes: List[Callable], constraints: List[Callable]
) -> List[Callable]:
passes: list[Callable], constraints: list[Callable]
) -> list[Callable]:
"""
Args
passes: Passes that we are ordering
@ -79,8 +79,8 @@ def _topological_sort_passes(
return passes
# Contruct a graph mapping nodes to a list of their users
graph: Dict[Callable, List[Callable]] = {p: [] for p in passes}
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
graph: dict[Callable, list[Callable]] = {p: [] for p in passes}
indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0)
candidates: Queue = Queue()
for a in passes:
for b in passes:
@ -95,8 +95,8 @@ def _topological_sort_passes(
if indegree_map[a] == 0:
candidates.put(a)
visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
sorted_passes: List[Callable] = []
visited: dict[Callable, bool] = dict.fromkeys(passes, False)
sorted_passes: list[Callable] = []
while not candidates.empty():
p = candidates.get()
@ -169,8 +169,8 @@ class PassManager:
checks
"""
passes: List[Callable[[nn.Module], PassResult]]
constraints: List[Callable[[Callable, Callable], bool]]
passes: list[Callable[[nn.Module], PassResult]]
constraints: list[Callable[[Callable, Callable], bool]]
_validated: bool = False
steps: int = 1

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import torch
import torch.fx
@ -106,7 +106,7 @@ class _MinimizerBase:
module: torch.fx.GraphModule,
sample_input: Tensors,
compare_fn: Callable[
[TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
[TensorOrTensors, TensorOrTensors, Names], tuple[float, bool]
],
settings: _MinimizerSettingBase,
module_exporter: Optional[
@ -124,16 +124,16 @@ class _MinimizerBase:
self.exclusion_fn = exclusion_fn
# Stores outputs of run_a function
self.a_outputs: Dict[str, Any] = {}
self.a_outputs: dict[str, Any] = {}
# Stores outputs of run_b function
self.b_outputs: Dict[str, Any] = {}
self.b_outputs: dict[str, Any] = {}
# Stores the results of compare_fn
self.results: Dict[Any, Any] = {}
self.results: dict[Any, Any] = {}
# Stores the report for the runs
self.reports: List[List[str]] = []
self.reports: list[list[str]] = []
# Current iteration
self.iteration: int = 0
@ -205,7 +205,7 @@ class _MinimizerBase:
def _get_submod_inputs(
self, main_module: torch.fx.GraphModule, submod_path: str
) -> Tuple[Tensors, Tensors]:
) -> tuple[Tensors, Tensors]:
"""
Try get submodule inputs from stored outputs. If not found then use
torch_glow.get_submod_inputs to get the inputs.
@ -280,7 +280,7 @@ class _MinimizerBase:
else:
node.tag = "main_0"
def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]:
def _build_submodule(self, nodes: NodeSet) -> tuple[torch.fx.GraphModule, str]:
"""
Split self.module so that one submodule consists of `nodes` and only `nodes`.
@ -412,7 +412,7 @@ class _MinimizerBase:
culprits: NodeSet = set()
nodes: NodeList = all_nodes[start_idx:end_idx]
report: List[str] = []
report: list[str] = []
if self.exclusion_fn is not None:
self.exclusion_fn(nodes, start_idx, end_idx)
if len(nodes) == 0:
@ -484,7 +484,7 @@ class _MinimizerBase:
culprits: NodeSet = set()
for node in nodes:
report: List[str] = []
report: list[str] = []
self.reports.append(report)
self.iteration += 1
report.append(f"Sequential traverse iteration {self.iteration}.")
@ -534,7 +534,7 @@ class _MinimizerBase:
find_last_node: If True, search for the last node which result in numerics difference
if False: find first node in sorted node list
"""
report: List[str] = []
report: list[str] = []
mid = (start_idx + end_idx) // 2
cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:]
@ -726,7 +726,7 @@ class _MinimizerBase:
return culprits
for node in nodes:
report: List[str] = []
report: list[str] = []
self.reports.append(report)
self.iteration += 1
report.append(f"Accumulate traverse iteration {self.iteration}.")
@ -770,7 +770,7 @@ class _MinimizerBase:
for node in nodes:
if node in self.fusions:
cur_nodes.update(self.fusions[node])
report: List[str] = []
report: list[str] = []
self.reports.append(report)
self.iteration += 1
report.append(f" Nodes block {self.iteration}.")
@ -797,7 +797,7 @@ class _MinimizerBase:
self.print_report(report)
return set()
def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet:
def _skip_traverse(self, all_nodes: NodeList, skip_nodes: list) -> NodeSet:
"""
Skip certain nodes in graph based on settings
"""
@ -874,7 +874,7 @@ class _MinimizerBase:
) as e:
print(e)
def print_report(self, report: List[str]):
def print_report(self, report: list[str]):
for i in range(len(report)):
if i > 0:
print(" . " + report[i])
@ -889,7 +889,7 @@ class _MinimizerBase:
self,
start: Optional[str] = None,
end: Optional[str] = None,
skip_nodes: Optional[List] = None,
skip_nodes: Optional[list] = None,
find_last_node: Optional[bool] = None,
) -> NodeSet:
"""

View File

@ -24,9 +24,9 @@ TargetTypeName = str
# Arguments' dtypes for a given node, see `OperatorSupport`
SupportedArgumentDTypes = t.Optional[
t.Tuple[
tuple[
t.Sequence[t.Sequence[torch.dtype]],
t.Dict[str, t.Sequence[torch.dtype]],
dict[str, t.Sequence[torch.dtype]],
]
]
@ -204,7 +204,7 @@ class OpSupports:
return create_op_support(_decline_if_input_dtype)
@classmethod
def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
def decline_if_node_in_names(cls, disallow_set: set[str]) -> OperatorSupportBase:
"""
If a node has a name that is in the disallow set, reported it as non-supported.
"""

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Tuple, Type
from typing import Any, Callable
import torch
import torch.nn as nn
@ -23,7 +23,7 @@ def default_matching(name: str, target_version: int) -> str:
# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = {
module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = {
torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
torch.nn.modules.conv.Conv2d: (
1,
@ -55,11 +55,11 @@ module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]]
@compatibility(is_backward_compatible=False)
def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]:
"""If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
after checking module's version is compatible with the `module_fetch_book`.
"""
attrs_for_lowering: Dict[str, Any] = {}
attrs_for_lowering: dict[str, Any] = {}
attrs_for_lowering["name"] = torch.typename(mod)
if type(mod) in module_fetch_book:

View File

@ -2,7 +2,7 @@
import logging
from functools import wraps
from inspect import unwrap
from typing import Callable, List, Optional
from typing import Callable, Optional
logger = logging.getLogger(__name__)
@ -121,7 +121,7 @@ def loop_pass(
# Implemented as 'depends on' operators. A constraint is satisfied iff a list
# has a valid partial ordering according to this comparison operator.
def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
):
for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]):
@ -191,8 +191,8 @@ class PassManager:
`this_before_that_pass_constraint` for example.
"""
passes: List[Callable]
constraints: List[Callable]
passes: list[Callable]
constraints: list[Callable]
_validated: bool = False
def __init__(
@ -217,7 +217,7 @@ class PassManager:
self.constraints.append(constraint)
self._validated = False
def remove_pass(self, _passes: List[str]):
def remove_pass(self, _passes: list[str]):
if _passes is None:
return
passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]

View File

@ -3,7 +3,6 @@ import _operator
import itertools
from collections import defaultdict
from enum import Enum
from typing import Dict, Set
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -199,7 +198,7 @@ _VIEW_INVERSE_MAP = {
# This function, given a set of set of (aliased) tensor nodes,
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
# in the node ordering.
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int):
def _add_if_tensor(x, set_):
if isinstance(x, FakeTensor):
set_.add(StorageWeakRef(x._typed_storage()))
@ -233,8 +232,8 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
# as "alias"
def _get_view_inverse_node_usages(
later_node_usages: Set[Node], self_aliases: Set[Node]
) -> Set[Node]:
later_node_usages: set[Node], self_aliases: set[Node]
) -> set[Node]:
def matching_view_metadata(a, b):
return (
a.size() == b.size()
@ -515,7 +514,7 @@ def reinplace(gm, *sample_args):
}
# We also need to know for a given node, what are all of its aliasing nodes.
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
storage_to_nodes: dict[StorageWeakRef, set[Node]] = defaultdict(set)
for n in gm.graph.nodes:
if "fake_result" in n.meta:
# Tree-mapping because some ops can return lists of tensors.

View File

@ -3,7 +3,7 @@ import functools
import logging
import operator
import sys
from typing import Any, Dict, Optional, Set, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
@ -123,7 +123,7 @@ def insert_deferred_runtime_asserts(
)
# We are going to mutate the dict
expr_to_proxy: Dict[sympy.Expr, fx.Proxy] = {}
expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {}
placeholders = set()
first_non_placeholder = None
for node in graph.nodes:
@ -163,7 +163,7 @@ def insert_deferred_runtime_asserts(
def _node_metadata_hook(
node: torch.fx.Node,
stack_trace: Optional[str] = None,
nn_module_stack: Optional[Dict[str, Any]] = None,
nn_module_stack: Optional[dict[str, Any]] = None,
) -> None:
fake_args = pytree.tree_map(
lambda arg: (
@ -189,8 +189,8 @@ def insert_deferred_runtime_asserts(
node.meta["nn_module_stack"] = nn_module_stack
# Track asserts/checks we've added
added_asserts: Set[sympy.Expr] = set()
constrained_unbacked_symbols: Set[sympy.Symbol] = set()
added_asserts: set[sympy.Expr] = set()
constrained_unbacked_symbols: set[sympy.Symbol] = set()
Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis

View File

@ -1,7 +1,7 @@
# mypy: ignore-errors
import traceback
from typing import Any, Dict, NamedTuple, Optional, Tuple
from typing import Any, NamedTuple, Optional
import torch
import torch.fx
@ -24,12 +24,12 @@ class TensorMetadata(NamedTuple):
shape: torch.Size
dtype: torch.dtype
requires_grad: bool
stride: Tuple[int, ...]
stride: tuple[int, ...]
memory_format: Optional[torch.memory_format]
# Quantization metadata
is_quantized: bool
qparams: Dict[str, Any]
qparams: dict[str, Any]
def _extract_tensor_metadata(
@ -57,7 +57,7 @@ def _extract_tensor_metadata(
break
is_quantized = result.is_quantized
qparams: Dict[str, Any] = {}
qparams: dict[str, Any] = {}
if is_quantized:
qscheme = result.qscheme()
qparams["qscheme"] = qscheme

View File

@ -2,7 +2,7 @@
import inspect
import logging
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Callable, Optional
import torch
from torch.fx._compatibility import compatibility
@ -20,14 +20,14 @@ class Partition:
def __init__(self, name: str):
self.name: str = name
self.submod_name = f"submod_{name}"
self.node_names: List[str] = []
self.inputs: Dict[str, None] = {}
self.outputs: Dict[str, None] = {}
self.dependencies: Dict[str, None] = {}
self.dependents: Dict[str, None] = {}
self.node_names: list[str] = []
self.inputs: dict[str, None] = {}
self.outputs: dict[str, None] = {}
self.dependencies: dict[str, None] = {}
self.dependents: dict[str, None] = {}
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
self.environment: Dict[Node, Node] = {}
self.targets: Dict[str, Any] = {}
self.environment: dict[Node, Node] = {}
self.targets: dict[str, Any] = {}
def __repr__(self) -> str:
return (
@ -55,7 +55,7 @@ def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[Node], int],
qualname_map: Optional[Dict[str, str]] = None,
qualname_map: Optional[dict[str, str]] = None,
keep_original_order: Optional[bool] = False,
keep_original_node_name: Optional[bool] = False,
):
@ -161,8 +161,8 @@ def split_module(
def construct_graph(
node: Node,
base_mod_env: Dict[str, Node],
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule],
base_mod_env: dict[str, Node],
base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule],
):
if node.op == "placeholder":
default_value = (
@ -195,9 +195,9 @@ def split_module(
import sympy
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, Node] = {}
symbol_to_node: Dict[sympy.Symbol, Node] = {}
partitions: dict[str, Partition] = {}
orig_nodes: dict[str, Node] = {}
symbol_to_node: dict[sympy.Symbol, Node] = {}
def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
from torch.fx.experimental.symbolic_shapes import free_symbols
@ -273,7 +273,7 @@ def split_module(
# ------------------------
# 1. first region: we do nothing
# 2. subsequent regions: we insert the set_grad at the beginning
grad_regions: OrderedDict[Node, Set[int]] = OrderedDict()
grad_regions: OrderedDict[Node, set[int]] = OrderedDict()
# For autocast regions:
# ------------------------
@ -282,8 +282,8 @@ def split_module(
# _enter at the beginning and _exit at the end
# 3. last region: we will only insert _enter at the beginning
# We will do so in the order in which the autocasts were instantiated.
autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict()
autocast_exits: Dict[Node, Optional[Node]] = {}
autocast_regions: OrderedDict[Node, set[int]] = OrderedDict()
autocast_exits: dict[Node, Optional[Node]] = {}
active_grad = None
active_autocasts = set()
@ -379,13 +379,13 @@ def split_module(
original_partition_order = list(partitions.keys())
# find partitions with no dependencies
root_partitions: List[str] = []
root_partitions: list[str] = []
for partition_name, partition in partitions.items():
if not len(partition.dependencies):
root_partitions.append(partition_name)
# check partitions for circular dependencies and create topological partition ordering
sorted_partitions: List[str] = []
sorted_partitions: list[str] = []
while root_partitions:
root_partition = root_partitions.pop()
sorted_partitions.append(root_partition)
@ -418,7 +418,7 @@ def split_module(
# add placeholders to partition inputs
for partition_name in sorted_partitions:
partition = partitions[partition_name]
new_inputs: Dict[str, None] = {}
new_inputs: dict[str, None] = {}
for inp in partition.inputs:
orig_node = orig_nodes[inp]
# We don't pass in get_attr nodes as inputs to the partition, but
@ -507,11 +507,11 @@ def split_module(
) # is it really a good idea to copy this?
# original module environment dict mapping node names to nodes
orig_mod_env: Dict[str, Node] = {}
orig_mod_env: dict[str, Node] = {}
# Set up values to construct base module
base_mod_env: Dict[str, Node] = {}
base_mod_env: dict[str, Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {}
if not keep_original_order:
for node in m.graph.nodes:
base_mod_env, base_mod_attrs = construct_graph(
@ -559,7 +559,7 @@ def split_module(
if keep_original_order:
# first get the attr nodes required by this partition
orig_mod_attr_nodes: List[Node] = [
orig_mod_attr_nodes: list[Node] = [
orig_mod_env[key]
for key in partition.inputs
if key not in original_order

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import copy
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Optional, Union
import torch.fx
from torch.fx._compatibility import compatibility
@ -45,28 +45,28 @@ class Component:
name: str
# Stores the placeholder nodes in `graph`.
input_placeholders: List = field(default_factory=list)
input_placeholders: list = field(default_factory=list)
# Store the nodes in original graph that are placeholder in `graph`.
orig_inputs: List = field(default_factory=list)
orig_inputs: list = field(default_factory=list)
# Store the nodes in original graph that are outputs in `graph`.
orig_outputs: List = field(default_factory=list)
orig_outputs: list = field(default_factory=list)
# Mapping from get_attr node in original graph to get_attr node in `graph`.
getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
constructor_args: List[str] = field(default_factory=list)
getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
constructor_args: list[str] = field(default_factory=list)
gm: Optional[torch.fx.GraphModule] = None
@compatibility(is_backward_compatible=False)
def split_by_tags(
gm: torch.fx.GraphModule,
tags: List[str],
tags: list[str],
return_fqn_mapping: bool = False,
return_tuple: bool = False,
GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule,
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule,
) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]:
"""
Splits a GraphModule using tags on its graph nodes. We honor the order of
tags. For example, we have tags = ["a", "b", "c"], the function will create
@ -133,26 +133,26 @@ def split_by_tags(
return r
# Mapping from node in original module to node in created submodule.
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
# Mapping from node in original module or created submodules to
# corresponding component.
node_to_component: Dict[torch.fx.Node, Component] = {}
node_to_component: dict[torch.fx.Node, Component] = {}
# Mapping from tag to the corresponding component.
tag_to_component: Dict[str, Component] = {}
tag_to_component: dict[str, Component] = {}
# Stores all components.
all_components: List[Component] = []
all_components: list[Component] = []
# Stores nodes that will be used in main graph.
used_in_main: Dict[torch.fx.Node, None] = {}
used_in_main: dict[torch.fx.Node, None] = {}
# Main graph after split.
main_g = torch.fx.Graph()
# Mapping from node in original module to node in main graph after split.
main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
main_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
# Output node of original module.
output_node: Optional[torch.fx.Node] = None
@ -258,7 +258,7 @@ def split_by_tags(
node_to_component[n].orig_outputs.append(n)
# Now we create a graphmodule for each component.
orig_to_split_fqn_mapping: Dict[str, str] = {}
orig_to_split_fqn_mapping: dict[str, str] = {}
for comp in all_components:
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))

View File

@ -3,8 +3,9 @@ import argparse
import copy
import logging
from collections import defaultdict
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple
from typing import Any, NamedTuple, Optional
import torch
from torch.fx._compatibility import compatibility
@ -225,7 +226,7 @@ class SplitResult(NamedTuple):
"""
split_module: torch.fx.GraphModule
submodule_inputs: Dict[str, Any]
submodule_inputs: dict[str, Any]
non_acc_submodule_prefix: str
@ -235,7 +236,7 @@ def generate_inputs_for_submodules(
inputs: Sequence[Any],
target_submodules: Iterable[str],
deepcopy: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
function doesn't work.
@ -365,16 +366,16 @@ class _SplitterBase:
self.update_deps_for_fusions()
self.non_acc_submodule_name = non_acc_submodule_name
self._node_submodule_map: Dict[str, str] = {}
self._node_submodule_map: dict[str, str] = {}
self._return_tuple = return_tuple
self.tags: List[str] = []
self.tags: list[str] = []
# ===============================================================
# Helpers for ctor and initial state
# ===============================================================
def get_node_submodule_map(self) -> Dict[str, str]:
def get_node_submodule_map(self) -> dict[str, str]:
"""Returns a map from node name to submodule name, e.g.
node: main_module_impl_impl_over_arch_unary_multiple_embedding
_pooling_embedding_pooling_sparse_entity_equivalence_key
@ -383,7 +384,7 @@ class _SplitterBase:
"""
return self._node_submodule_map
def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
def find_deps(self) -> dict[torch.fx.Node, NodeSet]:
"""
Builds a graph of node dependencies. Leaf nodes don't have any
dependencies and the "output" node doesn't have nodes depending on it.
@ -391,7 +392,7 @@ class _SplitterBase:
Resulting graph has only direct dependencies, i.e. there are no
transitive dependencies.
"""
deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
deps: dict[torch.fx.Node, NodeSet] = defaultdict(set)
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
@ -647,12 +648,12 @@ class _SplitterBase:
def find_reverse_deps(
self, tag_id: Optional[int] = None
) -> Dict[torch.fx.Node, NodeSet]:
) -> dict[torch.fx.Node, NodeSet]:
"""
Builds reversed topological node dependencies, if tag_id is specified,
we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
"""
result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
result: dict[torch.fx.Node, NodeSet] = defaultdict(set)
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
@ -667,7 +668,7 @@ class _SplitterBase:
return result
def update_reverse_deps_for_fusions(self, deps: Dict[torch.fx.Node, NodeSet]):
def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]):
processed_node = set()
for node, fusion in self.fusions.items():
@ -757,7 +758,7 @@ class _SplitterBase:
# Helpers for split() method
# ===============================================================
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
def starter_nodes(self) -> tuple[NodeSet, NodeSet]:
"""
Finds nodes that consume module inputs or get_attr nodes.
"""
@ -773,7 +774,7 @@ class _SplitterBase:
starter_cpu_nodes.add(user)
return starter_cpu_nodes, starter_acc_nodes
def put_nodes_into_subgraphs(self) -> List[Subgraph]:
def put_nodes_into_subgraphs(self) -> list[Subgraph]:
# We start graph traversal from leaf nodes
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
visited_nodes: NodeSet = set()
@ -785,7 +786,7 @@ class _SplitterBase:
current_subgraph_nodes: NodeList = []
# Result accumulator
subgraphs: List[Subgraph] = []
subgraphs: list[Subgraph] = []
while current_cpu_nodes or current_acc_nodes:
# Find the first node that should belong to the current subgraph and has all dependencies resolved
current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
@ -839,12 +840,12 @@ class _SplitterBase:
return subgraphs
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]:
"""
This pass finds ACC submodules with less than specified size and merges
them with adjacent CPU submodules.
"""
result: List[Subgraph] = []
result: list[Subgraph] = []
for subgraph in subgraphs:
if subgraph.is_acc:
if len(subgraph.nodes) >= self.settings.min_acc_module_size:
@ -866,7 +867,7 @@ class _SplitterBase:
result.append(subgraph)
return result
def tag(self, subgraphs: List[Subgraph]):
def tag(self, subgraphs: list[Subgraph]):
self.tags = []
for subgraph in subgraphs:
tag = (

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import collections
import operator
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch.fx
@ -18,11 +19,11 @@ __all__ = [
"legalize_graph",
]
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]]
TensorOrTensors = Union[torch.Tensor, Tensors]
NodeList = List[torch.fx.Node]
NodeSet = Set[torch.fx.Node]
Names = List[str]
NodeList = list[torch.fx.Node]
NodeSet = set[torch.fx.Node]
Names = list[str]
CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
@ -172,8 +173,8 @@ class FxNetAccFusionsFinder:
return False
def __call__(self) -> Dict[torch.fx.Node, NodeSet]:
result: Dict[torch.fx.Node, NodeSet] = {}
def __call__(self) -> dict[torch.fx.Node, NodeSet]:
result: dict[torch.fx.Node, NodeSet] = {}
acc_nodes = list(self.acc_nodes)
for node in acc_nodes:
@ -294,7 +295,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if indeg[node] == 0:
queue.append(node)
env: Dict[torch.fx.Node, torch.fx.Node] = {}
env: dict[torch.fx.Node, torch.fx.Node] = {}
# Pop nodes from the queue, and add nodes that have had all their
# dependencies fulfilled
while len(queue) > 0:

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Dict, Tuple
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
@ -30,7 +29,7 @@ def lift_subgraph_as_module(
subgraph: Graph,
comp_name: str = "",
class_name: str = "GraphModule",
) -> Tuple[GraphModule, Dict[str, str]]:
) -> tuple[GraphModule, dict[str, str]]:
"""
Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
@ -52,7 +51,7 @@ def lift_subgraph_as_module(
# make "weight" a attribute of "conv" HolderModule and point to conv.weight in
# the original module.
submodule = HolderModule({})
orig_to_split_fqn_mapping: Dict[str, str] = {}
orig_to_split_fqn_mapping: dict[str, str] = {}
for n in subgraph.nodes:
if n.op not in ("call_module", "get_attr"):
continue

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import copy
from queue import SimpleQueue
from typing import Dict, List, Optional as _Optional, Tuple
from typing import Optional as _Optional
import torch.fx
from torch.fx._compatibility import compatibility
@ -97,10 +97,10 @@ def fuse_as_graphmodule(
gm: GraphModule,
nodes: NodeList,
module_name: str,
partition_lookup_table: _Optional[Dict[Node, None]] = None,
partition_lookup_table: _Optional[dict[Node, None]] = None,
*,
always_return_tuple: bool = False,
) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]:
"""
Fuse nodes in graph_module into a GraphModule.
@ -144,10 +144,10 @@ def fuse_as_graphmodule(
subgraph = Graph()
node_to_placeholder: Dict[
node_to_placeholder: dict[
Node, Node
] = {} # mapping of nodes from old graph to placeholder in new graph
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
node_map: dict[Node, Node] = {} # mapping of nodes from old graph to new graph
# handles inputs through graph.node_copy's arg_transform functions
def remap_inputs(x):
@ -176,7 +176,7 @@ def fuse_as_graphmodule(
node_map[node] = new_node
# handles outputs
output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs
output_mapping: dict[Node, Node] = {} # mapping from old output to new outputs
for node in nodes:
for user_node in node.users:
@ -202,10 +202,10 @@ def fuse_as_graphmodule(
)
# sub_gm's input nodes in the original module
original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
original_inputs: tuple[Node, ...] = tuple(node_to_placeholder.keys())
# sub_gm's outputs node in the original module
original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys())
original_outputs: tuple[Node, ...] = tuple(output_mapping.keys())
return fused_gm, original_inputs, original_outputs
@ -214,8 +214,8 @@ def fuse_as_graphmodule(
def insert_subgm(
gm: GraphModule,
sub_gm: GraphModule,
orig_inputs: Tuple[Node, ...],
orig_outputs: Tuple[Node, ...],
orig_inputs: tuple[Node, ...],
orig_outputs: tuple[Node, ...],
):
# add sub_gm into gm
submodule_name = sub_gm.__class__.__name__
@ -250,7 +250,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList):
@compatibility(is_backward_compatible=False)
def fuse_by_partitions(
gm: GraphModule,
partitions: List[Dict[Node, None]],
partitions: list[dict[Node, None]],
prefix: str = "fused_",
always_return_tuple: bool = False,
) -> GraphModule:

View File

@ -4,7 +4,7 @@ import logging
import os
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Set, Tuple, Union
from typing import Any, Union
import torch
from torch.fx import Graph, Node
@ -37,19 +37,19 @@ logger = _init_logger()
@dataclass
class InternalMatch:
# Nodes from which the match was found
anchors: List[Node]
anchors: list[Node]
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] = field(default_factory=dict)
nodes_map: dict[Node, Node] = field(default_factory=dict)
# nodes in target graph that are matched placeholder in pattern
placeholder_nodes: List[Node] = field(default_factory=list)
placeholder_nodes: list[Node] = field(default_factory=list)
# nodes in matched subgraph returned by output
returning_nodes: List[Node] = field(default_factory=list)
returning_nodes: list[Node] = field(default_factory=list)
# map from a string name to a node in the target graph
# only available if the matcher is `SubgraphMatcherWithNameNodesMap`
name_node_map: Dict[str, Node] = field(default_factory=dict)
name_node_map: dict[str, Node] = field(default_factory=dict)
def __copy__(self):
return InternalMatch(
@ -107,9 +107,9 @@ class SubgraphMatcher:
]
output_node = next(iter(reversed(pattern.nodes)))
# nodes returned by outputs
self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes
self.pattern_returning_nodes: list[Node] = output_node.all_input_nodes
self.pattern_anchors: List[Node] = []
self.pattern_anchors: list[Node] = []
if match_output:
self.pattern_anchors = [output_node]
else:
@ -150,12 +150,12 @@ class SubgraphMatcher:
return pn.target == gn.target
return False
def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
def _is_contained(self, nodes_map: dict[Node, Node]) -> bool:
# `lookup` represents all the nodes in `original_graph`
# that are part of `pattern`
# Placeholders can be used by other nodes in the graphs
lookup: Dict[Node, Node] = {
lookup: dict[Node, Node] = {
gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder"
}
@ -172,10 +172,10 @@ class SubgraphMatcher:
return True
def _remove_overlapping_matches(
self, matches: List[InternalMatch]
) -> List[InternalMatch]:
non_overlapping_matches: List[InternalMatch] = []
nodes_matched: Set[Node] = set()
self, matches: list[InternalMatch]
) -> list[InternalMatch]:
non_overlapping_matches: list[InternalMatch] = []
nodes_matched: set[Node] = set()
for match in matches:
found_overlap = False
@ -244,7 +244,7 @@ class SubgraphMatcher:
# match for `gn`
match_found = True
def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool:
def _match_args(args1: Union[list, tuple], args2: Union[list, tuple]) -> bool:
if len(args1) != len(args2):
return False
@ -313,7 +313,7 @@ class SubgraphMatcher:
return True
def match(self, graph: Graph) -> List[InternalMatch]:
def match(self, graph: Graph) -> list[InternalMatch]:
"""
Returns:
The matched subgraphs.
@ -352,7 +352,7 @@ class SubgraphMatcher:
from torch.fx.passes.utils.fuser_utils import validate_partition
# find candidate nodes to match with pattern anchors
match_candidates: Dict[Node, List[Node]] = defaultdict(list)
match_candidates: dict[Node, list[Node]] = defaultdict(list)
for pattern_anchor in self.pattern_anchors:
for node in graph.nodes:
if self._nodes_are_equal(pattern_anchor, node):
@ -361,7 +361,7 @@ class SubgraphMatcher:
logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
matches: List[InternalMatch] = []
matches: list[InternalMatch] = []
def backtracking(anchor_index, match):
if anchor_index == len(match_candidates_list):

View File

@ -1,5 +1,3 @@
from typing import Dict, List, Tuple
from torch.fx import Graph, GraphModule, Node
from torch.fx._compatibility import compatibility
@ -11,7 +9,7 @@ __all__ = ["SubgraphMatcherWithNameNodeMap"]
def _split_to_graph_and_name_node_map(
gm: GraphModule,
) -> Tuple[GraphModule, Dict[str, Node]]:
) -> tuple[GraphModule, dict[str, Node]]:
from torch.fx.graph import _PyTreeInfo
from torch.utils._pytree import tree_flatten, tree_unflatten
@ -29,7 +27,7 @@ def _split_to_graph_and_name_node_map(
*out, name_node_map = output
flattened, out_spec = tree_flatten(out)
assert isinstance(
name_node_map, Dict
name_node_map, dict
), "Expecting the input graph to have a dict output as the last element"
n.args = (flattened,)
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
@ -88,7 +86,7 @@ class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
ignore_literals,
)
def match(self, graph: Graph) -> List[InternalMatch]:
def match(self, graph: Graph) -> list[InternalMatch]:
"""The returned InternalMatch will have name_node_map populated with a map
from node name (str) to the target node, e.g.
{"conv": target_conv_ndoe, "relu": target_relu_node}

View File

@ -1,7 +1,7 @@
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Optional
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
@ -34,29 +34,29 @@ logger = _init_logger()
@dataclass
class SourcePartition:
# Nodes in a particular partition
nodes: List[Node]
nodes: list[Node]
# The source these nodes decomposed from
source: Any
# Nodes in the graph that are needed as inputs to the partition
# These do not include the params of the partition
input_nodes: List[Node] = field(default_factory=list)
input_nodes: list[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the
# partition
output_nodes: List[Node] = field(default_factory=list)
output_nodes: list[Node] = field(default_factory=list)
# Parameters that are being used
params: List[Node] = field(default_factory=list)
params: list[Node] = field(default_factory=list)
@compatibility(is_backward_compatible=False) # type: ignore[misc]
def get_source_partitions(
graph: Graph,
wanted_sources: List[Any],
wanted_sources: list[Any],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Dict[Any, List[SourcePartition]]:
) -> dict[Any, list[SourcePartition]]:
"""
Args:
graph: The graph we want to partition
@ -69,7 +69,7 @@ def get_source_partitions(
that correspond to the list of nodes that were decomposed from the given
source.
"""
modules: Dict[Type, Dict[str, List[Node]]] = {}
modules: dict[type, dict[str, list[Node]]] = {}
for node in graph.nodes:
# The metadata source_fn should contain a tuple of a unique name for the
@ -98,7 +98,7 @@ def get_source_partitions(
partition = diff_modules.setdefault(source_fn[0], [])
partition.append(node)
def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
def make_partition(nodes: list[Node], module_type: type) -> SourcePartition:
input_nodes = set()
output_nodes = set()
params = set()
@ -124,7 +124,7 @@ def get_source_partitions(
list(params), # type: ignore[arg-type]
)
ret: Dict[Type[Any], List[SourcePartition]] = {}
ret: dict[type[Any], list[SourcePartition]] = {}
if filter_fn:
# for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the

View File

@ -8,8 +8,10 @@ import inspect
import logging
import operator
import sys
from collections import OrderedDict
from collections.abc import Iterator
from dataclasses import fields, is_dataclass
from typing import Any, Callable, Dict, Iterator, Optional, OrderedDict, Tuple
from typing import Any, Callable, Optional
import torch
import torch.fx.traceback as fx_traceback
@ -135,18 +137,18 @@ class TracerBase:
scope: Scope
# Records the module call stack
module_stack: OrderedDict[str, Tuple[str, Any]]
module_stack: OrderedDict[str, tuple[str, Any]]
# Mapping of node name to module scope
node_name_to_scope: Dict[str, Tuple[str, type]]
node_name_to_scope: dict[str, tuple[str, type]]
@compatibility(is_backward_compatible=True)
def create_node(
self,
kind: str,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
) -> Node:
@ -171,7 +173,7 @@ class TracerBase:
# Optionally set stack trace on the created Node for debugging purposes
if fx_traceback.has_preserved_node_meta():
current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
current_meta: dict[str, Any] = fx_traceback.get_current_meta()
stack_trace = current_meta.get("stack_trace")
if stack_trace:
@ -211,8 +213,8 @@ class TracerBase:
self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
# fix noqa when updating bc tests
@ -455,10 +457,10 @@ class Proxy:
# we peephole optimize to the method invocation
return Attribute(self, k)
def __getstate__(self) -> Dict:
def __getstate__(self) -> dict:
return self.__dict__
def __deepcopy__(self, memo) -> Dict:
def __deepcopy__(self, memo) -> dict:
# We have to explicitly override this method, because otherwise deepcopy
# will go to __getattr__(self, "__deepcopy__") and return a
# Attribute(__deepcopy__), and may go into an infinite loop in some cases.
@ -564,7 +566,7 @@ class Proxy:
args = args if args else ()
kwargs = kwargs if kwargs else {}
tracers: Dict[Any, None] = {}
tracers: dict[Any, None] = {}
def find_tracer(a):
if isinstance(a, cls):

View File

@ -1,16 +1,6 @@
import copy
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Set,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
import torch
@ -37,7 +27,7 @@ class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
nodes_map: dict[Node, Node]
@compatibility(is_backward_compatible=False)
@ -46,9 +36,9 @@ class ReplacedPatterns:
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
nodes_map: dict[Node, Node]
# List of nodes that were added into the graph
replacements: List[Node]
replacements: list[Node]
def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
@ -106,7 +96,7 @@ def replace_pattern(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
) -> List[Match]:
) -> list[Match]:
"""
Matches all possible non-overlapping sets of operators and their
data dependencies (``pattern``) in the Graph of a GraphModule
@ -237,14 +227,14 @@ def replace_pattern_with_filters(
pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule, None] = None,
match_filters: Optional[
List[Callable[["InternalMatch", Graph, Graph], bool]]
list[Callable[["InternalMatch", Graph, Graph], bool]]
] = None,
ignore_literals: bool = False,
# Placed at the end to avoid breaking backward compatibility
replacement_callback: Optional[
Callable[["InternalMatch", Graph, Graph], Graph]
] = None,
) -> List[ReplacedPatterns]:
) -> list[ReplacedPatterns]:
"""
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
@ -268,14 +258,14 @@ def _replace_pattern(
pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule, None] = None,
match_filters: Optional[
List[Callable[["InternalMatch", Graph, Graph], bool]]
list[Callable[["InternalMatch", Graph, Graph], bool]]
] = None,
ignore_literals: bool = False,
# Placed at the end to avoid breaking backward compatibility
replacement_callback: Optional[
Callable[["InternalMatch", Graph, Graph], Graph]
] = None,
) -> List[ReplacedPatterns]:
) -> list[ReplacedPatterns]:
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
if match_filters is None:
@ -298,7 +288,7 @@ def _replace_pattern(
remove_overlapping_matches=True,
ignore_literals=ignore_literals,
)
_matches: List[InternalMatch] = matcher.match(original_graph)
_matches: list[InternalMatch] = matcher.match(original_graph)
# Filter out matches that don't match the filter
_matches = [
@ -323,7 +313,7 @@ def _replace_pattern(
common_replacement_graph = None
# As we progressively replace nodes, we'll need to keep track of how the match results should change
match_changed_node: Dict[Node, Node] = {}
match_changed_node: dict[Node, Node] = {}
match_and_replacements = []
for match in _matches:
@ -345,7 +335,7 @@ def _replace_pattern(
# Initialize `val_map` with mappings from placeholder nodes in
# `replacement` to their corresponding node in `original_graph`
assert len(match.placeholder_nodes) == len(replacement_placeholders)
val_map: Dict[Node, Node] = {}
val_map: dict[Node, Node] = {}
for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
if isinstance(gn, Node):
val_map[rn] = match_changed_node.get(gn, gn)
@ -361,7 +351,7 @@ def _replace_pattern(
val_map[rn] = gn
# Copy the replacement graph over
user_nodes: Set[Node] = set()
user_nodes: set[Node] = set()
for n in match.returning_nodes:
user_nodes.update(n.users)
@ -402,7 +392,7 @@ def _replace_pattern(
copied_returning_nodes = (copied_returning_nodes,)
# Get a list of nodes that have been replaced into the graph
replacement_nodes: List[Node] = [
replacement_nodes: list[Node] = [
v for v in val_map.values() if v not in match.placeholder_nodes
]

View File

@ -4,7 +4,7 @@ import json
import traceback
from contextlib import contextmanager
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from ._compatibility import compatibility
from .graph import Graph
@ -25,7 +25,7 @@ __all__ = [
"get_graph_provenance_json",
]
current_meta: Dict[str, Any] = {}
current_meta: dict[str, Any] = {}
should_preserve_node_meta = False
@ -49,15 +49,15 @@ class NodeSource:
self.graph_id = graph_id
pass_name: str
action: List["NodeSourceAction"]
from_node: List["NodeSource"]
action: list["NodeSourceAction"]
from_node: list["NodeSource"]
node_info: Optional["NodeInfo"]
def __init__(
self,
node: Optional[Node],
pass_name: str = "",
action: Optional[Union["NodeSourceAction", List["NodeSourceAction"]]] = None,
action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
):
self.pass_name = pass_name
@ -146,7 +146,7 @@ def preserve_node_meta(enable=True):
@compatibility(is_backward_compatible=False)
def set_stack_trace(stack: List[str]):
def set_stack_trace(stack: list[str]):
global current_meta
if should_preserve_node_meta and stack:
@ -182,7 +182,7 @@ def reset_grad_fn_seq_nr():
@compatibility(is_backward_compatible=False)
def format_stack() -> List[str]:
def format_stack() -> list[str]:
if should_preserve_node_meta:
return [current_meta.get("stack_trace", "")]
else:
@ -219,7 +219,7 @@ def set_current_meta(node, pass_name=""):
@compatibility(is_backward_compatible=False)
def get_current_meta() -> Dict[str, Any]:
def get_current_meta() -> dict[str, Any]:
return current_meta