diff --git a/test/test_fx.py b/test/test_fx.py index 239db0a597ec..b76dca02e166 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3,6 +3,7 @@ import builtins import contextlib +import collections import copy import functools import inspect @@ -2767,7 +2768,7 @@ class TestFX(JitTestCase): return self.other(x) traced = symbolic_trace(ReturnTypeModule()) - self.assertIn("-> typing_List[str]", traced._code) + self.assertIn("-> list[str]", traced._code) scripted = torch.jit.script(traced) self.assertIn("-> List[str]", scripted.code) @@ -3566,8 +3567,8 @@ class TestFX(JitTestCase): traced(x, y) - FileCheck().check("_Tuple[()]") \ - .check("typing_Tuple[str,typing_Tuple[()]]") \ + FileCheck().check("tuple[()]") \ + .check("tuple[str,tuple[()]]") \ .run(traced.code) scripted = torch.jit.script(traced) @@ -4063,45 +4064,62 @@ class TestFXAPIBackwardCompatibility(JitTestCase): return f'{fn_name}({", ".join(arg_strs)}){return_annot}' - def _annotation_type_to_stable_str(self, t, sig_str): + _trivial_mappings = { + str : 'str', + int : 'int', + float: 'float', + bool: 'bool', + torch.dtype: 'torch.dtype', + torch.Tensor: 'torch.Tensor', + torch.device: 'torch.device', + torch.memory_format: 'torch.memory_format', + slice: 'slice', + torch.nn.Module: 'torch.nn.modules.module.Module', + torch.fx.Graph : 'torch.fx.graph.Graph', + torch.fx.Node : 'torch.fx.node.Node', + torch.fx.Proxy : 'torch.fx.proxy.Proxy', + torch.fx.node.Target : 'torch.fx.node.Target', + torch.fx.node.Argument : 'torch.fx.node.Argument', + torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode', + torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule', + torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match', + Ellipsis : '...', + typing.Any: 'Any', + type(None): 'NoneType', + None: 'None', + typing.Iterator: 'Iterator', + collections.abc.Iterator: 'Iterator', + } + + _UNBOUND_TYPES = { + dict, + list, + tuple, + type, + typing.Callable, + typing.Dict, + typing.List, + typing.Tuple, + typing.Type, + typing.Union, + } + + def _annotation_type_to_stable_str(self, t, sig_str, recursive: bool = False): if t is inspect.Signature.empty: return '' # Forward ref if isinstance(t, str): - return f"'{t}'" + if recursive: + return t + else: + return f"'{t}'" if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef): return t.__forward_arg__ if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef): return t.__forward_arg__ - trivial_mappings = { - str : 'str', - int : 'int', - float: 'float', - bool: 'bool', - torch.dtype: 'torch.dtype', - torch.Tensor: 'torch.Tensor', - torch.device: 'torch.device', - torch.memory_format: 'torch.memory_format', - slice: 'slice', - torch.nn.Module: 'torch.nn.modules.module.Module', - torch.fx.Graph : 'torch.fx.graph.Graph', - torch.fx.Node : 'torch.fx.node.Node', - torch.fx.Proxy : 'torch.fx.proxy.Proxy', - torch.fx.node.Target : 'torch.fx.node.Target', - torch.fx.node.Argument : 'torch.fx.node.Argument', - torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode', - torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule', - torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match', - Ellipsis : '...', - typing.Any: 'Any', - type(None): 'NoneType', - None: 'None', - typing.Iterator: 'Iterator', - } - - mapping = trivial_mappings.get(t, None) + mapping = self._trivial_mappings.get(t, None) if mapping: return mapping @@ -4115,14 +4133,14 @@ class TestFXAPIBackwardCompatibility(JitTestCase): if all(isinstance(ct, typing.TypeVar) for ct in contained): contained = [] - contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained] + contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str, True) for ct in contained] contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else '' origin = getattr(t, '__origin__', None) if origin is None: # Unbound types don't have `__origin__` in some Python versions, so fix that up here. - origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin + origin = t if t in self._UNBOUND_TYPES else origin if origin in {tuple, typing.Tuple}: return f'Tuple{contained_type_str}' @@ -4130,7 +4148,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase): # Annoying hack to detect Optional if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)): not_none_param = contained[0] if contained[0] is not type(None) else contained[1] - return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]' + return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str, True)}]' return f'Union{contained_type_str}' if origin in {dict, typing.Dict}: return f'Dict{contained_type_str}' diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 61c5687bf886..3373a1bb98b6 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1524,6 +1524,29 @@ class {test_classname}(torch.nn.Module): (int, type(torch.float)), (Union[int, float], int), (Union[int, float], float), + (list[int], int), + (list[int], create_type_hint([int, int])), + (list[int], create_type_hint((int, int))), + (list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), + ( + list[torch.Tensor], + create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), + ), + (torch.Tensor, torch.nn.Parameter), + (list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), + (list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), + (list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), + ( + list[torch.Tensor], + create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), + ), + (torch.Tensor, torch.nn.Parameter), + (list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), + (list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), + (Optional[list[torch.Tensor]], list[torch.Tensor]), + (Optional[list[int]], list[int]), + ] + [ + # pre-PEP585 signatures (List[int], int), (List[int], create_type_hint([int, int])), (List[int], create_type_hint((int, int))), @@ -1532,7 +1555,6 @@ class {test_classname}(torch.nn.Module): List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), ), - (torch.Tensor, torch.nn.Parameter), (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), @@ -1540,18 +1562,21 @@ class {test_classname}(torch.nn.Module): List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), ), - (torch.Tensor, torch.nn.Parameter), (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), (Optional[List[torch.Tensor]], List[torch.Tensor]), (Optional[List[int]], List[int]), ] + for sig_type, arg_type in should_be_equal: self.assertTrue(type_matches(sig_type, arg_type)) should_fail = [ (int, float), (Union[int, float], str), + (list[torch.Tensor], List[int]), + ] + [ + # pre-PEP585 signatures (List[torch.Tensor], List[int]), ] diff --git a/torch/fx/_compatibility.py b/torch/fx/_compatibility.py index 8a2eeb0d2d69..26bb3ff3b772 100644 --- a/torch/fx/_compatibility.py +++ b/torch/fx/_compatibility.py @@ -1,9 +1,9 @@ import textwrap -from typing import Any, Callable, Dict, TypeVar +from typing import Any, Callable, TypeVar -_BACK_COMPAT_OBJECTS: Dict[Any, None] = {} -_MARKED_WITH_COMPATIBILITY: Dict[Any, None] = {} +_BACK_COMPAT_OBJECTS: dict[Any, None] = {} +_MARKED_WITH_COMPATIBILITY: dict[Any, None] = {} _T = TypeVar("_T") diff --git a/torch/fx/_pytree.py b/torch/fx/_pytree.py index 611ad149e290..60349750ca48 100644 --- a/torch/fx/_pytree.py +++ b/torch/fx/_pytree.py @@ -1,20 +1,20 @@ # mypy: allow-untyped-defs from collections import namedtuple -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type +from typing import Any, Callable, NamedTuple, Optional import torch.return_types from torch.utils._pytree import PyTree, TreeSpec -FlattenFuncSpec = Callable[[PyTree, TreeSpec], List] +FlattenFuncSpec = Callable[[PyTree, TreeSpec], list] FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool] -SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {} -SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {} +SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {} +SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {} def register_pytree_flatten_spec( - cls: Type[Any], + cls: type[Any], flatten_fn_spec: FlattenFuncSpec, flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None, ) -> None: @@ -23,7 +23,7 @@ def register_pytree_flatten_spec( def _deregister_pytree_flatten_spec( - cls: Type[Any], + cls: type[Any], ) -> None: del SUPPORTED_NODES[cls] del SUPPORTED_NODES_EXACT_MATCH[cls] @@ -33,7 +33,7 @@ def tree_flatten_spec( pytree: PyTree, spec: TreeSpec, exact_structural_match=False, -) -> List[Any]: +) -> list[Any]: if spec.is_leaf(): return [pytree] if spec.type not in SUPPORTED_NODES: @@ -58,31 +58,31 @@ def tree_flatten_spec( return result -def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]: +def _dict_flatten_spec(d: dict[Any, Any], spec: TreeSpec) -> list[Any]: return [d[k] for k in spec.context] -def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]: +def _list_flatten_spec(d: list[Any], spec: TreeSpec) -> list[Any]: return [d[i] for i in range(spec.num_children)] -def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]: +def _tuple_flatten_spec(d: tuple[Any], spec: TreeSpec) -> list[Any]: return [d[i] for i in range(spec.num_children)] -def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]: +def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]: return [d[i] for i in range(spec.num_children)] -def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool: +def _dict_flatten_spec_exact_match(d: dict[Any, Any], spec: TreeSpec) -> bool: return len(d) == spec.num_children -def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool: +def _list_flatten_spec_exact_match(d: list[Any], spec: TreeSpec) -> bool: return len(d) == spec.num_children -def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool: +def _tuple_flatten_spec_exact_match(d: tuple[Any], spec: TreeSpec) -> bool: return len(d) == spec.num_children diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 2489ac6189fe..072328cda7f4 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -10,18 +10,7 @@ import os import warnings from itertools import chain from types import CodeType, FunctionType, ModuleType -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Set, - Tuple, - Type, - Union, -) +from typing import Any, Callable, NamedTuple, Optional, Union import torch import torch.utils._pytree as pytree @@ -42,7 +31,7 @@ HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS _orig_module_call: Callable = torch.nn.Module.__call__ _orig_module_getattr: Callable = torch.nn.Module.__getattr__ -_proxyable_classes: Dict[Type, None] = {} +_proxyable_classes: dict[type, None] = {} _is_fx_tracing_flag = False @@ -262,8 +251,8 @@ class Tracer(TracerBase): @compatibility(is_backward_compatible=True) def __init__( self, - autowrap_modules: Tuple[ModuleType] = (math,), - autowrap_functions: Tuple[Callable, ...] = (), + autowrap_modules: tuple[ModuleType] = (math,), + autowrap_functions: tuple[Callable, ...] = (), param_shapes_constant: bool = False, ) -> None: # This method's signature is overridden by the first line of this class' @@ -296,7 +285,7 @@ class Tracer(TracerBase): # Functions we will eagerly wrap when we see them while tracing # this captures both `math.sqrt()` and `from math import sqrt` automatically - self._autowrap_function_ids: Set[int] = { + self._autowrap_function_ids: set[int] = { id(value) for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) if not name.startswith("_") and callable(value) @@ -305,20 +294,20 @@ class Tracer(TracerBase): # Python modules to apply autowrap to at the start, in addition to # modules we see while tracing - self._autowrap_search: List[ModuleType] = list(autowrap_modules) + self._autowrap_search: list[ModuleType] = list(autowrap_modules) self.param_shapes_constant = param_shapes_constant - self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None + self.submodule_paths: Optional[dict[torch.nn.Module, str]] = None self.root_module_name: str = "" # Maps the containing module's name to the operator name self.scope = Scope("", None) # Records the module call stack self.module_stack = collections.OrderedDict() - self.num_calls: Dict[str, int] = {} + self.num_calls: dict[str, int] = {} # Mapping of node name to module scope - self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} + self.node_name_to_scope: dict[str, tuple[str, type]] = {} - _qualname_counter: Dict[str, int] = collections.defaultdict(int) + _qualname_counter: dict[str, int] = collections.defaultdict(int) @compatibility(is_backward_compatible=True) def get_fresh_qualname(self, prefix: str) -> str: @@ -492,8 +481,8 @@ class Tracer(TracerBase): self, m: torch.nn.Module, forward: Callable[..., Any], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], ) -> Any: """ Method that specifies the behavior of this ``Tracer`` when it encounters @@ -547,7 +536,7 @@ class Tracer(TracerBase): return ret_val @compatibility(is_backward_compatible=False) - def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): + def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]): """ Method that specifies the behavior of this ``Tracer`` when we call getattr on a call to an ``nn.Module`` instance. @@ -626,7 +615,7 @@ class Tracer(TracerBase): total_args = co.co_argcount + co.co_kwonlyargcount orig_args = list(co.co_varnames) names_iter = iter(co.co_varnames) - args: List[Any] = [] + args: list[Any] = [] skip_arg_idx = 0 if is_module: if total_args == 0: @@ -712,7 +701,7 @@ class Tracer(TracerBase): def trace( self, root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, + concrete_args: Optional[dict[str, Any]] = None, ) -> Graph: """ Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` @@ -763,7 +752,7 @@ class Tracer(TracerBase): self.root = torch.nn.Module() fn = root - tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) + tracer_cls: Optional[type[Tracer]] = getattr(self, "__class__", None) self.graph = Graph(tracer_cls=tracer_cls) if hasattr(fn, "__code__"): code = fn.__code__ @@ -777,11 +766,11 @@ class Tracer(TracerBase): # is some other attribute on the model. Construct a dict mapping Tensor # values to the qualified name here for efficiency. This is used downstream # in create_arg - self.tensor_attrs: Dict[ + self.tensor_attrs: dict[ Union[torch.Tensor, ScriptObject, FakeScriptObject], str ] = {} - def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]): for k, v in m.__dict__.items(): if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)): self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) @@ -797,7 +786,7 @@ class Tracer(TracerBase): fn, isinstance(root, torch.nn.Module), concrete_args ) - parameter_proxy_cache: Dict[ + parameter_proxy_cache: dict[ str, Proxy ] = {} # Reduce number of get_attr calls @@ -872,7 +861,7 @@ class Tracer(TracerBase): nonlocal cnt cnt += 1 param = sig.parameters[name] - default: Tuple[Any, ...] = ( + default: tuple[Any, ...] = ( () if param.default is inspect.Parameter.empty else (param.default,) ) out = self.create_proxy( @@ -913,7 +902,7 @@ class Tracer(TracerBase): return pytree.tree_map(replace_ph, concrete_args[name]) if name[0] == "*": - default: Tuple[Any, ...] = () + default: tuple[Any, ...] = () else: param = sig.parameters[name] default = ( # type: ignore[assignment] @@ -932,11 +921,11 @@ class Tracer(TracerBase): # the purposes of the wrap() API. # We key by the globals dict id and function name to ensure we're wrapping a given # function only once. -_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {} +_wrapped_fns_to_patch: dict[tuple[int, str], dict] = {} # List of methods on classes to wrap (class type, function name) # this currently only works for Tensor.* methods that aren't traced properly -_wrapped_methods_to_patch: List[Tuple[type, str]] = [] +_wrapped_methods_to_patch: list[tuple[type, str]] = [] if os.environ.get("FX_PATCH_GETITEM") == "1": # This change is needed to trace models like PositionalEmbedding from BERT: @@ -1043,12 +1032,12 @@ class _PatchedFnSetAttr(_PatchedFn): class _Patcher: def __init__(self) -> None: super().__init__() - self.patches_made: List[_PatchedFn] = [] - self.visited: Set[int] = set() + self.patches_made: list[_PatchedFn] = [] + self.visited: set[int] = set() def patch( self, - frame_dict: Dict[str, Any], + frame_dict: dict[str, Any], name: str, new_fn: Callable, deduplicate: bool = True, @@ -1169,7 +1158,7 @@ def _patch_wrapped_functions(patcher: _Patcher): def _autowrap_check( - patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] + patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int] ): """ Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. @@ -1252,7 +1241,7 @@ def wrap(fn_or_name: Union[str, Callable]): @compatibility(is_backward_compatible=True) def symbolic_trace( root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None, + concrete_args: Optional[dict[str, Any]] = None, ) -> GraphModule: """ Symbolic tracing API diff --git a/torch/fx/_utils.py b/torch/fx/_utils.py index fc62453d67c2..25f1c5117173 100644 --- a/torch/fx/_utils.py +++ b/torch/fx/_utils.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import sys -from typing import Dict, Optional +from typing import Optional import torch from torch._logging import LazyString @@ -43,7 +43,7 @@ def _format_graph_code(name, filename, graph_str): return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" -def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]: +def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[dict]: """ Returns the nn_module_stack of the first call_function node. """ diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index d2907ad3d08c..29b8d4541b81 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import operator from collections import deque -from typing import Deque, Dict, List, NamedTuple, Set, Tuple +from typing import NamedTuple import torch from torch.fx.experimental.partitioner_utils import ( @@ -28,15 +28,15 @@ class DAGNode: def __init__( self, submodule_node: Node, - input_nodes: List[Node], - output_nodes: List[Node], - logical_device_ids: List[int], + input_nodes: list[Node], + output_nodes: list[Node], + logical_device_ids: list[int], size_bytes: int, ) -> None: self.submodule_node: Node = submodule_node - self.input_nodes: List[Node] = input_nodes - self.output_nodes: List[Node] = output_nodes - self.logical_device_ids: List[int] = logical_device_ids + self.input_nodes: list[Node] = input_nodes + self.output_nodes: list[Node] = output_nodes + self.logical_device_ids: list[int] = logical_device_ids self.size_bytes = size_bytes def __str__(self) -> str: @@ -47,14 +47,14 @@ class DAG: """DAG class contains all the DAG nodes""" def __init__(self) -> None: - self.nodes: List[DAGNode] = [] + self.nodes: list[DAGNode] = [] def create_node( self, submodule_node: Node, - input_nodes: List[Node], - output_nodes: List[Node], - logical_devices: List[int], + input_nodes: list[Node], + output_nodes: list[Node], + logical_devices: list[int], size_bytes: int, ) -> None: node = DAGNode( @@ -79,7 +79,7 @@ def reset_partition_device(partitions): def combine_two_partitions( - partition_0: Partition, partition_1: Partition, partitions: List[Partition] + partition_0: Partition, partition_1: Partition, partitions: list[Partition] ) -> None: """Given a list of partitions and its two partitions, combine these two partitions into a new one appending to the partitions @@ -95,7 +95,7 @@ def combine_two_partitions( return -def set_parents_and_children(partitions: List[Partition]) -> None: +def set_parents_and_children(partitions: list[Partition]) -> None: """Given a list of partitions, mark parents and children for each partition""" # Go through all nodes in a partition. # If a node's user is in other partition, @@ -119,7 +119,7 @@ def set_parents_and_children(partitions: List[Partition]) -> None: return -def reorganize_partitions(partitions: List[Partition]) -> None: +def reorganize_partitions(partitions: list[Partition]) -> None: """Given a list of partitions, reorganize partition id, its parents and its children for each partition """ @@ -130,17 +130,17 @@ def reorganize_partitions(partitions: List[Partition]) -> None: return -def get_bfs_level_partition(partitions: List[Partition]) -> None: +def get_bfs_level_partition(partitions: list[Partition]) -> None: """Given a list of partitions, mark the bfs level for each partition """ - current_level: Set[Partition] = set() - visited: Set[Partition] = set() + current_level: set[Partition] = set() + visited: set[Partition] = set() for partition in partitions: # If a partition has no parent, it should be in root level if len(partition.parents) == 0: current_level.add(partition) - next_level: Set[Partition] = set() + next_level: set[Partition] = set() level = 0 # bfs while current_level: @@ -158,26 +158,26 @@ def get_bfs_level_partition(partitions: List[Partition]) -> None: return -def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]: +def get_node_to_partition_mapping(partitions: list[Partition]) -> dict[Node, int]: """Given a list of partitions,return node to partition mapping""" - node_to_partition: Dict[Node, int] = {} + node_to_partition: dict[Node, int] = {} for partition in partitions: for node in partition.nodes: node_to_partition[node] = partition.partition_id return node_to_partition -def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]: +def get_logical_id_to_device(devices: list[Device]) -> dict[int, Device]: """Get a mapping from device logical ID to Device object.""" - logical_id_to_device: Dict[int, Device] = {} + logical_id_to_device: dict[int, Device] = {} for d in devices: logical_id_to_device[d.logical_id] = d return logical_id_to_device def get_device_partition_stats( - partitions: List[Partition], devices: List[Device] -) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]: + partitions: list[Partition], devices: list[Device] +) -> tuple[dict[Device, list[Partition]], dict[Device, int], list[Partition]]: """Given a list of partitions and a list of devices, returns: 1. A mapping from device to partitions on it; 2. A mapping from device to its remaining memory size; @@ -186,9 +186,9 @@ def get_device_partition_stats( # logical id to device logical_id_to_device = get_logical_id_to_device(devices) # Track partitions on device - device_to_partitions: Dict[Device, List[Partition]] = {} + device_to_partitions: dict[Device, list[Partition]] = {} # Track device's left mem size - device_to_left_mem_bytes: Dict[Device, int] = {} + device_to_left_mem_bytes: dict[Device, int] = {} for d in devices: device_to_partitions[d] = [] device_to_left_mem_bytes[d] = d.available_mem_bytes @@ -213,16 +213,16 @@ def get_device_partition_stats( def get_device_to_partitions_mapping( - partitions: List[Partition], devices: List[Device] + partitions: list[Partition], devices: list[Device] ): """Given a list of partitions and a list of devices, map each partition into a device. """ def calculate_extra_mem_bytes_needed_for( - partition: Partition, partitions: List[Partition] + partition: Partition, partitions: list[Partition] ): - all_nodes: Set[Node] = set() + all_nodes: set[Node] = set() for p in partitions: all_nodes = all_nodes.union(p.nodes) if len(all_nodes) == 0: @@ -273,8 +273,8 @@ def check_dependency(partition): """Given a partition,check if there is a circular dependency on this partition using bfs """ - visited: Set[Partition] = {partition} - queue: Deque[Partition] = deque([partition]) + visited: set[Partition] = {partition} + queue: deque[Partition] = deque([partition]) while queue: p = queue.popleft() for child in p.children: @@ -298,9 +298,9 @@ class Partitioner: """ def __init__(self) -> None: - self.partitions: List[Partition] = [] - self.node_to_partition: Dict[Node, int] = {} - self.devices: List[Device] = [] + self.partitions: list[Partition] = [] + self.node_to_partition: dict[Node, int] = {} + self.devices: list[Device] = [] def partition_graph( self, @@ -435,9 +435,9 @@ class Partitioner: return device # Track partition and its left mem size - partition_to_left_mem_bytes: Dict[Partition, int] = {} + partition_to_left_mem_bytes: dict[Partition, int] = {} # Track all the devices that have been used - occupied_devices: List[Device] = [] + occupied_devices: list[Device] = [] partition = self.create_partition() for node in self.graph_module.graph.nodes: if node.op in {"call_module", "call_method", "call_function"}: @@ -516,7 +516,7 @@ class Partitioner: # Devices that hold partitions used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] # Track replicates of the assigned devices - replicated_device_to_used_device: Dict[Device, Device] = {} + replicated_device_to_used_device: dict[Device, Device] = {} while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( self.devices @@ -583,7 +583,7 @@ class Partitioner: continue if node.target == operator.__getitem__: continue - input_nodes: Dict[Node, None] = {} + input_nodes: dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) # When a node has two or more output nodes, @@ -634,7 +634,7 @@ class Partitioner: """ def combine_partitions_based_on_size( - partitions: List[Partition], available_mem_bytes: int + partitions: list[Partition], available_mem_bytes: int ) -> None: """Combining small partitions together to keep as less partitions as possible. Here is an example of the algorithm to do this: @@ -672,10 +672,10 @@ class Partitioner: return mem_bytes_needed def find_partition_to_combine_based_on_size( - sorted_partitions: List[Partition], + sorted_partitions: list[Partition], available_mem_bytes: int, - partitions: List[Partition], - ) -> Tuple[bool, List[Partition]]: + partitions: list[Partition], + ) -> tuple[bool, list[Partition]]: """step 1 in combine_partition_based_on_size()""" find_combination = False smallest_partition = sorted_partitions.pop(0) @@ -721,8 +721,8 @@ class Partitioner: return False # Track embedding partitions and non-embedding partitions separately - embedding_partitions: List[Partition] = [] - non_embedding_partitions: List[Partition] = [] + embedding_partitions: list[Partition] = [] + non_embedding_partitions: list[Partition] = [] # A Flag to check the boundary in_embedding_region: bool = False partition = self.create_partition() @@ -794,7 +794,7 @@ class Partitioner: def cost_aware_partition( self, transfer_rate_bytes_per_sec: float, - node_to_latency_mapping: Dict[Node, NodeLatency], + node_to_latency_mapping: dict[Node, NodeLatency], ) -> None: """This method is to partition the fx module based on the cost. The cost is the total latency of running the whole fx module. @@ -872,7 +872,7 @@ class Partitioner: ) if len(self.partitions) == 1: return False - partition_pair: List[int] = [] + partition_pair: list[int] = [] for i in range(len(self.partitions) - 1): for j in range(i + 1, len(self.partitions)): # Try to combine the partition pair @@ -915,7 +915,7 @@ class Partitioner: def kl_based_partition( self, transfer_rate_bytes_per_sec: float, - node_to_latency_mapping: Dict[Node, NodeLatency], + node_to_latency_mapping: dict[Node, NodeLatency], ) -> None: """This function is a cost aware partition based on Kernighan-Lin algorithm. @@ -987,7 +987,7 @@ class Partitioner: """ p1_nodes = list(p1.nodes) + [None] min_cost = float("inf") - node_pair: List[Node] = [] + node_pair: list[Node] = [] for n1 in p1_nodes: # Ignore the node if it is not a op node if n1 is not None and n1.op in {"placeholder", "get_attr"}: @@ -1011,9 +1011,9 @@ class Partitioner: self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec ) # Keep tracking the node pair that shows the better cost - node_pair: List[Node] = [] + node_pair: list[Node] = [] # Keep tracking the partition pair of node pair - partition_pair: List[Partition] = [] + partition_pair: list[Partition] = [] # Collect all the op nodes from the graph op_nodes = [ n @@ -1060,7 +1060,7 @@ class Partitioner: """This function helps to rebuild the partitions given the nodes and its corresponding partition id """ - partition_id_to_partition_mapping: Dict[int, Partition] = {} + partition_id_to_partition_mapping: dict[int, Partition] = {} self.node_to_partition = node_to_partition_mapping for node in self.node_to_partition: partition_id = self.node_to_partition[node] diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index d1ca4acde2b8..483b7e8b2ea2 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import re -from typing import Callable, Dict, Optional, Set, Union +from typing import Callable, Optional, Union import torch.fx from torch.fx.node import map_arg @@ -100,7 +100,7 @@ def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str): call_mod_args = call_mod_node_to_replace.args call_mod_kwargs = call_mod_node_to_replace.kwargs - replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {} + replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {} ph_count = 0 def replacement_fn(node): @@ -171,7 +171,7 @@ def split_const_subgraphs( # Build up a list of const_nodes, defined as nodes that are themselves # get_attrs, or have all get_attr or other constant node inputs. - const_nodes: Set[torch.fx.Node] = set() + const_nodes: set[torch.fx.Node] = set() found_const_folding = False for node in mod_traced.graph.nodes: # Skip over placeholders/outputs because they can't be const folded and diff --git a/torch/fx/experimental/debug.py b/torch/fx/experimental/debug.py index 5c290cceab48..b87dee9db9c7 100644 --- a/torch/fx/experimental/debug.py +++ b/torch/fx/experimental/debug.py @@ -1,4 +1,4 @@ -from typing import List, Sequence +from collections.abc import Sequence import torch.fx as fx @@ -19,7 +19,7 @@ def set_trace(gm: fx.GraphModule) -> fx.GraphModule: the `gm` with breakpoint inserted. """ - def insert_pdb(body: Sequence[str]) -> List[str]: + def insert_pdb(body: Sequence[str]) -> list[str]: return ["import pdb; pdb.set_trace()\n", *body] with gm.graph.on_generate_code( diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index 0e71aba53a0e..3b15ae0a6739 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -2,7 +2,7 @@ import itertools import operator from functools import reduce -from typing import Callable, Dict, TypeVar +from typing import Callable, TypeVar from typing_extensions import ParamSpec import sympy @@ -19,9 +19,9 @@ from torch.nn.modules.conv import Conv2d _T = TypeVar("_T") _P = ParamSpec("_P") -_INFERENCE_RULES: Dict[Target, Callable] = {} -_REFINEMENT_RULES: Dict[Target, Callable] = {} -_RULES: Dict[Target, Callable] = {} +_INFERENCE_RULES: dict[Target, Callable] = {} +_REFINEMENT_RULES: dict[Target, Callable] = {} +_RULES: dict[Target, Callable] = {} __all__ = [ "GraphTypeChecker", diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py index b3e1efcbd19e..c6a51918f930 100644 --- a/torch/fx/experimental/merge_matmul.py +++ b/torch/fx/experimental/merge_matmul.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import itertools import operator -from typing import Dict, List, Tuple import torch from torch.fx._symbolic_trace import symbolic_trace @@ -10,8 +9,8 @@ from torch.fx.passes.tools_common import legalize_graph def split_result_tensors( - result: torch.Tensor, inputs: List[torch.Tensor] -) -> Tuple[torch.Tensor, ...]: + result: torch.Tensor, inputs: list[torch.Tensor] +) -> tuple[torch.Tensor, ...]: """ A free function for use in the merge_matmul graph transformation below that splits the output from a merged matmul into the individual results for each @@ -71,7 +70,7 @@ def may_depend_on(a: Node, b: Node, search_depth: int = 6): return False -def are_nodes_independent(nodes: List[Node]): +def are_nodes_independent(nodes: list[Node]): """ Check if all of the given nodes are pairwise-data independent. @@ -102,8 +101,8 @@ def merge_matmul(in_mod: torch.nn.Module): """ gm = symbolic_trace(in_mod) - rhs_users: Dict[Node, List[Node]] = {} - lhs_users: Dict[Node, List[Node]] = {} + rhs_users: dict[Node, list[Node]] = {} + lhs_users: dict[Node, list[Node]] = {} # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to # the matmul of which they are the LHS/RHS. diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py index 1b74f33f40b5..e2fc033e0b8d 100644 --- a/torch/fx/experimental/meta_tracer.py +++ b/torch/fx/experimental/meta_tracer.py @@ -2,7 +2,7 @@ import builtins import functools import warnings -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.fx @@ -40,7 +40,7 @@ def torch_abs_override(input, *, out=None): return input -manual_meta_overrides: Dict[Callable, Callable] = { +manual_meta_overrides: dict[Callable, Callable] = { torch.nn.Embedding: embedding_override, torch.nn.LayerNorm: nn_layernorm_override, torch.relu: torch_relu_override, @@ -274,7 +274,7 @@ class MetaTracer(torch.fx.Tracer): def proxy(self, node): return MetaProxy(node, self) - def trace(self, root, meta_args: Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] + def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] assert isinstance(meta_args, dict) self.meta_args = meta_args @@ -299,8 +299,8 @@ class MetaTracer(torch.fx.Tracer): def symbolic_trace( root: Union[torch.nn.Module, Callable[..., Any]], - meta_args: Optional[Dict[str, torch.Tensor]] = None, - concrete_args: Optional[Dict[str, Any]] = None, + meta_args: Optional[dict[str, torch.Tensor]] = None, + concrete_args: Optional[dict[str, Any]] = None, ) -> torch.fx.GraphModule: tracer = MetaTracer() graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type] diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index e8c95ca7231c..03346b800924 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import operator import warnings -from typing import Callable, Dict, Iterable, TypeVar +from collections.abc import Iterable +from typing import Callable, TypeVar from typing_extensions import ParamSpec import torch @@ -57,7 +58,7 @@ from torch.nn.modules.conv import Conv2d _T = TypeVar("_T") _P = ParamSpec("_P") -_INFERENCE_RULES: Dict[Target, Callable] = {} +_INFERENCE_RULES: dict[Target, Callable] = {} MAX_TENSOR_RANK = 4 diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py index 263ac5de560d..11ebff010209 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -1,7 +1,7 @@ # mypy: ignore-errors import copy import itertools -from typing import Callable, Dict, List +from typing import Callable from torch.fx.experimental.migrate_gradual_types.constraint import ( ApplyBroadcasting, @@ -50,7 +50,7 @@ from torch.fx.experimental.migrate_gradual_types.util import ( from torch.fx.tensor_type import Dyn, TensorType -_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} +_TRANSFORMATION_RULES: dict[Constraint, Callable] = {} def register_transformation_rule(call_target): @@ -797,7 +797,7 @@ def transform_constraint(constraint: Constraint, counter: int): return constraint, counter -def calc_last_two_dims(constraint, d: List[DVar]): +def calc_last_two_dims(constraint, d: list[DVar]): """ Generates constraints for the last two dimensions of a convolution or a maxpool output Args: @@ -866,7 +866,7 @@ def calc_last_two_dims(constraint, d: List[DVar]): return c4, c5 -def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): +def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]): """ Generate all possibilities of being equal or not equal to dyn for my_list Args: @@ -888,7 +888,7 @@ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): return all_possibilities -def is_target_div_by_dim(target: List[int], dim: List[DVar]): +def is_target_div_by_dim(target: list[int], dim: list[DVar]): """ Generate constraints to check if the target dimensions are divisible by the input dimensions Args: @@ -901,7 +901,7 @@ def is_target_div_by_dim(target: List[int], dim: List[DVar]): return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) -def is_dim_div_by_target(target: List[int], dim: List[DVar]): +def is_dim_div_by_target(target: list[int], dim: list[DVar]): """ Generate constraints to check if the input dimensions is divisible by the target dimensions Args: @@ -1000,9 +1000,9 @@ def apply_padding( e11: BinConstraintT, e2: BinConstraintT, e12: BinConstraintT, - d2: List[DVar], - d11: List[DVar], - d12: List[DVar], + d2: list[DVar], + d11: list[DVar], + d12: list[DVar], counter: int, ): """ @@ -1068,7 +1068,7 @@ def apply_padding( def no_broadcast_dim_with_index( - d1: List[DVar], d2: List[DVar], d3: List[DVar], d4: List[DVar], i: int + d1: list[DVar], d2: list[DVar], d3: list[DVar], d4: list[DVar], i: int ): """ Args: @@ -1129,10 +1129,10 @@ def create_equality_constraints_for_broadcasting( e2: TVar, e11: TVar, e12: TVar, - d1: List[DVar], - d2: List[DVar], - d11: List[DVar], - d12: List[DVar], + d1: list[DVar], + d2: list[DVar], + d11: list[DVar], + d12: list[DVar], ): """ Create equality constraints for when no broadcasting occurs @@ -1236,7 +1236,7 @@ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): def generate_all_broadcasting_possibilities_no_padding( - d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar] + d1: list[DVar], d2: list[DVar], d11: list[DVar], d12: list[DVar] ): """ Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. diff --git a/torch/fx/experimental/normalize.py b/torch/fx/experimental/normalize.py index cc6944d5a5af..73cce6017bf1 100644 --- a/torch/fx/experimental/normalize.py +++ b/torch/fx/experimental/normalize.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import operator -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Optional import torch import torch.fx @@ -38,7 +38,7 @@ class NormalizeArgs(Transformer): self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True ): super().__init__(module) - self.node_map: Dict[Proxy, Node] = {} + self.node_map: dict[Proxy, Node] = {} self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs def run_node(self, n: Node) -> Any: @@ -66,10 +66,10 @@ class NormalizeArgs(Transformer): def call_function( self, target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Any], - arg_types: Optional[Tuple[Any, ...]] = None, - kwarg_types: Optional[Dict[str, Any]] = None, + args: tuple[Argument, ...], + kwargs: dict[str, Any], + arg_types: Optional[tuple[Any, ...]] = None, + kwarg_types: Optional[dict[str, Any]] = None, ): assert callable(target) new_args_and_kwargs = normalize_function( @@ -89,7 +89,7 @@ class NormalizeArgs(Transformer): return super().call_function(target, args, kwargs) def call_module( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ): assert isinstance(target, str) new_args_and_kwargs = normalize_module( @@ -124,7 +124,7 @@ class NormalizeOperators(AnnotateTypesWithSchema): traced = NormalizeOperators(traced).transform() """ - binary_magic_method_remap: Dict[ + binary_magic_method_remap: dict[ Callable[[Any, Any], Any], Callable[[Any, Any], Any] ] = { torch.add: operator.add, @@ -142,7 +142,7 @@ class NormalizeOperators(AnnotateTypesWithSchema): } def call_function( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ): # Normalize operators according to the magic methods implemented on tensors here: # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index 080c9e48102a..13d9c2d9ac77 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -4,8 +4,9 @@ import logging import operator import time from collections import defaultdict +from collections.abc import Iterable from enum import Enum -from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type +from typing import Any, cast, Optional import torch import torch.fx as fx @@ -33,7 +34,7 @@ __all__ = [ ] -def _parent_name(target: str) -> Tuple[str, str]: +def _parent_name(target: str) -> tuple[str, str]: """ Splits a qualname into parent path and last atom. For example, `foo.bar.baz` -> (`foo.bar`, `baz`) @@ -44,11 +45,11 @@ def _parent_name(target: str) -> Tuple[str, str]: # Works for length 2 patterns with 2 modules def matches_module_pattern( - pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any] + pattern: Iterable[type], node: fx.Node, modules: dict[str, Any] ): if len(node.args) == 0: return False - nodes: Tuple[Any, fx.Node] = (node.args[0], node) + nodes: tuple[Any, fx.Node] = (node.args[0], node) for expected_type, current_node in zip(pattern, nodes): if not isinstance(current_node, fx.Node): return False @@ -64,7 +65,7 @@ def matches_module_pattern( def replace_node_module( - node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module + node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module ): assert isinstance(node.target, str) parent_name, name = _parent_name(node.target) @@ -120,7 +121,7 @@ def remove_dropout(model: nn.Module) -> nn.Module: class DropoutRemover(torch.fx.Transformer): def call_module( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: if isinstance(self.submodules[target], nn.Dropout): assert len(args) == 1 @@ -133,15 +134,15 @@ def remove_dropout(model: nn.Module) -> nn.Module: def extract_subgraph( orig_module: nn.Module, - nodes: List[fx.Node], - inputs: List[fx.Node], - outputs: List[fx.Node], + nodes: list[fx.Node], + inputs: list[fx.Node], + outputs: list[fx.Node], ): """ Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. """ new_graph = fx.Graph() - env: Dict[fx.Node, fx.Node] = {} + env: dict[fx.Node, fx.Node] = {} for input in inputs: new_node = new_graph.placeholder(input.name) env[input] = new_node @@ -180,13 +181,13 @@ mkldnn_map = { } -def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): +def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]): """ For each node, if it's a module that can be preconverted into MKLDNN, then we do so and create a mapping to allow us to convert from the MKLDNN version of the module to the original. """ - old_modules: Dict[nn.Module, nn.Module] = {} + old_modules: dict[nn.Module, nn.Module] = {} for node in nodes: if node.op == "call_module": assert isinstance(node.target, str) @@ -200,9 +201,9 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): def reset_modules( - nodes: List[fx.Node], - modules: Dict[str, nn.Module], - old_modules: Dict[nn.Module, nn.Module], + nodes: list[fx.Node], + modules: dict[str, nn.Module], + old_modules: dict[nn.Module, nn.Module], ): """ Maps each module that's been changed with `modules_to_mkldnn` back to its @@ -219,9 +220,9 @@ def reset_modules( class MklSubgraph: def __init__(self, fx_graph: fx.Graph): self.fx_graph = fx_graph - self.nodes: List[fx.Node] = [] - self.start_nodes: List[fx.Node] = [] - self.end_nodes: List[fx.Node] = [] + self.nodes: list[fx.Node] = [] + self.start_nodes: list[fx.Node] = [] + self.end_nodes: list[fx.Node] = [] def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): @@ -244,7 +245,7 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] ShapeProp(fx_model).propagate(example_inputs) sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] - output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes]) + output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes]) submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) def benchmark(f): @@ -281,8 +282,8 @@ def use_mkl_length(graph: MklSubgraph) -> bool: class UnionFind: def __init__(self, n): - self.parent: List[Optional[int]] = [None] * n - self.size: List[int] = [0] * n + self.parent: list[Optional[int]] = [None] * n + self.size: list[int] = [0] * n def make_set(self, v: int): self.parent[v] = v @@ -308,8 +309,8 @@ class UnionFind: def optimize_for_inference( model: torch.nn.Module, - pass_config: Optional[Dict[str, Any]] = None, - tracer: Type[fx.Tracer] = fx.Tracer, + pass_config: Optional[dict[str, Any]] = None, + tracer: type[fx.Tracer] = fx.Tracer, ) -> torch.nn.Module: """ Performs a set of optimization passes to optimize a model for the @@ -348,7 +349,7 @@ def optimize_for_inference( cur_tracer = tracer() fx_graph = cur_tracer.trace(copy.deepcopy(model)) fx.GraphModule(cur_tracer.root, fx_graph) - modules: Dict[str, nn.Module] = dict(model.named_modules()) + modules: dict[str, nn.Module] = dict(model.named_modules()) class MklSupport(Enum): NO = 1 @@ -388,7 +389,7 @@ def optimize_for_inference( node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,)) ) - node.args = cast(Tuple[fx.node.Argument], mkldnn_args) + node.args = cast(tuple[fx.node.Argument], mkldnn_args) with fx_graph.inserting_after(node): dense_x = fx_graph.create_node("call_method", "to_dense", (node,)) @@ -455,7 +456,7 @@ def optimize_for_inference( for other_color in cur_colors[1:]: uf.join(cur_colors[0], other_color) - mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) + mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) for node in fx_graph.nodes: if hasattr(node, "color"): mkldnn_graphs[uf.find(node.color)].nodes.append(node) diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py index 5cecb8c69945..3658dd1a9ce9 100644 --- a/torch/fx/experimental/partitioner_utils.py +++ b/torch/fx/experimental/partitioner_utils.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs from enum import Enum -from typing import Dict, List, NamedTuple, Set +from typing import NamedTuple from torch.fx.node import map_arg, Node @@ -11,13 +11,13 @@ class Partition: """ def __init__(self, partition_id: int) -> None: - self.nodes: Set[Node] = set() + self.nodes: set[Node] = set() self.partition_id = partition_id - self.parents: Set[Partition] = set() - self.children: Set[Partition] = set() + self.parents: set[Partition] = set() + self.children: set[Partition] = set() self.bfs_level: int = -1 self.used_mem_bytes: int = 0 - self.logical_device_ids: List[int] = [] + self.logical_device_ids: list[int] = [] def __str__(self): return str(self.partition_id) @@ -28,7 +28,7 @@ class Partition: self.used_mem_bytes += get_extra_size_of(node, self.nodes) def add_node(self, node): - input_nodes: Dict[Node, None] = {} + input_nodes: dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) # Add current node's input nodes if they are placeholder or constants @@ -43,7 +43,7 @@ class Partition: if node in self.nodes: self.nodes.remove(node) # Collect the node's input nodes - input_nodes: Dict[Node, None] = {} + input_nodes: dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) # Check if an input node is a placeholder or get_attr, @@ -88,23 +88,23 @@ class PartitionMode(Enum): class PartitionerConfig(NamedTuple): - devices: List[Device] + devices: list[Device] mode: PartitionMode = PartitionMode.size_based transfer_rate_bytes_per_sec: float = 0.0 - node_to_latency_mapping: Dict[Node, NodeLatency] = {} - node_to_partition_mapping: Dict[Node, int] = {} - partition_to_logical_device_mapping: Dict[int, List[int]] = {} + node_to_latency_mapping: dict[Node, NodeLatency] = {} + node_to_partition_mapping: dict[Node, int] = {} + partition_to_logical_device_mapping: dict[int, list[int]] = {} # Saturate host by replicating partitions to the remaining idle devices. saturate_host: bool = False -def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: +def get_extra_size_of(node: Node, nodes: set[Node]) -> int: """Given a node and a set of nodes, this function return the extra size that needed if this node is included in this set. """ # Find all its input nodes - input_nodes: Dict[Node, None] = {} + input_nodes: dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) # Calculate total size of related nodes @@ -127,18 +127,18 @@ def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: def get_latency_of_one_partition( - partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] + partition: Partition, node_to_latency_mapping: dict[Node, NodeLatency] ) -> PartitionLatency: """Given a partition and its nodes' latency, return a PartitionLatency for this partition""" - def get_top_nodes(partition: Partition) -> List[Node]: + def get_top_nodes(partition: Partition) -> list[Node]: """Given a partition, return a list of nodes on the top bfs level""" - top_nodes: List[Node] = [] + top_nodes: list[Node] = [] for node in partition.nodes: # Skip placeholder and get_attr nodes if node.op in {"placeholder", "get_attr"}: continue - input_nodes: Dict[Node, None] = {} + input_nodes: dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) # If a node has no input nodes in this partition, @@ -216,12 +216,12 @@ def get_latency_of_one_partition( def get_partition_to_latency_mapping( - partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency] -) -> Dict[Partition, PartitionLatency]: + partitions: list[Partition], node_to_latency_mapping: dict[Node, NodeLatency] +) -> dict[Partition, PartitionLatency]: """Given all the partitions and node_to_latency_mapping dictionary, return a mapping dictionary of each partition to its overall latency """ - partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} + partition_to_latency_mapping: dict[Partition, PartitionLatency] = {} # Go through each partition and get its latency for partition in partitions: partition_latency = get_latency_of_one_partition( @@ -255,7 +255,7 @@ def get_comm_latency_between( # the output size of those input nodes will be counted # and added to comm_size for node in child_partition.nodes: - input_nodes: Dict[Node, None] = {} + input_nodes: dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) for n in input_nodes: @@ -268,8 +268,8 @@ def get_comm_latency_between( def get_latency_of_partitioned_graph( - partitions: List[Partition], - partition_to_latency_mapping: Dict[Partition, PartitionLatency], + partitions: list[Partition], + partition_to_latency_mapping: dict[Partition, PartitionLatency], transfer_rate_bytes_per_sec: float, ): """Given all partitions in a graph, find the critical path among all partitions @@ -298,7 +298,7 @@ def get_latency_of_partitioned_graph( return max_latency_sec return latency_so_far_sec - def get_top_partitions(partitions: List[Partition]) -> List[Partition]: + def get_top_partitions(partitions: list[Partition]) -> list[Partition]: """This function is to return all the partitions without parents as the starting points of all the paths """ diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 10a783dd74ba..3cf0edb3402c 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -17,21 +17,15 @@ import typing_extensions import warnings import weakref from collections import defaultdict +from collections.abc import Generator, Mapping, Sequence from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import ( Any, Callable, - Dict, - Generator, - List, - Mapping, Optional, overload, Protocol, - Sequence, - Tuple, - Type, TYPE_CHECKING, TypeVar, Union, @@ -168,7 +162,7 @@ from torch.types import py_sym_types, PySymType class _HasMeta(Protocol): - meta: Dict[str, PySymType] + meta: dict[str, PySymType] def is_sym_node(node: _HasMeta) -> bool: @@ -377,9 +371,9 @@ _ExtractValType = Optional[ PySymType, _AnyScriptObjectType, BackwardState, - List["_ExtractValType"], - Tuple["_ExtractValType", ...], - Dict[str, "_ExtractValType"], + list["_ExtractValType"], + tuple["_ExtractValType", ...], + dict[str, "_ExtractValType"], Tensor, int, float, @@ -767,10 +761,10 @@ def proxy_call( proxy_mode: ProxyTorchDispatchMode, func: OpOverload, pre_dispatch: bool, - args: Tuple[object, ...], - kwargs: Dict[str, object], + args: tuple[object, ...], + kwargs: dict[str, object], ) -> object: - unrecognized_types: List[Type] = [] + unrecognized_types: list[type] = [] flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs)) def can_handle_tensor(x: Tensor) -> bool: @@ -987,7 +981,7 @@ class _SymNodeDict: """ def __init__(self) -> None: - self.sym_node_dict: Dict[PySymType, _PySymProxyType] = {} + self.sym_node_dict: dict[PySymType, _PySymProxyType] = {} def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None: self.sym_node_dict[key.node] = value @@ -1015,9 +1009,9 @@ class _SymNodeDict: class PythonKeyTracer(Tracer): script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] symnode_tracker: _SymNodeDict - sympy_expr_tracker: Dict[sympy.Symbol, object] + sympy_expr_tracker: dict[sympy.Symbol, object] tensor_tracker: MutableMapping[Tensor, _ProxyTensor] - torch_fn_counts: Dict[OpOverload, int] + torch_fn_counts: dict[OpOverload, int] enable_thunkify: bool = False def __init__(self) -> None: @@ -1043,14 +1037,14 @@ class PythonKeyTracer(Tracer): self, m: Module, forward: Callable[..., Any], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], ) -> Any: return forward(*args, **kwargs) # We don't want to turn getattr calls into proxies. So we just return the actual value. def getattr( - self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy] + self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy] ) -> object: return attr_val @@ -1095,7 +1089,7 @@ class PythonKeyTracer(Tracer): def _make_temp_remove_mode_context_manager( - mode_ty: Type[TorchFunctionMode], + mode_ty: type[TorchFunctionMode], ) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: @contextmanager def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: @@ -1137,7 +1131,7 @@ def _make_temp_remove_mode_context_manager( def dispatch_trace( root: Union[Module, Callable], tracer: Tracer, - concrete_args: Optional[Tuple[Any, ...]] = None, + concrete_args: Optional[tuple[Any, ...]] = None, ) -> GraphModule: graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] @@ -1235,9 +1229,9 @@ class TorchFunctionMetadataMode(TorchFunctionMode): def __torch_function__( self, func: OpOverload, - types: Tuple[torch._C._TensorMeta, ...], - args: Tuple[object, ...] = (), - kwargs: Optional[Dict[str, object]] = None, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, ) -> object: kwargs = kwargs or {} self.tracer.torch_fn_metadata = func @@ -1259,14 +1253,14 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode): # The input to torch.amp.autocast_mode._exit_autocast graph node should be the # enter_autocast node. So we have to save the enter autocast node here, and assign it # to the exit_autocast call_function node. - self.enter_autocast_nodes: List[torch.fx.Node] = [] + self.enter_autocast_nodes: list[torch.fx.Node] = [] def __torch_function__( self, func: Union[OpOverload, Callable], - types: Tuple[torch._C._TensorMeta, ...], - args: Tuple[object, ...] = (), - kwargs: Optional[Dict[str, object]] = None, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, ) -> object: kwargs = kwargs or {} if func in _side_effectful_need_to_be_preserved_pre_dispatch: @@ -1324,7 +1318,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode): # Every time we enter a mode, we maintain a stack telling us what the previous # ProxyTorchDispatchMode state was (if there was any). # This lets us properly reset the state on exit. - self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = [] + self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = [] self.decomp_layers = 0 from torch._inductor import config @@ -1334,9 +1328,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode): def __torch_dispatch__( self, func: OpOverload, - types: Tuple[torch._C._TensorMeta, ...], - args: Tuple[object, ...] = (), - kwargs: Optional[Dict[str, object]] = None, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, ) -> object: with set_original_aten_op(func): kwargs = kwargs or {} @@ -1354,7 +1348,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode): def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[types.TracebackType], ) -> Optional[bool]: @@ -1372,10 +1366,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode): return True def _compute_proxy( - self, func: OpOverload, args: Tuple[object, ...], out: PySymType + self, func: OpOverload, args: tuple[object, ...], out: PySymType ) -> Proxy: # Handle torch.sym_sum - n_args: Tuple[object, ...] + n_args: tuple[object, ...] if len(args) == 1 and isinstance(args[0], (list, tuple)): n_args = ( tuple( @@ -1403,9 +1397,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode): def __sym_dispatch__( self, func: OpOverload, - types: Tuple[torch._C._TensorMeta, ...], - args: Tuple[object, ...], - kwargs: Dict[str, object], + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...], + kwargs: dict[str, object], ) -> object: # Peephole optimize multiply by one # NB: be careful not to trigger guards here! @@ -1438,9 +1432,9 @@ class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer): script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] symnode_tracker: MutableMapping[PySymType, _PySymProxyType] tensor_tracker: MutableMapping[Tensor, _ProxyTensor] - sympy_expr_tracker: Dict[sympy.Symbol, object] + sympy_expr_tracker: dict[sympy.Symbol, object] torch_fn_metadata: Optional[OpOverload] - torch_fn_counts: Dict[OpOverload, int] + torch_fn_counts: dict[OpOverload, int] enable_thunkify: bool = False def __init__(self, graph: fx.graph.Graph) -> None: @@ -1476,7 +1470,7 @@ class DecompositionInterpreter(fx.Interpreter): self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") def placeholder( - self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] + self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] ) -> object: out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) @@ -1485,7 +1479,7 @@ class DecompositionInterpreter(fx.Interpreter): return out def get_attr( - self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] + self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] ) -> object: out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) @@ -1495,7 +1489,7 @@ class DecompositionInterpreter(fx.Interpreter): # call_function, call_method, call_module get traced automatically by the outer mode. def output( - self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] + self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] ) -> object: out = super().output(target, args, kwargs) # type: ignore[arg-type] @@ -1516,13 +1510,13 @@ class DecompositionInterpreter(fx.Interpreter): def wrapper_and_args_for_make_fx( - func: Callable[..., R], args: Tuple[object, ...], kwargs: Dict[str, object] -) -> Tuple[Callable[[List[object]], R], List[object]]: + func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object] +) -> tuple[Callable[[list[object]], R], list[object]]: # make_fx doesn't support kwargs, so we need to do this flattening # and then unflatten the args before calling func flat_args, spec = pytree.tree_flatten((args, kwargs)) - def wrapped(flat_args: List[object]) -> R: + def wrapped(flat_args: list[object]) -> R: fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) return func(*fn_args, **fn_kwargs) @@ -1642,7 +1636,7 @@ class _ModuleStackTracer(PythonKeyTracer): return tracer.proxy_modules[self] @property - def _modules(self) -> Dict[str, AttrProxy]: + def _modules(self) -> dict[str, AttrProxy]: assert "_modules" in self.__dict__ submodules = self.__dict__["_modules"] assert isinstance(submodules, dict) @@ -1674,7 +1668,7 @@ class _ModuleStackTracer(PythonKeyTracer): raise _ModuleNotInstalledAsSubmoduleError from e def getattr( - self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy] + self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy] ) -> object: if ( not isinstance(attr_val, Module) @@ -1693,7 +1687,7 @@ class _ModuleStackTracer(PythonKeyTracer): return self.attr_proxy_map[attr_val] def trace( # type: ignore[override] - self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]] + self, root: Union[Module, Callable], concrete_args: Optional[dict[str, object]] ) -> fx.Graph: res = super().trace(root, concrete_args) @@ -1702,7 +1696,7 @@ class _ModuleStackTracer(PythonKeyTracer): # to the tracer while tracing, the proxy object gets registered # first. So we need to replace the proxy modules with the real ones # This can happen during HOO tracing - proxy_module_names_to_be_replaced: List[Tuple[str, _AttrProxy]] = [] + proxy_module_names_to_be_replaced: list[tuple[str, _AttrProxy]] = [] for name, module in self.root.named_modules(): if module in self.proxy_modules: proxy_module_names_to_be_replaced.append((name, module)) @@ -1746,8 +1740,8 @@ class _ModuleStackTracer(PythonKeyTracer): self, m: Module, forward: Callable, - args: Tuple[object, ...], - kwargs: Dict[str, object], + args: tuple[object, ...], + kwargs: dict[str, object], ) -> None: """PythonKeyTracer overrides call_module to avoid the scope handling, but we actually want it. @@ -1857,7 +1851,7 @@ class _MakefxTracer: ) -> None: # Configurations that are used to initialize the context managers and their states. # Should not modify them during tracing. - self.decomposition_table: Dict[OpOverload, Callable] = dict( + self.decomposition_table: dict[OpOverload, Callable] = dict( decomposition_table or {} ) self.decomposition_table.setdefault( @@ -1885,7 +1879,7 @@ class _MakefxTracer: nullcontext, TorchFunctionMetadataMode ] = nullcontext() - def _checkpoint_modes(self) -> List[Any]: + def _checkpoint_modes(self) -> list[Any]: return [ self.fake_tensor_mode, self.proxy_mode, @@ -1913,7 +1907,7 @@ class _MakefxTracer: @contextmanager def _init_modes_from_inputs( - self, f: Callable, args: Tuple[object, ...] + self, f: Callable, args: tuple[object, ...] ) -> Generator[None, None, None]: prev_modes = self._checkpoint_modes() try: @@ -2202,7 +2196,7 @@ def make_fx( return wrapped -def get_torch_dispatch_modes() -> List[TorchDispatchMode]: +def get_torch_dispatch_modes() -> list[TorchDispatchMode]: return torch.utils._python_dispatch._get_current_dispatch_mode_stack() @@ -2240,7 +2234,7 @@ def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) # dispatch machinery which disables it for us with disable_proxy_modes_tracing(): # TODO: properly compute types - types: List[Type] = [] + types: list[type] = [] return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value] @@ -2252,8 +2246,8 @@ def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, Non def maybe_handle_decomp( proxy_mode: ProxyTorchDispatchMode, op: OpOverload, - args: Tuple[object, ...], - kwargs: Dict[str, object], + args: tuple[object, ...], + kwargs: dict[str, object], ) -> object: from torch._inductor.compiler_bisector import CompilerBisector @@ -2274,8 +2268,8 @@ def maybe_handle_decomp( def get_isolated_graphmodule( func: Callable, - args: Tuple[object, ...], - kwargs: Dict[str, object], + args: tuple[object, ...], + kwargs: dict[str, object], tracing_mode: str = "real", decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, ) -> GraphModule: diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 957d17e77376..dcaa6659571f 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -4,7 +4,7 @@ import inspect import itertools import logging from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch import torch.utils._pytree as pytree @@ -83,11 +83,11 @@ class ShapeEnvEvent: f: Callable # Arguments and keyword arguments called with. - args: Optional[List[Any]] = None - kwargs: Optional[Dict[str, Any]] = None + args: Optional[list[Any]] = None + kwargs: Optional[dict[str, Any]] = None # List of tracked_fakes at the time the method was called. - tracked_fakes: Optional[List[Any]] = None + tracked_fakes: Optional[list[Any]] = None # Name of the captured event. # Used for special handling of particular methods. @@ -344,15 +344,15 @@ def replay_shape_env_events(events): # ShapeEnv.produce_guards. @dataclass class FakeTensorMeta: - tensor_size: Tuple[Union[int, torch.SymInt], ...] - tensor_stride: Tuple[Union[int, torch.SymInt], ...] + tensor_size: tuple[Union[int, torch.SymInt], ...] + tensor_stride: tuple[Union[int, torch.SymInt], ...] tensor_storage_offset: Union[int, torch.SymInt] is_nested: bool - def size(self) -> Tuple[Union[int, torch.SymInt], ...]: + def size(self) -> tuple[Union[int, torch.SymInt], ...]: return self.tensor_size - def stride(self) -> Tuple[Union[int, torch.SymInt], ...]: + def stride(self) -> tuple[Union[int, torch.SymInt], ...]: return self.tensor_stride def storage_offset(self) -> Union[int, torch.SymInt]: @@ -445,7 +445,7 @@ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value) # compare the two values. def compare_vars( map_value: Callable[[str, Any], Any] - ) -> List[Tuple[str, str, str]]: + ) -> list[tuple[str, str, str]]: env1_set, env2_set = set(env1_vars), set(env2_vars) # First, compare the set of keys in each vars dictionary. @@ -489,7 +489,7 @@ class NotEqualError(Exception): def __init__( self, msg: str, - mismatched: List[Tuple[str, str, str]], + mismatched: list[tuple[str, str, str]], ) -> None: details = "\n".join( [ diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index 76ec03f86289..8e635a525f6f 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -6,7 +6,7 @@ import functools import inspect import textwrap from types import FunctionType -from typing import Any, Callable, cast, Dict, Optional, Union +from typing import Any, Callable, cast, Optional, Union import torch from torch._sources import normalize_source_lines @@ -112,7 +112,7 @@ class RewritingTracer(Tracer): def trace( self, root: Union[torch.nn.Module, Callable], - concrete_args: Optional[Dict[str, Any]] = None, + concrete_args: Optional[dict[str, Any]] = None, ) -> Graph: return super().trace(_rewrite(root), concrete_args) diff --git a/torch/fx/experimental/schema_type_annotation.py b/torch/fx/experimental/schema_type_annotation.py index 519fec16cfc8..335c027c9321 100644 --- a/torch/fx/experimental/schema_type_annotation.py +++ b/torch/fx/experimental/schema_type_annotation.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import inspect -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import torch import torch.fx @@ -42,7 +42,7 @@ class AnnotateTypesWithSchema(Transformer): self.annotate_get_attrs = annotate_get_attrs def call_function( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ): python_ret_type = None if self.annotate_functionals and target.__module__ == "torch.nn.functional": @@ -73,7 +73,7 @@ class AnnotateTypesWithSchema(Transformer): return return_proxy def call_module( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ): python_ret_type = None assert isinstance(target, str) @@ -91,8 +91,8 @@ class AnnotateTypesWithSchema(Transformer): def get_attr( self, target: torch.fx.node.Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Any], + args: tuple[Argument, ...], + kwargs: dict[str, Any], ): attr_proxy = super().get_attr(target, args, kwargs) diff --git a/torch/fx/experimental/shape_inference/infer_symbol_values.py b/torch/fx/experimental/shape_inference/infer_symbol_values.py index 81eaca310407..d7ff154c16ad 100644 --- a/torch/fx/experimental/shape_inference/infer_symbol_values.py +++ b/torch/fx/experimental/shape_inference/infer_symbol_values.py @@ -1,5 +1,6 @@ import re -from typing import Any, DefaultDict, Dict, List, Tuple, Union +from collections import defaultdict +from typing import Any, Union import numpy as np import sympy as sp @@ -13,10 +14,10 @@ s_pattern = r"s\d+" def infer_symbol_values( - symints: List[Union[torch.SymInt, int]], - init_symints: List[Union[torch.SymInt, int]], - symbol_idx_dict: Dict[str, int], - padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]], + symints: list[Union[torch.SymInt, int]], + init_symints: list[Union[torch.SymInt, int]], + symbol_idx_dict: dict[str, int], + padding_constraints: defaultdict[torch.SymInt, list[Union[sp.Expr, int]]], constraint: str, ) -> None: if constraint.find("non-singleton") != -1: @@ -83,8 +84,8 @@ def infer_symbol_values( def calculate_value( left_expression: Union[str, Any, None], right_expression: Union[str, Any, None], - symints: List[Union[torch.SymInt, int]], - symbol_idx_dict: Dict[str, int], + symints: list[Union[torch.SymInt, int]], + symbol_idx_dict: dict[str, int], ) -> None: var, val = solve_equation(left_expression, right_expression) idx = symbol_idx_dict[var] @@ -95,7 +96,7 @@ def calculate_value( def solve_equation( left_expression: Union[str, Any, None], right_expression: Union[str, Any, None], -) -> Tuple[str, int]: +) -> tuple[str, int]: expression = f"{left_expression} - {right_expression}" var = re.findall(s_pattern, expression)[0] if re.findall(parentheses_pattern, expression): @@ -116,9 +117,9 @@ def solve_equation( def update_equation( - symints: List[Union[torch.SymInt, int]], - init_symints: List[Union[torch.SymInt, int]], - padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]], + symints: list[Union[torch.SymInt, int]], + init_symints: list[Union[torch.SymInt, int]], + padding_constraints: defaultdict[torch.SymInt, list[Union[sp.Expr, int]]], init_eq: sp.Expr, new_mod_num: int, var: torch.SymInt, diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 6f2b9fe696d1..941b30518c8d 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -20,7 +20,7 @@ import math import operator import sys from functools import lru_cache, update_wrapper -from typing import Optional, Type, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING, Union import torch @@ -1272,7 +1272,7 @@ def _make_node_magic(method, func): log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out) - pytype: Type + pytype: type # This is not strictly correct. In Python, a**b may return complex when # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This # returns a float while both arguments are ints: 2**(-1). Also, max and @@ -1335,7 +1335,7 @@ def _make_node_magic(method, func): out_hint = None if self.hint is not None: out_hint = op(self.hint) - pytype: Type + pytype: type if method in always_int_magic_methods: pytype = int elif method in always_bool_magic_methods: @@ -1485,7 +1485,7 @@ def _make_node_sizes_strides(method, func): out_hint = op(size_hints, stride_hints) # NB: This is the indicator function, not the actual bool! - pytype: Type + pytype: type if method.endswith("_indicator"): pytype = int else: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index dd971cd98d15..669f6208fb2f 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -23,7 +23,8 @@ import re import sys import threading import traceback -from collections import defaultdict +from collections import Counter, defaultdict +from collections.abc import Iterator, Mapping, Sequence from contextlib import _GeneratorContextManager, contextmanager from dataclasses import dataclass, field from enum import Enum @@ -31,19 +32,9 @@ from typing import ( Any, Callable, cast, - Counter, - DefaultDict, - Dict, - Iterator, - List, - Mapping, NamedTuple, NoReturn, Optional, - Sequence, - Set, - Tuple, - Type, TYPE_CHECKING, TypeVar, Union, @@ -104,8 +95,8 @@ if TYPE_CHECKING: from torch.types import BoolLikeType -InputList = List -DimList = List +InputList = list +DimList = list log = logging.getLogger(__name__) @@ -236,8 +227,8 @@ class SymIntEqByExpr: def _nested_int_aware_sort( - tup: Tuple[Union[SymInt, int], int] -) -> Tuple[int, Union[SymInt, int], int]: + tup: tuple[Union[SymInt, int], int] +) -> tuple[int, Union[SymInt, int], int]: return ( # Order nested ints by their coefficients. # 1 here to order nested ints after non-nested-ints. @@ -289,7 +280,7 @@ def lru_cache( # These are modules that contain generic code for interacting with ShapeEnv # which are unlikely to identify a particular interesting guard statement @lru_cache(None) -def uninteresting_files() -> Set[str]: +def uninteresting_files() -> set[str]: import torch._compile import torch._dynamo.eval_frame import torch._inductor.sizevars @@ -332,8 +323,8 @@ def has_symbolic_sizes_strides(elem: torch.Tensor) -> bool: Int: TypeAlias = Union[torch.SymInt, int] -def create_contiguous(shape: Sequence[Int]) -> List[Int]: - strides: List[Int] = [1] +def create_contiguous(shape: Sequence[Int]) -> list[Int]: + strides: list[Int] = [1] for dim in reversed(shape[:-1]): strides.append(dim * strides[-1]) # type: ignore[operator] return list(reversed(strides)) @@ -461,15 +452,15 @@ def check_consistent(new: _T, old: _T) -> None: def resolve_unbacked_bindings( shape_env: Optional[ShapeEnv], - bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]], -) -> Optional[Dict[sympy.Symbol, pytree.KeyPath]]: + bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], +) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]: if bindings is None: return None assert shape_env is not None return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()} -Result: TypeAlias = Union[torch.Tensor, Tuple[torch.Tensor, ...]] +Result: TypeAlias = Union[torch.Tensor, tuple[torch.Tensor, ...]] def rebind_unbacked( @@ -557,7 +548,7 @@ def rebind_unbacked( and len(raw_u1.args) == 2 and ( raw_u1_args0 := cast( - Tuple[sympy.Basic, sympy.Basic], raw_u1.args[0] + tuple[sympy.Basic, sympy.Basic], raw_u1.args[0] ) ) and raw_u1_args0[0] == 1 @@ -565,7 +556,7 @@ def rebind_unbacked( and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) and eq.rhs == 1 - and cast(Tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True) + and cast(tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True) ): # This is what the pattern match above is testing repacked = _sympy_cast_symbool_to_symint_guardless( @@ -645,8 +636,8 @@ def canonicalize_bool_expr(expr: _T) -> _T: def _sympy_from_args( - cls: Union[Type[sympy.Add], Type[sympy.Mul]], - args: List[sympy.Expr], + cls: type[Union[sympy.Add, sympy.Mul]], + args: list[sympy.Expr], sort: bool = True, is_commutative: Optional[bool] = None, ) -> sympy.Expr: @@ -686,7 +677,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: return type(expr)(*map(canonicalize_bool_expr, expr.args)) opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} - t: Union[Type[Any]] + t: Union[type[Any]] if isinstance(expr, tuple(opposite.keys())): rhs = expr.lhs - expr.rhs # type: ignore[attr-defined] t = opposite[type(expr)] # type: ignore[index] @@ -888,7 +879,7 @@ def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]: def find_symbol_binding_fx_nodes( graph: torch.fx.Graph, -) -> Dict[sympy.Symbol, torch.fx.Node]: +) -> dict[sympy.Symbol, torch.fx.Node]: r = {} # NB: Prefer first occurrence of symbol for node in graph.nodes: @@ -949,7 +940,7 @@ def compute_unbacked_bindings( example_value: object, old_example_value: Optional[object] = None, peek: bool = False, -) -> Optional[Dict[sympy.Symbol, pytree.KeyPath]]: +) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]: """ After having run fake tensor propagation and producing example_value result, traverse example_value looking for freshly bound unbacked @@ -977,7 +968,7 @@ def compute_unbacked_bindings( def free_unbacked_symbols_with_path( a: object, path: pytree.KeyPath, real: Optional[object] = None - ) -> Dict[sympy.Symbol, pytree.KeyPath]: + ) -> dict[sympy.Symbol, pytree.KeyPath]: assert shape_env is not None r = {} if isinstance(a, (tuple, list)): @@ -1456,11 +1447,11 @@ def guard_float(a: Union[SymFloat, float]) -> float: # Given a GraphModule, return all the FakeTensors for all the placeholders -def fx_placeholder_vals(gm: torch.fx.GraphModule) -> List[object]: +def fx_placeholder_vals(gm: torch.fx.GraphModule) -> list[object]: return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"] -def fx_placeholder_targets(gm: torch.fx.GraphModule) -> List[str]: +def fx_placeholder_targets(gm: torch.fx.GraphModule) -> list[str]: return [n.target for n in gm.graph.nodes if n.op == "placeholder"] @@ -1475,7 +1466,7 @@ def eval_guards( ) -def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> Dict[sympy.Symbol, int]: +def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> dict[sympy.Symbol, int]: return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr] @@ -1617,15 +1608,15 @@ class EqualityConstraint(Constraint): form and so the problem reduces to symbolic expression equality.) """ - source_pairs: List[Tuple[Source, Source]] - derived_equalities: List[ - Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]] + source_pairs: list[tuple[Source, Source]] + derived_equalities: list[ + tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]] ] - phantom_symbols: List[sympy.Symbol] - relaxed_sources: Set[Source] + phantom_symbols: list[sympy.Symbol] + relaxed_sources: set[Source] - _parents: Dict[Source, Source] = field(init=False) - _defs: Dict[Source, sympy.Expr] = field(init=False) + _parents: dict[Source, Source] = field(init=False) + _defs: dict[Source, sympy.Expr] = field(init=False) def __post_init__(self) -> None: """ @@ -1643,12 +1634,12 @@ class EqualityConstraint(Constraint): # self._parents is a map from input sources to input sources where, conceptually, # these are directed edges in a union-find forest - _parents: Dict[Source, Source] = {} + _parents: dict[Source, Source] = {} object.__setattr__(self, "_parents", _parents) # self._defs is a map from input sources to "canonical" symbolic expressions, # i.e., unary expressions with symbols that corresponds to regular Dims (i.e., # not derived Dims) - _defs: Dict[Source, sympy.Expr] = {} + _defs: dict[Source, sympy.Expr] = {} object.__setattr__(self, "_defs", _defs) for source1, source2 in self.source_pairs: @@ -1838,7 +1829,7 @@ class StatefulSymbolicContext(StatelessSymbolicContext): # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never # get recorded in var_to_val, etc. # TODO(voz): consider a weakref to the shape_env here - shape_env_to_source_to_symbol_cache: Dict[int, Dict[str, sympy.Expr]] = None # type: ignore[assignment] + shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = None # type: ignore[assignment] def __post_init__(self) -> None: super().__post_init__() @@ -1856,7 +1847,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext): flexibility, with inner symbolic contexts mapped via attr -> symbolic context. """ - inner_contexts: Dict[str, SymbolicContext] = None # type: ignore[assignment] + inner_contexts: dict[str, SymbolicContext] = None # type: ignore[assignment] def __post_init__(self) -> None: super().__post_init__() @@ -1875,7 +1866,7 @@ def is_symbolic( IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) -def _expandsums(args: List[sympy.Expr]) -> Tuple[sympy.Expr, bool]: +def _expandsums(args: list[sympy.Expr]) -> tuple[sympy.Expr, bool]: adds, other = [], [] for arg in args: if arg.is_Add: @@ -1912,8 +1903,8 @@ def _fast_expand(expr: _SympyT) -> _SympyT: elif exp < 0: return S.One / sympy.expand_multinomial(S.One / expr, deep=False) elif expr.is_Mul: - num: List[sympy.Expr] = [] - den: List[sympy.Expr] = [] + num: list[sympy.Expr] = [] + den: list[sympy.Expr] = [] for arg in expr.args: if arg.is_Pow and arg.args[1] == -1: den.append(S.One / arg) # type: ignore[operator, arg-type] @@ -1961,7 +1952,7 @@ class _SymbolInfo(NamedTuple): def _maybe_evaluate_static_worker( expr: _SympyT, # NB: this is a tuple to ensure it can be LRU cached - symbol_info: Tuple[_SymbolInfo, ...], + symbol_info: tuple[_SymbolInfo, ...], unbacked_only: bool, size_oblivious: bool, ) -> Optional[_SympyT]: @@ -2193,9 +2184,9 @@ class SymExprPrinter(PythonPrinter): class _ShapeGuardPrinter(abc.ABC): def __init__( self, - symbol_to_source: Mapping[sympy.Symbol, List[Source]], + symbol_to_source: Mapping[sympy.Symbol, list[Source]], source_ref: Callable[[Source], str], - var_to_sources: Mapping[sympy.Symbol, List[Source]], + var_to_sources: Mapping[sympy.Symbol, list[Source]], ) -> None: self.symbol_to_source = symbol_to_source self.source_ref = source_ref @@ -2246,7 +2237,7 @@ class ShapeGuardPrinter(ShapeGuardPythonPrinter): class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter): - def __init__(self, var_to_sources: Mapping[sympy.Symbol, List[Source]]): + def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]): super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) @@ -2261,7 +2252,7 @@ class DynamicDimConstraintPrinter(PythonPrinter): def __init__( self, - symbol_to_source: Dict[sympy.Symbol, List[Source]], + symbol_to_source: dict[sympy.Symbol, list[Source]], source_name_to_debug_name: Mapping[str, str], ): super().__init__() @@ -2284,23 +2275,23 @@ class DimConstraints: def __init__( self, - symbol_to_source: Dict[sympy.Symbol, List[Source]], + symbol_to_source: dict[sympy.Symbol, list[Source]], var_to_val: Mapping[sympy.Symbol, sympy.Integer], - marked_dynamic: Set[sympy.Symbol], + marked_dynamic: set[sympy.Symbol], source_name_to_debug_name: Mapping[str, str], ) -> None: # We try to solve systems of inequalities with 1 free variable. - self._univariate_inequalities: Dict[ - sympy.Symbol, Set[SympyBoolean] + self._univariate_inequalities: dict[ + sympy.Symbol, set[SympyBoolean] ] = defaultdict(set) # Among them, we prioritize solving for a free variable that has equalities. # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() # and removing a symbol from the former => removing it from the latter. - self._symbols_with_equalities: Set[sympy.Symbol] = set() + self._symbols_with_equalities: set[sympy.Symbol] = set() # A solution of a free variable with equalities becomes a substitution. # We use these substitutions to simplify other constraints. # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions. - self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {} + self._substitutions: dict[sympy.Symbol, sympy.Integer] = {} # In general, constraints may have // and % operations. # Of course, // can be expressed in terms of / and %. @@ -2308,20 +2299,20 @@ class DimConstraints: # We do so by using the values of variables as hints to evaluate %. # For soundness we record additional congruence guards and solve them separately. self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val - self._congruences: DefaultDict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set) + self._congruences: defaultdict[sympy.Symbol, set[sympy.Expr]] = defaultdict(set) # We do not try to (directly) solve inequalities with > 1 free variables. # NOTE: free variables in these inequalities cannot also be in _substitutions. - self._multivariate_inequalities: Set[SympyBoolean] = set() + self._multivariate_inequalities: set[SympyBoolean] = set() # We park external equalities between free variables here. - self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = [] + self._symbolic_equivalences: list[tuple[Source, sympy.Expr]] = [] # Solutions come in two forms: # - (static) specializations # - (dynamic) inequalities / congruences - self._static_results: Set[str] = set() - self._dynamic_results: Set[str] = set() + self._static_results: set[str] = set() + self._dynamic_results: set[str] = set() # printer for solutions self._dcp = DynamicDimConstraintPrinter( @@ -2329,13 +2320,13 @@ class DimConstraints: ) # inconsistencies found on substituting with concrete values / static solutions - self._inconsistencies: List[str] = [] + self._inconsistencies: list[str] = [] # symbols that are marked dynamic self._marked_dynamic = marked_dynamic # track supported sympy functions and subtract from list of all sympy functions - self._supported_sympy_functions: Set[sympy.Function] = { + self._supported_sympy_functions: set[sympy.Function] = { Application, Mod, PythonMod, @@ -2488,8 +2479,8 @@ class DimConstraints: # these will resolve to either specializations or dynamic equality constraints self._symbolic_equivalences.append((source, expr)) - def _reduce_congruences(self) -> Dict[sympy.Symbol, Set[sympy.Expr]]: - reduced_congruences: Dict[sympy.Symbol, Set[sympy.Expr]] = {} + def _reduce_congruences(self) -> dict[sympy.Symbol, set[sympy.Expr]]: + reduced_congruences: dict[sympy.Symbol, set[sympy.Expr]] = {} for s, congruences in self._congruences.items(): remainder_modulus_pairs = [] congruences_to_check = set() @@ -2650,7 +2641,7 @@ class DimConstraints: cond = cond and isinstance(divisor, sympy.Integer) return cond - def forced_specializations(self) -> Dict[str, sympy.Expr]: + def forced_specializations(self) -> dict[str, sympy.Expr]: """Returns a dictionary of the names of symbols to their specialized value""" def debug_name(src: Source) -> str: @@ -2678,8 +2669,8 @@ class DimConstraints: def _process_derived_dim_roots( self, - results: Dict[str, Dict[str, Any]], - name_to_dim: Dict[str, Any], + results: dict[str, dict[str, Any]], + name_to_dim: dict[str, Any], ) -> None: """ Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots, @@ -2745,7 +2736,7 @@ class DimConstraints: # {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2} # we want instead: # {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3} - introduced_roots: Dict[str, str] = {} # map new root -> old root + introduced_roots: dict[str, str] = {} # map new root -> old root for k, c in list(results.items()): if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim root = next(iter(c["eq"].free_symbols)) @@ -2782,7 +2773,7 @@ class DimConstraints: # this consists of: # 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim # 2) {"dy": "dx + 1"} -> dx: root for suggested fix - modified_roots: Set[str] = set() + modified_roots: set[str] = set() for k, c in results.items(): if k not in name_to_dim: # _dynamo.export() may handle source directly continue @@ -2799,7 +2790,7 @@ class DimConstraints: # evaluate the new value for each root # this is now either 1) unchanged, 2) refined with a new range, # or 3) specialized to a concrete value - modified_root_values: Dict[str, Dict[str, Any]] = {} + modified_root_values: dict[str, dict[str, Any]] = {} for mroot in modified_roots: swapped_root = True if mroot in results: @@ -2860,9 +2851,9 @@ class DimConstraints: def prettify_results( self, original_signature: inspect.Signature, - dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]], constraint_violation_error: object, - forced_specializations: Dict[str, str], + forced_specializations: dict[str, str], ) -> str: """Format a message for constraint violation erros""" from torch.export.dynamic_shapes import _get_dim_name_mapping @@ -2876,7 +2867,7 @@ class DimConstraints: s = s.replace(k, v) if not inverse else s.replace(v, k) return s - results: DefaultDict[str, Dict[str, Any]] = defaultdict(dict) + results: defaultdict[str, dict[str, Any]] = defaultdict(dict) if dynamic_shapes is None: dynamic_shapes = {} @@ -3050,7 +3041,7 @@ class ShapeEnv: self, *, should_record_events: Optional[bool] = None, - tracked_fakes: Optional[List[Any]] = None, + tracked_fakes: Optional[list[Any]] = None, **kwargs: Any, ) -> None: self._init(**kwargs) @@ -3086,7 +3077,7 @@ class ShapeEnv: # Keep track of the list of tracked fakes. self.tracked_fakes = tracked_fakes # List of events for reconstructing ShapeEnv at arbitrary points in time. - self.events: List[ShapeEnvEvent] = ( + self.events: list[ShapeEnvEvent] = ( [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else [] @@ -3099,7 +3090,7 @@ class ShapeEnv: # NOTE: It's important that SymNodes in this cache have their ShapeEnv # stripped otherwise you end up with cycles which can only be cleaned # with the GC. - self.fake_tensor_cache: Dict[ + self.fake_tensor_cache: dict[ torch._subclasses.fake_tensor._DispatchCacheKey, torch._subclasses.fake_tensor._DispatchCacheEntry, ] = {} @@ -3134,7 +3125,7 @@ class ShapeEnv: # symbolically equal. duck_shape: Optional[bool] = None, # For debugging - co_fields: Optional[Dict[str, str]] = None, + co_fields: Optional[dict[str, str]] = None, # When True, whenever safe, we will generate a deferred runtime assert # instead of a guard whenever we know that an expression must be True, # otherwise it would be an error, even for backed SymInts (where we @@ -3165,50 +3156,50 @@ class ShapeEnv: allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, ) - self.guards: List[ShapeGuard] = [] - self.axioms: Dict[sympy.Expr, sympy.Expr] = {} + self.guards: list[ShapeGuard] = [] + self.axioms: dict[sympy.Expr, sympy.Expr] = {} # Maps symbolic ints to their original concrete values # Currently populated from tensors - self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} + self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {} # Like var_to_val, but only set when propagate_real_tensors is on. # Used as last resort to avoid GuardOnDataDependent error - self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} + self.unbacked_var_to_val: dict[sympy.Symbol, sympy.Integer] = {} # Like above, but used exclusively for OBLIVIOUS_SIZE. These # potentially could be put together but I am not sure, writing out # the logic individually before abstracting. - self.oblivious_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} + self.oblivious_var_to_val: dict[sympy.Symbol, sympy.Integer] = {} # Maps symbolic ints to their min/max range. These ranges # are conservative: the int MUST fall in the range, but the # range may contain ints which may not actually appear in # practice - self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {} - self.var_to_range_sloc: Dict[sympy.Symbol, ValueRangesSLoc] = {} + self.var_to_range: dict[sympy.Symbol, ValueRanges] = {} + self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {} # When doing a size-oblivious test, exclude this integer and # everything higher than it from the acceptable range. This solves # https://github.com/pytorch/pytorch/issues/120288 for constant range # case # TODO: generalize this to work with expressions (in that case, we # need to maintain a SET and we need extra symbolic reasoning on top) - self.oblivious_upper_bound_exclusive: Dict[sympy.Symbol, sympy.Integer] = {} - self.source_name_to_debug_name: Dict[str, str] = {} - self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {} - self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {} + self.oblivious_upper_bound_exclusive: dict[sympy.Symbol, sympy.Integer] = {} + self.source_name_to_debug_name: dict[str, str] = {} + self.var_to_sources: dict[sympy.Symbol, list[Source]] = {} + self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {} # Maps a source to the *original* symbol that was assigned to it - self.source_to_var: Dict[str, sympy.Symbol] = {} + self.source_to_var: dict[str, sympy.Symbol] = {} # Maps from sympy ints to expressions representing them # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) - self.replacements: Dict[sympy.Symbol, sympy.Expr] = {} + self.replacements: dict[sympy.Symbol, sympy.Expr] = {} # The sloc of the guard that triggered this replacement to be added - self.replacements_slocs: Dict[sympy.Symbol, SLoc] = {} - self.unbacked_renamings: Dict[sympy.Symbol, sympy.Symbol] = {} + self.replacements_slocs: dict[sympy.Symbol, SLoc] = {} + self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {} # Set holds a % b expressions that evaluate to 0. - self.divisible: Set[sympy.Expr] = set() + self.divisible: set[sympy.Expr] = set() # Set that holds "size-like" symbols. When we perform # "size-oblivious" tests, these can be assumed to be >= 2. - self.size_like: Set[sympy.Symbol] = set() + self.size_like: set[sympy.Symbol] = set() # Duck-shaping says that if two input tensors have the same size, # they get assigned the same symbolic variable - self.val_to_var: Dict[int, sympy.Symbol] = {} + self.val_to_var: dict[int, sympy.Symbol] = {} if specialize_zero_one: self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One} self.unbacked_symfloat_counter = itertools.count() @@ -3241,8 +3232,8 @@ class ShapeEnv: # to the next unbacked symbol to wait on, but if we choose the # latest key, an assert will only show up at the moment when # we can actually codegen it. - self.deferred_runtime_asserts: Dict[ - Optional[sympy.Symbol], List[RuntimeAssert] + self.deferred_runtime_asserts: dict[ + Optional[sympy.Symbol], list[RuntimeAssert] ] = {} # This exists so we can efficiently invalidate the cache (it's used as # part of the cache key); otherwise we'd have to iterate through @@ -3279,7 +3270,7 @@ class ShapeEnv: # # NB: fresh unbacked symbols NEVER get substitutions applied to them, # they are binding sites! - self.pending_fresh_unbacked_symbols: List[sympy.Symbol] = [] + self.pending_fresh_unbacked_symbols: list[sympy.Symbol] = [] # Version counter used to invalidate cached values self._prev_cache_key = self._get_key() @@ -3294,8 +3285,8 @@ class ShapeEnv: # 2. list of arguments # This drastically reduces the size of the FX graph, avoiding # duplicated nodes. - self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {} - self.source_to_symbol: Dict[str, sympy.Symbol] = {} + self.fx_node_cache: dict[tuple[Callable, tuple[Any, ...]], torch.fx.Node] = {} + self.source_to_symbol: dict[str, sympy.Symbol] = {} # Suppose you want to replace an unbacked symbol with another # unbacked symbol. This is error prone because you can cause @@ -3322,7 +3313,7 @@ class ShapeEnv: # bindings. At the moment, this is not tracked, but we potentially # could track this at the IR level using a higher order operator # with something like effect token tracking. - self.unbacked_alloc_order: Dict[sympy.Symbol, int] = {} + self.unbacked_alloc_order: dict[sympy.Symbol, int] = {} from torch.fx.experimental.validator import translation_validation_enabled @@ -3345,7 +3336,7 @@ class ShapeEnv: # Whenever you add a node to self.graph, you must add a mapping to this # variable. Otherwise, the built FX graph on the replayed ShapeEnv will # not be valid. - self.name_to_node: Dict[str, torch.fx.Node] = {} + self.name_to_node: dict[str, torch.fx.Node] = {} @property def allow_scalar_outputs(self) -> bool: @@ -3439,7 +3430,7 @@ class ShapeEnv: shape_env_check_state_equal(self, other, non_state_variable_names, map_value) - def _snapshot_tracked_fakes(self) -> Optional[List[Any]]: + def _snapshot_tracked_fakes(self) -> Optional[list[Any]]: if self.tracked_fakes is None: return None @@ -3631,7 +3622,7 @@ class ShapeEnv: self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True) return self.source_to_symbol[srcname] - def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None: + def _add_z3var(self, symbol: sympy.Symbol, type: type) -> None: if self._translation_validation_enabled: self.validator.add_var(symbol, type) @@ -3651,8 +3642,8 @@ class ShapeEnv: def _create_fx_call_function( self, op: Callable, - args: Tuple, - ) -> Tuple[Optional[torch.fx.Node], bool]: + args: tuple, + ) -> tuple[Optional[torch.fx.Node], bool]: # Cache this tuple in order to avoid duplicated nodes. node_key = (op, args) # Flags whether the returned node was cached or not. @@ -3681,7 +3672,7 @@ class ShapeEnv: def _create_fx_placeholder_and_z3var( self, symbol: sympy.Symbol, - type: Type, + type: type, ) -> Optional[torch.fx.Node]: if not self._translation_validation_enabled: return None @@ -3742,7 +3733,7 @@ class ShapeEnv: """Context manager to ignore all guards generated inside""" return _suppress_guards(self) - def _get_key(self) -> Tuple[int, int, int, int]: + def _get_key(self) -> tuple[int, int, int, int]: """ Defines the current "state" of the guards we've accumulated in this ShapeEnv. Determines when we need to invalidate our cache @@ -3778,7 +3769,7 @@ class ShapeEnv: ex_size: Sequence[Union[int, SymInt]], source: Source, symbolic_context: SymbolicContext, - ) -> List[sympy.Expr]: + ) -> list[sympy.Expr]: return self._produce_dyn_sizes_from_int_tuple( tuple(ex_size), source, symbolic_context ) @@ -3788,7 +3779,7 @@ class ShapeEnv: tensor_size: Sequence[Union[int, SymInt]], source: Source, symbolic_context: SymbolicContext, - ) -> List[sympy.Expr]: + ) -> list[sympy.Expr]: assert all( not is_symbolic(val) for val in tensor_size ), f"Expect size to be a plain tuple of ints but got {tensor_size}" @@ -3816,9 +3807,9 @@ class ShapeEnv: source: Source, *, symbolic_context: Optional[SymbolicContext] = None, - ) -> Tuple[ - Tuple[Union[int, SymInt], ...], - Tuple[Union[int, SymInt], ...], + ) -> tuple[ + tuple[Union[int, SymInt], ...], + tuple[Union[int, SymInt], ...], Union[int, SymInt], ]: """ @@ -3903,17 +3894,17 @@ class ShapeEnv: source: Source, *, symbolic_context: Optional[SymbolicContext] = None, - ) -> Tuple[ - Tuple[Union[int, SymInt], ...], - Tuple[Union[int, SymInt], ...], + ) -> tuple[ + tuple[Union[int, SymInt], ...], + tuple[Union[int, SymInt], ...], Union[int, SymInt], ]: dim = len(ex_size) # Reimplement the legacy behavior if symbolic_context is None: - constraint_sizes: List[DimConstraint] = [None] * dim - constraint_strides: List[DimConstraint] = [None] * dim + constraint_sizes: list[DimConstraint] = [None] * dim + constraint_strides: list[DimConstraint] = [None] * dim dynamic_dims = [] dynamic_strides = [] for i in range(dim): @@ -3963,7 +3954,7 @@ class ShapeEnv: from torch._dynamo.source import TensorProperty, TensorPropertySource - size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple( + size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple( ex_size, source, symbolic_context ) stride = self._compute_symbolic_stride( @@ -4022,11 +4013,11 @@ class ShapeEnv: ], are_sizes_static: bool, symbolic_context: SymbolicContext, - ) -> List[sympy.Expr]: + ) -> list[sympy.Expr]: from torch._dynamo.source import TensorProperty, TensorPropertySource - stride: List[Optional[sympy.Expr]] = [None] * len(size) - candidates: Dict[Union[int, SymInt], sympy.Expr] = {} + stride: list[Optional[sympy.Expr]] = [None] * len(size) + candidates: dict[Union[int, SymInt], sympy.Expr] = {} # iterate over unbound strides in val ascending order with # index descending as a tie breaker since for cases like @@ -4590,7 +4581,7 @@ class ShapeEnv: return c_render return c.render(source) - def produce_guards(self, *args: Any, **kwargs: Any) -> List[str]: + def produce_guards(self, *args: Any, **kwargs: Any) -> list[str]: """ Like produce_guards_verbose, but only returns the non-verbose guard expressions (no verbose guards produced.) @@ -4603,7 +4594,7 @@ class ShapeEnv: sources: Sequence[Source], source_ref: Callable[[Source], str] = lambda n: n.name(), *, - guards: Optional[List[ShapeGuard]] = None, + guards: Optional[list[ShapeGuard]] = None, input_contexts: Optional[DimList[SymbolicContext]] = None, # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). # (See docs on EqualityConstraint for details of the encoding.) @@ -4611,7 +4602,7 @@ class ShapeEnv: _simplified: bool = False, # Indicates if we should produce guards for known static values. ignore_static: bool = True, - ) -> Tuple[List[str], List[str]]: # python, verbose + ) -> tuple[list[str], list[str]]: # python, verbose """ Generates a list of guards strings which, when evaluated in a context that defines tensors for all the sources, returns True or False depending @@ -4740,13 +4731,13 @@ class ShapeEnv: # the symbol mapping is input_guards = [] - symbol_to_source: Dict[sympy.Symbol, List[Source]] = collections.defaultdict( + symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict( list ) - symbol_to_constraints: DefaultDict[ - sympy.Symbol, Set[Constraint] + symbol_to_constraints: defaultdict[ + sympy.Symbol, set[Constraint] ] = collections.defaultdict(set) - constraint_violations: List[Tuple[bool, str, Callable[[], str]]] = [] + constraint_violations: list[tuple[bool, str, Callable[[], str]]] = [] py_printer = ShapeGuardPythonPrinter( symbol_to_source, source_ref, self.var_to_sources @@ -4956,7 +4947,7 @@ class ShapeEnv: # For subclasses, we need to track symints on BOTH the outer # and inner tensors. # TODO: type this better - sources_tensors_constraints: List[Tuple[Source, Any, Any, Any]] = [ + sources_tensors_constraints: list[tuple[Source, Any, Any, Any]] = [ (source, t, context.constraint_sizes, context.constraint_strides) ] attrs, _ = t.__tensor_flatten__() @@ -5256,8 +5247,8 @@ class ShapeEnv: ) if constraint_violations: - warn_msgs: List[str] = [] - error_msgs: List[str] = [] + warn_msgs: list[str] = [] + error_msgs: list[str] = [] debug_names = set() for warn_only, debug_name, msg_cb in constraint_violations: if warn_only: @@ -5327,7 +5318,7 @@ class ShapeEnv: self, placeholders: Sequence[Union[SymInt, FakeTensor]], *, - guards: Optional[List[ShapeGuard]] = None, + guards: Optional[list[ShapeGuard]] = None, ignore_static: bool = True, ) -> Optional[str]: """ @@ -5386,7 +5377,7 @@ class ShapeEnv: return self.evaluate_guards_expression(code, args) return True - def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> List[ShapeGuard]: + def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> list[ShapeGuard]: """ Get a list of guards, but pruned so it only provides guards that reference symints from the passed in input @@ -5401,7 +5392,7 @@ class ShapeEnv: def bind_symbols( self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor] - ) -> Dict[sympy.Symbol, int]: + ) -> dict[sympy.Symbol, int]: """ Given a paired list of placeholders (fake tensors with symbolic sizes) and concrete arguments (regular tensors @@ -5418,7 +5409,7 @@ class ShapeEnv: another copy. This assumes the guards are already checked, though if it's cheap we'll check for shenanigans """ - bindings: Dict[sympy.Symbol, int] = {} + bindings: dict[sympy.Symbol, int] = {} def bind_symint(arg: object, val: object) -> None: if isinstance(val, SymInt): @@ -5451,7 +5442,7 @@ class ShapeEnv: return bindings - def get_nontrivial_guards(self) -> List[SympyBoolean]: + def get_nontrivial_guards(self) -> list[SympyBoolean]: """Returns a list of guard expressions that aren't statically known (i.e. not trivial)""" return [ self.simplify(guard.expr) @@ -5488,9 +5479,9 @@ class ShapeEnv: @_lru_cache def get_axioms( self, - symbols: Optional[Tuple[sympy.Symbol]] = None, + symbols: Optional[tuple[sympy.Symbol]] = None, compute_hint: bool = False, - ) -> Tuple[SympyBoolean, ...]: + ) -> tuple[SympyBoolean, ...]: """ Given the symbols in an expression, it returns all the runtime asserts that have those symbols concatenated with all the guards. @@ -5518,9 +5509,9 @@ class ShapeEnv: @lru_cache(None) def get_implications( self, e: SympyBoolean - ) -> Tuple[Tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]: + ) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]: """Given a expression, it returns a list of predicates that follow from it""" - equiv: Dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {} + equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {} def add_expr(expr: SympyBoolean) -> None: expr = canonicalize_bool_expr(expr) @@ -5564,8 +5555,8 @@ class ShapeEnv: unbacked_only: bool = False, compute_hint: bool = False, size_oblivious: bool = False, - axioms: Optional[Tuple[SympyBoolean]] = None, - var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None, + axioms: Optional[tuple[SympyBoolean]] = None, + var_to_range: Optional[tuple[tuple[sympy.Symbol, ValueRanges]]] = None, ) -> Optional[sympy.Basic]: """ Tries to evaluate expr without introducing guards @@ -5589,7 +5580,7 @@ class ShapeEnv: expr = canonicalize_bool_expr(expr) - def resimplify_floor_div(axioms: Dict[sympy.Expr, sympy.Expr]) -> None: + def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None: if not self._resimplify_floor_div_axioms: return self._resimplify_floor_div_axioms = False @@ -6114,7 +6105,7 @@ class ShapeEnv: # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3). # (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols) # Prefer to simplify out symbols with ephemeral sources. - def _smart_symbol_sort(x: sympy.Symbol) -> Tuple[int, int, str]: + def _smart_symbol_sort(x: sympy.Symbol) -> tuple[int, int, str]: has_only_ephemeral_sources = x in self.var_to_sources and all( s.is_ephemeral() for s in self.var_to_sources[x] ) @@ -6282,7 +6273,7 @@ class ShapeEnv: def _get_stack_summary( self, is_debug: bool = False, framework_loc: Optional[str] = None - ) -> Tuple[SLoc, str]: + ) -> tuple[SLoc, str]: floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc if floc is None: frame = inspect.currentframe() @@ -6903,7 +6894,7 @@ class _PythonMsgPrinter(PythonPrinter): (i.e., as ==, !=, >, <). """ - def __init__(self, src_map: Dict[str, List[str]]) -> None: + def __init__(self, src_map: dict[str, list[str]]) -> None: super().__init__() self.src_map = src_map @@ -6912,7 +6903,7 @@ class _PythonMsgPrinter(PythonPrinter): def _suggest_torch_checks( - e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]] + e: GuardOnDataDependentSymNode, src_map: defaultdict[str, list[str]] ) -> None: # extract the unresolved condition on unbacked symints in the error cond = e.cond diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 5f9c2ecde9ca..61a51b977311 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -5,7 +5,7 @@ import logging import math import operator from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Optional, Union import sympy @@ -60,7 +60,7 @@ try: def z3str(e: z3.ExprRef) -> str: assert z3.is_expr(e), f"unsupported expression type: {e}" - def get_args_str(e: z3.ExprRef) -> List[str]: + def get_args_str(e: z3.ExprRef) -> list[str]: return [z3str(e.arg(i)) for i in range(e.num_args())] # First, we simplify the given expression. @@ -350,13 +350,13 @@ try: super().__init__(module, garbage_collect_values=True) def placeholder( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: symbol = fx_traceback.get_current_meta()["symbol"] return self.validator.z3var(symbol) def call_function( - self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: if target != torch._assert: # Lift and runs the node target function @@ -481,21 +481,21 @@ try: log.debug("new instance") # Mapping of SymPy symbols to Z3 variables. - self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {} + self.symbols: dict[sympy.Symbol, z3.ExprRef] = {} # Set of source Z3 expressions. # They represent the generated guards without any kind of # simplification or transformation. - self._source_exprs: Set[z3.BoolRef] = set() + self._source_exprs: set[z3.BoolRef] = set() # Set of target Z3 expressions. # They represent the actual checked guards at runtime. They might # be simplified or transformed versions of the source guards. - self._target_exprs: Set[z3.BoolRef] = set() + self._target_exprs: set[z3.BoolRef] = set() # Set of Z3 expressions representing assertions over both the # source and target expressions. - self._assertions: Set[z3.BoolRef] = set() + self._assertions: set[z3.BoolRef] = set() # Retrieves the corresponding Z3 variable. def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef: @@ -503,7 +503,7 @@ try: return self.symbols[symbol] # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists. - def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef: + def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef: if symbol in self.symbols: return self.symbols[symbol] @@ -769,7 +769,7 @@ def bisect(shape_env): # Checks whether the given shape_env fails when produce_guards is called. def check_shapeenv_fails( - shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]] + shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]] ) -> Optional[ValidationException]: assert tracked_fakes is not None try: diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 6e3e8a20687c..5698d76d66cc 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -11,23 +11,10 @@ import os import re import warnings from collections import defaultdict +from collections.abc import Iterable from contextlib import contextmanager from dataclasses import dataclass -from typing import ( - Any, - Callable, - Dict, - FrozenSet, - Iterable, - List, - Literal, - NamedTuple, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, -) +from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING import torch import torch.utils._pytree as pytree @@ -47,11 +34,11 @@ if TYPE_CHECKING: # Mapping of builtins to their `typing` equivalent. _origin_type_map = { - list: List, - dict: Dict, - set: Set, - frozenset: FrozenSet, - tuple: Tuple, + list: list, + dict: dict, + set: set, + frozenset: frozenset, + tuple: tuple, } _legal_ops = dict.fromkeys( @@ -61,7 +48,7 @@ _legal_ops = dict.fromkeys( # Signature for functions thattransforms the body (`list[str]`) of the # generated code -TransformCodeFunc = Callable[[List[str]], List[str]] +TransformCodeFunc = Callable[[list[str]], list[str]] class _CustomBuiltin(NamedTuple): @@ -78,7 +65,7 @@ class _CustomBuiltin(NamedTuple): obj: Any -_custom_builtins: Dict[str, _CustomBuiltin] = {} +_custom_builtins: dict[str, _CustomBuiltin] = {} def _register_custom_builtin(name: str, import_str: str, obj: Any): @@ -144,10 +131,10 @@ class _Namespace: """ def __init__(self): - self._obj_to_name: Dict[Any, str] = {} + self._obj_to_name: dict[Any, str] = {} self._unassociated_names = set() - self._used_names: Set[str] = set() - self._base_count: Dict[str, int] = defaultdict(int) + self._used_names: set[str] = set() + self._base_count: dict[str, int] = defaultdict(int) self._illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") @@ -261,10 +248,10 @@ class PythonCode: # Python source code for the forward function definition. src: str # Values in global scope during execution of `src_def`. - globals: Dict[str, Any] + globals: dict[str, Any] # Optional mapping from the forward function's line number to # node index. - _lineno_map: Optional[Dict[int, Optional[int]]] + _lineno_map: Optional[dict[int, Optional[int]]] def _format_target(base: str, target: str) -> str: @@ -311,7 +298,7 @@ class _PyTreeInfo(NamedTuple): Contains extra info stored when we're using Pytrees """ - orig_args: List[str] + orig_args: list[str] in_spec: pytree.TreeSpec out_spec: Optional[pytree.TreeSpec] @@ -359,7 +346,7 @@ class CodeGen: self._body_transformer: Optional[TransformCodeFunc] = None self._func_name: str = "forward" - def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: + def gen_fn_def(self, free_vars: list[str], maybe_return_annotation: str) -> str: """ Given the free variables and a return annotation, generates the beginning of the FX function. By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` @@ -398,7 +385,7 @@ class CodeGen: """ return outputs - def additional_globals(self) -> List[Tuple[str, Any]]: + def additional_globals(self) -> list[tuple[str, Any]]: """ If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. For example, return ['List', typing.List] if you need ``List`` in the global context. @@ -416,13 +403,13 @@ class CodeGen: include_device: bool = False, colored: bool = False, ) -> PythonCode: - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} + free_vars: list[str] = [] + body: list[str] = [] + globals_: dict[str, Any] = {} + wrapped_fns: dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [""] + maybe_return_annotation: list[str] = [""] include_stride = include_stride or ( os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1" ) @@ -553,7 +540,7 @@ class CodeGen: return blue(repr(arg)) def _format_args( - args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + args: tuple[Argument, ...], kwargs: dict[str, Argument] ) -> str: args_s = ", ".join(_get_repr(a) for a in args) kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) @@ -565,8 +552,8 @@ class CodeGen: # of a given node. This represents the *last* use of the node in the # execution order of the program, which we will use to free unused # values - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} + node_to_last_use: dict[Node, Node] = {} + user_to_last_uses: dict[Node, list[Node]] = {} def register_last_uses(n: Node, user: Node): if n not in node_to_last_use: @@ -782,9 +769,9 @@ class CodeGen: prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) # remove counter and generate lineno to node index mapping - lineno_map: Dict[int, Optional[int]] = {} + lineno_map: dict[int, Optional[int]] = {} prologue_len = prologue.count("\n") + 1 - new_lines: List[str] = [] + new_lines: list[str] = [] cur_idx = None for line in "".join(body).split("\n"): counter = re.search(r"# COUNTER: (\d+)", line) @@ -904,11 +891,11 @@ class _FindNodesLookupTable: """ def __init__(self): - self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict( + self.table: dict[tuple[str, Optional[Target]], dict[Node, None]] = defaultdict( dict ) - def _key(self, node) -> Tuple[str, Optional[Target]]: + def _key(self, node) -> tuple[str, Optional[Target]]: return (node.op, node.target if node.op == "call_function" else None) def __contains__(self, node) -> bool: @@ -985,14 +972,14 @@ class Graph: def __init__( self, owning_module: Optional["GraphModule"] = None, - tracer_cls: Optional[Type["Tracer"]] = None, - tracer_extras: Optional[Dict[str, Any]] = None, + tracer_cls: Optional[type["Tracer"]] = None, + tracer_extras: Optional[dict[str, Any]] = None, ): """ Construct an empty Graph. """ self._root: Node = Node(self, "", "root", "", (), {}) - self._used_names: Dict[str, int] = {} # base name -> number + self._used_names: dict[str, int] = {} # base name -> number self._insert = self._root.prepend self._len = 0 self._graph_namespace = _Namespace() @@ -1000,7 +987,7 @@ class Graph: self._tracer_cls = tracer_cls self._tracer_extras = tracer_extras self._codegen = CodeGen() - self._co_fields: Dict[str, Any] = {} + self._co_fields: dict[str, Any] = {} self._find_nodes_lookup_table = _FindNodesLookupTable() @property @@ -1060,7 +1047,7 @@ class Graph: @compatibility(is_backward_compatible=True) def graph_copy( - self, g: "Graph", val_map: Dict[Node, Node], return_output_node=False + self, g: "Graph", val_map: dict[Node, Node], return_output_node=False ) -> "Optional[Argument]": """ Copy all nodes from a given graph into ``self``. @@ -1113,8 +1100,8 @@ class Graph: self, op: str, target: "Target", - args: Optional[Tuple["Argument", ...]] = None, - kwargs: Optional[Dict[str, "Argument"]] = None, + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None, ) -> Node: @@ -1373,8 +1360,8 @@ class Graph: def call_module( self, module_name: str, - args: Optional[Tuple["Argument", ...]] = None, - kwargs: Optional[Dict[str, "Argument"]] = None, + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, type_expr: Optional[Any] = None, ) -> Node: """ @@ -1423,8 +1410,8 @@ class Graph: def call_method( self, method_name: str, - args: Optional[Tuple["Argument", ...]] = None, - kwargs: Optional[Dict[str, "Argument"]] = None, + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, type_expr: Optional[Any] = None, ) -> Node: """ @@ -1462,8 +1449,8 @@ class Graph: def call_function( self, the_function: Callable[..., Any], - args: Optional[Tuple["Argument", ...]] = None, - kwargs: Optional[Dict[str, "Argument"]] = None, + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, type_expr: Optional[Any] = None, ) -> Node: """ @@ -1668,10 +1655,10 @@ class Graph: Return a human-readable (not machine-readable) string representation of this Graph """ - placeholder_names: List[str] = [] + placeholder_names: list[str] = [] # This is a one-element array just so ``format_node`` can modify the closed # over value - maybe_return_typename: List[str] = [""] + maybe_return_typename: list[str] = [""] node_strs = [node.format_node(placeholder_names) for node in self.nodes] param_str = ", ".join(placeholder_names) @@ -1729,8 +1716,8 @@ class Graph: f"defined! Please check that Nodes in the graph are topologically ordered\n{self}" ) - seen_names: Set[str] = set() - seen_values: Set[Node] = set() + seen_names: set[str] = set() + seen_values: set[Node] = set() for node in self.nodes: if node.op not in [ "placeholder", diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index f53c2678e549..5a051d537ff8 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -8,7 +8,7 @@ import sys import traceback import warnings from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Type, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -39,7 +39,7 @@ class _EvalCacheLoader: self.eval_cache = {} self.next_id = 0 - def cache(self, src: str, globals: Dict[str, Any], co_fields=None): + def cache(self, src: str, globals: dict[str, Any], co_fields=None): """Store the source in a private cache, and add a lazy entry in linecache that allows the source to be retrieved by 'filename'. @@ -83,19 +83,19 @@ class _EvalCacheLoader: _loader = _EvalCacheLoader() -def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None): +def _exec_with_source(src: str, globals: dict[str, Any], co_fields=None): key = _loader.cache(src, globals, co_fields) exec(compile(src, key, "exec"), globals) -def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None): +def _forward_from_src(src: str, globals: dict[str, Any], co_fields=None): return _method_from_src( method_name="forward", src=src, globals=globals, co_fields=co_fields ) def _method_from_src( - method_name: str, src: str, globals: Dict[str, Any], co_fields=None + method_name: str, src: str, globals: dict[str, Any], co_fields=None ) -> Callable: # avoid mutating the passed in dict globals_copy = globals.copy() @@ -114,8 +114,8 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: return f"from {module_name} import {attr_name} as {name}" -def _format_import_block(globals: Dict[str, Any], importer: Importer): - import_strs: Set[str] = { +def _format_import_block(globals: dict[str, Any], importer: Importer): + import_strs: set[str] = { _format_import_statement(name, obj, importer) for name, obj in globals.items() } # Sort the imports so we have a stable import block that allows us to @@ -124,7 +124,7 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer): @compatibility(is_backward_compatible=True) -def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: +def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Module: # BC: attribute name was changed from `code` to `_code` to facilitate # making `code` into a property and adding a docstring to it fn_src = body.get("_code") or body["code"] @@ -134,7 +134,7 @@ def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mod @compatibility(is_backward_compatible=True) def reduce_package_graph_module( - importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str + importer: PackageImporter, body: dict[Any, Any], generated_module_name: str ) -> torch.nn.Module: forward = importer.import_module(generated_module_name).forward return _deserialize_graph_module(forward, body) @@ -142,7 +142,7 @@ def reduce_package_graph_module( @compatibility(is_backward_compatible=True) def reduce_deploy_graph_module( - importer: PackageImporter, body: Dict[Any, Any], import_block: str + importer: PackageImporter, body: dict[Any, Any], import_block: str ) -> torch.nn.Module: ns = {} ns["__builtins__"] = importer.patched_builtins @@ -162,7 +162,7 @@ class _CodeOnlyModule(torch.nn.Module): def _deserialize_graph_module( - forward, body: Dict[Any, Any], graph_module_cls=None + forward, body: dict[Any, Any], graph_module_cls=None ) -> torch.nn.Module: """ Deserialize a GraphModule given the dictionary of the original module, @@ -271,7 +271,7 @@ def _get_attr(model: torch.nn.Module, attr_name: str): return _get_attr_via_attr_list(model, attr_name.split(".")) -def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: List[str]): +def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]): if len(attr_list) == 0: return model *prefix, field = attr_list @@ -415,7 +415,7 @@ class GraphModule(torch.nn.Module): code. """ - def __new__(cls: "Type[GraphModule]", *args, **kwargs): + def __new__(cls: "type[GraphModule]", *args, **kwargs): # each instance of a graph module needs its own forward method # so create a new singleton class for each instance. # it is a subclass of the user-defined class, the only difference @@ -437,7 +437,7 @@ class GraphModule(torch.nn.Module): @compatibility(is_backward_compatible=True) def __init__( self, - root: Union[torch.nn.Module, Dict[str, Any]], + root: Union[torch.nn.Module, dict[str, Any]], graph: Graph, class_name: str = "GraphModule", ): @@ -527,12 +527,12 @@ class GraphModule(torch.nn.Module): self._tracer_extras = self.graph._tracer_extras # Dictionary to store metadata - self.meta: Dict[str, Any] = {} - self._replace_hooks: List[Callable] = [] - self._create_node_hooks: List[Callable] = [] - self._erase_node_hooks: List[Callable] = [] + self.meta: dict[str, Any] = {} + self._replace_hooks: list[Callable] = [] + self._create_node_hooks: list[Callable] = [] + self._erase_node_hooks: list[Callable] = [] # Used to remove hooks from deepcopied graph modules within a context manager. - self._deepcopy_hooks: List[Callable] = [] + self._deepcopy_hooks: list[Callable] = [] # TorchScript breaks trying to compile the graph setter because of the # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 @@ -739,7 +739,7 @@ class {module_name}(torch.nn.Module): This method can be called to clean up an ``nn.Module`` without manually calling ``delete_submodule`` on each unused submodule. """ - used: List[str] = [] + used: list[str] = [] for node in self.graph.nodes: if node.op == "call_module" or node.op == "get_attr": diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py index 2ff29cba474d..484f9c18f628 100644 --- a/torch/fx/immutable_collections.py +++ b/torch/fx/immutable_collections.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs -from typing import Any, Dict, Iterable, List, Tuple +from collections.abc import Iterable +from typing import Any from torch.utils._pytree import ( _dict_flatten, @@ -79,25 +80,25 @@ compatibility(is_backward_compatible=True)(immutable_dict) # Register immutable collections for PyTree operations -def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: +def _immutable_dict_flatten(d: dict[Any, Any]) -> tuple[list[Any], Context]: return _dict_flatten(d) def _immutable_dict_unflatten( values: Iterable[Any], context: Context, -) -> Dict[Any, Any]: +) -> dict[Any, Any]: return immutable_dict(_dict_unflatten(values, context)) -def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: +def _immutable_list_flatten(d: list[Any]) -> tuple[list[Any], Context]: return _list_flatten(d) def _immutable_list_unflatten( values: Iterable[Any], context: Context, -) -> List[Any]: +) -> list[Any]: return immutable_list(_list_unflatten(values, context)) diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index aa24eb4bed1a..9e2563b756c5 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import inspect from contextlib import contextmanager -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch import torch.fx.traceback as fx_traceback @@ -17,6 +17,10 @@ from .node import Argument, map_aggregate, map_arg, Node, Target from .proxy import Proxy +if TYPE_CHECKING: + from collections.abc import Iterator + + __all__ = ["Interpreter", "Transformer"] @@ -92,7 +96,7 @@ class Interpreter: self.graph = graph else: self.graph = self.module.graph # type: ignore[assignment] - self.env: Dict[Node, Any] = {} + self.env: dict[Node, Any] = {} self.name = "Interpreter" self.garbage_collect_values = garbage_collect_values self.extra_traceback = True @@ -102,8 +106,8 @@ class Interpreter: # of a given node. This represents the *last* use of the node in the # execution order of the program, which we will use to free unused # values - node_to_last_use: Dict[Node, Node] = {} - self.user_to_last_uses: Dict[Node, List[Node]] = {} + node_to_last_use: dict[Node, Node] = {} + self.user_to_last_uses: dict[Node, list[Node]] = {} def register_last_uses(n: Node, user: Node): if n not in node_to_last_use: @@ -118,7 +122,7 @@ class Interpreter: def run( self, *args, - initial_env: Optional[Dict[Node, Any]] = None, + initial_env: Optional[dict[Node, Any]] = None, enable_io_processing: bool = True, ) -> Any: """ @@ -232,7 +236,7 @@ class Interpreter: # Main Node running APIs @compatibility(is_backward_compatible=True) def placeholder( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: @@ -268,7 +272,7 @@ class Interpreter: @compatibility(is_backward_compatible=True) def get_attr( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute @@ -289,7 +293,7 @@ class Interpreter: @compatibility(is_backward_compatible=True) def call_function( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: """ Execute a ``call_function`` node and return the result. @@ -311,7 +315,7 @@ class Interpreter: @compatibility(is_backward_compatible=True) def call_method( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: """ Execute a ``call_method`` node and return the result. @@ -335,7 +339,7 @@ class Interpreter: @compatibility(is_backward_compatible=True) def call_module( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: """ Execute a ``call_module`` node and return the result. @@ -360,7 +364,7 @@ class Interpreter: @compatibility(is_backward_compatible=True) def output( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: """ Execute an ``output`` node. This really just retrieves @@ -401,7 +405,7 @@ class Interpreter: return attr_itr @compatibility(is_backward_compatible=True) - def fetch_args_kwargs_from_env(self, n: Node) -> Tuple[Tuple, Dict]: + def fetch_args_kwargs_from_env(self, n: Node) -> tuple[tuple, dict]: """ Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` from the current execution environment. @@ -497,7 +501,7 @@ class Transformer(Interpreter): def __init__(self, graph: Graph): super().__init__() self.graph = graph - self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment] + self.tensor_attrs: dict[torch.Tensor, str] = {} # type: ignore[assignment] def is_leaf_module(self, _, __) -> bool: return True @@ -507,7 +511,7 @@ class Transformer(Interpreter): @compatibility(is_backward_compatible=True) def placeholder( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Proxy: """ Execute a ``placeholder`` node. In ``Transformer``, this is @@ -529,7 +533,7 @@ class Transformer(Interpreter): @compatibility(is_backward_compatible=True) def get_attr( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Proxy: """ Execute a ``get_attr`` node. In ``Transformer``, this is @@ -548,7 +552,7 @@ class Transformer(Interpreter): @compatibility(is_backward_compatible=True) def call_module( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: # Override so that the leaf module policy from `self.tracer` is respected. assert isinstance(target, str) @@ -557,7 +561,7 @@ class Transformer(Interpreter): @compatibility(is_backward_compatible=True) def call_function( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: # Override so that functions that were wrapped are still wrapped. return self.tracer.create_proxy("call_function", target, args, kwargs) diff --git a/torch/fx/node.py b/torch/fx/node.py index 2f03341b6af0..b2aa020316c6 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -3,19 +3,8 @@ import builtins import inspect import types import warnings -from typing import ( - Any, - Callable, - Dict, - List, - Mapping, - Optional, - Sequence, - Set, - Tuple, - TYPE_CHECKING, - Union, -) +from collections.abc import Mapping, Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torch._C import _NodeBase @@ -57,7 +46,7 @@ Target = Union[Callable[..., Any], str] Argument = Optional[ Union[ - Tuple["Argument", ...], + tuple["Argument", ...], Sequence["Argument"], Mapping[str, "Argument"], slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing @@ -79,7 +68,7 @@ _legal_ops = dict.fromkeys( ] ) -_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { +_side_effectful_need_to_be_preserved_pre_dispatch: set[Callable] = { torch._C._set_grad_enabled, torch.amp._enter_autocast, torch.amp._exit_autocast, @@ -87,7 +76,7 @@ _side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { # TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, # or add logic to correctly mark all inplace ops as side effectful. -_side_effectful_functions: Set[Callable] = { +_side_effectful_functions: set[Callable] = { torch._assert, torch._assert_async, _ops.aten._assert_async.msg, @@ -227,18 +216,18 @@ class Node(_NodeBase): in the Graph printout. """ - _args: Tuple["Argument", ...] - _kwargs: Dict[str, "Argument"] + _args: tuple["Argument", ...] + _kwargs: dict[str, "Argument"] graph: "Graph" name: str op: str target: "Target" - _input_nodes: Dict["Node", None] - users: Dict["Node", None] + _input_nodes: dict["Node", None] + users: dict["Node", None] type: Optional[Any] _sort_key: Any _repr_fn: Optional[Callable[["Node"], str]] - meta: Dict[str, Any] + meta: dict[str, Any] @compatibility(is_backward_compatible=True) def __init__( @@ -247,8 +236,8 @@ class Node(_NodeBase): name: str, op: str, target: "Target", - args: Tuple["Argument", ...], - kwargs: Dict[str, "Argument"], + args: tuple["Argument", ...], + kwargs: dict[str, "Argument"], return_type: Optional[Any] = None, ) -> None: """ @@ -339,14 +328,14 @@ class Node(_NodeBase): # transformations. This metadata is preserved across node copies assign(self, "meta", {}) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() state["_erased"] = self._erased state["_prev"] = self._prev state["_next"] = self._next return state - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: _erased = state.pop("_erased") _prev = state.pop("_prev") _next = state.pop("_next") @@ -442,7 +431,7 @@ class Node(_NodeBase): p._next, n._prev = n, p @property - def args(self) -> Tuple[Argument, ...]: + def args(self) -> tuple[Argument, ...]: """ The tuple of arguments to this ``Node``. The interpretation of arguments depends on the node's opcode. See the :class:`Node` docstring for more @@ -454,7 +443,7 @@ class Node(_NodeBase): return self._args @args.setter - def args(self, a: Tuple[Argument, ...]) -> None: + def args(self, a: tuple[Argument, ...]) -> None: """ Set the tuple of arguments to this Node. The interpretation of arguments depends on the node's opcode. See the ``fx.Graph`` docstring for more @@ -465,7 +454,7 @@ class Node(_NodeBase): self.__update_args_kwargs(a, self._kwargs) @property - def kwargs(self) -> Dict[str, Argument]: + def kwargs(self) -> dict[str, Argument]: """ The dict of keyword arguments to this ``Node``. The interpretation of arguments depends on the node's opcode. See the :class:`Node` docstring for more @@ -477,7 +466,7 @@ class Node(_NodeBase): return self._kwargs @kwargs.setter - def kwargs(self, k: Dict[str, Argument]) -> None: + def kwargs(self, k: dict[str, Argument]) -> None: """ Set the dict of kwargs to this Node. The interpretation of arguments depends on the node's opcode. See the ``fx.Graph`` docstring for more @@ -488,7 +477,7 @@ class Node(_NodeBase): self.__update_args_kwargs(self._args, k) @property - def all_input_nodes(self) -> List["Node"]: + def all_input_nodes(self) -> list["Node"]: """ Return all Nodes that are inputs to this Node. This is equivalent to iterating over ``args`` and ``kwargs`` and only collecting the values that @@ -534,7 +523,7 @@ class Node(_NodeBase): self._args = args_left + (arg,) + args_right - _new_input_nodes: Dict[Node, None] = {} + _new_input_nodes: dict[Node, None] = {} map_arg(arg, _new_input_nodes.setdefault) for new_use in _new_input_nodes.keys(): @@ -574,7 +563,7 @@ class Node(_NodeBase): self.meta["stack_trace"] = trace def __update_args_kwargs( - self, new_args: Tuple["Argument", ...], new_kwargs: Dict[str, "Argument"] + self, new_args: tuple["Argument", ...], new_kwargs: dict[str, "Argument"] ) -> None: """ This API is internal. Do *not* call it directly. @@ -634,8 +623,8 @@ class Node(_NodeBase): @compatibility(is_backward_compatible=True) def format_node( self, - placeholder_names: Optional[List[str]] = None, - maybe_return_typename: Optional[List[str]] = None, + placeholder_names: Optional[list[str]] = None, + maybe_return_typename: Optional[list[str]] = None, ) -> Optional[str]: """ Return a descriptive string representation of ``self``. @@ -704,7 +693,7 @@ class Node(_NodeBase): delete_user_cb: Callable[["Node"], bool] = lambda user: True, *, propagate_meta: bool = False, - ) -> List["Node"]: + ) -> list["Node"]: """ Replace all uses of ``self`` in the Graph with the Node ``replace_with``. @@ -775,7 +764,7 @@ class Node(_NodeBase): # impure since it mutates inputs return True - tags: Optional[List[torch.Tag]] = getattr(self.target, "_tags", None) + tags: Optional[list[torch.Tag]] = getattr(self.target, "_tags", None) if tags is not None and torch.Tag.nondeterministic_seeded in tags: # impure since it mutates RNG state return True @@ -799,8 +788,8 @@ class Node(_NodeBase): def normalized_arguments( self, root: torch.nn.Module, - arg_types: Optional[Tuple[Any]] = None, - kwarg_types: Optional[Dict[str, Any]] = None, + arg_types: Optional[tuple[Any]] = None, + kwarg_types: Optional[dict[str, Any]] = None, normalize_to_only_use_kwargs: bool = False, ) -> Optional[ArgsKwargsPair]: """ diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index f654b6c060e8..c9319726318a 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -5,17 +5,7 @@ import numbers import types import typing import warnings -from typing import ( - Any, - Callable, - cast, - Dict, - List, - NamedTuple, - Optional, - Tuple, - TYPE_CHECKING, -) +from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING import torch from torch._jit_internal import boolean_dispatched @@ -44,11 +34,11 @@ class ArgsKwargsPair(NamedTuple): Simple named tuple for wrapping args/kwargs pairs. """ - args: Tuple[Any, ...] - kwargs: Dict[str, Any] + args: tuple[Any, ...] + kwargs: dict[str, Any] -_manual_overrides: Dict[Callable, List[inspect.Signature]] = {} +_manual_overrides: dict[Callable, list[inspect.Signature]] = {} def _nonzero_schemas(): @@ -108,7 +98,7 @@ def _torchscript_schema_to_signature_impl( ) -> inspect.Signature: from inspect import Parameter - parameters: List[Parameter] = [] + parameters: list[Parameter] = [] for arg in ts_schema.arguments: arg_type = _torchscript_type_to_python_type(arg.type) default = arg.default_value if arg.has_default_value() else Parameter.empty @@ -154,7 +144,7 @@ def _torchscript_schema_to_signature_impl( return inspect.Signature(parameters, return_annotation=return_type) -_SCHEMA_TO_SIGNATURE_CACHE: Dict[Tuple[str, str], inspect.Signature] = {} +_SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {} def _torchscript_schema_to_signature( @@ -173,7 +163,7 @@ def _torchscript_schema_to_signature( @compatibility(is_backward_compatible=False) def check_for_mutable_operation( - target: Callable, args: Tuple["Argument", ...], kwargs: Dict[str, "Argument"] + target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"] ): signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) @@ -265,12 +255,12 @@ def create_type_hint(x): if isinstance(x, list): def ret_type(x): - return List[x] # type: ignore[valid-type] + return list[x] # type: ignore[valid-type] else: def ret_type(x): - return Tuple[x, ...] + return tuple[x, ...] # type: ignore[valid-type] if len(x) == 0: return ret_type(Any) @@ -291,6 +281,10 @@ def create_type_hint(x): return x +_LIST_TYPES = (list, typing.List) # noqa: UP006 +_TUPLE_TYPES = (tuple, typing.Tuple) # noqa: UP006 + + @compatibility(is_backward_compatible=False) def type_matches(signature_type: Any, argument_type: Any): sig_origin_type = getattr(signature_type, "__origin__", signature_type) @@ -304,22 +298,24 @@ def type_matches(signature_type: Any, argument_type: Any): sig_contained = signature_type.__args__ return any(type_matches(c, argument_type) for c in sig_contained) - if signature_type is List[int] and argument_type is int: + if signature_type is typing.List[int] and argument_type is int: # noqa: UP006 # int can be promoted to List[int] return True - if getattr(signature_type, "__origin__", None) in {list, List}: + if getattr(signature_type, "__origin__", None) in _LIST_TYPES: sig_el_type = signature_type.__args__[0] + if sig_el_type is argument_type: + return True if not inspect.isclass(sig_el_type): warnings.warn( f"Does not support nested parametric types, got {signature_type}. Please file a bug." ) return False - if getattr(argument_type, "__origin__", None) in {list, List}: + if getattr(argument_type, "__origin__", None) in _LIST_TYPES: return issubclass(argument_type.__args__[0], sig_el_type) def is_homogeneous_tuple(t): - if getattr(t, "__origin__", None) not in {tuple, Tuple}: + if getattr(t, "__origin__", None) not in _TUPLE_TYPES: return False contained = t.__args__ if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason @@ -344,10 +340,10 @@ def type_matches(signature_type: Any, argument_type: Any): @compatibility(is_backward_compatible=False) def normalize_function( target: Callable, - args: Tuple[Any], - kwargs: Optional[Dict[str, Any]] = None, - arg_types: Optional[Tuple[Any]] = None, - kwarg_types: Optional[Dict[str, Any]] = None, + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + arg_types: Optional[tuple[Any]] = None, + kwarg_types: Optional[dict[str, Any]] = None, normalize_to_only_use_kwargs: bool = False, ) -> Optional[ArgsKwargsPair]: """ @@ -424,7 +420,7 @@ def normalize_function( ) else: if arg_types is not None or kwarg_types is not None: - arg_types = arg_types if arg_types else cast(Tuple[Any], ()) + arg_types = arg_types if arg_types else cast(tuple[Any], ()) kwarg_types = kwarg_types if kwarg_types else {} for candidate_signature in torch_op_schemas: sig_matches = True @@ -468,8 +464,8 @@ def normalize_function( def normalize_module( root: torch.nn.Module, target: str, - args: Tuple[Any], - kwargs: Optional[Dict[str, Any]] = None, + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, normalize_to_only_use_kwargs: bool = False, ) -> Optional[ArgsKwargsPair]: """ @@ -513,8 +509,8 @@ def normalize_module( def _args_kwargs_to_normalized_args_kwargs( sig: inspect.Signature, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], normalize_to_only_use_kwargs: bool, ) -> Optional[ArgsKwargsPair]: """ @@ -552,8 +548,8 @@ def _args_kwargs_to_normalized_args_kwargs( bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() - new_kwargs: Dict[str, Any] = {} - new_args: List[Any] = [] + new_kwargs: dict[str, Any] = {} + new_args: list[Any] = [] for i, param in enumerate(sig.parameters): if not normalize_to_only_use_kwargs and i < len(args): new_args.append(bound_args.arguments[param]) diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index b8eb4c4ce628..a7a2cdfcb17e 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging import os -from typing import Any, List, Set, Union +from typing import Any, Union from sympy import Integer, Number, Symbol from sympy.logic.boolalg import BooleanAtom @@ -28,7 +28,7 @@ from torch.utils._sympy.reference import TensorReferenceAnalysis from torch.utils._sympy.symbol import symbol_is_type, SymT -__all__: List[str] = [] +__all__: list[str] = [] log = logging.getLogger(__name__) graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") @@ -242,7 +242,7 @@ def tensorify_python_scalars( if node.op == "call_function" and ( replacement_op := SUPPORTED_OPS.get(node.target) ): - args: List[Any] = [] + args: list[Any] = [] transform = False compute_dtype = get_computation_dtype(node.meta["val"].dtype) @@ -299,7 +299,7 @@ def tensorify_python_scalars( "tensorify_float_success", True, overwrite=True ) - failed_tensorify_ops: Set[str] = set() + failed_tensorify_ops: set[str] = set() # Now do one more pass that specializes all symfloats we didn't manage # to tensorify away. diff --git a/torch/fx/passes/dialect/common/cse_pass.py b/torch/fx/passes/dialect/common/cse_pass.py index 6a501f041d19..e5889375bb07 100644 --- a/torch/fx/passes/dialect/common/cse_pass.py +++ b/torch/fx/passes/dialect/common/cse_pass.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Dict, Tuple +from typing import Any import torch from torch.fx import Graph, GraphModule, Node @@ -90,14 +90,14 @@ class CSEPass(PassBase): modified = False new_graph = Graph() - env: Dict[ + env: dict[ Node, Node ] = {} # map from node in the old graph to node in the new graph - hash_env: Dict[ - Tuple[torch._ops.OpOverload, int], Node + hash_env: dict[ + tuple[torch._ops.OpOverload, int], Node ] = {} # map from hash to a node in the new graph - token_map: Dict[ - Tuple[torch._ops.OpOverload, int], Dict[str, Any] + token_map: dict[ + tuple[torch._ops.OpOverload, int], dict[str, Any] ] = {} # map from hash to token for n in graph_module.graph.nodes: # The placeholder, output, and get_attr nodes are copied to the new graph without change diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index d8b95231f891..ab696837bb17 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -2,7 +2,7 @@ import hashlib from itertools import chain -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import torch import torch.fx @@ -150,10 +150,10 @@ if HAS_PYDOT: def get_submod_dot_graph(self, submod_name) -> pydot.Dot: return self._dot_graphs[f"{self._name}_{submod_name}"] - def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: + def get_all_dot_graphs(self) -> dict[str, pydot.Dot]: return self._dot_graphs - def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: + def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]: template = { "shape": self.dot_graph_shape, "fillcolor": "#CAFFE3", diff --git a/torch/fx/passes/graph_manipulation.py b/torch/fx/passes/graph_manipulation.py index ce9904fc500e..f559aa0bfcb3 100644 --- a/torch/fx/passes/graph_manipulation.py +++ b/torch/fx/passes/graph_manipulation.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Dict, List, NamedTuple, Optional +from typing import Any, NamedTuple, Optional import torch from torch.fx._compatibility import compatibility @@ -29,7 +29,7 @@ def replace_target_nodes_with( """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, and updates them to match the new op code and target""" new_graph = Graph() - val_map: Dict[Node, Node] = {} + val_map: dict[Node, Node] = {} for node in fx_module.graph.nodes: if node.op == old_op and node.target == old_target: args = map_arg(node.args, lambda n: val_map[n]) @@ -52,7 +52,7 @@ class size_bytes(NamedTuple): @compatibility(is_backward_compatible=False) def get_size_of_all_nodes( - fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None + fx_module: GraphModule, args: Optional[list[torch.Tensor]] = None ) -> None: """Given a fx graph module, update each node with its total size (weights + bias + output) and its output_size(output). For a non-module node, the total size is the output size. diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index d72a7599f349..e19abc7ad3d8 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import os -from typing import Callable, Dict, List, Optional, Set, TypeVar +from typing import Callable, Optional, TypeVar from torch.fx import Graph, Node from torch.fx._compatibility import compatibility @@ -45,11 +45,11 @@ class GraphTransformObserver: self.active = trace.enabled or self.log_url is not None if self.active: - self.erased_nodes: Set[str] = set() - self.created_nodes: Set[str] = set() - self.name_to_node: Dict[str, Node] = {} + self.erased_nodes: set[str] = set() + self.created_nodes: set[str] = set() + self.name_to_node: dict[str, Node] = {} # record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context - self.copied_gms: List[GraphModule] = [] + self.copied_gms: list[GraphModule] = [] self._node_creation_hook = self.get_node_creation_hook() self._node_erase_hook = self.get_node_erase_hook() diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index ca5848f1ff04..7867a0a7a6ae 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -2,8 +2,9 @@ import collections import itertools import logging +from collections.abc import Iterable, Sequence from copy import copy -from typing import Dict, Iterable, List, Optional, Sequence, Set +from typing import Optional from torch.fx.graph_module import GraphModule from torch.fx.node import _get_qualified_name, Node @@ -52,10 +53,10 @@ class _DependencyViewer: self.downstreams[node].add(output_node) self.downstreams[node].update(self.downstreams[output_node]) - def downstreams_of(self, node: Node) -> Set[Node]: + def downstreams_of(self, node: Node) -> set[Node]: return self.downstreams[node] - def upstreams_of(self, node: Node) -> Set[Node]: + def upstreams_of(self, node: Node) -> set[Node]: return self.upstreams[node] @@ -84,21 +85,21 @@ class CapabilityBasedPartitioner: dict(self.graph_module.named_modules()), node ) - def propose_partitions(self) -> List[Partition]: + def propose_partitions(self) -> list[Partition]: # partition_map is a mapping from partition id to a set of partition id's. # The value set contains all the partition ids that can be reached by doing a # DFS starting from the partition id in the key. - partition_map: Dict[int, Set] = collections.defaultdict(set) + partition_map: dict[int, set] = collections.defaultdict(set) # assumptions: nodes in candidate list is sorted in topological order - assignment: Dict[Node, int] = {} # mapping from node to partition_id - partitions_by_id: Dict[ + assignment: dict[Node, int] = {} # mapping from node to partition_id + partitions_by_id: dict[ int, Partition ] = {} # mapping from partition_id to partition - nodes_order: Dict[ + nodes_order: dict[ Node, int ] = {} # mapping from nodes to reversed topological order - partitions_order: Dict[ + partitions_order: dict[ int, int ] = {} # mapping from partition_id to minimum topo order of nodes in partition new_partition_id = itertools.count() @@ -111,7 +112,7 @@ class CapabilityBasedPartitioner: merged_nodes = copy(partitions_by_id[self_id].nodes) merged_nodes.update(partitions_by_id[other_id].nodes) - def dfs_iter_find_cycle(all_user_nodes: Set[Node]): + def dfs_iter_find_cycle(all_user_nodes: set[Node]): for user_node in all_user_nodes: visited_partition_ids = set() @@ -210,7 +211,7 @@ class CapabilityBasedPartitioner: for node in reversed(self.graph_module.graph.nodes): # use Dict as an ordered set to ensure deterministic partitioning result, don't care value - merge_candidates: Dict[int, None] = {} + merge_candidates: dict[int, None] = {} # Note a limited horizontal fusion is enabled: # when `node` is not supported, the code below attempts to fuse consumer of `node`. @@ -241,7 +242,7 @@ class CapabilityBasedPartitioner: # post processing to re-assign "getitem" nodes into upstream partition logger.debug("Reassigning getitem nodes to its producer node's partition...") - nodes_reassignment: Dict[Node, int] = {} + nodes_reassignment: dict[Node, int] = {} for node in self.graph_module.graph.nodes: is_tuple_output = True for user in node.users: @@ -266,7 +267,7 @@ class CapabilityBasedPartitioner: logger.debug("Filtering out single node partitions...") default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) - partitions_to_remove: List[int] = [] + partitions_to_remove: list[int] = [] for id, partition in partitions_by_id.items(): compute_node_count = 0 for node in partition.nodes: @@ -295,7 +296,7 @@ class CapabilityBasedPartitioner: ] def fuse_partitions( - self, partitions: List[Partition], prefix: str = "fused_" + self, partitions: list[Partition], prefix: str = "fused_" ) -> GraphModule: logger.debug("Fusing partitions...") # fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ] @@ -306,7 +307,7 @@ class CapabilityBasedPartitioner: ) # remove non-compute-ops that sits at the boundary of a partition. - def remove_bookend_non_compute_ops(self, partitions: List[Partition]): + def remove_bookend_non_compute_ops(self, partitions: list[Partition]): non_compute_ops = set(self.non_compute_ops) def is_non_compute_node(node: Node): @@ -316,11 +317,11 @@ class CapabilityBasedPartitioner: ) # cache transparent nodes - transparent_input_nodes: Dict[Node, bool] = {} - transparent_output_nodes: Dict[Node, bool] = {} + transparent_input_nodes: dict[Node, bool] = {} + transparent_output_nodes: dict[Node, bool] = {} def is_transparent_input_node( - node: Node, partition: Set[Node], removed_nodes: Set[Node] + node: Node, partition: set[Node], removed_nodes: set[Node] ): if ( node.op == "placeholder" @@ -341,7 +342,7 @@ class CapabilityBasedPartitioner: return False def is_transparent_output_node( - node: Node, partition: Set[Node], removed_nodes: Set[Node] + node: Node, partition: set[Node], removed_nodes: set[Node] ): if ( node.op == "placeholder" @@ -367,7 +368,7 @@ class CapabilityBasedPartitioner: # Note it's ok to use `set` here, since we are only query if a node # has been removed. We are NEVER going to iterate on nodes inside # the set. - remove_node: Set[Node] = set() + remove_node: set[Node] = set() for node in partition.nodes: if is_non_compute_node(node) and ( is_transparent_input_node(node, set(partition.nodes), remove_node) diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index cea5f4f25c77..68753d9351f1 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -3,7 +3,7 @@ import inspect import logging from functools import wraps from queue import Queue -from typing import Callable, Dict, List +from typing import Callable import torch.nn as nn from torch.fx._compatibility import compatibility @@ -50,7 +50,7 @@ def pass_result_wrapper(fn: Callable) -> Callable: def _validate_pass_schedule_constraint( - constraint: Callable[[Callable, Callable], bool], passes: List[Callable] + constraint: Callable[[Callable, Callable], bool], passes: list[Callable] ) -> None: for i, a in enumerate(passes): for j, b in enumerate(passes[i + 1 :]): @@ -64,8 +64,8 @@ def _validate_pass_schedule_constraint( def _topological_sort_passes( - passes: List[Callable], constraints: List[Callable] -) -> List[Callable]: + passes: list[Callable], constraints: list[Callable] +) -> list[Callable]: """ Args passes: Passes that we are ordering @@ -79,8 +79,8 @@ def _topological_sort_passes( return passes # Contruct a graph mapping nodes to a list of their users - graph: Dict[Callable, List[Callable]] = {p: [] for p in passes} - indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0) + graph: dict[Callable, list[Callable]] = {p: [] for p in passes} + indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0) candidates: Queue = Queue() for a in passes: for b in passes: @@ -95,8 +95,8 @@ def _topological_sort_passes( if indegree_map[a] == 0: candidates.put(a) - visited: Dict[Callable, bool] = dict.fromkeys(passes, False) - sorted_passes: List[Callable] = [] + visited: dict[Callable, bool] = dict.fromkeys(passes, False) + sorted_passes: list[Callable] = [] while not candidates.empty(): p = candidates.get() @@ -169,8 +169,8 @@ class PassManager: checks """ - passes: List[Callable[[nn.Module], PassResult]] - constraints: List[Callable[[Callable, Callable], bool]] + passes: list[Callable[[nn.Module], PassResult]] + constraints: list[Callable[[Callable, Callable], bool]] _validated: bool = False steps: int = 1 diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index c349c896ac3e..590dcde8152f 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import logging from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional import torch import torch.fx @@ -106,7 +106,7 @@ class _MinimizerBase: module: torch.fx.GraphModule, sample_input: Tensors, compare_fn: Callable[ - [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool] + [TensorOrTensors, TensorOrTensors, Names], tuple[float, bool] ], settings: _MinimizerSettingBase, module_exporter: Optional[ @@ -124,16 +124,16 @@ class _MinimizerBase: self.exclusion_fn = exclusion_fn # Stores outputs of run_a function - self.a_outputs: Dict[str, Any] = {} + self.a_outputs: dict[str, Any] = {} # Stores outputs of run_b function - self.b_outputs: Dict[str, Any] = {} + self.b_outputs: dict[str, Any] = {} # Stores the results of compare_fn - self.results: Dict[Any, Any] = {} + self.results: dict[Any, Any] = {} # Stores the report for the runs - self.reports: List[List[str]] = [] + self.reports: list[list[str]] = [] # Current iteration self.iteration: int = 0 @@ -205,7 +205,7 @@ class _MinimizerBase: def _get_submod_inputs( self, main_module: torch.fx.GraphModule, submod_path: str - ) -> Tuple[Tensors, Tensors]: + ) -> tuple[Tensors, Tensors]: """ Try get submodule inputs from stored outputs. If not found then use torch_glow.get_submod_inputs to get the inputs. @@ -280,7 +280,7 @@ class _MinimizerBase: else: node.tag = "main_0" - def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]: + def _build_submodule(self, nodes: NodeSet) -> tuple[torch.fx.GraphModule, str]: """ Split self.module so that one submodule consists of `nodes` and only `nodes`. @@ -412,7 +412,7 @@ class _MinimizerBase: culprits: NodeSet = set() nodes: NodeList = all_nodes[start_idx:end_idx] - report: List[str] = [] + report: list[str] = [] if self.exclusion_fn is not None: self.exclusion_fn(nodes, start_idx, end_idx) if len(nodes) == 0: @@ -484,7 +484,7 @@ class _MinimizerBase: culprits: NodeSet = set() for node in nodes: - report: List[str] = [] + report: list[str] = [] self.reports.append(report) self.iteration += 1 report.append(f"Sequential traverse iteration {self.iteration}.") @@ -534,7 +534,7 @@ class _MinimizerBase: find_last_node: If True, search for the last node which result in numerics difference if False: find first node in sorted node list """ - report: List[str] = [] + report: list[str] = [] mid = (start_idx + end_idx) // 2 cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:] @@ -726,7 +726,7 @@ class _MinimizerBase: return culprits for node in nodes: - report: List[str] = [] + report: list[str] = [] self.reports.append(report) self.iteration += 1 report.append(f"Accumulate traverse iteration {self.iteration}.") @@ -770,7 +770,7 @@ class _MinimizerBase: for node in nodes: if node in self.fusions: cur_nodes.update(self.fusions[node]) - report: List[str] = [] + report: list[str] = [] self.reports.append(report) self.iteration += 1 report.append(f" Nodes block {self.iteration}.") @@ -797,7 +797,7 @@ class _MinimizerBase: self.print_report(report) return set() - def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet: + def _skip_traverse(self, all_nodes: NodeList, skip_nodes: list) -> NodeSet: """ Skip certain nodes in graph based on settings """ @@ -874,7 +874,7 @@ class _MinimizerBase: ) as e: print(e) - def print_report(self, report: List[str]): + def print_report(self, report: list[str]): for i in range(len(report)): if i > 0: print(" . " + report[i]) @@ -889,7 +889,7 @@ class _MinimizerBase: self, start: Optional[str] = None, end: Optional[str] = None, - skip_nodes: Optional[List] = None, + skip_nodes: Optional[list] = None, find_last_node: Optional[bool] = None, ) -> NodeSet: """ diff --git a/torch/fx/passes/operator_support.py b/torch/fx/passes/operator_support.py index 53e8be37cecf..6cb14d312b60 100644 --- a/torch/fx/passes/operator_support.py +++ b/torch/fx/passes/operator_support.py @@ -24,9 +24,9 @@ TargetTypeName = str # Arguments' dtypes for a given node, see `OperatorSupport` SupportedArgumentDTypes = t.Optional[ - t.Tuple[ + tuple[ t.Sequence[t.Sequence[torch.dtype]], - t.Dict[str, t.Sequence[torch.dtype]], + dict[str, t.Sequence[torch.dtype]], ] ] @@ -204,7 +204,7 @@ class OpSupports: return create_op_support(_decline_if_input_dtype) @classmethod - def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase: + def decline_if_node_in_names(cls, disallow_set: set[str]) -> OperatorSupportBase: """ If a node has a name that is in the disallow set, reported it as non-supported. """ diff --git a/torch/fx/passes/param_fetch.py b/torch/fx/passes/param_fetch.py index 3eba16b06b03..02904b8e403e 100644 --- a/torch/fx/passes/param_fetch.py +++ b/torch/fx/passes/param_fetch.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Any, Callable import torch import torch.nn as nn @@ -23,7 +23,7 @@ def default_matching(name: str, target_version: int) -> str: # This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. # The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. # If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. -module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { +module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = { torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), torch.nn.modules.conv.Conv2d: ( 1, @@ -55,11 +55,11 @@ module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] @compatibility(is_backward_compatible=False) -def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: +def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]: """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` after checking module's version is compatible with the `module_fetch_book`. """ - attrs_for_lowering: Dict[str, Any] = {} + attrs_for_lowering: dict[str, Any] = {} attrs_for_lowering["name"] = torch.typename(mod) if type(mod) in module_fetch_book: diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index 9b0ccbb82d50..ddb1410f6840 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -2,7 +2,7 @@ import logging from functools import wraps from inspect import unwrap -from typing import Callable, List, Optional +from typing import Callable, Optional logger = logging.getLogger(__name__) @@ -121,7 +121,7 @@ def loop_pass( # Implemented as 'depends on' operators. A constraint is satisfied iff a list # has a valid partial ordering according to this comparison operator. def _validate_pass_schedule_constraint( - constraint: Callable[[Callable, Callable], bool], passes: List[Callable] + constraint: Callable[[Callable, Callable], bool], passes: list[Callable] ): for i, a in enumerate(passes): for j, b in enumerate(passes[i + 1 :]): @@ -191,8 +191,8 @@ class PassManager: `this_before_that_pass_constraint` for example. """ - passes: List[Callable] - constraints: List[Callable] + passes: list[Callable] + constraints: list[Callable] _validated: bool = False def __init__( @@ -217,7 +217,7 @@ class PassManager: self.constraints.append(constraint) self._validated = False - def remove_pass(self, _passes: List[str]): + def remove_pass(self, _passes: list[str]): if _passes is None: return passes_left = [ps for ps in self.passes if ps.__name__ not in _passes] diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 3b61446a92f7..0fcd72938367 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -3,7 +3,6 @@ import _operator import itertools from collections import defaultdict from enum import Enum -from typing import Dict, Set import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -199,7 +198,7 @@ _VIEW_INVERSE_MAP = { # This function, given a set of set of (aliased) tensor nodes, # Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index # in the node ordering. -def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): +def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int): def _add_if_tensor(x, set_): if isinstance(x, FakeTensor): set_.add(StorageWeakRef(x._typed_storage())) @@ -233,8 +232,8 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): # (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata # as "alias" def _get_view_inverse_node_usages( - later_node_usages: Set[Node], self_aliases: Set[Node] -) -> Set[Node]: + later_node_usages: set[Node], self_aliases: set[Node] +) -> set[Node]: def matching_view_metadata(a, b): return ( a.size() == b.size() @@ -515,7 +514,7 @@ def reinplace(gm, *sample_args): } # We also need to know for a given node, what are all of its aliasing nodes. - storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) + storage_to_nodes: dict[StorageWeakRef, set[Node]] = defaultdict(set) for n in gm.graph.nodes: if "fake_result" in n.meta: # Tree-mapping because some ops can return lists of tensors. diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 34ebadc0ef7a..f8c12327f318 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -3,7 +3,7 @@ import functools import logging import operator import sys -from typing import Any, Dict, Optional, Set, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING # Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow @@ -123,7 +123,7 @@ def insert_deferred_runtime_asserts( ) # We are going to mutate the dict - expr_to_proxy: Dict[sympy.Expr, fx.Proxy] = {} + expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {} placeholders = set() first_non_placeholder = None for node in graph.nodes: @@ -163,7 +163,7 @@ def insert_deferred_runtime_asserts( def _node_metadata_hook( node: torch.fx.Node, stack_trace: Optional[str] = None, - nn_module_stack: Optional[Dict[str, Any]] = None, + nn_module_stack: Optional[dict[str, Any]] = None, ) -> None: fake_args = pytree.tree_map( lambda arg: ( @@ -189,8 +189,8 @@ def insert_deferred_runtime_asserts( node.meta["nn_module_stack"] = nn_module_stack # Track asserts/checks we've added - added_asserts: Set[sympy.Expr] = set() - constrained_unbacked_symbols: Set[sympy.Symbol] = set() + added_asserts: set[sympy.Expr] = set() + constrained_unbacked_symbols: set[sympy.Symbol] = set() Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index 49a1046d54e3..28c6e8a57fe0 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -1,7 +1,7 @@ # mypy: ignore-errors import traceback -from typing import Any, Dict, NamedTuple, Optional, Tuple +from typing import Any, NamedTuple, Optional import torch import torch.fx @@ -24,12 +24,12 @@ class TensorMetadata(NamedTuple): shape: torch.Size dtype: torch.dtype requires_grad: bool - stride: Tuple[int, ...] + stride: tuple[int, ...] memory_format: Optional[torch.memory_format] # Quantization metadata is_quantized: bool - qparams: Dict[str, Any] + qparams: dict[str, Any] def _extract_tensor_metadata( @@ -57,7 +57,7 @@ def _extract_tensor_metadata( break is_quantized = result.is_quantized - qparams: Dict[str, Any] = {} + qparams: dict[str, Any] = {} if is_quantized: qscheme = result.qscheme() qparams["qscheme"] = qscheme diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 7fec3089c527..59c560423d40 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -2,7 +2,7 @@ import inspect import logging from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Callable, Optional import torch from torch.fx._compatibility import compatibility @@ -20,14 +20,14 @@ class Partition: def __init__(self, name: str): self.name: str = name self.submod_name = f"submod_{name}" - self.node_names: List[str] = [] - self.inputs: Dict[str, None] = {} - self.outputs: Dict[str, None] = {} - self.dependencies: Dict[str, None] = {} - self.dependents: Dict[str, None] = {} + self.node_names: list[str] = [] + self.inputs: dict[str, None] = {} + self.outputs: dict[str, None] = {} + self.dependencies: dict[str, None] = {} + self.dependents: dict[str, None] = {} self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() - self.environment: Dict[Node, Node] = {} - self.targets: Dict[str, Any] = {} + self.environment: dict[Node, Node] = {} + self.targets: dict[str, Any] = {} def __repr__(self) -> str: return ( @@ -55,7 +55,7 @@ def split_module( m: GraphModule, root_m: torch.nn.Module, split_callback: Callable[[Node], int], - qualname_map: Optional[Dict[str, str]] = None, + qualname_map: Optional[dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, ): @@ -161,8 +161,8 @@ def split_module( def construct_graph( node: Node, - base_mod_env: Dict[str, Node], - base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule], + base_mod_env: dict[str, Node], + base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule], ): if node.op == "placeholder": default_value = ( @@ -195,9 +195,9 @@ def split_module( import sympy - partitions: Dict[str, Partition] = {} - orig_nodes: Dict[str, Node] = {} - symbol_to_node: Dict[sympy.Symbol, Node] = {} + partitions: dict[str, Partition] = {} + orig_nodes: dict[str, Node] = {} + symbol_to_node: dict[sympy.Symbol, Node] = {} def record_cross_partition_use(def_node: Node, use_node: Optional[Node]): from torch.fx.experimental.symbolic_shapes import free_symbols @@ -273,7 +273,7 @@ def split_module( # ------------------------ # 1. first region: we do nothing # 2. subsequent regions: we insert the set_grad at the beginning - grad_regions: OrderedDict[Node, Set[int]] = OrderedDict() + grad_regions: OrderedDict[Node, set[int]] = OrderedDict() # For autocast regions: # ------------------------ @@ -282,8 +282,8 @@ def split_module( # _enter at the beginning and _exit at the end # 3. last region: we will only insert _enter at the beginning # We will do so in the order in which the autocasts were instantiated. - autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict() - autocast_exits: Dict[Node, Optional[Node]] = {} + autocast_regions: OrderedDict[Node, set[int]] = OrderedDict() + autocast_exits: dict[Node, Optional[Node]] = {} active_grad = None active_autocasts = set() @@ -379,13 +379,13 @@ def split_module( original_partition_order = list(partitions.keys()) # find partitions with no dependencies - root_partitions: List[str] = [] + root_partitions: list[str] = [] for partition_name, partition in partitions.items(): if not len(partition.dependencies): root_partitions.append(partition_name) # check partitions for circular dependencies and create topological partition ordering - sorted_partitions: List[str] = [] + sorted_partitions: list[str] = [] while root_partitions: root_partition = root_partitions.pop() sorted_partitions.append(root_partition) @@ -418,7 +418,7 @@ def split_module( # add placeholders to partition inputs for partition_name in sorted_partitions: partition = partitions[partition_name] - new_inputs: Dict[str, None] = {} + new_inputs: dict[str, None] = {} for inp in partition.inputs: orig_node = orig_nodes[inp] # We don't pass in get_attr nodes as inputs to the partition, but @@ -507,11 +507,11 @@ def split_module( ) # is it really a good idea to copy this? # original module environment dict mapping node names to nodes - orig_mod_env: Dict[str, Node] = {} + orig_mod_env: dict[str, Node] = {} # Set up values to construct base module - base_mod_env: Dict[str, Node] = {} + base_mod_env: dict[str, Node] = {} base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() - base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} + base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {} if not keep_original_order: for node in m.graph.nodes: base_mod_env, base_mod_attrs = construct_graph( @@ -559,7 +559,7 @@ def split_module( if keep_original_order: # first get the attr nodes required by this partition - orig_mod_attr_nodes: List[Node] = [ + orig_mod_attr_nodes: list[Node] = [ orig_mod_env[key] for key in partition.inputs if key not in original_order diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index 46e89814f625..c95a2b4cbfd1 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import copy from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Optional, Union import torch.fx from torch.fx._compatibility import compatibility @@ -45,28 +45,28 @@ class Component: name: str # Stores the placeholder nodes in `graph`. - input_placeholders: List = field(default_factory=list) + input_placeholders: list = field(default_factory=list) # Store the nodes in original graph that are placeholder in `graph`. - orig_inputs: List = field(default_factory=list) + orig_inputs: list = field(default_factory=list) # Store the nodes in original graph that are outputs in `graph`. - orig_outputs: List = field(default_factory=list) + orig_outputs: list = field(default_factory=list) # Mapping from get_attr node in original graph to get_attr node in `graph`. - getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) - constructor_args: List[str] = field(default_factory=list) + getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) + constructor_args: list[str] = field(default_factory=list) gm: Optional[torch.fx.GraphModule] = None @compatibility(is_backward_compatible=False) def split_by_tags( gm: torch.fx.GraphModule, - tags: List[str], + tags: list[str], return_fqn_mapping: bool = False, return_tuple: bool = False, - GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule, -) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]: + GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule, +) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]: """ Splits a GraphModule using tags on its graph nodes. We honor the order of tags. For example, we have tags = ["a", "b", "c"], the function will create @@ -133,26 +133,26 @@ def split_by_tags( return r # Mapping from node in original module to node in created submodule. - node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} # Mapping from node in original module or created submodules to # corresponding component. - node_to_component: Dict[torch.fx.Node, Component] = {} + node_to_component: dict[torch.fx.Node, Component] = {} # Mapping from tag to the corresponding component. - tag_to_component: Dict[str, Component] = {} + tag_to_component: dict[str, Component] = {} # Stores all components. - all_components: List[Component] = [] + all_components: list[Component] = [] # Stores nodes that will be used in main graph. - used_in_main: Dict[torch.fx.Node, None] = {} + used_in_main: dict[torch.fx.Node, None] = {} # Main graph after split. main_g = torch.fx.Graph() # Mapping from node in original module to node in main graph after split. - main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + main_remapping: dict[torch.fx.Node, torch.fx.Node] = {} # Output node of original module. output_node: Optional[torch.fx.Node] = None @@ -258,7 +258,7 @@ def split_by_tags( node_to_component[n].orig_outputs.append(n) # Now we create a graphmodule for each component. - orig_to_split_fqn_mapping: Dict[str, str] = {} + orig_to_split_fqn_mapping: dict[str, str] = {} for comp in all_components: outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 880ebc68cd83..6ca9da390f35 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -3,8 +3,9 @@ import argparse import copy import logging from collections import defaultdict +from collections.abc import Iterable, Sequence from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple +from typing import Any, NamedTuple, Optional import torch from torch.fx._compatibility import compatibility @@ -225,7 +226,7 @@ class SplitResult(NamedTuple): """ split_module: torch.fx.GraphModule - submodule_inputs: Dict[str, Any] + submodule_inputs: dict[str, Any] non_acc_submodule_prefix: str @@ -235,7 +236,7 @@ def generate_inputs_for_submodules( inputs: Sequence[Any], target_submodules: Iterable[str], deepcopy: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this function doesn't work. @@ -365,16 +366,16 @@ class _SplitterBase: self.update_deps_for_fusions() self.non_acc_submodule_name = non_acc_submodule_name - self._node_submodule_map: Dict[str, str] = {} + self._node_submodule_map: dict[str, str] = {} self._return_tuple = return_tuple - self.tags: List[str] = [] + self.tags: list[str] = [] # =============================================================== # Helpers for ctor and initial state # =============================================================== - def get_node_submodule_map(self) -> Dict[str, str]: + def get_node_submodule_map(self) -> dict[str, str]: """Returns a map from node name to submodule name, e.g. node: main_module_impl_impl_over_arch_unary_multiple_embedding _pooling_embedding_pooling_sparse_entity_equivalence_key @@ -383,7 +384,7 @@ class _SplitterBase: """ return self._node_submodule_map - def find_deps(self) -> Dict[torch.fx.Node, NodeSet]: + def find_deps(self) -> dict[torch.fx.Node, NodeSet]: """ Builds a graph of node dependencies. Leaf nodes don't have any dependencies and the "output" node doesn't have nodes depending on it. @@ -391,7 +392,7 @@ class _SplitterBase: Resulting graph has only direct dependencies, i.e. there are no transitive dependencies. """ - deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set) + deps: dict[torch.fx.Node, NodeSet] = defaultdict(set) for node in self.module.graph.nodes: if node.op not in CALLABLE_NODE_OPS: continue @@ -647,12 +648,12 @@ class _SplitterBase: def find_reverse_deps( self, tag_id: Optional[int] = None - ) -> Dict[torch.fx.Node, NodeSet]: + ) -> dict[torch.fx.Node, NodeSet]: """ Builds reversed topological node dependencies, if tag_id is specified, we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. """ - result: Dict[torch.fx.Node, NodeSet] = defaultdict(set) + result: dict[torch.fx.Node, NodeSet] = defaultdict(set) for node in self.module.graph.nodes: if node.op not in CALLABLE_NODE_OPS: @@ -667,7 +668,7 @@ class _SplitterBase: return result - def update_reverse_deps_for_fusions(self, deps: Dict[torch.fx.Node, NodeSet]): + def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]): processed_node = set() for node, fusion in self.fusions.items(): @@ -757,7 +758,7 @@ class _SplitterBase: # Helpers for split() method # =============================================================== - def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: + def starter_nodes(self) -> tuple[NodeSet, NodeSet]: """ Finds nodes that consume module inputs or get_attr nodes. """ @@ -773,7 +774,7 @@ class _SplitterBase: starter_cpu_nodes.add(user) return starter_cpu_nodes, starter_acc_nodes - def put_nodes_into_subgraphs(self) -> List[Subgraph]: + def put_nodes_into_subgraphs(self) -> list[Subgraph]: # We start graph traversal from leaf nodes current_cpu_nodes, current_acc_nodes = self.starter_nodes() visited_nodes: NodeSet = set() @@ -785,7 +786,7 @@ class _SplitterBase: current_subgraph_nodes: NodeList = [] # Result accumulator - subgraphs: List[Subgraph] = [] + subgraphs: list[Subgraph] = [] while current_cpu_nodes or current_acc_nodes: # Find the first node that should belong to the current subgraph and has all dependencies resolved current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes @@ -839,12 +840,12 @@ class _SplitterBase: return subgraphs - def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: + def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]: """ This pass finds ACC submodules with less than specified size and merges them with adjacent CPU submodules. """ - result: List[Subgraph] = [] + result: list[Subgraph] = [] for subgraph in subgraphs: if subgraph.is_acc: if len(subgraph.nodes) >= self.settings.min_acc_module_size: @@ -866,7 +867,7 @@ class _SplitterBase: result.append(subgraph) return result - def tag(self, subgraphs: List[Subgraph]): + def tag(self, subgraphs: list[Subgraph]): self.tags = [] for subgraph in subgraphs: tag = ( diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index 4ed56be63b09..212b094e86e3 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs import collections import operator +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union +from typing import Any, Optional, Union import torch import torch.fx @@ -18,11 +19,11 @@ __all__ = [ "legalize_graph", ] -Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] +Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]] TensorOrTensors = Union[torch.Tensor, Tensors] -NodeList = List[torch.fx.Node] -NodeSet = Set[torch.fx.Node] -Names = List[str] +NodeList = list[torch.fx.Node] +NodeSet = set[torch.fx.Node] +Names = list[str] CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} @@ -172,8 +173,8 @@ class FxNetAccFusionsFinder: return False - def __call__(self) -> Dict[torch.fx.Node, NodeSet]: - result: Dict[torch.fx.Node, NodeSet] = {} + def __call__(self) -> dict[torch.fx.Node, NodeSet]: + result: dict[torch.fx.Node, NodeSet] = {} acc_nodes = list(self.acc_nodes) for node in acc_nodes: @@ -294,7 +295,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: for node in gm.graph.nodes: if indeg[node] == 0: queue.append(node) - env: Dict[torch.fx.Node, torch.fx.Node] = {} + env: dict[torch.fx.Node, torch.fx.Node] = {} # Pop nodes from the queue, and add nodes that have had all their # dependencies fulfilled while len(queue) > 0: diff --git a/torch/fx/passes/utils/common.py b/torch/fx/passes/utils/common.py index bb628372337b..17362c9eec12 100644 --- a/torch/fx/passes/utils/common.py +++ b/torch/fx/passes/utils/common.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Dict, Tuple from torch.fx._compatibility import compatibility from torch.fx.graph import Graph @@ -30,7 +29,7 @@ def lift_subgraph_as_module( subgraph: Graph, comp_name: str = "", class_name: str = "GraphModule", -) -> Tuple[GraphModule, Dict[str, str]]: +) -> tuple[GraphModule, dict[str, str]]: """ Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. @@ -52,7 +51,7 @@ def lift_subgraph_as_module( # make "weight" a attribute of "conv" HolderModule and point to conv.weight in # the original module. submodule = HolderModule({}) - orig_to_split_fqn_mapping: Dict[str, str] = {} + orig_to_split_fqn_mapping: dict[str, str] = {} for n in subgraph.nodes: if n.op not in ("call_module", "get_attr"): continue diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index fa090b677f32..7487bc2c6631 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import copy from queue import SimpleQueue -from typing import Dict, List, Optional as _Optional, Tuple +from typing import Optional as _Optional import torch.fx from torch.fx._compatibility import compatibility @@ -97,10 +97,10 @@ def fuse_as_graphmodule( gm: GraphModule, nodes: NodeList, module_name: str, - partition_lookup_table: _Optional[Dict[Node, None]] = None, + partition_lookup_table: _Optional[dict[Node, None]] = None, *, always_return_tuple: bool = False, -) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: +) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]: """ Fuse nodes in graph_module into a GraphModule. @@ -144,10 +144,10 @@ def fuse_as_graphmodule( subgraph = Graph() - node_to_placeholder: Dict[ + node_to_placeholder: dict[ Node, Node ] = {} # mapping of nodes from old graph to placeholder in new graph - node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph + node_map: dict[Node, Node] = {} # mapping of nodes from old graph to new graph # handles inputs through graph.node_copy's arg_transform functions def remap_inputs(x): @@ -176,7 +176,7 @@ def fuse_as_graphmodule( node_map[node] = new_node # handles outputs - output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs + output_mapping: dict[Node, Node] = {} # mapping from old output to new outputs for node in nodes: for user_node in node.users: @@ -202,10 +202,10 @@ def fuse_as_graphmodule( ) # sub_gm's input nodes in the original module - original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys()) + original_inputs: tuple[Node, ...] = tuple(node_to_placeholder.keys()) # sub_gm's outputs node in the original module - original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys()) + original_outputs: tuple[Node, ...] = tuple(output_mapping.keys()) return fused_gm, original_inputs, original_outputs @@ -214,8 +214,8 @@ def fuse_as_graphmodule( def insert_subgm( gm: GraphModule, sub_gm: GraphModule, - orig_inputs: Tuple[Node, ...], - orig_outputs: Tuple[Node, ...], + orig_inputs: tuple[Node, ...], + orig_outputs: tuple[Node, ...], ): # add sub_gm into gm submodule_name = sub_gm.__class__.__name__ @@ -250,7 +250,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList): @compatibility(is_backward_compatible=False) def fuse_by_partitions( gm: GraphModule, - partitions: List[Dict[Node, None]], + partitions: list[dict[Node, None]], prefix: str = "fused_", always_return_tuple: bool = False, ) -> GraphModule: diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index cc05b8f512b1..27d24ed29945 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -4,7 +4,7 @@ import logging import os from collections import defaultdict from dataclasses import dataclass, field -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, Union import torch from torch.fx import Graph, Node @@ -37,19 +37,19 @@ logger = _init_logger() @dataclass class InternalMatch: # Nodes from which the match was found - anchors: List[Node] + anchors: list[Node] # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] = field(default_factory=dict) + nodes_map: dict[Node, Node] = field(default_factory=dict) # nodes in target graph that are matched placeholder in pattern - placeholder_nodes: List[Node] = field(default_factory=list) + placeholder_nodes: list[Node] = field(default_factory=list) # nodes in matched subgraph returned by output - returning_nodes: List[Node] = field(default_factory=list) + returning_nodes: list[Node] = field(default_factory=list) # map from a string name to a node in the target graph # only available if the matcher is `SubgraphMatcherWithNameNodesMap` - name_node_map: Dict[str, Node] = field(default_factory=dict) + name_node_map: dict[str, Node] = field(default_factory=dict) def __copy__(self): return InternalMatch( @@ -107,9 +107,9 @@ class SubgraphMatcher: ] output_node = next(iter(reversed(pattern.nodes))) # nodes returned by outputs - self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes + self.pattern_returning_nodes: list[Node] = output_node.all_input_nodes - self.pattern_anchors: List[Node] = [] + self.pattern_anchors: list[Node] = [] if match_output: self.pattern_anchors = [output_node] else: @@ -150,12 +150,12 @@ class SubgraphMatcher: return pn.target == gn.target return False - def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool: + def _is_contained(self, nodes_map: dict[Node, Node]) -> bool: # `lookup` represents all the nodes in `original_graph` # that are part of `pattern` # Placeholders can be used by other nodes in the graphs - lookup: Dict[Node, Node] = { + lookup: dict[Node, Node] = { gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder" } @@ -172,10 +172,10 @@ class SubgraphMatcher: return True def _remove_overlapping_matches( - self, matches: List[InternalMatch] - ) -> List[InternalMatch]: - non_overlapping_matches: List[InternalMatch] = [] - nodes_matched: Set[Node] = set() + self, matches: list[InternalMatch] + ) -> list[InternalMatch]: + non_overlapping_matches: list[InternalMatch] = [] + nodes_matched: set[Node] = set() for match in matches: found_overlap = False @@ -244,7 +244,7 @@ class SubgraphMatcher: # match for `gn` match_found = True - def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool: + def _match_args(args1: Union[list, tuple], args2: Union[list, tuple]) -> bool: if len(args1) != len(args2): return False @@ -313,7 +313,7 @@ class SubgraphMatcher: return True - def match(self, graph: Graph) -> List[InternalMatch]: + def match(self, graph: Graph) -> list[InternalMatch]: """ Returns: The matched subgraphs. @@ -352,7 +352,7 @@ class SubgraphMatcher: from torch.fx.passes.utils.fuser_utils import validate_partition # find candidate nodes to match with pattern anchors - match_candidates: Dict[Node, List[Node]] = defaultdict(list) + match_candidates: dict[Node, list[Node]] = defaultdict(list) for pattern_anchor in self.pattern_anchors: for node in graph.nodes: if self._nodes_are_equal(pattern_anchor, node): @@ -361,7 +361,7 @@ class SubgraphMatcher: logger.info("Initial match_candidates_list: %s\n", match_candidates_list) - matches: List[InternalMatch] = [] + matches: list[InternalMatch] = [] def backtracking(anchor_index, match): if anchor_index == len(match_candidates_list): diff --git a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py index 78b063ff8a7a..1fa9b721e9cc 100644 --- a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py +++ b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -1,5 +1,3 @@ -from typing import Dict, List, Tuple - from torch.fx import Graph, GraphModule, Node from torch.fx._compatibility import compatibility @@ -11,7 +9,7 @@ __all__ = ["SubgraphMatcherWithNameNodeMap"] def _split_to_graph_and_name_node_map( gm: GraphModule, -) -> Tuple[GraphModule, Dict[str, Node]]: +) -> tuple[GraphModule, dict[str, Node]]: from torch.fx.graph import _PyTreeInfo from torch.utils._pytree import tree_flatten, tree_unflatten @@ -29,7 +27,7 @@ def _split_to_graph_and_name_node_map( *out, name_node_map = output flattened, out_spec = tree_flatten(out) assert isinstance( - name_node_map, Dict + name_node_map, dict ), "Expecting the input graph to have a dict output as the last element" n.args = (flattened,) orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] @@ -88,7 +86,7 @@ class SubgraphMatcherWithNameNodeMap(SubgraphMatcher): ignore_literals, ) - def match(self, graph: Graph) -> List[InternalMatch]: + def match(self, graph: Graph) -> list[InternalMatch]: """The returned InternalMatch will have name_node_map populated with a map from node name (str) to the target node, e.g. {"conv": target_conv_ndoe, "relu": target_relu_node} diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index 8826dc62f745..97a60b06694c 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -1,7 +1,7 @@ import logging import os from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Optional from torch.fx._compatibility import compatibility from torch.fx.graph import Graph @@ -34,29 +34,29 @@ logger = _init_logger() @dataclass class SourcePartition: # Nodes in a particular partition - nodes: List[Node] + nodes: list[Node] # The source these nodes decomposed from source: Any # Nodes in the graph that are needed as inputs to the partition # These do not include the params of the partition - input_nodes: List[Node] = field(default_factory=list) + input_nodes: list[Node] = field(default_factory=list) # Nodes in the partition that are being used by nodes outside of the # partition - output_nodes: List[Node] = field(default_factory=list) + output_nodes: list[Node] = field(default_factory=list) # Parameters that are being used - params: List[Node] = field(default_factory=list) + params: list[Node] = field(default_factory=list) @compatibility(is_backward_compatible=False) # type: ignore[misc] def get_source_partitions( graph: Graph, - wanted_sources: List[Any], + wanted_sources: list[Any], filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Dict[Any, List[SourcePartition]]: +) -> dict[Any, list[SourcePartition]]: """ Args: graph: The graph we want to partition @@ -69,7 +69,7 @@ def get_source_partitions( that correspond to the list of nodes that were decomposed from the given source. """ - modules: Dict[Type, Dict[str, List[Node]]] = {} + modules: dict[type, dict[str, list[Node]]] = {} for node in graph.nodes: # The metadata source_fn should contain a tuple of a unique name for the @@ -98,7 +98,7 @@ def get_source_partitions( partition = diff_modules.setdefault(source_fn[0], []) partition.append(node) - def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition: + def make_partition(nodes: list[Node], module_type: type) -> SourcePartition: input_nodes = set() output_nodes = set() params = set() @@ -124,7 +124,7 @@ def get_source_partitions( list(params), # type: ignore[arg-type] ) - ret: Dict[Type[Any], List[SourcePartition]] = {} + ret: dict[type[Any], list[SourcePartition]] = {} if filter_fn: # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 7268a7363336..24ac76ac6318 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -8,8 +8,10 @@ import inspect import logging import operator import sys +from collections import OrderedDict +from collections.abc import Iterator from dataclasses import fields, is_dataclass -from typing import Any, Callable, Dict, Iterator, Optional, OrderedDict, Tuple +from typing import Any, Callable, Optional import torch import torch.fx.traceback as fx_traceback @@ -135,18 +137,18 @@ class TracerBase: scope: Scope # Records the module call stack - module_stack: OrderedDict[str, Tuple[str, Any]] + module_stack: OrderedDict[str, tuple[str, Any]] # Mapping of node name to module scope - node_name_to_scope: Dict[str, Tuple[str, type]] + node_name_to_scope: dict[str, tuple[str, type]] @compatibility(is_backward_compatible=True) def create_node( self, kind: str, target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: Optional[str] = None, type_expr: Optional[Any] = None, ) -> Node: @@ -171,7 +173,7 @@ class TracerBase: # Optionally set stack trace on the created Node for debugging purposes if fx_traceback.has_preserved_node_meta(): - current_meta: Dict[str, Any] = fx_traceback.get_current_meta() + current_meta: dict[str, Any] = fx_traceback.get_current_meta() stack_trace = current_meta.get("stack_trace") if stack_trace: @@ -211,8 +213,8 @@ class TracerBase: self, kind: str, target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, # fix noqa when updating bc tests @@ -455,10 +457,10 @@ class Proxy: # we peephole optimize to the method invocation return Attribute(self, k) - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: return self.__dict__ - def __deepcopy__(self, memo) -> Dict: + def __deepcopy__(self, memo) -> dict: # We have to explicitly override this method, because otherwise deepcopy # will go to __getattr__(self, "__deepcopy__") and return a # Attribute(__deepcopy__), and may go into an infinite loop in some cases. @@ -564,7 +566,7 @@ class Proxy: args = args if args else () kwargs = kwargs if kwargs else {} - tracers: Dict[Any, None] = {} + tracers: dict[Any, None] = {} def find_tracer(a): if isinstance(a, cls): diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 44b03c3acf5f..711fba12542d 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -1,16 +1,6 @@ import copy from dataclasses import dataclass -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Set, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union import torch @@ -37,7 +27,7 @@ class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] + nodes_map: dict[Node, Node] @compatibility(is_backward_compatible=False) @@ -46,9 +36,9 @@ class ReplacedPatterns: # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph - nodes_map: Dict[Node, Node] + nodes_map: dict[Node, Node] # List of nodes that were added into the graph - replacements: List[Node] + replacements: list[Node] def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: @@ -106,7 +96,7 @@ def replace_pattern( gm: GraphModule, pattern: Union[Callable, GraphModule], replacement: Union[Callable, GraphModule], -) -> List[Match]: +) -> list[Match]: """ Matches all possible non-overlapping sets of operators and their data dependencies (``pattern``) in the Graph of a GraphModule @@ -237,14 +227,14 @@ def replace_pattern_with_filters( pattern: Union[Callable, Graph, GraphModule], replacement: Union[Callable, Graph, GraphModule, None] = None, match_filters: Optional[ - List[Callable[["InternalMatch", Graph, Graph], bool]] + list[Callable[["InternalMatch", Graph, Graph], bool]] ] = None, ignore_literals: bool = False, # Placed at the end to avoid breaking backward compatibility replacement_callback: Optional[ Callable[["InternalMatch", Graph, Graph], Graph] ] = None, -) -> List[ReplacedPatterns]: +) -> list[ReplacedPatterns]: """ See replace_pattern for documentation. This function is an overload with an additional match_filter argument. @@ -268,14 +258,14 @@ def _replace_pattern( pattern: Union[Callable, Graph, GraphModule], replacement: Union[Callable, Graph, GraphModule, None] = None, match_filters: Optional[ - List[Callable[["InternalMatch", Graph, Graph], bool]] + list[Callable[["InternalMatch", Graph, Graph], bool]] ] = None, ignore_literals: bool = False, # Placed at the end to avoid breaking backward compatibility replacement_callback: Optional[ Callable[["InternalMatch", Graph, Graph], Graph] ] = None, -) -> List[ReplacedPatterns]: +) -> list[ReplacedPatterns]: from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher if match_filters is None: @@ -298,7 +288,7 @@ def _replace_pattern( remove_overlapping_matches=True, ignore_literals=ignore_literals, ) - _matches: List[InternalMatch] = matcher.match(original_graph) + _matches: list[InternalMatch] = matcher.match(original_graph) # Filter out matches that don't match the filter _matches = [ @@ -323,7 +313,7 @@ def _replace_pattern( common_replacement_graph = None # As we progressively replace nodes, we'll need to keep track of how the match results should change - match_changed_node: Dict[Node, Node] = {} + match_changed_node: dict[Node, Node] = {} match_and_replacements = [] for match in _matches: @@ -345,7 +335,7 @@ def _replace_pattern( # Initialize `val_map` with mappings from placeholder nodes in # `replacement` to their corresponding node in `original_graph` assert len(match.placeholder_nodes) == len(replacement_placeholders) - val_map: Dict[Node, Node] = {} + val_map: dict[Node, Node] = {} for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): if isinstance(gn, Node): val_map[rn] = match_changed_node.get(gn, gn) @@ -361,7 +351,7 @@ def _replace_pattern( val_map[rn] = gn # Copy the replacement graph over - user_nodes: Set[Node] = set() + user_nodes: set[Node] = set() for n in match.returning_nodes: user_nodes.update(n.users) @@ -402,7 +392,7 @@ def _replace_pattern( copied_returning_nodes = (copied_returning_nodes,) # Get a list of nodes that have been replaced into the graph - replacement_nodes: List[Node] = [ + replacement_nodes: list[Node] = [ v for v in val_map.values() if v not in match.placeholder_nodes ] diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 88a8fc54fa5f..095b3f1b27bc 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -4,7 +4,7 @@ import json import traceback from contextlib import contextmanager from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from ._compatibility import compatibility from .graph import Graph @@ -25,7 +25,7 @@ __all__ = [ "get_graph_provenance_json", ] -current_meta: Dict[str, Any] = {} +current_meta: dict[str, Any] = {} should_preserve_node_meta = False @@ -49,15 +49,15 @@ class NodeSource: self.graph_id = graph_id pass_name: str - action: List["NodeSourceAction"] - from_node: List["NodeSource"] + action: list["NodeSourceAction"] + from_node: list["NodeSource"] node_info: Optional["NodeInfo"] def __init__( self, node: Optional[Node], pass_name: str = "", - action: Optional[Union["NodeSourceAction", List["NodeSourceAction"]]] = None, + action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None, ): self.pass_name = pass_name @@ -146,7 +146,7 @@ def preserve_node_meta(enable=True): @compatibility(is_backward_compatible=False) -def set_stack_trace(stack: List[str]): +def set_stack_trace(stack: list[str]): global current_meta if should_preserve_node_meta and stack: @@ -182,7 +182,7 @@ def reset_grad_fn_seq_nr(): @compatibility(is_backward_compatible=False) -def format_stack() -> List[str]: +def format_stack() -> list[str]: if should_preserve_node_meta: return [current_meta.get("stack_trace", "")] else: @@ -219,7 +219,7 @@ def set_current_meta(node, pass_name=""): @compatibility(is_backward_compatible=False) -def get_current_meta() -> Dict[str, Any]: +def get_current_meta() -> dict[str, Any]: return current_meta