PEP585 update - torch/fx (#145166)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-19 19:32:07 -08:00
committed by PyTorch MergeBot
parent 6374332d33
commit 0b2a3687b9
57 changed files with 904 additions and 917 deletions

View File

@ -3,6 +3,7 @@
import builtins import builtins
import contextlib import contextlib
import collections
import copy import copy
import functools import functools
import inspect import inspect
@ -2767,7 +2768,7 @@ class TestFX(JitTestCase):
return self.other(x) return self.other(x)
traced = symbolic_trace(ReturnTypeModule()) traced = symbolic_trace(ReturnTypeModule())
self.assertIn("-> typing_List[str]", traced._code) self.assertIn("-> list[str]", traced._code)
scripted = torch.jit.script(traced) scripted = torch.jit.script(traced)
self.assertIn("-> List[str]", scripted.code) self.assertIn("-> List[str]", scripted.code)
@ -3566,8 +3567,8 @@ class TestFX(JitTestCase):
traced(x, y) traced(x, y)
FileCheck().check("_Tuple[()]") \ FileCheck().check("tuple[()]") \
.check("typing_Tuple[str,typing_Tuple[()]]") \ .check("tuple[str,tuple[()]]") \
.run(traced.code) .run(traced.code)
scripted = torch.jit.script(traced) scripted = torch.jit.script(traced)
@ -4063,45 +4064,62 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
return f'{fn_name}({", ".join(arg_strs)}){return_annot}' return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
def _annotation_type_to_stable_str(self, t, sig_str): _trivial_mappings = {
str : 'str',
int : 'int',
float: 'float',
bool: 'bool',
torch.dtype: 'torch.dtype',
torch.Tensor: 'torch.Tensor',
torch.device: 'torch.device',
torch.memory_format: 'torch.memory_format',
slice: 'slice',
torch.nn.Module: 'torch.nn.modules.module.Module',
torch.fx.Graph : 'torch.fx.graph.Graph',
torch.fx.Node : 'torch.fx.node.Node',
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
torch.fx.node.Target : 'torch.fx.node.Target',
torch.fx.node.Argument : 'torch.fx.node.Argument',
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
Ellipsis : '...',
typing.Any: 'Any',
type(None): 'NoneType',
None: 'None',
typing.Iterator: 'Iterator',
collections.abc.Iterator: 'Iterator',
}
_UNBOUND_TYPES = {
dict,
list,
tuple,
type,
typing.Callable,
typing.Dict,
typing.List,
typing.Tuple,
typing.Type,
typing.Union,
}
def _annotation_type_to_stable_str(self, t, sig_str, recursive: bool = False):
if t is inspect.Signature.empty: if t is inspect.Signature.empty:
return '' return ''
# Forward ref # Forward ref
if isinstance(t, str): if isinstance(t, str):
return f"'{t}'" if recursive:
return t
else:
return f"'{t}'"
if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef): if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
return t.__forward_arg__ return t.__forward_arg__
if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef): if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
return t.__forward_arg__ return t.__forward_arg__
trivial_mappings = { mapping = self._trivial_mappings.get(t, None)
str : 'str',
int : 'int',
float: 'float',
bool: 'bool',
torch.dtype: 'torch.dtype',
torch.Tensor: 'torch.Tensor',
torch.device: 'torch.device',
torch.memory_format: 'torch.memory_format',
slice: 'slice',
torch.nn.Module: 'torch.nn.modules.module.Module',
torch.fx.Graph : 'torch.fx.graph.Graph',
torch.fx.Node : 'torch.fx.node.Node',
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
torch.fx.node.Target : 'torch.fx.node.Target',
torch.fx.node.Argument : 'torch.fx.node.Argument',
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
Ellipsis : '...',
typing.Any: 'Any',
type(None): 'NoneType',
None: 'None',
typing.Iterator: 'Iterator',
}
mapping = trivial_mappings.get(t, None)
if mapping: if mapping:
return mapping return mapping
@ -4115,14 +4133,14 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
if all(isinstance(ct, typing.TypeVar) for ct in contained): if all(isinstance(ct, typing.TypeVar) for ct in contained):
contained = [] contained = []
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained] contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str, True) for ct in contained]
contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else '' contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
origin = getattr(t, '__origin__', None) origin = getattr(t, '__origin__', None)
if origin is None: if origin is None:
# Unbound types don't have `__origin__` in some Python versions, so fix that up here. # Unbound types don't have `__origin__` in some Python versions, so fix that up here.
origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin origin = t if t in self._UNBOUND_TYPES else origin
if origin in {tuple, typing.Tuple}: if origin in {tuple, typing.Tuple}:
return f'Tuple{contained_type_str}' return f'Tuple{contained_type_str}'
@ -4130,7 +4148,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
# Annoying hack to detect Optional # Annoying hack to detect Optional
if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)): if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
not_none_param = contained[0] if contained[0] is not type(None) else contained[1] not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]' return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str, True)}]'
return f'Union{contained_type_str}' return f'Union{contained_type_str}'
if origin in {dict, typing.Dict}: if origin in {dict, typing.Dict}:
return f'Dict{contained_type_str}' return f'Dict{contained_type_str}'

View File

@ -1524,6 +1524,29 @@ class {test_classname}(torch.nn.Module):
(int, type(torch.float)), (int, type(torch.float)),
(Union[int, float], int), (Union[int, float], int),
(Union[int, float], float), (Union[int, float], float),
(list[int], int),
(list[int], create_type_hint([int, int])),
(list[int], create_type_hint((int, int))),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
(
list[torch.Tensor],
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
),
(torch.Tensor, torch.nn.Parameter),
(list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
(
list[torch.Tensor],
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
),
(torch.Tensor, torch.nn.Parameter),
(list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[list[torch.Tensor]], list[torch.Tensor]),
(Optional[list[int]], list[int]),
] + [
# pre-PEP585 signatures
(List[int], int), (List[int], int),
(List[int], create_type_hint([int, int])), (List[int], create_type_hint([int, int])),
(List[int], create_type_hint((int, int))), (List[int], create_type_hint((int, int))),
@ -1532,7 +1555,6 @@ class {test_classname}(torch.nn.Module):
List[torch.Tensor], List[torch.Tensor],
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
), ),
(torch.Tensor, torch.nn.Parameter),
(List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
@ -1540,18 +1562,21 @@ class {test_classname}(torch.nn.Module):
List[torch.Tensor], List[torch.Tensor],
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
), ),
(torch.Tensor, torch.nn.Parameter),
(List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[List[torch.Tensor]], List[torch.Tensor]), (Optional[List[torch.Tensor]], List[torch.Tensor]),
(Optional[List[int]], List[int]), (Optional[List[int]], List[int]),
] ]
for sig_type, arg_type in should_be_equal: for sig_type, arg_type in should_be_equal:
self.assertTrue(type_matches(sig_type, arg_type)) self.assertTrue(type_matches(sig_type, arg_type))
should_fail = [ should_fail = [
(int, float), (int, float),
(Union[int, float], str), (Union[int, float], str),
(list[torch.Tensor], List[int]),
] + [
# pre-PEP585 signatures
(List[torch.Tensor], List[int]), (List[torch.Tensor], List[int]),
] ]

View File

@ -1,9 +1,9 @@
import textwrap import textwrap
from typing import Any, Callable, Dict, TypeVar from typing import Any, Callable, TypeVar
_BACK_COMPAT_OBJECTS: Dict[Any, None] = {} _BACK_COMPAT_OBJECTS: dict[Any, None] = {}
_MARKED_WITH_COMPATIBILITY: Dict[Any, None] = {} _MARKED_WITH_COMPATIBILITY: dict[Any, None] = {}
_T = TypeVar("_T") _T = TypeVar("_T")

View File

@ -1,20 +1,20 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from collections import namedtuple from collections import namedtuple
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type from typing import Any, Callable, NamedTuple, Optional
import torch.return_types import torch.return_types
from torch.utils._pytree import PyTree, TreeSpec from torch.utils._pytree import PyTree, TreeSpec
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List] FlattenFuncSpec = Callable[[PyTree, TreeSpec], list]
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool] FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {} SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {}
SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {} SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
def register_pytree_flatten_spec( def register_pytree_flatten_spec(
cls: Type[Any], cls: type[Any],
flatten_fn_spec: FlattenFuncSpec, flatten_fn_spec: FlattenFuncSpec,
flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None, flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
) -> None: ) -> None:
@ -23,7 +23,7 @@ def register_pytree_flatten_spec(
def _deregister_pytree_flatten_spec( def _deregister_pytree_flatten_spec(
cls: Type[Any], cls: type[Any],
) -> None: ) -> None:
del SUPPORTED_NODES[cls] del SUPPORTED_NODES[cls]
del SUPPORTED_NODES_EXACT_MATCH[cls] del SUPPORTED_NODES_EXACT_MATCH[cls]
@ -33,7 +33,7 @@ def tree_flatten_spec(
pytree: PyTree, pytree: PyTree,
spec: TreeSpec, spec: TreeSpec,
exact_structural_match=False, exact_structural_match=False,
) -> List[Any]: ) -> list[Any]:
if spec.is_leaf(): if spec.is_leaf():
return [pytree] return [pytree]
if spec.type not in SUPPORTED_NODES: if spec.type not in SUPPORTED_NODES:
@ -58,31 +58,31 @@ def tree_flatten_spec(
return result return result
def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]: def _dict_flatten_spec(d: dict[Any, Any], spec: TreeSpec) -> list[Any]:
return [d[k] for k in spec.context] return [d[k] for k in spec.context]
def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]: def _list_flatten_spec(d: list[Any], spec: TreeSpec) -> list[Any]:
return [d[i] for i in range(spec.num_children)] return [d[i] for i in range(spec.num_children)]
def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]: def _tuple_flatten_spec(d: tuple[Any], spec: TreeSpec) -> list[Any]:
return [d[i] for i in range(spec.num_children)] return [d[i] for i in range(spec.num_children)]
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]: def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]:
return [d[i] for i in range(spec.num_children)] return [d[i] for i in range(spec.num_children)]
def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool: def _dict_flatten_spec_exact_match(d: dict[Any, Any], spec: TreeSpec) -> bool:
return len(d) == spec.num_children return len(d) == spec.num_children
def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool: def _list_flatten_spec_exact_match(d: list[Any], spec: TreeSpec) -> bool:
return len(d) == spec.num_children return len(d) == spec.num_children
def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool: def _tuple_flatten_spec_exact_match(d: tuple[Any], spec: TreeSpec) -> bool:
return len(d) == spec.num_children return len(d) == spec.num_children

View File

@ -10,18 +10,7 @@ import os
import warnings import warnings
from itertools import chain from itertools import chain
from types import CodeType, FunctionType, ModuleType from types import CodeType, FunctionType, ModuleType
from typing import ( from typing import Any, Callable, NamedTuple, Optional, Union
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
Type,
Union,
)
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -42,7 +31,7 @@ HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
_orig_module_call: Callable = torch.nn.Module.__call__ _orig_module_call: Callable = torch.nn.Module.__call__
_orig_module_getattr: Callable = torch.nn.Module.__getattr__ _orig_module_getattr: Callable = torch.nn.Module.__getattr__
_proxyable_classes: Dict[Type, None] = {} _proxyable_classes: dict[type, None] = {}
_is_fx_tracing_flag = False _is_fx_tracing_flag = False
@ -262,8 +251,8 @@ class Tracer(TracerBase):
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def __init__( def __init__(
self, self,
autowrap_modules: Tuple[ModuleType] = (math,), autowrap_modules: tuple[ModuleType] = (math,),
autowrap_functions: Tuple[Callable, ...] = (), autowrap_functions: tuple[Callable, ...] = (),
param_shapes_constant: bool = False, param_shapes_constant: bool = False,
) -> None: ) -> None:
# This method's signature is overridden by the first line of this class' # This method's signature is overridden by the first line of this class'
@ -296,7 +285,7 @@ class Tracer(TracerBase):
# Functions we will eagerly wrap when we see them while tracing # Functions we will eagerly wrap when we see them while tracing
# this captures both `math.sqrt()` and `from math import sqrt` automatically # this captures both `math.sqrt()` and `from math import sqrt` automatically
self._autowrap_function_ids: Set[int] = { self._autowrap_function_ids: set[int] = {
id(value) id(value)
for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
if not name.startswith("_") and callable(value) if not name.startswith("_") and callable(value)
@ -305,20 +294,20 @@ class Tracer(TracerBase):
# Python modules to apply autowrap to at the start, in addition to # Python modules to apply autowrap to at the start, in addition to
# modules we see while tracing # modules we see while tracing
self._autowrap_search: List[ModuleType] = list(autowrap_modules) self._autowrap_search: list[ModuleType] = list(autowrap_modules)
self.param_shapes_constant = param_shapes_constant self.param_shapes_constant = param_shapes_constant
self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None self.submodule_paths: Optional[dict[torch.nn.Module, str]] = None
self.root_module_name: str = "" self.root_module_name: str = ""
# Maps the containing module's name to the operator name # Maps the containing module's name to the operator name
self.scope = Scope("", None) self.scope = Scope("", None)
# Records the module call stack # Records the module call stack
self.module_stack = collections.OrderedDict() self.module_stack = collections.OrderedDict()
self.num_calls: Dict[str, int] = {} self.num_calls: dict[str, int] = {}
# Mapping of node name to module scope # Mapping of node name to module scope
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} self.node_name_to_scope: dict[str, tuple[str, type]] = {}
_qualname_counter: Dict[str, int] = collections.defaultdict(int) _qualname_counter: dict[str, int] = collections.defaultdict(int)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def get_fresh_qualname(self, prefix: str) -> str: def get_fresh_qualname(self, prefix: str) -> str:
@ -492,8 +481,8 @@ class Tracer(TracerBase):
self, self,
m: torch.nn.Module, m: torch.nn.Module,
forward: Callable[..., Any], forward: Callable[..., Any],
args: Tuple[Any, ...], args: tuple[Any, ...],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
) -> Any: ) -> Any:
""" """
Method that specifies the behavior of this ``Tracer`` when it encounters Method that specifies the behavior of this ``Tracer`` when it encounters
@ -547,7 +536,7 @@ class Tracer(TracerBase):
return ret_val return ret_val
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
""" """
Method that specifies the behavior of this ``Tracer`` when we call getattr Method that specifies the behavior of this ``Tracer`` when we call getattr
on a call to an ``nn.Module`` instance. on a call to an ``nn.Module`` instance.
@ -626,7 +615,7 @@ class Tracer(TracerBase):
total_args = co.co_argcount + co.co_kwonlyargcount total_args = co.co_argcount + co.co_kwonlyargcount
orig_args = list(co.co_varnames) orig_args = list(co.co_varnames)
names_iter = iter(co.co_varnames) names_iter = iter(co.co_varnames)
args: List[Any] = [] args: list[Any] = []
skip_arg_idx = 0 skip_arg_idx = 0
if is_module: if is_module:
if total_args == 0: if total_args == 0:
@ -712,7 +701,7 @@ class Tracer(TracerBase):
def trace( def trace(
self, self,
root: Union[torch.nn.Module, Callable[..., Any]], root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[dict[str, Any]] = None,
) -> Graph: ) -> Graph:
""" """
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
@ -763,7 +752,7 @@ class Tracer(TracerBase):
self.root = torch.nn.Module() self.root = torch.nn.Module()
fn = root fn = root
tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) tracer_cls: Optional[type[Tracer]] = getattr(self, "__class__", None)
self.graph = Graph(tracer_cls=tracer_cls) self.graph = Graph(tracer_cls=tracer_cls)
if hasattr(fn, "__code__"): if hasattr(fn, "__code__"):
code = fn.__code__ code = fn.__code__
@ -777,11 +766,11 @@ class Tracer(TracerBase):
# is some other attribute on the model. Construct a dict mapping Tensor # is some other attribute on the model. Construct a dict mapping Tensor
# values to the qualified name here for efficiency. This is used downstream # values to the qualified name here for efficiency. This is used downstream
# in create_arg # in create_arg
self.tensor_attrs: Dict[ self.tensor_attrs: dict[
Union[torch.Tensor, ScriptObject, FakeScriptObject], str Union[torch.Tensor, ScriptObject, FakeScriptObject], str
] = {} ] = {}
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]):
for k, v in m.__dict__.items(): for k, v in m.__dict__.items():
if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)): if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)):
self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
@ -797,7 +786,7 @@ class Tracer(TracerBase):
fn, isinstance(root, torch.nn.Module), concrete_args fn, isinstance(root, torch.nn.Module), concrete_args
) )
parameter_proxy_cache: Dict[ parameter_proxy_cache: dict[
str, Proxy str, Proxy
] = {} # Reduce number of get_attr calls ] = {} # Reduce number of get_attr calls
@ -872,7 +861,7 @@ class Tracer(TracerBase):
nonlocal cnt nonlocal cnt
cnt += 1 cnt += 1
param = sig.parameters[name] param = sig.parameters[name]
default: Tuple[Any, ...] = ( default: tuple[Any, ...] = (
() if param.default is inspect.Parameter.empty else (param.default,) () if param.default is inspect.Parameter.empty else (param.default,)
) )
out = self.create_proxy( out = self.create_proxy(
@ -913,7 +902,7 @@ class Tracer(TracerBase):
return pytree.tree_map(replace_ph, concrete_args[name]) return pytree.tree_map(replace_ph, concrete_args[name])
if name[0] == "*": if name[0] == "*":
default: Tuple[Any, ...] = () default: tuple[Any, ...] = ()
else: else:
param = sig.parameters[name] param = sig.parameters[name]
default = ( # type: ignore[assignment] default = ( # type: ignore[assignment]
@ -932,11 +921,11 @@ class Tracer(TracerBase):
# the purposes of the wrap() API. # the purposes of the wrap() API.
# We key by the globals dict id and function name to ensure we're wrapping a given # We key by the globals dict id and function name to ensure we're wrapping a given
# function only once. # function only once.
_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {} _wrapped_fns_to_patch: dict[tuple[int, str], dict] = {}
# List of methods on classes to wrap (class type, function name) # List of methods on classes to wrap (class type, function name)
# this currently only works for Tensor.* methods that aren't traced properly # this currently only works for Tensor.* methods that aren't traced properly
_wrapped_methods_to_patch: List[Tuple[type, str]] = [] _wrapped_methods_to_patch: list[tuple[type, str]] = []
if os.environ.get("FX_PATCH_GETITEM") == "1": if os.environ.get("FX_PATCH_GETITEM") == "1":
# This change is needed to trace models like PositionalEmbedding from BERT: # This change is needed to trace models like PositionalEmbedding from BERT:
@ -1043,12 +1032,12 @@ class _PatchedFnSetAttr(_PatchedFn):
class _Patcher: class _Patcher:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.patches_made: List[_PatchedFn] = [] self.patches_made: list[_PatchedFn] = []
self.visited: Set[int] = set() self.visited: set[int] = set()
def patch( def patch(
self, self,
frame_dict: Dict[str, Any], frame_dict: dict[str, Any],
name: str, name: str,
new_fn: Callable, new_fn: Callable,
deduplicate: bool = True, deduplicate: bool = True,
@ -1169,7 +1158,7 @@ def _patch_wrapped_functions(patcher: _Patcher):
def _autowrap_check( def _autowrap_check(
patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int]
): ):
""" """
Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
@ -1252,7 +1241,7 @@ def wrap(fn_or_name: Union[str, Callable]):
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def symbolic_trace( def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]], root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[dict[str, Any]] = None,
) -> GraphModule: ) -> GraphModule:
""" """
Symbolic tracing API Symbolic tracing API

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import sys import sys
from typing import Dict, Optional from typing import Optional
import torch import torch
from torch._logging import LazyString from torch._logging import LazyString
@ -43,7 +43,7 @@ def _format_graph_code(name, filename, graph_str):
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]: def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[dict]:
""" """
Returns the nn_module_stack of the first call_function node. Returns the nn_module_stack of the first call_function node.
""" """

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import operator import operator
from collections import deque from collections import deque
from typing import Deque, Dict, List, NamedTuple, Set, Tuple from typing import NamedTuple
import torch import torch
from torch.fx.experimental.partitioner_utils import ( from torch.fx.experimental.partitioner_utils import (
@ -28,15 +28,15 @@ class DAGNode:
def __init__( def __init__(
self, self,
submodule_node: Node, submodule_node: Node,
input_nodes: List[Node], input_nodes: list[Node],
output_nodes: List[Node], output_nodes: list[Node],
logical_device_ids: List[int], logical_device_ids: list[int],
size_bytes: int, size_bytes: int,
) -> None: ) -> None:
self.submodule_node: Node = submodule_node self.submodule_node: Node = submodule_node
self.input_nodes: List[Node] = input_nodes self.input_nodes: list[Node] = input_nodes
self.output_nodes: List[Node] = output_nodes self.output_nodes: list[Node] = output_nodes
self.logical_device_ids: List[int] = logical_device_ids self.logical_device_ids: list[int] = logical_device_ids
self.size_bytes = size_bytes self.size_bytes = size_bytes
def __str__(self) -> str: def __str__(self) -> str:
@ -47,14 +47,14 @@ class DAG:
"""DAG class contains all the DAG nodes""" """DAG class contains all the DAG nodes"""
def __init__(self) -> None: def __init__(self) -> None:
self.nodes: List[DAGNode] = [] self.nodes: list[DAGNode] = []
def create_node( def create_node(
self, self,
submodule_node: Node, submodule_node: Node,
input_nodes: List[Node], input_nodes: list[Node],
output_nodes: List[Node], output_nodes: list[Node],
logical_devices: List[int], logical_devices: list[int],
size_bytes: int, size_bytes: int,
) -> None: ) -> None:
node = DAGNode( node = DAGNode(
@ -79,7 +79,7 @@ def reset_partition_device(partitions):
def combine_two_partitions( def combine_two_partitions(
partition_0: Partition, partition_1: Partition, partitions: List[Partition] partition_0: Partition, partition_1: Partition, partitions: list[Partition]
) -> None: ) -> None:
"""Given a list of partitions and its two partitions, """Given a list of partitions and its two partitions,
combine these two partitions into a new one appending to the partitions combine these two partitions into a new one appending to the partitions
@ -95,7 +95,7 @@ def combine_two_partitions(
return return
def set_parents_and_children(partitions: List[Partition]) -> None: def set_parents_and_children(partitions: list[Partition]) -> None:
"""Given a list of partitions, mark parents and children for each partition""" """Given a list of partitions, mark parents and children for each partition"""
# Go through all nodes in a partition. # Go through all nodes in a partition.
# If a node's user is in other partition, # If a node's user is in other partition,
@ -119,7 +119,7 @@ def set_parents_and_children(partitions: List[Partition]) -> None:
return return
def reorganize_partitions(partitions: List[Partition]) -> None: def reorganize_partitions(partitions: list[Partition]) -> None:
"""Given a list of partitions, reorganize partition id, """Given a list of partitions, reorganize partition id,
its parents and its children for each partition its parents and its children for each partition
""" """
@ -130,17 +130,17 @@ def reorganize_partitions(partitions: List[Partition]) -> None:
return return
def get_bfs_level_partition(partitions: List[Partition]) -> None: def get_bfs_level_partition(partitions: list[Partition]) -> None:
"""Given a list of partitions, """Given a list of partitions,
mark the bfs level for each partition mark the bfs level for each partition
""" """
current_level: Set[Partition] = set() current_level: set[Partition] = set()
visited: Set[Partition] = set() visited: set[Partition] = set()
for partition in partitions: for partition in partitions:
# If a partition has no parent, it should be in root level # If a partition has no parent, it should be in root level
if len(partition.parents) == 0: if len(partition.parents) == 0:
current_level.add(partition) current_level.add(partition)
next_level: Set[Partition] = set() next_level: set[Partition] = set()
level = 0 level = 0
# bfs # bfs
while current_level: while current_level:
@ -158,26 +158,26 @@ def get_bfs_level_partition(partitions: List[Partition]) -> None:
return return
def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]: def get_node_to_partition_mapping(partitions: list[Partition]) -> dict[Node, int]:
"""Given a list of partitions,return node to partition mapping""" """Given a list of partitions,return node to partition mapping"""
node_to_partition: Dict[Node, int] = {} node_to_partition: dict[Node, int] = {}
for partition in partitions: for partition in partitions:
for node in partition.nodes: for node in partition.nodes:
node_to_partition[node] = partition.partition_id node_to_partition[node] = partition.partition_id
return node_to_partition return node_to_partition
def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]: def get_logical_id_to_device(devices: list[Device]) -> dict[int, Device]:
"""Get a mapping from device logical ID to Device object.""" """Get a mapping from device logical ID to Device object."""
logical_id_to_device: Dict[int, Device] = {} logical_id_to_device: dict[int, Device] = {}
for d in devices: for d in devices:
logical_id_to_device[d.logical_id] = d logical_id_to_device[d.logical_id] = d
return logical_id_to_device return logical_id_to_device
def get_device_partition_stats( def get_device_partition_stats(
partitions: List[Partition], devices: List[Device] partitions: list[Partition], devices: list[Device]
) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]: ) -> tuple[dict[Device, list[Partition]], dict[Device, int], list[Partition]]:
"""Given a list of partitions and a list of devices, returns: """Given a list of partitions and a list of devices, returns:
1. A mapping from device to partitions on it; 1. A mapping from device to partitions on it;
2. A mapping from device to its remaining memory size; 2. A mapping from device to its remaining memory size;
@ -186,9 +186,9 @@ def get_device_partition_stats(
# logical id to device # logical id to device
logical_id_to_device = get_logical_id_to_device(devices) logical_id_to_device = get_logical_id_to_device(devices)
# Track partitions on device # Track partitions on device
device_to_partitions: Dict[Device, List[Partition]] = {} device_to_partitions: dict[Device, list[Partition]] = {}
# Track device's left mem size # Track device's left mem size
device_to_left_mem_bytes: Dict[Device, int] = {} device_to_left_mem_bytes: dict[Device, int] = {}
for d in devices: for d in devices:
device_to_partitions[d] = [] device_to_partitions[d] = []
device_to_left_mem_bytes[d] = d.available_mem_bytes device_to_left_mem_bytes[d] = d.available_mem_bytes
@ -213,16 +213,16 @@ def get_device_partition_stats(
def get_device_to_partitions_mapping( def get_device_to_partitions_mapping(
partitions: List[Partition], devices: List[Device] partitions: list[Partition], devices: list[Device]
): ):
"""Given a list of partitions and a list of devices, """Given a list of partitions and a list of devices,
map each partition into a device. map each partition into a device.
""" """
def calculate_extra_mem_bytes_needed_for( def calculate_extra_mem_bytes_needed_for(
partition: Partition, partitions: List[Partition] partition: Partition, partitions: list[Partition]
): ):
all_nodes: Set[Node] = set() all_nodes: set[Node] = set()
for p in partitions: for p in partitions:
all_nodes = all_nodes.union(p.nodes) all_nodes = all_nodes.union(p.nodes)
if len(all_nodes) == 0: if len(all_nodes) == 0:
@ -273,8 +273,8 @@ def check_dependency(partition):
"""Given a partition,check if there is a circular dependency on """Given a partition,check if there is a circular dependency on
this partition using bfs this partition using bfs
""" """
visited: Set[Partition] = {partition} visited: set[Partition] = {partition}
queue: Deque[Partition] = deque([partition]) queue: deque[Partition] = deque([partition])
while queue: while queue:
p = queue.popleft() p = queue.popleft()
for child in p.children: for child in p.children:
@ -298,9 +298,9 @@ class Partitioner:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.partitions: List[Partition] = [] self.partitions: list[Partition] = []
self.node_to_partition: Dict[Node, int] = {} self.node_to_partition: dict[Node, int] = {}
self.devices: List[Device] = [] self.devices: list[Device] = []
def partition_graph( def partition_graph(
self, self,
@ -435,9 +435,9 @@ class Partitioner:
return device return device
# Track partition and its left mem size # Track partition and its left mem size
partition_to_left_mem_bytes: Dict[Partition, int] = {} partition_to_left_mem_bytes: dict[Partition, int] = {}
# Track all the devices that have been used # Track all the devices that have been used
occupied_devices: List[Device] = [] occupied_devices: list[Device] = []
partition = self.create_partition() partition = self.create_partition()
for node in self.graph_module.graph.nodes: for node in self.graph_module.graph.nodes:
if node.op in {"call_module", "call_method", "call_function"}: if node.op in {"call_module", "call_method", "call_function"}:
@ -516,7 +516,7 @@ class Partitioner:
# Devices that hold partitions # Devices that hold partitions
used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
# Track replicates of the assigned devices # Track replicates of the assigned devices
replicated_device_to_used_device: Dict[Device, Device] = {} replicated_device_to_used_device: dict[Device, Device] = {}
while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len(
self.devices self.devices
@ -583,7 +583,7 @@ class Partitioner:
continue continue
if node.target == operator.__getitem__: if node.target == operator.__getitem__:
continue continue
input_nodes: Dict[Node, None] = {} input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault) map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault)
# When a node has two or more output nodes, # When a node has two or more output nodes,
@ -634,7 +634,7 @@ class Partitioner:
""" """
def combine_partitions_based_on_size( def combine_partitions_based_on_size(
partitions: List[Partition], available_mem_bytes: int partitions: list[Partition], available_mem_bytes: int
) -> None: ) -> None:
"""Combining small partitions together to keep as less partitions as possible. """Combining small partitions together to keep as less partitions as possible.
Here is an example of the algorithm to do this: Here is an example of the algorithm to do this:
@ -672,10 +672,10 @@ class Partitioner:
return mem_bytes_needed return mem_bytes_needed
def find_partition_to_combine_based_on_size( def find_partition_to_combine_based_on_size(
sorted_partitions: List[Partition], sorted_partitions: list[Partition],
available_mem_bytes: int, available_mem_bytes: int,
partitions: List[Partition], partitions: list[Partition],
) -> Tuple[bool, List[Partition]]: ) -> tuple[bool, list[Partition]]:
"""step 1 in combine_partition_based_on_size()""" """step 1 in combine_partition_based_on_size()"""
find_combination = False find_combination = False
smallest_partition = sorted_partitions.pop(0) smallest_partition = sorted_partitions.pop(0)
@ -721,8 +721,8 @@ class Partitioner:
return False return False
# Track embedding partitions and non-embedding partitions separately # Track embedding partitions and non-embedding partitions separately
embedding_partitions: List[Partition] = [] embedding_partitions: list[Partition] = []
non_embedding_partitions: List[Partition] = [] non_embedding_partitions: list[Partition] = []
# A Flag to check the boundary # A Flag to check the boundary
in_embedding_region: bool = False in_embedding_region: bool = False
partition = self.create_partition() partition = self.create_partition()
@ -794,7 +794,7 @@ class Partitioner:
def cost_aware_partition( def cost_aware_partition(
self, self,
transfer_rate_bytes_per_sec: float, transfer_rate_bytes_per_sec: float,
node_to_latency_mapping: Dict[Node, NodeLatency], node_to_latency_mapping: dict[Node, NodeLatency],
) -> None: ) -> None:
"""This method is to partition the fx module based on the cost. """This method is to partition the fx module based on the cost.
The cost is the total latency of running the whole fx module. The cost is the total latency of running the whole fx module.
@ -872,7 +872,7 @@ class Partitioner:
) )
if len(self.partitions) == 1: if len(self.partitions) == 1:
return False return False
partition_pair: List[int] = [] partition_pair: list[int] = []
for i in range(len(self.partitions) - 1): for i in range(len(self.partitions) - 1):
for j in range(i + 1, len(self.partitions)): for j in range(i + 1, len(self.partitions)):
# Try to combine the partition pair # Try to combine the partition pair
@ -915,7 +915,7 @@ class Partitioner:
def kl_based_partition( def kl_based_partition(
self, self,
transfer_rate_bytes_per_sec: float, transfer_rate_bytes_per_sec: float,
node_to_latency_mapping: Dict[Node, NodeLatency], node_to_latency_mapping: dict[Node, NodeLatency],
) -> None: ) -> None:
"""This function is a cost aware partition based """This function is a cost aware partition based
on Kernighan-Lin algorithm. on Kernighan-Lin algorithm.
@ -987,7 +987,7 @@ class Partitioner:
""" """
p1_nodes = list(p1.nodes) + [None] p1_nodes = list(p1.nodes) + [None]
min_cost = float("inf") min_cost = float("inf")
node_pair: List[Node] = [] node_pair: list[Node] = []
for n1 in p1_nodes: for n1 in p1_nodes:
# Ignore the node if it is not a op node # Ignore the node if it is not a op node
if n1 is not None and n1.op in {"placeholder", "get_attr"}: if n1 is not None and n1.op in {"placeholder", "get_attr"}:
@ -1011,9 +1011,9 @@ class Partitioner:
self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
) )
# Keep tracking the node pair that shows the better cost # Keep tracking the node pair that shows the better cost
node_pair: List[Node] = [] node_pair: list[Node] = []
# Keep tracking the partition pair of node pair # Keep tracking the partition pair of node pair
partition_pair: List[Partition] = [] partition_pair: list[Partition] = []
# Collect all the op nodes from the graph # Collect all the op nodes from the graph
op_nodes = [ op_nodes = [
n n
@ -1060,7 +1060,7 @@ class Partitioner:
"""This function helps to rebuild the partitions given the nodes and its """This function helps to rebuild the partitions given the nodes and its
corresponding partition id corresponding partition id
""" """
partition_id_to_partition_mapping: Dict[int, Partition] = {} partition_id_to_partition_mapping: dict[int, Partition] = {}
self.node_to_partition = node_to_partition_mapping self.node_to_partition = node_to_partition_mapping
for node in self.node_to_partition: for node in self.node_to_partition:
partition_id = self.node_to_partition[node] partition_id = self.node_to_partition[node]

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import re import re
from typing import Callable, Dict, Optional, Set, Union from typing import Callable, Optional, Union
import torch.fx import torch.fx
from torch.fx.node import map_arg from torch.fx.node import map_arg
@ -100,7 +100,7 @@ def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
call_mod_args = call_mod_node_to_replace.args call_mod_args = call_mod_node_to_replace.args
call_mod_kwargs = call_mod_node_to_replace.kwargs call_mod_kwargs = call_mod_node_to_replace.kwargs
replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {} replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {}
ph_count = 0 ph_count = 0
def replacement_fn(node): def replacement_fn(node):
@ -171,7 +171,7 @@ def split_const_subgraphs(
# Build up a list of const_nodes, defined as nodes that are themselves # Build up a list of const_nodes, defined as nodes that are themselves
# get_attrs, or have all get_attr or other constant node inputs. # get_attrs, or have all get_attr or other constant node inputs.
const_nodes: Set[torch.fx.Node] = set() const_nodes: set[torch.fx.Node] = set()
found_const_folding = False found_const_folding = False
for node in mod_traced.graph.nodes: for node in mod_traced.graph.nodes:
# Skip over placeholders/outputs because they can't be const folded and # Skip over placeholders/outputs because they can't be const folded and

View File

@ -1,4 +1,4 @@
from typing import List, Sequence from collections.abc import Sequence
import torch.fx as fx import torch.fx as fx
@ -19,7 +19,7 @@ def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
the `gm` with breakpoint inserted. the `gm` with breakpoint inserted.
""" """
def insert_pdb(body: Sequence[str]) -> List[str]: def insert_pdb(body: Sequence[str]) -> list[str]:
return ["import pdb; pdb.set_trace()\n", *body] return ["import pdb; pdb.set_trace()\n", *body]
with gm.graph.on_generate_code( with gm.graph.on_generate_code(

View File

@ -2,7 +2,7 @@
import itertools import itertools
import operator import operator
from functools import reduce from functools import reduce
from typing import Callable, Dict, TypeVar from typing import Callable, TypeVar
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
import sympy import sympy
@ -19,9 +19,9 @@ from torch.nn.modules.conv import Conv2d
_T = TypeVar("_T") _T = TypeVar("_T")
_P = ParamSpec("_P") _P = ParamSpec("_P")
_INFERENCE_RULES: Dict[Target, Callable] = {} _INFERENCE_RULES: dict[Target, Callable] = {}
_REFINEMENT_RULES: Dict[Target, Callable] = {} _REFINEMENT_RULES: dict[Target, Callable] = {}
_RULES: Dict[Target, Callable] = {} _RULES: dict[Target, Callable] = {}
__all__ = [ __all__ = [
"GraphTypeChecker", "GraphTypeChecker",

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import itertools import itertools
import operator import operator
from typing import Dict, List, Tuple
import torch import torch
from torch.fx._symbolic_trace import symbolic_trace from torch.fx._symbolic_trace import symbolic_trace
@ -10,8 +9,8 @@ from torch.fx.passes.tools_common import legalize_graph
def split_result_tensors( def split_result_tensors(
result: torch.Tensor, inputs: List[torch.Tensor] result: torch.Tensor, inputs: list[torch.Tensor]
) -> Tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
""" """
A free function for use in the merge_matmul graph transformation below that A free function for use in the merge_matmul graph transformation below that
splits the output from a merged matmul into the individual results for each splits the output from a merged matmul into the individual results for each
@ -71,7 +70,7 @@ def may_depend_on(a: Node, b: Node, search_depth: int = 6):
return False return False
def are_nodes_independent(nodes: List[Node]): def are_nodes_independent(nodes: list[Node]):
""" """
Check if all of the given nodes are pairwise-data independent. Check if all of the given nodes are pairwise-data independent.
@ -102,8 +101,8 @@ def merge_matmul(in_mod: torch.nn.Module):
""" """
gm = symbolic_trace(in_mod) gm = symbolic_trace(in_mod)
rhs_users: Dict[Node, List[Node]] = {} rhs_users: dict[Node, list[Node]] = {}
lhs_users: Dict[Node, List[Node]] = {} lhs_users: dict[Node, list[Node]] = {}
# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
# the matmul of which they are the LHS/RHS. # the matmul of which they are the LHS/RHS.

View File

@ -2,7 +2,7 @@
import builtins import builtins
import functools import functools
import warnings import warnings
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.fx import torch.fx
@ -40,7 +40,7 @@ def torch_abs_override(input, *, out=None):
return input return input
manual_meta_overrides: Dict[Callable, Callable] = { manual_meta_overrides: dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override, torch.nn.Embedding: embedding_override,
torch.nn.LayerNorm: nn_layernorm_override, torch.nn.LayerNorm: nn_layernorm_override,
torch.relu: torch_relu_override, torch.relu: torch_relu_override,
@ -274,7 +274,7 @@ class MetaTracer(torch.fx.Tracer):
def proxy(self, node): def proxy(self, node):
return MetaProxy(node, self) return MetaProxy(node, self)
def trace(self, root, meta_args: Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
assert isinstance(meta_args, dict) assert isinstance(meta_args, dict)
self.meta_args = meta_args self.meta_args = meta_args
@ -299,8 +299,8 @@ class MetaTracer(torch.fx.Tracer):
def symbolic_trace( def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]], root: Union[torch.nn.Module, Callable[..., Any]],
meta_args: Optional[Dict[str, torch.Tensor]] = None, meta_args: Optional[dict[str, torch.Tensor]] = None,
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[dict[str, Any]] = None,
) -> torch.fx.GraphModule: ) -> torch.fx.GraphModule:
tracer = MetaTracer() tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type] graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import operator import operator
import warnings import warnings
from typing import Callable, Dict, Iterable, TypeVar from collections.abc import Iterable
from typing import Callable, TypeVar
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
import torch import torch
@ -57,7 +58,7 @@ from torch.nn.modules.conv import Conv2d
_T = TypeVar("_T") _T = TypeVar("_T")
_P = ParamSpec("_P") _P = ParamSpec("_P")
_INFERENCE_RULES: Dict[Target, Callable] = {} _INFERENCE_RULES: dict[Target, Callable] = {}
MAX_TENSOR_RANK = 4 MAX_TENSOR_RANK = 4

View File

@ -1,7 +1,7 @@
# mypy: ignore-errors # mypy: ignore-errors
import copy import copy
import itertools import itertools
from typing import Callable, Dict, List from typing import Callable
from torch.fx.experimental.migrate_gradual_types.constraint import ( from torch.fx.experimental.migrate_gradual_types.constraint import (
ApplyBroadcasting, ApplyBroadcasting,
@ -50,7 +50,7 @@ from torch.fx.experimental.migrate_gradual_types.util import (
from torch.fx.tensor_type import Dyn, TensorType from torch.fx.tensor_type import Dyn, TensorType
_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} _TRANSFORMATION_RULES: dict[Constraint, Callable] = {}
def register_transformation_rule(call_target): def register_transformation_rule(call_target):
@ -797,7 +797,7 @@ def transform_constraint(constraint: Constraint, counter: int):
return constraint, counter return constraint, counter
def calc_last_two_dims(constraint, d: List[DVar]): def calc_last_two_dims(constraint, d: list[DVar]):
""" """
Generates constraints for the last two dimensions of a convolution or a maxpool output Generates constraints for the last two dimensions of a convolution or a maxpool output
Args: Args:
@ -866,7 +866,7 @@ def calc_last_two_dims(constraint, d: List[DVar]):
return c4, c5 return c4, c5
def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]):
""" """
Generate all possibilities of being equal or not equal to dyn for my_list Generate all possibilities of being equal or not equal to dyn for my_list
Args: Args:
@ -888,7 +888,7 @@ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
return all_possibilities return all_possibilities
def is_target_div_by_dim(target: List[int], dim: List[DVar]): def is_target_div_by_dim(target: list[int], dim: list[DVar]):
""" """
Generate constraints to check if the target dimensions are divisible by the input dimensions Generate constraints to check if the target dimensions are divisible by the input dimensions
Args: Args:
@ -901,7 +901,7 @@ def is_target_div_by_dim(target: List[int], dim: List[DVar]):
return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
def is_dim_div_by_target(target: List[int], dim: List[DVar]): def is_dim_div_by_target(target: list[int], dim: list[DVar]):
""" """
Generate constraints to check if the input dimensions is divisible by the target dimensions Generate constraints to check if the input dimensions is divisible by the target dimensions
Args: Args:
@ -1000,9 +1000,9 @@ def apply_padding(
e11: BinConstraintT, e11: BinConstraintT,
e2: BinConstraintT, e2: BinConstraintT,
e12: BinConstraintT, e12: BinConstraintT,
d2: List[DVar], d2: list[DVar],
d11: List[DVar], d11: list[DVar],
d12: List[DVar], d12: list[DVar],
counter: int, counter: int,
): ):
""" """
@ -1068,7 +1068,7 @@ def apply_padding(
def no_broadcast_dim_with_index( def no_broadcast_dim_with_index(
d1: List[DVar], d2: List[DVar], d3: List[DVar], d4: List[DVar], i: int d1: list[DVar], d2: list[DVar], d3: list[DVar], d4: list[DVar], i: int
): ):
""" """
Args: Args:
@ -1129,10 +1129,10 @@ def create_equality_constraints_for_broadcasting(
e2: TVar, e2: TVar,
e11: TVar, e11: TVar,
e12: TVar, e12: TVar,
d1: List[DVar], d1: list[DVar],
d2: List[DVar], d2: list[DVar],
d11: List[DVar], d11: list[DVar],
d12: List[DVar], d12: list[DVar],
): ):
""" """
Create equality constraints for when no broadcasting occurs Create equality constraints for when no broadcasting occurs
@ -1236,7 +1236,7 @@ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
def generate_all_broadcasting_possibilities_no_padding( def generate_all_broadcasting_possibilities_no_padding(
d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar] d1: list[DVar], d2: list[DVar], d11: list[DVar], d12: list[DVar]
): ):
""" """
Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import operator import operator
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Optional
import torch import torch
import torch.fx import torch.fx
@ -38,7 +38,7 @@ class NormalizeArgs(Transformer):
self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
): ):
super().__init__(module) super().__init__(module)
self.node_map: Dict[Proxy, Node] = {} self.node_map: dict[Proxy, Node] = {}
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
def run_node(self, n: Node) -> Any: def run_node(self, n: Node) -> Any:
@ -66,10 +66,10 @@ class NormalizeArgs(Transformer):
def call_function( def call_function(
self, self,
target: Target, target: Target,
args: Tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
arg_types: Optional[Tuple[Any, ...]] = None, arg_types: Optional[tuple[Any, ...]] = None,
kwarg_types: Optional[Dict[str, Any]] = None, kwarg_types: Optional[dict[str, Any]] = None,
): ):
assert callable(target) assert callable(target)
new_args_and_kwargs = normalize_function( new_args_and_kwargs = normalize_function(
@ -89,7 +89,7 @@ class NormalizeArgs(Transformer):
return super().call_function(target, args, kwargs) return super().call_function(target, args, kwargs)
def call_module( def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
): ):
assert isinstance(target, str) assert isinstance(target, str)
new_args_and_kwargs = normalize_module( new_args_and_kwargs = normalize_module(
@ -124,7 +124,7 @@ class NormalizeOperators(AnnotateTypesWithSchema):
traced = NormalizeOperators(traced).transform() traced = NormalizeOperators(traced).transform()
""" """
binary_magic_method_remap: Dict[ binary_magic_method_remap: dict[
Callable[[Any, Any], Any], Callable[[Any, Any], Any] Callable[[Any, Any], Any], Callable[[Any, Any], Any]
] = { ] = {
torch.add: operator.add, torch.add: operator.add,
@ -142,7 +142,7 @@ class NormalizeOperators(AnnotateTypesWithSchema):
} }
def call_function( def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
): ):
# Normalize operators according to the magic methods implemented on tensors here: # Normalize operators according to the magic methods implemented on tensors here:
# https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950

View File

@ -4,8 +4,9 @@ import logging
import operator import operator
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable
from enum import Enum from enum import Enum
from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type from typing import Any, cast, Optional
import torch import torch
import torch.fx as fx import torch.fx as fx
@ -33,7 +34,7 @@ __all__ = [
] ]
def _parent_name(target: str) -> Tuple[str, str]: def _parent_name(target: str) -> tuple[str, str]:
""" """
Splits a qualname into parent path and last atom. Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`) For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
@ -44,11 +45,11 @@ def _parent_name(target: str) -> Tuple[str, str]:
# Works for length 2 patterns with 2 modules # Works for length 2 patterns with 2 modules
def matches_module_pattern( def matches_module_pattern(
pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any] pattern: Iterable[type], node: fx.Node, modules: dict[str, Any]
): ):
if len(node.args) == 0: if len(node.args) == 0:
return False return False
nodes: Tuple[Any, fx.Node] = (node.args[0], node) nodes: tuple[Any, fx.Node] = (node.args[0], node)
for expected_type, current_node in zip(pattern, nodes): for expected_type, current_node in zip(pattern, nodes):
if not isinstance(current_node, fx.Node): if not isinstance(current_node, fx.Node):
return False return False
@ -64,7 +65,7 @@ def matches_module_pattern(
def replace_node_module( def replace_node_module(
node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module
): ):
assert isinstance(node.target, str) assert isinstance(node.target, str)
parent_name, name = _parent_name(node.target) parent_name, name = _parent_name(node.target)
@ -120,7 +121,7 @@ def remove_dropout(model: nn.Module) -> nn.Module:
class DropoutRemover(torch.fx.Transformer): class DropoutRemover(torch.fx.Transformer):
def call_module( def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
if isinstance(self.submodules[target], nn.Dropout): if isinstance(self.submodules[target], nn.Dropout):
assert len(args) == 1 assert len(args) == 1
@ -133,15 +134,15 @@ def remove_dropout(model: nn.Module) -> nn.Module:
def extract_subgraph( def extract_subgraph(
orig_module: nn.Module, orig_module: nn.Module,
nodes: List[fx.Node], nodes: list[fx.Node],
inputs: List[fx.Node], inputs: list[fx.Node],
outputs: List[fx.Node], outputs: list[fx.Node],
): ):
""" """
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
""" """
new_graph = fx.Graph() new_graph = fx.Graph()
env: Dict[fx.Node, fx.Node] = {} env: dict[fx.Node, fx.Node] = {}
for input in inputs: for input in inputs:
new_node = new_graph.placeholder(input.name) new_node = new_graph.placeholder(input.name)
env[input] = new_node env[input] = new_node
@ -180,13 +181,13 @@ mkldnn_map = {
} }
def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
""" """
For each node, if it's a module that can be preconverted into MKLDNN, For each node, if it's a module that can be preconverted into MKLDNN,
then we do so and create a mapping to allow us to convert from the MKLDNN then we do so and create a mapping to allow us to convert from the MKLDNN
version of the module to the original. version of the module to the original.
""" """
old_modules: Dict[nn.Module, nn.Module] = {} old_modules: dict[nn.Module, nn.Module] = {}
for node in nodes: for node in nodes:
if node.op == "call_module": if node.op == "call_module":
assert isinstance(node.target, str) assert isinstance(node.target, str)
@ -200,9 +201,9 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
def reset_modules( def reset_modules(
nodes: List[fx.Node], nodes: list[fx.Node],
modules: Dict[str, nn.Module], modules: dict[str, nn.Module],
old_modules: Dict[nn.Module, nn.Module], old_modules: dict[nn.Module, nn.Module],
): ):
""" """
Maps each module that's been changed with `modules_to_mkldnn` back to its Maps each module that's been changed with `modules_to_mkldnn` back to its
@ -219,9 +220,9 @@ def reset_modules(
class MklSubgraph: class MklSubgraph:
def __init__(self, fx_graph: fx.Graph): def __init__(self, fx_graph: fx.Graph):
self.fx_graph = fx_graph self.fx_graph = fx_graph
self.nodes: List[fx.Node] = [] self.nodes: list[fx.Node] = []
self.start_nodes: List[fx.Node] = [] self.start_nodes: list[fx.Node] = []
self.end_nodes: List[fx.Node] = [] self.end_nodes: list[fx.Node] = []
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
@ -244,7 +245,7 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
ShapeProp(fx_model).propagate(example_inputs) ShapeProp(fx_model).propagate(example_inputs)
sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes]) output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes])
submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
def benchmark(f): def benchmark(f):
@ -281,8 +282,8 @@ def use_mkl_length(graph: MklSubgraph) -> bool:
class UnionFind: class UnionFind:
def __init__(self, n): def __init__(self, n):
self.parent: List[Optional[int]] = [None] * n self.parent: list[Optional[int]] = [None] * n
self.size: List[int] = [0] * n self.size: list[int] = [0] * n
def make_set(self, v: int): def make_set(self, v: int):
self.parent[v] = v self.parent[v] = v
@ -308,8 +309,8 @@ class UnionFind:
def optimize_for_inference( def optimize_for_inference(
model: torch.nn.Module, model: torch.nn.Module,
pass_config: Optional[Dict[str, Any]] = None, pass_config: Optional[dict[str, Any]] = None,
tracer: Type[fx.Tracer] = fx.Tracer, tracer: type[fx.Tracer] = fx.Tracer,
) -> torch.nn.Module: ) -> torch.nn.Module:
""" """
Performs a set of optimization passes to optimize a model for the Performs a set of optimization passes to optimize a model for the
@ -348,7 +349,7 @@ def optimize_for_inference(
cur_tracer = tracer() cur_tracer = tracer()
fx_graph = cur_tracer.trace(copy.deepcopy(model)) fx_graph = cur_tracer.trace(copy.deepcopy(model))
fx.GraphModule(cur_tracer.root, fx_graph) fx.GraphModule(cur_tracer.root, fx_graph)
modules: Dict[str, nn.Module] = dict(model.named_modules()) modules: dict[str, nn.Module] = dict(model.named_modules())
class MklSupport(Enum): class MklSupport(Enum):
NO = 1 NO = 1
@ -388,7 +389,7 @@ def optimize_for_inference(
node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,)) node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
) )
node.args = cast(Tuple[fx.node.Argument], mkldnn_args) node.args = cast(tuple[fx.node.Argument], mkldnn_args)
with fx_graph.inserting_after(node): with fx_graph.inserting_after(node):
dense_x = fx_graph.create_node("call_method", "to_dense", (node,)) dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
@ -455,7 +456,7 @@ def optimize_for_inference(
for other_color in cur_colors[1:]: for other_color in cur_colors[1:]:
uf.join(cur_colors[0], other_color) uf.join(cur_colors[0], other_color)
mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
for node in fx_graph.nodes: for node in fx_graph.nodes:
if hasattr(node, "color"): if hasattr(node, "color"):
mkldnn_graphs[uf.find(node.color)].nodes.append(node) mkldnn_graphs[uf.find(node.color)].nodes.append(node)

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from enum import Enum from enum import Enum
from typing import Dict, List, NamedTuple, Set from typing import NamedTuple
from torch.fx.node import map_arg, Node from torch.fx.node import map_arg, Node
@ -11,13 +11,13 @@ class Partition:
""" """
def __init__(self, partition_id: int) -> None: def __init__(self, partition_id: int) -> None:
self.nodes: Set[Node] = set() self.nodes: set[Node] = set()
self.partition_id = partition_id self.partition_id = partition_id
self.parents: Set[Partition] = set() self.parents: set[Partition] = set()
self.children: Set[Partition] = set() self.children: set[Partition] = set()
self.bfs_level: int = -1 self.bfs_level: int = -1
self.used_mem_bytes: int = 0 self.used_mem_bytes: int = 0
self.logical_device_ids: List[int] = [] self.logical_device_ids: list[int] = []
def __str__(self): def __str__(self):
return str(self.partition_id) return str(self.partition_id)
@ -28,7 +28,7 @@ class Partition:
self.used_mem_bytes += get_extra_size_of(node, self.nodes) self.used_mem_bytes += get_extra_size_of(node, self.nodes)
def add_node(self, node): def add_node(self, node):
input_nodes: Dict[Node, None] = {} input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault) map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault)
# Add current node's input nodes if they are placeholder or constants # Add current node's input nodes if they are placeholder or constants
@ -43,7 +43,7 @@ class Partition:
if node in self.nodes: if node in self.nodes:
self.nodes.remove(node) self.nodes.remove(node)
# Collect the node's input nodes # Collect the node's input nodes
input_nodes: Dict[Node, None] = {} input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault) map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault)
# Check if an input node is a placeholder or get_attr, # Check if an input node is a placeholder or get_attr,
@ -88,23 +88,23 @@ class PartitionMode(Enum):
class PartitionerConfig(NamedTuple): class PartitionerConfig(NamedTuple):
devices: List[Device] devices: list[Device]
mode: PartitionMode = PartitionMode.size_based mode: PartitionMode = PartitionMode.size_based
transfer_rate_bytes_per_sec: float = 0.0 transfer_rate_bytes_per_sec: float = 0.0
node_to_latency_mapping: Dict[Node, NodeLatency] = {} node_to_latency_mapping: dict[Node, NodeLatency] = {}
node_to_partition_mapping: Dict[Node, int] = {} node_to_partition_mapping: dict[Node, int] = {}
partition_to_logical_device_mapping: Dict[int, List[int]] = {} partition_to_logical_device_mapping: dict[int, list[int]] = {}
# Saturate host by replicating partitions to the remaining idle devices. # Saturate host by replicating partitions to the remaining idle devices.
saturate_host: bool = False saturate_host: bool = False
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: def get_extra_size_of(node: Node, nodes: set[Node]) -> int:
"""Given a node and a set of nodes, """Given a node and a set of nodes,
this function return the extra size that needed this function return the extra size that needed
if this node is included in this set. if this node is included in this set.
""" """
# Find all its input nodes # Find all its input nodes
input_nodes: Dict[Node, None] = {} input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault) map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault)
# Calculate total size of related nodes # Calculate total size of related nodes
@ -127,18 +127,18 @@ def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
def get_latency_of_one_partition( def get_latency_of_one_partition(
partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] partition: Partition, node_to_latency_mapping: dict[Node, NodeLatency]
) -> PartitionLatency: ) -> PartitionLatency:
"""Given a partition and its nodes' latency, return a PartitionLatency for this partition""" """Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
def get_top_nodes(partition: Partition) -> List[Node]: def get_top_nodes(partition: Partition) -> list[Node]:
"""Given a partition, return a list of nodes on the top bfs level""" """Given a partition, return a list of nodes on the top bfs level"""
top_nodes: List[Node] = [] top_nodes: list[Node] = []
for node in partition.nodes: for node in partition.nodes:
# Skip placeholder and get_attr nodes # Skip placeholder and get_attr nodes
if node.op in {"placeholder", "get_attr"}: if node.op in {"placeholder", "get_attr"}:
continue continue
input_nodes: Dict[Node, None] = {} input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault) map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault)
# If a node has no input nodes in this partition, # If a node has no input nodes in this partition,
@ -216,12 +216,12 @@ def get_latency_of_one_partition(
def get_partition_to_latency_mapping( def get_partition_to_latency_mapping(
partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency] partitions: list[Partition], node_to_latency_mapping: dict[Node, NodeLatency]
) -> Dict[Partition, PartitionLatency]: ) -> dict[Partition, PartitionLatency]:
"""Given all the partitions and node_to_latency_mapping dictionary, """Given all the partitions and node_to_latency_mapping dictionary,
return a mapping dictionary of each partition to its overall latency return a mapping dictionary of each partition to its overall latency
""" """
partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} partition_to_latency_mapping: dict[Partition, PartitionLatency] = {}
# Go through each partition and get its latency # Go through each partition and get its latency
for partition in partitions: for partition in partitions:
partition_latency = get_latency_of_one_partition( partition_latency = get_latency_of_one_partition(
@ -255,7 +255,7 @@ def get_comm_latency_between(
# the output size of those input nodes will be counted # the output size of those input nodes will be counted
# and added to comm_size # and added to comm_size
for node in child_partition.nodes: for node in child_partition.nodes:
input_nodes: Dict[Node, None] = {} input_nodes: dict[Node, None] = {}
map_arg(node.args, input_nodes.setdefault) map_arg(node.args, input_nodes.setdefault)
map_arg(node.kwargs, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault)
for n in input_nodes: for n in input_nodes:
@ -268,8 +268,8 @@ def get_comm_latency_between(
def get_latency_of_partitioned_graph( def get_latency_of_partitioned_graph(
partitions: List[Partition], partitions: list[Partition],
partition_to_latency_mapping: Dict[Partition, PartitionLatency], partition_to_latency_mapping: dict[Partition, PartitionLatency],
transfer_rate_bytes_per_sec: float, transfer_rate_bytes_per_sec: float,
): ):
"""Given all partitions in a graph, find the critical path among all partitions """Given all partitions in a graph, find the critical path among all partitions
@ -298,7 +298,7 @@ def get_latency_of_partitioned_graph(
return max_latency_sec return max_latency_sec
return latency_so_far_sec return latency_so_far_sec
def get_top_partitions(partitions: List[Partition]) -> List[Partition]: def get_top_partitions(partitions: list[Partition]) -> list[Partition]:
"""This function is to return all the partitions without parents """This function is to return all the partitions without parents
as the starting points of all the paths as the starting points of all the paths
""" """

View File

@ -17,21 +17,15 @@ import typing_extensions
import warnings import warnings
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from collections.abc import Generator, Mapping, Sequence
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
Generator,
List,
Mapping,
Optional, Optional,
overload, overload,
Protocol, Protocol,
Sequence,
Tuple,
Type,
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
Union, Union,
@ -168,7 +162,7 @@ from torch.types import py_sym_types, PySymType
class _HasMeta(Protocol): class _HasMeta(Protocol):
meta: Dict[str, PySymType] meta: dict[str, PySymType]
def is_sym_node(node: _HasMeta) -> bool: def is_sym_node(node: _HasMeta) -> bool:
@ -377,9 +371,9 @@ _ExtractValType = Optional[
PySymType, PySymType,
_AnyScriptObjectType, _AnyScriptObjectType,
BackwardState, BackwardState,
List["_ExtractValType"], list["_ExtractValType"],
Tuple["_ExtractValType", ...], tuple["_ExtractValType", ...],
Dict[str, "_ExtractValType"], dict[str, "_ExtractValType"],
Tensor, Tensor,
int, int,
float, float,
@ -767,10 +761,10 @@ def proxy_call(
proxy_mode: ProxyTorchDispatchMode, proxy_mode: ProxyTorchDispatchMode,
func: OpOverload, func: OpOverload,
pre_dispatch: bool, pre_dispatch: bool,
args: Tuple[object, ...], args: tuple[object, ...],
kwargs: Dict[str, object], kwargs: dict[str, object],
) -> object: ) -> object:
unrecognized_types: List[Type] = [] unrecognized_types: list[type] = []
flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs)) flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs))
def can_handle_tensor(x: Tensor) -> bool: def can_handle_tensor(x: Tensor) -> bool:
@ -987,7 +981,7 @@ class _SymNodeDict:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.sym_node_dict: Dict[PySymType, _PySymProxyType] = {} self.sym_node_dict: dict[PySymType, _PySymProxyType] = {}
def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None: def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None:
self.sym_node_dict[key.node] = value self.sym_node_dict[key.node] = value
@ -1015,9 +1009,9 @@ class _SymNodeDict:
class PythonKeyTracer(Tracer): class PythonKeyTracer(Tracer):
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
symnode_tracker: _SymNodeDict symnode_tracker: _SymNodeDict
sympy_expr_tracker: Dict[sympy.Symbol, object] sympy_expr_tracker: dict[sympy.Symbol, object]
tensor_tracker: MutableMapping[Tensor, _ProxyTensor] tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
torch_fn_counts: Dict[OpOverload, int] torch_fn_counts: dict[OpOverload, int]
enable_thunkify: bool = False enable_thunkify: bool = False
def __init__(self) -> None: def __init__(self) -> None:
@ -1043,14 +1037,14 @@ class PythonKeyTracer(Tracer):
self, self,
m: Module, m: Module,
forward: Callable[..., Any], forward: Callable[..., Any],
args: Tuple[Any, ...], args: tuple[Any, ...],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
) -> Any: ) -> Any:
return forward(*args, **kwargs) return forward(*args, **kwargs)
# We don't want to turn getattr calls into proxies. So we just return the actual value. # We don't want to turn getattr calls into proxies. So we just return the actual value.
def getattr( def getattr(
self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy] self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
) -> object: ) -> object:
return attr_val return attr_val
@ -1095,7 +1089,7 @@ class PythonKeyTracer(Tracer):
def _make_temp_remove_mode_context_manager( def _make_temp_remove_mode_context_manager(
mode_ty: Type[TorchFunctionMode], mode_ty: type[TorchFunctionMode],
) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: ) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]:
@contextmanager @contextmanager
def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]:
@ -1137,7 +1131,7 @@ def _make_temp_remove_mode_context_manager(
def dispatch_trace( def dispatch_trace(
root: Union[Module, Callable], root: Union[Module, Callable],
tracer: Tracer, tracer: Tracer,
concrete_args: Optional[Tuple[Any, ...]] = None, concrete_args: Optional[tuple[Any, ...]] = None,
) -> GraphModule: ) -> GraphModule:
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
@ -1235,9 +1229,9 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
def __torch_function__( def __torch_function__(
self, self,
func: OpOverload, func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...], types: tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (), args: tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None, kwargs: Optional[dict[str, object]] = None,
) -> object: ) -> object:
kwargs = kwargs or {} kwargs = kwargs or {}
self.tracer.torch_fn_metadata = func self.tracer.torch_fn_metadata = func
@ -1259,14 +1253,14 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
# The input to torch.amp.autocast_mode._exit_autocast graph node should be the # The input to torch.amp.autocast_mode._exit_autocast graph node should be the
# enter_autocast node. So we have to save the enter autocast node here, and assign it # enter_autocast node. So we have to save the enter autocast node here, and assign it
# to the exit_autocast call_function node. # to the exit_autocast call_function node.
self.enter_autocast_nodes: List[torch.fx.Node] = [] self.enter_autocast_nodes: list[torch.fx.Node] = []
def __torch_function__( def __torch_function__(
self, self,
func: Union[OpOverload, Callable], func: Union[OpOverload, Callable],
types: Tuple[torch._C._TensorMeta, ...], types: tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (), args: tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None, kwargs: Optional[dict[str, object]] = None,
) -> object: ) -> object:
kwargs = kwargs or {} kwargs = kwargs or {}
if func in _side_effectful_need_to_be_preserved_pre_dispatch: if func in _side_effectful_need_to_be_preserved_pre_dispatch:
@ -1324,7 +1318,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
# Every time we enter a mode, we maintain a stack telling us what the previous # Every time we enter a mode, we maintain a stack telling us what the previous
# ProxyTorchDispatchMode state was (if there was any). # ProxyTorchDispatchMode state was (if there was any).
# This lets us properly reset the state on exit. # This lets us properly reset the state on exit.
self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = [] self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = []
self.decomp_layers = 0 self.decomp_layers = 0
from torch._inductor import config from torch._inductor import config
@ -1334,9 +1328,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__( def __torch_dispatch__(
self, self,
func: OpOverload, func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...], types: tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...] = (), args: tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None, kwargs: Optional[dict[str, object]] = None,
) -> object: ) -> object:
with set_original_aten_op(func): with set_original_aten_op(func):
kwargs = kwargs or {} kwargs = kwargs or {}
@ -1354,7 +1348,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType], traceback: Optional[types.TracebackType],
) -> Optional[bool]: ) -> Optional[bool]:
@ -1372,10 +1366,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
return True return True
def _compute_proxy( def _compute_proxy(
self, func: OpOverload, args: Tuple[object, ...], out: PySymType self, func: OpOverload, args: tuple[object, ...], out: PySymType
) -> Proxy: ) -> Proxy:
# Handle torch.sym_sum # Handle torch.sym_sum
n_args: Tuple[object, ...] n_args: tuple[object, ...]
if len(args) == 1 and isinstance(args[0], (list, tuple)): if len(args) == 1 and isinstance(args[0], (list, tuple)):
n_args = ( n_args = (
tuple( tuple(
@ -1403,9 +1397,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def __sym_dispatch__( def __sym_dispatch__(
self, self,
func: OpOverload, func: OpOverload,
types: Tuple[torch._C._TensorMeta, ...], types: tuple[torch._C._TensorMeta, ...],
args: Tuple[object, ...], args: tuple[object, ...],
kwargs: Dict[str, object], kwargs: dict[str, object],
) -> object: ) -> object:
# Peephole optimize multiply by one # Peephole optimize multiply by one
# NB: be careful not to trigger guards here! # NB: be careful not to trigger guards here!
@ -1438,9 +1432,9 @@ class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
symnode_tracker: MutableMapping[PySymType, _PySymProxyType] symnode_tracker: MutableMapping[PySymType, _PySymProxyType]
tensor_tracker: MutableMapping[Tensor, _ProxyTensor] tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
sympy_expr_tracker: Dict[sympy.Symbol, object] sympy_expr_tracker: dict[sympy.Symbol, object]
torch_fn_metadata: Optional[OpOverload] torch_fn_metadata: Optional[OpOverload]
torch_fn_counts: Dict[OpOverload, int] torch_fn_counts: dict[OpOverload, int]
enable_thunkify: bool = False enable_thunkify: bool = False
def __init__(self, graph: fx.graph.Graph) -> None: def __init__(self, graph: fx.graph.Graph) -> None:
@ -1476,7 +1470,7 @@ class DecompositionInterpreter(fx.Interpreter):
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
def placeholder( def placeholder(
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
) -> object: ) -> object:
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
@ -1485,7 +1479,7 @@ class DecompositionInterpreter(fx.Interpreter):
return out return out
def get_attr( def get_attr(
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
) -> object: ) -> object:
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
@ -1495,7 +1489,7 @@ class DecompositionInterpreter(fx.Interpreter):
# call_function, call_method, call_module get traced automatically by the outer mode. # call_function, call_method, call_module get traced automatically by the outer mode.
def output( def output(
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
) -> object: ) -> object:
out = super().output(target, args, kwargs) # type: ignore[arg-type] out = super().output(target, args, kwargs) # type: ignore[arg-type]
@ -1516,13 +1510,13 @@ class DecompositionInterpreter(fx.Interpreter):
def wrapper_and_args_for_make_fx( def wrapper_and_args_for_make_fx(
func: Callable[..., R], args: Tuple[object, ...], kwargs: Dict[str, object] func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object]
) -> Tuple[Callable[[List[object]], R], List[object]]: ) -> tuple[Callable[[list[object]], R], list[object]]:
# make_fx doesn't support kwargs, so we need to do this flattening # make_fx doesn't support kwargs, so we need to do this flattening
# and then unflatten the args before calling func # and then unflatten the args before calling func
flat_args, spec = pytree.tree_flatten((args, kwargs)) flat_args, spec = pytree.tree_flatten((args, kwargs))
def wrapped(flat_args: List[object]) -> R: def wrapped(flat_args: list[object]) -> R:
fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec)
return func(*fn_args, **fn_kwargs) return func(*fn_args, **fn_kwargs)
@ -1642,7 +1636,7 @@ class _ModuleStackTracer(PythonKeyTracer):
return tracer.proxy_modules[self] return tracer.proxy_modules[self]
@property @property
def _modules(self) -> Dict[str, AttrProxy]: def _modules(self) -> dict[str, AttrProxy]:
assert "_modules" in self.__dict__ assert "_modules" in self.__dict__
submodules = self.__dict__["_modules"] submodules = self.__dict__["_modules"]
assert isinstance(submodules, dict) assert isinstance(submodules, dict)
@ -1674,7 +1668,7 @@ class _ModuleStackTracer(PythonKeyTracer):
raise _ModuleNotInstalledAsSubmoduleError from e raise _ModuleNotInstalledAsSubmoduleError from e
def getattr( def getattr(
self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy] self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
) -> object: ) -> object:
if ( if (
not isinstance(attr_val, Module) not isinstance(attr_val, Module)
@ -1693,7 +1687,7 @@ class _ModuleStackTracer(PythonKeyTracer):
return self.attr_proxy_map[attr_val] return self.attr_proxy_map[attr_val]
def trace( # type: ignore[override] def trace( # type: ignore[override]
self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]] self, root: Union[Module, Callable], concrete_args: Optional[dict[str, object]]
) -> fx.Graph: ) -> fx.Graph:
res = super().trace(root, concrete_args) res = super().trace(root, concrete_args)
@ -1702,7 +1696,7 @@ class _ModuleStackTracer(PythonKeyTracer):
# to the tracer while tracing, the proxy object gets registered # to the tracer while tracing, the proxy object gets registered
# first. So we need to replace the proxy modules with the real ones # first. So we need to replace the proxy modules with the real ones
# This can happen during HOO tracing # This can happen during HOO tracing
proxy_module_names_to_be_replaced: List[Tuple[str, _AttrProxy]] = [] proxy_module_names_to_be_replaced: list[tuple[str, _AttrProxy]] = []
for name, module in self.root.named_modules(): for name, module in self.root.named_modules():
if module in self.proxy_modules: if module in self.proxy_modules:
proxy_module_names_to_be_replaced.append((name, module)) proxy_module_names_to_be_replaced.append((name, module))
@ -1746,8 +1740,8 @@ class _ModuleStackTracer(PythonKeyTracer):
self, self,
m: Module, m: Module,
forward: Callable, forward: Callable,
args: Tuple[object, ...], args: tuple[object, ...],
kwargs: Dict[str, object], kwargs: dict[str, object],
) -> None: ) -> None:
"""PythonKeyTracer overrides call_module to avoid the scope handling, """PythonKeyTracer overrides call_module to avoid the scope handling,
but we actually want it. but we actually want it.
@ -1857,7 +1851,7 @@ class _MakefxTracer:
) -> None: ) -> None:
# Configurations that are used to initialize the context managers and their states. # Configurations that are used to initialize the context managers and their states.
# Should not modify them during tracing. # Should not modify them during tracing.
self.decomposition_table: Dict[OpOverload, Callable] = dict( self.decomposition_table: dict[OpOverload, Callable] = dict(
decomposition_table or {} decomposition_table or {}
) )
self.decomposition_table.setdefault( self.decomposition_table.setdefault(
@ -1885,7 +1879,7 @@ class _MakefxTracer:
nullcontext, TorchFunctionMetadataMode nullcontext, TorchFunctionMetadataMode
] = nullcontext() ] = nullcontext()
def _checkpoint_modes(self) -> List[Any]: def _checkpoint_modes(self) -> list[Any]:
return [ return [
self.fake_tensor_mode, self.fake_tensor_mode,
self.proxy_mode, self.proxy_mode,
@ -1913,7 +1907,7 @@ class _MakefxTracer:
@contextmanager @contextmanager
def _init_modes_from_inputs( def _init_modes_from_inputs(
self, f: Callable, args: Tuple[object, ...] self, f: Callable, args: tuple[object, ...]
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
prev_modes = self._checkpoint_modes() prev_modes = self._checkpoint_modes()
try: try:
@ -2202,7 +2196,7 @@ def make_fx(
return wrapped return wrapped
def get_torch_dispatch_modes() -> List[TorchDispatchMode]: def get_torch_dispatch_modes() -> list[TorchDispatchMode]:
return torch.utils._python_dispatch._get_current_dispatch_mode_stack() return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
@ -2240,7 +2234,7 @@ def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs)
# dispatch machinery which disables it for us # dispatch machinery which disables it for us
with disable_proxy_modes_tracing(): with disable_proxy_modes_tracing():
# TODO: properly compute types # TODO: properly compute types
types: List[Type] = [] types: list[type] = []
return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value] return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value]
@ -2252,8 +2246,8 @@ def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, Non
def maybe_handle_decomp( def maybe_handle_decomp(
proxy_mode: ProxyTorchDispatchMode, proxy_mode: ProxyTorchDispatchMode,
op: OpOverload, op: OpOverload,
args: Tuple[object, ...], args: tuple[object, ...],
kwargs: Dict[str, object], kwargs: dict[str, object],
) -> object: ) -> object:
from torch._inductor.compiler_bisector import CompilerBisector from torch._inductor.compiler_bisector import CompilerBisector
@ -2274,8 +2268,8 @@ def maybe_handle_decomp(
def get_isolated_graphmodule( def get_isolated_graphmodule(
func: Callable, func: Callable,
args: Tuple[object, ...], args: tuple[object, ...],
kwargs: Dict[str, object], kwargs: dict[str, object],
tracing_mode: str = "real", tracing_mode: str = "real",
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
) -> GraphModule: ) -> GraphModule:

View File

@ -4,7 +4,7 @@ import inspect
import itertools import itertools
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -83,11 +83,11 @@ class ShapeEnvEvent:
f: Callable f: Callable
# Arguments and keyword arguments called with. # Arguments and keyword arguments called with.
args: Optional[List[Any]] = None args: Optional[list[Any]] = None
kwargs: Optional[Dict[str, Any]] = None kwargs: Optional[dict[str, Any]] = None
# List of tracked_fakes at the time the method was called. # List of tracked_fakes at the time the method was called.
tracked_fakes: Optional[List[Any]] = None tracked_fakes: Optional[list[Any]] = None
# Name of the captured event. # Name of the captured event.
# Used for special handling of particular methods. # Used for special handling of particular methods.
@ -344,15 +344,15 @@ def replay_shape_env_events(events):
# ShapeEnv.produce_guards. # ShapeEnv.produce_guards.
@dataclass @dataclass
class FakeTensorMeta: class FakeTensorMeta:
tensor_size: Tuple[Union[int, torch.SymInt], ...] tensor_size: tuple[Union[int, torch.SymInt], ...]
tensor_stride: Tuple[Union[int, torch.SymInt], ...] tensor_stride: tuple[Union[int, torch.SymInt], ...]
tensor_storage_offset: Union[int, torch.SymInt] tensor_storage_offset: Union[int, torch.SymInt]
is_nested: bool is_nested: bool
def size(self) -> Tuple[Union[int, torch.SymInt], ...]: def size(self) -> tuple[Union[int, torch.SymInt], ...]:
return self.tensor_size return self.tensor_size
def stride(self) -> Tuple[Union[int, torch.SymInt], ...]: def stride(self) -> tuple[Union[int, torch.SymInt], ...]:
return self.tensor_stride return self.tensor_stride
def storage_offset(self) -> Union[int, torch.SymInt]: def storage_offset(self) -> Union[int, torch.SymInt]:
@ -445,7 +445,7 @@ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value)
# compare the two values. # compare the two values.
def compare_vars( def compare_vars(
map_value: Callable[[str, Any], Any] map_value: Callable[[str, Any], Any]
) -> List[Tuple[str, str, str]]: ) -> list[tuple[str, str, str]]:
env1_set, env2_set = set(env1_vars), set(env2_vars) env1_set, env2_set = set(env1_vars), set(env2_vars)
# First, compare the set of keys in each vars dictionary. # First, compare the set of keys in each vars dictionary.
@ -489,7 +489,7 @@ class NotEqualError(Exception):
def __init__( def __init__(
self, self,
msg: str, msg: str,
mismatched: List[Tuple[str, str, str]], mismatched: list[tuple[str, str, str]],
) -> None: ) -> None:
details = "\n".join( details = "\n".join(
[ [

View File

@ -6,7 +6,7 @@ import functools
import inspect import inspect
import textwrap import textwrap
from types import FunctionType from types import FunctionType
from typing import Any, Callable, cast, Dict, Optional, Union from typing import Any, Callable, cast, Optional, Union
import torch import torch
from torch._sources import normalize_source_lines from torch._sources import normalize_source_lines
@ -112,7 +112,7 @@ class RewritingTracer(Tracer):
def trace( def trace(
self, self,
root: Union[torch.nn.Module, Callable], root: Union[torch.nn.Module, Callable],
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[dict[str, Any]] = None,
) -> Graph: ) -> Graph:
return super().trace(_rewrite(root), concrete_args) return super().trace(_rewrite(root), concrete_args)

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import inspect import inspect
from typing import Any, Dict, Optional, Tuple from typing import Any, Optional
import torch import torch
import torch.fx import torch.fx
@ -42,7 +42,7 @@ class AnnotateTypesWithSchema(Transformer):
self.annotate_get_attrs = annotate_get_attrs self.annotate_get_attrs = annotate_get_attrs
def call_function( def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
): ):
python_ret_type = None python_ret_type = None
if self.annotate_functionals and target.__module__ == "torch.nn.functional": if self.annotate_functionals and target.__module__ == "torch.nn.functional":
@ -73,7 +73,7 @@ class AnnotateTypesWithSchema(Transformer):
return return_proxy return return_proxy
def call_module( def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
): ):
python_ret_type = None python_ret_type = None
assert isinstance(target, str) assert isinstance(target, str)
@ -91,8 +91,8 @@ class AnnotateTypesWithSchema(Transformer):
def get_attr( def get_attr(
self, self,
target: torch.fx.node.Target, target: torch.fx.node.Target,
args: Tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
): ):
attr_proxy = super().get_attr(target, args, kwargs) attr_proxy = super().get_attr(target, args, kwargs)

View File

@ -1,5 +1,6 @@
import re import re
from typing import Any, DefaultDict, Dict, List, Tuple, Union from collections import defaultdict
from typing import Any, Union
import numpy as np import numpy as np
import sympy as sp import sympy as sp
@ -13,10 +14,10 @@ s_pattern = r"s\d+"
def infer_symbol_values( def infer_symbol_values(
symints: List[Union[torch.SymInt, int]], symints: list[Union[torch.SymInt, int]],
init_symints: List[Union[torch.SymInt, int]], init_symints: list[Union[torch.SymInt, int]],
symbol_idx_dict: Dict[str, int], symbol_idx_dict: dict[str, int],
padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]], padding_constraints: defaultdict[torch.SymInt, list[Union[sp.Expr, int]]],
constraint: str, constraint: str,
) -> None: ) -> None:
if constraint.find("non-singleton") != -1: if constraint.find("non-singleton") != -1:
@ -83,8 +84,8 @@ def infer_symbol_values(
def calculate_value( def calculate_value(
left_expression: Union[str, Any, None], left_expression: Union[str, Any, None],
right_expression: Union[str, Any, None], right_expression: Union[str, Any, None],
symints: List[Union[torch.SymInt, int]], symints: list[Union[torch.SymInt, int]],
symbol_idx_dict: Dict[str, int], symbol_idx_dict: dict[str, int],
) -> None: ) -> None:
var, val = solve_equation(left_expression, right_expression) var, val = solve_equation(left_expression, right_expression)
idx = symbol_idx_dict[var] idx = symbol_idx_dict[var]
@ -95,7 +96,7 @@ def calculate_value(
def solve_equation( def solve_equation(
left_expression: Union[str, Any, None], left_expression: Union[str, Any, None],
right_expression: Union[str, Any, None], right_expression: Union[str, Any, None],
) -> Tuple[str, int]: ) -> tuple[str, int]:
expression = f"{left_expression} - {right_expression}" expression = f"{left_expression} - {right_expression}"
var = re.findall(s_pattern, expression)[0] var = re.findall(s_pattern, expression)[0]
if re.findall(parentheses_pattern, expression): if re.findall(parentheses_pattern, expression):
@ -116,9 +117,9 @@ def solve_equation(
def update_equation( def update_equation(
symints: List[Union[torch.SymInt, int]], symints: list[Union[torch.SymInt, int]],
init_symints: List[Union[torch.SymInt, int]], init_symints: list[Union[torch.SymInt, int]],
padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]], padding_constraints: defaultdict[torch.SymInt, list[Union[sp.Expr, int]]],
init_eq: sp.Expr, init_eq: sp.Expr,
new_mod_num: int, new_mod_num: int,
var: torch.SymInt, var: torch.SymInt,

View File

@ -20,7 +20,7 @@ import math
import operator import operator
import sys import sys
from functools import lru_cache, update_wrapper from functools import lru_cache, update_wrapper
from typing import Optional, Type, TYPE_CHECKING, Union from typing import Optional, TYPE_CHECKING, Union
import torch import torch
@ -1272,7 +1272,7 @@ def _make_node_magic(method, func):
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
raise raise
sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out) sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
pytype: Type pytype: type
# This is not strictly correct. In Python, a**b may return complex when # This is not strictly correct. In Python, a**b may return complex when
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
# returns a float while both arguments are ints: 2**(-1). Also, max and # returns a float while both arguments are ints: 2**(-1). Also, max and
@ -1335,7 +1335,7 @@ def _make_node_magic(method, func):
out_hint = None out_hint = None
if self.hint is not None: if self.hint is not None:
out_hint = op(self.hint) out_hint = op(self.hint)
pytype: Type pytype: type
if method in always_int_magic_methods: if method in always_int_magic_methods:
pytype = int pytype = int
elif method in always_bool_magic_methods: elif method in always_bool_magic_methods:
@ -1485,7 +1485,7 @@ def _make_node_sizes_strides(method, func):
out_hint = op(size_hints, stride_hints) out_hint = op(size_hints, stride_hints)
# NB: This is the indicator function, not the actual bool! # NB: This is the indicator function, not the actual bool!
pytype: Type pytype: type
if method.endswith("_indicator"): if method.endswith("_indicator"):
pytype = int pytype = int
else: else:

View File

@ -23,7 +23,8 @@ import re
import sys import sys
import threading import threading
import traceback import traceback
from collections import defaultdict from collections import Counter, defaultdict
from collections.abc import Iterator, Mapping, Sequence
from contextlib import _GeneratorContextManager, contextmanager from contextlib import _GeneratorContextManager, contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@ -31,19 +32,9 @@ from typing import (
Any, Any,
Callable, Callable,
cast, cast,
Counter,
DefaultDict,
Dict,
Iterator,
List,
Mapping,
NamedTuple, NamedTuple,
NoReturn, NoReturn,
Optional, Optional,
Sequence,
Set,
Tuple,
Type,
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
Union, Union,
@ -104,8 +95,8 @@ if TYPE_CHECKING:
from torch.types import BoolLikeType from torch.types import BoolLikeType
InputList = List InputList = list
DimList = List DimList = list
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -236,8 +227,8 @@ class SymIntEqByExpr:
def _nested_int_aware_sort( def _nested_int_aware_sort(
tup: Tuple[Union[SymInt, int], int] tup: tuple[Union[SymInt, int], int]
) -> Tuple[int, Union[SymInt, int], int]: ) -> tuple[int, Union[SymInt, int], int]:
return ( return (
# Order nested ints by their coefficients. # Order nested ints by their coefficients.
# 1 here to order nested ints after non-nested-ints. # 1 here to order nested ints after non-nested-ints.
@ -289,7 +280,7 @@ def lru_cache(
# These are modules that contain generic code for interacting with ShapeEnv # These are modules that contain generic code for interacting with ShapeEnv
# which are unlikely to identify a particular interesting guard statement # which are unlikely to identify a particular interesting guard statement
@lru_cache(None) @lru_cache(None)
def uninteresting_files() -> Set[str]: def uninteresting_files() -> set[str]:
import torch._compile import torch._compile
import torch._dynamo.eval_frame import torch._dynamo.eval_frame
import torch._inductor.sizevars import torch._inductor.sizevars
@ -332,8 +323,8 @@ def has_symbolic_sizes_strides(elem: torch.Tensor) -> bool:
Int: TypeAlias = Union[torch.SymInt, int] Int: TypeAlias = Union[torch.SymInt, int]
def create_contiguous(shape: Sequence[Int]) -> List[Int]: def create_contiguous(shape: Sequence[Int]) -> list[Int]:
strides: List[Int] = [1] strides: list[Int] = [1]
for dim in reversed(shape[:-1]): for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1]) # type: ignore[operator] strides.append(dim * strides[-1]) # type: ignore[operator]
return list(reversed(strides)) return list(reversed(strides))
@ -461,15 +452,15 @@ def check_consistent(new: _T, old: _T) -> None:
def resolve_unbacked_bindings( def resolve_unbacked_bindings(
shape_env: Optional[ShapeEnv], shape_env: Optional[ShapeEnv],
bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]], bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
) -> Optional[Dict[sympy.Symbol, pytree.KeyPath]]: ) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]:
if bindings is None: if bindings is None:
return None return None
assert shape_env is not None assert shape_env is not None
return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()} return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()}
Result: TypeAlias = Union[torch.Tensor, Tuple[torch.Tensor, ...]] Result: TypeAlias = Union[torch.Tensor, tuple[torch.Tensor, ...]]
def rebind_unbacked( def rebind_unbacked(
@ -557,7 +548,7 @@ def rebind_unbacked(
and len(raw_u1.args) == 2 and len(raw_u1.args) == 2
and ( and (
raw_u1_args0 := cast( raw_u1_args0 := cast(
Tuple[sympy.Basic, sympy.Basic], raw_u1.args[0] tuple[sympy.Basic, sympy.Basic], raw_u1.args[0]
) )
) )
and raw_u1_args0[0] == 1 and raw_u1_args0[0] == 1
@ -565,7 +556,7 @@ def rebind_unbacked(
and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol)
and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1))
and eq.rhs == 1 and eq.rhs == 1
and cast(Tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True) and cast(tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True)
): ):
# This is what the pattern match above is testing # This is what the pattern match above is testing
repacked = _sympy_cast_symbool_to_symint_guardless( repacked = _sympy_cast_symbool_to_symint_guardless(
@ -645,8 +636,8 @@ def canonicalize_bool_expr(expr: _T) -> _T:
def _sympy_from_args( def _sympy_from_args(
cls: Union[Type[sympy.Add], Type[sympy.Mul]], cls: type[Union[sympy.Add, sympy.Mul]],
args: List[sympy.Expr], args: list[sympy.Expr],
sort: bool = True, sort: bool = True,
is_commutative: Optional[bool] = None, is_commutative: Optional[bool] = None,
) -> sympy.Expr: ) -> sympy.Expr:
@ -686,7 +677,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
return type(expr)(*map(canonicalize_bool_expr, expr.args)) return type(expr)(*map(canonicalize_bool_expr, expr.args))
opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
t: Union[Type[Any]] t: Union[type[Any]]
if isinstance(expr, tuple(opposite.keys())): if isinstance(expr, tuple(opposite.keys())):
rhs = expr.lhs - expr.rhs # type: ignore[attr-defined] rhs = expr.lhs - expr.rhs # type: ignore[attr-defined]
t = opposite[type(expr)] # type: ignore[index] t = opposite[type(expr)] # type: ignore[index]
@ -888,7 +879,7 @@ def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]:
def find_symbol_binding_fx_nodes( def find_symbol_binding_fx_nodes(
graph: torch.fx.Graph, graph: torch.fx.Graph,
) -> Dict[sympy.Symbol, torch.fx.Node]: ) -> dict[sympy.Symbol, torch.fx.Node]:
r = {} r = {}
# NB: Prefer first occurrence of symbol # NB: Prefer first occurrence of symbol
for node in graph.nodes: for node in graph.nodes:
@ -949,7 +940,7 @@ def compute_unbacked_bindings(
example_value: object, example_value: object,
old_example_value: Optional[object] = None, old_example_value: Optional[object] = None,
peek: bool = False, peek: bool = False,
) -> Optional[Dict[sympy.Symbol, pytree.KeyPath]]: ) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]:
""" """
After having run fake tensor propagation and producing example_value After having run fake tensor propagation and producing example_value
result, traverse example_value looking for freshly bound unbacked result, traverse example_value looking for freshly bound unbacked
@ -977,7 +968,7 @@ def compute_unbacked_bindings(
def free_unbacked_symbols_with_path( def free_unbacked_symbols_with_path(
a: object, path: pytree.KeyPath, real: Optional[object] = None a: object, path: pytree.KeyPath, real: Optional[object] = None
) -> Dict[sympy.Symbol, pytree.KeyPath]: ) -> dict[sympy.Symbol, pytree.KeyPath]:
assert shape_env is not None assert shape_env is not None
r = {} r = {}
if isinstance(a, (tuple, list)): if isinstance(a, (tuple, list)):
@ -1456,11 +1447,11 @@ def guard_float(a: Union[SymFloat, float]) -> float:
# Given a GraphModule, return all the FakeTensors for all the placeholders # Given a GraphModule, return all the FakeTensors for all the placeholders
def fx_placeholder_vals(gm: torch.fx.GraphModule) -> List[object]: def fx_placeholder_vals(gm: torch.fx.GraphModule) -> list[object]:
return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"] return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"]
def fx_placeholder_targets(gm: torch.fx.GraphModule) -> List[str]: def fx_placeholder_targets(gm: torch.fx.GraphModule) -> list[str]:
return [n.target for n in gm.graph.nodes if n.op == "placeholder"] return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
@ -1475,7 +1466,7 @@ def eval_guards(
) )
def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> Dict[sympy.Symbol, int]: def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> dict[sympy.Symbol, int]:
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr] return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr]
@ -1617,15 +1608,15 @@ class EqualityConstraint(Constraint):
form and so the problem reduces to symbolic expression equality.) form and so the problem reduces to symbolic expression equality.)
""" """
source_pairs: List[Tuple[Source, Source]] source_pairs: list[tuple[Source, Source]]
derived_equalities: List[ derived_equalities: list[
Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]] tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]
] ]
phantom_symbols: List[sympy.Symbol] phantom_symbols: list[sympy.Symbol]
relaxed_sources: Set[Source] relaxed_sources: set[Source]
_parents: Dict[Source, Source] = field(init=False) _parents: dict[Source, Source] = field(init=False)
_defs: Dict[Source, sympy.Expr] = field(init=False) _defs: dict[Source, sympy.Expr] = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
""" """
@ -1643,12 +1634,12 @@ class EqualityConstraint(Constraint):
# self._parents is a map from input sources to input sources where, conceptually, # self._parents is a map from input sources to input sources where, conceptually,
# these are directed edges in a union-find forest # these are directed edges in a union-find forest
_parents: Dict[Source, Source] = {} _parents: dict[Source, Source] = {}
object.__setattr__(self, "_parents", _parents) object.__setattr__(self, "_parents", _parents)
# self._defs is a map from input sources to "canonical" symbolic expressions, # self._defs is a map from input sources to "canonical" symbolic expressions,
# i.e., unary expressions with symbols that corresponds to regular Dims (i.e., # i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,
# not derived Dims) # not derived Dims)
_defs: Dict[Source, sympy.Expr] = {} _defs: dict[Source, sympy.Expr] = {}
object.__setattr__(self, "_defs", _defs) object.__setattr__(self, "_defs", _defs)
for source1, source2 in self.source_pairs: for source1, source2 in self.source_pairs:
@ -1838,7 +1829,7 @@ class StatefulSymbolicContext(StatelessSymbolicContext):
# cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never
# get recorded in var_to_val, etc. # get recorded in var_to_val, etc.
# TODO(voz): consider a weakref to the shape_env here # TODO(voz): consider a weakref to the shape_env here
shape_env_to_source_to_symbol_cache: Dict[int, Dict[str, sympy.Expr]] = None # type: ignore[assignment] shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = None # type: ignore[assignment]
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
@ -1856,7 +1847,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext):
flexibility, with inner symbolic contexts mapped via attr -> symbolic context. flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
""" """
inner_contexts: Dict[str, SymbolicContext] = None # type: ignore[assignment] inner_contexts: dict[str, SymbolicContext] = None # type: ignore[assignment]
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
@ -1875,7 +1866,7 @@ def is_symbolic(
IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) IndicatorTypes = (IsNonOverlappingAndDenseIndicator,)
def _expandsums(args: List[sympy.Expr]) -> Tuple[sympy.Expr, bool]: def _expandsums(args: list[sympy.Expr]) -> tuple[sympy.Expr, bool]:
adds, other = [], [] adds, other = [], []
for arg in args: for arg in args:
if arg.is_Add: if arg.is_Add:
@ -1912,8 +1903,8 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
elif exp < 0: elif exp < 0:
return S.One / sympy.expand_multinomial(S.One / expr, deep=False) return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
elif expr.is_Mul: elif expr.is_Mul:
num: List[sympy.Expr] = [] num: list[sympy.Expr] = []
den: List[sympy.Expr] = [] den: list[sympy.Expr] = []
for arg in expr.args: for arg in expr.args:
if arg.is_Pow and arg.args[1] == -1: if arg.is_Pow and arg.args[1] == -1:
den.append(S.One / arg) # type: ignore[operator, arg-type] den.append(S.One / arg) # type: ignore[operator, arg-type]
@ -1961,7 +1952,7 @@ class _SymbolInfo(NamedTuple):
def _maybe_evaluate_static_worker( def _maybe_evaluate_static_worker(
expr: _SympyT, expr: _SympyT,
# NB: this is a tuple to ensure it can be LRU cached # NB: this is a tuple to ensure it can be LRU cached
symbol_info: Tuple[_SymbolInfo, ...], symbol_info: tuple[_SymbolInfo, ...],
unbacked_only: bool, unbacked_only: bool,
size_oblivious: bool, size_oblivious: bool,
) -> Optional[_SympyT]: ) -> Optional[_SympyT]:
@ -2193,9 +2184,9 @@ class SymExprPrinter(PythonPrinter):
class _ShapeGuardPrinter(abc.ABC): class _ShapeGuardPrinter(abc.ABC):
def __init__( def __init__(
self, self,
symbol_to_source: Mapping[sympy.Symbol, List[Source]], symbol_to_source: Mapping[sympy.Symbol, list[Source]],
source_ref: Callable[[Source], str], source_ref: Callable[[Source], str],
var_to_sources: Mapping[sympy.Symbol, List[Source]], var_to_sources: Mapping[sympy.Symbol, list[Source]],
) -> None: ) -> None:
self.symbol_to_source = symbol_to_source self.symbol_to_source = symbol_to_source
self.source_ref = source_ref self.source_ref = source_ref
@ -2246,7 +2237,7 @@ class ShapeGuardPrinter(ShapeGuardPythonPrinter):
class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter): class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter):
def __init__(self, var_to_sources: Mapping[sympy.Symbol, List[Source]]): def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]):
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
@ -2261,7 +2252,7 @@ class DynamicDimConstraintPrinter(PythonPrinter):
def __init__( def __init__(
self, self,
symbol_to_source: Dict[sympy.Symbol, List[Source]], symbol_to_source: dict[sympy.Symbol, list[Source]],
source_name_to_debug_name: Mapping[str, str], source_name_to_debug_name: Mapping[str, str],
): ):
super().__init__() super().__init__()
@ -2284,23 +2275,23 @@ class DimConstraints:
def __init__( def __init__(
self, self,
symbol_to_source: Dict[sympy.Symbol, List[Source]], symbol_to_source: dict[sympy.Symbol, list[Source]],
var_to_val: Mapping[sympy.Symbol, sympy.Integer], var_to_val: Mapping[sympy.Symbol, sympy.Integer],
marked_dynamic: Set[sympy.Symbol], marked_dynamic: set[sympy.Symbol],
source_name_to_debug_name: Mapping[str, str], source_name_to_debug_name: Mapping[str, str],
) -> None: ) -> None:
# We try to solve systems of inequalities with 1 free variable. # We try to solve systems of inequalities with 1 free variable.
self._univariate_inequalities: Dict[ self._univariate_inequalities: dict[
sympy.Symbol, Set[SympyBoolean] sympy.Symbol, set[SympyBoolean]
] = defaultdict(set) ] = defaultdict(set)
# Among them, we prioritize solving for a free variable that has equalities. # Among them, we prioritize solving for a free variable that has equalities.
# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
# and removing a symbol from the former => removing it from the latter. # and removing a symbol from the former => removing it from the latter.
self._symbols_with_equalities: Set[sympy.Symbol] = set() self._symbols_with_equalities: set[sympy.Symbol] = set()
# A solution of a free variable with equalities becomes a substitution. # A solution of a free variable with equalities becomes a substitution.
# We use these substitutions to simplify other constraints. # We use these substitutions to simplify other constraints.
# NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions. # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.
self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {} self._substitutions: dict[sympy.Symbol, sympy.Integer] = {}
# In general, constraints may have // and % operations. # In general, constraints may have // and % operations.
# Of course, // can be expressed in terms of / and %. # Of course, // can be expressed in terms of / and %.
@ -2308,20 +2299,20 @@ class DimConstraints:
# We do so by using the values of variables as hints to evaluate %. # We do so by using the values of variables as hints to evaluate %.
# For soundness we record additional congruence guards and solve them separately. # For soundness we record additional congruence guards and solve them separately.
self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val
self._congruences: DefaultDict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set) self._congruences: defaultdict[sympy.Symbol, set[sympy.Expr]] = defaultdict(set)
# We do not try to (directly) solve inequalities with > 1 free variables. # We do not try to (directly) solve inequalities with > 1 free variables.
# NOTE: free variables in these inequalities cannot also be in _substitutions. # NOTE: free variables in these inequalities cannot also be in _substitutions.
self._multivariate_inequalities: Set[SympyBoolean] = set() self._multivariate_inequalities: set[SympyBoolean] = set()
# We park external equalities between free variables here. # We park external equalities between free variables here.
self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = [] self._symbolic_equivalences: list[tuple[Source, sympy.Expr]] = []
# Solutions come in two forms: # Solutions come in two forms:
# - (static) specializations # - (static) specializations
# - (dynamic) inequalities / congruences # - (dynamic) inequalities / congruences
self._static_results: Set[str] = set() self._static_results: set[str] = set()
self._dynamic_results: Set[str] = set() self._dynamic_results: set[str] = set()
# printer for solutions # printer for solutions
self._dcp = DynamicDimConstraintPrinter( self._dcp = DynamicDimConstraintPrinter(
@ -2329,13 +2320,13 @@ class DimConstraints:
) )
# inconsistencies found on substituting with concrete values / static solutions # inconsistencies found on substituting with concrete values / static solutions
self._inconsistencies: List[str] = [] self._inconsistencies: list[str] = []
# symbols that are marked dynamic # symbols that are marked dynamic
self._marked_dynamic = marked_dynamic self._marked_dynamic = marked_dynamic
# track supported sympy functions and subtract from list of all sympy functions # track supported sympy functions and subtract from list of all sympy functions
self._supported_sympy_functions: Set[sympy.Function] = { self._supported_sympy_functions: set[sympy.Function] = {
Application, Application,
Mod, Mod,
PythonMod, PythonMod,
@ -2488,8 +2479,8 @@ class DimConstraints:
# these will resolve to either specializations or dynamic equality constraints # these will resolve to either specializations or dynamic equality constraints
self._symbolic_equivalences.append((source, expr)) self._symbolic_equivalences.append((source, expr))
def _reduce_congruences(self) -> Dict[sympy.Symbol, Set[sympy.Expr]]: def _reduce_congruences(self) -> dict[sympy.Symbol, set[sympy.Expr]]:
reduced_congruences: Dict[sympy.Symbol, Set[sympy.Expr]] = {} reduced_congruences: dict[sympy.Symbol, set[sympy.Expr]] = {}
for s, congruences in self._congruences.items(): for s, congruences in self._congruences.items():
remainder_modulus_pairs = [] remainder_modulus_pairs = []
congruences_to_check = set() congruences_to_check = set()
@ -2650,7 +2641,7 @@ class DimConstraints:
cond = cond and isinstance(divisor, sympy.Integer) cond = cond and isinstance(divisor, sympy.Integer)
return cond return cond
def forced_specializations(self) -> Dict[str, sympy.Expr]: def forced_specializations(self) -> dict[str, sympy.Expr]:
"""Returns a dictionary of the names of symbols to their specialized value""" """Returns a dictionary of the names of symbols to their specialized value"""
def debug_name(src: Source) -> str: def debug_name(src: Source) -> str:
@ -2678,8 +2669,8 @@ class DimConstraints:
def _process_derived_dim_roots( def _process_derived_dim_roots(
self, self,
results: Dict[str, Dict[str, Any]], results: dict[str, dict[str, Any]],
name_to_dim: Dict[str, Any], name_to_dim: dict[str, Any],
) -> None: ) -> None:
""" """
Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots, Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots,
@ -2745,7 +2736,7 @@ class DimConstraints:
# {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2} # {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2}
# we want instead: # we want instead:
# {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3} # {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3}
introduced_roots: Dict[str, str] = {} # map new root -> old root introduced_roots: dict[str, str] = {} # map new root -> old root
for k, c in list(results.items()): for k, c in list(results.items()):
if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim
root = next(iter(c["eq"].free_symbols)) root = next(iter(c["eq"].free_symbols))
@ -2782,7 +2773,7 @@ class DimConstraints:
# this consists of: # this consists of:
# 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim # 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim
# 2) {"dy": "dx + 1"} -> dx: root for suggested fix # 2) {"dy": "dx + 1"} -> dx: root for suggested fix
modified_roots: Set[str] = set() modified_roots: set[str] = set()
for k, c in results.items(): for k, c in results.items():
if k not in name_to_dim: # _dynamo.export() may handle source directly if k not in name_to_dim: # _dynamo.export() may handle source directly
continue continue
@ -2799,7 +2790,7 @@ class DimConstraints:
# evaluate the new value for each root # evaluate the new value for each root
# this is now either 1) unchanged, 2) refined with a new range, # this is now either 1) unchanged, 2) refined with a new range,
# or 3) specialized to a concrete value # or 3) specialized to a concrete value
modified_root_values: Dict[str, Dict[str, Any]] = {} modified_root_values: dict[str, dict[str, Any]] = {}
for mroot in modified_roots: for mroot in modified_roots:
swapped_root = True swapped_root = True
if mroot in results: if mroot in results:
@ -2860,9 +2851,9 @@ class DimConstraints:
def prettify_results( def prettify_results(
self, self,
original_signature: inspect.Signature, original_signature: inspect.Signature,
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
constraint_violation_error: object, constraint_violation_error: object,
forced_specializations: Dict[str, str], forced_specializations: dict[str, str],
) -> str: ) -> str:
"""Format a message for constraint violation erros""" """Format a message for constraint violation erros"""
from torch.export.dynamic_shapes import _get_dim_name_mapping from torch.export.dynamic_shapes import _get_dim_name_mapping
@ -2876,7 +2867,7 @@ class DimConstraints:
s = s.replace(k, v) if not inverse else s.replace(v, k) s = s.replace(k, v) if not inverse else s.replace(v, k)
return s return s
results: DefaultDict[str, Dict[str, Any]] = defaultdict(dict) results: defaultdict[str, dict[str, Any]] = defaultdict(dict)
if dynamic_shapes is None: if dynamic_shapes is None:
dynamic_shapes = {} dynamic_shapes = {}
@ -3050,7 +3041,7 @@ class ShapeEnv:
self, self,
*, *,
should_record_events: Optional[bool] = None, should_record_events: Optional[bool] = None,
tracked_fakes: Optional[List[Any]] = None, tracked_fakes: Optional[list[Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
self._init(**kwargs) self._init(**kwargs)
@ -3086,7 +3077,7 @@ class ShapeEnv:
# Keep track of the list of tracked fakes. # Keep track of the list of tracked fakes.
self.tracked_fakes = tracked_fakes self.tracked_fakes = tracked_fakes
# List of events for reconstructing ShapeEnv at arbitrary points in time. # List of events for reconstructing ShapeEnv at arbitrary points in time.
self.events: List[ShapeEnvEvent] = ( self.events: list[ShapeEnvEvent] = (
[ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)]
if self.should_record_events if self.should_record_events
else [] else []
@ -3099,7 +3090,7 @@ class ShapeEnv:
# NOTE: It's important that SymNodes in this cache have their ShapeEnv # NOTE: It's important that SymNodes in this cache have their ShapeEnv
# stripped otherwise you end up with cycles which can only be cleaned # stripped otherwise you end up with cycles which can only be cleaned
# with the GC. # with the GC.
self.fake_tensor_cache: Dict[ self.fake_tensor_cache: dict[
torch._subclasses.fake_tensor._DispatchCacheKey, torch._subclasses.fake_tensor._DispatchCacheKey,
torch._subclasses.fake_tensor._DispatchCacheEntry, torch._subclasses.fake_tensor._DispatchCacheEntry,
] = {} ] = {}
@ -3134,7 +3125,7 @@ class ShapeEnv:
# symbolically equal. # symbolically equal.
duck_shape: Optional[bool] = None, duck_shape: Optional[bool] = None,
# For debugging # For debugging
co_fields: Optional[Dict[str, str]] = None, co_fields: Optional[dict[str, str]] = None,
# When True, whenever safe, we will generate a deferred runtime assert # When True, whenever safe, we will generate a deferred runtime assert
# instead of a guard whenever we know that an expression must be True, # instead of a guard whenever we know that an expression must be True,
# otherwise it would be an error, even for backed SymInts (where we # otherwise it would be an error, even for backed SymInts (where we
@ -3165,50 +3156,50 @@ class ShapeEnv:
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
) )
self.guards: List[ShapeGuard] = [] self.guards: list[ShapeGuard] = []
self.axioms: Dict[sympy.Expr, sympy.Expr] = {} self.axioms: dict[sympy.Expr, sympy.Expr] = {}
# Maps symbolic ints to their original concrete values # Maps symbolic ints to their original concrete values
# Currently populated from tensors # Currently populated from tensors
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
# Like var_to_val, but only set when propagate_real_tensors is on. # Like var_to_val, but only set when propagate_real_tensors is on.
# Used as last resort to avoid GuardOnDataDependent error # Used as last resort to avoid GuardOnDataDependent error
self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} self.unbacked_var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
# Like above, but used exclusively for OBLIVIOUS_SIZE. These # Like above, but used exclusively for OBLIVIOUS_SIZE. These
# potentially could be put together but I am not sure, writing out # potentially could be put together but I am not sure, writing out
# the logic individually before abstracting. # the logic individually before abstracting.
self.oblivious_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} self.oblivious_var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
# Maps symbolic ints to their min/max range. These ranges # Maps symbolic ints to their min/max range. These ranges
# are conservative: the int MUST fall in the range, but the # are conservative: the int MUST fall in the range, but the
# range may contain ints which may not actually appear in # range may contain ints which may not actually appear in
# practice # practice
self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {} self.var_to_range: dict[sympy.Symbol, ValueRanges] = {}
self.var_to_range_sloc: Dict[sympy.Symbol, ValueRangesSLoc] = {} self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
# When doing a size-oblivious test, exclude this integer and # When doing a size-oblivious test, exclude this integer and
# everything higher than it from the acceptable range. This solves # everything higher than it from the acceptable range. This solves
# https://github.com/pytorch/pytorch/issues/120288 for constant range # https://github.com/pytorch/pytorch/issues/120288 for constant range
# case # case
# TODO: generalize this to work with expressions (in that case, we # TODO: generalize this to work with expressions (in that case, we
# need to maintain a SET and we need extra symbolic reasoning on top) # need to maintain a SET and we need extra symbolic reasoning on top)
self.oblivious_upper_bound_exclusive: Dict[sympy.Symbol, sympy.Integer] = {} self.oblivious_upper_bound_exclusive: dict[sympy.Symbol, sympy.Integer] = {}
self.source_name_to_debug_name: Dict[str, str] = {} self.source_name_to_debug_name: dict[str, str] = {}
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {} self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {} self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
# Maps a source to the *original* symbol that was assigned to it # Maps a source to the *original* symbol that was assigned to it
self.source_to_var: Dict[str, sympy.Symbol] = {} self.source_to_var: dict[str, sympy.Symbol] = {}
# Maps from sympy ints to expressions representing them # Maps from sympy ints to expressions representing them
# Populated from equality guards (i.e. a.shape[0] == b.shape[0]) # Populated from equality guards (i.e. a.shape[0] == b.shape[0])
self.replacements: Dict[sympy.Symbol, sympy.Expr] = {} self.replacements: dict[sympy.Symbol, sympy.Expr] = {}
# The sloc of the guard that triggered this replacement to be added # The sloc of the guard that triggered this replacement to be added
self.replacements_slocs: Dict[sympy.Symbol, SLoc] = {} self.replacements_slocs: dict[sympy.Symbol, SLoc] = {}
self.unbacked_renamings: Dict[sympy.Symbol, sympy.Symbol] = {} self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {}
# Set holds a % b expressions that evaluate to 0. # Set holds a % b expressions that evaluate to 0.
self.divisible: Set[sympy.Expr] = set() self.divisible: set[sympy.Expr] = set()
# Set that holds "size-like" symbols. When we perform # Set that holds "size-like" symbols. When we perform
# "size-oblivious" tests, these can be assumed to be >= 2. # "size-oblivious" tests, these can be assumed to be >= 2.
self.size_like: Set[sympy.Symbol] = set() self.size_like: set[sympy.Symbol] = set()
# Duck-shaping says that if two input tensors have the same size, # Duck-shaping says that if two input tensors have the same size,
# they get assigned the same symbolic variable # they get assigned the same symbolic variable
self.val_to_var: Dict[int, sympy.Symbol] = {} self.val_to_var: dict[int, sympy.Symbol] = {}
if specialize_zero_one: if specialize_zero_one:
self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One} self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One}
self.unbacked_symfloat_counter = itertools.count() self.unbacked_symfloat_counter = itertools.count()
@ -3241,8 +3232,8 @@ class ShapeEnv:
# to the next unbacked symbol to wait on, but if we choose the # to the next unbacked symbol to wait on, but if we choose the
# latest key, an assert will only show up at the moment when # latest key, an assert will only show up at the moment when
# we can actually codegen it. # we can actually codegen it.
self.deferred_runtime_asserts: Dict[ self.deferred_runtime_asserts: dict[
Optional[sympy.Symbol], List[RuntimeAssert] Optional[sympy.Symbol], list[RuntimeAssert]
] = {} ] = {}
# This exists so we can efficiently invalidate the cache (it's used as # This exists so we can efficiently invalidate the cache (it's used as
# part of the cache key); otherwise we'd have to iterate through # part of the cache key); otherwise we'd have to iterate through
@ -3279,7 +3270,7 @@ class ShapeEnv:
# #
# NB: fresh unbacked symbols NEVER get substitutions applied to them, # NB: fresh unbacked symbols NEVER get substitutions applied to them,
# they are binding sites! # they are binding sites!
self.pending_fresh_unbacked_symbols: List[sympy.Symbol] = [] self.pending_fresh_unbacked_symbols: list[sympy.Symbol] = []
# Version counter used to invalidate cached values # Version counter used to invalidate cached values
self._prev_cache_key = self._get_key() self._prev_cache_key = self._get_key()
@ -3294,8 +3285,8 @@ class ShapeEnv:
# 2. list of arguments # 2. list of arguments
# This drastically reduces the size of the FX graph, avoiding # This drastically reduces the size of the FX graph, avoiding
# duplicated nodes. # duplicated nodes.
self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {} self.fx_node_cache: dict[tuple[Callable, tuple[Any, ...]], torch.fx.Node] = {}
self.source_to_symbol: Dict[str, sympy.Symbol] = {} self.source_to_symbol: dict[str, sympy.Symbol] = {}
# Suppose you want to replace an unbacked symbol with another # Suppose you want to replace an unbacked symbol with another
# unbacked symbol. This is error prone because you can cause # unbacked symbol. This is error prone because you can cause
@ -3322,7 +3313,7 @@ class ShapeEnv:
# bindings. At the moment, this is not tracked, but we potentially # bindings. At the moment, this is not tracked, but we potentially
# could track this at the IR level using a higher order operator # could track this at the IR level using a higher order operator
# with something like effect token tracking. # with something like effect token tracking.
self.unbacked_alloc_order: Dict[sympy.Symbol, int] = {} self.unbacked_alloc_order: dict[sympy.Symbol, int] = {}
from torch.fx.experimental.validator import translation_validation_enabled from torch.fx.experimental.validator import translation_validation_enabled
@ -3345,7 +3336,7 @@ class ShapeEnv:
# Whenever you add a node to self.graph, you must add a mapping to this # Whenever you add a node to self.graph, you must add a mapping to this
# variable. Otherwise, the built FX graph on the replayed ShapeEnv will # variable. Otherwise, the built FX graph on the replayed ShapeEnv will
# not be valid. # not be valid.
self.name_to_node: Dict[str, torch.fx.Node] = {} self.name_to_node: dict[str, torch.fx.Node] = {}
@property @property
def allow_scalar_outputs(self) -> bool: def allow_scalar_outputs(self) -> bool:
@ -3439,7 +3430,7 @@ class ShapeEnv:
shape_env_check_state_equal(self, other, non_state_variable_names, map_value) shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
def _snapshot_tracked_fakes(self) -> Optional[List[Any]]: def _snapshot_tracked_fakes(self) -> Optional[list[Any]]:
if self.tracked_fakes is None: if self.tracked_fakes is None:
return None return None
@ -3631,7 +3622,7 @@ class ShapeEnv:
self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True) self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
return self.source_to_symbol[srcname] return self.source_to_symbol[srcname]
def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None: def _add_z3var(self, symbol: sympy.Symbol, type: type) -> None:
if self._translation_validation_enabled: if self._translation_validation_enabled:
self.validator.add_var(symbol, type) self.validator.add_var(symbol, type)
@ -3651,8 +3642,8 @@ class ShapeEnv:
def _create_fx_call_function( def _create_fx_call_function(
self, self,
op: Callable, op: Callable,
args: Tuple, args: tuple,
) -> Tuple[Optional[torch.fx.Node], bool]: ) -> tuple[Optional[torch.fx.Node], bool]:
# Cache this tuple in order to avoid duplicated nodes. # Cache this tuple in order to avoid duplicated nodes.
node_key = (op, args) node_key = (op, args)
# Flags whether the returned node was cached or not. # Flags whether the returned node was cached or not.
@ -3681,7 +3672,7 @@ class ShapeEnv:
def _create_fx_placeholder_and_z3var( def _create_fx_placeholder_and_z3var(
self, self,
symbol: sympy.Symbol, symbol: sympy.Symbol,
type: Type, type: type,
) -> Optional[torch.fx.Node]: ) -> Optional[torch.fx.Node]:
if not self._translation_validation_enabled: if not self._translation_validation_enabled:
return None return None
@ -3742,7 +3733,7 @@ class ShapeEnv:
"""Context manager to ignore all guards generated inside""" """Context manager to ignore all guards generated inside"""
return _suppress_guards(self) return _suppress_guards(self)
def _get_key(self) -> Tuple[int, int, int, int]: def _get_key(self) -> tuple[int, int, int, int]:
""" """
Defines the current "state" of the guards we've accumulated in this ShapeEnv. Defines the current "state" of the guards we've accumulated in this ShapeEnv.
Determines when we need to invalidate our cache Determines when we need to invalidate our cache
@ -3778,7 +3769,7 @@ class ShapeEnv:
ex_size: Sequence[Union[int, SymInt]], ex_size: Sequence[Union[int, SymInt]],
source: Source, source: Source,
symbolic_context: SymbolicContext, symbolic_context: SymbolicContext,
) -> List[sympy.Expr]: ) -> list[sympy.Expr]:
return self._produce_dyn_sizes_from_int_tuple( return self._produce_dyn_sizes_from_int_tuple(
tuple(ex_size), source, symbolic_context tuple(ex_size), source, symbolic_context
) )
@ -3788,7 +3779,7 @@ class ShapeEnv:
tensor_size: Sequence[Union[int, SymInt]], tensor_size: Sequence[Union[int, SymInt]],
source: Source, source: Source,
symbolic_context: SymbolicContext, symbolic_context: SymbolicContext,
) -> List[sympy.Expr]: ) -> list[sympy.Expr]:
assert all( assert all(
not is_symbolic(val) for val in tensor_size not is_symbolic(val) for val in tensor_size
), f"Expect size to be a plain tuple of ints but got {tensor_size}" ), f"Expect size to be a plain tuple of ints but got {tensor_size}"
@ -3816,9 +3807,9 @@ class ShapeEnv:
source: Source, source: Source,
*, *,
symbolic_context: Optional[SymbolicContext] = None, symbolic_context: Optional[SymbolicContext] = None,
) -> Tuple[ ) -> tuple[
Tuple[Union[int, SymInt], ...], tuple[Union[int, SymInt], ...],
Tuple[Union[int, SymInt], ...], tuple[Union[int, SymInt], ...],
Union[int, SymInt], Union[int, SymInt],
]: ]:
""" """
@ -3903,17 +3894,17 @@ class ShapeEnv:
source: Source, source: Source,
*, *,
symbolic_context: Optional[SymbolicContext] = None, symbolic_context: Optional[SymbolicContext] = None,
) -> Tuple[ ) -> tuple[
Tuple[Union[int, SymInt], ...], tuple[Union[int, SymInt], ...],
Tuple[Union[int, SymInt], ...], tuple[Union[int, SymInt], ...],
Union[int, SymInt], Union[int, SymInt],
]: ]:
dim = len(ex_size) dim = len(ex_size)
# Reimplement the legacy behavior # Reimplement the legacy behavior
if symbolic_context is None: if symbolic_context is None:
constraint_sizes: List[DimConstraint] = [None] * dim constraint_sizes: list[DimConstraint] = [None] * dim
constraint_strides: List[DimConstraint] = [None] * dim constraint_strides: list[DimConstraint] = [None] * dim
dynamic_dims = [] dynamic_dims = []
dynamic_strides = [] dynamic_strides = []
for i in range(dim): for i in range(dim):
@ -3963,7 +3954,7 @@ class ShapeEnv:
from torch._dynamo.source import TensorProperty, TensorPropertySource from torch._dynamo.source import TensorProperty, TensorPropertySource
size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple( size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(
ex_size, source, symbolic_context ex_size, source, symbolic_context
) )
stride = self._compute_symbolic_stride( stride = self._compute_symbolic_stride(
@ -4022,11 +4013,11 @@ class ShapeEnv:
], ],
are_sizes_static: bool, are_sizes_static: bool,
symbolic_context: SymbolicContext, symbolic_context: SymbolicContext,
) -> List[sympy.Expr]: ) -> list[sympy.Expr]:
from torch._dynamo.source import TensorProperty, TensorPropertySource from torch._dynamo.source import TensorProperty, TensorPropertySource
stride: List[Optional[sympy.Expr]] = [None] * len(size) stride: list[Optional[sympy.Expr]] = [None] * len(size)
candidates: Dict[Union[int, SymInt], sympy.Expr] = {} candidates: dict[Union[int, SymInt], sympy.Expr] = {}
# iterate over unbound strides in val ascending order with # iterate over unbound strides in val ascending order with
# index descending as a tie breaker since for cases like # index descending as a tie breaker since for cases like
@ -4590,7 +4581,7 @@ class ShapeEnv:
return c_render return c_render
return c.render(source) return c.render(source)
def produce_guards(self, *args: Any, **kwargs: Any) -> List[str]: def produce_guards(self, *args: Any, **kwargs: Any) -> list[str]:
""" """
Like produce_guards_verbose, but only returns the non-verbose guard expressions Like produce_guards_verbose, but only returns the non-verbose guard expressions
(no verbose guards produced.) (no verbose guards produced.)
@ -4603,7 +4594,7 @@ class ShapeEnv:
sources: Sequence[Source], sources: Sequence[Source],
source_ref: Callable[[Source], str] = lambda n: n.name(), source_ref: Callable[[Source], str] = lambda n: n.name(),
*, *,
guards: Optional[List[ShapeGuard]] = None, guards: Optional[list[ShapeGuard]] = None,
input_contexts: Optional[DimList[SymbolicContext]] = None, input_contexts: Optional[DimList[SymbolicContext]] = None,
# Encodes user-specified input shape equations of the form s = s' and s = fn(s'). # Encodes user-specified input shape equations of the form s = s' and s = fn(s').
# (See docs on EqualityConstraint for details of the encoding.) # (See docs on EqualityConstraint for details of the encoding.)
@ -4611,7 +4602,7 @@ class ShapeEnv:
_simplified: bool = False, _simplified: bool = False,
# Indicates if we should produce guards for known static values. # Indicates if we should produce guards for known static values.
ignore_static: bool = True, ignore_static: bool = True,
) -> Tuple[List[str], List[str]]: # python, verbose ) -> tuple[list[str], list[str]]: # python, verbose
""" """
Generates a list of guards strings which, when evaluated in a context that Generates a list of guards strings which, when evaluated in a context that
defines tensors for all the sources, returns True or False depending defines tensors for all the sources, returns True or False depending
@ -4740,13 +4731,13 @@ class ShapeEnv:
# the symbol mapping is # the symbol mapping is
input_guards = [] input_guards = []
symbol_to_source: Dict[sympy.Symbol, List[Source]] = collections.defaultdict( symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict(
list list
) )
symbol_to_constraints: DefaultDict[ symbol_to_constraints: defaultdict[
sympy.Symbol, Set[Constraint] sympy.Symbol, set[Constraint]
] = collections.defaultdict(set) ] = collections.defaultdict(set)
constraint_violations: List[Tuple[bool, str, Callable[[], str]]] = [] constraint_violations: list[tuple[bool, str, Callable[[], str]]] = []
py_printer = ShapeGuardPythonPrinter( py_printer = ShapeGuardPythonPrinter(
symbol_to_source, source_ref, self.var_to_sources symbol_to_source, source_ref, self.var_to_sources
@ -4956,7 +4947,7 @@ class ShapeEnv:
# For subclasses, we need to track symints on BOTH the outer # For subclasses, we need to track symints on BOTH the outer
# and inner tensors. # and inner tensors.
# TODO: type this better # TODO: type this better
sources_tensors_constraints: List[Tuple[Source, Any, Any, Any]] = [ sources_tensors_constraints: list[tuple[Source, Any, Any, Any]] = [
(source, t, context.constraint_sizes, context.constraint_strides) (source, t, context.constraint_sizes, context.constraint_strides)
] ]
attrs, _ = t.__tensor_flatten__() attrs, _ = t.__tensor_flatten__()
@ -5256,8 +5247,8 @@ class ShapeEnv:
) )
if constraint_violations: if constraint_violations:
warn_msgs: List[str] = [] warn_msgs: list[str] = []
error_msgs: List[str] = [] error_msgs: list[str] = []
debug_names = set() debug_names = set()
for warn_only, debug_name, msg_cb in constraint_violations: for warn_only, debug_name, msg_cb in constraint_violations:
if warn_only: if warn_only:
@ -5327,7 +5318,7 @@ class ShapeEnv:
self, self,
placeholders: Sequence[Union[SymInt, FakeTensor]], placeholders: Sequence[Union[SymInt, FakeTensor]],
*, *,
guards: Optional[List[ShapeGuard]] = None, guards: Optional[list[ShapeGuard]] = None,
ignore_static: bool = True, ignore_static: bool = True,
) -> Optional[str]: ) -> Optional[str]:
""" """
@ -5386,7 +5377,7 @@ class ShapeEnv:
return self.evaluate_guards_expression(code, args) return self.evaluate_guards_expression(code, args)
return True return True
def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> List[ShapeGuard]: def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> list[ShapeGuard]:
""" """
Get a list of guards, but pruned so it only provides guards that Get a list of guards, but pruned so it only provides guards that
reference symints from the passed in input reference symints from the passed in input
@ -5401,7 +5392,7 @@ class ShapeEnv:
def bind_symbols( def bind_symbols(
self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor] self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor]
) -> Dict[sympy.Symbol, int]: ) -> dict[sympy.Symbol, int]:
""" """
Given a paired list of placeholders (fake tensors with Given a paired list of placeholders (fake tensors with
symbolic sizes) and concrete arguments (regular tensors symbolic sizes) and concrete arguments (regular tensors
@ -5418,7 +5409,7 @@ class ShapeEnv:
another copy. This assumes the guards are already checked, another copy. This assumes the guards are already checked,
though if it's cheap we'll check for shenanigans though if it's cheap we'll check for shenanigans
""" """
bindings: Dict[sympy.Symbol, int] = {} bindings: dict[sympy.Symbol, int] = {}
def bind_symint(arg: object, val: object) -> None: def bind_symint(arg: object, val: object) -> None:
if isinstance(val, SymInt): if isinstance(val, SymInt):
@ -5451,7 +5442,7 @@ class ShapeEnv:
return bindings return bindings
def get_nontrivial_guards(self) -> List[SympyBoolean]: def get_nontrivial_guards(self) -> list[SympyBoolean]:
"""Returns a list of guard expressions that aren't statically known (i.e. not trivial)""" """Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""
return [ return [
self.simplify(guard.expr) self.simplify(guard.expr)
@ -5488,9 +5479,9 @@ class ShapeEnv:
@_lru_cache @_lru_cache
def get_axioms( def get_axioms(
self, self,
symbols: Optional[Tuple[sympy.Symbol]] = None, symbols: Optional[tuple[sympy.Symbol]] = None,
compute_hint: bool = False, compute_hint: bool = False,
) -> Tuple[SympyBoolean, ...]: ) -> tuple[SympyBoolean, ...]:
""" """
Given the symbols in an expression, it returns all the runtime asserts that have those symbols Given the symbols in an expression, it returns all the runtime asserts that have those symbols
concatenated with all the guards. concatenated with all the guards.
@ -5518,9 +5509,9 @@ class ShapeEnv:
@lru_cache(None) @lru_cache(None)
def get_implications( def get_implications(
self, e: SympyBoolean self, e: SympyBoolean
) -> Tuple[Tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]: ) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]:
"""Given a expression, it returns a list of predicates that follow from it""" """Given a expression, it returns a list of predicates that follow from it"""
equiv: Dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {} equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {}
def add_expr(expr: SympyBoolean) -> None: def add_expr(expr: SympyBoolean) -> None:
expr = canonicalize_bool_expr(expr) expr = canonicalize_bool_expr(expr)
@ -5564,8 +5555,8 @@ class ShapeEnv:
unbacked_only: bool = False, unbacked_only: bool = False,
compute_hint: bool = False, compute_hint: bool = False,
size_oblivious: bool = False, size_oblivious: bool = False,
axioms: Optional[Tuple[SympyBoolean]] = None, axioms: Optional[tuple[SympyBoolean]] = None,
var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None, var_to_range: Optional[tuple[tuple[sympy.Symbol, ValueRanges]]] = None,
) -> Optional[sympy.Basic]: ) -> Optional[sympy.Basic]:
""" """
Tries to evaluate expr without introducing guards Tries to evaluate expr without introducing guards
@ -5589,7 +5580,7 @@ class ShapeEnv:
expr = canonicalize_bool_expr(expr) expr = canonicalize_bool_expr(expr)
def resimplify_floor_div(axioms: Dict[sympy.Expr, sympy.Expr]) -> None: def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None:
if not self._resimplify_floor_div_axioms: if not self._resimplify_floor_div_axioms:
return return
self._resimplify_floor_div_axioms = False self._resimplify_floor_div_axioms = False
@ -6114,7 +6105,7 @@ class ShapeEnv:
# Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3). # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
# (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols) # (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
# Prefer to simplify out symbols with ephemeral sources. # Prefer to simplify out symbols with ephemeral sources.
def _smart_symbol_sort(x: sympy.Symbol) -> Tuple[int, int, str]: def _smart_symbol_sort(x: sympy.Symbol) -> tuple[int, int, str]:
has_only_ephemeral_sources = x in self.var_to_sources and all( has_only_ephemeral_sources = x in self.var_to_sources and all(
s.is_ephemeral() for s in self.var_to_sources[x] s.is_ephemeral() for s in self.var_to_sources[x]
) )
@ -6282,7 +6273,7 @@ class ShapeEnv:
def _get_stack_summary( def _get_stack_summary(
self, is_debug: bool = False, framework_loc: Optional[str] = None self, is_debug: bool = False, framework_loc: Optional[str] = None
) -> Tuple[SLoc, str]: ) -> tuple[SLoc, str]:
floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc
if floc is None: if floc is None:
frame = inspect.currentframe() frame = inspect.currentframe()
@ -6903,7 +6894,7 @@ class _PythonMsgPrinter(PythonPrinter):
(i.e., as ==, !=, >, <). (i.e., as ==, !=, >, <).
""" """
def __init__(self, src_map: Dict[str, List[str]]) -> None: def __init__(self, src_map: dict[str, list[str]]) -> None:
super().__init__() super().__init__()
self.src_map = src_map self.src_map = src_map
@ -6912,7 +6903,7 @@ class _PythonMsgPrinter(PythonPrinter):
def _suggest_torch_checks( def _suggest_torch_checks(
e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]] e: GuardOnDataDependentSymNode, src_map: defaultdict[str, list[str]]
) -> None: ) -> None:
# extract the unresolved condition on unbacked symints in the error # extract the unresolved condition on unbacked symints in the error
cond = e.cond cond = e.cond

View File

@ -5,7 +5,7 @@ import logging
import math import math
import operator import operator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from typing import Any, Callable, Optional, Union
import sympy import sympy
@ -60,7 +60,7 @@ try:
def z3str(e: z3.ExprRef) -> str: def z3str(e: z3.ExprRef) -> str:
assert z3.is_expr(e), f"unsupported expression type: {e}" assert z3.is_expr(e), f"unsupported expression type: {e}"
def get_args_str(e: z3.ExprRef) -> List[str]: def get_args_str(e: z3.ExprRef) -> list[str]:
return [z3str(e.arg(i)) for i in range(e.num_args())] return [z3str(e.arg(i)) for i in range(e.num_args())]
# First, we simplify the given expression. # First, we simplify the given expression.
@ -350,13 +350,13 @@ try:
super().__init__(module, garbage_collect_values=True) super().__init__(module, garbage_collect_values=True)
def placeholder( def placeholder(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
symbol = fx_traceback.get_current_meta()["symbol"] symbol = fx_traceback.get_current_meta()["symbol"]
return self.validator.z3var(symbol) return self.validator.z3var(symbol)
def call_function( def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
if target != torch._assert: if target != torch._assert:
# Lift and runs the node target function # Lift and runs the node target function
@ -481,21 +481,21 @@ try:
log.debug("new instance") log.debug("new instance")
# Mapping of SymPy symbols to Z3 variables. # Mapping of SymPy symbols to Z3 variables.
self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {} self.symbols: dict[sympy.Symbol, z3.ExprRef] = {}
# Set of source Z3 expressions. # Set of source Z3 expressions.
# They represent the generated guards without any kind of # They represent the generated guards without any kind of
# simplification or transformation. # simplification or transformation.
self._source_exprs: Set[z3.BoolRef] = set() self._source_exprs: set[z3.BoolRef] = set()
# Set of target Z3 expressions. # Set of target Z3 expressions.
# They represent the actual checked guards at runtime. They might # They represent the actual checked guards at runtime. They might
# be simplified or transformed versions of the source guards. # be simplified or transformed versions of the source guards.
self._target_exprs: Set[z3.BoolRef] = set() self._target_exprs: set[z3.BoolRef] = set()
# Set of Z3 expressions representing assertions over both the # Set of Z3 expressions representing assertions over both the
# source and target expressions. # source and target expressions.
self._assertions: Set[z3.BoolRef] = set() self._assertions: set[z3.BoolRef] = set()
# Retrieves the corresponding Z3 variable. # Retrieves the corresponding Z3 variable.
def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef: def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
@ -503,7 +503,7 @@ try:
return self.symbols[symbol] return self.symbols[symbol]
# Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists. # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef: def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef:
if symbol in self.symbols: if symbol in self.symbols:
return self.symbols[symbol] return self.symbols[symbol]
@ -769,7 +769,7 @@ def bisect(shape_env):
# Checks whether the given shape_env fails when produce_guards is called. # Checks whether the given shape_env fails when produce_guards is called.
def check_shapeenv_fails( def check_shapeenv_fails(
shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]] shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]]
) -> Optional[ValidationException]: ) -> Optional[ValidationException]:
assert tracked_fakes is not None assert tracked_fakes is not None
try: try:

View File

@ -11,23 +11,10 @@ import os
import re import re
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING
Any,
Callable,
Dict,
FrozenSet,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
)
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -47,11 +34,11 @@ if TYPE_CHECKING:
# Mapping of builtins to their `typing` equivalent. # Mapping of builtins to their `typing` equivalent.
_origin_type_map = { _origin_type_map = {
list: List, list: list,
dict: Dict, dict: dict,
set: Set, set: set,
frozenset: FrozenSet, frozenset: frozenset,
tuple: Tuple, tuple: tuple,
} }
_legal_ops = dict.fromkeys( _legal_ops = dict.fromkeys(
@ -61,7 +48,7 @@ _legal_ops = dict.fromkeys(
# Signature for functions thattransforms the body (`list[str]`) of the # Signature for functions thattransforms the body (`list[str]`) of the
# generated code # generated code
TransformCodeFunc = Callable[[List[str]], List[str]] TransformCodeFunc = Callable[[list[str]], list[str]]
class _CustomBuiltin(NamedTuple): class _CustomBuiltin(NamedTuple):
@ -78,7 +65,7 @@ class _CustomBuiltin(NamedTuple):
obj: Any obj: Any
_custom_builtins: Dict[str, _CustomBuiltin] = {} _custom_builtins: dict[str, _CustomBuiltin] = {}
def _register_custom_builtin(name: str, import_str: str, obj: Any): def _register_custom_builtin(name: str, import_str: str, obj: Any):
@ -144,10 +131,10 @@ class _Namespace:
""" """
def __init__(self): def __init__(self):
self._obj_to_name: Dict[Any, str] = {} self._obj_to_name: dict[Any, str] = {}
self._unassociated_names = set() self._unassociated_names = set()
self._used_names: Set[str] = set() self._used_names: set[str] = set()
self._base_count: Dict[str, int] = defaultdict(int) self._base_count: dict[str, int] = defaultdict(int)
self._illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") self._illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
@ -261,10 +248,10 @@ class PythonCode:
# Python source code for the forward function definition. # Python source code for the forward function definition.
src: str src: str
# Values in global scope during execution of `src_def`. # Values in global scope during execution of `src_def`.
globals: Dict[str, Any] globals: dict[str, Any]
# Optional mapping from the forward function's line number to # Optional mapping from the forward function's line number to
# node index. # node index.
_lineno_map: Optional[Dict[int, Optional[int]]] _lineno_map: Optional[dict[int, Optional[int]]]
def _format_target(base: str, target: str) -> str: def _format_target(base: str, target: str) -> str:
@ -311,7 +298,7 @@ class _PyTreeInfo(NamedTuple):
Contains extra info stored when we're using Pytrees Contains extra info stored when we're using Pytrees
""" """
orig_args: List[str] orig_args: list[str]
in_spec: pytree.TreeSpec in_spec: pytree.TreeSpec
out_spec: Optional[pytree.TreeSpec] out_spec: Optional[pytree.TreeSpec]
@ -359,7 +346,7 @@ class CodeGen:
self._body_transformer: Optional[TransformCodeFunc] = None self._body_transformer: Optional[TransformCodeFunc] = None
self._func_name: str = "forward" self._func_name: str = "forward"
def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: def gen_fn_def(self, free_vars: list[str], maybe_return_annotation: str) -> str:
""" """
Given the free variables and a return annotation, generates the beginning of the FX function. Given the free variables and a return annotation, generates the beginning of the FX function.
By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
@ -398,7 +385,7 @@ class CodeGen:
""" """
return outputs return outputs
def additional_globals(self) -> List[Tuple[str, Any]]: def additional_globals(self) -> list[tuple[str, Any]]:
""" """
If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
For example, return ['List', typing.List] if you need ``List`` in the global context. For example, return ['List', typing.List] if you need ``List`` in the global context.
@ -416,13 +403,13 @@ class CodeGen:
include_device: bool = False, include_device: bool = False,
colored: bool = False, colored: bool = False,
) -> PythonCode: ) -> PythonCode:
free_vars: List[str] = [] free_vars: list[str] = []
body: List[str] = [] body: list[str] = []
globals_: Dict[str, Any] = {} globals_: dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {} wrapped_fns: dict[str, None] = {}
# Wrap string in list to pass by reference # Wrap string in list to pass by reference
maybe_return_annotation: List[str] = [""] maybe_return_annotation: list[str] = [""]
include_stride = include_stride or ( include_stride = include_stride or (
os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1" os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1"
) )
@ -553,7 +540,7 @@ class CodeGen:
return blue(repr(arg)) return blue(repr(arg))
def _format_args( def _format_args(
args: Tuple[Argument, ...], kwargs: Dict[str, Argument] args: tuple[Argument, ...], kwargs: dict[str, Argument]
) -> str: ) -> str:
args_s = ", ".join(_get_repr(a) for a in args) args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
@ -565,8 +552,8 @@ class CodeGen:
# of a given node. This represents the *last* use of the node in the # of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused # execution order of the program, which we will use to free unused
# values # values
node_to_last_use: Dict[Node, Node] = {} node_to_last_use: dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {} user_to_last_uses: dict[Node, list[Node]] = {}
def register_last_uses(n: Node, user: Node): def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use: if n not in node_to_last_use:
@ -782,9 +769,9 @@ class CodeGen:
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
# remove counter and generate lineno to node index mapping # remove counter and generate lineno to node index mapping
lineno_map: Dict[int, Optional[int]] = {} lineno_map: dict[int, Optional[int]] = {}
prologue_len = prologue.count("\n") + 1 prologue_len = prologue.count("\n") + 1
new_lines: List[str] = [] new_lines: list[str] = []
cur_idx = None cur_idx = None
for line in "".join(body).split("\n"): for line in "".join(body).split("\n"):
counter = re.search(r"# COUNTER: (\d+)", line) counter = re.search(r"# COUNTER: (\d+)", line)
@ -904,11 +891,11 @@ class _FindNodesLookupTable:
""" """
def __init__(self): def __init__(self):
self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict( self.table: dict[tuple[str, Optional[Target]], dict[Node, None]] = defaultdict(
dict dict
) )
def _key(self, node) -> Tuple[str, Optional[Target]]: def _key(self, node) -> tuple[str, Optional[Target]]:
return (node.op, node.target if node.op == "call_function" else None) return (node.op, node.target if node.op == "call_function" else None)
def __contains__(self, node) -> bool: def __contains__(self, node) -> bool:
@ -985,14 +972,14 @@ class Graph:
def __init__( def __init__(
self, self,
owning_module: Optional["GraphModule"] = None, owning_module: Optional["GraphModule"] = None,
tracer_cls: Optional[Type["Tracer"]] = None, tracer_cls: Optional[type["Tracer"]] = None,
tracer_extras: Optional[Dict[str, Any]] = None, tracer_extras: Optional[dict[str, Any]] = None,
): ):
""" """
Construct an empty Graph. Construct an empty Graph.
""" """
self._root: Node = Node(self, "", "root", "", (), {}) self._root: Node = Node(self, "", "root", "", (), {})
self._used_names: Dict[str, int] = {} # base name -> number self._used_names: dict[str, int] = {} # base name -> number
self._insert = self._root.prepend self._insert = self._root.prepend
self._len = 0 self._len = 0
self._graph_namespace = _Namespace() self._graph_namespace = _Namespace()
@ -1000,7 +987,7 @@ class Graph:
self._tracer_cls = tracer_cls self._tracer_cls = tracer_cls
self._tracer_extras = tracer_extras self._tracer_extras = tracer_extras
self._codegen = CodeGen() self._codegen = CodeGen()
self._co_fields: Dict[str, Any] = {} self._co_fields: dict[str, Any] = {}
self._find_nodes_lookup_table = _FindNodesLookupTable() self._find_nodes_lookup_table = _FindNodesLookupTable()
@property @property
@ -1060,7 +1047,7 @@ class Graph:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def graph_copy( def graph_copy(
self, g: "Graph", val_map: Dict[Node, Node], return_output_node=False self, g: "Graph", val_map: dict[Node, Node], return_output_node=False
) -> "Optional[Argument]": ) -> "Optional[Argument]":
""" """
Copy all nodes from a given graph into ``self``. Copy all nodes from a given graph into ``self``.
@ -1113,8 +1100,8 @@ class Graph:
self, self,
op: str, op: str,
target: "Target", target: "Target",
args: Optional[Tuple["Argument", ...]] = None, args: Optional[tuple["Argument", ...]] = None,
kwargs: Optional[Dict[str, "Argument"]] = None, kwargs: Optional[dict[str, "Argument"]] = None,
name: Optional[str] = None, name: Optional[str] = None,
type_expr: Optional[Any] = None, type_expr: Optional[Any] = None,
) -> Node: ) -> Node:
@ -1373,8 +1360,8 @@ class Graph:
def call_module( def call_module(
self, self,
module_name: str, module_name: str,
args: Optional[Tuple["Argument", ...]] = None, args: Optional[tuple["Argument", ...]] = None,
kwargs: Optional[Dict[str, "Argument"]] = None, kwargs: Optional[dict[str, "Argument"]] = None,
type_expr: Optional[Any] = None, type_expr: Optional[Any] = None,
) -> Node: ) -> Node:
""" """
@ -1423,8 +1410,8 @@ class Graph:
def call_method( def call_method(
self, self,
method_name: str, method_name: str,
args: Optional[Tuple["Argument", ...]] = None, args: Optional[tuple["Argument", ...]] = None,
kwargs: Optional[Dict[str, "Argument"]] = None, kwargs: Optional[dict[str, "Argument"]] = None,
type_expr: Optional[Any] = None, type_expr: Optional[Any] = None,
) -> Node: ) -> Node:
""" """
@ -1462,8 +1449,8 @@ class Graph:
def call_function( def call_function(
self, self,
the_function: Callable[..., Any], the_function: Callable[..., Any],
args: Optional[Tuple["Argument", ...]] = None, args: Optional[tuple["Argument", ...]] = None,
kwargs: Optional[Dict[str, "Argument"]] = None, kwargs: Optional[dict[str, "Argument"]] = None,
type_expr: Optional[Any] = None, type_expr: Optional[Any] = None,
) -> Node: ) -> Node:
""" """
@ -1668,10 +1655,10 @@ class Graph:
Return a human-readable (not machine-readable) string representation Return a human-readable (not machine-readable) string representation
of this Graph of this Graph
""" """
placeholder_names: List[str] = [] placeholder_names: list[str] = []
# This is a one-element array just so ``format_node`` can modify the closed # This is a one-element array just so ``format_node`` can modify the closed
# over value # over value
maybe_return_typename: List[str] = [""] maybe_return_typename: list[str] = [""]
node_strs = [node.format_node(placeholder_names) for node in self.nodes] node_strs = [node.format_node(placeholder_names) for node in self.nodes]
param_str = ", ".join(placeholder_names) param_str = ", ".join(placeholder_names)
@ -1729,8 +1716,8 @@ class Graph:
f"defined! Please check that Nodes in the graph are topologically ordered\n{self}" f"defined! Please check that Nodes in the graph are topologically ordered\n{self}"
) )
seen_names: Set[str] = set() seen_names: set[str] = set()
seen_values: Set[Node] = set() seen_values: set[Node] = set()
for node in self.nodes: for node in self.nodes:
if node.op not in [ if node.op not in [
"placeholder", "placeholder",

View File

@ -8,7 +8,7 @@ import sys
import traceback import traceback
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -39,7 +39,7 @@ class _EvalCacheLoader:
self.eval_cache = {} self.eval_cache = {}
self.next_id = 0 self.next_id = 0
def cache(self, src: str, globals: Dict[str, Any], co_fields=None): def cache(self, src: str, globals: dict[str, Any], co_fields=None):
"""Store the source in a private cache, and add a lazy entry in linecache """Store the source in a private cache, and add a lazy entry in linecache
that allows the source to be retrieved by 'filename'. that allows the source to be retrieved by 'filename'.
@ -83,19 +83,19 @@ class _EvalCacheLoader:
_loader = _EvalCacheLoader() _loader = _EvalCacheLoader()
def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None): def _exec_with_source(src: str, globals: dict[str, Any], co_fields=None):
key = _loader.cache(src, globals, co_fields) key = _loader.cache(src, globals, co_fields)
exec(compile(src, key, "exec"), globals) exec(compile(src, key, "exec"), globals)
def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None): def _forward_from_src(src: str, globals: dict[str, Any], co_fields=None):
return _method_from_src( return _method_from_src(
method_name="forward", src=src, globals=globals, co_fields=co_fields method_name="forward", src=src, globals=globals, co_fields=co_fields
) )
def _method_from_src( def _method_from_src(
method_name: str, src: str, globals: Dict[str, Any], co_fields=None method_name: str, src: str, globals: dict[str, Any], co_fields=None
) -> Callable: ) -> Callable:
# avoid mutating the passed in dict # avoid mutating the passed in dict
globals_copy = globals.copy() globals_copy = globals.copy()
@ -114,8 +114,8 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
return f"from {module_name} import {attr_name} as {name}" return f"from {module_name} import {attr_name} as {name}"
def _format_import_block(globals: Dict[str, Any], importer: Importer): def _format_import_block(globals: dict[str, Any], importer: Importer):
import_strs: Set[str] = { import_strs: set[str] = {
_format_import_statement(name, obj, importer) for name, obj in globals.items() _format_import_statement(name, obj, importer) for name, obj in globals.items()
} }
# Sort the imports so we have a stable import block that allows us to # Sort the imports so we have a stable import block that allows us to
@ -124,7 +124,7 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer):
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Module:
# BC: attribute name was changed from `code` to `_code` to facilitate # BC: attribute name was changed from `code` to `_code` to facilitate
# making `code` into a property and adding a docstring to it # making `code` into a property and adding a docstring to it
fn_src = body.get("_code") or body["code"] fn_src = body.get("_code") or body["code"]
@ -134,7 +134,7 @@ def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mod
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def reduce_package_graph_module( def reduce_package_graph_module(
importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str importer: PackageImporter, body: dict[Any, Any], generated_module_name: str
) -> torch.nn.Module: ) -> torch.nn.Module:
forward = importer.import_module(generated_module_name).forward forward = importer.import_module(generated_module_name).forward
return _deserialize_graph_module(forward, body) return _deserialize_graph_module(forward, body)
@ -142,7 +142,7 @@ def reduce_package_graph_module(
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def reduce_deploy_graph_module( def reduce_deploy_graph_module(
importer: PackageImporter, body: Dict[Any, Any], import_block: str importer: PackageImporter, body: dict[Any, Any], import_block: str
) -> torch.nn.Module: ) -> torch.nn.Module:
ns = {} ns = {}
ns["__builtins__"] = importer.patched_builtins ns["__builtins__"] = importer.patched_builtins
@ -162,7 +162,7 @@ class _CodeOnlyModule(torch.nn.Module):
def _deserialize_graph_module( def _deserialize_graph_module(
forward, body: Dict[Any, Any], graph_module_cls=None forward, body: dict[Any, Any], graph_module_cls=None
) -> torch.nn.Module: ) -> torch.nn.Module:
""" """
Deserialize a GraphModule given the dictionary of the original module, Deserialize a GraphModule given the dictionary of the original module,
@ -271,7 +271,7 @@ def _get_attr(model: torch.nn.Module, attr_name: str):
return _get_attr_via_attr_list(model, attr_name.split(".")) return _get_attr_via_attr_list(model, attr_name.split("."))
def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: List[str]): def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]):
if len(attr_list) == 0: if len(attr_list) == 0:
return model return model
*prefix, field = attr_list *prefix, field = attr_list
@ -415,7 +415,7 @@ class GraphModule(torch.nn.Module):
code. code.
""" """
def __new__(cls: "Type[GraphModule]", *args, **kwargs): def __new__(cls: "type[GraphModule]", *args, **kwargs):
# each instance of a graph module needs its own forward method # each instance of a graph module needs its own forward method
# so create a new singleton class for each instance. # so create a new singleton class for each instance.
# it is a subclass of the user-defined class, the only difference # it is a subclass of the user-defined class, the only difference
@ -437,7 +437,7 @@ class GraphModule(torch.nn.Module):
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def __init__( def __init__(
self, self,
root: Union[torch.nn.Module, Dict[str, Any]], root: Union[torch.nn.Module, dict[str, Any]],
graph: Graph, graph: Graph,
class_name: str = "GraphModule", class_name: str = "GraphModule",
): ):
@ -527,12 +527,12 @@ class GraphModule(torch.nn.Module):
self._tracer_extras = self.graph._tracer_extras self._tracer_extras = self.graph._tracer_extras
# Dictionary to store metadata # Dictionary to store metadata
self.meta: Dict[str, Any] = {} self.meta: dict[str, Any] = {}
self._replace_hooks: List[Callable] = [] self._replace_hooks: list[Callable] = []
self._create_node_hooks: List[Callable] = [] self._create_node_hooks: list[Callable] = []
self._erase_node_hooks: List[Callable] = [] self._erase_node_hooks: list[Callable] = []
# Used to remove hooks from deepcopied graph modules within a context manager. # Used to remove hooks from deepcopied graph modules within a context manager.
self._deepcopy_hooks: List[Callable] = [] self._deepcopy_hooks: list[Callable] = []
# TorchScript breaks trying to compile the graph setter because of the # TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
@ -739,7 +739,7 @@ class {module_name}(torch.nn.Module):
This method can be called to clean up an ``nn.Module`` without This method can be called to clean up an ``nn.Module`` without
manually calling ``delete_submodule`` on each unused submodule. manually calling ``delete_submodule`` on each unused submodule.
""" """
used: List[str] = [] used: list[str] = []
for node in self.graph.nodes: for node in self.graph.nodes:
if node.op == "call_module" or node.op == "get_attr": if node.op == "call_module" or node.op == "get_attr":

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import Any, Dict, Iterable, List, Tuple from collections.abc import Iterable
from typing import Any
from torch.utils._pytree import ( from torch.utils._pytree import (
_dict_flatten, _dict_flatten,
@ -79,25 +80,25 @@ compatibility(is_backward_compatible=True)(immutable_dict)
# Register immutable collections for PyTree operations # Register immutable collections for PyTree operations
def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: def _immutable_dict_flatten(d: dict[Any, Any]) -> tuple[list[Any], Context]:
return _dict_flatten(d) return _dict_flatten(d)
def _immutable_dict_unflatten( def _immutable_dict_unflatten(
values: Iterable[Any], values: Iterable[Any],
context: Context, context: Context,
) -> Dict[Any, Any]: ) -> dict[Any, Any]:
return immutable_dict(_dict_unflatten(values, context)) return immutable_dict(_dict_unflatten(values, context))
def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: def _immutable_list_flatten(d: list[Any]) -> tuple[list[Any], Context]:
return _list_flatten(d) return _list_flatten(d)
def _immutable_list_unflatten( def _immutable_list_unflatten(
values: Iterable[Any], values: Iterable[Any],
context: Context, context: Context,
) -> List[Any]: ) -> list[Any]:
return immutable_list(_list_unflatten(values, context)) return immutable_list(_list_unflatten(values, context))

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import inspect import inspect
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Optional, TYPE_CHECKING, Union
import torch import torch
import torch.fx.traceback as fx_traceback import torch.fx.traceback as fx_traceback
@ -17,6 +17,10 @@ from .node import Argument, map_aggregate, map_arg, Node, Target
from .proxy import Proxy from .proxy import Proxy
if TYPE_CHECKING:
from collections.abc import Iterator
__all__ = ["Interpreter", "Transformer"] __all__ = ["Interpreter", "Transformer"]
@ -92,7 +96,7 @@ class Interpreter:
self.graph = graph self.graph = graph
else: else:
self.graph = self.module.graph # type: ignore[assignment] self.graph = self.module.graph # type: ignore[assignment]
self.env: Dict[Node, Any] = {} self.env: dict[Node, Any] = {}
self.name = "Interpreter" self.name = "Interpreter"
self.garbage_collect_values = garbage_collect_values self.garbage_collect_values = garbage_collect_values
self.extra_traceback = True self.extra_traceback = True
@ -102,8 +106,8 @@ class Interpreter:
# of a given node. This represents the *last* use of the node in the # of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused # execution order of the program, which we will use to free unused
# values # values
node_to_last_use: Dict[Node, Node] = {} node_to_last_use: dict[Node, Node] = {}
self.user_to_last_uses: Dict[Node, List[Node]] = {} self.user_to_last_uses: dict[Node, list[Node]] = {}
def register_last_uses(n: Node, user: Node): def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use: if n not in node_to_last_use:
@ -118,7 +122,7 @@ class Interpreter:
def run( def run(
self, self,
*args, *args,
initial_env: Optional[Dict[Node, Any]] = None, initial_env: Optional[dict[Node, Any]] = None,
enable_io_processing: bool = True, enable_io_processing: bool = True,
) -> Any: ) -> Any:
""" """
@ -232,7 +236,7 @@ class Interpreter:
# Main Node running APIs # Main Node running APIs
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def placeholder( def placeholder(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
""" """
Execute a ``placeholder`` node. Note that this is stateful: Execute a ``placeholder`` node. Note that this is stateful:
@ -268,7 +272,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def get_attr( def get_attr(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
""" """
Execute a ``get_attr`` node. Will retrieve an attribute Execute a ``get_attr`` node. Will retrieve an attribute
@ -289,7 +293,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_function( def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
""" """
Execute a ``call_function`` node and return the result. Execute a ``call_function`` node and return the result.
@ -311,7 +315,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_method( def call_method(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
""" """
Execute a ``call_method`` node and return the result. Execute a ``call_method`` node and return the result.
@ -335,7 +339,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_module( def call_module(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
""" """
Execute a ``call_module`` node and return the result. Execute a ``call_module`` node and return the result.
@ -360,7 +364,7 @@ class Interpreter:
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def output( def output(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
""" """
Execute an ``output`` node. This really just retrieves Execute an ``output`` node. This really just retrieves
@ -401,7 +405,7 @@ class Interpreter:
return attr_itr return attr_itr
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def fetch_args_kwargs_from_env(self, n: Node) -> Tuple[Tuple, Dict]: def fetch_args_kwargs_from_env(self, n: Node) -> tuple[tuple, dict]:
""" """
Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
from the current execution environment. from the current execution environment.
@ -497,7 +501,7 @@ class Transformer(Interpreter):
def __init__(self, graph: Graph): def __init__(self, graph: Graph):
super().__init__() super().__init__()
self.graph = graph self.graph = graph
self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment] self.tensor_attrs: dict[torch.Tensor, str] = {} # type: ignore[assignment]
def is_leaf_module(self, _, __) -> bool: def is_leaf_module(self, _, __) -> bool:
return True return True
@ -507,7 +511,7 @@ class Transformer(Interpreter):
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def placeholder( def placeholder(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Proxy: ) -> Proxy:
""" """
Execute a ``placeholder`` node. In ``Transformer``, this is Execute a ``placeholder`` node. In ``Transformer``, this is
@ -529,7 +533,7 @@ class Transformer(Interpreter):
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def get_attr( def get_attr(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Proxy: ) -> Proxy:
""" """
Execute a ``get_attr`` node. In ``Transformer``, this is Execute a ``get_attr`` node. In ``Transformer``, this is
@ -548,7 +552,7 @@ class Transformer(Interpreter):
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_module( def call_module(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
# Override so that the leaf module policy from `self.tracer` is respected. # Override so that the leaf module policy from `self.tracer` is respected.
assert isinstance(target, str) assert isinstance(target, str)
@ -557,7 +561,7 @@ class Transformer(Interpreter):
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def call_function( def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
) -> Any: ) -> Any:
# Override so that functions that were wrapped are still wrapped. # Override so that functions that were wrapped are still wrapped.
return self.tracer.create_proxy("call_function", target, args, kwargs) return self.tracer.create_proxy("call_function", target, args, kwargs)

View File

@ -3,19 +3,8 @@ import builtins
import inspect import inspect
import types import types
import warnings import warnings
from typing import ( from collections.abc import Mapping, Sequence
Any, from typing import Any, Callable, Optional, TYPE_CHECKING, Union
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
import torch import torch
from torch._C import _NodeBase from torch._C import _NodeBase
@ -57,7 +46,7 @@ Target = Union[Callable[..., Any], str]
Argument = Optional[ Argument = Optional[
Union[ Union[
Tuple["Argument", ...], tuple["Argument", ...],
Sequence["Argument"], Sequence["Argument"],
Mapping[str, "Argument"], Mapping[str, "Argument"],
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
@ -79,7 +68,7 @@ _legal_ops = dict.fromkeys(
] ]
) )
_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { _side_effectful_need_to_be_preserved_pre_dispatch: set[Callable] = {
torch._C._set_grad_enabled, torch._C._set_grad_enabled,
torch.amp._enter_autocast, torch.amp._enter_autocast,
torch.amp._exit_autocast, torch.amp._exit_autocast,
@ -87,7 +76,7 @@ _side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, # TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
# or add logic to correctly mark all inplace ops as side effectful. # or add logic to correctly mark all inplace ops as side effectful.
_side_effectful_functions: Set[Callable] = { _side_effectful_functions: set[Callable] = {
torch._assert, torch._assert,
torch._assert_async, torch._assert_async,
_ops.aten._assert_async.msg, _ops.aten._assert_async.msg,
@ -227,18 +216,18 @@ class Node(_NodeBase):
in the Graph printout. in the Graph printout.
""" """
_args: Tuple["Argument", ...] _args: tuple["Argument", ...]
_kwargs: Dict[str, "Argument"] _kwargs: dict[str, "Argument"]
graph: "Graph" graph: "Graph"
name: str name: str
op: str op: str
target: "Target" target: "Target"
_input_nodes: Dict["Node", None] _input_nodes: dict["Node", None]
users: Dict["Node", None] users: dict["Node", None]
type: Optional[Any] type: Optional[Any]
_sort_key: Any _sort_key: Any
_repr_fn: Optional[Callable[["Node"], str]] _repr_fn: Optional[Callable[["Node"], str]]
meta: Dict[str, Any] meta: dict[str, Any]
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def __init__( def __init__(
@ -247,8 +236,8 @@ class Node(_NodeBase):
name: str, name: str,
op: str, op: str,
target: "Target", target: "Target",
args: Tuple["Argument", ...], args: tuple["Argument", ...],
kwargs: Dict[str, "Argument"], kwargs: dict[str, "Argument"],
return_type: Optional[Any] = None, return_type: Optional[Any] = None,
) -> None: ) -> None:
""" """
@ -339,14 +328,14 @@ class Node(_NodeBase):
# transformations. This metadata is preserved across node copies # transformations. This metadata is preserved across node copies
assign(self, "meta", {}) assign(self, "meta", {})
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy() state = self.__dict__.copy()
state["_erased"] = self._erased state["_erased"] = self._erased
state["_prev"] = self._prev state["_prev"] = self._prev
state["_next"] = self._next state["_next"] = self._next
return state return state
def __setstate__(self, state: Dict[str, Any]) -> None: def __setstate__(self, state: dict[str, Any]) -> None:
_erased = state.pop("_erased") _erased = state.pop("_erased")
_prev = state.pop("_prev") _prev = state.pop("_prev")
_next = state.pop("_next") _next = state.pop("_next")
@ -442,7 +431,7 @@ class Node(_NodeBase):
p._next, n._prev = n, p p._next, n._prev = n, p
@property @property
def args(self) -> Tuple[Argument, ...]: def args(self) -> tuple[Argument, ...]:
""" """
The tuple of arguments to this ``Node``. The interpretation of arguments The tuple of arguments to this ``Node``. The interpretation of arguments
depends on the node's opcode. See the :class:`Node` docstring for more depends on the node's opcode. See the :class:`Node` docstring for more
@ -454,7 +443,7 @@ class Node(_NodeBase):
return self._args return self._args
@args.setter @args.setter
def args(self, a: Tuple[Argument, ...]) -> None: def args(self, a: tuple[Argument, ...]) -> None:
""" """
Set the tuple of arguments to this Node. The interpretation of arguments Set the tuple of arguments to this Node. The interpretation of arguments
depends on the node's opcode. See the ``fx.Graph`` docstring for more depends on the node's opcode. See the ``fx.Graph`` docstring for more
@ -465,7 +454,7 @@ class Node(_NodeBase):
self.__update_args_kwargs(a, self._kwargs) self.__update_args_kwargs(a, self._kwargs)
@property @property
def kwargs(self) -> Dict[str, Argument]: def kwargs(self) -> dict[str, Argument]:
""" """
The dict of keyword arguments to this ``Node``. The interpretation of arguments The dict of keyword arguments to this ``Node``. The interpretation of arguments
depends on the node's opcode. See the :class:`Node` docstring for more depends on the node's opcode. See the :class:`Node` docstring for more
@ -477,7 +466,7 @@ class Node(_NodeBase):
return self._kwargs return self._kwargs
@kwargs.setter @kwargs.setter
def kwargs(self, k: Dict[str, Argument]) -> None: def kwargs(self, k: dict[str, Argument]) -> None:
""" """
Set the dict of kwargs to this Node. The interpretation of arguments Set the dict of kwargs to this Node. The interpretation of arguments
depends on the node's opcode. See the ``fx.Graph`` docstring for more depends on the node's opcode. See the ``fx.Graph`` docstring for more
@ -488,7 +477,7 @@ class Node(_NodeBase):
self.__update_args_kwargs(self._args, k) self.__update_args_kwargs(self._args, k)
@property @property
def all_input_nodes(self) -> List["Node"]: def all_input_nodes(self) -> list["Node"]:
""" """
Return all Nodes that are inputs to this Node. This is equivalent to Return all Nodes that are inputs to this Node. This is equivalent to
iterating over ``args`` and ``kwargs`` and only collecting the values that iterating over ``args`` and ``kwargs`` and only collecting the values that
@ -534,7 +523,7 @@ class Node(_NodeBase):
self._args = args_left + (arg,) + args_right self._args = args_left + (arg,) + args_right
_new_input_nodes: Dict[Node, None] = {} _new_input_nodes: dict[Node, None] = {}
map_arg(arg, _new_input_nodes.setdefault) map_arg(arg, _new_input_nodes.setdefault)
for new_use in _new_input_nodes.keys(): for new_use in _new_input_nodes.keys():
@ -574,7 +563,7 @@ class Node(_NodeBase):
self.meta["stack_trace"] = trace self.meta["stack_trace"] = trace
def __update_args_kwargs( def __update_args_kwargs(
self, new_args: Tuple["Argument", ...], new_kwargs: Dict[str, "Argument"] self, new_args: tuple["Argument", ...], new_kwargs: dict[str, "Argument"]
) -> None: ) -> None:
""" """
This API is internal. Do *not* call it directly. This API is internal. Do *not* call it directly.
@ -634,8 +623,8 @@ class Node(_NodeBase):
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def format_node( def format_node(
self, self,
placeholder_names: Optional[List[str]] = None, placeholder_names: Optional[list[str]] = None,
maybe_return_typename: Optional[List[str]] = None, maybe_return_typename: Optional[list[str]] = None,
) -> Optional[str]: ) -> Optional[str]:
""" """
Return a descriptive string representation of ``self``. Return a descriptive string representation of ``self``.
@ -704,7 +693,7 @@ class Node(_NodeBase):
delete_user_cb: Callable[["Node"], bool] = lambda user: True, delete_user_cb: Callable[["Node"], bool] = lambda user: True,
*, *,
propagate_meta: bool = False, propagate_meta: bool = False,
) -> List["Node"]: ) -> list["Node"]:
""" """
Replace all uses of ``self`` in the Graph with the Node ``replace_with``. Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
@ -775,7 +764,7 @@ class Node(_NodeBase):
# impure since it mutates inputs # impure since it mutates inputs
return True return True
tags: Optional[List[torch.Tag]] = getattr(self.target, "_tags", None) tags: Optional[list[torch.Tag]] = getattr(self.target, "_tags", None)
if tags is not None and torch.Tag.nondeterministic_seeded in tags: if tags is not None and torch.Tag.nondeterministic_seeded in tags:
# impure since it mutates RNG state # impure since it mutates RNG state
return True return True
@ -799,8 +788,8 @@ class Node(_NodeBase):
def normalized_arguments( def normalized_arguments(
self, self,
root: torch.nn.Module, root: torch.nn.Module,
arg_types: Optional[Tuple[Any]] = None, arg_types: Optional[tuple[Any]] = None,
kwarg_types: Optional[Dict[str, Any]] = None, kwarg_types: Optional[dict[str, Any]] = None,
normalize_to_only_use_kwargs: bool = False, normalize_to_only_use_kwargs: bool = False,
) -> Optional[ArgsKwargsPair]: ) -> Optional[ArgsKwargsPair]:
""" """

View File

@ -5,17 +5,7 @@ import numbers
import types import types
import typing import typing
import warnings import warnings
from typing import ( from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING
Any,
Callable,
cast,
Dict,
List,
NamedTuple,
Optional,
Tuple,
TYPE_CHECKING,
)
import torch import torch
from torch._jit_internal import boolean_dispatched from torch._jit_internal import boolean_dispatched
@ -44,11 +34,11 @@ class ArgsKwargsPair(NamedTuple):
Simple named tuple for wrapping args/kwargs pairs. Simple named tuple for wrapping args/kwargs pairs.
""" """
args: Tuple[Any, ...] args: tuple[Any, ...]
kwargs: Dict[str, Any] kwargs: dict[str, Any]
_manual_overrides: Dict[Callable, List[inspect.Signature]] = {} _manual_overrides: dict[Callable, list[inspect.Signature]] = {}
def _nonzero_schemas(): def _nonzero_schemas():
@ -108,7 +98,7 @@ def _torchscript_schema_to_signature_impl(
) -> inspect.Signature: ) -> inspect.Signature:
from inspect import Parameter from inspect import Parameter
parameters: List[Parameter] = [] parameters: list[Parameter] = []
for arg in ts_schema.arguments: for arg in ts_schema.arguments:
arg_type = _torchscript_type_to_python_type(arg.type) arg_type = _torchscript_type_to_python_type(arg.type)
default = arg.default_value if arg.has_default_value() else Parameter.empty default = arg.default_value if arg.has_default_value() else Parameter.empty
@ -154,7 +144,7 @@ def _torchscript_schema_to_signature_impl(
return inspect.Signature(parameters, return_annotation=return_type) return inspect.Signature(parameters, return_annotation=return_type)
_SCHEMA_TO_SIGNATURE_CACHE: Dict[Tuple[str, str], inspect.Signature] = {} _SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {}
def _torchscript_schema_to_signature( def _torchscript_schema_to_signature(
@ -173,7 +163,7 @@ def _torchscript_schema_to_signature(
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def check_for_mutable_operation( def check_for_mutable_operation(
target: Callable, args: Tuple["Argument", ...], kwargs: Dict[str, "Argument"] target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"]
): ):
signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
@ -265,12 +255,12 @@ def create_type_hint(x):
if isinstance(x, list): if isinstance(x, list):
def ret_type(x): def ret_type(x):
return List[x] # type: ignore[valid-type] return list[x] # type: ignore[valid-type]
else: else:
def ret_type(x): def ret_type(x):
return Tuple[x, ...] return tuple[x, ...] # type: ignore[valid-type]
if len(x) == 0: if len(x) == 0:
return ret_type(Any) return ret_type(Any)
@ -291,6 +281,10 @@ def create_type_hint(x):
return x return x
_LIST_TYPES = (list, typing.List) # noqa: UP006
_TUPLE_TYPES = (tuple, typing.Tuple) # noqa: UP006
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def type_matches(signature_type: Any, argument_type: Any): def type_matches(signature_type: Any, argument_type: Any):
sig_origin_type = getattr(signature_type, "__origin__", signature_type) sig_origin_type = getattr(signature_type, "__origin__", signature_type)
@ -304,22 +298,24 @@ def type_matches(signature_type: Any, argument_type: Any):
sig_contained = signature_type.__args__ sig_contained = signature_type.__args__
return any(type_matches(c, argument_type) for c in sig_contained) return any(type_matches(c, argument_type) for c in sig_contained)
if signature_type is List[int] and argument_type is int: if signature_type is typing.List[int] and argument_type is int: # noqa: UP006
# int can be promoted to List[int] # int can be promoted to List[int]
return True return True
if getattr(signature_type, "__origin__", None) in {list, List}: if getattr(signature_type, "__origin__", None) in _LIST_TYPES:
sig_el_type = signature_type.__args__[0] sig_el_type = signature_type.__args__[0]
if sig_el_type is argument_type:
return True
if not inspect.isclass(sig_el_type): if not inspect.isclass(sig_el_type):
warnings.warn( warnings.warn(
f"Does not support nested parametric types, got {signature_type}. Please file a bug." f"Does not support nested parametric types, got {signature_type}. Please file a bug."
) )
return False return False
if getattr(argument_type, "__origin__", None) in {list, List}: if getattr(argument_type, "__origin__", None) in _LIST_TYPES:
return issubclass(argument_type.__args__[0], sig_el_type) return issubclass(argument_type.__args__[0], sig_el_type)
def is_homogeneous_tuple(t): def is_homogeneous_tuple(t):
if getattr(t, "__origin__", None) not in {tuple, Tuple}: if getattr(t, "__origin__", None) not in _TUPLE_TYPES:
return False return False
contained = t.__args__ contained = t.__args__
if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
@ -344,10 +340,10 @@ def type_matches(signature_type: Any, argument_type: Any):
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def normalize_function( def normalize_function(
target: Callable, target: Callable,
args: Tuple[Any], args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
arg_types: Optional[Tuple[Any]] = None, arg_types: Optional[tuple[Any]] = None,
kwarg_types: Optional[Dict[str, Any]] = None, kwarg_types: Optional[dict[str, Any]] = None,
normalize_to_only_use_kwargs: bool = False, normalize_to_only_use_kwargs: bool = False,
) -> Optional[ArgsKwargsPair]: ) -> Optional[ArgsKwargsPair]:
""" """
@ -424,7 +420,7 @@ def normalize_function(
) )
else: else:
if arg_types is not None or kwarg_types is not None: if arg_types is not None or kwarg_types is not None:
arg_types = arg_types if arg_types else cast(Tuple[Any], ()) arg_types = arg_types if arg_types else cast(tuple[Any], ())
kwarg_types = kwarg_types if kwarg_types else {} kwarg_types = kwarg_types if kwarg_types else {}
for candidate_signature in torch_op_schemas: for candidate_signature in torch_op_schemas:
sig_matches = True sig_matches = True
@ -468,8 +464,8 @@ def normalize_function(
def normalize_module( def normalize_module(
root: torch.nn.Module, root: torch.nn.Module,
target: str, target: str,
args: Tuple[Any], args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
normalize_to_only_use_kwargs: bool = False, normalize_to_only_use_kwargs: bool = False,
) -> Optional[ArgsKwargsPair]: ) -> Optional[ArgsKwargsPair]:
""" """
@ -513,8 +509,8 @@ def normalize_module(
def _args_kwargs_to_normalized_args_kwargs( def _args_kwargs_to_normalized_args_kwargs(
sig: inspect.Signature, sig: inspect.Signature,
args: Tuple[Any, ...], args: tuple[Any, ...],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
normalize_to_only_use_kwargs: bool, normalize_to_only_use_kwargs: bool,
) -> Optional[ArgsKwargsPair]: ) -> Optional[ArgsKwargsPair]:
""" """
@ -552,8 +548,8 @@ def _args_kwargs_to_normalized_args_kwargs(
bound_args = sig.bind(*args, **kwargs) bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults() bound_args.apply_defaults()
new_kwargs: Dict[str, Any] = {} new_kwargs: dict[str, Any] = {}
new_args: List[Any] = [] new_args: list[Any] = []
for i, param in enumerate(sig.parameters): for i, param in enumerate(sig.parameters):
if not normalize_to_only_use_kwargs and i < len(args): if not normalize_to_only_use_kwargs and i < len(args):
new_args.append(bound_args.arguments[param]) new_args.append(bound_args.arguments[param])

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Any, List, Set, Union from typing import Any, Union
from sympy import Integer, Number, Symbol from sympy import Integer, Number, Symbol
from sympy.logic.boolalg import BooleanAtom from sympy.logic.boolalg import BooleanAtom
@ -28,7 +28,7 @@ from torch.utils._sympy.reference import TensorReferenceAnalysis
from torch.utils._sympy.symbol import symbol_is_type, SymT from torch.utils._sympy.symbol import symbol_is_type, SymT
__all__: List[str] = [] __all__: list[str] = []
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
@ -242,7 +242,7 @@ def tensorify_python_scalars(
if node.op == "call_function" and ( if node.op == "call_function" and (
replacement_op := SUPPORTED_OPS.get(node.target) replacement_op := SUPPORTED_OPS.get(node.target)
): ):
args: List[Any] = [] args: list[Any] = []
transform = False transform = False
compute_dtype = get_computation_dtype(node.meta["val"].dtype) compute_dtype = get_computation_dtype(node.meta["val"].dtype)
@ -299,7 +299,7 @@ def tensorify_python_scalars(
"tensorify_float_success", True, overwrite=True "tensorify_float_success", True, overwrite=True
) )
failed_tensorify_ops: Set[str] = set() failed_tensorify_ops: set[str] = set()
# Now do one more pass that specializes all symfloats we didn't manage # Now do one more pass that specializes all symfloats we didn't manage
# to tensorify away. # to tensorify away.

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import Any, Dict, Tuple from typing import Any
import torch import torch
from torch.fx import Graph, GraphModule, Node from torch.fx import Graph, GraphModule, Node
@ -90,14 +90,14 @@ class CSEPass(PassBase):
modified = False modified = False
new_graph = Graph() new_graph = Graph()
env: Dict[ env: dict[
Node, Node Node, Node
] = {} # map from node in the old graph to node in the new graph ] = {} # map from node in the old graph to node in the new graph
hash_env: Dict[ hash_env: dict[
Tuple[torch._ops.OpOverload, int], Node tuple[torch._ops.OpOverload, int], Node
] = {} # map from hash to a node in the new graph ] = {} # map from hash to a node in the new graph
token_map: Dict[ token_map: dict[
Tuple[torch._ops.OpOverload, int], Dict[str, Any] tuple[torch._ops.OpOverload, int], dict[str, Any]
] = {} # map from hash to token ] = {} # map from hash to token
for n in graph_module.graph.nodes: for n in graph_module.graph.nodes:
# The placeholder, output, and get_attr nodes are copied to the new graph without change # The placeholder, output, and get_attr nodes are copied to the new graph without change

View File

@ -2,7 +2,7 @@
import hashlib import hashlib
from itertools import chain from itertools import chain
from typing import Any, Dict, Optional, TYPE_CHECKING from typing import Any, Optional, TYPE_CHECKING
import torch import torch
import torch.fx import torch.fx
@ -150,10 +150,10 @@ if HAS_PYDOT:
def get_submod_dot_graph(self, submod_name) -> pydot.Dot: def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
return self._dot_graphs[f"{self._name}_{submod_name}"] return self._dot_graphs[f"{self._name}_{submod_name}"]
def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: def get_all_dot_graphs(self) -> dict[str, pydot.Dot]:
return self._dot_graphs return self._dot_graphs
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]:
template = { template = {
"shape": self.dot_graph_shape, "shape": self.dot_graph_shape,
"fillcolor": "#CAFFE3", "fillcolor": "#CAFFE3",

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import Any, Dict, List, NamedTuple, Optional from typing import Any, NamedTuple, Optional
import torch import torch
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
@ -29,7 +29,7 @@ def replace_target_nodes_with(
"""Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
and updates them to match the new op code and target""" and updates them to match the new op code and target"""
new_graph = Graph() new_graph = Graph()
val_map: Dict[Node, Node] = {} val_map: dict[Node, Node] = {}
for node in fx_module.graph.nodes: for node in fx_module.graph.nodes:
if node.op == old_op and node.target == old_target: if node.op == old_op and node.target == old_target:
args = map_arg(node.args, lambda n: val_map[n]) args = map_arg(node.args, lambda n: val_map[n])
@ -52,7 +52,7 @@ class size_bytes(NamedTuple):
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def get_size_of_all_nodes( def get_size_of_all_nodes(
fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None fx_module: GraphModule, args: Optional[list[torch.Tensor]] = None
) -> None: ) -> None:
"""Given a fx graph module, update each node with its total size (weights + bias + output) """Given a fx graph module, update each node with its total size (weights + bias + output)
and its output_size(output). For a non-module node, the total size is the output size. and its output_size(output). For a non-module node, the total size is the output size.

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import os import os
from typing import Callable, Dict, List, Optional, Set, TypeVar from typing import Callable, Optional, TypeVar
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
@ -45,11 +45,11 @@ class GraphTransformObserver:
self.active = trace.enabled or self.log_url is not None self.active = trace.enabled or self.log_url is not None
if self.active: if self.active:
self.erased_nodes: Set[str] = set() self.erased_nodes: set[str] = set()
self.created_nodes: Set[str] = set() self.created_nodes: set[str] = set()
self.name_to_node: Dict[str, Node] = {} self.name_to_node: dict[str, Node] = {}
# record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context # record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context
self.copied_gms: List[GraphModule] = [] self.copied_gms: list[GraphModule] = []
self._node_creation_hook = self.get_node_creation_hook() self._node_creation_hook = self.get_node_creation_hook()
self._node_erase_hook = self.get_node_erase_hook() self._node_erase_hook = self.get_node_erase_hook()

View File

@ -2,8 +2,9 @@
import collections import collections
import itertools import itertools
import logging import logging
from collections.abc import Iterable, Sequence
from copy import copy from copy import copy
from typing import Dict, Iterable, List, Optional, Sequence, Set from typing import Optional
from torch.fx.graph_module import GraphModule from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name, Node from torch.fx.node import _get_qualified_name, Node
@ -52,10 +53,10 @@ class _DependencyViewer:
self.downstreams[node].add(output_node) self.downstreams[node].add(output_node)
self.downstreams[node].update(self.downstreams[output_node]) self.downstreams[node].update(self.downstreams[output_node])
def downstreams_of(self, node: Node) -> Set[Node]: def downstreams_of(self, node: Node) -> set[Node]:
return self.downstreams[node] return self.downstreams[node]
def upstreams_of(self, node: Node) -> Set[Node]: def upstreams_of(self, node: Node) -> set[Node]:
return self.upstreams[node] return self.upstreams[node]
@ -84,21 +85,21 @@ class CapabilityBasedPartitioner:
dict(self.graph_module.named_modules()), node dict(self.graph_module.named_modules()), node
) )
def propose_partitions(self) -> List[Partition]: def propose_partitions(self) -> list[Partition]:
# partition_map is a mapping from partition id to a set of partition id's. # partition_map is a mapping from partition id to a set of partition id's.
# The value set contains all the partition ids that can be reached by doing a # The value set contains all the partition ids that can be reached by doing a
# DFS starting from the partition id in the key. # DFS starting from the partition id in the key.
partition_map: Dict[int, Set] = collections.defaultdict(set) partition_map: dict[int, set] = collections.defaultdict(set)
# assumptions: nodes in candidate list is sorted in topological order # assumptions: nodes in candidate list is sorted in topological order
assignment: Dict[Node, int] = {} # mapping from node to partition_id assignment: dict[Node, int] = {} # mapping from node to partition_id
partitions_by_id: Dict[ partitions_by_id: dict[
int, Partition int, Partition
] = {} # mapping from partition_id to partition ] = {} # mapping from partition_id to partition
nodes_order: Dict[ nodes_order: dict[
Node, int Node, int
] = {} # mapping from nodes to reversed topological order ] = {} # mapping from nodes to reversed topological order
partitions_order: Dict[ partitions_order: dict[
int, int int, int
] = {} # mapping from partition_id to minimum topo order of nodes in partition ] = {} # mapping from partition_id to minimum topo order of nodes in partition
new_partition_id = itertools.count() new_partition_id = itertools.count()
@ -111,7 +112,7 @@ class CapabilityBasedPartitioner:
merged_nodes = copy(partitions_by_id[self_id].nodes) merged_nodes = copy(partitions_by_id[self_id].nodes)
merged_nodes.update(partitions_by_id[other_id].nodes) merged_nodes.update(partitions_by_id[other_id].nodes)
def dfs_iter_find_cycle(all_user_nodes: Set[Node]): def dfs_iter_find_cycle(all_user_nodes: set[Node]):
for user_node in all_user_nodes: for user_node in all_user_nodes:
visited_partition_ids = set() visited_partition_ids = set()
@ -210,7 +211,7 @@ class CapabilityBasedPartitioner:
for node in reversed(self.graph_module.graph.nodes): for node in reversed(self.graph_module.graph.nodes):
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
merge_candidates: Dict[int, None] = {} merge_candidates: dict[int, None] = {}
# Note a limited horizontal fusion is enabled: # Note a limited horizontal fusion is enabled:
# when `node` is not supported, the code below attempts to fuse consumer of `node`. # when `node` is not supported, the code below attempts to fuse consumer of `node`.
@ -241,7 +242,7 @@ class CapabilityBasedPartitioner:
# post processing to re-assign "getitem" nodes into upstream partition # post processing to re-assign "getitem" nodes into upstream partition
logger.debug("Reassigning getitem nodes to its producer node's partition...") logger.debug("Reassigning getitem nodes to its producer node's partition...")
nodes_reassignment: Dict[Node, int] = {} nodes_reassignment: dict[Node, int] = {}
for node in self.graph_module.graph.nodes: for node in self.graph_module.graph.nodes:
is_tuple_output = True is_tuple_output = True
for user in node.users: for user in node.users:
@ -266,7 +267,7 @@ class CapabilityBasedPartitioner:
logger.debug("Filtering out single node partitions...") logger.debug("Filtering out single node partitions...")
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
partitions_to_remove: List[int] = [] partitions_to_remove: list[int] = []
for id, partition in partitions_by_id.items(): for id, partition in partitions_by_id.items():
compute_node_count = 0 compute_node_count = 0
for node in partition.nodes: for node in partition.nodes:
@ -295,7 +296,7 @@ class CapabilityBasedPartitioner:
] ]
def fuse_partitions( def fuse_partitions(
self, partitions: List[Partition], prefix: str = "fused_" self, partitions: list[Partition], prefix: str = "fused_"
) -> GraphModule: ) -> GraphModule:
logger.debug("Fusing partitions...") logger.debug("Fusing partitions...")
# fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ] # fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ]
@ -306,7 +307,7 @@ class CapabilityBasedPartitioner:
) )
# remove non-compute-ops that sits at the boundary of a partition. # remove non-compute-ops that sits at the boundary of a partition.
def remove_bookend_non_compute_ops(self, partitions: List[Partition]): def remove_bookend_non_compute_ops(self, partitions: list[Partition]):
non_compute_ops = set(self.non_compute_ops) non_compute_ops = set(self.non_compute_ops)
def is_non_compute_node(node: Node): def is_non_compute_node(node: Node):
@ -316,11 +317,11 @@ class CapabilityBasedPartitioner:
) )
# cache transparent nodes # cache transparent nodes
transparent_input_nodes: Dict[Node, bool] = {} transparent_input_nodes: dict[Node, bool] = {}
transparent_output_nodes: Dict[Node, bool] = {} transparent_output_nodes: dict[Node, bool] = {}
def is_transparent_input_node( def is_transparent_input_node(
node: Node, partition: Set[Node], removed_nodes: Set[Node] node: Node, partition: set[Node], removed_nodes: set[Node]
): ):
if ( if (
node.op == "placeholder" node.op == "placeholder"
@ -341,7 +342,7 @@ class CapabilityBasedPartitioner:
return False return False
def is_transparent_output_node( def is_transparent_output_node(
node: Node, partition: Set[Node], removed_nodes: Set[Node] node: Node, partition: set[Node], removed_nodes: set[Node]
): ):
if ( if (
node.op == "placeholder" node.op == "placeholder"
@ -367,7 +368,7 @@ class CapabilityBasedPartitioner:
# Note it's ok to use `set` here, since we are only query if a node # Note it's ok to use `set` here, since we are only query if a node
# has been removed. We are NEVER going to iterate on nodes inside # has been removed. We are NEVER going to iterate on nodes inside
# the set. # the set.
remove_node: Set[Node] = set() remove_node: set[Node] = set()
for node in partition.nodes: for node in partition.nodes:
if is_non_compute_node(node) and ( if is_non_compute_node(node) and (
is_transparent_input_node(node, set(partition.nodes), remove_node) is_transparent_input_node(node, set(partition.nodes), remove_node)

View File

@ -3,7 +3,7 @@ import inspect
import logging import logging
from functools import wraps from functools import wraps
from queue import Queue from queue import Queue
from typing import Callable, Dict, List from typing import Callable
import torch.nn as nn import torch.nn as nn
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
@ -50,7 +50,7 @@ def pass_result_wrapper(fn: Callable) -> Callable:
def _validate_pass_schedule_constraint( def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable] constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
) -> None: ) -> None:
for i, a in enumerate(passes): for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]): for j, b in enumerate(passes[i + 1 :]):
@ -64,8 +64,8 @@ def _validate_pass_schedule_constraint(
def _topological_sort_passes( def _topological_sort_passes(
passes: List[Callable], constraints: List[Callable] passes: list[Callable], constraints: list[Callable]
) -> List[Callable]: ) -> list[Callable]:
""" """
Args Args
passes: Passes that we are ordering passes: Passes that we are ordering
@ -79,8 +79,8 @@ def _topological_sort_passes(
return passes return passes
# Contruct a graph mapping nodes to a list of their users # Contruct a graph mapping nodes to a list of their users
graph: Dict[Callable, List[Callable]] = {p: [] for p in passes} graph: dict[Callable, list[Callable]] = {p: [] for p in passes}
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0) indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0)
candidates: Queue = Queue() candidates: Queue = Queue()
for a in passes: for a in passes:
for b in passes: for b in passes:
@ -95,8 +95,8 @@ def _topological_sort_passes(
if indegree_map[a] == 0: if indegree_map[a] == 0:
candidates.put(a) candidates.put(a)
visited: Dict[Callable, bool] = dict.fromkeys(passes, False) visited: dict[Callable, bool] = dict.fromkeys(passes, False)
sorted_passes: List[Callable] = [] sorted_passes: list[Callable] = []
while not candidates.empty(): while not candidates.empty():
p = candidates.get() p = candidates.get()
@ -169,8 +169,8 @@ class PassManager:
checks checks
""" """
passes: List[Callable[[nn.Module], PassResult]] passes: list[Callable[[nn.Module], PassResult]]
constraints: List[Callable[[Callable, Callable], bool]] constraints: list[Callable[[Callable, Callable], bool]]
_validated: bool = False _validated: bool = False
steps: int = 1 steps: int = 1

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Optional
import torch import torch
import torch.fx import torch.fx
@ -106,7 +106,7 @@ class _MinimizerBase:
module: torch.fx.GraphModule, module: torch.fx.GraphModule,
sample_input: Tensors, sample_input: Tensors,
compare_fn: Callable[ compare_fn: Callable[
[TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool] [TensorOrTensors, TensorOrTensors, Names], tuple[float, bool]
], ],
settings: _MinimizerSettingBase, settings: _MinimizerSettingBase,
module_exporter: Optional[ module_exporter: Optional[
@ -124,16 +124,16 @@ class _MinimizerBase:
self.exclusion_fn = exclusion_fn self.exclusion_fn = exclusion_fn
# Stores outputs of run_a function # Stores outputs of run_a function
self.a_outputs: Dict[str, Any] = {} self.a_outputs: dict[str, Any] = {}
# Stores outputs of run_b function # Stores outputs of run_b function
self.b_outputs: Dict[str, Any] = {} self.b_outputs: dict[str, Any] = {}
# Stores the results of compare_fn # Stores the results of compare_fn
self.results: Dict[Any, Any] = {} self.results: dict[Any, Any] = {}
# Stores the report for the runs # Stores the report for the runs
self.reports: List[List[str]] = [] self.reports: list[list[str]] = []
# Current iteration # Current iteration
self.iteration: int = 0 self.iteration: int = 0
@ -205,7 +205,7 @@ class _MinimizerBase:
def _get_submod_inputs( def _get_submod_inputs(
self, main_module: torch.fx.GraphModule, submod_path: str self, main_module: torch.fx.GraphModule, submod_path: str
) -> Tuple[Tensors, Tensors]: ) -> tuple[Tensors, Tensors]:
""" """
Try get submodule inputs from stored outputs. If not found then use Try get submodule inputs from stored outputs. If not found then use
torch_glow.get_submod_inputs to get the inputs. torch_glow.get_submod_inputs to get the inputs.
@ -280,7 +280,7 @@ class _MinimizerBase:
else: else:
node.tag = "main_0" node.tag = "main_0"
def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]: def _build_submodule(self, nodes: NodeSet) -> tuple[torch.fx.GraphModule, str]:
""" """
Split self.module so that one submodule consists of `nodes` and only `nodes`. Split self.module so that one submodule consists of `nodes` and only `nodes`.
@ -412,7 +412,7 @@ class _MinimizerBase:
culprits: NodeSet = set() culprits: NodeSet = set()
nodes: NodeList = all_nodes[start_idx:end_idx] nodes: NodeList = all_nodes[start_idx:end_idx]
report: List[str] = [] report: list[str] = []
if self.exclusion_fn is not None: if self.exclusion_fn is not None:
self.exclusion_fn(nodes, start_idx, end_idx) self.exclusion_fn(nodes, start_idx, end_idx)
if len(nodes) == 0: if len(nodes) == 0:
@ -484,7 +484,7 @@ class _MinimizerBase:
culprits: NodeSet = set() culprits: NodeSet = set()
for node in nodes: for node in nodes:
report: List[str] = [] report: list[str] = []
self.reports.append(report) self.reports.append(report)
self.iteration += 1 self.iteration += 1
report.append(f"Sequential traverse iteration {self.iteration}.") report.append(f"Sequential traverse iteration {self.iteration}.")
@ -534,7 +534,7 @@ class _MinimizerBase:
find_last_node: If True, search for the last node which result in numerics difference find_last_node: If True, search for the last node which result in numerics difference
if False: find first node in sorted node list if False: find first node in sorted node list
""" """
report: List[str] = [] report: list[str] = []
mid = (start_idx + end_idx) // 2 mid = (start_idx + end_idx) // 2
cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:] cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:]
@ -726,7 +726,7 @@ class _MinimizerBase:
return culprits return culprits
for node in nodes: for node in nodes:
report: List[str] = [] report: list[str] = []
self.reports.append(report) self.reports.append(report)
self.iteration += 1 self.iteration += 1
report.append(f"Accumulate traverse iteration {self.iteration}.") report.append(f"Accumulate traverse iteration {self.iteration}.")
@ -770,7 +770,7 @@ class _MinimizerBase:
for node in nodes: for node in nodes:
if node in self.fusions: if node in self.fusions:
cur_nodes.update(self.fusions[node]) cur_nodes.update(self.fusions[node])
report: List[str] = [] report: list[str] = []
self.reports.append(report) self.reports.append(report)
self.iteration += 1 self.iteration += 1
report.append(f" Nodes block {self.iteration}.") report.append(f" Nodes block {self.iteration}.")
@ -797,7 +797,7 @@ class _MinimizerBase:
self.print_report(report) self.print_report(report)
return set() return set()
def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet: def _skip_traverse(self, all_nodes: NodeList, skip_nodes: list) -> NodeSet:
""" """
Skip certain nodes in graph based on settings Skip certain nodes in graph based on settings
""" """
@ -874,7 +874,7 @@ class _MinimizerBase:
) as e: ) as e:
print(e) print(e)
def print_report(self, report: List[str]): def print_report(self, report: list[str]):
for i in range(len(report)): for i in range(len(report)):
if i > 0: if i > 0:
print(" . " + report[i]) print(" . " + report[i])
@ -889,7 +889,7 @@ class _MinimizerBase:
self, self,
start: Optional[str] = None, start: Optional[str] = None,
end: Optional[str] = None, end: Optional[str] = None,
skip_nodes: Optional[List] = None, skip_nodes: Optional[list] = None,
find_last_node: Optional[bool] = None, find_last_node: Optional[bool] = None,
) -> NodeSet: ) -> NodeSet:
""" """

View File

@ -24,9 +24,9 @@ TargetTypeName = str
# Arguments' dtypes for a given node, see `OperatorSupport` # Arguments' dtypes for a given node, see `OperatorSupport`
SupportedArgumentDTypes = t.Optional[ SupportedArgumentDTypes = t.Optional[
t.Tuple[ tuple[
t.Sequence[t.Sequence[torch.dtype]], t.Sequence[t.Sequence[torch.dtype]],
t.Dict[str, t.Sequence[torch.dtype]], dict[str, t.Sequence[torch.dtype]],
] ]
] ]
@ -204,7 +204,7 @@ class OpSupports:
return create_op_support(_decline_if_input_dtype) return create_op_support(_decline_if_input_dtype)
@classmethod @classmethod
def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase: def decline_if_node_in_names(cls, disallow_set: set[str]) -> OperatorSupportBase:
""" """
If a node has a name that is in the disallow set, reported it as non-supported. If a node has a name that is in the disallow set, reported it as non-supported.
""" """

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Tuple, Type from typing import Any, Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -23,7 +23,7 @@ def default_matching(name: str, target_version: int) -> str:
# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. # This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. # The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. # If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = {
torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
torch.nn.modules.conv.Conv2d: ( torch.nn.modules.conv.Conv2d: (
1, 1,
@ -55,11 +55,11 @@ module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]]
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]:
"""If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
after checking module's version is compatible with the `module_fetch_book`. after checking module's version is compatible with the `module_fetch_book`.
""" """
attrs_for_lowering: Dict[str, Any] = {} attrs_for_lowering: dict[str, Any] = {}
attrs_for_lowering["name"] = torch.typename(mod) attrs_for_lowering["name"] = torch.typename(mod)
if type(mod) in module_fetch_book: if type(mod) in module_fetch_book:

View File

@ -2,7 +2,7 @@
import logging import logging
from functools import wraps from functools import wraps
from inspect import unwrap from inspect import unwrap
from typing import Callable, List, Optional from typing import Callable, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -121,7 +121,7 @@ def loop_pass(
# Implemented as 'depends on' operators. A constraint is satisfied iff a list # Implemented as 'depends on' operators. A constraint is satisfied iff a list
# has a valid partial ordering according to this comparison operator. # has a valid partial ordering according to this comparison operator.
def _validate_pass_schedule_constraint( def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable] constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
): ):
for i, a in enumerate(passes): for i, a in enumerate(passes):
for j, b in enumerate(passes[i + 1 :]): for j, b in enumerate(passes[i + 1 :]):
@ -191,8 +191,8 @@ class PassManager:
`this_before_that_pass_constraint` for example. `this_before_that_pass_constraint` for example.
""" """
passes: List[Callable] passes: list[Callable]
constraints: List[Callable] constraints: list[Callable]
_validated: bool = False _validated: bool = False
def __init__( def __init__(
@ -217,7 +217,7 @@ class PassManager:
self.constraints.append(constraint) self.constraints.append(constraint)
self._validated = False self._validated = False
def remove_pass(self, _passes: List[str]): def remove_pass(self, _passes: list[str]):
if _passes is None: if _passes is None:
return return
passes_left = [ps for ps in self.passes if ps.__name__ not in _passes] passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]

View File

@ -3,7 +3,6 @@ import _operator
import itertools import itertools
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Set
import torch import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -199,7 +198,7 @@ _VIEW_INVERSE_MAP = {
# This function, given a set of set of (aliased) tensor nodes, # This function, given a set of set of (aliased) tensor nodes,
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index # Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
# in the node ordering. # in the node ordering.
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int):
def _add_if_tensor(x, set_): def _add_if_tensor(x, set_):
if isinstance(x, FakeTensor): if isinstance(x, FakeTensor):
set_.add(StorageWeakRef(x._typed_storage())) set_.add(StorageWeakRef(x._typed_storage()))
@ -233,8 +232,8 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata # (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
# as "alias" # as "alias"
def _get_view_inverse_node_usages( def _get_view_inverse_node_usages(
later_node_usages: Set[Node], self_aliases: Set[Node] later_node_usages: set[Node], self_aliases: set[Node]
) -> Set[Node]: ) -> set[Node]:
def matching_view_metadata(a, b): def matching_view_metadata(a, b):
return ( return (
a.size() == b.size() a.size() == b.size()
@ -515,7 +514,7 @@ def reinplace(gm, *sample_args):
} }
# We also need to know for a given node, what are all of its aliasing nodes. # We also need to know for a given node, what are all of its aliasing nodes.
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) storage_to_nodes: dict[StorageWeakRef, set[Node]] = defaultdict(set)
for n in gm.graph.nodes: for n in gm.graph.nodes:
if "fake_result" in n.meta: if "fake_result" in n.meta:
# Tree-mapping because some ops can return lists of tensors. # Tree-mapping because some ops can return lists of tensors.

View File

@ -3,7 +3,7 @@ import functools
import logging import logging
import operator import operator
import sys import sys
from typing import Any, Dict, Optional, Set, TYPE_CHECKING from typing import Any, Optional, TYPE_CHECKING
# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow # Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
@ -123,7 +123,7 @@ def insert_deferred_runtime_asserts(
) )
# We are going to mutate the dict # We are going to mutate the dict
expr_to_proxy: Dict[sympy.Expr, fx.Proxy] = {} expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {}
placeholders = set() placeholders = set()
first_non_placeholder = None first_non_placeholder = None
for node in graph.nodes: for node in graph.nodes:
@ -163,7 +163,7 @@ def insert_deferred_runtime_asserts(
def _node_metadata_hook( def _node_metadata_hook(
node: torch.fx.Node, node: torch.fx.Node,
stack_trace: Optional[str] = None, stack_trace: Optional[str] = None,
nn_module_stack: Optional[Dict[str, Any]] = None, nn_module_stack: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
fake_args = pytree.tree_map( fake_args = pytree.tree_map(
lambda arg: ( lambda arg: (
@ -189,8 +189,8 @@ def insert_deferred_runtime_asserts(
node.meta["nn_module_stack"] = nn_module_stack node.meta["nn_module_stack"] = nn_module_stack
# Track asserts/checks we've added # Track asserts/checks we've added
added_asserts: Set[sympy.Expr] = set() added_asserts: set[sympy.Expr] = set()
constrained_unbacked_symbols: Set[sympy.Symbol] = set() constrained_unbacked_symbols: set[sympy.Symbol] = set()
Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis

View File

@ -1,7 +1,7 @@
# mypy: ignore-errors # mypy: ignore-errors
import traceback import traceback
from typing import Any, Dict, NamedTuple, Optional, Tuple from typing import Any, NamedTuple, Optional
import torch import torch
import torch.fx import torch.fx
@ -24,12 +24,12 @@ class TensorMetadata(NamedTuple):
shape: torch.Size shape: torch.Size
dtype: torch.dtype dtype: torch.dtype
requires_grad: bool requires_grad: bool
stride: Tuple[int, ...] stride: tuple[int, ...]
memory_format: Optional[torch.memory_format] memory_format: Optional[torch.memory_format]
# Quantization metadata # Quantization metadata
is_quantized: bool is_quantized: bool
qparams: Dict[str, Any] qparams: dict[str, Any]
def _extract_tensor_metadata( def _extract_tensor_metadata(
@ -57,7 +57,7 @@ def _extract_tensor_metadata(
break break
is_quantized = result.is_quantized is_quantized = result.is_quantized
qparams: Dict[str, Any] = {} qparams: dict[str, Any] = {}
if is_quantized: if is_quantized:
qscheme = result.qscheme() qscheme = result.qscheme()
qparams["qscheme"] = qscheme qparams["qscheme"] = qscheme

View File

@ -2,7 +2,7 @@
import inspect import inspect
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set from typing import Any, Callable, Optional
import torch import torch
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
@ -20,14 +20,14 @@ class Partition:
def __init__(self, name: str): def __init__(self, name: str):
self.name: str = name self.name: str = name
self.submod_name = f"submod_{name}" self.submod_name = f"submod_{name}"
self.node_names: List[str] = [] self.node_names: list[str] = []
self.inputs: Dict[str, None] = {} self.inputs: dict[str, None] = {}
self.outputs: Dict[str, None] = {} self.outputs: dict[str, None] = {}
self.dependencies: Dict[str, None] = {} self.dependencies: dict[str, None] = {}
self.dependents: Dict[str, None] = {} self.dependents: dict[str, None] = {}
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
self.environment: Dict[Node, Node] = {} self.environment: dict[Node, Node] = {}
self.targets: Dict[str, Any] = {} self.targets: dict[str, Any] = {}
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -55,7 +55,7 @@ def split_module(
m: GraphModule, m: GraphModule,
root_m: torch.nn.Module, root_m: torch.nn.Module,
split_callback: Callable[[Node], int], split_callback: Callable[[Node], int],
qualname_map: Optional[Dict[str, str]] = None, qualname_map: Optional[dict[str, str]] = None,
keep_original_order: Optional[bool] = False, keep_original_order: Optional[bool] = False,
keep_original_node_name: Optional[bool] = False, keep_original_node_name: Optional[bool] = False,
): ):
@ -161,8 +161,8 @@ def split_module(
def construct_graph( def construct_graph(
node: Node, node: Node,
base_mod_env: Dict[str, Node], base_mod_env: dict[str, Node],
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule], base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule],
): ):
if node.op == "placeholder": if node.op == "placeholder":
default_value = ( default_value = (
@ -195,9 +195,9 @@ def split_module(
import sympy import sympy
partitions: Dict[str, Partition] = {} partitions: dict[str, Partition] = {}
orig_nodes: Dict[str, Node] = {} orig_nodes: dict[str, Node] = {}
symbol_to_node: Dict[sympy.Symbol, Node] = {} symbol_to_node: dict[sympy.Symbol, Node] = {}
def record_cross_partition_use(def_node: Node, use_node: Optional[Node]): def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
from torch.fx.experimental.symbolic_shapes import free_symbols from torch.fx.experimental.symbolic_shapes import free_symbols
@ -273,7 +273,7 @@ def split_module(
# ------------------------ # ------------------------
# 1. first region: we do nothing # 1. first region: we do nothing
# 2. subsequent regions: we insert the set_grad at the beginning # 2. subsequent regions: we insert the set_grad at the beginning
grad_regions: OrderedDict[Node, Set[int]] = OrderedDict() grad_regions: OrderedDict[Node, set[int]] = OrderedDict()
# For autocast regions: # For autocast regions:
# ------------------------ # ------------------------
@ -282,8 +282,8 @@ def split_module(
# _enter at the beginning and _exit at the end # _enter at the beginning and _exit at the end
# 3. last region: we will only insert _enter at the beginning # 3. last region: we will only insert _enter at the beginning
# We will do so in the order in which the autocasts were instantiated. # We will do so in the order in which the autocasts were instantiated.
autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict() autocast_regions: OrderedDict[Node, set[int]] = OrderedDict()
autocast_exits: Dict[Node, Optional[Node]] = {} autocast_exits: dict[Node, Optional[Node]] = {}
active_grad = None active_grad = None
active_autocasts = set() active_autocasts = set()
@ -379,13 +379,13 @@ def split_module(
original_partition_order = list(partitions.keys()) original_partition_order = list(partitions.keys())
# find partitions with no dependencies # find partitions with no dependencies
root_partitions: List[str] = [] root_partitions: list[str] = []
for partition_name, partition in partitions.items(): for partition_name, partition in partitions.items():
if not len(partition.dependencies): if not len(partition.dependencies):
root_partitions.append(partition_name) root_partitions.append(partition_name)
# check partitions for circular dependencies and create topological partition ordering # check partitions for circular dependencies and create topological partition ordering
sorted_partitions: List[str] = [] sorted_partitions: list[str] = []
while root_partitions: while root_partitions:
root_partition = root_partitions.pop() root_partition = root_partitions.pop()
sorted_partitions.append(root_partition) sorted_partitions.append(root_partition)
@ -418,7 +418,7 @@ def split_module(
# add placeholders to partition inputs # add placeholders to partition inputs
for partition_name in sorted_partitions: for partition_name in sorted_partitions:
partition = partitions[partition_name] partition = partitions[partition_name]
new_inputs: Dict[str, None] = {} new_inputs: dict[str, None] = {}
for inp in partition.inputs: for inp in partition.inputs:
orig_node = orig_nodes[inp] orig_node = orig_nodes[inp]
# We don't pass in get_attr nodes as inputs to the partition, but # We don't pass in get_attr nodes as inputs to the partition, but
@ -507,11 +507,11 @@ def split_module(
) # is it really a good idea to copy this? ) # is it really a good idea to copy this?
# original module environment dict mapping node names to nodes # original module environment dict mapping node names to nodes
orig_mod_env: Dict[str, Node] = {} orig_mod_env: dict[str, Node] = {}
# Set up values to construct base module # Set up values to construct base module
base_mod_env: Dict[str, Node] = {} base_mod_env: dict[str, Node] = {}
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {}
if not keep_original_order: if not keep_original_order:
for node in m.graph.nodes: for node in m.graph.nodes:
base_mod_env, base_mod_attrs = construct_graph( base_mod_env, base_mod_attrs = construct_graph(
@ -559,7 +559,7 @@ def split_module(
if keep_original_order: if keep_original_order:
# first get the attr nodes required by this partition # first get the attr nodes required by this partition
orig_mod_attr_nodes: List[Node] = [ orig_mod_attr_nodes: list[Node] = [
orig_mod_env[key] orig_mod_env[key]
for key in partition.inputs for key in partition.inputs
if key not in original_order if key not in original_order

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import copy import copy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Optional, Union
import torch.fx import torch.fx
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
@ -45,28 +45,28 @@ class Component:
name: str name: str
# Stores the placeholder nodes in `graph`. # Stores the placeholder nodes in `graph`.
input_placeholders: List = field(default_factory=list) input_placeholders: list = field(default_factory=list)
# Store the nodes in original graph that are placeholder in `graph`. # Store the nodes in original graph that are placeholder in `graph`.
orig_inputs: List = field(default_factory=list) orig_inputs: list = field(default_factory=list)
# Store the nodes in original graph that are outputs in `graph`. # Store the nodes in original graph that are outputs in `graph`.
orig_outputs: List = field(default_factory=list) orig_outputs: list = field(default_factory=list)
# Mapping from get_attr node in original graph to get_attr node in `graph`. # Mapping from get_attr node in original graph to get_attr node in `graph`.
getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
constructor_args: List[str] = field(default_factory=list) constructor_args: list[str] = field(default_factory=list)
gm: Optional[torch.fx.GraphModule] = None gm: Optional[torch.fx.GraphModule] = None
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def split_by_tags( def split_by_tags(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
tags: List[str], tags: list[str],
return_fqn_mapping: bool = False, return_fqn_mapping: bool = False,
return_tuple: bool = False, return_tuple: bool = False,
GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule, GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule,
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]: ) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]:
""" """
Splits a GraphModule using tags on its graph nodes. We honor the order of Splits a GraphModule using tags on its graph nodes. We honor the order of
tags. For example, we have tags = ["a", "b", "c"], the function will create tags. For example, we have tags = ["a", "b", "c"], the function will create
@ -133,26 +133,26 @@ def split_by_tags(
return r return r
# Mapping from node in original module to node in created submodule. # Mapping from node in original module to node in created submodule.
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
# Mapping from node in original module or created submodules to # Mapping from node in original module or created submodules to
# corresponding component. # corresponding component.
node_to_component: Dict[torch.fx.Node, Component] = {} node_to_component: dict[torch.fx.Node, Component] = {}
# Mapping from tag to the corresponding component. # Mapping from tag to the corresponding component.
tag_to_component: Dict[str, Component] = {} tag_to_component: dict[str, Component] = {}
# Stores all components. # Stores all components.
all_components: List[Component] = [] all_components: list[Component] = []
# Stores nodes that will be used in main graph. # Stores nodes that will be used in main graph.
used_in_main: Dict[torch.fx.Node, None] = {} used_in_main: dict[torch.fx.Node, None] = {}
# Main graph after split. # Main graph after split.
main_g = torch.fx.Graph() main_g = torch.fx.Graph()
# Mapping from node in original module to node in main graph after split. # Mapping from node in original module to node in main graph after split.
main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} main_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
# Output node of original module. # Output node of original module.
output_node: Optional[torch.fx.Node] = None output_node: Optional[torch.fx.Node] = None
@ -258,7 +258,7 @@ def split_by_tags(
node_to_component[n].orig_outputs.append(n) node_to_component[n].orig_outputs.append(n)
# Now we create a graphmodule for each component. # Now we create a graphmodule for each component.
orig_to_split_fqn_mapping: Dict[str, str] = {} orig_to_split_fqn_mapping: dict[str, str] = {}
for comp in all_components: for comp in all_components:
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))

View File

@ -3,8 +3,9 @@ import argparse
import copy import copy
import logging import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple from typing import Any, NamedTuple, Optional
import torch import torch
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
@ -225,7 +226,7 @@ class SplitResult(NamedTuple):
""" """
split_module: torch.fx.GraphModule split_module: torch.fx.GraphModule
submodule_inputs: Dict[str, Any] submodule_inputs: dict[str, Any]
non_acc_submodule_prefix: str non_acc_submodule_prefix: str
@ -235,7 +236,7 @@ def generate_inputs_for_submodules(
inputs: Sequence[Any], inputs: Sequence[Any],
target_submodules: Iterable[str], target_submodules: Iterable[str],
deepcopy: bool = False, deepcopy: bool = False,
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
function doesn't work. function doesn't work.
@ -365,16 +366,16 @@ class _SplitterBase:
self.update_deps_for_fusions() self.update_deps_for_fusions()
self.non_acc_submodule_name = non_acc_submodule_name self.non_acc_submodule_name = non_acc_submodule_name
self._node_submodule_map: Dict[str, str] = {} self._node_submodule_map: dict[str, str] = {}
self._return_tuple = return_tuple self._return_tuple = return_tuple
self.tags: List[str] = [] self.tags: list[str] = []
# =============================================================== # ===============================================================
# Helpers for ctor and initial state # Helpers for ctor and initial state
# =============================================================== # ===============================================================
def get_node_submodule_map(self) -> Dict[str, str]: def get_node_submodule_map(self) -> dict[str, str]:
"""Returns a map from node name to submodule name, e.g. """Returns a map from node name to submodule name, e.g.
node: main_module_impl_impl_over_arch_unary_multiple_embedding node: main_module_impl_impl_over_arch_unary_multiple_embedding
_pooling_embedding_pooling_sparse_entity_equivalence_key _pooling_embedding_pooling_sparse_entity_equivalence_key
@ -383,7 +384,7 @@ class _SplitterBase:
""" """
return self._node_submodule_map return self._node_submodule_map
def find_deps(self) -> Dict[torch.fx.Node, NodeSet]: def find_deps(self) -> dict[torch.fx.Node, NodeSet]:
""" """
Builds a graph of node dependencies. Leaf nodes don't have any Builds a graph of node dependencies. Leaf nodes don't have any
dependencies and the "output" node doesn't have nodes depending on it. dependencies and the "output" node doesn't have nodes depending on it.
@ -391,7 +392,7 @@ class _SplitterBase:
Resulting graph has only direct dependencies, i.e. there are no Resulting graph has only direct dependencies, i.e. there are no
transitive dependencies. transitive dependencies.
""" """
deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set) deps: dict[torch.fx.Node, NodeSet] = defaultdict(set)
for node in self.module.graph.nodes: for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS: if node.op not in CALLABLE_NODE_OPS:
continue continue
@ -647,12 +648,12 @@ class _SplitterBase:
def find_reverse_deps( def find_reverse_deps(
self, tag_id: Optional[int] = None self, tag_id: Optional[int] = None
) -> Dict[torch.fx.Node, NodeSet]: ) -> dict[torch.fx.Node, NodeSet]:
""" """
Builds reversed topological node dependencies, if tag_id is specified, Builds reversed topological node dependencies, if tag_id is specified,
we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
""" """
result: Dict[torch.fx.Node, NodeSet] = defaultdict(set) result: dict[torch.fx.Node, NodeSet] = defaultdict(set)
for node in self.module.graph.nodes: for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS: if node.op not in CALLABLE_NODE_OPS:
@ -667,7 +668,7 @@ class _SplitterBase:
return result return result
def update_reverse_deps_for_fusions(self, deps: Dict[torch.fx.Node, NodeSet]): def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]):
processed_node = set() processed_node = set()
for node, fusion in self.fusions.items(): for node, fusion in self.fusions.items():
@ -757,7 +758,7 @@ class _SplitterBase:
# Helpers for split() method # Helpers for split() method
# =============================================================== # ===============================================================
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: def starter_nodes(self) -> tuple[NodeSet, NodeSet]:
""" """
Finds nodes that consume module inputs or get_attr nodes. Finds nodes that consume module inputs or get_attr nodes.
""" """
@ -773,7 +774,7 @@ class _SplitterBase:
starter_cpu_nodes.add(user) starter_cpu_nodes.add(user)
return starter_cpu_nodes, starter_acc_nodes return starter_cpu_nodes, starter_acc_nodes
def put_nodes_into_subgraphs(self) -> List[Subgraph]: def put_nodes_into_subgraphs(self) -> list[Subgraph]:
# We start graph traversal from leaf nodes # We start graph traversal from leaf nodes
current_cpu_nodes, current_acc_nodes = self.starter_nodes() current_cpu_nodes, current_acc_nodes = self.starter_nodes()
visited_nodes: NodeSet = set() visited_nodes: NodeSet = set()
@ -785,7 +786,7 @@ class _SplitterBase:
current_subgraph_nodes: NodeList = [] current_subgraph_nodes: NodeList = []
# Result accumulator # Result accumulator
subgraphs: List[Subgraph] = [] subgraphs: list[Subgraph] = []
while current_cpu_nodes or current_acc_nodes: while current_cpu_nodes or current_acc_nodes:
# Find the first node that should belong to the current subgraph and has all dependencies resolved # Find the first node that should belong to the current subgraph and has all dependencies resolved
current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
@ -839,12 +840,12 @@ class _SplitterBase:
return subgraphs return subgraphs
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]:
""" """
This pass finds ACC submodules with less than specified size and merges This pass finds ACC submodules with less than specified size and merges
them with adjacent CPU submodules. them with adjacent CPU submodules.
""" """
result: List[Subgraph] = [] result: list[Subgraph] = []
for subgraph in subgraphs: for subgraph in subgraphs:
if subgraph.is_acc: if subgraph.is_acc:
if len(subgraph.nodes) >= self.settings.min_acc_module_size: if len(subgraph.nodes) >= self.settings.min_acc_module_size:
@ -866,7 +867,7 @@ class _SplitterBase:
result.append(subgraph) result.append(subgraph)
return result return result
def tag(self, subgraphs: List[Subgraph]): def tag(self, subgraphs: list[Subgraph]):
self.tags = [] self.tags = []
for subgraph in subgraphs: for subgraph in subgraphs:
tag = ( tag = (

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import collections import collections
import operator import operator
from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union from typing import Any, Optional, Union
import torch import torch
import torch.fx import torch.fx
@ -18,11 +19,11 @@ __all__ = [
"legalize_graph", "legalize_graph",
] ]
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]]
TensorOrTensors = Union[torch.Tensor, Tensors] TensorOrTensors = Union[torch.Tensor, Tensors]
NodeList = List[torch.fx.Node] NodeList = list[torch.fx.Node]
NodeSet = Set[torch.fx.Node] NodeSet = set[torch.fx.Node]
Names = List[str] Names = list[str]
CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
@ -172,8 +173,8 @@ class FxNetAccFusionsFinder:
return False return False
def __call__(self) -> Dict[torch.fx.Node, NodeSet]: def __call__(self) -> dict[torch.fx.Node, NodeSet]:
result: Dict[torch.fx.Node, NodeSet] = {} result: dict[torch.fx.Node, NodeSet] = {}
acc_nodes = list(self.acc_nodes) acc_nodes = list(self.acc_nodes)
for node in acc_nodes: for node in acc_nodes:
@ -294,7 +295,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes: for node in gm.graph.nodes:
if indeg[node] == 0: if indeg[node] == 0:
queue.append(node) queue.append(node)
env: Dict[torch.fx.Node, torch.fx.Node] = {} env: dict[torch.fx.Node, torch.fx.Node] = {}
# Pop nodes from the queue, and add nodes that have had all their # Pop nodes from the queue, and add nodes that have had all their
# dependencies fulfilled # dependencies fulfilled
while len(queue) > 0: while len(queue) > 0:

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import Dict, Tuple
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph from torch.fx.graph import Graph
@ -30,7 +29,7 @@ def lift_subgraph_as_module(
subgraph: Graph, subgraph: Graph,
comp_name: str = "", comp_name: str = "",
class_name: str = "GraphModule", class_name: str = "GraphModule",
) -> Tuple[GraphModule, Dict[str, str]]: ) -> tuple[GraphModule, dict[str, str]]:
""" """
Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
@ -52,7 +51,7 @@ def lift_subgraph_as_module(
# make "weight" a attribute of "conv" HolderModule and point to conv.weight in # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
# the original module. # the original module.
submodule = HolderModule({}) submodule = HolderModule({})
orig_to_split_fqn_mapping: Dict[str, str] = {} orig_to_split_fqn_mapping: dict[str, str] = {}
for n in subgraph.nodes: for n in subgraph.nodes:
if n.op not in ("call_module", "get_attr"): if n.op not in ("call_module", "get_attr"):
continue continue

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import copy import copy
from queue import SimpleQueue from queue import SimpleQueue
from typing import Dict, List, Optional as _Optional, Tuple from typing import Optional as _Optional
import torch.fx import torch.fx
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
@ -97,10 +97,10 @@ def fuse_as_graphmodule(
gm: GraphModule, gm: GraphModule,
nodes: NodeList, nodes: NodeList,
module_name: str, module_name: str,
partition_lookup_table: _Optional[Dict[Node, None]] = None, partition_lookup_table: _Optional[dict[Node, None]] = None,
*, *,
always_return_tuple: bool = False, always_return_tuple: bool = False,
) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: ) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]:
""" """
Fuse nodes in graph_module into a GraphModule. Fuse nodes in graph_module into a GraphModule.
@ -144,10 +144,10 @@ def fuse_as_graphmodule(
subgraph = Graph() subgraph = Graph()
node_to_placeholder: Dict[ node_to_placeholder: dict[
Node, Node Node, Node
] = {} # mapping of nodes from old graph to placeholder in new graph ] = {} # mapping of nodes from old graph to placeholder in new graph
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph node_map: dict[Node, Node] = {} # mapping of nodes from old graph to new graph
# handles inputs through graph.node_copy's arg_transform functions # handles inputs through graph.node_copy's arg_transform functions
def remap_inputs(x): def remap_inputs(x):
@ -176,7 +176,7 @@ def fuse_as_graphmodule(
node_map[node] = new_node node_map[node] = new_node
# handles outputs # handles outputs
output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs output_mapping: dict[Node, Node] = {} # mapping from old output to new outputs
for node in nodes: for node in nodes:
for user_node in node.users: for user_node in node.users:
@ -202,10 +202,10 @@ def fuse_as_graphmodule(
) )
# sub_gm's input nodes in the original module # sub_gm's input nodes in the original module
original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys()) original_inputs: tuple[Node, ...] = tuple(node_to_placeholder.keys())
# sub_gm's outputs node in the original module # sub_gm's outputs node in the original module
original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys()) original_outputs: tuple[Node, ...] = tuple(output_mapping.keys())
return fused_gm, original_inputs, original_outputs return fused_gm, original_inputs, original_outputs
@ -214,8 +214,8 @@ def fuse_as_graphmodule(
def insert_subgm( def insert_subgm(
gm: GraphModule, gm: GraphModule,
sub_gm: GraphModule, sub_gm: GraphModule,
orig_inputs: Tuple[Node, ...], orig_inputs: tuple[Node, ...],
orig_outputs: Tuple[Node, ...], orig_outputs: tuple[Node, ...],
): ):
# add sub_gm into gm # add sub_gm into gm
submodule_name = sub_gm.__class__.__name__ submodule_name = sub_gm.__class__.__name__
@ -250,7 +250,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList):
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def fuse_by_partitions( def fuse_by_partitions(
gm: GraphModule, gm: GraphModule,
partitions: List[Dict[Node, None]], partitions: list[dict[Node, None]],
prefix: str = "fused_", prefix: str = "fused_",
always_return_tuple: bool = False, always_return_tuple: bool = False,
) -> GraphModule: ) -> GraphModule:

View File

@ -4,7 +4,7 @@ import logging
import os import os
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Set, Tuple, Union from typing import Any, Union
import torch import torch
from torch.fx import Graph, Node from torch.fx import Graph, Node
@ -37,19 +37,19 @@ logger = _init_logger()
@dataclass @dataclass
class InternalMatch: class InternalMatch:
# Nodes from which the match was found # Nodes from which the match was found
anchors: List[Node] anchors: list[Node]
# Maps nodes in the pattern subgraph to nodes in the larger graph # Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] = field(default_factory=dict) nodes_map: dict[Node, Node] = field(default_factory=dict)
# nodes in target graph that are matched placeholder in pattern # nodes in target graph that are matched placeholder in pattern
placeholder_nodes: List[Node] = field(default_factory=list) placeholder_nodes: list[Node] = field(default_factory=list)
# nodes in matched subgraph returned by output # nodes in matched subgraph returned by output
returning_nodes: List[Node] = field(default_factory=list) returning_nodes: list[Node] = field(default_factory=list)
# map from a string name to a node in the target graph # map from a string name to a node in the target graph
# only available if the matcher is `SubgraphMatcherWithNameNodesMap` # only available if the matcher is `SubgraphMatcherWithNameNodesMap`
name_node_map: Dict[str, Node] = field(default_factory=dict) name_node_map: dict[str, Node] = field(default_factory=dict)
def __copy__(self): def __copy__(self):
return InternalMatch( return InternalMatch(
@ -107,9 +107,9 @@ class SubgraphMatcher:
] ]
output_node = next(iter(reversed(pattern.nodes))) output_node = next(iter(reversed(pattern.nodes)))
# nodes returned by outputs # nodes returned by outputs
self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes self.pattern_returning_nodes: list[Node] = output_node.all_input_nodes
self.pattern_anchors: List[Node] = [] self.pattern_anchors: list[Node] = []
if match_output: if match_output:
self.pattern_anchors = [output_node] self.pattern_anchors = [output_node]
else: else:
@ -150,12 +150,12 @@ class SubgraphMatcher:
return pn.target == gn.target return pn.target == gn.target
return False return False
def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool: def _is_contained(self, nodes_map: dict[Node, Node]) -> bool:
# `lookup` represents all the nodes in `original_graph` # `lookup` represents all the nodes in `original_graph`
# that are part of `pattern` # that are part of `pattern`
# Placeholders can be used by other nodes in the graphs # Placeholders can be used by other nodes in the graphs
lookup: Dict[Node, Node] = { lookup: dict[Node, Node] = {
gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder" gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder"
} }
@ -172,10 +172,10 @@ class SubgraphMatcher:
return True return True
def _remove_overlapping_matches( def _remove_overlapping_matches(
self, matches: List[InternalMatch] self, matches: list[InternalMatch]
) -> List[InternalMatch]: ) -> list[InternalMatch]:
non_overlapping_matches: List[InternalMatch] = [] non_overlapping_matches: list[InternalMatch] = []
nodes_matched: Set[Node] = set() nodes_matched: set[Node] = set()
for match in matches: for match in matches:
found_overlap = False found_overlap = False
@ -244,7 +244,7 @@ class SubgraphMatcher:
# match for `gn` # match for `gn`
match_found = True match_found = True
def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool: def _match_args(args1: Union[list, tuple], args2: Union[list, tuple]) -> bool:
if len(args1) != len(args2): if len(args1) != len(args2):
return False return False
@ -313,7 +313,7 @@ class SubgraphMatcher:
return True return True
def match(self, graph: Graph) -> List[InternalMatch]: def match(self, graph: Graph) -> list[InternalMatch]:
""" """
Returns: Returns:
The matched subgraphs. The matched subgraphs.
@ -352,7 +352,7 @@ class SubgraphMatcher:
from torch.fx.passes.utils.fuser_utils import validate_partition from torch.fx.passes.utils.fuser_utils import validate_partition
# find candidate nodes to match with pattern anchors # find candidate nodes to match with pattern anchors
match_candidates: Dict[Node, List[Node]] = defaultdict(list) match_candidates: dict[Node, list[Node]] = defaultdict(list)
for pattern_anchor in self.pattern_anchors: for pattern_anchor in self.pattern_anchors:
for node in graph.nodes: for node in graph.nodes:
if self._nodes_are_equal(pattern_anchor, node): if self._nodes_are_equal(pattern_anchor, node):
@ -361,7 +361,7 @@ class SubgraphMatcher:
logger.info("Initial match_candidates_list: %s\n", match_candidates_list) logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
matches: List[InternalMatch] = [] matches: list[InternalMatch] = []
def backtracking(anchor_index, match): def backtracking(anchor_index, match):
if anchor_index == len(match_candidates_list): if anchor_index == len(match_candidates_list):

View File

@ -1,5 +1,3 @@
from typing import Dict, List, Tuple
from torch.fx import Graph, GraphModule, Node from torch.fx import Graph, GraphModule, Node
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
@ -11,7 +9,7 @@ __all__ = ["SubgraphMatcherWithNameNodeMap"]
def _split_to_graph_and_name_node_map( def _split_to_graph_and_name_node_map(
gm: GraphModule, gm: GraphModule,
) -> Tuple[GraphModule, Dict[str, Node]]: ) -> tuple[GraphModule, dict[str, Node]]:
from torch.fx.graph import _PyTreeInfo from torch.fx.graph import _PyTreeInfo
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
@ -29,7 +27,7 @@ def _split_to_graph_and_name_node_map(
*out, name_node_map = output *out, name_node_map = output
flattened, out_spec = tree_flatten(out) flattened, out_spec = tree_flatten(out)
assert isinstance( assert isinstance(
name_node_map, Dict name_node_map, dict
), "Expecting the input graph to have a dict output as the last element" ), "Expecting the input graph to have a dict output as the last element"
n.args = (flattened,) n.args = (flattened,)
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
@ -88,7 +86,7 @@ class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
ignore_literals, ignore_literals,
) )
def match(self, graph: Graph) -> List[InternalMatch]: def match(self, graph: Graph) -> list[InternalMatch]:
"""The returned InternalMatch will have name_node_map populated with a map """The returned InternalMatch will have name_node_map populated with a map
from node name (str) to the target node, e.g. from node name (str) to the target node, e.g.
{"conv": target_conv_ndoe, "relu": target_relu_node} {"conv": target_conv_ndoe, "relu": target_relu_node}

View File

@ -1,7 +1,7 @@
import logging import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type from typing import Any, Callable, Optional
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph from torch.fx.graph import Graph
@ -34,29 +34,29 @@ logger = _init_logger()
@dataclass @dataclass
class SourcePartition: class SourcePartition:
# Nodes in a particular partition # Nodes in a particular partition
nodes: List[Node] nodes: list[Node]
# The source these nodes decomposed from # The source these nodes decomposed from
source: Any source: Any
# Nodes in the graph that are needed as inputs to the partition # Nodes in the graph that are needed as inputs to the partition
# These do not include the params of the partition # These do not include the params of the partition
input_nodes: List[Node] = field(default_factory=list) input_nodes: list[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the # Nodes in the partition that are being used by nodes outside of the
# partition # partition
output_nodes: List[Node] = field(default_factory=list) output_nodes: list[Node] = field(default_factory=list)
# Parameters that are being used # Parameters that are being used
params: List[Node] = field(default_factory=list) params: list[Node] = field(default_factory=list)
@compatibility(is_backward_compatible=False) # type: ignore[misc] @compatibility(is_backward_compatible=False) # type: ignore[misc]
def get_source_partitions( def get_source_partitions(
graph: Graph, graph: Graph,
wanted_sources: List[Any], wanted_sources: list[Any],
filter_fn: Optional[Callable[[Node], bool]] = None, filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Dict[Any, List[SourcePartition]]: ) -> dict[Any, list[SourcePartition]]:
""" """
Args: Args:
graph: The graph we want to partition graph: The graph we want to partition
@ -69,7 +69,7 @@ def get_source_partitions(
that correspond to the list of nodes that were decomposed from the given that correspond to the list of nodes that were decomposed from the given
source. source.
""" """
modules: Dict[Type, Dict[str, List[Node]]] = {} modules: dict[type, dict[str, list[Node]]] = {}
for node in graph.nodes: for node in graph.nodes:
# The metadata source_fn should contain a tuple of a unique name for the # The metadata source_fn should contain a tuple of a unique name for the
@ -98,7 +98,7 @@ def get_source_partitions(
partition = diff_modules.setdefault(source_fn[0], []) partition = diff_modules.setdefault(source_fn[0], [])
partition.append(node) partition.append(node)
def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition: def make_partition(nodes: list[Node], module_type: type) -> SourcePartition:
input_nodes = set() input_nodes = set()
output_nodes = set() output_nodes = set()
params = set() params = set()
@ -124,7 +124,7 @@ def get_source_partitions(
list(params), # type: ignore[arg-type] list(params), # type: ignore[arg-type]
) )
ret: Dict[Type[Any], List[SourcePartition]] = {} ret: dict[type[Any], list[SourcePartition]] = {}
if filter_fn: if filter_fn:
# for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the

View File

@ -8,8 +8,10 @@ import inspect
import logging import logging
import operator import operator
import sys import sys
from collections import OrderedDict
from collections.abc import Iterator
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from typing import Any, Callable, Dict, Iterator, Optional, OrderedDict, Tuple from typing import Any, Callable, Optional
import torch import torch
import torch.fx.traceback as fx_traceback import torch.fx.traceback as fx_traceback
@ -135,18 +137,18 @@ class TracerBase:
scope: Scope scope: Scope
# Records the module call stack # Records the module call stack
module_stack: OrderedDict[str, Tuple[str, Any]] module_stack: OrderedDict[str, tuple[str, Any]]
# Mapping of node name to module scope # Mapping of node name to module scope
node_name_to_scope: Dict[str, Tuple[str, type]] node_name_to_scope: dict[str, tuple[str, type]]
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def create_node( def create_node(
self, self,
kind: str, kind: str,
target: Target, target: Target,
args: Tuple[Argument, ...], args: tuple[Argument, ...],
kwargs: Dict[str, Argument], kwargs: dict[str, Argument],
name: Optional[str] = None, name: Optional[str] = None,
type_expr: Optional[Any] = None, type_expr: Optional[Any] = None,
) -> Node: ) -> Node:
@ -171,7 +173,7 @@ class TracerBase:
# Optionally set stack trace on the created Node for debugging purposes # Optionally set stack trace on the created Node for debugging purposes
if fx_traceback.has_preserved_node_meta(): if fx_traceback.has_preserved_node_meta():
current_meta: Dict[str, Any] = fx_traceback.get_current_meta() current_meta: dict[str, Any] = fx_traceback.get_current_meta()
stack_trace = current_meta.get("stack_trace") stack_trace = current_meta.get("stack_trace")
if stack_trace: if stack_trace:
@ -211,8 +213,8 @@ class TracerBase:
self, self,
kind: str, kind: str,
target: Target, target: Target,
args: Tuple[Any, ...], args: tuple[Any, ...],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
name: Optional[str] = None, name: Optional[str] = None,
type_expr: Optional[Any] = None, type_expr: Optional[Any] = None,
# fix noqa when updating bc tests # fix noqa when updating bc tests
@ -455,10 +457,10 @@ class Proxy:
# we peephole optimize to the method invocation # we peephole optimize to the method invocation
return Attribute(self, k) return Attribute(self, k)
def __getstate__(self) -> Dict: def __getstate__(self) -> dict:
return self.__dict__ return self.__dict__
def __deepcopy__(self, memo) -> Dict: def __deepcopy__(self, memo) -> dict:
# We have to explicitly override this method, because otherwise deepcopy # We have to explicitly override this method, because otherwise deepcopy
# will go to __getattr__(self, "__deepcopy__") and return a # will go to __getattr__(self, "__deepcopy__") and return a
# Attribute(__deepcopy__), and may go into an infinite loop in some cases. # Attribute(__deepcopy__), and may go into an infinite loop in some cases.
@ -564,7 +566,7 @@ class Proxy:
args = args if args else () args = args if args else ()
kwargs = kwargs if kwargs else {} kwargs = kwargs if kwargs else {}
tracers: Dict[Any, None] = {} tracers: dict[Any, None] = {}
def find_tracer(a): def find_tracer(a):
if isinstance(a, cls): if isinstance(a, cls):

View File

@ -1,16 +1,6 @@
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Set,
TYPE_CHECKING,
Union,
)
import torch import torch
@ -37,7 +27,7 @@ class Match(NamedTuple):
# Node from which the match was found # Node from which the match was found
anchor: Node anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph # Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] nodes_map: dict[Node, Node]
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
@ -46,9 +36,9 @@ class ReplacedPatterns:
# Node from which the match was found # Node from which the match was found
anchor: Node anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph # Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] nodes_map: dict[Node, Node]
# List of nodes that were added into the graph # List of nodes that were added into the graph
replacements: List[Node] replacements: list[Node]
def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
@ -106,7 +96,7 @@ def replace_pattern(
gm: GraphModule, gm: GraphModule,
pattern: Union[Callable, GraphModule], pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule], replacement: Union[Callable, GraphModule],
) -> List[Match]: ) -> list[Match]:
""" """
Matches all possible non-overlapping sets of operators and their Matches all possible non-overlapping sets of operators and their
data dependencies (``pattern``) in the Graph of a GraphModule data dependencies (``pattern``) in the Graph of a GraphModule
@ -237,14 +227,14 @@ def replace_pattern_with_filters(
pattern: Union[Callable, Graph, GraphModule], pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule, None] = None, replacement: Union[Callable, Graph, GraphModule, None] = None,
match_filters: Optional[ match_filters: Optional[
List[Callable[["InternalMatch", Graph, Graph], bool]] list[Callable[["InternalMatch", Graph, Graph], bool]]
] = None, ] = None,
ignore_literals: bool = False, ignore_literals: bool = False,
# Placed at the end to avoid breaking backward compatibility # Placed at the end to avoid breaking backward compatibility
replacement_callback: Optional[ replacement_callback: Optional[
Callable[["InternalMatch", Graph, Graph], Graph] Callable[["InternalMatch", Graph, Graph], Graph]
] = None, ] = None,
) -> List[ReplacedPatterns]: ) -> list[ReplacedPatterns]:
""" """
See replace_pattern for documentation. This function is an overload with an additional match_filter argument. See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
@ -268,14 +258,14 @@ def _replace_pattern(
pattern: Union[Callable, Graph, GraphModule], pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule, None] = None, replacement: Union[Callable, Graph, GraphModule, None] = None,
match_filters: Optional[ match_filters: Optional[
List[Callable[["InternalMatch", Graph, Graph], bool]] list[Callable[["InternalMatch", Graph, Graph], bool]]
] = None, ] = None,
ignore_literals: bool = False, ignore_literals: bool = False,
# Placed at the end to avoid breaking backward compatibility # Placed at the end to avoid breaking backward compatibility
replacement_callback: Optional[ replacement_callback: Optional[
Callable[["InternalMatch", Graph, Graph], Graph] Callable[["InternalMatch", Graph, Graph], Graph]
] = None, ] = None,
) -> List[ReplacedPatterns]: ) -> list[ReplacedPatterns]:
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
if match_filters is None: if match_filters is None:
@ -298,7 +288,7 @@ def _replace_pattern(
remove_overlapping_matches=True, remove_overlapping_matches=True,
ignore_literals=ignore_literals, ignore_literals=ignore_literals,
) )
_matches: List[InternalMatch] = matcher.match(original_graph) _matches: list[InternalMatch] = matcher.match(original_graph)
# Filter out matches that don't match the filter # Filter out matches that don't match the filter
_matches = [ _matches = [
@ -323,7 +313,7 @@ def _replace_pattern(
common_replacement_graph = None common_replacement_graph = None
# As we progressively replace nodes, we'll need to keep track of how the match results should change # As we progressively replace nodes, we'll need to keep track of how the match results should change
match_changed_node: Dict[Node, Node] = {} match_changed_node: dict[Node, Node] = {}
match_and_replacements = [] match_and_replacements = []
for match in _matches: for match in _matches:
@ -345,7 +335,7 @@ def _replace_pattern(
# Initialize `val_map` with mappings from placeholder nodes in # Initialize `val_map` with mappings from placeholder nodes in
# `replacement` to their corresponding node in `original_graph` # `replacement` to their corresponding node in `original_graph`
assert len(match.placeholder_nodes) == len(replacement_placeholders) assert len(match.placeholder_nodes) == len(replacement_placeholders)
val_map: Dict[Node, Node] = {} val_map: dict[Node, Node] = {}
for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
if isinstance(gn, Node): if isinstance(gn, Node):
val_map[rn] = match_changed_node.get(gn, gn) val_map[rn] = match_changed_node.get(gn, gn)
@ -361,7 +351,7 @@ def _replace_pattern(
val_map[rn] = gn val_map[rn] = gn
# Copy the replacement graph over # Copy the replacement graph over
user_nodes: Set[Node] = set() user_nodes: set[Node] = set()
for n in match.returning_nodes: for n in match.returning_nodes:
user_nodes.update(n.users) user_nodes.update(n.users)
@ -402,7 +392,7 @@ def _replace_pattern(
copied_returning_nodes = (copied_returning_nodes,) copied_returning_nodes = (copied_returning_nodes,)
# Get a list of nodes that have been replaced into the graph # Get a list of nodes that have been replaced into the graph
replacement_nodes: List[Node] = [ replacement_nodes: list[Node] = [
v for v in val_map.values() if v not in match.placeholder_nodes v for v in val_map.values() if v not in match.placeholder_nodes
] ]

View File

@ -4,7 +4,7 @@ import json
import traceback import traceback
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from ._compatibility import compatibility from ._compatibility import compatibility
from .graph import Graph from .graph import Graph
@ -25,7 +25,7 @@ __all__ = [
"get_graph_provenance_json", "get_graph_provenance_json",
] ]
current_meta: Dict[str, Any] = {} current_meta: dict[str, Any] = {}
should_preserve_node_meta = False should_preserve_node_meta = False
@ -49,15 +49,15 @@ class NodeSource:
self.graph_id = graph_id self.graph_id = graph_id
pass_name: str pass_name: str
action: List["NodeSourceAction"] action: list["NodeSourceAction"]
from_node: List["NodeSource"] from_node: list["NodeSource"]
node_info: Optional["NodeInfo"] node_info: Optional["NodeInfo"]
def __init__( def __init__(
self, self,
node: Optional[Node], node: Optional[Node],
pass_name: str = "", pass_name: str = "",
action: Optional[Union["NodeSourceAction", List["NodeSourceAction"]]] = None, action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
): ):
self.pass_name = pass_name self.pass_name = pass_name
@ -146,7 +146,7 @@ def preserve_node_meta(enable=True):
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def set_stack_trace(stack: List[str]): def set_stack_trace(stack: list[str]):
global current_meta global current_meta
if should_preserve_node_meta and stack: if should_preserve_node_meta and stack:
@ -182,7 +182,7 @@ def reset_grad_fn_seq_nr():
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def format_stack() -> List[str]: def format_stack() -> list[str]:
if should_preserve_node_meta: if should_preserve_node_meta:
return [current_meta.get("stack_trace", "")] return [current_meta.get("stack_trace", "")]
else: else:
@ -219,7 +219,7 @@ def set_current_meta(node, pass_name=""):
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
def get_current_meta() -> Dict[str, Any]: def get_current_meta() -> dict[str, Any]:
return current_meta return current_meta