mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/fx (#145166)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
6374332d33
commit
0b2a3687b9
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
@ -2767,7 +2768,7 @@ class TestFX(JitTestCase):
|
|||||||
return self.other(x)
|
return self.other(x)
|
||||||
|
|
||||||
traced = symbolic_trace(ReturnTypeModule())
|
traced = symbolic_trace(ReturnTypeModule())
|
||||||
self.assertIn("-> typing_List[str]", traced._code)
|
self.assertIn("-> list[str]", traced._code)
|
||||||
scripted = torch.jit.script(traced)
|
scripted = torch.jit.script(traced)
|
||||||
self.assertIn("-> List[str]", scripted.code)
|
self.assertIn("-> List[str]", scripted.code)
|
||||||
|
|
||||||
@ -3566,8 +3567,8 @@ class TestFX(JitTestCase):
|
|||||||
|
|
||||||
traced(x, y)
|
traced(x, y)
|
||||||
|
|
||||||
FileCheck().check("_Tuple[()]") \
|
FileCheck().check("tuple[()]") \
|
||||||
.check("typing_Tuple[str,typing_Tuple[()]]") \
|
.check("tuple[str,tuple[()]]") \
|
||||||
.run(traced.code)
|
.run(traced.code)
|
||||||
|
|
||||||
scripted = torch.jit.script(traced)
|
scripted = torch.jit.script(traced)
|
||||||
@ -4063,45 +4064,62 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
|||||||
|
|
||||||
return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
|
return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
|
||||||
|
|
||||||
def _annotation_type_to_stable_str(self, t, sig_str):
|
_trivial_mappings = {
|
||||||
|
str : 'str',
|
||||||
|
int : 'int',
|
||||||
|
float: 'float',
|
||||||
|
bool: 'bool',
|
||||||
|
torch.dtype: 'torch.dtype',
|
||||||
|
torch.Tensor: 'torch.Tensor',
|
||||||
|
torch.device: 'torch.device',
|
||||||
|
torch.memory_format: 'torch.memory_format',
|
||||||
|
slice: 'slice',
|
||||||
|
torch.nn.Module: 'torch.nn.modules.module.Module',
|
||||||
|
torch.fx.Graph : 'torch.fx.graph.Graph',
|
||||||
|
torch.fx.Node : 'torch.fx.node.Node',
|
||||||
|
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
|
||||||
|
torch.fx.node.Target : 'torch.fx.node.Target',
|
||||||
|
torch.fx.node.Argument : 'torch.fx.node.Argument',
|
||||||
|
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
|
||||||
|
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
|
||||||
|
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
|
||||||
|
Ellipsis : '...',
|
||||||
|
typing.Any: 'Any',
|
||||||
|
type(None): 'NoneType',
|
||||||
|
None: 'None',
|
||||||
|
typing.Iterator: 'Iterator',
|
||||||
|
collections.abc.Iterator: 'Iterator',
|
||||||
|
}
|
||||||
|
|
||||||
|
_UNBOUND_TYPES = {
|
||||||
|
dict,
|
||||||
|
list,
|
||||||
|
tuple,
|
||||||
|
type,
|
||||||
|
typing.Callable,
|
||||||
|
typing.Dict,
|
||||||
|
typing.List,
|
||||||
|
typing.Tuple,
|
||||||
|
typing.Type,
|
||||||
|
typing.Union,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _annotation_type_to_stable_str(self, t, sig_str, recursive: bool = False):
|
||||||
if t is inspect.Signature.empty:
|
if t is inspect.Signature.empty:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
# Forward ref
|
# Forward ref
|
||||||
if isinstance(t, str):
|
if isinstance(t, str):
|
||||||
return f"'{t}'"
|
if recursive:
|
||||||
|
return t
|
||||||
|
else:
|
||||||
|
return f"'{t}'"
|
||||||
if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
|
if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
|
||||||
return t.__forward_arg__
|
return t.__forward_arg__
|
||||||
if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
|
if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
|
||||||
return t.__forward_arg__
|
return t.__forward_arg__
|
||||||
|
|
||||||
trivial_mappings = {
|
mapping = self._trivial_mappings.get(t, None)
|
||||||
str : 'str',
|
|
||||||
int : 'int',
|
|
||||||
float: 'float',
|
|
||||||
bool: 'bool',
|
|
||||||
torch.dtype: 'torch.dtype',
|
|
||||||
torch.Tensor: 'torch.Tensor',
|
|
||||||
torch.device: 'torch.device',
|
|
||||||
torch.memory_format: 'torch.memory_format',
|
|
||||||
slice: 'slice',
|
|
||||||
torch.nn.Module: 'torch.nn.modules.module.Module',
|
|
||||||
torch.fx.Graph : 'torch.fx.graph.Graph',
|
|
||||||
torch.fx.Node : 'torch.fx.node.Node',
|
|
||||||
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
|
|
||||||
torch.fx.node.Target : 'torch.fx.node.Target',
|
|
||||||
torch.fx.node.Argument : 'torch.fx.node.Argument',
|
|
||||||
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
|
|
||||||
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
|
|
||||||
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
|
|
||||||
Ellipsis : '...',
|
|
||||||
typing.Any: 'Any',
|
|
||||||
type(None): 'NoneType',
|
|
||||||
None: 'None',
|
|
||||||
typing.Iterator: 'Iterator',
|
|
||||||
}
|
|
||||||
|
|
||||||
mapping = trivial_mappings.get(t, None)
|
|
||||||
if mapping:
|
if mapping:
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
@ -4115,14 +4133,14 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
|||||||
if all(isinstance(ct, typing.TypeVar) for ct in contained):
|
if all(isinstance(ct, typing.TypeVar) for ct in contained):
|
||||||
contained = []
|
contained = []
|
||||||
|
|
||||||
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
|
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str, True) for ct in contained]
|
||||||
contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
|
contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
|
||||||
|
|
||||||
|
|
||||||
origin = getattr(t, '__origin__', None)
|
origin = getattr(t, '__origin__', None)
|
||||||
if origin is None:
|
if origin is None:
|
||||||
# Unbound types don't have `__origin__` in some Python versions, so fix that up here.
|
# Unbound types don't have `__origin__` in some Python versions, so fix that up here.
|
||||||
origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
|
origin = t if t in self._UNBOUND_TYPES else origin
|
||||||
|
|
||||||
if origin in {tuple, typing.Tuple}:
|
if origin in {tuple, typing.Tuple}:
|
||||||
return f'Tuple{contained_type_str}'
|
return f'Tuple{contained_type_str}'
|
||||||
@ -4130,7 +4148,7 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
|||||||
# Annoying hack to detect Optional
|
# Annoying hack to detect Optional
|
||||||
if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
|
if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
|
||||||
not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
|
not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
|
||||||
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
|
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str, True)}]'
|
||||||
return f'Union{contained_type_str}'
|
return f'Union{contained_type_str}'
|
||||||
if origin in {dict, typing.Dict}:
|
if origin in {dict, typing.Dict}:
|
||||||
return f'Dict{contained_type_str}'
|
return f'Dict{contained_type_str}'
|
||||||
|
@ -1524,6 +1524,29 @@ class {test_classname}(torch.nn.Module):
|
|||||||
(int, type(torch.float)),
|
(int, type(torch.float)),
|
||||||
(Union[int, float], int),
|
(Union[int, float], int),
|
||||||
(Union[int, float], float),
|
(Union[int, float], float),
|
||||||
|
(list[int], int),
|
||||||
|
(list[int], create_type_hint([int, int])),
|
||||||
|
(list[int], create_type_hint((int, int))),
|
||||||
|
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
|
||||||
|
(
|
||||||
|
list[torch.Tensor],
|
||||||
|
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
||||||
|
),
|
||||||
|
(torch.Tensor, torch.nn.Parameter),
|
||||||
|
(list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
||||||
|
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
||||||
|
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
||||||
|
(
|
||||||
|
list[torch.Tensor],
|
||||||
|
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
||||||
|
),
|
||||||
|
(torch.Tensor, torch.nn.Parameter),
|
||||||
|
(list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
||||||
|
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
||||||
|
(Optional[list[torch.Tensor]], list[torch.Tensor]),
|
||||||
|
(Optional[list[int]], list[int]),
|
||||||
|
] + [
|
||||||
|
# pre-PEP585 signatures
|
||||||
(List[int], int),
|
(List[int], int),
|
||||||
(List[int], create_type_hint([int, int])),
|
(List[int], create_type_hint([int, int])),
|
||||||
(List[int], create_type_hint((int, int))),
|
(List[int], create_type_hint((int, int))),
|
||||||
@ -1532,7 +1555,6 @@ class {test_classname}(torch.nn.Module):
|
|||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
||||||
),
|
),
|
||||||
(torch.Tensor, torch.nn.Parameter),
|
|
||||||
(List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
(List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
||||||
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
||||||
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
||||||
@ -1540,18 +1562,21 @@ class {test_classname}(torch.nn.Module):
|
|||||||
List[torch.Tensor],
|
List[torch.Tensor],
|
||||||
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
||||||
),
|
),
|
||||||
(torch.Tensor, torch.nn.Parameter),
|
|
||||||
(List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
(List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
||||||
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
||||||
(Optional[List[torch.Tensor]], List[torch.Tensor]),
|
(Optional[List[torch.Tensor]], List[torch.Tensor]),
|
||||||
(Optional[List[int]], List[int]),
|
(Optional[List[int]], List[int]),
|
||||||
]
|
]
|
||||||
|
|
||||||
for sig_type, arg_type in should_be_equal:
|
for sig_type, arg_type in should_be_equal:
|
||||||
self.assertTrue(type_matches(sig_type, arg_type))
|
self.assertTrue(type_matches(sig_type, arg_type))
|
||||||
|
|
||||||
should_fail = [
|
should_fail = [
|
||||||
(int, float),
|
(int, float),
|
||||||
(Union[int, float], str),
|
(Union[int, float], str),
|
||||||
|
(list[torch.Tensor], List[int]),
|
||||||
|
] + [
|
||||||
|
# pre-PEP585 signatures
|
||||||
(List[torch.Tensor], List[int]),
|
(List[torch.Tensor], List[int]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Callable, Dict, TypeVar
|
from typing import Any, Callable, TypeVar
|
||||||
|
|
||||||
|
|
||||||
_BACK_COMPAT_OBJECTS: Dict[Any, None] = {}
|
_BACK_COMPAT_OBJECTS: dict[Any, None] = {}
|
||||||
_MARKED_WITH_COMPATIBILITY: Dict[Any, None] = {}
|
_MARKED_WITH_COMPATIBILITY: dict[Any, None] = {}
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
@ -1,20 +1,20 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type
|
from typing import Any, Callable, NamedTuple, Optional
|
||||||
|
|
||||||
import torch.return_types
|
import torch.return_types
|
||||||
from torch.utils._pytree import PyTree, TreeSpec
|
from torch.utils._pytree import PyTree, TreeSpec
|
||||||
|
|
||||||
|
|
||||||
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
|
FlattenFuncSpec = Callable[[PyTree, TreeSpec], list]
|
||||||
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
|
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
|
||||||
|
|
||||||
SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {}
|
SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {}
|
||||||
SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
|
SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_pytree_flatten_spec(
|
def register_pytree_flatten_spec(
|
||||||
cls: Type[Any],
|
cls: type[Any],
|
||||||
flatten_fn_spec: FlattenFuncSpec,
|
flatten_fn_spec: FlattenFuncSpec,
|
||||||
flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
|
flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -23,7 +23,7 @@ def register_pytree_flatten_spec(
|
|||||||
|
|
||||||
|
|
||||||
def _deregister_pytree_flatten_spec(
|
def _deregister_pytree_flatten_spec(
|
||||||
cls: Type[Any],
|
cls: type[Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
del SUPPORTED_NODES[cls]
|
del SUPPORTED_NODES[cls]
|
||||||
del SUPPORTED_NODES_EXACT_MATCH[cls]
|
del SUPPORTED_NODES_EXACT_MATCH[cls]
|
||||||
@ -33,7 +33,7 @@ def tree_flatten_spec(
|
|||||||
pytree: PyTree,
|
pytree: PyTree,
|
||||||
spec: TreeSpec,
|
spec: TreeSpec,
|
||||||
exact_structural_match=False,
|
exact_structural_match=False,
|
||||||
) -> List[Any]:
|
) -> list[Any]:
|
||||||
if spec.is_leaf():
|
if spec.is_leaf():
|
||||||
return [pytree]
|
return [pytree]
|
||||||
if spec.type not in SUPPORTED_NODES:
|
if spec.type not in SUPPORTED_NODES:
|
||||||
@ -58,31 +58,31 @@ def tree_flatten_spec(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]:
|
def _dict_flatten_spec(d: dict[Any, Any], spec: TreeSpec) -> list[Any]:
|
||||||
return [d[k] for k in spec.context]
|
return [d[k] for k in spec.context]
|
||||||
|
|
||||||
|
|
||||||
def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
|
def _list_flatten_spec(d: list[Any], spec: TreeSpec) -> list[Any]:
|
||||||
return [d[i] for i in range(spec.num_children)]
|
return [d[i] for i in range(spec.num_children)]
|
||||||
|
|
||||||
|
|
||||||
def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]:
|
def _tuple_flatten_spec(d: tuple[Any], spec: TreeSpec) -> list[Any]:
|
||||||
return [d[i] for i in range(spec.num_children)]
|
return [d[i] for i in range(spec.num_children)]
|
||||||
|
|
||||||
|
|
||||||
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]:
|
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]:
|
||||||
return [d[i] for i in range(spec.num_children)]
|
return [d[i] for i in range(spec.num_children)]
|
||||||
|
|
||||||
|
|
||||||
def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool:
|
def _dict_flatten_spec_exact_match(d: dict[Any, Any], spec: TreeSpec) -> bool:
|
||||||
return len(d) == spec.num_children
|
return len(d) == spec.num_children
|
||||||
|
|
||||||
|
|
||||||
def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool:
|
def _list_flatten_spec_exact_match(d: list[Any], spec: TreeSpec) -> bool:
|
||||||
return len(d) == spec.num_children
|
return len(d) == spec.num_children
|
||||||
|
|
||||||
|
|
||||||
def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool:
|
def _tuple_flatten_spec_exact_match(d: tuple[Any], spec: TreeSpec) -> bool:
|
||||||
return len(d) == spec.num_children
|
return len(d) == spec.num_children
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,18 +10,7 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from types import CodeType, FunctionType, ModuleType
|
from types import CodeType, FunctionType, ModuleType
|
||||||
from typing import (
|
from typing import Any, Callable, NamedTuple, Optional, Union
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
@ -42,7 +31,7 @@ HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
|
|||||||
_orig_module_call: Callable = torch.nn.Module.__call__
|
_orig_module_call: Callable = torch.nn.Module.__call__
|
||||||
_orig_module_getattr: Callable = torch.nn.Module.__getattr__
|
_orig_module_getattr: Callable = torch.nn.Module.__getattr__
|
||||||
|
|
||||||
_proxyable_classes: Dict[Type, None] = {}
|
_proxyable_classes: dict[type, None] = {}
|
||||||
|
|
||||||
_is_fx_tracing_flag = False
|
_is_fx_tracing_flag = False
|
||||||
|
|
||||||
@ -262,8 +251,8 @@ class Tracer(TracerBase):
|
|||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
autowrap_modules: Tuple[ModuleType] = (math,),
|
autowrap_modules: tuple[ModuleType] = (math,),
|
||||||
autowrap_functions: Tuple[Callable, ...] = (),
|
autowrap_functions: tuple[Callable, ...] = (),
|
||||||
param_shapes_constant: bool = False,
|
param_shapes_constant: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# This method's signature is overridden by the first line of this class'
|
# This method's signature is overridden by the first line of this class'
|
||||||
@ -296,7 +285,7 @@ class Tracer(TracerBase):
|
|||||||
|
|
||||||
# Functions we will eagerly wrap when we see them while tracing
|
# Functions we will eagerly wrap when we see them while tracing
|
||||||
# this captures both `math.sqrt()` and `from math import sqrt` automatically
|
# this captures both `math.sqrt()` and `from math import sqrt` automatically
|
||||||
self._autowrap_function_ids: Set[int] = {
|
self._autowrap_function_ids: set[int] = {
|
||||||
id(value)
|
id(value)
|
||||||
for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
|
for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
|
||||||
if not name.startswith("_") and callable(value)
|
if not name.startswith("_") and callable(value)
|
||||||
@ -305,20 +294,20 @@ class Tracer(TracerBase):
|
|||||||
|
|
||||||
# Python modules to apply autowrap to at the start, in addition to
|
# Python modules to apply autowrap to at the start, in addition to
|
||||||
# modules we see while tracing
|
# modules we see while tracing
|
||||||
self._autowrap_search: List[ModuleType] = list(autowrap_modules)
|
self._autowrap_search: list[ModuleType] = list(autowrap_modules)
|
||||||
self.param_shapes_constant = param_shapes_constant
|
self.param_shapes_constant = param_shapes_constant
|
||||||
|
|
||||||
self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
|
self.submodule_paths: Optional[dict[torch.nn.Module, str]] = None
|
||||||
self.root_module_name: str = ""
|
self.root_module_name: str = ""
|
||||||
# Maps the containing module's name to the operator name
|
# Maps the containing module's name to the operator name
|
||||||
self.scope = Scope("", None)
|
self.scope = Scope("", None)
|
||||||
# Records the module call stack
|
# Records the module call stack
|
||||||
self.module_stack = collections.OrderedDict()
|
self.module_stack = collections.OrderedDict()
|
||||||
self.num_calls: Dict[str, int] = {}
|
self.num_calls: dict[str, int] = {}
|
||||||
# Mapping of node name to module scope
|
# Mapping of node name to module scope
|
||||||
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
|
self.node_name_to_scope: dict[str, tuple[str, type]] = {}
|
||||||
|
|
||||||
_qualname_counter: Dict[str, int] = collections.defaultdict(int)
|
_qualname_counter: dict[str, int] = collections.defaultdict(int)
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def get_fresh_qualname(self, prefix: str) -> str:
|
def get_fresh_qualname(self, prefix: str) -> str:
|
||||||
@ -492,8 +481,8 @@ class Tracer(TracerBase):
|
|||||||
self,
|
self,
|
||||||
m: torch.nn.Module,
|
m: torch.nn.Module,
|
||||||
forward: Callable[..., Any],
|
forward: Callable[..., Any],
|
||||||
args: Tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Method that specifies the behavior of this ``Tracer`` when it encounters
|
Method that specifies the behavior of this ``Tracer`` when it encounters
|
||||||
@ -547,7 +536,7 @@ class Tracer(TracerBase):
|
|||||||
return ret_val
|
return ret_val
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
|
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Method that specifies the behavior of this ``Tracer`` when we call getattr
|
Method that specifies the behavior of this ``Tracer`` when we call getattr
|
||||||
on a call to an ``nn.Module`` instance.
|
on a call to an ``nn.Module`` instance.
|
||||||
@ -626,7 +615,7 @@ class Tracer(TracerBase):
|
|||||||
total_args = co.co_argcount + co.co_kwonlyargcount
|
total_args = co.co_argcount + co.co_kwonlyargcount
|
||||||
orig_args = list(co.co_varnames)
|
orig_args = list(co.co_varnames)
|
||||||
names_iter = iter(co.co_varnames)
|
names_iter = iter(co.co_varnames)
|
||||||
args: List[Any] = []
|
args: list[Any] = []
|
||||||
skip_arg_idx = 0
|
skip_arg_idx = 0
|
||||||
if is_module:
|
if is_module:
|
||||||
if total_args == 0:
|
if total_args == 0:
|
||||||
@ -712,7 +701,7 @@ class Tracer(TracerBase):
|
|||||||
def trace(
|
def trace(
|
||||||
self,
|
self,
|
||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
concrete_args: Optional[Dict[str, Any]] = None,
|
concrete_args: Optional[dict[str, Any]] = None,
|
||||||
) -> Graph:
|
) -> Graph:
|
||||||
"""
|
"""
|
||||||
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
|
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
|
||||||
@ -763,7 +752,7 @@ class Tracer(TracerBase):
|
|||||||
self.root = torch.nn.Module()
|
self.root = torch.nn.Module()
|
||||||
fn = root
|
fn = root
|
||||||
|
|
||||||
tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None)
|
tracer_cls: Optional[type[Tracer]] = getattr(self, "__class__", None)
|
||||||
self.graph = Graph(tracer_cls=tracer_cls)
|
self.graph = Graph(tracer_cls=tracer_cls)
|
||||||
if hasattr(fn, "__code__"):
|
if hasattr(fn, "__code__"):
|
||||||
code = fn.__code__
|
code = fn.__code__
|
||||||
@ -777,11 +766,11 @@ class Tracer(TracerBase):
|
|||||||
# is some other attribute on the model. Construct a dict mapping Tensor
|
# is some other attribute on the model. Construct a dict mapping Tensor
|
||||||
# values to the qualified name here for efficiency. This is used downstream
|
# values to the qualified name here for efficiency. This is used downstream
|
||||||
# in create_arg
|
# in create_arg
|
||||||
self.tensor_attrs: Dict[
|
self.tensor_attrs: dict[
|
||||||
Union[torch.Tensor, ScriptObject, FakeScriptObject], str
|
Union[torch.Tensor, ScriptObject, FakeScriptObject], str
|
||||||
] = {}
|
] = {}
|
||||||
|
|
||||||
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
|
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]):
|
||||||
for k, v in m.__dict__.items():
|
for k, v in m.__dict__.items():
|
||||||
if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)):
|
if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)):
|
||||||
self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
|
self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
|
||||||
@ -797,7 +786,7 @@ class Tracer(TracerBase):
|
|||||||
fn, isinstance(root, torch.nn.Module), concrete_args
|
fn, isinstance(root, torch.nn.Module), concrete_args
|
||||||
)
|
)
|
||||||
|
|
||||||
parameter_proxy_cache: Dict[
|
parameter_proxy_cache: dict[
|
||||||
str, Proxy
|
str, Proxy
|
||||||
] = {} # Reduce number of get_attr calls
|
] = {} # Reduce number of get_attr calls
|
||||||
|
|
||||||
@ -872,7 +861,7 @@ class Tracer(TracerBase):
|
|||||||
nonlocal cnt
|
nonlocal cnt
|
||||||
cnt += 1
|
cnt += 1
|
||||||
param = sig.parameters[name]
|
param = sig.parameters[name]
|
||||||
default: Tuple[Any, ...] = (
|
default: tuple[Any, ...] = (
|
||||||
() if param.default is inspect.Parameter.empty else (param.default,)
|
() if param.default is inspect.Parameter.empty else (param.default,)
|
||||||
)
|
)
|
||||||
out = self.create_proxy(
|
out = self.create_proxy(
|
||||||
@ -913,7 +902,7 @@ class Tracer(TracerBase):
|
|||||||
|
|
||||||
return pytree.tree_map(replace_ph, concrete_args[name])
|
return pytree.tree_map(replace_ph, concrete_args[name])
|
||||||
if name[0] == "*":
|
if name[0] == "*":
|
||||||
default: Tuple[Any, ...] = ()
|
default: tuple[Any, ...] = ()
|
||||||
else:
|
else:
|
||||||
param = sig.parameters[name]
|
param = sig.parameters[name]
|
||||||
default = ( # type: ignore[assignment]
|
default = ( # type: ignore[assignment]
|
||||||
@ -932,11 +921,11 @@ class Tracer(TracerBase):
|
|||||||
# the purposes of the wrap() API.
|
# the purposes of the wrap() API.
|
||||||
# We key by the globals dict id and function name to ensure we're wrapping a given
|
# We key by the globals dict id and function name to ensure we're wrapping a given
|
||||||
# function only once.
|
# function only once.
|
||||||
_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {}
|
_wrapped_fns_to_patch: dict[tuple[int, str], dict] = {}
|
||||||
|
|
||||||
# List of methods on classes to wrap (class type, function name)
|
# List of methods on classes to wrap (class type, function name)
|
||||||
# this currently only works for Tensor.* methods that aren't traced properly
|
# this currently only works for Tensor.* methods that aren't traced properly
|
||||||
_wrapped_methods_to_patch: List[Tuple[type, str]] = []
|
_wrapped_methods_to_patch: list[tuple[type, str]] = []
|
||||||
|
|
||||||
if os.environ.get("FX_PATCH_GETITEM") == "1":
|
if os.environ.get("FX_PATCH_GETITEM") == "1":
|
||||||
# This change is needed to trace models like PositionalEmbedding from BERT:
|
# This change is needed to trace models like PositionalEmbedding from BERT:
|
||||||
@ -1043,12 +1032,12 @@ class _PatchedFnSetAttr(_PatchedFn):
|
|||||||
class _Patcher:
|
class _Patcher:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patches_made: List[_PatchedFn] = []
|
self.patches_made: list[_PatchedFn] = []
|
||||||
self.visited: Set[int] = set()
|
self.visited: set[int] = set()
|
||||||
|
|
||||||
def patch(
|
def patch(
|
||||||
self,
|
self,
|
||||||
frame_dict: Dict[str, Any],
|
frame_dict: dict[str, Any],
|
||||||
name: str,
|
name: str,
|
||||||
new_fn: Callable,
|
new_fn: Callable,
|
||||||
deduplicate: bool = True,
|
deduplicate: bool = True,
|
||||||
@ -1169,7 +1158,7 @@ def _patch_wrapped_functions(patcher: _Patcher):
|
|||||||
|
|
||||||
|
|
||||||
def _autowrap_check(
|
def _autowrap_check(
|
||||||
patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int]
|
patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
|
Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
|
||||||
@ -1252,7 +1241,7 @@ def wrap(fn_or_name: Union[str, Callable]):
|
|||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def symbolic_trace(
|
def symbolic_trace(
|
||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
concrete_args: Optional[Dict[str, Any]] = None,
|
concrete_args: Optional[dict[str, Any]] = None,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
"""
|
"""
|
||||||
Symbolic tracing API
|
Symbolic tracing API
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._logging import LazyString
|
from torch._logging import LazyString
|
||||||
@ -43,7 +43,7 @@ def _format_graph_code(name, filename, graph_str):
|
|||||||
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
|
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
|
||||||
|
|
||||||
|
|
||||||
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
|
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Returns the nn_module_stack of the first call_function node.
|
Returns the nn_module_stack of the first call_function node.
|
||||||
"""
|
"""
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import operator
|
import operator
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Deque, Dict, List, NamedTuple, Set, Tuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.experimental.partitioner_utils import (
|
from torch.fx.experimental.partitioner_utils import (
|
||||||
@ -28,15 +28,15 @@ class DAGNode:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
submodule_node: Node,
|
submodule_node: Node,
|
||||||
input_nodes: List[Node],
|
input_nodes: list[Node],
|
||||||
output_nodes: List[Node],
|
output_nodes: list[Node],
|
||||||
logical_device_ids: List[int],
|
logical_device_ids: list[int],
|
||||||
size_bytes: int,
|
size_bytes: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.submodule_node: Node = submodule_node
|
self.submodule_node: Node = submodule_node
|
||||||
self.input_nodes: List[Node] = input_nodes
|
self.input_nodes: list[Node] = input_nodes
|
||||||
self.output_nodes: List[Node] = output_nodes
|
self.output_nodes: list[Node] = output_nodes
|
||||||
self.logical_device_ids: List[int] = logical_device_ids
|
self.logical_device_ids: list[int] = logical_device_ids
|
||||||
self.size_bytes = size_bytes
|
self.size_bytes = size_bytes
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@ -47,14 +47,14 @@ class DAG:
|
|||||||
"""DAG class contains all the DAG nodes"""
|
"""DAG class contains all the DAG nodes"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.nodes: List[DAGNode] = []
|
self.nodes: list[DAGNode] = []
|
||||||
|
|
||||||
def create_node(
|
def create_node(
|
||||||
self,
|
self,
|
||||||
submodule_node: Node,
|
submodule_node: Node,
|
||||||
input_nodes: List[Node],
|
input_nodes: list[Node],
|
||||||
output_nodes: List[Node],
|
output_nodes: list[Node],
|
||||||
logical_devices: List[int],
|
logical_devices: list[int],
|
||||||
size_bytes: int,
|
size_bytes: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
node = DAGNode(
|
node = DAGNode(
|
||||||
@ -79,7 +79,7 @@ def reset_partition_device(partitions):
|
|||||||
|
|
||||||
|
|
||||||
def combine_two_partitions(
|
def combine_two_partitions(
|
||||||
partition_0: Partition, partition_1: Partition, partitions: List[Partition]
|
partition_0: Partition, partition_1: Partition, partitions: list[Partition]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Given a list of partitions and its two partitions,
|
"""Given a list of partitions and its two partitions,
|
||||||
combine these two partitions into a new one appending to the partitions
|
combine these two partitions into a new one appending to the partitions
|
||||||
@ -95,7 +95,7 @@ def combine_two_partitions(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def set_parents_and_children(partitions: List[Partition]) -> None:
|
def set_parents_and_children(partitions: list[Partition]) -> None:
|
||||||
"""Given a list of partitions, mark parents and children for each partition"""
|
"""Given a list of partitions, mark parents and children for each partition"""
|
||||||
# Go through all nodes in a partition.
|
# Go through all nodes in a partition.
|
||||||
# If a node's user is in other partition,
|
# If a node's user is in other partition,
|
||||||
@ -119,7 +119,7 @@ def set_parents_and_children(partitions: List[Partition]) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def reorganize_partitions(partitions: List[Partition]) -> None:
|
def reorganize_partitions(partitions: list[Partition]) -> None:
|
||||||
"""Given a list of partitions, reorganize partition id,
|
"""Given a list of partitions, reorganize partition id,
|
||||||
its parents and its children for each partition
|
its parents and its children for each partition
|
||||||
"""
|
"""
|
||||||
@ -130,17 +130,17 @@ def reorganize_partitions(partitions: List[Partition]) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def get_bfs_level_partition(partitions: List[Partition]) -> None:
|
def get_bfs_level_partition(partitions: list[Partition]) -> None:
|
||||||
"""Given a list of partitions,
|
"""Given a list of partitions,
|
||||||
mark the bfs level for each partition
|
mark the bfs level for each partition
|
||||||
"""
|
"""
|
||||||
current_level: Set[Partition] = set()
|
current_level: set[Partition] = set()
|
||||||
visited: Set[Partition] = set()
|
visited: set[Partition] = set()
|
||||||
for partition in partitions:
|
for partition in partitions:
|
||||||
# If a partition has no parent, it should be in root level
|
# If a partition has no parent, it should be in root level
|
||||||
if len(partition.parents) == 0:
|
if len(partition.parents) == 0:
|
||||||
current_level.add(partition)
|
current_level.add(partition)
|
||||||
next_level: Set[Partition] = set()
|
next_level: set[Partition] = set()
|
||||||
level = 0
|
level = 0
|
||||||
# bfs
|
# bfs
|
||||||
while current_level:
|
while current_level:
|
||||||
@ -158,26 +158,26 @@ def get_bfs_level_partition(partitions: List[Partition]) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]:
|
def get_node_to_partition_mapping(partitions: list[Partition]) -> dict[Node, int]:
|
||||||
"""Given a list of partitions,return node to partition mapping"""
|
"""Given a list of partitions,return node to partition mapping"""
|
||||||
node_to_partition: Dict[Node, int] = {}
|
node_to_partition: dict[Node, int] = {}
|
||||||
for partition in partitions:
|
for partition in partitions:
|
||||||
for node in partition.nodes:
|
for node in partition.nodes:
|
||||||
node_to_partition[node] = partition.partition_id
|
node_to_partition[node] = partition.partition_id
|
||||||
return node_to_partition
|
return node_to_partition
|
||||||
|
|
||||||
|
|
||||||
def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]:
|
def get_logical_id_to_device(devices: list[Device]) -> dict[int, Device]:
|
||||||
"""Get a mapping from device logical ID to Device object."""
|
"""Get a mapping from device logical ID to Device object."""
|
||||||
logical_id_to_device: Dict[int, Device] = {}
|
logical_id_to_device: dict[int, Device] = {}
|
||||||
for d in devices:
|
for d in devices:
|
||||||
logical_id_to_device[d.logical_id] = d
|
logical_id_to_device[d.logical_id] = d
|
||||||
return logical_id_to_device
|
return logical_id_to_device
|
||||||
|
|
||||||
|
|
||||||
def get_device_partition_stats(
|
def get_device_partition_stats(
|
||||||
partitions: List[Partition], devices: List[Device]
|
partitions: list[Partition], devices: list[Device]
|
||||||
) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]:
|
) -> tuple[dict[Device, list[Partition]], dict[Device, int], list[Partition]]:
|
||||||
"""Given a list of partitions and a list of devices, returns:
|
"""Given a list of partitions and a list of devices, returns:
|
||||||
1. A mapping from device to partitions on it;
|
1. A mapping from device to partitions on it;
|
||||||
2. A mapping from device to its remaining memory size;
|
2. A mapping from device to its remaining memory size;
|
||||||
@ -186,9 +186,9 @@ def get_device_partition_stats(
|
|||||||
# logical id to device
|
# logical id to device
|
||||||
logical_id_to_device = get_logical_id_to_device(devices)
|
logical_id_to_device = get_logical_id_to_device(devices)
|
||||||
# Track partitions on device
|
# Track partitions on device
|
||||||
device_to_partitions: Dict[Device, List[Partition]] = {}
|
device_to_partitions: dict[Device, list[Partition]] = {}
|
||||||
# Track device's left mem size
|
# Track device's left mem size
|
||||||
device_to_left_mem_bytes: Dict[Device, int] = {}
|
device_to_left_mem_bytes: dict[Device, int] = {}
|
||||||
for d in devices:
|
for d in devices:
|
||||||
device_to_partitions[d] = []
|
device_to_partitions[d] = []
|
||||||
device_to_left_mem_bytes[d] = d.available_mem_bytes
|
device_to_left_mem_bytes[d] = d.available_mem_bytes
|
||||||
@ -213,16 +213,16 @@ def get_device_partition_stats(
|
|||||||
|
|
||||||
|
|
||||||
def get_device_to_partitions_mapping(
|
def get_device_to_partitions_mapping(
|
||||||
partitions: List[Partition], devices: List[Device]
|
partitions: list[Partition], devices: list[Device]
|
||||||
):
|
):
|
||||||
"""Given a list of partitions and a list of devices,
|
"""Given a list of partitions and a list of devices,
|
||||||
map each partition into a device.
|
map each partition into a device.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def calculate_extra_mem_bytes_needed_for(
|
def calculate_extra_mem_bytes_needed_for(
|
||||||
partition: Partition, partitions: List[Partition]
|
partition: Partition, partitions: list[Partition]
|
||||||
):
|
):
|
||||||
all_nodes: Set[Node] = set()
|
all_nodes: set[Node] = set()
|
||||||
for p in partitions:
|
for p in partitions:
|
||||||
all_nodes = all_nodes.union(p.nodes)
|
all_nodes = all_nodes.union(p.nodes)
|
||||||
if len(all_nodes) == 0:
|
if len(all_nodes) == 0:
|
||||||
@ -273,8 +273,8 @@ def check_dependency(partition):
|
|||||||
"""Given a partition,check if there is a circular dependency on
|
"""Given a partition,check if there is a circular dependency on
|
||||||
this partition using bfs
|
this partition using bfs
|
||||||
"""
|
"""
|
||||||
visited: Set[Partition] = {partition}
|
visited: set[Partition] = {partition}
|
||||||
queue: Deque[Partition] = deque([partition])
|
queue: deque[Partition] = deque([partition])
|
||||||
while queue:
|
while queue:
|
||||||
p = queue.popleft()
|
p = queue.popleft()
|
||||||
for child in p.children:
|
for child in p.children:
|
||||||
@ -298,9 +298,9 @@ class Partitioner:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.partitions: List[Partition] = []
|
self.partitions: list[Partition] = []
|
||||||
self.node_to_partition: Dict[Node, int] = {}
|
self.node_to_partition: dict[Node, int] = {}
|
||||||
self.devices: List[Device] = []
|
self.devices: list[Device] = []
|
||||||
|
|
||||||
def partition_graph(
|
def partition_graph(
|
||||||
self,
|
self,
|
||||||
@ -435,9 +435,9 @@ class Partitioner:
|
|||||||
return device
|
return device
|
||||||
|
|
||||||
# Track partition and its left mem size
|
# Track partition and its left mem size
|
||||||
partition_to_left_mem_bytes: Dict[Partition, int] = {}
|
partition_to_left_mem_bytes: dict[Partition, int] = {}
|
||||||
# Track all the devices that have been used
|
# Track all the devices that have been used
|
||||||
occupied_devices: List[Device] = []
|
occupied_devices: list[Device] = []
|
||||||
partition = self.create_partition()
|
partition = self.create_partition()
|
||||||
for node in self.graph_module.graph.nodes:
|
for node in self.graph_module.graph.nodes:
|
||||||
if node.op in {"call_module", "call_method", "call_function"}:
|
if node.op in {"call_module", "call_method", "call_function"}:
|
||||||
@ -516,7 +516,7 @@ class Partitioner:
|
|||||||
# Devices that hold partitions
|
# Devices that hold partitions
|
||||||
used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
|
used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
|
||||||
# Track replicates of the assigned devices
|
# Track replicates of the assigned devices
|
||||||
replicated_device_to_used_device: Dict[Device, Device] = {}
|
replicated_device_to_used_device: dict[Device, Device] = {}
|
||||||
|
|
||||||
while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len(
|
while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len(
|
||||||
self.devices
|
self.devices
|
||||||
@ -583,7 +583,7 @@ class Partitioner:
|
|||||||
continue
|
continue
|
||||||
if node.target == operator.__getitem__:
|
if node.target == operator.__getitem__:
|
||||||
continue
|
continue
|
||||||
input_nodes: Dict[Node, None] = {}
|
input_nodes: dict[Node, None] = {}
|
||||||
map_arg(node.args, input_nodes.setdefault)
|
map_arg(node.args, input_nodes.setdefault)
|
||||||
map_arg(node.kwargs, input_nodes.setdefault)
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||||||
# When a node has two or more output nodes,
|
# When a node has two or more output nodes,
|
||||||
@ -634,7 +634,7 @@ class Partitioner:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def combine_partitions_based_on_size(
|
def combine_partitions_based_on_size(
|
||||||
partitions: List[Partition], available_mem_bytes: int
|
partitions: list[Partition], available_mem_bytes: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Combining small partitions together to keep as less partitions as possible.
|
"""Combining small partitions together to keep as less partitions as possible.
|
||||||
Here is an example of the algorithm to do this:
|
Here is an example of the algorithm to do this:
|
||||||
@ -672,10 +672,10 @@ class Partitioner:
|
|||||||
return mem_bytes_needed
|
return mem_bytes_needed
|
||||||
|
|
||||||
def find_partition_to_combine_based_on_size(
|
def find_partition_to_combine_based_on_size(
|
||||||
sorted_partitions: List[Partition],
|
sorted_partitions: list[Partition],
|
||||||
available_mem_bytes: int,
|
available_mem_bytes: int,
|
||||||
partitions: List[Partition],
|
partitions: list[Partition],
|
||||||
) -> Tuple[bool, List[Partition]]:
|
) -> tuple[bool, list[Partition]]:
|
||||||
"""step 1 in combine_partition_based_on_size()"""
|
"""step 1 in combine_partition_based_on_size()"""
|
||||||
find_combination = False
|
find_combination = False
|
||||||
smallest_partition = sorted_partitions.pop(0)
|
smallest_partition = sorted_partitions.pop(0)
|
||||||
@ -721,8 +721,8 @@ class Partitioner:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Track embedding partitions and non-embedding partitions separately
|
# Track embedding partitions and non-embedding partitions separately
|
||||||
embedding_partitions: List[Partition] = []
|
embedding_partitions: list[Partition] = []
|
||||||
non_embedding_partitions: List[Partition] = []
|
non_embedding_partitions: list[Partition] = []
|
||||||
# A Flag to check the boundary
|
# A Flag to check the boundary
|
||||||
in_embedding_region: bool = False
|
in_embedding_region: bool = False
|
||||||
partition = self.create_partition()
|
partition = self.create_partition()
|
||||||
@ -794,7 +794,7 @@ class Partitioner:
|
|||||||
def cost_aware_partition(
|
def cost_aware_partition(
|
||||||
self,
|
self,
|
||||||
transfer_rate_bytes_per_sec: float,
|
transfer_rate_bytes_per_sec: float,
|
||||||
node_to_latency_mapping: Dict[Node, NodeLatency],
|
node_to_latency_mapping: dict[Node, NodeLatency],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This method is to partition the fx module based on the cost.
|
"""This method is to partition the fx module based on the cost.
|
||||||
The cost is the total latency of running the whole fx module.
|
The cost is the total latency of running the whole fx module.
|
||||||
@ -872,7 +872,7 @@ class Partitioner:
|
|||||||
)
|
)
|
||||||
if len(self.partitions) == 1:
|
if len(self.partitions) == 1:
|
||||||
return False
|
return False
|
||||||
partition_pair: List[int] = []
|
partition_pair: list[int] = []
|
||||||
for i in range(len(self.partitions) - 1):
|
for i in range(len(self.partitions) - 1):
|
||||||
for j in range(i + 1, len(self.partitions)):
|
for j in range(i + 1, len(self.partitions)):
|
||||||
# Try to combine the partition pair
|
# Try to combine the partition pair
|
||||||
@ -915,7 +915,7 @@ class Partitioner:
|
|||||||
def kl_based_partition(
|
def kl_based_partition(
|
||||||
self,
|
self,
|
||||||
transfer_rate_bytes_per_sec: float,
|
transfer_rate_bytes_per_sec: float,
|
||||||
node_to_latency_mapping: Dict[Node, NodeLatency],
|
node_to_latency_mapping: dict[Node, NodeLatency],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function is a cost aware partition based
|
"""This function is a cost aware partition based
|
||||||
on Kernighan-Lin algorithm.
|
on Kernighan-Lin algorithm.
|
||||||
@ -987,7 +987,7 @@ class Partitioner:
|
|||||||
"""
|
"""
|
||||||
p1_nodes = list(p1.nodes) + [None]
|
p1_nodes = list(p1.nodes) + [None]
|
||||||
min_cost = float("inf")
|
min_cost = float("inf")
|
||||||
node_pair: List[Node] = []
|
node_pair: list[Node] = []
|
||||||
for n1 in p1_nodes:
|
for n1 in p1_nodes:
|
||||||
# Ignore the node if it is not a op node
|
# Ignore the node if it is not a op node
|
||||||
if n1 is not None and n1.op in {"placeholder", "get_attr"}:
|
if n1 is not None and n1.op in {"placeholder", "get_attr"}:
|
||||||
@ -1011,9 +1011,9 @@ class Partitioner:
|
|||||||
self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
|
self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
|
||||||
)
|
)
|
||||||
# Keep tracking the node pair that shows the better cost
|
# Keep tracking the node pair that shows the better cost
|
||||||
node_pair: List[Node] = []
|
node_pair: list[Node] = []
|
||||||
# Keep tracking the partition pair of node pair
|
# Keep tracking the partition pair of node pair
|
||||||
partition_pair: List[Partition] = []
|
partition_pair: list[Partition] = []
|
||||||
# Collect all the op nodes from the graph
|
# Collect all the op nodes from the graph
|
||||||
op_nodes = [
|
op_nodes = [
|
||||||
n
|
n
|
||||||
@ -1060,7 +1060,7 @@ class Partitioner:
|
|||||||
"""This function helps to rebuild the partitions given the nodes and its
|
"""This function helps to rebuild the partitions given the nodes and its
|
||||||
corresponding partition id
|
corresponding partition id
|
||||||
"""
|
"""
|
||||||
partition_id_to_partition_mapping: Dict[int, Partition] = {}
|
partition_id_to_partition_mapping: dict[int, Partition] = {}
|
||||||
self.node_to_partition = node_to_partition_mapping
|
self.node_to_partition = node_to_partition_mapping
|
||||||
for node in self.node_to_partition:
|
for node in self.node_to_partition:
|
||||||
partition_id = self.node_to_partition[node]
|
partition_id = self.node_to_partition[node]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Dict, Optional, Set, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx.node import map_arg
|
from torch.fx.node import map_arg
|
||||||
@ -100,7 +100,7 @@ def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
|
|||||||
call_mod_args = call_mod_node_to_replace.args
|
call_mod_args = call_mod_node_to_replace.args
|
||||||
call_mod_kwargs = call_mod_node_to_replace.kwargs
|
call_mod_kwargs = call_mod_node_to_replace.kwargs
|
||||||
|
|
||||||
replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||||
ph_count = 0
|
ph_count = 0
|
||||||
|
|
||||||
def replacement_fn(node):
|
def replacement_fn(node):
|
||||||
@ -171,7 +171,7 @@ def split_const_subgraphs(
|
|||||||
|
|
||||||
# Build up a list of const_nodes, defined as nodes that are themselves
|
# Build up a list of const_nodes, defined as nodes that are themselves
|
||||||
# get_attrs, or have all get_attr or other constant node inputs.
|
# get_attrs, or have all get_attr or other constant node inputs.
|
||||||
const_nodes: Set[torch.fx.Node] = set()
|
const_nodes: set[torch.fx.Node] = set()
|
||||||
found_const_folding = False
|
found_const_folding = False
|
||||||
for node in mod_traced.graph.nodes:
|
for node in mod_traced.graph.nodes:
|
||||||
# Skip over placeholders/outputs because they can't be const folded and
|
# Skip over placeholders/outputs because they can't be const folded and
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
|
|||||||
the `gm` with breakpoint inserted.
|
the `gm` with breakpoint inserted.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def insert_pdb(body: Sequence[str]) -> List[str]:
|
def insert_pdb(body: Sequence[str]) -> list[str]:
|
||||||
return ["import pdb; pdb.set_trace()\n", *body]
|
return ["import pdb; pdb.set_trace()\n", *body]
|
||||||
|
|
||||||
with gm.graph.on_generate_code(
|
with gm.graph.on_generate_code(
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import operator
|
import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Callable, Dict, TypeVar
|
from typing import Callable, TypeVar
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
@ -19,9 +19,9 @@ from torch.nn.modules.conv import Conv2d
|
|||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
_P = ParamSpec("_P")
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
_INFERENCE_RULES: Dict[Target, Callable] = {}
|
_INFERENCE_RULES: dict[Target, Callable] = {}
|
||||||
_REFINEMENT_RULES: Dict[Target, Callable] = {}
|
_REFINEMENT_RULES: dict[Target, Callable] = {}
|
||||||
_RULES: Dict[Target, Callable] = {}
|
_RULES: dict[Target, Callable] = {}
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GraphTypeChecker",
|
"GraphTypeChecker",
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import itertools
|
import itertools
|
||||||
import operator
|
import operator
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx._symbolic_trace import symbolic_trace
|
from torch.fx._symbolic_trace import symbolic_trace
|
||||||
@ -10,8 +9,8 @@ from torch.fx.passes.tools_common import legalize_graph
|
|||||||
|
|
||||||
|
|
||||||
def split_result_tensors(
|
def split_result_tensors(
|
||||||
result: torch.Tensor, inputs: List[torch.Tensor]
|
result: torch.Tensor, inputs: list[torch.Tensor]
|
||||||
) -> Tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
"""
|
"""
|
||||||
A free function for use in the merge_matmul graph transformation below that
|
A free function for use in the merge_matmul graph transformation below that
|
||||||
splits the output from a merged matmul into the individual results for each
|
splits the output from a merged matmul into the individual results for each
|
||||||
@ -71,7 +70,7 @@ def may_depend_on(a: Node, b: Node, search_depth: int = 6):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def are_nodes_independent(nodes: List[Node]):
|
def are_nodes_independent(nodes: list[Node]):
|
||||||
"""
|
"""
|
||||||
Check if all of the given nodes are pairwise-data independent.
|
Check if all of the given nodes are pairwise-data independent.
|
||||||
|
|
||||||
@ -102,8 +101,8 @@ def merge_matmul(in_mod: torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
gm = symbolic_trace(in_mod)
|
gm = symbolic_trace(in_mod)
|
||||||
|
|
||||||
rhs_users: Dict[Node, List[Node]] = {}
|
rhs_users: dict[Node, list[Node]] = {}
|
||||||
lhs_users: Dict[Node, List[Node]] = {}
|
lhs_users: dict[Node, list[Node]] = {}
|
||||||
|
|
||||||
# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
|
# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
|
||||||
# the matmul of which they are the LHS/RHS.
|
# the matmul of which they are the LHS/RHS.
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import builtins
|
import builtins
|
||||||
import functools
|
import functools
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
@ -40,7 +40,7 @@ def torch_abs_override(input, *, out=None):
|
|||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
manual_meta_overrides: Dict[Callable, Callable] = {
|
manual_meta_overrides: dict[Callable, Callable] = {
|
||||||
torch.nn.Embedding: embedding_override,
|
torch.nn.Embedding: embedding_override,
|
||||||
torch.nn.LayerNorm: nn_layernorm_override,
|
torch.nn.LayerNorm: nn_layernorm_override,
|
||||||
torch.relu: torch_relu_override,
|
torch.relu: torch_relu_override,
|
||||||
@ -274,7 +274,7 @@ class MetaTracer(torch.fx.Tracer):
|
|||||||
def proxy(self, node):
|
def proxy(self, node):
|
||||||
return MetaProxy(node, self)
|
return MetaProxy(node, self)
|
||||||
|
|
||||||
def trace(self, root, meta_args: Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
|
def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
|
||||||
assert isinstance(meta_args, dict)
|
assert isinstance(meta_args, dict)
|
||||||
self.meta_args = meta_args
|
self.meta_args = meta_args
|
||||||
|
|
||||||
@ -299,8 +299,8 @@ class MetaTracer(torch.fx.Tracer):
|
|||||||
|
|
||||||
def symbolic_trace(
|
def symbolic_trace(
|
||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
meta_args: Optional[Dict[str, torch.Tensor]] = None,
|
meta_args: Optional[dict[str, torch.Tensor]] = None,
|
||||||
concrete_args: Optional[Dict[str, Any]] = None,
|
concrete_args: Optional[dict[str, Any]] = None,
|
||||||
) -> torch.fx.GraphModule:
|
) -> torch.fx.GraphModule:
|
||||||
tracer = MetaTracer()
|
tracer = MetaTracer()
|
||||||
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
|
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import operator
|
import operator
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Callable, Dict, Iterable, TypeVar
|
from collections.abc import Iterable
|
||||||
|
from typing import Callable, TypeVar
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -57,7 +58,7 @@ from torch.nn.modules.conv import Conv2d
|
|||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
_P = ParamSpec("_P")
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
_INFERENCE_RULES: Dict[Target, Callable] = {}
|
_INFERENCE_RULES: dict[Target, Callable] = {}
|
||||||
|
|
||||||
MAX_TENSOR_RANK = 4
|
MAX_TENSOR_RANK = 4
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable
|
||||||
|
|
||||||
from torch.fx.experimental.migrate_gradual_types.constraint import (
|
from torch.fx.experimental.migrate_gradual_types.constraint import (
|
||||||
ApplyBroadcasting,
|
ApplyBroadcasting,
|
||||||
@ -50,7 +50,7 @@ from torch.fx.experimental.migrate_gradual_types.util import (
|
|||||||
from torch.fx.tensor_type import Dyn, TensorType
|
from torch.fx.tensor_type import Dyn, TensorType
|
||||||
|
|
||||||
|
|
||||||
_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {}
|
_TRANSFORMATION_RULES: dict[Constraint, Callable] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_transformation_rule(call_target):
|
def register_transformation_rule(call_target):
|
||||||
@ -797,7 +797,7 @@ def transform_constraint(constraint: Constraint, counter: int):
|
|||||||
return constraint, counter
|
return constraint, counter
|
||||||
|
|
||||||
|
|
||||||
def calc_last_two_dims(constraint, d: List[DVar]):
|
def calc_last_two_dims(constraint, d: list[DVar]):
|
||||||
"""
|
"""
|
||||||
Generates constraints for the last two dimensions of a convolution or a maxpool output
|
Generates constraints for the last two dimensions of a convolution or a maxpool output
|
||||||
Args:
|
Args:
|
||||||
@ -866,7 +866,7 @@ def calc_last_two_dims(constraint, d: List[DVar]):
|
|||||||
return c4, c5
|
return c4, c5
|
||||||
|
|
||||||
|
|
||||||
def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
|
def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]):
|
||||||
"""
|
"""
|
||||||
Generate all possibilities of being equal or not equal to dyn for my_list
|
Generate all possibilities of being equal or not equal to dyn for my_list
|
||||||
Args:
|
Args:
|
||||||
@ -888,7 +888,7 @@ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
|
|||||||
return all_possibilities
|
return all_possibilities
|
||||||
|
|
||||||
|
|
||||||
def is_target_div_by_dim(target: List[int], dim: List[DVar]):
|
def is_target_div_by_dim(target: list[int], dim: list[DVar]):
|
||||||
"""
|
"""
|
||||||
Generate constraints to check if the target dimensions are divisible by the input dimensions
|
Generate constraints to check if the target dimensions are divisible by the input dimensions
|
||||||
Args:
|
Args:
|
||||||
@ -901,7 +901,7 @@ def is_target_div_by_dim(target: List[int], dim: List[DVar]):
|
|||||||
return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
|
return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
|
||||||
|
|
||||||
|
|
||||||
def is_dim_div_by_target(target: List[int], dim: List[DVar]):
|
def is_dim_div_by_target(target: list[int], dim: list[DVar]):
|
||||||
"""
|
"""
|
||||||
Generate constraints to check if the input dimensions is divisible by the target dimensions
|
Generate constraints to check if the input dimensions is divisible by the target dimensions
|
||||||
Args:
|
Args:
|
||||||
@ -1000,9 +1000,9 @@ def apply_padding(
|
|||||||
e11: BinConstraintT,
|
e11: BinConstraintT,
|
||||||
e2: BinConstraintT,
|
e2: BinConstraintT,
|
||||||
e12: BinConstraintT,
|
e12: BinConstraintT,
|
||||||
d2: List[DVar],
|
d2: list[DVar],
|
||||||
d11: List[DVar],
|
d11: list[DVar],
|
||||||
d12: List[DVar],
|
d12: list[DVar],
|
||||||
counter: int,
|
counter: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1068,7 +1068,7 @@ def apply_padding(
|
|||||||
|
|
||||||
|
|
||||||
def no_broadcast_dim_with_index(
|
def no_broadcast_dim_with_index(
|
||||||
d1: List[DVar], d2: List[DVar], d3: List[DVar], d4: List[DVar], i: int
|
d1: list[DVar], d2: list[DVar], d3: list[DVar], d4: list[DVar], i: int
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -1129,10 +1129,10 @@ def create_equality_constraints_for_broadcasting(
|
|||||||
e2: TVar,
|
e2: TVar,
|
||||||
e11: TVar,
|
e11: TVar,
|
||||||
e12: TVar,
|
e12: TVar,
|
||||||
d1: List[DVar],
|
d1: list[DVar],
|
||||||
d2: List[DVar],
|
d2: list[DVar],
|
||||||
d11: List[DVar],
|
d11: list[DVar],
|
||||||
d12: List[DVar],
|
d12: list[DVar],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create equality constraints for when no broadcasting occurs
|
Create equality constraints for when no broadcasting occurs
|
||||||
@ -1236,7 +1236,7 @@ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
|
|||||||
|
|
||||||
|
|
||||||
def generate_all_broadcasting_possibilities_no_padding(
|
def generate_all_broadcasting_possibilities_no_padding(
|
||||||
d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]
|
d1: list[DVar], d2: list[DVar], d11: list[DVar], d12: list[DVar]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
|
Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import operator
|
import operator
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
@ -38,7 +38,7 @@ class NormalizeArgs(Transformer):
|
|||||||
self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
|
self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
|
||||||
):
|
):
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.node_map: Dict[Proxy, Node] = {}
|
self.node_map: dict[Proxy, Node] = {}
|
||||||
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
|
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
|
||||||
|
|
||||||
def run_node(self, n: Node) -> Any:
|
def run_node(self, n: Node) -> Any:
|
||||||
@ -66,10 +66,10 @@ class NormalizeArgs(Transformer):
|
|||||||
def call_function(
|
def call_function(
|
||||||
self,
|
self,
|
||||||
target: Target,
|
target: Target,
|
||||||
args: Tuple[Argument, ...],
|
args: tuple[Argument, ...],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
arg_types: Optional[Tuple[Any, ...]] = None,
|
arg_types: Optional[tuple[Any, ...]] = None,
|
||||||
kwarg_types: Optional[Dict[str, Any]] = None,
|
kwarg_types: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
assert callable(target)
|
assert callable(target)
|
||||||
new_args_and_kwargs = normalize_function(
|
new_args_and_kwargs = normalize_function(
|
||||||
@ -89,7 +89,7 @@ class NormalizeArgs(Transformer):
|
|||||||
return super().call_function(target, args, kwargs)
|
return super().call_function(target, args, kwargs)
|
||||||
|
|
||||||
def call_module(
|
def call_module(
|
||||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
):
|
):
|
||||||
assert isinstance(target, str)
|
assert isinstance(target, str)
|
||||||
new_args_and_kwargs = normalize_module(
|
new_args_and_kwargs = normalize_module(
|
||||||
@ -124,7 +124,7 @@ class NormalizeOperators(AnnotateTypesWithSchema):
|
|||||||
traced = NormalizeOperators(traced).transform()
|
traced = NormalizeOperators(traced).transform()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
binary_magic_method_remap: Dict[
|
binary_magic_method_remap: dict[
|
||||||
Callable[[Any, Any], Any], Callable[[Any, Any], Any]
|
Callable[[Any, Any], Any], Callable[[Any, Any], Any]
|
||||||
] = {
|
] = {
|
||||||
torch.add: operator.add,
|
torch.add: operator.add,
|
||||||
@ -142,7 +142,7 @@ class NormalizeOperators(AnnotateTypesWithSchema):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
):
|
):
|
||||||
# Normalize operators according to the magic methods implemented on tensors here:
|
# Normalize operators according to the magic methods implemented on tensors here:
|
||||||
# https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
|
# https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
|
||||||
|
@ -4,8 +4,9 @@ import logging
|
|||||||
import operator
|
import operator
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type
|
from typing import Any, cast, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
@ -33,7 +34,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _parent_name(target: str) -> Tuple[str, str]:
|
def _parent_name(target: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Splits a qualname into parent path and last atom.
|
Splits a qualname into parent path and last atom.
|
||||||
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
|
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
|
||||||
@ -44,11 +45,11 @@ def _parent_name(target: str) -> Tuple[str, str]:
|
|||||||
|
|
||||||
# Works for length 2 patterns with 2 modules
|
# Works for length 2 patterns with 2 modules
|
||||||
def matches_module_pattern(
|
def matches_module_pattern(
|
||||||
pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]
|
pattern: Iterable[type], node: fx.Node, modules: dict[str, Any]
|
||||||
):
|
):
|
||||||
if len(node.args) == 0:
|
if len(node.args) == 0:
|
||||||
return False
|
return False
|
||||||
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
|
nodes: tuple[Any, fx.Node] = (node.args[0], node)
|
||||||
for expected_type, current_node in zip(pattern, nodes):
|
for expected_type, current_node in zip(pattern, nodes):
|
||||||
if not isinstance(current_node, fx.Node):
|
if not isinstance(current_node, fx.Node):
|
||||||
return False
|
return False
|
||||||
@ -64,7 +65,7 @@ def matches_module_pattern(
|
|||||||
|
|
||||||
|
|
||||||
def replace_node_module(
|
def replace_node_module(
|
||||||
node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module
|
node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module
|
||||||
):
|
):
|
||||||
assert isinstance(node.target, str)
|
assert isinstance(node.target, str)
|
||||||
parent_name, name = _parent_name(node.target)
|
parent_name, name = _parent_name(node.target)
|
||||||
@ -120,7 +121,7 @@ def remove_dropout(model: nn.Module) -> nn.Module:
|
|||||||
|
|
||||||
class DropoutRemover(torch.fx.Transformer):
|
class DropoutRemover(torch.fx.Transformer):
|
||||||
def call_module(
|
def call_module(
|
||||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if isinstance(self.submodules[target], nn.Dropout):
|
if isinstance(self.submodules[target], nn.Dropout):
|
||||||
assert len(args) == 1
|
assert len(args) == 1
|
||||||
@ -133,15 +134,15 @@ def remove_dropout(model: nn.Module) -> nn.Module:
|
|||||||
|
|
||||||
def extract_subgraph(
|
def extract_subgraph(
|
||||||
orig_module: nn.Module,
|
orig_module: nn.Module,
|
||||||
nodes: List[fx.Node],
|
nodes: list[fx.Node],
|
||||||
inputs: List[fx.Node],
|
inputs: list[fx.Node],
|
||||||
outputs: List[fx.Node],
|
outputs: list[fx.Node],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
|
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
|
||||||
"""
|
"""
|
||||||
new_graph = fx.Graph()
|
new_graph = fx.Graph()
|
||||||
env: Dict[fx.Node, fx.Node] = {}
|
env: dict[fx.Node, fx.Node] = {}
|
||||||
for input in inputs:
|
for input in inputs:
|
||||||
new_node = new_graph.placeholder(input.name)
|
new_node = new_graph.placeholder(input.name)
|
||||||
env[input] = new_node
|
env[input] = new_node
|
||||||
@ -180,13 +181,13 @@ mkldnn_map = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
|
def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
|
||||||
"""
|
"""
|
||||||
For each node, if it's a module that can be preconverted into MKLDNN,
|
For each node, if it's a module that can be preconverted into MKLDNN,
|
||||||
then we do so and create a mapping to allow us to convert from the MKLDNN
|
then we do so and create a mapping to allow us to convert from the MKLDNN
|
||||||
version of the module to the original.
|
version of the module to the original.
|
||||||
"""
|
"""
|
||||||
old_modules: Dict[nn.Module, nn.Module] = {}
|
old_modules: dict[nn.Module, nn.Module] = {}
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.op == "call_module":
|
if node.op == "call_module":
|
||||||
assert isinstance(node.target, str)
|
assert isinstance(node.target, str)
|
||||||
@ -200,9 +201,9 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
|
|||||||
|
|
||||||
|
|
||||||
def reset_modules(
|
def reset_modules(
|
||||||
nodes: List[fx.Node],
|
nodes: list[fx.Node],
|
||||||
modules: Dict[str, nn.Module],
|
modules: dict[str, nn.Module],
|
||||||
old_modules: Dict[nn.Module, nn.Module],
|
old_modules: dict[nn.Module, nn.Module],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Maps each module that's been changed with `modules_to_mkldnn` back to its
|
Maps each module that's been changed with `modules_to_mkldnn` back to its
|
||||||
@ -219,9 +220,9 @@ def reset_modules(
|
|||||||
class MklSubgraph:
|
class MklSubgraph:
|
||||||
def __init__(self, fx_graph: fx.Graph):
|
def __init__(self, fx_graph: fx.Graph):
|
||||||
self.fx_graph = fx_graph
|
self.fx_graph = fx_graph
|
||||||
self.nodes: List[fx.Node] = []
|
self.nodes: list[fx.Node] = []
|
||||||
self.start_nodes: List[fx.Node] = []
|
self.start_nodes: list[fx.Node] = []
|
||||||
self.end_nodes: List[fx.Node] = []
|
self.end_nodes: list[fx.Node] = []
|
||||||
|
|
||||||
|
|
||||||
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
|
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
|
||||||
@ -244,7 +245,7 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
|
|||||||
old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
|
old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
|
||||||
ShapeProp(fx_model).propagate(example_inputs)
|
ShapeProp(fx_model).propagate(example_inputs)
|
||||||
sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
|
sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
|
||||||
output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes])
|
output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes])
|
||||||
submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
|
submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
|
||||||
|
|
||||||
def benchmark(f):
|
def benchmark(f):
|
||||||
@ -281,8 +282,8 @@ def use_mkl_length(graph: MklSubgraph) -> bool:
|
|||||||
|
|
||||||
class UnionFind:
|
class UnionFind:
|
||||||
def __init__(self, n):
|
def __init__(self, n):
|
||||||
self.parent: List[Optional[int]] = [None] * n
|
self.parent: list[Optional[int]] = [None] * n
|
||||||
self.size: List[int] = [0] * n
|
self.size: list[int] = [0] * n
|
||||||
|
|
||||||
def make_set(self, v: int):
|
def make_set(self, v: int):
|
||||||
self.parent[v] = v
|
self.parent[v] = v
|
||||||
@ -308,8 +309,8 @@ class UnionFind:
|
|||||||
|
|
||||||
def optimize_for_inference(
|
def optimize_for_inference(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
pass_config: Optional[Dict[str, Any]] = None,
|
pass_config: Optional[dict[str, Any]] = None,
|
||||||
tracer: Type[fx.Tracer] = fx.Tracer,
|
tracer: type[fx.Tracer] = fx.Tracer,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
"""
|
"""
|
||||||
Performs a set of optimization passes to optimize a model for the
|
Performs a set of optimization passes to optimize a model for the
|
||||||
@ -348,7 +349,7 @@ def optimize_for_inference(
|
|||||||
cur_tracer = tracer()
|
cur_tracer = tracer()
|
||||||
fx_graph = cur_tracer.trace(copy.deepcopy(model))
|
fx_graph = cur_tracer.trace(copy.deepcopy(model))
|
||||||
fx.GraphModule(cur_tracer.root, fx_graph)
|
fx.GraphModule(cur_tracer.root, fx_graph)
|
||||||
modules: Dict[str, nn.Module] = dict(model.named_modules())
|
modules: dict[str, nn.Module] = dict(model.named_modules())
|
||||||
|
|
||||||
class MklSupport(Enum):
|
class MklSupport(Enum):
|
||||||
NO = 1
|
NO = 1
|
||||||
@ -388,7 +389,7 @@ def optimize_for_inference(
|
|||||||
node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
|
node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
|
||||||
)
|
)
|
||||||
|
|
||||||
node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
|
node.args = cast(tuple[fx.node.Argument], mkldnn_args)
|
||||||
|
|
||||||
with fx_graph.inserting_after(node):
|
with fx_graph.inserting_after(node):
|
||||||
dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
|
dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
|
||||||
@ -455,7 +456,7 @@ def optimize_for_inference(
|
|||||||
for other_color in cur_colors[1:]:
|
for other_color in cur_colors[1:]:
|
||||||
uf.join(cur_colors[0], other_color)
|
uf.join(cur_colors[0], other_color)
|
||||||
|
|
||||||
mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
|
mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
|
||||||
for node in fx_graph.nodes:
|
for node in fx_graph.nodes:
|
||||||
if hasattr(node, "color"):
|
if hasattr(node, "color"):
|
||||||
mkldnn_graphs[uf.find(node.color)].nodes.append(node)
|
mkldnn_graphs[uf.find(node.color)].nodes.append(node)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, NamedTuple, Set
|
from typing import NamedTuple
|
||||||
|
|
||||||
from torch.fx.node import map_arg, Node
|
from torch.fx.node import map_arg, Node
|
||||||
|
|
||||||
@ -11,13 +11,13 @@ class Partition:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, partition_id: int) -> None:
|
def __init__(self, partition_id: int) -> None:
|
||||||
self.nodes: Set[Node] = set()
|
self.nodes: set[Node] = set()
|
||||||
self.partition_id = partition_id
|
self.partition_id = partition_id
|
||||||
self.parents: Set[Partition] = set()
|
self.parents: set[Partition] = set()
|
||||||
self.children: Set[Partition] = set()
|
self.children: set[Partition] = set()
|
||||||
self.bfs_level: int = -1
|
self.bfs_level: int = -1
|
||||||
self.used_mem_bytes: int = 0
|
self.used_mem_bytes: int = 0
|
||||||
self.logical_device_ids: List[int] = []
|
self.logical_device_ids: list[int] = []
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.partition_id)
|
return str(self.partition_id)
|
||||||
@ -28,7 +28,7 @@ class Partition:
|
|||||||
self.used_mem_bytes += get_extra_size_of(node, self.nodes)
|
self.used_mem_bytes += get_extra_size_of(node, self.nodes)
|
||||||
|
|
||||||
def add_node(self, node):
|
def add_node(self, node):
|
||||||
input_nodes: Dict[Node, None] = {}
|
input_nodes: dict[Node, None] = {}
|
||||||
map_arg(node.args, input_nodes.setdefault)
|
map_arg(node.args, input_nodes.setdefault)
|
||||||
map_arg(node.kwargs, input_nodes.setdefault)
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||||||
# Add current node's input nodes if they are placeholder or constants
|
# Add current node's input nodes if they are placeholder or constants
|
||||||
@ -43,7 +43,7 @@ class Partition:
|
|||||||
if node in self.nodes:
|
if node in self.nodes:
|
||||||
self.nodes.remove(node)
|
self.nodes.remove(node)
|
||||||
# Collect the node's input nodes
|
# Collect the node's input nodes
|
||||||
input_nodes: Dict[Node, None] = {}
|
input_nodes: dict[Node, None] = {}
|
||||||
map_arg(node.args, input_nodes.setdefault)
|
map_arg(node.args, input_nodes.setdefault)
|
||||||
map_arg(node.kwargs, input_nodes.setdefault)
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||||||
# Check if an input node is a placeholder or get_attr,
|
# Check if an input node is a placeholder or get_attr,
|
||||||
@ -88,23 +88,23 @@ class PartitionMode(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class PartitionerConfig(NamedTuple):
|
class PartitionerConfig(NamedTuple):
|
||||||
devices: List[Device]
|
devices: list[Device]
|
||||||
mode: PartitionMode = PartitionMode.size_based
|
mode: PartitionMode = PartitionMode.size_based
|
||||||
transfer_rate_bytes_per_sec: float = 0.0
|
transfer_rate_bytes_per_sec: float = 0.0
|
||||||
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
|
node_to_latency_mapping: dict[Node, NodeLatency] = {}
|
||||||
node_to_partition_mapping: Dict[Node, int] = {}
|
node_to_partition_mapping: dict[Node, int] = {}
|
||||||
partition_to_logical_device_mapping: Dict[int, List[int]] = {}
|
partition_to_logical_device_mapping: dict[int, list[int]] = {}
|
||||||
# Saturate host by replicating partitions to the remaining idle devices.
|
# Saturate host by replicating partitions to the remaining idle devices.
|
||||||
saturate_host: bool = False
|
saturate_host: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
|
def get_extra_size_of(node: Node, nodes: set[Node]) -> int:
|
||||||
"""Given a node and a set of nodes,
|
"""Given a node and a set of nodes,
|
||||||
this function return the extra size that needed
|
this function return the extra size that needed
|
||||||
if this node is included in this set.
|
if this node is included in this set.
|
||||||
"""
|
"""
|
||||||
# Find all its input nodes
|
# Find all its input nodes
|
||||||
input_nodes: Dict[Node, None] = {}
|
input_nodes: dict[Node, None] = {}
|
||||||
map_arg(node.args, input_nodes.setdefault)
|
map_arg(node.args, input_nodes.setdefault)
|
||||||
map_arg(node.kwargs, input_nodes.setdefault)
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||||||
# Calculate total size of related nodes
|
# Calculate total size of related nodes
|
||||||
@ -127,18 +127,18 @@ def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def get_latency_of_one_partition(
|
def get_latency_of_one_partition(
|
||||||
partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
|
partition: Partition, node_to_latency_mapping: dict[Node, NodeLatency]
|
||||||
) -> PartitionLatency:
|
) -> PartitionLatency:
|
||||||
"""Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
|
"""Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
|
||||||
|
|
||||||
def get_top_nodes(partition: Partition) -> List[Node]:
|
def get_top_nodes(partition: Partition) -> list[Node]:
|
||||||
"""Given a partition, return a list of nodes on the top bfs level"""
|
"""Given a partition, return a list of nodes on the top bfs level"""
|
||||||
top_nodes: List[Node] = []
|
top_nodes: list[Node] = []
|
||||||
for node in partition.nodes:
|
for node in partition.nodes:
|
||||||
# Skip placeholder and get_attr nodes
|
# Skip placeholder and get_attr nodes
|
||||||
if node.op in {"placeholder", "get_attr"}:
|
if node.op in {"placeholder", "get_attr"}:
|
||||||
continue
|
continue
|
||||||
input_nodes: Dict[Node, None] = {}
|
input_nodes: dict[Node, None] = {}
|
||||||
map_arg(node.args, input_nodes.setdefault)
|
map_arg(node.args, input_nodes.setdefault)
|
||||||
map_arg(node.kwargs, input_nodes.setdefault)
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||||||
# If a node has no input nodes in this partition,
|
# If a node has no input nodes in this partition,
|
||||||
@ -216,12 +216,12 @@ def get_latency_of_one_partition(
|
|||||||
|
|
||||||
|
|
||||||
def get_partition_to_latency_mapping(
|
def get_partition_to_latency_mapping(
|
||||||
partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
|
partitions: list[Partition], node_to_latency_mapping: dict[Node, NodeLatency]
|
||||||
) -> Dict[Partition, PartitionLatency]:
|
) -> dict[Partition, PartitionLatency]:
|
||||||
"""Given all the partitions and node_to_latency_mapping dictionary,
|
"""Given all the partitions and node_to_latency_mapping dictionary,
|
||||||
return a mapping dictionary of each partition to its overall latency
|
return a mapping dictionary of each partition to its overall latency
|
||||||
"""
|
"""
|
||||||
partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
|
partition_to_latency_mapping: dict[Partition, PartitionLatency] = {}
|
||||||
# Go through each partition and get its latency
|
# Go through each partition and get its latency
|
||||||
for partition in partitions:
|
for partition in partitions:
|
||||||
partition_latency = get_latency_of_one_partition(
|
partition_latency = get_latency_of_one_partition(
|
||||||
@ -255,7 +255,7 @@ def get_comm_latency_between(
|
|||||||
# the output size of those input nodes will be counted
|
# the output size of those input nodes will be counted
|
||||||
# and added to comm_size
|
# and added to comm_size
|
||||||
for node in child_partition.nodes:
|
for node in child_partition.nodes:
|
||||||
input_nodes: Dict[Node, None] = {}
|
input_nodes: dict[Node, None] = {}
|
||||||
map_arg(node.args, input_nodes.setdefault)
|
map_arg(node.args, input_nodes.setdefault)
|
||||||
map_arg(node.kwargs, input_nodes.setdefault)
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||||||
for n in input_nodes:
|
for n in input_nodes:
|
||||||
@ -268,8 +268,8 @@ def get_comm_latency_between(
|
|||||||
|
|
||||||
|
|
||||||
def get_latency_of_partitioned_graph(
|
def get_latency_of_partitioned_graph(
|
||||||
partitions: List[Partition],
|
partitions: list[Partition],
|
||||||
partition_to_latency_mapping: Dict[Partition, PartitionLatency],
|
partition_to_latency_mapping: dict[Partition, PartitionLatency],
|
||||||
transfer_rate_bytes_per_sec: float,
|
transfer_rate_bytes_per_sec: float,
|
||||||
):
|
):
|
||||||
"""Given all partitions in a graph, find the critical path among all partitions
|
"""Given all partitions in a graph, find the critical path among all partitions
|
||||||
@ -298,7 +298,7 @@ def get_latency_of_partitioned_graph(
|
|||||||
return max_latency_sec
|
return max_latency_sec
|
||||||
return latency_so_far_sec
|
return latency_so_far_sec
|
||||||
|
|
||||||
def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
|
def get_top_partitions(partitions: list[Partition]) -> list[Partition]:
|
||||||
"""This function is to return all the partitions without parents
|
"""This function is to return all the partitions without parents
|
||||||
as the starting points of all the paths
|
as the starting points of all the paths
|
||||||
"""
|
"""
|
||||||
|
@ -17,21 +17,15 @@ import typing_extensions
|
|||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
|
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
overload,
|
overload,
|
||||||
Protocol,
|
Protocol,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -168,7 +162,7 @@ from torch.types import py_sym_types, PySymType
|
|||||||
|
|
||||||
|
|
||||||
class _HasMeta(Protocol):
|
class _HasMeta(Protocol):
|
||||||
meta: Dict[str, PySymType]
|
meta: dict[str, PySymType]
|
||||||
|
|
||||||
|
|
||||||
def is_sym_node(node: _HasMeta) -> bool:
|
def is_sym_node(node: _HasMeta) -> bool:
|
||||||
@ -377,9 +371,9 @@ _ExtractValType = Optional[
|
|||||||
PySymType,
|
PySymType,
|
||||||
_AnyScriptObjectType,
|
_AnyScriptObjectType,
|
||||||
BackwardState,
|
BackwardState,
|
||||||
List["_ExtractValType"],
|
list["_ExtractValType"],
|
||||||
Tuple["_ExtractValType", ...],
|
tuple["_ExtractValType", ...],
|
||||||
Dict[str, "_ExtractValType"],
|
dict[str, "_ExtractValType"],
|
||||||
Tensor,
|
Tensor,
|
||||||
int,
|
int,
|
||||||
float,
|
float,
|
||||||
@ -767,10 +761,10 @@ def proxy_call(
|
|||||||
proxy_mode: ProxyTorchDispatchMode,
|
proxy_mode: ProxyTorchDispatchMode,
|
||||||
func: OpOverload,
|
func: OpOverload,
|
||||||
pre_dispatch: bool,
|
pre_dispatch: bool,
|
||||||
args: Tuple[object, ...],
|
args: tuple[object, ...],
|
||||||
kwargs: Dict[str, object],
|
kwargs: dict[str, object],
|
||||||
) -> object:
|
) -> object:
|
||||||
unrecognized_types: List[Type] = []
|
unrecognized_types: list[type] = []
|
||||||
flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs))
|
flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs))
|
||||||
|
|
||||||
def can_handle_tensor(x: Tensor) -> bool:
|
def can_handle_tensor(x: Tensor) -> bool:
|
||||||
@ -987,7 +981,7 @@ class _SymNodeDict:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.sym_node_dict: Dict[PySymType, _PySymProxyType] = {}
|
self.sym_node_dict: dict[PySymType, _PySymProxyType] = {}
|
||||||
|
|
||||||
def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None:
|
def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None:
|
||||||
self.sym_node_dict[key.node] = value
|
self.sym_node_dict[key.node] = value
|
||||||
@ -1015,9 +1009,9 @@ class _SymNodeDict:
|
|||||||
class PythonKeyTracer(Tracer):
|
class PythonKeyTracer(Tracer):
|
||||||
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
|
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
|
||||||
symnode_tracker: _SymNodeDict
|
symnode_tracker: _SymNodeDict
|
||||||
sympy_expr_tracker: Dict[sympy.Symbol, object]
|
sympy_expr_tracker: dict[sympy.Symbol, object]
|
||||||
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
|
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
|
||||||
torch_fn_counts: Dict[OpOverload, int]
|
torch_fn_counts: dict[OpOverload, int]
|
||||||
enable_thunkify: bool = False
|
enable_thunkify: bool = False
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -1043,14 +1037,14 @@ class PythonKeyTracer(Tracer):
|
|||||||
self,
|
self,
|
||||||
m: Module,
|
m: Module,
|
||||||
forward: Callable[..., Any],
|
forward: Callable[..., Any],
|
||||||
args: Tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
return forward(*args, **kwargs)
|
return forward(*args, **kwargs)
|
||||||
|
|
||||||
# We don't want to turn getattr calls into proxies. So we just return the actual value.
|
# We don't want to turn getattr calls into proxies. So we just return the actual value.
|
||||||
def getattr(
|
def getattr(
|
||||||
self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy]
|
self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
|
||||||
) -> object:
|
) -> object:
|
||||||
return attr_val
|
return attr_val
|
||||||
|
|
||||||
@ -1095,7 +1089,7 @@ class PythonKeyTracer(Tracer):
|
|||||||
|
|
||||||
|
|
||||||
def _make_temp_remove_mode_context_manager(
|
def _make_temp_remove_mode_context_manager(
|
||||||
mode_ty: Type[TorchFunctionMode],
|
mode_ty: type[TorchFunctionMode],
|
||||||
) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]:
|
) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]:
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]:
|
def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]:
|
||||||
@ -1137,7 +1131,7 @@ def _make_temp_remove_mode_context_manager(
|
|||||||
def dispatch_trace(
|
def dispatch_trace(
|
||||||
root: Union[Module, Callable],
|
root: Union[Module, Callable],
|
||||||
tracer: Tracer,
|
tracer: Tracer,
|
||||||
concrete_args: Optional[Tuple[Any, ...]] = None,
|
concrete_args: Optional[tuple[Any, ...]] = None,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
|
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
|
||||||
|
|
||||||
@ -1235,9 +1229,9 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
|
|||||||
def __torch_function__(
|
def __torch_function__(
|
||||||
self,
|
self,
|
||||||
func: OpOverload,
|
func: OpOverload,
|
||||||
types: Tuple[torch._C._TensorMeta, ...],
|
types: tuple[torch._C._TensorMeta, ...],
|
||||||
args: Tuple[object, ...] = (),
|
args: tuple[object, ...] = (),
|
||||||
kwargs: Optional[Dict[str, object]] = None,
|
kwargs: Optional[dict[str, object]] = None,
|
||||||
) -> object:
|
) -> object:
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
self.tracer.torch_fn_metadata = func
|
self.tracer.torch_fn_metadata = func
|
||||||
@ -1259,14 +1253,14 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
|||||||
# The input to torch.amp.autocast_mode._exit_autocast graph node should be the
|
# The input to torch.amp.autocast_mode._exit_autocast graph node should be the
|
||||||
# enter_autocast node. So we have to save the enter autocast node here, and assign it
|
# enter_autocast node. So we have to save the enter autocast node here, and assign it
|
||||||
# to the exit_autocast call_function node.
|
# to the exit_autocast call_function node.
|
||||||
self.enter_autocast_nodes: List[torch.fx.Node] = []
|
self.enter_autocast_nodes: list[torch.fx.Node] = []
|
||||||
|
|
||||||
def __torch_function__(
|
def __torch_function__(
|
||||||
self,
|
self,
|
||||||
func: Union[OpOverload, Callable],
|
func: Union[OpOverload, Callable],
|
||||||
types: Tuple[torch._C._TensorMeta, ...],
|
types: tuple[torch._C._TensorMeta, ...],
|
||||||
args: Tuple[object, ...] = (),
|
args: tuple[object, ...] = (),
|
||||||
kwargs: Optional[Dict[str, object]] = None,
|
kwargs: Optional[dict[str, object]] = None,
|
||||||
) -> object:
|
) -> object:
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
if func in _side_effectful_need_to_be_preserved_pre_dispatch:
|
if func in _side_effectful_need_to_be_preserved_pre_dispatch:
|
||||||
@ -1324,7 +1318,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
# Every time we enter a mode, we maintain a stack telling us what the previous
|
# Every time we enter a mode, we maintain a stack telling us what the previous
|
||||||
# ProxyTorchDispatchMode state was (if there was any).
|
# ProxyTorchDispatchMode state was (if there was any).
|
||||||
# This lets us properly reset the state on exit.
|
# This lets us properly reset the state on exit.
|
||||||
self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = []
|
self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = []
|
||||||
self.decomp_layers = 0
|
self.decomp_layers = 0
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
|
|
||||||
@ -1334,9 +1328,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
def __torch_dispatch__(
|
def __torch_dispatch__(
|
||||||
self,
|
self,
|
||||||
func: OpOverload,
|
func: OpOverload,
|
||||||
types: Tuple[torch._C._TensorMeta, ...],
|
types: tuple[torch._C._TensorMeta, ...],
|
||||||
args: Tuple[object, ...] = (),
|
args: tuple[object, ...] = (),
|
||||||
kwargs: Optional[Dict[str, object]] = None,
|
kwargs: Optional[dict[str, object]] = None,
|
||||||
) -> object:
|
) -> object:
|
||||||
with set_original_aten_op(func):
|
with set_original_aten_op(func):
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
@ -1354,7 +1348,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
|
|
||||||
def __exit__(
|
def __exit__(
|
||||||
self,
|
self,
|
||||||
exc_type: Optional[Type[BaseException]],
|
exc_type: Optional[type[BaseException]],
|
||||||
exc_value: Optional[BaseException],
|
exc_value: Optional[BaseException],
|
||||||
traceback: Optional[types.TracebackType],
|
traceback: Optional[types.TracebackType],
|
||||||
) -> Optional[bool]:
|
) -> Optional[bool]:
|
||||||
@ -1372,10 +1366,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _compute_proxy(
|
def _compute_proxy(
|
||||||
self, func: OpOverload, args: Tuple[object, ...], out: PySymType
|
self, func: OpOverload, args: tuple[object, ...], out: PySymType
|
||||||
) -> Proxy:
|
) -> Proxy:
|
||||||
# Handle torch.sym_sum
|
# Handle torch.sym_sum
|
||||||
n_args: Tuple[object, ...]
|
n_args: tuple[object, ...]
|
||||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||||
n_args = (
|
n_args = (
|
||||||
tuple(
|
tuple(
|
||||||
@ -1403,9 +1397,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
def __sym_dispatch__(
|
def __sym_dispatch__(
|
||||||
self,
|
self,
|
||||||
func: OpOverload,
|
func: OpOverload,
|
||||||
types: Tuple[torch._C._TensorMeta, ...],
|
types: tuple[torch._C._TensorMeta, ...],
|
||||||
args: Tuple[object, ...],
|
args: tuple[object, ...],
|
||||||
kwargs: Dict[str, object],
|
kwargs: dict[str, object],
|
||||||
) -> object:
|
) -> object:
|
||||||
# Peephole optimize multiply by one
|
# Peephole optimize multiply by one
|
||||||
# NB: be careful not to trigger guards here!
|
# NB: be careful not to trigger guards here!
|
||||||
@ -1438,9 +1432,9 @@ class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
|
|||||||
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
|
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
|
||||||
symnode_tracker: MutableMapping[PySymType, _PySymProxyType]
|
symnode_tracker: MutableMapping[PySymType, _PySymProxyType]
|
||||||
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
|
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
|
||||||
sympy_expr_tracker: Dict[sympy.Symbol, object]
|
sympy_expr_tracker: dict[sympy.Symbol, object]
|
||||||
torch_fn_metadata: Optional[OpOverload]
|
torch_fn_metadata: Optional[OpOverload]
|
||||||
torch_fn_counts: Dict[OpOverload, int]
|
torch_fn_counts: dict[OpOverload, int]
|
||||||
enable_thunkify: bool = False
|
enable_thunkify: bool = False
|
||||||
|
|
||||||
def __init__(self, graph: fx.graph.Graph) -> None:
|
def __init__(self, graph: fx.graph.Graph) -> None:
|
||||||
@ -1476,7 +1470,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
||||||
|
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
|
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
|
||||||
) -> object:
|
) -> object:
|
||||||
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
out = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||||
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
|
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
|
||||||
@ -1485,7 +1479,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def get_attr(
|
def get_attr(
|
||||||
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
|
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
|
||||||
) -> object:
|
) -> object:
|
||||||
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
|
out = super().get_attr(target, args, kwargs) # type: ignore[arg-type]
|
||||||
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
|
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
|
||||||
@ -1495,7 +1489,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
# call_function, call_method, call_module get traced automatically by the outer mode.
|
# call_function, call_method, call_module get traced automatically by the outer mode.
|
||||||
|
|
||||||
def output(
|
def output(
|
||||||
self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override]
|
self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override]
|
||||||
) -> object:
|
) -> object:
|
||||||
out = super().output(target, args, kwargs) # type: ignore[arg-type]
|
out = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
@ -1516,13 +1510,13 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||||||
|
|
||||||
|
|
||||||
def wrapper_and_args_for_make_fx(
|
def wrapper_and_args_for_make_fx(
|
||||||
func: Callable[..., R], args: Tuple[object, ...], kwargs: Dict[str, object]
|
func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object]
|
||||||
) -> Tuple[Callable[[List[object]], R], List[object]]:
|
) -> tuple[Callable[[list[object]], R], list[object]]:
|
||||||
# make_fx doesn't support kwargs, so we need to do this flattening
|
# make_fx doesn't support kwargs, so we need to do this flattening
|
||||||
# and then unflatten the args before calling func
|
# and then unflatten the args before calling func
|
||||||
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
||||||
|
|
||||||
def wrapped(flat_args: List[object]) -> R:
|
def wrapped(flat_args: list[object]) -> R:
|
||||||
fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec)
|
fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec)
|
||||||
return func(*fn_args, **fn_kwargs)
|
return func(*fn_args, **fn_kwargs)
|
||||||
|
|
||||||
@ -1642,7 +1636,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
return tracer.proxy_modules[self]
|
return tracer.proxy_modules[self]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _modules(self) -> Dict[str, AttrProxy]:
|
def _modules(self) -> dict[str, AttrProxy]:
|
||||||
assert "_modules" in self.__dict__
|
assert "_modules" in self.__dict__
|
||||||
submodules = self.__dict__["_modules"]
|
submodules = self.__dict__["_modules"]
|
||||||
assert isinstance(submodules, dict)
|
assert isinstance(submodules, dict)
|
||||||
@ -1674,7 +1668,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
raise _ModuleNotInstalledAsSubmoduleError from e
|
raise _ModuleNotInstalledAsSubmoduleError from e
|
||||||
|
|
||||||
def getattr(
|
def getattr(
|
||||||
self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy]
|
self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
|
||||||
) -> object:
|
) -> object:
|
||||||
if (
|
if (
|
||||||
not isinstance(attr_val, Module)
|
not isinstance(attr_val, Module)
|
||||||
@ -1693,7 +1687,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
return self.attr_proxy_map[attr_val]
|
return self.attr_proxy_map[attr_val]
|
||||||
|
|
||||||
def trace( # type: ignore[override]
|
def trace( # type: ignore[override]
|
||||||
self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]]
|
self, root: Union[Module, Callable], concrete_args: Optional[dict[str, object]]
|
||||||
) -> fx.Graph:
|
) -> fx.Graph:
|
||||||
res = super().trace(root, concrete_args)
|
res = super().trace(root, concrete_args)
|
||||||
|
|
||||||
@ -1702,7 +1696,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
# to the tracer while tracing, the proxy object gets registered
|
# to the tracer while tracing, the proxy object gets registered
|
||||||
# first. So we need to replace the proxy modules with the real ones
|
# first. So we need to replace the proxy modules with the real ones
|
||||||
# This can happen during HOO tracing
|
# This can happen during HOO tracing
|
||||||
proxy_module_names_to_be_replaced: List[Tuple[str, _AttrProxy]] = []
|
proxy_module_names_to_be_replaced: list[tuple[str, _AttrProxy]] = []
|
||||||
for name, module in self.root.named_modules():
|
for name, module in self.root.named_modules():
|
||||||
if module in self.proxy_modules:
|
if module in self.proxy_modules:
|
||||||
proxy_module_names_to_be_replaced.append((name, module))
|
proxy_module_names_to_be_replaced.append((name, module))
|
||||||
@ -1746,8 +1740,8 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||||||
self,
|
self,
|
||||||
m: Module,
|
m: Module,
|
||||||
forward: Callable,
|
forward: Callable,
|
||||||
args: Tuple[object, ...],
|
args: tuple[object, ...],
|
||||||
kwargs: Dict[str, object],
|
kwargs: dict[str, object],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""PythonKeyTracer overrides call_module to avoid the scope handling,
|
"""PythonKeyTracer overrides call_module to avoid the scope handling,
|
||||||
but we actually want it.
|
but we actually want it.
|
||||||
@ -1857,7 +1851,7 @@ class _MakefxTracer:
|
|||||||
) -> None:
|
) -> None:
|
||||||
# Configurations that are used to initialize the context managers and their states.
|
# Configurations that are used to initialize the context managers and their states.
|
||||||
# Should not modify them during tracing.
|
# Should not modify them during tracing.
|
||||||
self.decomposition_table: Dict[OpOverload, Callable] = dict(
|
self.decomposition_table: dict[OpOverload, Callable] = dict(
|
||||||
decomposition_table or {}
|
decomposition_table or {}
|
||||||
)
|
)
|
||||||
self.decomposition_table.setdefault(
|
self.decomposition_table.setdefault(
|
||||||
@ -1885,7 +1879,7 @@ class _MakefxTracer:
|
|||||||
nullcontext, TorchFunctionMetadataMode
|
nullcontext, TorchFunctionMetadataMode
|
||||||
] = nullcontext()
|
] = nullcontext()
|
||||||
|
|
||||||
def _checkpoint_modes(self) -> List[Any]:
|
def _checkpoint_modes(self) -> list[Any]:
|
||||||
return [
|
return [
|
||||||
self.fake_tensor_mode,
|
self.fake_tensor_mode,
|
||||||
self.proxy_mode,
|
self.proxy_mode,
|
||||||
@ -1913,7 +1907,7 @@ class _MakefxTracer:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _init_modes_from_inputs(
|
def _init_modes_from_inputs(
|
||||||
self, f: Callable, args: Tuple[object, ...]
|
self, f: Callable, args: tuple[object, ...]
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
prev_modes = self._checkpoint_modes()
|
prev_modes = self._checkpoint_modes()
|
||||||
try:
|
try:
|
||||||
@ -2202,7 +2196,7 @@ def make_fx(
|
|||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def get_torch_dispatch_modes() -> List[TorchDispatchMode]:
|
def get_torch_dispatch_modes() -> list[TorchDispatchMode]:
|
||||||
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
||||||
|
|
||||||
|
|
||||||
@ -2240,7 +2234,7 @@ def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs)
|
|||||||
# dispatch machinery which disables it for us
|
# dispatch machinery which disables it for us
|
||||||
with disable_proxy_modes_tracing():
|
with disable_proxy_modes_tracing():
|
||||||
# TODO: properly compute types
|
# TODO: properly compute types
|
||||||
types: List[Type] = []
|
types: list[type] = []
|
||||||
return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value]
|
return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value]
|
||||||
|
|
||||||
|
|
||||||
@ -2252,8 +2246,8 @@ def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, Non
|
|||||||
def maybe_handle_decomp(
|
def maybe_handle_decomp(
|
||||||
proxy_mode: ProxyTorchDispatchMode,
|
proxy_mode: ProxyTorchDispatchMode,
|
||||||
op: OpOverload,
|
op: OpOverload,
|
||||||
args: Tuple[object, ...],
|
args: tuple[object, ...],
|
||||||
kwargs: Dict[str, object],
|
kwargs: dict[str, object],
|
||||||
) -> object:
|
) -> object:
|
||||||
from torch._inductor.compiler_bisector import CompilerBisector
|
from torch._inductor.compiler_bisector import CompilerBisector
|
||||||
|
|
||||||
@ -2274,8 +2268,8 @@ def maybe_handle_decomp(
|
|||||||
|
|
||||||
def get_isolated_graphmodule(
|
def get_isolated_graphmodule(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
args: Tuple[object, ...],
|
args: tuple[object, ...],
|
||||||
kwargs: Dict[str, object],
|
kwargs: dict[str, object],
|
||||||
tracing_mode: str = "real",
|
tracing_mode: str = "real",
|
||||||
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
|
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
|
@ -4,7 +4,7 @@ import inspect
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
@ -83,11 +83,11 @@ class ShapeEnvEvent:
|
|||||||
f: Callable
|
f: Callable
|
||||||
|
|
||||||
# Arguments and keyword arguments called with.
|
# Arguments and keyword arguments called with.
|
||||||
args: Optional[List[Any]] = None
|
args: Optional[list[Any]] = None
|
||||||
kwargs: Optional[Dict[str, Any]] = None
|
kwargs: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
# List of tracked_fakes at the time the method was called.
|
# List of tracked_fakes at the time the method was called.
|
||||||
tracked_fakes: Optional[List[Any]] = None
|
tracked_fakes: Optional[list[Any]] = None
|
||||||
|
|
||||||
# Name of the captured event.
|
# Name of the captured event.
|
||||||
# Used for special handling of particular methods.
|
# Used for special handling of particular methods.
|
||||||
@ -344,15 +344,15 @@ def replay_shape_env_events(events):
|
|||||||
# ShapeEnv.produce_guards.
|
# ShapeEnv.produce_guards.
|
||||||
@dataclass
|
@dataclass
|
||||||
class FakeTensorMeta:
|
class FakeTensorMeta:
|
||||||
tensor_size: Tuple[Union[int, torch.SymInt], ...]
|
tensor_size: tuple[Union[int, torch.SymInt], ...]
|
||||||
tensor_stride: Tuple[Union[int, torch.SymInt], ...]
|
tensor_stride: tuple[Union[int, torch.SymInt], ...]
|
||||||
tensor_storage_offset: Union[int, torch.SymInt]
|
tensor_storage_offset: Union[int, torch.SymInt]
|
||||||
is_nested: bool
|
is_nested: bool
|
||||||
|
|
||||||
def size(self) -> Tuple[Union[int, torch.SymInt], ...]:
|
def size(self) -> tuple[Union[int, torch.SymInt], ...]:
|
||||||
return self.tensor_size
|
return self.tensor_size
|
||||||
|
|
||||||
def stride(self) -> Tuple[Union[int, torch.SymInt], ...]:
|
def stride(self) -> tuple[Union[int, torch.SymInt], ...]:
|
||||||
return self.tensor_stride
|
return self.tensor_stride
|
||||||
|
|
||||||
def storage_offset(self) -> Union[int, torch.SymInt]:
|
def storage_offset(self) -> Union[int, torch.SymInt]:
|
||||||
@ -445,7 +445,7 @@ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value)
|
|||||||
# compare the two values.
|
# compare the two values.
|
||||||
def compare_vars(
|
def compare_vars(
|
||||||
map_value: Callable[[str, Any], Any]
|
map_value: Callable[[str, Any], Any]
|
||||||
) -> List[Tuple[str, str, str]]:
|
) -> list[tuple[str, str, str]]:
|
||||||
env1_set, env2_set = set(env1_vars), set(env2_vars)
|
env1_set, env2_set = set(env1_vars), set(env2_vars)
|
||||||
|
|
||||||
# First, compare the set of keys in each vars dictionary.
|
# First, compare the set of keys in each vars dictionary.
|
||||||
@ -489,7 +489,7 @@ class NotEqualError(Exception):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
msg: str,
|
msg: str,
|
||||||
mismatched: List[Tuple[str, str, str]],
|
mismatched: list[tuple[str, str, str]],
|
||||||
) -> None:
|
) -> None:
|
||||||
details = "\n".join(
|
details = "\n".join(
|
||||||
[
|
[
|
||||||
|
@ -6,7 +6,7 @@ import functools
|
|||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
from typing import Any, Callable, cast, Dict, Optional, Union
|
from typing import Any, Callable, cast, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._sources import normalize_source_lines
|
from torch._sources import normalize_source_lines
|
||||||
@ -112,7 +112,7 @@ class RewritingTracer(Tracer):
|
|||||||
def trace(
|
def trace(
|
||||||
self,
|
self,
|
||||||
root: Union[torch.nn.Module, Callable],
|
root: Union[torch.nn.Module, Callable],
|
||||||
concrete_args: Optional[Dict[str, Any]] = None,
|
concrete_args: Optional[dict[str, Any]] = None,
|
||||||
) -> Graph:
|
) -> Graph:
|
||||||
return super().trace(_rewrite(root), concrete_args)
|
return super().trace(_rewrite(root), concrete_args)
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
@ -42,7 +42,7 @@ class AnnotateTypesWithSchema(Transformer):
|
|||||||
self.annotate_get_attrs = annotate_get_attrs
|
self.annotate_get_attrs = annotate_get_attrs
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
):
|
):
|
||||||
python_ret_type = None
|
python_ret_type = None
|
||||||
if self.annotate_functionals and target.__module__ == "torch.nn.functional":
|
if self.annotate_functionals and target.__module__ == "torch.nn.functional":
|
||||||
@ -73,7 +73,7 @@ class AnnotateTypesWithSchema(Transformer):
|
|||||||
return return_proxy
|
return return_proxy
|
||||||
|
|
||||||
def call_module(
|
def call_module(
|
||||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
):
|
):
|
||||||
python_ret_type = None
|
python_ret_type = None
|
||||||
assert isinstance(target, str)
|
assert isinstance(target, str)
|
||||||
@ -91,8 +91,8 @@ class AnnotateTypesWithSchema(Transformer):
|
|||||||
def get_attr(
|
def get_attr(
|
||||||
self,
|
self,
|
||||||
target: torch.fx.node.Target,
|
target: torch.fx.node.Target,
|
||||||
args: Tuple[Argument, ...],
|
args: tuple[Argument, ...],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
):
|
):
|
||||||
attr_proxy = super().get_attr(target, args, kwargs)
|
attr_proxy = super().get_attr(target, args, kwargs)
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any, DefaultDict, Dict, List, Tuple, Union
|
from collections import defaultdict
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sympy as sp
|
import sympy as sp
|
||||||
@ -13,10 +14,10 @@ s_pattern = r"s\d+"
|
|||||||
|
|
||||||
|
|
||||||
def infer_symbol_values(
|
def infer_symbol_values(
|
||||||
symints: List[Union[torch.SymInt, int]],
|
symints: list[Union[torch.SymInt, int]],
|
||||||
init_symints: List[Union[torch.SymInt, int]],
|
init_symints: list[Union[torch.SymInt, int]],
|
||||||
symbol_idx_dict: Dict[str, int],
|
symbol_idx_dict: dict[str, int],
|
||||||
padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]],
|
padding_constraints: defaultdict[torch.SymInt, list[Union[sp.Expr, int]]],
|
||||||
constraint: str,
|
constraint: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if constraint.find("non-singleton") != -1:
|
if constraint.find("non-singleton") != -1:
|
||||||
@ -83,8 +84,8 @@ def infer_symbol_values(
|
|||||||
def calculate_value(
|
def calculate_value(
|
||||||
left_expression: Union[str, Any, None],
|
left_expression: Union[str, Any, None],
|
||||||
right_expression: Union[str, Any, None],
|
right_expression: Union[str, Any, None],
|
||||||
symints: List[Union[torch.SymInt, int]],
|
symints: list[Union[torch.SymInt, int]],
|
||||||
symbol_idx_dict: Dict[str, int],
|
symbol_idx_dict: dict[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
var, val = solve_equation(left_expression, right_expression)
|
var, val = solve_equation(left_expression, right_expression)
|
||||||
idx = symbol_idx_dict[var]
|
idx = symbol_idx_dict[var]
|
||||||
@ -95,7 +96,7 @@ def calculate_value(
|
|||||||
def solve_equation(
|
def solve_equation(
|
||||||
left_expression: Union[str, Any, None],
|
left_expression: Union[str, Any, None],
|
||||||
right_expression: Union[str, Any, None],
|
right_expression: Union[str, Any, None],
|
||||||
) -> Tuple[str, int]:
|
) -> tuple[str, int]:
|
||||||
expression = f"{left_expression} - {right_expression}"
|
expression = f"{left_expression} - {right_expression}"
|
||||||
var = re.findall(s_pattern, expression)[0]
|
var = re.findall(s_pattern, expression)[0]
|
||||||
if re.findall(parentheses_pattern, expression):
|
if re.findall(parentheses_pattern, expression):
|
||||||
@ -116,9 +117,9 @@ def solve_equation(
|
|||||||
|
|
||||||
|
|
||||||
def update_equation(
|
def update_equation(
|
||||||
symints: List[Union[torch.SymInt, int]],
|
symints: list[Union[torch.SymInt, int]],
|
||||||
init_symints: List[Union[torch.SymInt, int]],
|
init_symints: list[Union[torch.SymInt, int]],
|
||||||
padding_constraints: DefaultDict[torch.SymInt, List[Union[sp.Expr, int]]],
|
padding_constraints: defaultdict[torch.SymInt, list[Union[sp.Expr, int]]],
|
||||||
init_eq: sp.Expr,
|
init_eq: sp.Expr,
|
||||||
new_mod_num: int,
|
new_mod_num: int,
|
||||||
var: torch.SymInt,
|
var: torch.SymInt,
|
||||||
|
@ -20,7 +20,7 @@ import math
|
|||||||
import operator
|
import operator
|
||||||
import sys
|
import sys
|
||||||
from functools import lru_cache, update_wrapper
|
from functools import lru_cache, update_wrapper
|
||||||
from typing import Optional, Type, TYPE_CHECKING, Union
|
from typing import Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -1272,7 +1272,7 @@ def _make_node_magic(method, func):
|
|||||||
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
|
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
|
||||||
raise
|
raise
|
||||||
sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
|
sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
|
||||||
pytype: Type
|
pytype: type
|
||||||
# This is not strictly correct. In Python, a**b may return complex when
|
# This is not strictly correct. In Python, a**b may return complex when
|
||||||
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
|
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
|
||||||
# returns a float while both arguments are ints: 2**(-1). Also, max and
|
# returns a float while both arguments are ints: 2**(-1). Also, max and
|
||||||
@ -1335,7 +1335,7 @@ def _make_node_magic(method, func):
|
|||||||
out_hint = None
|
out_hint = None
|
||||||
if self.hint is not None:
|
if self.hint is not None:
|
||||||
out_hint = op(self.hint)
|
out_hint = op(self.hint)
|
||||||
pytype: Type
|
pytype: type
|
||||||
if method in always_int_magic_methods:
|
if method in always_int_magic_methods:
|
||||||
pytype = int
|
pytype = int
|
||||||
elif method in always_bool_magic_methods:
|
elif method in always_bool_magic_methods:
|
||||||
@ -1485,7 +1485,7 @@ def _make_node_sizes_strides(method, func):
|
|||||||
out_hint = op(size_hints, stride_hints)
|
out_hint = op(size_hints, stride_hints)
|
||||||
|
|
||||||
# NB: This is the indicator function, not the actual bool!
|
# NB: This is the indicator function, not the actual bool!
|
||||||
pytype: Type
|
pytype: type
|
||||||
if method.endswith("_indicator"):
|
if method.endswith("_indicator"):
|
||||||
pytype = int
|
pytype = int
|
||||||
else:
|
else:
|
||||||
|
@ -23,7 +23,8 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
from collections import defaultdict
|
from collections import Counter, defaultdict
|
||||||
|
from collections.abc import Iterator, Mapping, Sequence
|
||||||
from contextlib import _GeneratorContextManager, contextmanager
|
from contextlib import _GeneratorContextManager, contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -31,19 +32,9 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
cast,
|
cast,
|
||||||
Counter,
|
|
||||||
DefaultDict,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
NoReturn,
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -104,8 +95,8 @@ if TYPE_CHECKING:
|
|||||||
from torch.types import BoolLikeType
|
from torch.types import BoolLikeType
|
||||||
|
|
||||||
|
|
||||||
InputList = List
|
InputList = list
|
||||||
DimList = List
|
DimList = list
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -236,8 +227,8 @@ class SymIntEqByExpr:
|
|||||||
|
|
||||||
|
|
||||||
def _nested_int_aware_sort(
|
def _nested_int_aware_sort(
|
||||||
tup: Tuple[Union[SymInt, int], int]
|
tup: tuple[Union[SymInt, int], int]
|
||||||
) -> Tuple[int, Union[SymInt, int], int]:
|
) -> tuple[int, Union[SymInt, int], int]:
|
||||||
return (
|
return (
|
||||||
# Order nested ints by their coefficients.
|
# Order nested ints by their coefficients.
|
||||||
# 1 here to order nested ints after non-nested-ints.
|
# 1 here to order nested ints after non-nested-ints.
|
||||||
@ -289,7 +280,7 @@ def lru_cache(
|
|||||||
# These are modules that contain generic code for interacting with ShapeEnv
|
# These are modules that contain generic code for interacting with ShapeEnv
|
||||||
# which are unlikely to identify a particular interesting guard statement
|
# which are unlikely to identify a particular interesting guard statement
|
||||||
@lru_cache(None)
|
@lru_cache(None)
|
||||||
def uninteresting_files() -> Set[str]:
|
def uninteresting_files() -> set[str]:
|
||||||
import torch._compile
|
import torch._compile
|
||||||
import torch._dynamo.eval_frame
|
import torch._dynamo.eval_frame
|
||||||
import torch._inductor.sizevars
|
import torch._inductor.sizevars
|
||||||
@ -332,8 +323,8 @@ def has_symbolic_sizes_strides(elem: torch.Tensor) -> bool:
|
|||||||
Int: TypeAlias = Union[torch.SymInt, int]
|
Int: TypeAlias = Union[torch.SymInt, int]
|
||||||
|
|
||||||
|
|
||||||
def create_contiguous(shape: Sequence[Int]) -> List[Int]:
|
def create_contiguous(shape: Sequence[Int]) -> list[Int]:
|
||||||
strides: List[Int] = [1]
|
strides: list[Int] = [1]
|
||||||
for dim in reversed(shape[:-1]):
|
for dim in reversed(shape[:-1]):
|
||||||
strides.append(dim * strides[-1]) # type: ignore[operator]
|
strides.append(dim * strides[-1]) # type: ignore[operator]
|
||||||
return list(reversed(strides))
|
return list(reversed(strides))
|
||||||
@ -461,15 +452,15 @@ def check_consistent(new: _T, old: _T) -> None:
|
|||||||
|
|
||||||
def resolve_unbacked_bindings(
|
def resolve_unbacked_bindings(
|
||||||
shape_env: Optional[ShapeEnv],
|
shape_env: Optional[ShapeEnv],
|
||||||
bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]],
|
bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
|
||||||
) -> Optional[Dict[sympy.Symbol, pytree.KeyPath]]:
|
) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]:
|
||||||
if bindings is None:
|
if bindings is None:
|
||||||
return None
|
return None
|
||||||
assert shape_env is not None
|
assert shape_env is not None
|
||||||
return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()}
|
return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()}
|
||||||
|
|
||||||
|
|
||||||
Result: TypeAlias = Union[torch.Tensor, Tuple[torch.Tensor, ...]]
|
Result: TypeAlias = Union[torch.Tensor, tuple[torch.Tensor, ...]]
|
||||||
|
|
||||||
|
|
||||||
def rebind_unbacked(
|
def rebind_unbacked(
|
||||||
@ -557,7 +548,7 @@ def rebind_unbacked(
|
|||||||
and len(raw_u1.args) == 2
|
and len(raw_u1.args) == 2
|
||||||
and (
|
and (
|
||||||
raw_u1_args0 := cast(
|
raw_u1_args0 := cast(
|
||||||
Tuple[sympy.Basic, sympy.Basic], raw_u1.args[0]
|
tuple[sympy.Basic, sympy.Basic], raw_u1.args[0]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
and raw_u1_args0[0] == 1
|
and raw_u1_args0[0] == 1
|
||||||
@ -565,7 +556,7 @@ def rebind_unbacked(
|
|||||||
and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol)
|
and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol)
|
||||||
and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1))
|
and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1))
|
||||||
and eq.rhs == 1
|
and eq.rhs == 1
|
||||||
and cast(Tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True)
|
and cast(tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True)
|
||||||
):
|
):
|
||||||
# This is what the pattern match above is testing
|
# This is what the pattern match above is testing
|
||||||
repacked = _sympy_cast_symbool_to_symint_guardless(
|
repacked = _sympy_cast_symbool_to_symint_guardless(
|
||||||
@ -645,8 +636,8 @@ def canonicalize_bool_expr(expr: _T) -> _T:
|
|||||||
|
|
||||||
|
|
||||||
def _sympy_from_args(
|
def _sympy_from_args(
|
||||||
cls: Union[Type[sympy.Add], Type[sympy.Mul]],
|
cls: type[Union[sympy.Add, sympy.Mul]],
|
||||||
args: List[sympy.Expr],
|
args: list[sympy.Expr],
|
||||||
sort: bool = True,
|
sort: bool = True,
|
||||||
is_commutative: Optional[bool] = None,
|
is_commutative: Optional[bool] = None,
|
||||||
) -> sympy.Expr:
|
) -> sympy.Expr:
|
||||||
@ -686,7 +677,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
|
|||||||
return type(expr)(*map(canonicalize_bool_expr, expr.args))
|
return type(expr)(*map(canonicalize_bool_expr, expr.args))
|
||||||
|
|
||||||
opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
|
opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
|
||||||
t: Union[Type[Any]]
|
t: Union[type[Any]]
|
||||||
if isinstance(expr, tuple(opposite.keys())):
|
if isinstance(expr, tuple(opposite.keys())):
|
||||||
rhs = expr.lhs - expr.rhs # type: ignore[attr-defined]
|
rhs = expr.lhs - expr.rhs # type: ignore[attr-defined]
|
||||||
t = opposite[type(expr)] # type: ignore[index]
|
t = opposite[type(expr)] # type: ignore[index]
|
||||||
@ -888,7 +879,7 @@ def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]:
|
|||||||
|
|
||||||
def find_symbol_binding_fx_nodes(
|
def find_symbol_binding_fx_nodes(
|
||||||
graph: torch.fx.Graph,
|
graph: torch.fx.Graph,
|
||||||
) -> Dict[sympy.Symbol, torch.fx.Node]:
|
) -> dict[sympy.Symbol, torch.fx.Node]:
|
||||||
r = {}
|
r = {}
|
||||||
# NB: Prefer first occurrence of symbol
|
# NB: Prefer first occurrence of symbol
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
@ -949,7 +940,7 @@ def compute_unbacked_bindings(
|
|||||||
example_value: object,
|
example_value: object,
|
||||||
old_example_value: Optional[object] = None,
|
old_example_value: Optional[object] = None,
|
||||||
peek: bool = False,
|
peek: bool = False,
|
||||||
) -> Optional[Dict[sympy.Symbol, pytree.KeyPath]]:
|
) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]:
|
||||||
"""
|
"""
|
||||||
After having run fake tensor propagation and producing example_value
|
After having run fake tensor propagation and producing example_value
|
||||||
result, traverse example_value looking for freshly bound unbacked
|
result, traverse example_value looking for freshly bound unbacked
|
||||||
@ -977,7 +968,7 @@ def compute_unbacked_bindings(
|
|||||||
|
|
||||||
def free_unbacked_symbols_with_path(
|
def free_unbacked_symbols_with_path(
|
||||||
a: object, path: pytree.KeyPath, real: Optional[object] = None
|
a: object, path: pytree.KeyPath, real: Optional[object] = None
|
||||||
) -> Dict[sympy.Symbol, pytree.KeyPath]:
|
) -> dict[sympy.Symbol, pytree.KeyPath]:
|
||||||
assert shape_env is not None
|
assert shape_env is not None
|
||||||
r = {}
|
r = {}
|
||||||
if isinstance(a, (tuple, list)):
|
if isinstance(a, (tuple, list)):
|
||||||
@ -1456,11 +1447,11 @@ def guard_float(a: Union[SymFloat, float]) -> float:
|
|||||||
|
|
||||||
|
|
||||||
# Given a GraphModule, return all the FakeTensors for all the placeholders
|
# Given a GraphModule, return all the FakeTensors for all the placeholders
|
||||||
def fx_placeholder_vals(gm: torch.fx.GraphModule) -> List[object]:
|
def fx_placeholder_vals(gm: torch.fx.GraphModule) -> list[object]:
|
||||||
return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"]
|
return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"]
|
||||||
|
|
||||||
|
|
||||||
def fx_placeholder_targets(gm: torch.fx.GraphModule) -> List[str]:
|
def fx_placeholder_targets(gm: torch.fx.GraphModule) -> list[str]:
|
||||||
return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
|
return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
|
||||||
|
|
||||||
|
|
||||||
@ -1475,7 +1466,7 @@ def eval_guards(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> Dict[sympy.Symbol, int]:
|
def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> dict[sympy.Symbol, int]:
|
||||||
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr]
|
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr]
|
||||||
|
|
||||||
|
|
||||||
@ -1617,15 +1608,15 @@ class EqualityConstraint(Constraint):
|
|||||||
form and so the problem reduces to symbolic expression equality.)
|
form and so the problem reduces to symbolic expression equality.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
source_pairs: List[Tuple[Source, Source]]
|
source_pairs: list[tuple[Source, Source]]
|
||||||
derived_equalities: List[
|
derived_equalities: list[
|
||||||
Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]
|
tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]
|
||||||
]
|
]
|
||||||
phantom_symbols: List[sympy.Symbol]
|
phantom_symbols: list[sympy.Symbol]
|
||||||
relaxed_sources: Set[Source]
|
relaxed_sources: set[Source]
|
||||||
|
|
||||||
_parents: Dict[Source, Source] = field(init=False)
|
_parents: dict[Source, Source] = field(init=False)
|
||||||
_defs: Dict[Source, sympy.Expr] = field(init=False)
|
_defs: dict[Source, sympy.Expr] = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -1643,12 +1634,12 @@ class EqualityConstraint(Constraint):
|
|||||||
|
|
||||||
# self._parents is a map from input sources to input sources where, conceptually,
|
# self._parents is a map from input sources to input sources where, conceptually,
|
||||||
# these are directed edges in a union-find forest
|
# these are directed edges in a union-find forest
|
||||||
_parents: Dict[Source, Source] = {}
|
_parents: dict[Source, Source] = {}
|
||||||
object.__setattr__(self, "_parents", _parents)
|
object.__setattr__(self, "_parents", _parents)
|
||||||
# self._defs is a map from input sources to "canonical" symbolic expressions,
|
# self._defs is a map from input sources to "canonical" symbolic expressions,
|
||||||
# i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,
|
# i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,
|
||||||
# not derived Dims)
|
# not derived Dims)
|
||||||
_defs: Dict[Source, sympy.Expr] = {}
|
_defs: dict[Source, sympy.Expr] = {}
|
||||||
object.__setattr__(self, "_defs", _defs)
|
object.__setattr__(self, "_defs", _defs)
|
||||||
|
|
||||||
for source1, source2 in self.source_pairs:
|
for source1, source2 in self.source_pairs:
|
||||||
@ -1838,7 +1829,7 @@ class StatefulSymbolicContext(StatelessSymbolicContext):
|
|||||||
# cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never
|
# cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never
|
||||||
# get recorded in var_to_val, etc.
|
# get recorded in var_to_val, etc.
|
||||||
# TODO(voz): consider a weakref to the shape_env here
|
# TODO(voz): consider a weakref to the shape_env here
|
||||||
shape_env_to_source_to_symbol_cache: Dict[int, Dict[str, sympy.Expr]] = None # type: ignore[assignment]
|
shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = None # type: ignore[assignment]
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
@ -1856,7 +1847,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext):
|
|||||||
flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
|
flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inner_contexts: Dict[str, SymbolicContext] = None # type: ignore[assignment]
|
inner_contexts: dict[str, SymbolicContext] = None # type: ignore[assignment]
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
@ -1875,7 +1866,7 @@ def is_symbolic(
|
|||||||
IndicatorTypes = (IsNonOverlappingAndDenseIndicator,)
|
IndicatorTypes = (IsNonOverlappingAndDenseIndicator,)
|
||||||
|
|
||||||
|
|
||||||
def _expandsums(args: List[sympy.Expr]) -> Tuple[sympy.Expr, bool]:
|
def _expandsums(args: list[sympy.Expr]) -> tuple[sympy.Expr, bool]:
|
||||||
adds, other = [], []
|
adds, other = [], []
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if arg.is_Add:
|
if arg.is_Add:
|
||||||
@ -1912,8 +1903,8 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
|
|||||||
elif exp < 0:
|
elif exp < 0:
|
||||||
return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
|
return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
|
||||||
elif expr.is_Mul:
|
elif expr.is_Mul:
|
||||||
num: List[sympy.Expr] = []
|
num: list[sympy.Expr] = []
|
||||||
den: List[sympy.Expr] = []
|
den: list[sympy.Expr] = []
|
||||||
for arg in expr.args:
|
for arg in expr.args:
|
||||||
if arg.is_Pow and arg.args[1] == -1:
|
if arg.is_Pow and arg.args[1] == -1:
|
||||||
den.append(S.One / arg) # type: ignore[operator, arg-type]
|
den.append(S.One / arg) # type: ignore[operator, arg-type]
|
||||||
@ -1961,7 +1952,7 @@ class _SymbolInfo(NamedTuple):
|
|||||||
def _maybe_evaluate_static_worker(
|
def _maybe_evaluate_static_worker(
|
||||||
expr: _SympyT,
|
expr: _SympyT,
|
||||||
# NB: this is a tuple to ensure it can be LRU cached
|
# NB: this is a tuple to ensure it can be LRU cached
|
||||||
symbol_info: Tuple[_SymbolInfo, ...],
|
symbol_info: tuple[_SymbolInfo, ...],
|
||||||
unbacked_only: bool,
|
unbacked_only: bool,
|
||||||
size_oblivious: bool,
|
size_oblivious: bool,
|
||||||
) -> Optional[_SympyT]:
|
) -> Optional[_SympyT]:
|
||||||
@ -2193,9 +2184,9 @@ class SymExprPrinter(PythonPrinter):
|
|||||||
class _ShapeGuardPrinter(abc.ABC):
|
class _ShapeGuardPrinter(abc.ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
symbol_to_source: Mapping[sympy.Symbol, List[Source]],
|
symbol_to_source: Mapping[sympy.Symbol, list[Source]],
|
||||||
source_ref: Callable[[Source], str],
|
source_ref: Callable[[Source], str],
|
||||||
var_to_sources: Mapping[sympy.Symbol, List[Source]],
|
var_to_sources: Mapping[sympy.Symbol, list[Source]],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.symbol_to_source = symbol_to_source
|
self.symbol_to_source = symbol_to_source
|
||||||
self.source_ref = source_ref
|
self.source_ref = source_ref
|
||||||
@ -2246,7 +2237,7 @@ class ShapeGuardPrinter(ShapeGuardPythonPrinter):
|
|||||||
|
|
||||||
|
|
||||||
class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter):
|
class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter):
|
||||||
def __init__(self, var_to_sources: Mapping[sympy.Symbol, List[Source]]):
|
def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]):
|
||||||
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
|
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
|
||||||
|
|
||||||
|
|
||||||
@ -2261,7 +2252,7 @@ class DynamicDimConstraintPrinter(PythonPrinter):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
symbol_to_source: Dict[sympy.Symbol, List[Source]],
|
symbol_to_source: dict[sympy.Symbol, list[Source]],
|
||||||
source_name_to_debug_name: Mapping[str, str],
|
source_name_to_debug_name: Mapping[str, str],
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -2284,23 +2275,23 @@ class DimConstraints:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
symbol_to_source: Dict[sympy.Symbol, List[Source]],
|
symbol_to_source: dict[sympy.Symbol, list[Source]],
|
||||||
var_to_val: Mapping[sympy.Symbol, sympy.Integer],
|
var_to_val: Mapping[sympy.Symbol, sympy.Integer],
|
||||||
marked_dynamic: Set[sympy.Symbol],
|
marked_dynamic: set[sympy.Symbol],
|
||||||
source_name_to_debug_name: Mapping[str, str],
|
source_name_to_debug_name: Mapping[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
# We try to solve systems of inequalities with 1 free variable.
|
# We try to solve systems of inequalities with 1 free variable.
|
||||||
self._univariate_inequalities: Dict[
|
self._univariate_inequalities: dict[
|
||||||
sympy.Symbol, Set[SympyBoolean]
|
sympy.Symbol, set[SympyBoolean]
|
||||||
] = defaultdict(set)
|
] = defaultdict(set)
|
||||||
# Among them, we prioritize solving for a free variable that has equalities.
|
# Among them, we prioritize solving for a free variable that has equalities.
|
||||||
# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
|
# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
|
||||||
# and removing a symbol from the former => removing it from the latter.
|
# and removing a symbol from the former => removing it from the latter.
|
||||||
self._symbols_with_equalities: Set[sympy.Symbol] = set()
|
self._symbols_with_equalities: set[sympy.Symbol] = set()
|
||||||
# A solution of a free variable with equalities becomes a substitution.
|
# A solution of a free variable with equalities becomes a substitution.
|
||||||
# We use these substitutions to simplify other constraints.
|
# We use these substitutions to simplify other constraints.
|
||||||
# NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.
|
# NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.
|
||||||
self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {}
|
self._substitutions: dict[sympy.Symbol, sympy.Integer] = {}
|
||||||
|
|
||||||
# In general, constraints may have // and % operations.
|
# In general, constraints may have // and % operations.
|
||||||
# Of course, // can be expressed in terms of / and %.
|
# Of course, // can be expressed in terms of / and %.
|
||||||
@ -2308,20 +2299,20 @@ class DimConstraints:
|
|||||||
# We do so by using the values of variables as hints to evaluate %.
|
# We do so by using the values of variables as hints to evaluate %.
|
||||||
# For soundness we record additional congruence guards and solve them separately.
|
# For soundness we record additional congruence guards and solve them separately.
|
||||||
self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val
|
self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val
|
||||||
self._congruences: DefaultDict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set)
|
self._congruences: defaultdict[sympy.Symbol, set[sympy.Expr]] = defaultdict(set)
|
||||||
|
|
||||||
# We do not try to (directly) solve inequalities with > 1 free variables.
|
# We do not try to (directly) solve inequalities with > 1 free variables.
|
||||||
# NOTE: free variables in these inequalities cannot also be in _substitutions.
|
# NOTE: free variables in these inequalities cannot also be in _substitutions.
|
||||||
self._multivariate_inequalities: Set[SympyBoolean] = set()
|
self._multivariate_inequalities: set[SympyBoolean] = set()
|
||||||
|
|
||||||
# We park external equalities between free variables here.
|
# We park external equalities between free variables here.
|
||||||
self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = []
|
self._symbolic_equivalences: list[tuple[Source, sympy.Expr]] = []
|
||||||
|
|
||||||
# Solutions come in two forms:
|
# Solutions come in two forms:
|
||||||
# - (static) specializations
|
# - (static) specializations
|
||||||
# - (dynamic) inequalities / congruences
|
# - (dynamic) inequalities / congruences
|
||||||
self._static_results: Set[str] = set()
|
self._static_results: set[str] = set()
|
||||||
self._dynamic_results: Set[str] = set()
|
self._dynamic_results: set[str] = set()
|
||||||
|
|
||||||
# printer for solutions
|
# printer for solutions
|
||||||
self._dcp = DynamicDimConstraintPrinter(
|
self._dcp = DynamicDimConstraintPrinter(
|
||||||
@ -2329,13 +2320,13 @@ class DimConstraints:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# inconsistencies found on substituting with concrete values / static solutions
|
# inconsistencies found on substituting with concrete values / static solutions
|
||||||
self._inconsistencies: List[str] = []
|
self._inconsistencies: list[str] = []
|
||||||
|
|
||||||
# symbols that are marked dynamic
|
# symbols that are marked dynamic
|
||||||
self._marked_dynamic = marked_dynamic
|
self._marked_dynamic = marked_dynamic
|
||||||
|
|
||||||
# track supported sympy functions and subtract from list of all sympy functions
|
# track supported sympy functions and subtract from list of all sympy functions
|
||||||
self._supported_sympy_functions: Set[sympy.Function] = {
|
self._supported_sympy_functions: set[sympy.Function] = {
|
||||||
Application,
|
Application,
|
||||||
Mod,
|
Mod,
|
||||||
PythonMod,
|
PythonMod,
|
||||||
@ -2488,8 +2479,8 @@ class DimConstraints:
|
|||||||
# these will resolve to either specializations or dynamic equality constraints
|
# these will resolve to either specializations or dynamic equality constraints
|
||||||
self._symbolic_equivalences.append((source, expr))
|
self._symbolic_equivalences.append((source, expr))
|
||||||
|
|
||||||
def _reduce_congruences(self) -> Dict[sympy.Symbol, Set[sympy.Expr]]:
|
def _reduce_congruences(self) -> dict[sympy.Symbol, set[sympy.Expr]]:
|
||||||
reduced_congruences: Dict[sympy.Symbol, Set[sympy.Expr]] = {}
|
reduced_congruences: dict[sympy.Symbol, set[sympy.Expr]] = {}
|
||||||
for s, congruences in self._congruences.items():
|
for s, congruences in self._congruences.items():
|
||||||
remainder_modulus_pairs = []
|
remainder_modulus_pairs = []
|
||||||
congruences_to_check = set()
|
congruences_to_check = set()
|
||||||
@ -2650,7 +2641,7 @@ class DimConstraints:
|
|||||||
cond = cond and isinstance(divisor, sympy.Integer)
|
cond = cond and isinstance(divisor, sympy.Integer)
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
def forced_specializations(self) -> Dict[str, sympy.Expr]:
|
def forced_specializations(self) -> dict[str, sympy.Expr]:
|
||||||
"""Returns a dictionary of the names of symbols to their specialized value"""
|
"""Returns a dictionary of the names of symbols to their specialized value"""
|
||||||
|
|
||||||
def debug_name(src: Source) -> str:
|
def debug_name(src: Source) -> str:
|
||||||
@ -2678,8 +2669,8 @@ class DimConstraints:
|
|||||||
|
|
||||||
def _process_derived_dim_roots(
|
def _process_derived_dim_roots(
|
||||||
self,
|
self,
|
||||||
results: Dict[str, Dict[str, Any]],
|
results: dict[str, dict[str, Any]],
|
||||||
name_to_dim: Dict[str, Any],
|
name_to_dim: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots,
|
Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots,
|
||||||
@ -2745,7 +2736,7 @@ class DimConstraints:
|
|||||||
# {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2}
|
# {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2}
|
||||||
# we want instead:
|
# we want instead:
|
||||||
# {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3}
|
# {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3}
|
||||||
introduced_roots: Dict[str, str] = {} # map new root -> old root
|
introduced_roots: dict[str, str] = {} # map new root -> old root
|
||||||
for k, c in list(results.items()):
|
for k, c in list(results.items()):
|
||||||
if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim
|
if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim
|
||||||
root = next(iter(c["eq"].free_symbols))
|
root = next(iter(c["eq"].free_symbols))
|
||||||
@ -2782,7 +2773,7 @@ class DimConstraints:
|
|||||||
# this consists of:
|
# this consists of:
|
||||||
# 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim
|
# 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim
|
||||||
# 2) {"dy": "dx + 1"} -> dx: root for suggested fix
|
# 2) {"dy": "dx + 1"} -> dx: root for suggested fix
|
||||||
modified_roots: Set[str] = set()
|
modified_roots: set[str] = set()
|
||||||
for k, c in results.items():
|
for k, c in results.items():
|
||||||
if k not in name_to_dim: # _dynamo.export() may handle source directly
|
if k not in name_to_dim: # _dynamo.export() may handle source directly
|
||||||
continue
|
continue
|
||||||
@ -2799,7 +2790,7 @@ class DimConstraints:
|
|||||||
# evaluate the new value for each root
|
# evaluate the new value for each root
|
||||||
# this is now either 1) unchanged, 2) refined with a new range,
|
# this is now either 1) unchanged, 2) refined with a new range,
|
||||||
# or 3) specialized to a concrete value
|
# or 3) specialized to a concrete value
|
||||||
modified_root_values: Dict[str, Dict[str, Any]] = {}
|
modified_root_values: dict[str, dict[str, Any]] = {}
|
||||||
for mroot in modified_roots:
|
for mroot in modified_roots:
|
||||||
swapped_root = True
|
swapped_root = True
|
||||||
if mroot in results:
|
if mroot in results:
|
||||||
@ -2860,9 +2851,9 @@ class DimConstraints:
|
|||||||
def prettify_results(
|
def prettify_results(
|
||||||
self,
|
self,
|
||||||
original_signature: inspect.Signature,
|
original_signature: inspect.Signature,
|
||||||
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
|
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]],
|
||||||
constraint_violation_error: object,
|
constraint_violation_error: object,
|
||||||
forced_specializations: Dict[str, str],
|
forced_specializations: dict[str, str],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Format a message for constraint violation erros"""
|
"""Format a message for constraint violation erros"""
|
||||||
from torch.export.dynamic_shapes import _get_dim_name_mapping
|
from torch.export.dynamic_shapes import _get_dim_name_mapping
|
||||||
@ -2876,7 +2867,7 @@ class DimConstraints:
|
|||||||
s = s.replace(k, v) if not inverse else s.replace(v, k)
|
s = s.replace(k, v) if not inverse else s.replace(v, k)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
results: DefaultDict[str, Dict[str, Any]] = defaultdict(dict)
|
results: defaultdict[str, dict[str, Any]] = defaultdict(dict)
|
||||||
if dynamic_shapes is None:
|
if dynamic_shapes is None:
|
||||||
dynamic_shapes = {}
|
dynamic_shapes = {}
|
||||||
|
|
||||||
@ -3050,7 +3041,7 @@ class ShapeEnv:
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
should_record_events: Optional[bool] = None,
|
should_record_events: Optional[bool] = None,
|
||||||
tracked_fakes: Optional[List[Any]] = None,
|
tracked_fakes: Optional[list[Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._init(**kwargs)
|
self._init(**kwargs)
|
||||||
@ -3086,7 +3077,7 @@ class ShapeEnv:
|
|||||||
# Keep track of the list of tracked fakes.
|
# Keep track of the list of tracked fakes.
|
||||||
self.tracked_fakes = tracked_fakes
|
self.tracked_fakes = tracked_fakes
|
||||||
# List of events for reconstructing ShapeEnv at arbitrary points in time.
|
# List of events for reconstructing ShapeEnv at arbitrary points in time.
|
||||||
self.events: List[ShapeEnvEvent] = (
|
self.events: list[ShapeEnvEvent] = (
|
||||||
[ShapeEnvEvent(ShapeEnv, kwargs=kwargs)]
|
[ShapeEnvEvent(ShapeEnv, kwargs=kwargs)]
|
||||||
if self.should_record_events
|
if self.should_record_events
|
||||||
else []
|
else []
|
||||||
@ -3099,7 +3090,7 @@ class ShapeEnv:
|
|||||||
# NOTE: It's important that SymNodes in this cache have their ShapeEnv
|
# NOTE: It's important that SymNodes in this cache have their ShapeEnv
|
||||||
# stripped otherwise you end up with cycles which can only be cleaned
|
# stripped otherwise you end up with cycles which can only be cleaned
|
||||||
# with the GC.
|
# with the GC.
|
||||||
self.fake_tensor_cache: Dict[
|
self.fake_tensor_cache: dict[
|
||||||
torch._subclasses.fake_tensor._DispatchCacheKey,
|
torch._subclasses.fake_tensor._DispatchCacheKey,
|
||||||
torch._subclasses.fake_tensor._DispatchCacheEntry,
|
torch._subclasses.fake_tensor._DispatchCacheEntry,
|
||||||
] = {}
|
] = {}
|
||||||
@ -3134,7 +3125,7 @@ class ShapeEnv:
|
|||||||
# symbolically equal.
|
# symbolically equal.
|
||||||
duck_shape: Optional[bool] = None,
|
duck_shape: Optional[bool] = None,
|
||||||
# For debugging
|
# For debugging
|
||||||
co_fields: Optional[Dict[str, str]] = None,
|
co_fields: Optional[dict[str, str]] = None,
|
||||||
# When True, whenever safe, we will generate a deferred runtime assert
|
# When True, whenever safe, we will generate a deferred runtime assert
|
||||||
# instead of a guard whenever we know that an expression must be True,
|
# instead of a guard whenever we know that an expression must be True,
|
||||||
# otherwise it would be an error, even for backed SymInts (where we
|
# otherwise it would be an error, even for backed SymInts (where we
|
||||||
@ -3165,50 +3156,50 @@ class ShapeEnv:
|
|||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.guards: List[ShapeGuard] = []
|
self.guards: list[ShapeGuard] = []
|
||||||
self.axioms: Dict[sympy.Expr, sympy.Expr] = {}
|
self.axioms: dict[sympy.Expr, sympy.Expr] = {}
|
||||||
# Maps symbolic ints to their original concrete values
|
# Maps symbolic ints to their original concrete values
|
||||||
# Currently populated from tensors
|
# Currently populated from tensors
|
||||||
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
|
self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
|
||||||
# Like var_to_val, but only set when propagate_real_tensors is on.
|
# Like var_to_val, but only set when propagate_real_tensors is on.
|
||||||
# Used as last resort to avoid GuardOnDataDependent error
|
# Used as last resort to avoid GuardOnDataDependent error
|
||||||
self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
|
self.unbacked_var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
|
||||||
# Like above, but used exclusively for OBLIVIOUS_SIZE. These
|
# Like above, but used exclusively for OBLIVIOUS_SIZE. These
|
||||||
# potentially could be put together but I am not sure, writing out
|
# potentially could be put together but I am not sure, writing out
|
||||||
# the logic individually before abstracting.
|
# the logic individually before abstracting.
|
||||||
self.oblivious_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
|
self.oblivious_var_to_val: dict[sympy.Symbol, sympy.Integer] = {}
|
||||||
# Maps symbolic ints to their min/max range. These ranges
|
# Maps symbolic ints to their min/max range. These ranges
|
||||||
# are conservative: the int MUST fall in the range, but the
|
# are conservative: the int MUST fall in the range, but the
|
||||||
# range may contain ints which may not actually appear in
|
# range may contain ints which may not actually appear in
|
||||||
# practice
|
# practice
|
||||||
self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
|
self.var_to_range: dict[sympy.Symbol, ValueRanges] = {}
|
||||||
self.var_to_range_sloc: Dict[sympy.Symbol, ValueRangesSLoc] = {}
|
self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
|
||||||
# When doing a size-oblivious test, exclude this integer and
|
# When doing a size-oblivious test, exclude this integer and
|
||||||
# everything higher than it from the acceptable range. This solves
|
# everything higher than it from the acceptable range. This solves
|
||||||
# https://github.com/pytorch/pytorch/issues/120288 for constant range
|
# https://github.com/pytorch/pytorch/issues/120288 for constant range
|
||||||
# case
|
# case
|
||||||
# TODO: generalize this to work with expressions (in that case, we
|
# TODO: generalize this to work with expressions (in that case, we
|
||||||
# need to maintain a SET and we need extra symbolic reasoning on top)
|
# need to maintain a SET and we need extra symbolic reasoning on top)
|
||||||
self.oblivious_upper_bound_exclusive: Dict[sympy.Symbol, sympy.Integer] = {}
|
self.oblivious_upper_bound_exclusive: dict[sympy.Symbol, sympy.Integer] = {}
|
||||||
self.source_name_to_debug_name: Dict[str, str] = {}
|
self.source_name_to_debug_name: dict[str, str] = {}
|
||||||
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
|
self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
|
||||||
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
|
self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
|
||||||
# Maps a source to the *original* symbol that was assigned to it
|
# Maps a source to the *original* symbol that was assigned to it
|
||||||
self.source_to_var: Dict[str, sympy.Symbol] = {}
|
self.source_to_var: dict[str, sympy.Symbol] = {}
|
||||||
# Maps from sympy ints to expressions representing them
|
# Maps from sympy ints to expressions representing them
|
||||||
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
|
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
|
||||||
self.replacements: Dict[sympy.Symbol, sympy.Expr] = {}
|
self.replacements: dict[sympy.Symbol, sympy.Expr] = {}
|
||||||
# The sloc of the guard that triggered this replacement to be added
|
# The sloc of the guard that triggered this replacement to be added
|
||||||
self.replacements_slocs: Dict[sympy.Symbol, SLoc] = {}
|
self.replacements_slocs: dict[sympy.Symbol, SLoc] = {}
|
||||||
self.unbacked_renamings: Dict[sympy.Symbol, sympy.Symbol] = {}
|
self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {}
|
||||||
# Set holds a % b expressions that evaluate to 0.
|
# Set holds a % b expressions that evaluate to 0.
|
||||||
self.divisible: Set[sympy.Expr] = set()
|
self.divisible: set[sympy.Expr] = set()
|
||||||
# Set that holds "size-like" symbols. When we perform
|
# Set that holds "size-like" symbols. When we perform
|
||||||
# "size-oblivious" tests, these can be assumed to be >= 2.
|
# "size-oblivious" tests, these can be assumed to be >= 2.
|
||||||
self.size_like: Set[sympy.Symbol] = set()
|
self.size_like: set[sympy.Symbol] = set()
|
||||||
# Duck-shaping says that if two input tensors have the same size,
|
# Duck-shaping says that if two input tensors have the same size,
|
||||||
# they get assigned the same symbolic variable
|
# they get assigned the same symbolic variable
|
||||||
self.val_to_var: Dict[int, sympy.Symbol] = {}
|
self.val_to_var: dict[int, sympy.Symbol] = {}
|
||||||
if specialize_zero_one:
|
if specialize_zero_one:
|
||||||
self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One}
|
self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One}
|
||||||
self.unbacked_symfloat_counter = itertools.count()
|
self.unbacked_symfloat_counter = itertools.count()
|
||||||
@ -3241,8 +3232,8 @@ class ShapeEnv:
|
|||||||
# to the next unbacked symbol to wait on, but if we choose the
|
# to the next unbacked symbol to wait on, but if we choose the
|
||||||
# latest key, an assert will only show up at the moment when
|
# latest key, an assert will only show up at the moment when
|
||||||
# we can actually codegen it.
|
# we can actually codegen it.
|
||||||
self.deferred_runtime_asserts: Dict[
|
self.deferred_runtime_asserts: dict[
|
||||||
Optional[sympy.Symbol], List[RuntimeAssert]
|
Optional[sympy.Symbol], list[RuntimeAssert]
|
||||||
] = {}
|
] = {}
|
||||||
# This exists so we can efficiently invalidate the cache (it's used as
|
# This exists so we can efficiently invalidate the cache (it's used as
|
||||||
# part of the cache key); otherwise we'd have to iterate through
|
# part of the cache key); otherwise we'd have to iterate through
|
||||||
@ -3279,7 +3270,7 @@ class ShapeEnv:
|
|||||||
#
|
#
|
||||||
# NB: fresh unbacked symbols NEVER get substitutions applied to them,
|
# NB: fresh unbacked symbols NEVER get substitutions applied to them,
|
||||||
# they are binding sites!
|
# they are binding sites!
|
||||||
self.pending_fresh_unbacked_symbols: List[sympy.Symbol] = []
|
self.pending_fresh_unbacked_symbols: list[sympy.Symbol] = []
|
||||||
|
|
||||||
# Version counter used to invalidate cached values
|
# Version counter used to invalidate cached values
|
||||||
self._prev_cache_key = self._get_key()
|
self._prev_cache_key = self._get_key()
|
||||||
@ -3294,8 +3285,8 @@ class ShapeEnv:
|
|||||||
# 2. list of arguments
|
# 2. list of arguments
|
||||||
# This drastically reduces the size of the FX graph, avoiding
|
# This drastically reduces the size of the FX graph, avoiding
|
||||||
# duplicated nodes.
|
# duplicated nodes.
|
||||||
self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {}
|
self.fx_node_cache: dict[tuple[Callable, tuple[Any, ...]], torch.fx.Node] = {}
|
||||||
self.source_to_symbol: Dict[str, sympy.Symbol] = {}
|
self.source_to_symbol: dict[str, sympy.Symbol] = {}
|
||||||
|
|
||||||
# Suppose you want to replace an unbacked symbol with another
|
# Suppose you want to replace an unbacked symbol with another
|
||||||
# unbacked symbol. This is error prone because you can cause
|
# unbacked symbol. This is error prone because you can cause
|
||||||
@ -3322,7 +3313,7 @@ class ShapeEnv:
|
|||||||
# bindings. At the moment, this is not tracked, but we potentially
|
# bindings. At the moment, this is not tracked, but we potentially
|
||||||
# could track this at the IR level using a higher order operator
|
# could track this at the IR level using a higher order operator
|
||||||
# with something like effect token tracking.
|
# with something like effect token tracking.
|
||||||
self.unbacked_alloc_order: Dict[sympy.Symbol, int] = {}
|
self.unbacked_alloc_order: dict[sympy.Symbol, int] = {}
|
||||||
|
|
||||||
from torch.fx.experimental.validator import translation_validation_enabled
|
from torch.fx.experimental.validator import translation_validation_enabled
|
||||||
|
|
||||||
@ -3345,7 +3336,7 @@ class ShapeEnv:
|
|||||||
# Whenever you add a node to self.graph, you must add a mapping to this
|
# Whenever you add a node to self.graph, you must add a mapping to this
|
||||||
# variable. Otherwise, the built FX graph on the replayed ShapeEnv will
|
# variable. Otherwise, the built FX graph on the replayed ShapeEnv will
|
||||||
# not be valid.
|
# not be valid.
|
||||||
self.name_to_node: Dict[str, torch.fx.Node] = {}
|
self.name_to_node: dict[str, torch.fx.Node] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allow_scalar_outputs(self) -> bool:
|
def allow_scalar_outputs(self) -> bool:
|
||||||
@ -3439,7 +3430,7 @@ class ShapeEnv:
|
|||||||
|
|
||||||
shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
|
shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
|
||||||
|
|
||||||
def _snapshot_tracked_fakes(self) -> Optional[List[Any]]:
|
def _snapshot_tracked_fakes(self) -> Optional[list[Any]]:
|
||||||
if self.tracked_fakes is None:
|
if self.tracked_fakes is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -3631,7 +3622,7 @@ class ShapeEnv:
|
|||||||
self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
|
self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
|
||||||
return self.source_to_symbol[srcname]
|
return self.source_to_symbol[srcname]
|
||||||
|
|
||||||
def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None:
|
def _add_z3var(self, symbol: sympy.Symbol, type: type) -> None:
|
||||||
if self._translation_validation_enabled:
|
if self._translation_validation_enabled:
|
||||||
self.validator.add_var(symbol, type)
|
self.validator.add_var(symbol, type)
|
||||||
|
|
||||||
@ -3651,8 +3642,8 @@ class ShapeEnv:
|
|||||||
def _create_fx_call_function(
|
def _create_fx_call_function(
|
||||||
self,
|
self,
|
||||||
op: Callable,
|
op: Callable,
|
||||||
args: Tuple,
|
args: tuple,
|
||||||
) -> Tuple[Optional[torch.fx.Node], bool]:
|
) -> tuple[Optional[torch.fx.Node], bool]:
|
||||||
# Cache this tuple in order to avoid duplicated nodes.
|
# Cache this tuple in order to avoid duplicated nodes.
|
||||||
node_key = (op, args)
|
node_key = (op, args)
|
||||||
# Flags whether the returned node was cached or not.
|
# Flags whether the returned node was cached or not.
|
||||||
@ -3681,7 +3672,7 @@ class ShapeEnv:
|
|||||||
def _create_fx_placeholder_and_z3var(
|
def _create_fx_placeholder_and_z3var(
|
||||||
self,
|
self,
|
||||||
symbol: sympy.Symbol,
|
symbol: sympy.Symbol,
|
||||||
type: Type,
|
type: type,
|
||||||
) -> Optional[torch.fx.Node]:
|
) -> Optional[torch.fx.Node]:
|
||||||
if not self._translation_validation_enabled:
|
if not self._translation_validation_enabled:
|
||||||
return None
|
return None
|
||||||
@ -3742,7 +3733,7 @@ class ShapeEnv:
|
|||||||
"""Context manager to ignore all guards generated inside"""
|
"""Context manager to ignore all guards generated inside"""
|
||||||
return _suppress_guards(self)
|
return _suppress_guards(self)
|
||||||
|
|
||||||
def _get_key(self) -> Tuple[int, int, int, int]:
|
def _get_key(self) -> tuple[int, int, int, int]:
|
||||||
"""
|
"""
|
||||||
Defines the current "state" of the guards we've accumulated in this ShapeEnv.
|
Defines the current "state" of the guards we've accumulated in this ShapeEnv.
|
||||||
Determines when we need to invalidate our cache
|
Determines when we need to invalidate our cache
|
||||||
@ -3778,7 +3769,7 @@ class ShapeEnv:
|
|||||||
ex_size: Sequence[Union[int, SymInt]],
|
ex_size: Sequence[Union[int, SymInt]],
|
||||||
source: Source,
|
source: Source,
|
||||||
symbolic_context: SymbolicContext,
|
symbolic_context: SymbolicContext,
|
||||||
) -> List[sympy.Expr]:
|
) -> list[sympy.Expr]:
|
||||||
return self._produce_dyn_sizes_from_int_tuple(
|
return self._produce_dyn_sizes_from_int_tuple(
|
||||||
tuple(ex_size), source, symbolic_context
|
tuple(ex_size), source, symbolic_context
|
||||||
)
|
)
|
||||||
@ -3788,7 +3779,7 @@ class ShapeEnv:
|
|||||||
tensor_size: Sequence[Union[int, SymInt]],
|
tensor_size: Sequence[Union[int, SymInt]],
|
||||||
source: Source,
|
source: Source,
|
||||||
symbolic_context: SymbolicContext,
|
symbolic_context: SymbolicContext,
|
||||||
) -> List[sympy.Expr]:
|
) -> list[sympy.Expr]:
|
||||||
assert all(
|
assert all(
|
||||||
not is_symbolic(val) for val in tensor_size
|
not is_symbolic(val) for val in tensor_size
|
||||||
), f"Expect size to be a plain tuple of ints but got {tensor_size}"
|
), f"Expect size to be a plain tuple of ints but got {tensor_size}"
|
||||||
@ -3816,9 +3807,9 @@ class ShapeEnv:
|
|||||||
source: Source,
|
source: Source,
|
||||||
*,
|
*,
|
||||||
symbolic_context: Optional[SymbolicContext] = None,
|
symbolic_context: Optional[SymbolicContext] = None,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
Tuple[Union[int, SymInt], ...],
|
tuple[Union[int, SymInt], ...],
|
||||||
Tuple[Union[int, SymInt], ...],
|
tuple[Union[int, SymInt], ...],
|
||||||
Union[int, SymInt],
|
Union[int, SymInt],
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
@ -3903,17 +3894,17 @@ class ShapeEnv:
|
|||||||
source: Source,
|
source: Source,
|
||||||
*,
|
*,
|
||||||
symbolic_context: Optional[SymbolicContext] = None,
|
symbolic_context: Optional[SymbolicContext] = None,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
Tuple[Union[int, SymInt], ...],
|
tuple[Union[int, SymInt], ...],
|
||||||
Tuple[Union[int, SymInt], ...],
|
tuple[Union[int, SymInt], ...],
|
||||||
Union[int, SymInt],
|
Union[int, SymInt],
|
||||||
]:
|
]:
|
||||||
dim = len(ex_size)
|
dim = len(ex_size)
|
||||||
|
|
||||||
# Reimplement the legacy behavior
|
# Reimplement the legacy behavior
|
||||||
if symbolic_context is None:
|
if symbolic_context is None:
|
||||||
constraint_sizes: List[DimConstraint] = [None] * dim
|
constraint_sizes: list[DimConstraint] = [None] * dim
|
||||||
constraint_strides: List[DimConstraint] = [None] * dim
|
constraint_strides: list[DimConstraint] = [None] * dim
|
||||||
dynamic_dims = []
|
dynamic_dims = []
|
||||||
dynamic_strides = []
|
dynamic_strides = []
|
||||||
for i in range(dim):
|
for i in range(dim):
|
||||||
@ -3963,7 +3954,7 @@ class ShapeEnv:
|
|||||||
|
|
||||||
from torch._dynamo.source import TensorProperty, TensorPropertySource
|
from torch._dynamo.source import TensorProperty, TensorPropertySource
|
||||||
|
|
||||||
size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(
|
size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(
|
||||||
ex_size, source, symbolic_context
|
ex_size, source, symbolic_context
|
||||||
)
|
)
|
||||||
stride = self._compute_symbolic_stride(
|
stride = self._compute_symbolic_stride(
|
||||||
@ -4022,11 +4013,11 @@ class ShapeEnv:
|
|||||||
],
|
],
|
||||||
are_sizes_static: bool,
|
are_sizes_static: bool,
|
||||||
symbolic_context: SymbolicContext,
|
symbolic_context: SymbolicContext,
|
||||||
) -> List[sympy.Expr]:
|
) -> list[sympy.Expr]:
|
||||||
from torch._dynamo.source import TensorProperty, TensorPropertySource
|
from torch._dynamo.source import TensorProperty, TensorPropertySource
|
||||||
|
|
||||||
stride: List[Optional[sympy.Expr]] = [None] * len(size)
|
stride: list[Optional[sympy.Expr]] = [None] * len(size)
|
||||||
candidates: Dict[Union[int, SymInt], sympy.Expr] = {}
|
candidates: dict[Union[int, SymInt], sympy.Expr] = {}
|
||||||
|
|
||||||
# iterate over unbound strides in val ascending order with
|
# iterate over unbound strides in val ascending order with
|
||||||
# index descending as a tie breaker since for cases like
|
# index descending as a tie breaker since for cases like
|
||||||
@ -4590,7 +4581,7 @@ class ShapeEnv:
|
|||||||
return c_render
|
return c_render
|
||||||
return c.render(source)
|
return c.render(source)
|
||||||
|
|
||||||
def produce_guards(self, *args: Any, **kwargs: Any) -> List[str]:
|
def produce_guards(self, *args: Any, **kwargs: Any) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Like produce_guards_verbose, but only returns the non-verbose guard expressions
|
Like produce_guards_verbose, but only returns the non-verbose guard expressions
|
||||||
(no verbose guards produced.)
|
(no verbose guards produced.)
|
||||||
@ -4603,7 +4594,7 @@ class ShapeEnv:
|
|||||||
sources: Sequence[Source],
|
sources: Sequence[Source],
|
||||||
source_ref: Callable[[Source], str] = lambda n: n.name(),
|
source_ref: Callable[[Source], str] = lambda n: n.name(),
|
||||||
*,
|
*,
|
||||||
guards: Optional[List[ShapeGuard]] = None,
|
guards: Optional[list[ShapeGuard]] = None,
|
||||||
input_contexts: Optional[DimList[SymbolicContext]] = None,
|
input_contexts: Optional[DimList[SymbolicContext]] = None,
|
||||||
# Encodes user-specified input shape equations of the form s = s' and s = fn(s').
|
# Encodes user-specified input shape equations of the form s = s' and s = fn(s').
|
||||||
# (See docs on EqualityConstraint for details of the encoding.)
|
# (See docs on EqualityConstraint for details of the encoding.)
|
||||||
@ -4611,7 +4602,7 @@ class ShapeEnv:
|
|||||||
_simplified: bool = False,
|
_simplified: bool = False,
|
||||||
# Indicates if we should produce guards for known static values.
|
# Indicates if we should produce guards for known static values.
|
||||||
ignore_static: bool = True,
|
ignore_static: bool = True,
|
||||||
) -> Tuple[List[str], List[str]]: # python, verbose
|
) -> tuple[list[str], list[str]]: # python, verbose
|
||||||
"""
|
"""
|
||||||
Generates a list of guards strings which, when evaluated in a context that
|
Generates a list of guards strings which, when evaluated in a context that
|
||||||
defines tensors for all the sources, returns True or False depending
|
defines tensors for all the sources, returns True or False depending
|
||||||
@ -4740,13 +4731,13 @@ class ShapeEnv:
|
|||||||
# the symbol mapping is
|
# the symbol mapping is
|
||||||
input_guards = []
|
input_guards = []
|
||||||
|
|
||||||
symbol_to_source: Dict[sympy.Symbol, List[Source]] = collections.defaultdict(
|
symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict(
|
||||||
list
|
list
|
||||||
)
|
)
|
||||||
symbol_to_constraints: DefaultDict[
|
symbol_to_constraints: defaultdict[
|
||||||
sympy.Symbol, Set[Constraint]
|
sympy.Symbol, set[Constraint]
|
||||||
] = collections.defaultdict(set)
|
] = collections.defaultdict(set)
|
||||||
constraint_violations: List[Tuple[bool, str, Callable[[], str]]] = []
|
constraint_violations: list[tuple[bool, str, Callable[[], str]]] = []
|
||||||
|
|
||||||
py_printer = ShapeGuardPythonPrinter(
|
py_printer = ShapeGuardPythonPrinter(
|
||||||
symbol_to_source, source_ref, self.var_to_sources
|
symbol_to_source, source_ref, self.var_to_sources
|
||||||
@ -4956,7 +4947,7 @@ class ShapeEnv:
|
|||||||
# For subclasses, we need to track symints on BOTH the outer
|
# For subclasses, we need to track symints on BOTH the outer
|
||||||
# and inner tensors.
|
# and inner tensors.
|
||||||
# TODO: type this better
|
# TODO: type this better
|
||||||
sources_tensors_constraints: List[Tuple[Source, Any, Any, Any]] = [
|
sources_tensors_constraints: list[tuple[Source, Any, Any, Any]] = [
|
||||||
(source, t, context.constraint_sizes, context.constraint_strides)
|
(source, t, context.constraint_sizes, context.constraint_strides)
|
||||||
]
|
]
|
||||||
attrs, _ = t.__tensor_flatten__()
|
attrs, _ = t.__tensor_flatten__()
|
||||||
@ -5256,8 +5247,8 @@ class ShapeEnv:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if constraint_violations:
|
if constraint_violations:
|
||||||
warn_msgs: List[str] = []
|
warn_msgs: list[str] = []
|
||||||
error_msgs: List[str] = []
|
error_msgs: list[str] = []
|
||||||
debug_names = set()
|
debug_names = set()
|
||||||
for warn_only, debug_name, msg_cb in constraint_violations:
|
for warn_only, debug_name, msg_cb in constraint_violations:
|
||||||
if warn_only:
|
if warn_only:
|
||||||
@ -5327,7 +5318,7 @@ class ShapeEnv:
|
|||||||
self,
|
self,
|
||||||
placeholders: Sequence[Union[SymInt, FakeTensor]],
|
placeholders: Sequence[Union[SymInt, FakeTensor]],
|
||||||
*,
|
*,
|
||||||
guards: Optional[List[ShapeGuard]] = None,
|
guards: Optional[list[ShapeGuard]] = None,
|
||||||
ignore_static: bool = True,
|
ignore_static: bool = True,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@ -5386,7 +5377,7 @@ class ShapeEnv:
|
|||||||
return self.evaluate_guards_expression(code, args)
|
return self.evaluate_guards_expression(code, args)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> List[ShapeGuard]:
|
def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> list[ShapeGuard]:
|
||||||
"""
|
"""
|
||||||
Get a list of guards, but pruned so it only provides guards that
|
Get a list of guards, but pruned so it only provides guards that
|
||||||
reference symints from the passed in input
|
reference symints from the passed in input
|
||||||
@ -5401,7 +5392,7 @@ class ShapeEnv:
|
|||||||
|
|
||||||
def bind_symbols(
|
def bind_symbols(
|
||||||
self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor]
|
self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor]
|
||||||
) -> Dict[sympy.Symbol, int]:
|
) -> dict[sympy.Symbol, int]:
|
||||||
"""
|
"""
|
||||||
Given a paired list of placeholders (fake tensors with
|
Given a paired list of placeholders (fake tensors with
|
||||||
symbolic sizes) and concrete arguments (regular tensors
|
symbolic sizes) and concrete arguments (regular tensors
|
||||||
@ -5418,7 +5409,7 @@ class ShapeEnv:
|
|||||||
another copy. This assumes the guards are already checked,
|
another copy. This assumes the guards are already checked,
|
||||||
though if it's cheap we'll check for shenanigans
|
though if it's cheap we'll check for shenanigans
|
||||||
"""
|
"""
|
||||||
bindings: Dict[sympy.Symbol, int] = {}
|
bindings: dict[sympy.Symbol, int] = {}
|
||||||
|
|
||||||
def bind_symint(arg: object, val: object) -> None:
|
def bind_symint(arg: object, val: object) -> None:
|
||||||
if isinstance(val, SymInt):
|
if isinstance(val, SymInt):
|
||||||
@ -5451,7 +5442,7 @@ class ShapeEnv:
|
|||||||
|
|
||||||
return bindings
|
return bindings
|
||||||
|
|
||||||
def get_nontrivial_guards(self) -> List[SympyBoolean]:
|
def get_nontrivial_guards(self) -> list[SympyBoolean]:
|
||||||
"""Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""
|
"""Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""
|
||||||
return [
|
return [
|
||||||
self.simplify(guard.expr)
|
self.simplify(guard.expr)
|
||||||
@ -5488,9 +5479,9 @@ class ShapeEnv:
|
|||||||
@_lru_cache
|
@_lru_cache
|
||||||
def get_axioms(
|
def get_axioms(
|
||||||
self,
|
self,
|
||||||
symbols: Optional[Tuple[sympy.Symbol]] = None,
|
symbols: Optional[tuple[sympy.Symbol]] = None,
|
||||||
compute_hint: bool = False,
|
compute_hint: bool = False,
|
||||||
) -> Tuple[SympyBoolean, ...]:
|
) -> tuple[SympyBoolean, ...]:
|
||||||
"""
|
"""
|
||||||
Given the symbols in an expression, it returns all the runtime asserts that have those symbols
|
Given the symbols in an expression, it returns all the runtime asserts that have those symbols
|
||||||
concatenated with all the guards.
|
concatenated with all the guards.
|
||||||
@ -5518,9 +5509,9 @@ class ShapeEnv:
|
|||||||
@lru_cache(None)
|
@lru_cache(None)
|
||||||
def get_implications(
|
def get_implications(
|
||||||
self, e: SympyBoolean
|
self, e: SympyBoolean
|
||||||
) -> Tuple[Tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]:
|
) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]:
|
||||||
"""Given a expression, it returns a list of predicates that follow from it"""
|
"""Given a expression, it returns a list of predicates that follow from it"""
|
||||||
equiv: Dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {}
|
equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {}
|
||||||
|
|
||||||
def add_expr(expr: SympyBoolean) -> None:
|
def add_expr(expr: SympyBoolean) -> None:
|
||||||
expr = canonicalize_bool_expr(expr)
|
expr = canonicalize_bool_expr(expr)
|
||||||
@ -5564,8 +5555,8 @@ class ShapeEnv:
|
|||||||
unbacked_only: bool = False,
|
unbacked_only: bool = False,
|
||||||
compute_hint: bool = False,
|
compute_hint: bool = False,
|
||||||
size_oblivious: bool = False,
|
size_oblivious: bool = False,
|
||||||
axioms: Optional[Tuple[SympyBoolean]] = None,
|
axioms: Optional[tuple[SympyBoolean]] = None,
|
||||||
var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None,
|
var_to_range: Optional[tuple[tuple[sympy.Symbol, ValueRanges]]] = None,
|
||||||
) -> Optional[sympy.Basic]:
|
) -> Optional[sympy.Basic]:
|
||||||
"""
|
"""
|
||||||
Tries to evaluate expr without introducing guards
|
Tries to evaluate expr without introducing guards
|
||||||
@ -5589,7 +5580,7 @@ class ShapeEnv:
|
|||||||
|
|
||||||
expr = canonicalize_bool_expr(expr)
|
expr = canonicalize_bool_expr(expr)
|
||||||
|
|
||||||
def resimplify_floor_div(axioms: Dict[sympy.Expr, sympy.Expr]) -> None:
|
def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None:
|
||||||
if not self._resimplify_floor_div_axioms:
|
if not self._resimplify_floor_div_axioms:
|
||||||
return
|
return
|
||||||
self._resimplify_floor_div_axioms = False
|
self._resimplify_floor_div_axioms = False
|
||||||
@ -6114,7 +6105,7 @@ class ShapeEnv:
|
|||||||
# Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
|
# Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
|
||||||
# (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
|
# (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
|
||||||
# Prefer to simplify out symbols with ephemeral sources.
|
# Prefer to simplify out symbols with ephemeral sources.
|
||||||
def _smart_symbol_sort(x: sympy.Symbol) -> Tuple[int, int, str]:
|
def _smart_symbol_sort(x: sympy.Symbol) -> tuple[int, int, str]:
|
||||||
has_only_ephemeral_sources = x in self.var_to_sources and all(
|
has_only_ephemeral_sources = x in self.var_to_sources and all(
|
||||||
s.is_ephemeral() for s in self.var_to_sources[x]
|
s.is_ephemeral() for s in self.var_to_sources[x]
|
||||||
)
|
)
|
||||||
@ -6282,7 +6273,7 @@ class ShapeEnv:
|
|||||||
|
|
||||||
def _get_stack_summary(
|
def _get_stack_summary(
|
||||||
self, is_debug: bool = False, framework_loc: Optional[str] = None
|
self, is_debug: bool = False, framework_loc: Optional[str] = None
|
||||||
) -> Tuple[SLoc, str]:
|
) -> tuple[SLoc, str]:
|
||||||
floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc
|
floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc
|
||||||
if floc is None:
|
if floc is None:
|
||||||
frame = inspect.currentframe()
|
frame = inspect.currentframe()
|
||||||
@ -6903,7 +6894,7 @@ class _PythonMsgPrinter(PythonPrinter):
|
|||||||
(i.e., as ==, !=, >, <).
|
(i.e., as ==, !=, >, <).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, src_map: Dict[str, List[str]]) -> None:
|
def __init__(self, src_map: dict[str, list[str]]) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.src_map = src_map
|
self.src_map = src_map
|
||||||
|
|
||||||
@ -6912,7 +6903,7 @@ class _PythonMsgPrinter(PythonPrinter):
|
|||||||
|
|
||||||
|
|
||||||
def _suggest_torch_checks(
|
def _suggest_torch_checks(
|
||||||
e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]]
|
e: GuardOnDataDependentSymNode, src_map: defaultdict[str, list[str]]
|
||||||
) -> None:
|
) -> None:
|
||||||
# extract the unresolved condition on unbacked symints in the error
|
# extract the unresolved condition on unbacked symints in the error
|
||||||
cond = e.cond
|
cond = e.cond
|
||||||
|
@ -5,7 +5,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ try:
|
|||||||
def z3str(e: z3.ExprRef) -> str:
|
def z3str(e: z3.ExprRef) -> str:
|
||||||
assert z3.is_expr(e), f"unsupported expression type: {e}"
|
assert z3.is_expr(e), f"unsupported expression type: {e}"
|
||||||
|
|
||||||
def get_args_str(e: z3.ExprRef) -> List[str]:
|
def get_args_str(e: z3.ExprRef) -> list[str]:
|
||||||
return [z3str(e.arg(i)) for i in range(e.num_args())]
|
return [z3str(e.arg(i)) for i in range(e.num_args())]
|
||||||
|
|
||||||
# First, we simplify the given expression.
|
# First, we simplify the given expression.
|
||||||
@ -350,13 +350,13 @@ try:
|
|||||||
super().__init__(module, garbage_collect_values=True)
|
super().__init__(module, garbage_collect_values=True)
|
||||||
|
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
symbol = fx_traceback.get_current_meta()["symbol"]
|
symbol = fx_traceback.get_current_meta()["symbol"]
|
||||||
return self.validator.z3var(symbol)
|
return self.validator.z3var(symbol)
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if target != torch._assert:
|
if target != torch._assert:
|
||||||
# Lift and runs the node target function
|
# Lift and runs the node target function
|
||||||
@ -481,21 +481,21 @@ try:
|
|||||||
log.debug("new instance")
|
log.debug("new instance")
|
||||||
|
|
||||||
# Mapping of SymPy symbols to Z3 variables.
|
# Mapping of SymPy symbols to Z3 variables.
|
||||||
self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {}
|
self.symbols: dict[sympy.Symbol, z3.ExprRef] = {}
|
||||||
|
|
||||||
# Set of source Z3 expressions.
|
# Set of source Z3 expressions.
|
||||||
# They represent the generated guards without any kind of
|
# They represent the generated guards without any kind of
|
||||||
# simplification or transformation.
|
# simplification or transformation.
|
||||||
self._source_exprs: Set[z3.BoolRef] = set()
|
self._source_exprs: set[z3.BoolRef] = set()
|
||||||
|
|
||||||
# Set of target Z3 expressions.
|
# Set of target Z3 expressions.
|
||||||
# They represent the actual checked guards at runtime. They might
|
# They represent the actual checked guards at runtime. They might
|
||||||
# be simplified or transformed versions of the source guards.
|
# be simplified or transformed versions of the source guards.
|
||||||
self._target_exprs: Set[z3.BoolRef] = set()
|
self._target_exprs: set[z3.BoolRef] = set()
|
||||||
|
|
||||||
# Set of Z3 expressions representing assertions over both the
|
# Set of Z3 expressions representing assertions over both the
|
||||||
# source and target expressions.
|
# source and target expressions.
|
||||||
self._assertions: Set[z3.BoolRef] = set()
|
self._assertions: set[z3.BoolRef] = set()
|
||||||
|
|
||||||
# Retrieves the corresponding Z3 variable.
|
# Retrieves the corresponding Z3 variable.
|
||||||
def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
|
def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
|
||||||
@ -503,7 +503,7 @@ try:
|
|||||||
return self.symbols[symbol]
|
return self.symbols[symbol]
|
||||||
|
|
||||||
# Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
|
# Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
|
||||||
def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef:
|
def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef:
|
||||||
if symbol in self.symbols:
|
if symbol in self.symbols:
|
||||||
return self.symbols[symbol]
|
return self.symbols[symbol]
|
||||||
|
|
||||||
@ -769,7 +769,7 @@ def bisect(shape_env):
|
|||||||
|
|
||||||
# Checks whether the given shape_env fails when produce_guards is called.
|
# Checks whether the given shape_env fails when produce_guards is called.
|
||||||
def check_shapeenv_fails(
|
def check_shapeenv_fails(
|
||||||
shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]
|
shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]]
|
||||||
) -> Optional[ValidationException]:
|
) -> Optional[ValidationException]:
|
||||||
assert tracked_fakes is not None
|
assert tracked_fakes is not None
|
||||||
try:
|
try:
|
||||||
|
@ -11,23 +11,10 @@ import os
|
|||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
FrozenSet,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
@ -47,11 +34,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
# Mapping of builtins to their `typing` equivalent.
|
# Mapping of builtins to their `typing` equivalent.
|
||||||
_origin_type_map = {
|
_origin_type_map = {
|
||||||
list: List,
|
list: list,
|
||||||
dict: Dict,
|
dict: dict,
|
||||||
set: Set,
|
set: set,
|
||||||
frozenset: FrozenSet,
|
frozenset: frozenset,
|
||||||
tuple: Tuple,
|
tuple: tuple,
|
||||||
}
|
}
|
||||||
|
|
||||||
_legal_ops = dict.fromkeys(
|
_legal_ops = dict.fromkeys(
|
||||||
@ -61,7 +48,7 @@ _legal_ops = dict.fromkeys(
|
|||||||
|
|
||||||
# Signature for functions thattransforms the body (`list[str]`) of the
|
# Signature for functions thattransforms the body (`list[str]`) of the
|
||||||
# generated code
|
# generated code
|
||||||
TransformCodeFunc = Callable[[List[str]], List[str]]
|
TransformCodeFunc = Callable[[list[str]], list[str]]
|
||||||
|
|
||||||
|
|
||||||
class _CustomBuiltin(NamedTuple):
|
class _CustomBuiltin(NamedTuple):
|
||||||
@ -78,7 +65,7 @@ class _CustomBuiltin(NamedTuple):
|
|||||||
obj: Any
|
obj: Any
|
||||||
|
|
||||||
|
|
||||||
_custom_builtins: Dict[str, _CustomBuiltin] = {}
|
_custom_builtins: dict[str, _CustomBuiltin] = {}
|
||||||
|
|
||||||
|
|
||||||
def _register_custom_builtin(name: str, import_str: str, obj: Any):
|
def _register_custom_builtin(name: str, import_str: str, obj: Any):
|
||||||
@ -144,10 +131,10 @@ class _Namespace:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._obj_to_name: Dict[Any, str] = {}
|
self._obj_to_name: dict[Any, str] = {}
|
||||||
self._unassociated_names = set()
|
self._unassociated_names = set()
|
||||||
self._used_names: Set[str] = set()
|
self._used_names: set[str] = set()
|
||||||
self._base_count: Dict[str, int] = defaultdict(int)
|
self._base_count: dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
self._illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
|
self._illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
|
||||||
self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
|
self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
|
||||||
@ -261,10 +248,10 @@ class PythonCode:
|
|||||||
# Python source code for the forward function definition.
|
# Python source code for the forward function definition.
|
||||||
src: str
|
src: str
|
||||||
# Values in global scope during execution of `src_def`.
|
# Values in global scope during execution of `src_def`.
|
||||||
globals: Dict[str, Any]
|
globals: dict[str, Any]
|
||||||
# Optional mapping from the forward function's line number to
|
# Optional mapping from the forward function's line number to
|
||||||
# node index.
|
# node index.
|
||||||
_lineno_map: Optional[Dict[int, Optional[int]]]
|
_lineno_map: Optional[dict[int, Optional[int]]]
|
||||||
|
|
||||||
|
|
||||||
def _format_target(base: str, target: str) -> str:
|
def _format_target(base: str, target: str) -> str:
|
||||||
@ -311,7 +298,7 @@ class _PyTreeInfo(NamedTuple):
|
|||||||
Contains extra info stored when we're using Pytrees
|
Contains extra info stored when we're using Pytrees
|
||||||
"""
|
"""
|
||||||
|
|
||||||
orig_args: List[str]
|
orig_args: list[str]
|
||||||
in_spec: pytree.TreeSpec
|
in_spec: pytree.TreeSpec
|
||||||
out_spec: Optional[pytree.TreeSpec]
|
out_spec: Optional[pytree.TreeSpec]
|
||||||
|
|
||||||
@ -359,7 +346,7 @@ class CodeGen:
|
|||||||
self._body_transformer: Optional[TransformCodeFunc] = None
|
self._body_transformer: Optional[TransformCodeFunc] = None
|
||||||
self._func_name: str = "forward"
|
self._func_name: str = "forward"
|
||||||
|
|
||||||
def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
|
def gen_fn_def(self, free_vars: list[str], maybe_return_annotation: str) -> str:
|
||||||
"""
|
"""
|
||||||
Given the free variables and a return annotation, generates the beginning of the FX function.
|
Given the free variables and a return annotation, generates the beginning of the FX function.
|
||||||
By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
|
By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
|
||||||
@ -398,7 +385,7 @@ class CodeGen:
|
|||||||
"""
|
"""
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def additional_globals(self) -> List[Tuple[str, Any]]:
|
def additional_globals(self) -> list[tuple[str, Any]]:
|
||||||
"""
|
"""
|
||||||
If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
|
If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
|
||||||
For example, return ['List', typing.List] if you need ``List`` in the global context.
|
For example, return ['List', typing.List] if you need ``List`` in the global context.
|
||||||
@ -416,13 +403,13 @@ class CodeGen:
|
|||||||
include_device: bool = False,
|
include_device: bool = False,
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
) -> PythonCode:
|
) -> PythonCode:
|
||||||
free_vars: List[str] = []
|
free_vars: list[str] = []
|
||||||
body: List[str] = []
|
body: list[str] = []
|
||||||
globals_: Dict[str, Any] = {}
|
globals_: dict[str, Any] = {}
|
||||||
wrapped_fns: Dict[str, None] = {}
|
wrapped_fns: dict[str, None] = {}
|
||||||
|
|
||||||
# Wrap string in list to pass by reference
|
# Wrap string in list to pass by reference
|
||||||
maybe_return_annotation: List[str] = [""]
|
maybe_return_annotation: list[str] = [""]
|
||||||
include_stride = include_stride or (
|
include_stride = include_stride or (
|
||||||
os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1"
|
os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1"
|
||||||
)
|
)
|
||||||
@ -553,7 +540,7 @@ class CodeGen:
|
|||||||
return blue(repr(arg))
|
return blue(repr(arg))
|
||||||
|
|
||||||
def _format_args(
|
def _format_args(
|
||||||
args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
args: tuple[Argument, ...], kwargs: dict[str, Argument]
|
||||||
) -> str:
|
) -> str:
|
||||||
args_s = ", ".join(_get_repr(a) for a in args)
|
args_s = ", ".join(_get_repr(a) for a in args)
|
||||||
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
|
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
|
||||||
@ -565,8 +552,8 @@ class CodeGen:
|
|||||||
# of a given node. This represents the *last* use of the node in the
|
# of a given node. This represents the *last* use of the node in the
|
||||||
# execution order of the program, which we will use to free unused
|
# execution order of the program, which we will use to free unused
|
||||||
# values
|
# values
|
||||||
node_to_last_use: Dict[Node, Node] = {}
|
node_to_last_use: dict[Node, Node] = {}
|
||||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
user_to_last_uses: dict[Node, list[Node]] = {}
|
||||||
|
|
||||||
def register_last_uses(n: Node, user: Node):
|
def register_last_uses(n: Node, user: Node):
|
||||||
if n not in node_to_last_use:
|
if n not in node_to_last_use:
|
||||||
@ -782,9 +769,9 @@ class CodeGen:
|
|||||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||||
|
|
||||||
# remove counter and generate lineno to node index mapping
|
# remove counter and generate lineno to node index mapping
|
||||||
lineno_map: Dict[int, Optional[int]] = {}
|
lineno_map: dict[int, Optional[int]] = {}
|
||||||
prologue_len = prologue.count("\n") + 1
|
prologue_len = prologue.count("\n") + 1
|
||||||
new_lines: List[str] = []
|
new_lines: list[str] = []
|
||||||
cur_idx = None
|
cur_idx = None
|
||||||
for line in "".join(body).split("\n"):
|
for line in "".join(body).split("\n"):
|
||||||
counter = re.search(r"# COUNTER: (\d+)", line)
|
counter = re.search(r"# COUNTER: (\d+)", line)
|
||||||
@ -904,11 +891,11 @@ class _FindNodesLookupTable:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(
|
self.table: dict[tuple[str, Optional[Target]], dict[Node, None]] = defaultdict(
|
||||||
dict
|
dict
|
||||||
)
|
)
|
||||||
|
|
||||||
def _key(self, node) -> Tuple[str, Optional[Target]]:
|
def _key(self, node) -> tuple[str, Optional[Target]]:
|
||||||
return (node.op, node.target if node.op == "call_function" else None)
|
return (node.op, node.target if node.op == "call_function" else None)
|
||||||
|
|
||||||
def __contains__(self, node) -> bool:
|
def __contains__(self, node) -> bool:
|
||||||
@ -985,14 +972,14 @@ class Graph:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
owning_module: Optional["GraphModule"] = None,
|
owning_module: Optional["GraphModule"] = None,
|
||||||
tracer_cls: Optional[Type["Tracer"]] = None,
|
tracer_cls: Optional[type["Tracer"]] = None,
|
||||||
tracer_extras: Optional[Dict[str, Any]] = None,
|
tracer_extras: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Construct an empty Graph.
|
Construct an empty Graph.
|
||||||
"""
|
"""
|
||||||
self._root: Node = Node(self, "", "root", "", (), {})
|
self._root: Node = Node(self, "", "root", "", (), {})
|
||||||
self._used_names: Dict[str, int] = {} # base name -> number
|
self._used_names: dict[str, int] = {} # base name -> number
|
||||||
self._insert = self._root.prepend
|
self._insert = self._root.prepend
|
||||||
self._len = 0
|
self._len = 0
|
||||||
self._graph_namespace = _Namespace()
|
self._graph_namespace = _Namespace()
|
||||||
@ -1000,7 +987,7 @@ class Graph:
|
|||||||
self._tracer_cls = tracer_cls
|
self._tracer_cls = tracer_cls
|
||||||
self._tracer_extras = tracer_extras
|
self._tracer_extras = tracer_extras
|
||||||
self._codegen = CodeGen()
|
self._codegen = CodeGen()
|
||||||
self._co_fields: Dict[str, Any] = {}
|
self._co_fields: dict[str, Any] = {}
|
||||||
self._find_nodes_lookup_table = _FindNodesLookupTable()
|
self._find_nodes_lookup_table = _FindNodesLookupTable()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1060,7 +1047,7 @@ class Graph:
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def graph_copy(
|
def graph_copy(
|
||||||
self, g: "Graph", val_map: Dict[Node, Node], return_output_node=False
|
self, g: "Graph", val_map: dict[Node, Node], return_output_node=False
|
||||||
) -> "Optional[Argument]":
|
) -> "Optional[Argument]":
|
||||||
"""
|
"""
|
||||||
Copy all nodes from a given graph into ``self``.
|
Copy all nodes from a given graph into ``self``.
|
||||||
@ -1113,8 +1100,8 @@ class Graph:
|
|||||||
self,
|
self,
|
||||||
op: str,
|
op: str,
|
||||||
target: "Target",
|
target: "Target",
|
||||||
args: Optional[Tuple["Argument", ...]] = None,
|
args: Optional[tuple["Argument", ...]] = None,
|
||||||
kwargs: Optional[Dict[str, "Argument"]] = None,
|
kwargs: Optional[dict[str, "Argument"]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
type_expr: Optional[Any] = None,
|
type_expr: Optional[Any] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
@ -1373,8 +1360,8 @@ class Graph:
|
|||||||
def call_module(
|
def call_module(
|
||||||
self,
|
self,
|
||||||
module_name: str,
|
module_name: str,
|
||||||
args: Optional[Tuple["Argument", ...]] = None,
|
args: Optional[tuple["Argument", ...]] = None,
|
||||||
kwargs: Optional[Dict[str, "Argument"]] = None,
|
kwargs: Optional[dict[str, "Argument"]] = None,
|
||||||
type_expr: Optional[Any] = None,
|
type_expr: Optional[Any] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""
|
"""
|
||||||
@ -1423,8 +1410,8 @@ class Graph:
|
|||||||
def call_method(
|
def call_method(
|
||||||
self,
|
self,
|
||||||
method_name: str,
|
method_name: str,
|
||||||
args: Optional[Tuple["Argument", ...]] = None,
|
args: Optional[tuple["Argument", ...]] = None,
|
||||||
kwargs: Optional[Dict[str, "Argument"]] = None,
|
kwargs: Optional[dict[str, "Argument"]] = None,
|
||||||
type_expr: Optional[Any] = None,
|
type_expr: Optional[Any] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""
|
"""
|
||||||
@ -1462,8 +1449,8 @@ class Graph:
|
|||||||
def call_function(
|
def call_function(
|
||||||
self,
|
self,
|
||||||
the_function: Callable[..., Any],
|
the_function: Callable[..., Any],
|
||||||
args: Optional[Tuple["Argument", ...]] = None,
|
args: Optional[tuple["Argument", ...]] = None,
|
||||||
kwargs: Optional[Dict[str, "Argument"]] = None,
|
kwargs: Optional[dict[str, "Argument"]] = None,
|
||||||
type_expr: Optional[Any] = None,
|
type_expr: Optional[Any] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""
|
"""
|
||||||
@ -1668,10 +1655,10 @@ class Graph:
|
|||||||
Return a human-readable (not machine-readable) string representation
|
Return a human-readable (not machine-readable) string representation
|
||||||
of this Graph
|
of this Graph
|
||||||
"""
|
"""
|
||||||
placeholder_names: List[str] = []
|
placeholder_names: list[str] = []
|
||||||
# This is a one-element array just so ``format_node`` can modify the closed
|
# This is a one-element array just so ``format_node`` can modify the closed
|
||||||
# over value
|
# over value
|
||||||
maybe_return_typename: List[str] = [""]
|
maybe_return_typename: list[str] = [""]
|
||||||
|
|
||||||
node_strs = [node.format_node(placeholder_names) for node in self.nodes]
|
node_strs = [node.format_node(placeholder_names) for node in self.nodes]
|
||||||
param_str = ", ".join(placeholder_names)
|
param_str = ", ".join(placeholder_names)
|
||||||
@ -1729,8 +1716,8 @@ class Graph:
|
|||||||
f"defined! Please check that Nodes in the graph are topologically ordered\n{self}"
|
f"defined! Please check that Nodes in the graph are topologically ordered\n{self}"
|
||||||
)
|
)
|
||||||
|
|
||||||
seen_names: Set[str] = set()
|
seen_names: set[str] = set()
|
||||||
seen_values: Set[Node] = set()
|
seen_values: set[Node] = set()
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
if node.op not in [
|
if node.op not in [
|
||||||
"placeholder",
|
"placeholder",
|
||||||
|
@ -8,7 +8,7 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -39,7 +39,7 @@ class _EvalCacheLoader:
|
|||||||
self.eval_cache = {}
|
self.eval_cache = {}
|
||||||
self.next_id = 0
|
self.next_id = 0
|
||||||
|
|
||||||
def cache(self, src: str, globals: Dict[str, Any], co_fields=None):
|
def cache(self, src: str, globals: dict[str, Any], co_fields=None):
|
||||||
"""Store the source in a private cache, and add a lazy entry in linecache
|
"""Store the source in a private cache, and add a lazy entry in linecache
|
||||||
that allows the source to be retrieved by 'filename'.
|
that allows the source to be retrieved by 'filename'.
|
||||||
|
|
||||||
@ -83,19 +83,19 @@ class _EvalCacheLoader:
|
|||||||
_loader = _EvalCacheLoader()
|
_loader = _EvalCacheLoader()
|
||||||
|
|
||||||
|
|
||||||
def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None):
|
def _exec_with_source(src: str, globals: dict[str, Any], co_fields=None):
|
||||||
key = _loader.cache(src, globals, co_fields)
|
key = _loader.cache(src, globals, co_fields)
|
||||||
exec(compile(src, key, "exec"), globals)
|
exec(compile(src, key, "exec"), globals)
|
||||||
|
|
||||||
|
|
||||||
def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None):
|
def _forward_from_src(src: str, globals: dict[str, Any], co_fields=None):
|
||||||
return _method_from_src(
|
return _method_from_src(
|
||||||
method_name="forward", src=src, globals=globals, co_fields=co_fields
|
method_name="forward", src=src, globals=globals, co_fields=co_fields
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _method_from_src(
|
def _method_from_src(
|
||||||
method_name: str, src: str, globals: Dict[str, Any], co_fields=None
|
method_name: str, src: str, globals: dict[str, Any], co_fields=None
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
# avoid mutating the passed in dict
|
# avoid mutating the passed in dict
|
||||||
globals_copy = globals.copy()
|
globals_copy = globals.copy()
|
||||||
@ -114,8 +114,8 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
|
|||||||
return f"from {module_name} import {attr_name} as {name}"
|
return f"from {module_name} import {attr_name} as {name}"
|
||||||
|
|
||||||
|
|
||||||
def _format_import_block(globals: Dict[str, Any], importer: Importer):
|
def _format_import_block(globals: dict[str, Any], importer: Importer):
|
||||||
import_strs: Set[str] = {
|
import_strs: set[str] = {
|
||||||
_format_import_statement(name, obj, importer) for name, obj in globals.items()
|
_format_import_statement(name, obj, importer) for name, obj in globals.items()
|
||||||
}
|
}
|
||||||
# Sort the imports so we have a stable import block that allows us to
|
# Sort the imports so we have a stable import block that allows us to
|
||||||
@ -124,7 +124,7 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer):
|
|||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
|
def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Module:
|
||||||
# BC: attribute name was changed from `code` to `_code` to facilitate
|
# BC: attribute name was changed from `code` to `_code` to facilitate
|
||||||
# making `code` into a property and adding a docstring to it
|
# making `code` into a property and adding a docstring to it
|
||||||
fn_src = body.get("_code") or body["code"]
|
fn_src = body.get("_code") or body["code"]
|
||||||
@ -134,7 +134,7 @@ def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mod
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def reduce_package_graph_module(
|
def reduce_package_graph_module(
|
||||||
importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
|
importer: PackageImporter, body: dict[Any, Any], generated_module_name: str
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
forward = importer.import_module(generated_module_name).forward
|
forward = importer.import_module(generated_module_name).forward
|
||||||
return _deserialize_graph_module(forward, body)
|
return _deserialize_graph_module(forward, body)
|
||||||
@ -142,7 +142,7 @@ def reduce_package_graph_module(
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def reduce_deploy_graph_module(
|
def reduce_deploy_graph_module(
|
||||||
importer: PackageImporter, body: Dict[Any, Any], import_block: str
|
importer: PackageImporter, body: dict[Any, Any], import_block: str
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
ns = {}
|
ns = {}
|
||||||
ns["__builtins__"] = importer.patched_builtins
|
ns["__builtins__"] = importer.patched_builtins
|
||||||
@ -162,7 +162,7 @@ class _CodeOnlyModule(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _deserialize_graph_module(
|
def _deserialize_graph_module(
|
||||||
forward, body: Dict[Any, Any], graph_module_cls=None
|
forward, body: dict[Any, Any], graph_module_cls=None
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
"""
|
"""
|
||||||
Deserialize a GraphModule given the dictionary of the original module,
|
Deserialize a GraphModule given the dictionary of the original module,
|
||||||
@ -271,7 +271,7 @@ def _get_attr(model: torch.nn.Module, attr_name: str):
|
|||||||
return _get_attr_via_attr_list(model, attr_name.split("."))
|
return _get_attr_via_attr_list(model, attr_name.split("."))
|
||||||
|
|
||||||
|
|
||||||
def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: List[str]):
|
def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]):
|
||||||
if len(attr_list) == 0:
|
if len(attr_list) == 0:
|
||||||
return model
|
return model
|
||||||
*prefix, field = attr_list
|
*prefix, field = attr_list
|
||||||
@ -415,7 +415,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
code.
|
code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls: "Type[GraphModule]", *args, **kwargs):
|
def __new__(cls: "type[GraphModule]", *args, **kwargs):
|
||||||
# each instance of a graph module needs its own forward method
|
# each instance of a graph module needs its own forward method
|
||||||
# so create a new singleton class for each instance.
|
# so create a new singleton class for each instance.
|
||||||
# it is a subclass of the user-defined class, the only difference
|
# it is a subclass of the user-defined class, the only difference
|
||||||
@ -437,7 +437,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
root: Union[torch.nn.Module, dict[str, Any]],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
class_name: str = "GraphModule",
|
class_name: str = "GraphModule",
|
||||||
):
|
):
|
||||||
@ -527,12 +527,12 @@ class GraphModule(torch.nn.Module):
|
|||||||
self._tracer_extras = self.graph._tracer_extras
|
self._tracer_extras = self.graph._tracer_extras
|
||||||
|
|
||||||
# Dictionary to store metadata
|
# Dictionary to store metadata
|
||||||
self.meta: Dict[str, Any] = {}
|
self.meta: dict[str, Any] = {}
|
||||||
self._replace_hooks: List[Callable] = []
|
self._replace_hooks: list[Callable] = []
|
||||||
self._create_node_hooks: List[Callable] = []
|
self._create_node_hooks: list[Callable] = []
|
||||||
self._erase_node_hooks: List[Callable] = []
|
self._erase_node_hooks: list[Callable] = []
|
||||||
# Used to remove hooks from deepcopied graph modules within a context manager.
|
# Used to remove hooks from deepcopied graph modules within a context manager.
|
||||||
self._deepcopy_hooks: List[Callable] = []
|
self._deepcopy_hooks: list[Callable] = []
|
||||||
|
|
||||||
# TorchScript breaks trying to compile the graph setter because of the
|
# TorchScript breaks trying to compile the graph setter because of the
|
||||||
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
|
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
|
||||||
@ -739,7 +739,7 @@ class {module_name}(torch.nn.Module):
|
|||||||
This method can be called to clean up an ``nn.Module`` without
|
This method can be called to clean up an ``nn.Module`` without
|
||||||
manually calling ``delete_submodule`` on each unused submodule.
|
manually calling ``delete_submodule`` on each unused submodule.
|
||||||
"""
|
"""
|
||||||
used: List[str] = []
|
used: list[str] = []
|
||||||
|
|
||||||
for node in self.graph.nodes:
|
for node in self.graph.nodes:
|
||||||
if node.op == "call_module" or node.op == "get_attr":
|
if node.op == "call_module" or node.op == "get_attr":
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing import Any, Dict, Iterable, List, Tuple
|
from collections.abc import Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from torch.utils._pytree import (
|
from torch.utils._pytree import (
|
||||||
_dict_flatten,
|
_dict_flatten,
|
||||||
@ -79,25 +80,25 @@ compatibility(is_backward_compatible=True)(immutable_dict)
|
|||||||
|
|
||||||
|
|
||||||
# Register immutable collections for PyTree operations
|
# Register immutable collections for PyTree operations
|
||||||
def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
|
def _immutable_dict_flatten(d: dict[Any, Any]) -> tuple[list[Any], Context]:
|
||||||
return _dict_flatten(d)
|
return _dict_flatten(d)
|
||||||
|
|
||||||
|
|
||||||
def _immutable_dict_unflatten(
|
def _immutable_dict_unflatten(
|
||||||
values: Iterable[Any],
|
values: Iterable[Any],
|
||||||
context: Context,
|
context: Context,
|
||||||
) -> Dict[Any, Any]:
|
) -> dict[Any, Any]:
|
||||||
return immutable_dict(_dict_unflatten(values, context))
|
return immutable_dict(_dict_unflatten(values, context))
|
||||||
|
|
||||||
|
|
||||||
def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
|
def _immutable_list_flatten(d: list[Any]) -> tuple[list[Any], Context]:
|
||||||
return _list_flatten(d)
|
return _list_flatten(d)
|
||||||
|
|
||||||
|
|
||||||
def _immutable_list_unflatten(
|
def _immutable_list_unflatten(
|
||||||
values: Iterable[Any],
|
values: Iterable[Any],
|
||||||
context: Context,
|
context: Context,
|
||||||
) -> List[Any]:
|
) -> list[Any]:
|
||||||
return immutable_list(_list_unflatten(values, context))
|
return immutable_list(_list_unflatten(values, context))
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import inspect
|
import inspect
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx.traceback as fx_traceback
|
import torch.fx.traceback as fx_traceback
|
||||||
@ -17,6 +17,10 @@ from .node import Argument, map_aggregate, map_arg, Node, Target
|
|||||||
from .proxy import Proxy
|
from .proxy import Proxy
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Interpreter", "Transformer"]
|
__all__ = ["Interpreter", "Transformer"]
|
||||||
|
|
||||||
|
|
||||||
@ -92,7 +96,7 @@ class Interpreter:
|
|||||||
self.graph = graph
|
self.graph = graph
|
||||||
else:
|
else:
|
||||||
self.graph = self.module.graph # type: ignore[assignment]
|
self.graph = self.module.graph # type: ignore[assignment]
|
||||||
self.env: Dict[Node, Any] = {}
|
self.env: dict[Node, Any] = {}
|
||||||
self.name = "Interpreter"
|
self.name = "Interpreter"
|
||||||
self.garbage_collect_values = garbage_collect_values
|
self.garbage_collect_values = garbage_collect_values
|
||||||
self.extra_traceback = True
|
self.extra_traceback = True
|
||||||
@ -102,8 +106,8 @@ class Interpreter:
|
|||||||
# of a given node. This represents the *last* use of the node in the
|
# of a given node. This represents the *last* use of the node in the
|
||||||
# execution order of the program, which we will use to free unused
|
# execution order of the program, which we will use to free unused
|
||||||
# values
|
# values
|
||||||
node_to_last_use: Dict[Node, Node] = {}
|
node_to_last_use: dict[Node, Node] = {}
|
||||||
self.user_to_last_uses: Dict[Node, List[Node]] = {}
|
self.user_to_last_uses: dict[Node, list[Node]] = {}
|
||||||
|
|
||||||
def register_last_uses(n: Node, user: Node):
|
def register_last_uses(n: Node, user: Node):
|
||||||
if n not in node_to_last_use:
|
if n not in node_to_last_use:
|
||||||
@ -118,7 +122,7 @@ class Interpreter:
|
|||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
initial_env: Optional[Dict[Node, Any]] = None,
|
initial_env: Optional[dict[Node, Any]] = None,
|
||||||
enable_io_processing: bool = True,
|
enable_io_processing: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
@ -232,7 +236,7 @@ class Interpreter:
|
|||||||
# Main Node running APIs
|
# Main Node running APIs
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``placeholder`` node. Note that this is stateful:
|
Execute a ``placeholder`` node. Note that this is stateful:
|
||||||
@ -268,7 +272,7 @@ class Interpreter:
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def get_attr(
|
def get_attr(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``get_attr`` node. Will retrieve an attribute
|
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||||
@ -289,7 +293,7 @@ class Interpreter:
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def call_function(
|
def call_function(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``call_function`` node and return the result.
|
Execute a ``call_function`` node and return the result.
|
||||||
@ -311,7 +315,7 @@ class Interpreter:
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def call_method(
|
def call_method(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``call_method`` node and return the result.
|
Execute a ``call_method`` node and return the result.
|
||||||
@ -335,7 +339,7 @@ class Interpreter:
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def call_module(
|
def call_module(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute a ``call_module`` node and return the result.
|
Execute a ``call_module`` node and return the result.
|
||||||
@ -360,7 +364,7 @@ class Interpreter:
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def output(
|
def output(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute an ``output`` node. This really just retrieves
|
Execute an ``output`` node. This really just retrieves
|
||||||
@ -401,7 +405,7 @@ class Interpreter:
|
|||||||
return attr_itr
|
return attr_itr
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def fetch_args_kwargs_from_env(self, n: Node) -> Tuple[Tuple, Dict]:
|
def fetch_args_kwargs_from_env(self, n: Node) -> tuple[tuple, dict]:
|
||||||
"""
|
"""
|
||||||
Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
|
Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
|
||||||
from the current execution environment.
|
from the current execution environment.
|
||||||
@ -497,7 +501,7 @@ class Transformer(Interpreter):
|
|||||||
def __init__(self, graph: Graph):
|
def __init__(self, graph: Graph):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment]
|
self.tensor_attrs: dict[torch.Tensor, str] = {} # type: ignore[assignment]
|
||||||
|
|
||||||
def is_leaf_module(self, _, __) -> bool:
|
def is_leaf_module(self, _, __) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -507,7 +511,7 @@ class Transformer(Interpreter):
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Proxy:
|
) -> Proxy:
|
||||||
"""
|
"""
|
||||||
Execute a ``placeholder`` node. In ``Transformer``, this is
|
Execute a ``placeholder`` node. In ``Transformer``, this is
|
||||||
@ -529,7 +533,7 @@ class Transformer(Interpreter):
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def get_attr(
|
def get_attr(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Proxy:
|
) -> Proxy:
|
||||||
"""
|
"""
|
||||||
Execute a ``get_attr`` node. In ``Transformer``, this is
|
Execute a ``get_attr`` node. In ``Transformer``, this is
|
||||||
@ -548,7 +552,7 @@ class Transformer(Interpreter):
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def call_module(
|
def call_module(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
# Override so that the leaf module policy from `self.tracer` is respected.
|
# Override so that the leaf module policy from `self.tracer` is respected.
|
||||||
assert isinstance(target, str)
|
assert isinstance(target, str)
|
||||||
@ -557,7 +561,7 @@ class Transformer(Interpreter):
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def call_function(
|
def call_function(
|
||||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
# Override so that functions that were wrapped are still wrapped.
|
# Override so that functions that were wrapped are still wrapped.
|
||||||
return self.tracer.create_proxy("call_function", target, args, kwargs)
|
return self.tracer.create_proxy("call_function", target, args, kwargs)
|
||||||
|
@ -3,19 +3,8 @@ import builtins
|
|||||||
import inspect
|
import inspect
|
||||||
import types
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
from typing import (
|
from collections.abc import Mapping, Sequence
|
||||||
Any,
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._C import _NodeBase
|
from torch._C import _NodeBase
|
||||||
@ -57,7 +46,7 @@ Target = Union[Callable[..., Any], str]
|
|||||||
|
|
||||||
Argument = Optional[
|
Argument = Optional[
|
||||||
Union[
|
Union[
|
||||||
Tuple["Argument", ...],
|
tuple["Argument", ...],
|
||||||
Sequence["Argument"],
|
Sequence["Argument"],
|
||||||
Mapping[str, "Argument"],
|
Mapping[str, "Argument"],
|
||||||
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
|
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
|
||||||
@ -79,7 +68,7 @@ _legal_ops = dict.fromkeys(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
|
_side_effectful_need_to_be_preserved_pre_dispatch: set[Callable] = {
|
||||||
torch._C._set_grad_enabled,
|
torch._C._set_grad_enabled,
|
||||||
torch.amp._enter_autocast,
|
torch.amp._enter_autocast,
|
||||||
torch.amp._exit_autocast,
|
torch.amp._exit_autocast,
|
||||||
@ -87,7 +76,7 @@ _side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
|
|||||||
|
|
||||||
# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
|
# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
|
||||||
# or add logic to correctly mark all inplace ops as side effectful.
|
# or add logic to correctly mark all inplace ops as side effectful.
|
||||||
_side_effectful_functions: Set[Callable] = {
|
_side_effectful_functions: set[Callable] = {
|
||||||
torch._assert,
|
torch._assert,
|
||||||
torch._assert_async,
|
torch._assert_async,
|
||||||
_ops.aten._assert_async.msg,
|
_ops.aten._assert_async.msg,
|
||||||
@ -227,18 +216,18 @@ class Node(_NodeBase):
|
|||||||
in the Graph printout.
|
in the Graph printout.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_args: Tuple["Argument", ...]
|
_args: tuple["Argument", ...]
|
||||||
_kwargs: Dict[str, "Argument"]
|
_kwargs: dict[str, "Argument"]
|
||||||
graph: "Graph"
|
graph: "Graph"
|
||||||
name: str
|
name: str
|
||||||
op: str
|
op: str
|
||||||
target: "Target"
|
target: "Target"
|
||||||
_input_nodes: Dict["Node", None]
|
_input_nodes: dict["Node", None]
|
||||||
users: Dict["Node", None]
|
users: dict["Node", None]
|
||||||
type: Optional[Any]
|
type: Optional[Any]
|
||||||
_sort_key: Any
|
_sort_key: Any
|
||||||
_repr_fn: Optional[Callable[["Node"], str]]
|
_repr_fn: Optional[Callable[["Node"], str]]
|
||||||
meta: Dict[str, Any]
|
meta: dict[str, Any]
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -247,8 +236,8 @@ class Node(_NodeBase):
|
|||||||
name: str,
|
name: str,
|
||||||
op: str,
|
op: str,
|
||||||
target: "Target",
|
target: "Target",
|
||||||
args: Tuple["Argument", ...],
|
args: tuple["Argument", ...],
|
||||||
kwargs: Dict[str, "Argument"],
|
kwargs: dict[str, "Argument"],
|
||||||
return_type: Optional[Any] = None,
|
return_type: Optional[Any] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -339,14 +328,14 @@ class Node(_NodeBase):
|
|||||||
# transformations. This metadata is preserved across node copies
|
# transformations. This metadata is preserved across node copies
|
||||||
assign(self, "meta", {})
|
assign(self, "meta", {})
|
||||||
|
|
||||||
def __getstate__(self) -> Dict[str, Any]:
|
def __getstate__(self) -> dict[str, Any]:
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
state["_erased"] = self._erased
|
state["_erased"] = self._erased
|
||||||
state["_prev"] = self._prev
|
state["_prev"] = self._prev
|
||||||
state["_next"] = self._next
|
state["_next"] = self._next
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||||
_erased = state.pop("_erased")
|
_erased = state.pop("_erased")
|
||||||
_prev = state.pop("_prev")
|
_prev = state.pop("_prev")
|
||||||
_next = state.pop("_next")
|
_next = state.pop("_next")
|
||||||
@ -442,7 +431,7 @@ class Node(_NodeBase):
|
|||||||
p._next, n._prev = n, p
|
p._next, n._prev = n, p
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> Tuple[Argument, ...]:
|
def args(self) -> tuple[Argument, ...]:
|
||||||
"""
|
"""
|
||||||
The tuple of arguments to this ``Node``. The interpretation of arguments
|
The tuple of arguments to this ``Node``. The interpretation of arguments
|
||||||
depends on the node's opcode. See the :class:`Node` docstring for more
|
depends on the node's opcode. See the :class:`Node` docstring for more
|
||||||
@ -454,7 +443,7 @@ class Node(_NodeBase):
|
|||||||
return self._args
|
return self._args
|
||||||
|
|
||||||
@args.setter
|
@args.setter
|
||||||
def args(self, a: Tuple[Argument, ...]) -> None:
|
def args(self, a: tuple[Argument, ...]) -> None:
|
||||||
"""
|
"""
|
||||||
Set the tuple of arguments to this Node. The interpretation of arguments
|
Set the tuple of arguments to this Node. The interpretation of arguments
|
||||||
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
||||||
@ -465,7 +454,7 @@ class Node(_NodeBase):
|
|||||||
self.__update_args_kwargs(a, self._kwargs)
|
self.__update_args_kwargs(a, self._kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kwargs(self) -> Dict[str, Argument]:
|
def kwargs(self) -> dict[str, Argument]:
|
||||||
"""
|
"""
|
||||||
The dict of keyword arguments to this ``Node``. The interpretation of arguments
|
The dict of keyword arguments to this ``Node``. The interpretation of arguments
|
||||||
depends on the node's opcode. See the :class:`Node` docstring for more
|
depends on the node's opcode. See the :class:`Node` docstring for more
|
||||||
@ -477,7 +466,7 @@ class Node(_NodeBase):
|
|||||||
return self._kwargs
|
return self._kwargs
|
||||||
|
|
||||||
@kwargs.setter
|
@kwargs.setter
|
||||||
def kwargs(self, k: Dict[str, Argument]) -> None:
|
def kwargs(self, k: dict[str, Argument]) -> None:
|
||||||
"""
|
"""
|
||||||
Set the dict of kwargs to this Node. The interpretation of arguments
|
Set the dict of kwargs to this Node. The interpretation of arguments
|
||||||
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
||||||
@ -488,7 +477,7 @@ class Node(_NodeBase):
|
|||||||
self.__update_args_kwargs(self._args, k)
|
self.__update_args_kwargs(self._args, k)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_input_nodes(self) -> List["Node"]:
|
def all_input_nodes(self) -> list["Node"]:
|
||||||
"""
|
"""
|
||||||
Return all Nodes that are inputs to this Node. This is equivalent to
|
Return all Nodes that are inputs to this Node. This is equivalent to
|
||||||
iterating over ``args`` and ``kwargs`` and only collecting the values that
|
iterating over ``args`` and ``kwargs`` and only collecting the values that
|
||||||
@ -534,7 +523,7 @@ class Node(_NodeBase):
|
|||||||
|
|
||||||
self._args = args_left + (arg,) + args_right
|
self._args = args_left + (arg,) + args_right
|
||||||
|
|
||||||
_new_input_nodes: Dict[Node, None] = {}
|
_new_input_nodes: dict[Node, None] = {}
|
||||||
map_arg(arg, _new_input_nodes.setdefault)
|
map_arg(arg, _new_input_nodes.setdefault)
|
||||||
|
|
||||||
for new_use in _new_input_nodes.keys():
|
for new_use in _new_input_nodes.keys():
|
||||||
@ -574,7 +563,7 @@ class Node(_NodeBase):
|
|||||||
self.meta["stack_trace"] = trace
|
self.meta["stack_trace"] = trace
|
||||||
|
|
||||||
def __update_args_kwargs(
|
def __update_args_kwargs(
|
||||||
self, new_args: Tuple["Argument", ...], new_kwargs: Dict[str, "Argument"]
|
self, new_args: tuple["Argument", ...], new_kwargs: dict[str, "Argument"]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
This API is internal. Do *not* call it directly.
|
This API is internal. Do *not* call it directly.
|
||||||
@ -634,8 +623,8 @@ class Node(_NodeBase):
|
|||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def format_node(
|
def format_node(
|
||||||
self,
|
self,
|
||||||
placeholder_names: Optional[List[str]] = None,
|
placeholder_names: Optional[list[str]] = None,
|
||||||
maybe_return_typename: Optional[List[str]] = None,
|
maybe_return_typename: Optional[list[str]] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Return a descriptive string representation of ``self``.
|
Return a descriptive string representation of ``self``.
|
||||||
@ -704,7 +693,7 @@ class Node(_NodeBase):
|
|||||||
delete_user_cb: Callable[["Node"], bool] = lambda user: True,
|
delete_user_cb: Callable[["Node"], bool] = lambda user: True,
|
||||||
*,
|
*,
|
||||||
propagate_meta: bool = False,
|
propagate_meta: bool = False,
|
||||||
) -> List["Node"]:
|
) -> list["Node"]:
|
||||||
"""
|
"""
|
||||||
Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
|
Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
|
||||||
|
|
||||||
@ -775,7 +764,7 @@ class Node(_NodeBase):
|
|||||||
# impure since it mutates inputs
|
# impure since it mutates inputs
|
||||||
return True
|
return True
|
||||||
|
|
||||||
tags: Optional[List[torch.Tag]] = getattr(self.target, "_tags", None)
|
tags: Optional[list[torch.Tag]] = getattr(self.target, "_tags", None)
|
||||||
if tags is not None and torch.Tag.nondeterministic_seeded in tags:
|
if tags is not None and torch.Tag.nondeterministic_seeded in tags:
|
||||||
# impure since it mutates RNG state
|
# impure since it mutates RNG state
|
||||||
return True
|
return True
|
||||||
@ -799,8 +788,8 @@ class Node(_NodeBase):
|
|||||||
def normalized_arguments(
|
def normalized_arguments(
|
||||||
self,
|
self,
|
||||||
root: torch.nn.Module,
|
root: torch.nn.Module,
|
||||||
arg_types: Optional[Tuple[Any]] = None,
|
arg_types: Optional[tuple[Any]] = None,
|
||||||
kwarg_types: Optional[Dict[str, Any]] = None,
|
kwarg_types: Optional[dict[str, Any]] = None,
|
||||||
normalize_to_only_use_kwargs: bool = False,
|
normalize_to_only_use_kwargs: bool = False,
|
||||||
) -> Optional[ArgsKwargsPair]:
|
) -> Optional[ArgsKwargsPair]:
|
||||||
"""
|
"""
|
||||||
|
@ -5,17 +5,7 @@ import numbers
|
|||||||
import types
|
import types
|
||||||
import typing
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from typing import (
|
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
cast,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._jit_internal import boolean_dispatched
|
from torch._jit_internal import boolean_dispatched
|
||||||
@ -44,11 +34,11 @@ class ArgsKwargsPair(NamedTuple):
|
|||||||
Simple named tuple for wrapping args/kwargs pairs.
|
Simple named tuple for wrapping args/kwargs pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args: Tuple[Any, ...]
|
args: tuple[Any, ...]
|
||||||
kwargs: Dict[str, Any]
|
kwargs: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
_manual_overrides: Dict[Callable, List[inspect.Signature]] = {}
|
_manual_overrides: dict[Callable, list[inspect.Signature]] = {}
|
||||||
|
|
||||||
|
|
||||||
def _nonzero_schemas():
|
def _nonzero_schemas():
|
||||||
@ -108,7 +98,7 @@ def _torchscript_schema_to_signature_impl(
|
|||||||
) -> inspect.Signature:
|
) -> inspect.Signature:
|
||||||
from inspect import Parameter
|
from inspect import Parameter
|
||||||
|
|
||||||
parameters: List[Parameter] = []
|
parameters: list[Parameter] = []
|
||||||
for arg in ts_schema.arguments:
|
for arg in ts_schema.arguments:
|
||||||
arg_type = _torchscript_type_to_python_type(arg.type)
|
arg_type = _torchscript_type_to_python_type(arg.type)
|
||||||
default = arg.default_value if arg.has_default_value() else Parameter.empty
|
default = arg.default_value if arg.has_default_value() else Parameter.empty
|
||||||
@ -154,7 +144,7 @@ def _torchscript_schema_to_signature_impl(
|
|||||||
return inspect.Signature(parameters, return_annotation=return_type)
|
return inspect.Signature(parameters, return_annotation=return_type)
|
||||||
|
|
||||||
|
|
||||||
_SCHEMA_TO_SIGNATURE_CACHE: Dict[Tuple[str, str], inspect.Signature] = {}
|
_SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {}
|
||||||
|
|
||||||
|
|
||||||
def _torchscript_schema_to_signature(
|
def _torchscript_schema_to_signature(
|
||||||
@ -173,7 +163,7 @@ def _torchscript_schema_to_signature(
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def check_for_mutable_operation(
|
def check_for_mutable_operation(
|
||||||
target: Callable, args: Tuple["Argument", ...], kwargs: Dict[str, "Argument"]
|
target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"]
|
||||||
):
|
):
|
||||||
signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
|
signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
|
||||||
|
|
||||||
@ -265,12 +255,12 @@ def create_type_hint(x):
|
|||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
|
|
||||||
def ret_type(x):
|
def ret_type(x):
|
||||||
return List[x] # type: ignore[valid-type]
|
return list[x] # type: ignore[valid-type]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def ret_type(x):
|
def ret_type(x):
|
||||||
return Tuple[x, ...]
|
return tuple[x, ...] # type: ignore[valid-type]
|
||||||
|
|
||||||
if len(x) == 0:
|
if len(x) == 0:
|
||||||
return ret_type(Any)
|
return ret_type(Any)
|
||||||
@ -291,6 +281,10 @@ def create_type_hint(x):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
_LIST_TYPES = (list, typing.List) # noqa: UP006
|
||||||
|
_TUPLE_TYPES = (tuple, typing.Tuple) # noqa: UP006
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def type_matches(signature_type: Any, argument_type: Any):
|
def type_matches(signature_type: Any, argument_type: Any):
|
||||||
sig_origin_type = getattr(signature_type, "__origin__", signature_type)
|
sig_origin_type = getattr(signature_type, "__origin__", signature_type)
|
||||||
@ -304,22 +298,24 @@ def type_matches(signature_type: Any, argument_type: Any):
|
|||||||
sig_contained = signature_type.__args__
|
sig_contained = signature_type.__args__
|
||||||
return any(type_matches(c, argument_type) for c in sig_contained)
|
return any(type_matches(c, argument_type) for c in sig_contained)
|
||||||
|
|
||||||
if signature_type is List[int] and argument_type is int:
|
if signature_type is typing.List[int] and argument_type is int: # noqa: UP006
|
||||||
# int can be promoted to List[int]
|
# int can be promoted to List[int]
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if getattr(signature_type, "__origin__", None) in {list, List}:
|
if getattr(signature_type, "__origin__", None) in _LIST_TYPES:
|
||||||
sig_el_type = signature_type.__args__[0]
|
sig_el_type = signature_type.__args__[0]
|
||||||
|
if sig_el_type is argument_type:
|
||||||
|
return True
|
||||||
if not inspect.isclass(sig_el_type):
|
if not inspect.isclass(sig_el_type):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Does not support nested parametric types, got {signature_type}. Please file a bug."
|
f"Does not support nested parametric types, got {signature_type}. Please file a bug."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
if getattr(argument_type, "__origin__", None) in {list, List}:
|
if getattr(argument_type, "__origin__", None) in _LIST_TYPES:
|
||||||
return issubclass(argument_type.__args__[0], sig_el_type)
|
return issubclass(argument_type.__args__[0], sig_el_type)
|
||||||
|
|
||||||
def is_homogeneous_tuple(t):
|
def is_homogeneous_tuple(t):
|
||||||
if getattr(t, "__origin__", None) not in {tuple, Tuple}:
|
if getattr(t, "__origin__", None) not in _TUPLE_TYPES:
|
||||||
return False
|
return False
|
||||||
contained = t.__args__
|
contained = t.__args__
|
||||||
if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
|
if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
|
||||||
@ -344,10 +340,10 @@ def type_matches(signature_type: Any, argument_type: Any):
|
|||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def normalize_function(
|
def normalize_function(
|
||||||
target: Callable,
|
target: Callable,
|
||||||
args: Tuple[Any],
|
args: tuple[Any],
|
||||||
kwargs: Optional[Dict[str, Any]] = None,
|
kwargs: Optional[dict[str, Any]] = None,
|
||||||
arg_types: Optional[Tuple[Any]] = None,
|
arg_types: Optional[tuple[Any]] = None,
|
||||||
kwarg_types: Optional[Dict[str, Any]] = None,
|
kwarg_types: Optional[dict[str, Any]] = None,
|
||||||
normalize_to_only_use_kwargs: bool = False,
|
normalize_to_only_use_kwargs: bool = False,
|
||||||
) -> Optional[ArgsKwargsPair]:
|
) -> Optional[ArgsKwargsPair]:
|
||||||
"""
|
"""
|
||||||
@ -424,7 +420,7 @@ def normalize_function(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if arg_types is not None or kwarg_types is not None:
|
if arg_types is not None or kwarg_types is not None:
|
||||||
arg_types = arg_types if arg_types else cast(Tuple[Any], ())
|
arg_types = arg_types if arg_types else cast(tuple[Any], ())
|
||||||
kwarg_types = kwarg_types if kwarg_types else {}
|
kwarg_types = kwarg_types if kwarg_types else {}
|
||||||
for candidate_signature in torch_op_schemas:
|
for candidate_signature in torch_op_schemas:
|
||||||
sig_matches = True
|
sig_matches = True
|
||||||
@ -468,8 +464,8 @@ def normalize_function(
|
|||||||
def normalize_module(
|
def normalize_module(
|
||||||
root: torch.nn.Module,
|
root: torch.nn.Module,
|
||||||
target: str,
|
target: str,
|
||||||
args: Tuple[Any],
|
args: tuple[Any],
|
||||||
kwargs: Optional[Dict[str, Any]] = None,
|
kwargs: Optional[dict[str, Any]] = None,
|
||||||
normalize_to_only_use_kwargs: bool = False,
|
normalize_to_only_use_kwargs: bool = False,
|
||||||
) -> Optional[ArgsKwargsPair]:
|
) -> Optional[ArgsKwargsPair]:
|
||||||
"""
|
"""
|
||||||
@ -513,8 +509,8 @@ def normalize_module(
|
|||||||
|
|
||||||
def _args_kwargs_to_normalized_args_kwargs(
|
def _args_kwargs_to_normalized_args_kwargs(
|
||||||
sig: inspect.Signature,
|
sig: inspect.Signature,
|
||||||
args: Tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
normalize_to_only_use_kwargs: bool,
|
normalize_to_only_use_kwargs: bool,
|
||||||
) -> Optional[ArgsKwargsPair]:
|
) -> Optional[ArgsKwargsPair]:
|
||||||
"""
|
"""
|
||||||
@ -552,8 +548,8 @@ def _args_kwargs_to_normalized_args_kwargs(
|
|||||||
bound_args = sig.bind(*args, **kwargs)
|
bound_args = sig.bind(*args, **kwargs)
|
||||||
bound_args.apply_defaults()
|
bound_args.apply_defaults()
|
||||||
|
|
||||||
new_kwargs: Dict[str, Any] = {}
|
new_kwargs: dict[str, Any] = {}
|
||||||
new_args: List[Any] = []
|
new_args: list[Any] = []
|
||||||
for i, param in enumerate(sig.parameters):
|
for i, param in enumerate(sig.parameters):
|
||||||
if not normalize_to_only_use_kwargs and i < len(args):
|
if not normalize_to_only_use_kwargs and i < len(args):
|
||||||
new_args.append(bound_args.arguments[param])
|
new_args.append(bound_args.arguments[param])
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, List, Set, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from sympy import Integer, Number, Symbol
|
from sympy import Integer, Number, Symbol
|
||||||
from sympy.logic.boolalg import BooleanAtom
|
from sympy.logic.boolalg import BooleanAtom
|
||||||
@ -28,7 +28,7 @@ from torch.utils._sympy.reference import TensorReferenceAnalysis
|
|||||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||||
|
|
||||||
|
|
||||||
__all__: List[str] = []
|
__all__: list[str] = []
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
|
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
|
||||||
@ -242,7 +242,7 @@ def tensorify_python_scalars(
|
|||||||
if node.op == "call_function" and (
|
if node.op == "call_function" and (
|
||||||
replacement_op := SUPPORTED_OPS.get(node.target)
|
replacement_op := SUPPORTED_OPS.get(node.target)
|
||||||
):
|
):
|
||||||
args: List[Any] = []
|
args: list[Any] = []
|
||||||
transform = False
|
transform = False
|
||||||
compute_dtype = get_computation_dtype(node.meta["val"].dtype)
|
compute_dtype = get_computation_dtype(node.meta["val"].dtype)
|
||||||
|
|
||||||
@ -299,7 +299,7 @@ def tensorify_python_scalars(
|
|||||||
"tensorify_float_success", True, overwrite=True
|
"tensorify_float_success", True, overwrite=True
|
||||||
)
|
)
|
||||||
|
|
||||||
failed_tensorify_ops: Set[str] = set()
|
failed_tensorify_ops: set[str] = set()
|
||||||
|
|
||||||
# Now do one more pass that specializes all symfloats we didn't manage
|
# Now do one more pass that specializes all symfloats we didn't manage
|
||||||
# to tensorify away.
|
# to tensorify away.
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Graph, GraphModule, Node
|
from torch.fx import Graph, GraphModule, Node
|
||||||
@ -90,14 +90,14 @@ class CSEPass(PassBase):
|
|||||||
|
|
||||||
modified = False
|
modified = False
|
||||||
new_graph = Graph()
|
new_graph = Graph()
|
||||||
env: Dict[
|
env: dict[
|
||||||
Node, Node
|
Node, Node
|
||||||
] = {} # map from node in the old graph to node in the new graph
|
] = {} # map from node in the old graph to node in the new graph
|
||||||
hash_env: Dict[
|
hash_env: dict[
|
||||||
Tuple[torch._ops.OpOverload, int], Node
|
tuple[torch._ops.OpOverload, int], Node
|
||||||
] = {} # map from hash to a node in the new graph
|
] = {} # map from hash to a node in the new graph
|
||||||
token_map: Dict[
|
token_map: dict[
|
||||||
Tuple[torch._ops.OpOverload, int], Dict[str, Any]
|
tuple[torch._ops.OpOverload, int], dict[str, Any]
|
||||||
] = {} # map from hash to token
|
] = {} # map from hash to token
|
||||||
for n in graph_module.graph.nodes:
|
for n in graph_module.graph.nodes:
|
||||||
# The placeholder, output, and get_attr nodes are copied to the new graph without change
|
# The placeholder, output, and get_attr nodes are copied to the new graph without change
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
from typing import Any, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
@ -150,10 +150,10 @@ if HAS_PYDOT:
|
|||||||
def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
|
def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
|
||||||
return self._dot_graphs[f"{self._name}_{submod_name}"]
|
return self._dot_graphs[f"{self._name}_{submod_name}"]
|
||||||
|
|
||||||
def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]:
|
def get_all_dot_graphs(self) -> dict[str, pydot.Dot]:
|
||||||
return self._dot_graphs
|
return self._dot_graphs
|
||||||
|
|
||||||
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
|
def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]:
|
||||||
template = {
|
template = {
|
||||||
"shape": self.dot_graph_shape,
|
"shape": self.dot_graph_shape,
|
||||||
"fillcolor": "#CAFFE3",
|
"fillcolor": "#CAFFE3",
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional
|
from typing import Any, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
@ -29,7 +29,7 @@ def replace_target_nodes_with(
|
|||||||
"""Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
|
"""Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
|
||||||
and updates them to match the new op code and target"""
|
and updates them to match the new op code and target"""
|
||||||
new_graph = Graph()
|
new_graph = Graph()
|
||||||
val_map: Dict[Node, Node] = {}
|
val_map: dict[Node, Node] = {}
|
||||||
for node in fx_module.graph.nodes:
|
for node in fx_module.graph.nodes:
|
||||||
if node.op == old_op and node.target == old_target:
|
if node.op == old_op and node.target == old_target:
|
||||||
args = map_arg(node.args, lambda n: val_map[n])
|
args = map_arg(node.args, lambda n: val_map[n])
|
||||||
@ -52,7 +52,7 @@ class size_bytes(NamedTuple):
|
|||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def get_size_of_all_nodes(
|
def get_size_of_all_nodes(
|
||||||
fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
|
fx_module: GraphModule, args: Optional[list[torch.Tensor]] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Given a fx graph module, update each node with its total size (weights + bias + output)
|
"""Given a fx graph module, update each node with its total size (weights + bias + output)
|
||||||
and its output_size(output). For a non-module node, the total size is the output size.
|
and its output_size(output). For a non-module node, the total size is the output size.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Dict, List, Optional, Set, TypeVar
|
from typing import Callable, Optional, TypeVar
|
||||||
|
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
@ -45,11 +45,11 @@ class GraphTransformObserver:
|
|||||||
self.active = trace.enabled or self.log_url is not None
|
self.active = trace.enabled or self.log_url is not None
|
||||||
|
|
||||||
if self.active:
|
if self.active:
|
||||||
self.erased_nodes: Set[str] = set()
|
self.erased_nodes: set[str] = set()
|
||||||
self.created_nodes: Set[str] = set()
|
self.created_nodes: set[str] = set()
|
||||||
self.name_to_node: Dict[str, Node] = {}
|
self.name_to_node: dict[str, Node] = {}
|
||||||
# record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context
|
# record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context
|
||||||
self.copied_gms: List[GraphModule] = []
|
self.copied_gms: list[GraphModule] = []
|
||||||
|
|
||||||
self._node_creation_hook = self.get_node_creation_hook()
|
self._node_creation_hook = self.get_node_creation_hook()
|
||||||
self._node_erase_hook = self.get_node_erase_hook()
|
self._node_erase_hook = self.get_node_erase_hook()
|
||||||
|
@ -2,8 +2,9 @@
|
|||||||
import collections
|
import collections
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Dict, Iterable, List, Optional, Sequence, Set
|
from typing import Optional
|
||||||
|
|
||||||
from torch.fx.graph_module import GraphModule
|
from torch.fx.graph_module import GraphModule
|
||||||
from torch.fx.node import _get_qualified_name, Node
|
from torch.fx.node import _get_qualified_name, Node
|
||||||
@ -52,10 +53,10 @@ class _DependencyViewer:
|
|||||||
self.downstreams[node].add(output_node)
|
self.downstreams[node].add(output_node)
|
||||||
self.downstreams[node].update(self.downstreams[output_node])
|
self.downstreams[node].update(self.downstreams[output_node])
|
||||||
|
|
||||||
def downstreams_of(self, node: Node) -> Set[Node]:
|
def downstreams_of(self, node: Node) -> set[Node]:
|
||||||
return self.downstreams[node]
|
return self.downstreams[node]
|
||||||
|
|
||||||
def upstreams_of(self, node: Node) -> Set[Node]:
|
def upstreams_of(self, node: Node) -> set[Node]:
|
||||||
return self.upstreams[node]
|
return self.upstreams[node]
|
||||||
|
|
||||||
|
|
||||||
@ -84,21 +85,21 @@ class CapabilityBasedPartitioner:
|
|||||||
dict(self.graph_module.named_modules()), node
|
dict(self.graph_module.named_modules()), node
|
||||||
)
|
)
|
||||||
|
|
||||||
def propose_partitions(self) -> List[Partition]:
|
def propose_partitions(self) -> list[Partition]:
|
||||||
# partition_map is a mapping from partition id to a set of partition id's.
|
# partition_map is a mapping from partition id to a set of partition id's.
|
||||||
# The value set contains all the partition ids that can be reached by doing a
|
# The value set contains all the partition ids that can be reached by doing a
|
||||||
# DFS starting from the partition id in the key.
|
# DFS starting from the partition id in the key.
|
||||||
partition_map: Dict[int, Set] = collections.defaultdict(set)
|
partition_map: dict[int, set] = collections.defaultdict(set)
|
||||||
|
|
||||||
# assumptions: nodes in candidate list is sorted in topological order
|
# assumptions: nodes in candidate list is sorted in topological order
|
||||||
assignment: Dict[Node, int] = {} # mapping from node to partition_id
|
assignment: dict[Node, int] = {} # mapping from node to partition_id
|
||||||
partitions_by_id: Dict[
|
partitions_by_id: dict[
|
||||||
int, Partition
|
int, Partition
|
||||||
] = {} # mapping from partition_id to partition
|
] = {} # mapping from partition_id to partition
|
||||||
nodes_order: Dict[
|
nodes_order: dict[
|
||||||
Node, int
|
Node, int
|
||||||
] = {} # mapping from nodes to reversed topological order
|
] = {} # mapping from nodes to reversed topological order
|
||||||
partitions_order: Dict[
|
partitions_order: dict[
|
||||||
int, int
|
int, int
|
||||||
] = {} # mapping from partition_id to minimum topo order of nodes in partition
|
] = {} # mapping from partition_id to minimum topo order of nodes in partition
|
||||||
new_partition_id = itertools.count()
|
new_partition_id = itertools.count()
|
||||||
@ -111,7 +112,7 @@ class CapabilityBasedPartitioner:
|
|||||||
merged_nodes = copy(partitions_by_id[self_id].nodes)
|
merged_nodes = copy(partitions_by_id[self_id].nodes)
|
||||||
merged_nodes.update(partitions_by_id[other_id].nodes)
|
merged_nodes.update(partitions_by_id[other_id].nodes)
|
||||||
|
|
||||||
def dfs_iter_find_cycle(all_user_nodes: Set[Node]):
|
def dfs_iter_find_cycle(all_user_nodes: set[Node]):
|
||||||
for user_node in all_user_nodes:
|
for user_node in all_user_nodes:
|
||||||
visited_partition_ids = set()
|
visited_partition_ids = set()
|
||||||
|
|
||||||
@ -210,7 +211,7 @@ class CapabilityBasedPartitioner:
|
|||||||
|
|
||||||
for node in reversed(self.graph_module.graph.nodes):
|
for node in reversed(self.graph_module.graph.nodes):
|
||||||
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
|
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
|
||||||
merge_candidates: Dict[int, None] = {}
|
merge_candidates: dict[int, None] = {}
|
||||||
|
|
||||||
# Note a limited horizontal fusion is enabled:
|
# Note a limited horizontal fusion is enabled:
|
||||||
# when `node` is not supported, the code below attempts to fuse consumer of `node`.
|
# when `node` is not supported, the code below attempts to fuse consumer of `node`.
|
||||||
@ -241,7 +242,7 @@ class CapabilityBasedPartitioner:
|
|||||||
|
|
||||||
# post processing to re-assign "getitem" nodes into upstream partition
|
# post processing to re-assign "getitem" nodes into upstream partition
|
||||||
logger.debug("Reassigning getitem nodes to its producer node's partition...")
|
logger.debug("Reassigning getitem nodes to its producer node's partition...")
|
||||||
nodes_reassignment: Dict[Node, int] = {}
|
nodes_reassignment: dict[Node, int] = {}
|
||||||
for node in self.graph_module.graph.nodes:
|
for node in self.graph_module.graph.nodes:
|
||||||
is_tuple_output = True
|
is_tuple_output = True
|
||||||
for user in node.users:
|
for user in node.users:
|
||||||
@ -266,7 +267,7 @@ class CapabilityBasedPartitioner:
|
|||||||
logger.debug("Filtering out single node partitions...")
|
logger.debug("Filtering out single node partitions...")
|
||||||
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
|
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
|
||||||
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
|
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
|
||||||
partitions_to_remove: List[int] = []
|
partitions_to_remove: list[int] = []
|
||||||
for id, partition in partitions_by_id.items():
|
for id, partition in partitions_by_id.items():
|
||||||
compute_node_count = 0
|
compute_node_count = 0
|
||||||
for node in partition.nodes:
|
for node in partition.nodes:
|
||||||
@ -295,7 +296,7 @@ class CapabilityBasedPartitioner:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def fuse_partitions(
|
def fuse_partitions(
|
||||||
self, partitions: List[Partition], prefix: str = "fused_"
|
self, partitions: list[Partition], prefix: str = "fused_"
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
logger.debug("Fusing partitions...")
|
logger.debug("Fusing partitions...")
|
||||||
# fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ]
|
# fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ]
|
||||||
@ -306,7 +307,7 @@ class CapabilityBasedPartitioner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# remove non-compute-ops that sits at the boundary of a partition.
|
# remove non-compute-ops that sits at the boundary of a partition.
|
||||||
def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
|
def remove_bookend_non_compute_ops(self, partitions: list[Partition]):
|
||||||
non_compute_ops = set(self.non_compute_ops)
|
non_compute_ops = set(self.non_compute_ops)
|
||||||
|
|
||||||
def is_non_compute_node(node: Node):
|
def is_non_compute_node(node: Node):
|
||||||
@ -316,11 +317,11 @@ class CapabilityBasedPartitioner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# cache transparent nodes
|
# cache transparent nodes
|
||||||
transparent_input_nodes: Dict[Node, bool] = {}
|
transparent_input_nodes: dict[Node, bool] = {}
|
||||||
transparent_output_nodes: Dict[Node, bool] = {}
|
transparent_output_nodes: dict[Node, bool] = {}
|
||||||
|
|
||||||
def is_transparent_input_node(
|
def is_transparent_input_node(
|
||||||
node: Node, partition: Set[Node], removed_nodes: Set[Node]
|
node: Node, partition: set[Node], removed_nodes: set[Node]
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
node.op == "placeholder"
|
node.op == "placeholder"
|
||||||
@ -341,7 +342,7 @@ class CapabilityBasedPartitioner:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def is_transparent_output_node(
|
def is_transparent_output_node(
|
||||||
node: Node, partition: Set[Node], removed_nodes: Set[Node]
|
node: Node, partition: set[Node], removed_nodes: set[Node]
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
node.op == "placeholder"
|
node.op == "placeholder"
|
||||||
@ -367,7 +368,7 @@ class CapabilityBasedPartitioner:
|
|||||||
# Note it's ok to use `set` here, since we are only query if a node
|
# Note it's ok to use `set` here, since we are only query if a node
|
||||||
# has been removed. We are NEVER going to iterate on nodes inside
|
# has been removed. We are NEVER going to iterate on nodes inside
|
||||||
# the set.
|
# the set.
|
||||||
remove_node: Set[Node] = set()
|
remove_node: set[Node] = set()
|
||||||
for node in partition.nodes:
|
for node in partition.nodes:
|
||||||
if is_non_compute_node(node) and (
|
if is_non_compute_node(node) and (
|
||||||
is_transparent_input_node(node, set(partition.nodes), remove_node)
|
is_transparent_input_node(node, set(partition.nodes), remove_node)
|
||||||
|
@ -3,7 +3,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
@ -50,7 +50,7 @@ def pass_result_wrapper(fn: Callable) -> Callable:
|
|||||||
|
|
||||||
|
|
||||||
def _validate_pass_schedule_constraint(
|
def _validate_pass_schedule_constraint(
|
||||||
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
|
constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
|
||||||
) -> None:
|
) -> None:
|
||||||
for i, a in enumerate(passes):
|
for i, a in enumerate(passes):
|
||||||
for j, b in enumerate(passes[i + 1 :]):
|
for j, b in enumerate(passes[i + 1 :]):
|
||||||
@ -64,8 +64,8 @@ def _validate_pass_schedule_constraint(
|
|||||||
|
|
||||||
|
|
||||||
def _topological_sort_passes(
|
def _topological_sort_passes(
|
||||||
passes: List[Callable], constraints: List[Callable]
|
passes: list[Callable], constraints: list[Callable]
|
||||||
) -> List[Callable]:
|
) -> list[Callable]:
|
||||||
"""
|
"""
|
||||||
Args
|
Args
|
||||||
passes: Passes that we are ordering
|
passes: Passes that we are ordering
|
||||||
@ -79,8 +79,8 @@ def _topological_sort_passes(
|
|||||||
return passes
|
return passes
|
||||||
|
|
||||||
# Contruct a graph mapping nodes to a list of their users
|
# Contruct a graph mapping nodes to a list of their users
|
||||||
graph: Dict[Callable, List[Callable]] = {p: [] for p in passes}
|
graph: dict[Callable, list[Callable]] = {p: [] for p in passes}
|
||||||
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
|
indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0)
|
||||||
candidates: Queue = Queue()
|
candidates: Queue = Queue()
|
||||||
for a in passes:
|
for a in passes:
|
||||||
for b in passes:
|
for b in passes:
|
||||||
@ -95,8 +95,8 @@ def _topological_sort_passes(
|
|||||||
if indegree_map[a] == 0:
|
if indegree_map[a] == 0:
|
||||||
candidates.put(a)
|
candidates.put(a)
|
||||||
|
|
||||||
visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
|
visited: dict[Callable, bool] = dict.fromkeys(passes, False)
|
||||||
sorted_passes: List[Callable] = []
|
sorted_passes: list[Callable] = []
|
||||||
|
|
||||||
while not candidates.empty():
|
while not candidates.empty():
|
||||||
p = candidates.get()
|
p = candidates.get()
|
||||||
@ -169,8 +169,8 @@ class PassManager:
|
|||||||
checks
|
checks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
passes: List[Callable[[nn.Module], PassResult]]
|
passes: list[Callable[[nn.Module], PassResult]]
|
||||||
constraints: List[Callable[[Callable, Callable], bool]]
|
constraints: list[Callable[[Callable, Callable], bool]]
|
||||||
_validated: bool = False
|
_validated: bool = False
|
||||||
steps: int = 1
|
steps: int = 1
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
@ -106,7 +106,7 @@ class _MinimizerBase:
|
|||||||
module: torch.fx.GraphModule,
|
module: torch.fx.GraphModule,
|
||||||
sample_input: Tensors,
|
sample_input: Tensors,
|
||||||
compare_fn: Callable[
|
compare_fn: Callable[
|
||||||
[TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
|
[TensorOrTensors, TensorOrTensors, Names], tuple[float, bool]
|
||||||
],
|
],
|
||||||
settings: _MinimizerSettingBase,
|
settings: _MinimizerSettingBase,
|
||||||
module_exporter: Optional[
|
module_exporter: Optional[
|
||||||
@ -124,16 +124,16 @@ class _MinimizerBase:
|
|||||||
self.exclusion_fn = exclusion_fn
|
self.exclusion_fn = exclusion_fn
|
||||||
|
|
||||||
# Stores outputs of run_a function
|
# Stores outputs of run_a function
|
||||||
self.a_outputs: Dict[str, Any] = {}
|
self.a_outputs: dict[str, Any] = {}
|
||||||
|
|
||||||
# Stores outputs of run_b function
|
# Stores outputs of run_b function
|
||||||
self.b_outputs: Dict[str, Any] = {}
|
self.b_outputs: dict[str, Any] = {}
|
||||||
|
|
||||||
# Stores the results of compare_fn
|
# Stores the results of compare_fn
|
||||||
self.results: Dict[Any, Any] = {}
|
self.results: dict[Any, Any] = {}
|
||||||
|
|
||||||
# Stores the report for the runs
|
# Stores the report for the runs
|
||||||
self.reports: List[List[str]] = []
|
self.reports: list[list[str]] = []
|
||||||
|
|
||||||
# Current iteration
|
# Current iteration
|
||||||
self.iteration: int = 0
|
self.iteration: int = 0
|
||||||
@ -205,7 +205,7 @@ class _MinimizerBase:
|
|||||||
|
|
||||||
def _get_submod_inputs(
|
def _get_submod_inputs(
|
||||||
self, main_module: torch.fx.GraphModule, submod_path: str
|
self, main_module: torch.fx.GraphModule, submod_path: str
|
||||||
) -> Tuple[Tensors, Tensors]:
|
) -> tuple[Tensors, Tensors]:
|
||||||
"""
|
"""
|
||||||
Try get submodule inputs from stored outputs. If not found then use
|
Try get submodule inputs from stored outputs. If not found then use
|
||||||
torch_glow.get_submod_inputs to get the inputs.
|
torch_glow.get_submod_inputs to get the inputs.
|
||||||
@ -280,7 +280,7 @@ class _MinimizerBase:
|
|||||||
else:
|
else:
|
||||||
node.tag = "main_0"
|
node.tag = "main_0"
|
||||||
|
|
||||||
def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]:
|
def _build_submodule(self, nodes: NodeSet) -> tuple[torch.fx.GraphModule, str]:
|
||||||
"""
|
"""
|
||||||
Split self.module so that one submodule consists of `nodes` and only `nodes`.
|
Split self.module so that one submodule consists of `nodes` and only `nodes`.
|
||||||
|
|
||||||
@ -412,7 +412,7 @@ class _MinimizerBase:
|
|||||||
culprits: NodeSet = set()
|
culprits: NodeSet = set()
|
||||||
nodes: NodeList = all_nodes[start_idx:end_idx]
|
nodes: NodeList = all_nodes[start_idx:end_idx]
|
||||||
|
|
||||||
report: List[str] = []
|
report: list[str] = []
|
||||||
if self.exclusion_fn is not None:
|
if self.exclusion_fn is not None:
|
||||||
self.exclusion_fn(nodes, start_idx, end_idx)
|
self.exclusion_fn(nodes, start_idx, end_idx)
|
||||||
if len(nodes) == 0:
|
if len(nodes) == 0:
|
||||||
@ -484,7 +484,7 @@ class _MinimizerBase:
|
|||||||
culprits: NodeSet = set()
|
culprits: NodeSet = set()
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
report: List[str] = []
|
report: list[str] = []
|
||||||
self.reports.append(report)
|
self.reports.append(report)
|
||||||
self.iteration += 1
|
self.iteration += 1
|
||||||
report.append(f"Sequential traverse iteration {self.iteration}.")
|
report.append(f"Sequential traverse iteration {self.iteration}.")
|
||||||
@ -534,7 +534,7 @@ class _MinimizerBase:
|
|||||||
find_last_node: If True, search for the last node which result in numerics difference
|
find_last_node: If True, search for the last node which result in numerics difference
|
||||||
if False: find first node in sorted node list
|
if False: find first node in sorted node list
|
||||||
"""
|
"""
|
||||||
report: List[str] = []
|
report: list[str] = []
|
||||||
|
|
||||||
mid = (start_idx + end_idx) // 2
|
mid = (start_idx + end_idx) // 2
|
||||||
cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:]
|
cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:]
|
||||||
@ -726,7 +726,7 @@ class _MinimizerBase:
|
|||||||
return culprits
|
return culprits
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
report: List[str] = []
|
report: list[str] = []
|
||||||
self.reports.append(report)
|
self.reports.append(report)
|
||||||
self.iteration += 1
|
self.iteration += 1
|
||||||
report.append(f"Accumulate traverse iteration {self.iteration}.")
|
report.append(f"Accumulate traverse iteration {self.iteration}.")
|
||||||
@ -770,7 +770,7 @@ class _MinimizerBase:
|
|||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node in self.fusions:
|
if node in self.fusions:
|
||||||
cur_nodes.update(self.fusions[node])
|
cur_nodes.update(self.fusions[node])
|
||||||
report: List[str] = []
|
report: list[str] = []
|
||||||
self.reports.append(report)
|
self.reports.append(report)
|
||||||
self.iteration += 1
|
self.iteration += 1
|
||||||
report.append(f" Nodes block {self.iteration}.")
|
report.append(f" Nodes block {self.iteration}.")
|
||||||
@ -797,7 +797,7 @@ class _MinimizerBase:
|
|||||||
self.print_report(report)
|
self.print_report(report)
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet:
|
def _skip_traverse(self, all_nodes: NodeList, skip_nodes: list) -> NodeSet:
|
||||||
"""
|
"""
|
||||||
Skip certain nodes in graph based on settings
|
Skip certain nodes in graph based on settings
|
||||||
"""
|
"""
|
||||||
@ -874,7 +874,7 @@ class _MinimizerBase:
|
|||||||
) as e:
|
) as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
def print_report(self, report: List[str]):
|
def print_report(self, report: list[str]):
|
||||||
for i in range(len(report)):
|
for i in range(len(report)):
|
||||||
if i > 0:
|
if i > 0:
|
||||||
print(" . " + report[i])
|
print(" . " + report[i])
|
||||||
@ -889,7 +889,7 @@ class _MinimizerBase:
|
|||||||
self,
|
self,
|
||||||
start: Optional[str] = None,
|
start: Optional[str] = None,
|
||||||
end: Optional[str] = None,
|
end: Optional[str] = None,
|
||||||
skip_nodes: Optional[List] = None,
|
skip_nodes: Optional[list] = None,
|
||||||
find_last_node: Optional[bool] = None,
|
find_last_node: Optional[bool] = None,
|
||||||
) -> NodeSet:
|
) -> NodeSet:
|
||||||
"""
|
"""
|
||||||
|
@ -24,9 +24,9 @@ TargetTypeName = str
|
|||||||
|
|
||||||
# Arguments' dtypes for a given node, see `OperatorSupport`
|
# Arguments' dtypes for a given node, see `OperatorSupport`
|
||||||
SupportedArgumentDTypes = t.Optional[
|
SupportedArgumentDTypes = t.Optional[
|
||||||
t.Tuple[
|
tuple[
|
||||||
t.Sequence[t.Sequence[torch.dtype]],
|
t.Sequence[t.Sequence[torch.dtype]],
|
||||||
t.Dict[str, t.Sequence[torch.dtype]],
|
dict[str, t.Sequence[torch.dtype]],
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -204,7 +204,7 @@ class OpSupports:
|
|||||||
return create_op_support(_decline_if_input_dtype)
|
return create_op_support(_decline_if_input_dtype)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
|
def decline_if_node_in_names(cls, disallow_set: set[str]) -> OperatorSupportBase:
|
||||||
"""
|
"""
|
||||||
If a node has a name that is in the disallow set, reported it as non-supported.
|
If a node has a name that is in the disallow set, reported it as non-supported.
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
from typing import Any, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -23,7 +23,7 @@ def default_matching(name: str, target_version: int) -> str:
|
|||||||
# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
|
# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
|
||||||
# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
|
# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
|
||||||
# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
|
# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
|
||||||
module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = {
|
module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = {
|
||||||
torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
|
torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
|
||||||
torch.nn.modules.conv.Conv2d: (
|
torch.nn.modules.conv.Conv2d: (
|
||||||
1,
|
1,
|
||||||
@ -55,11 +55,11 @@ module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]]
|
|||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
|
def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]:
|
||||||
"""If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
|
"""If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
|
||||||
after checking module's version is compatible with the `module_fetch_book`.
|
after checking module's version is compatible with the `module_fetch_book`.
|
||||||
"""
|
"""
|
||||||
attrs_for_lowering: Dict[str, Any] = {}
|
attrs_for_lowering: dict[str, Any] = {}
|
||||||
attrs_for_lowering["name"] = torch.typename(mod)
|
attrs_for_lowering["name"] = torch.typename(mod)
|
||||||
|
|
||||||
if type(mod) in module_fetch_book:
|
if type(mod) in module_fetch_book:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from inspect import unwrap
|
from inspect import unwrap
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -121,7 +121,7 @@ def loop_pass(
|
|||||||
# Implemented as 'depends on' operators. A constraint is satisfied iff a list
|
# Implemented as 'depends on' operators. A constraint is satisfied iff a list
|
||||||
# has a valid partial ordering according to this comparison operator.
|
# has a valid partial ordering according to this comparison operator.
|
||||||
def _validate_pass_schedule_constraint(
|
def _validate_pass_schedule_constraint(
|
||||||
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
|
constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
|
||||||
):
|
):
|
||||||
for i, a in enumerate(passes):
|
for i, a in enumerate(passes):
|
||||||
for j, b in enumerate(passes[i + 1 :]):
|
for j, b in enumerate(passes[i + 1 :]):
|
||||||
@ -191,8 +191,8 @@ class PassManager:
|
|||||||
`this_before_that_pass_constraint` for example.
|
`this_before_that_pass_constraint` for example.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
passes: List[Callable]
|
passes: list[Callable]
|
||||||
constraints: List[Callable]
|
constraints: list[Callable]
|
||||||
_validated: bool = False
|
_validated: bool = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -217,7 +217,7 @@ class PassManager:
|
|||||||
self.constraints.append(constraint)
|
self.constraints.append(constraint)
|
||||||
self._validated = False
|
self._validated = False
|
||||||
|
|
||||||
def remove_pass(self, _passes: List[str]):
|
def remove_pass(self, _passes: list[str]):
|
||||||
if _passes is None:
|
if _passes is None:
|
||||||
return
|
return
|
||||||
passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]
|
passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]
|
||||||
|
@ -3,7 +3,6 @@ import _operator
|
|||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Set
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||||
@ -199,7 +198,7 @@ _VIEW_INVERSE_MAP = {
|
|||||||
# This function, given a set of set of (aliased) tensor nodes,
|
# This function, given a set of set of (aliased) tensor nodes,
|
||||||
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
|
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
|
||||||
# in the node ordering.
|
# in the node ordering.
|
||||||
def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
|
def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int):
|
||||||
def _add_if_tensor(x, set_):
|
def _add_if_tensor(x, set_):
|
||||||
if isinstance(x, FakeTensor):
|
if isinstance(x, FakeTensor):
|
||||||
set_.add(StorageWeakRef(x._typed_storage()))
|
set_.add(StorageWeakRef(x._typed_storage()))
|
||||||
@ -233,8 +232,8 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
|
|||||||
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
|
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
|
||||||
# as "alias"
|
# as "alias"
|
||||||
def _get_view_inverse_node_usages(
|
def _get_view_inverse_node_usages(
|
||||||
later_node_usages: Set[Node], self_aliases: Set[Node]
|
later_node_usages: set[Node], self_aliases: set[Node]
|
||||||
) -> Set[Node]:
|
) -> set[Node]:
|
||||||
def matching_view_metadata(a, b):
|
def matching_view_metadata(a, b):
|
||||||
return (
|
return (
|
||||||
a.size() == b.size()
|
a.size() == b.size()
|
||||||
@ -515,7 +514,7 @@ def reinplace(gm, *sample_args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# We also need to know for a given node, what are all of its aliasing nodes.
|
# We also need to know for a given node, what are all of its aliasing nodes.
|
||||||
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
|
storage_to_nodes: dict[StorageWeakRef, set[Node]] = defaultdict(set)
|
||||||
for n in gm.graph.nodes:
|
for n in gm.graph.nodes:
|
||||||
if "fake_result" in n.meta:
|
if "fake_result" in n.meta:
|
||||||
# Tree-mapping because some ops can return lists of tensors.
|
# Tree-mapping because some ops can return lists of tensors.
|
||||||
|
@ -3,7 +3,7 @@ import functools
|
|||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Optional, Set, TYPE_CHECKING
|
from typing import Any, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
|
# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
|
||||||
@ -123,7 +123,7 @@ def insert_deferred_runtime_asserts(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# We are going to mutate the dict
|
# We are going to mutate the dict
|
||||||
expr_to_proxy: Dict[sympy.Expr, fx.Proxy] = {}
|
expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {}
|
||||||
placeholders = set()
|
placeholders = set()
|
||||||
first_non_placeholder = None
|
first_non_placeholder = None
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
@ -163,7 +163,7 @@ def insert_deferred_runtime_asserts(
|
|||||||
def _node_metadata_hook(
|
def _node_metadata_hook(
|
||||||
node: torch.fx.Node,
|
node: torch.fx.Node,
|
||||||
stack_trace: Optional[str] = None,
|
stack_trace: Optional[str] = None,
|
||||||
nn_module_stack: Optional[Dict[str, Any]] = None,
|
nn_module_stack: Optional[dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
fake_args = pytree.tree_map(
|
fake_args = pytree.tree_map(
|
||||||
lambda arg: (
|
lambda arg: (
|
||||||
@ -189,8 +189,8 @@ def insert_deferred_runtime_asserts(
|
|||||||
node.meta["nn_module_stack"] = nn_module_stack
|
node.meta["nn_module_stack"] = nn_module_stack
|
||||||
|
|
||||||
# Track asserts/checks we've added
|
# Track asserts/checks we've added
|
||||||
added_asserts: Set[sympy.Expr] = set()
|
added_asserts: set[sympy.Expr] = set()
|
||||||
constrained_unbacked_symbols: Set[sympy.Symbol] = set()
|
constrained_unbacked_symbols: set[sympy.Symbol] = set()
|
||||||
|
|
||||||
Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
|
Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Dict, NamedTuple, Optional, Tuple
|
from typing import Any, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
@ -24,12 +24,12 @@ class TensorMetadata(NamedTuple):
|
|||||||
shape: torch.Size
|
shape: torch.Size
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
requires_grad: bool
|
requires_grad: bool
|
||||||
stride: Tuple[int, ...]
|
stride: tuple[int, ...]
|
||||||
memory_format: Optional[torch.memory_format]
|
memory_format: Optional[torch.memory_format]
|
||||||
|
|
||||||
# Quantization metadata
|
# Quantization metadata
|
||||||
is_quantized: bool
|
is_quantized: bool
|
||||||
qparams: Dict[str, Any]
|
qparams: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
def _extract_tensor_metadata(
|
def _extract_tensor_metadata(
|
||||||
@ -57,7 +57,7 @@ def _extract_tensor_metadata(
|
|||||||
break
|
break
|
||||||
|
|
||||||
is_quantized = result.is_quantized
|
is_quantized = result.is_quantized
|
||||||
qparams: Dict[str, Any] = {}
|
qparams: dict[str, Any] = {}
|
||||||
if is_quantized:
|
if is_quantized:
|
||||||
qscheme = result.qscheme()
|
qscheme = result.qscheme()
|
||||||
qparams["qscheme"] = qscheme
|
qparams["qscheme"] = qscheme
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
@ -20,14 +20,14 @@ class Partition:
|
|||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
self.name: str = name
|
self.name: str = name
|
||||||
self.submod_name = f"submod_{name}"
|
self.submod_name = f"submod_{name}"
|
||||||
self.node_names: List[str] = []
|
self.node_names: list[str] = []
|
||||||
self.inputs: Dict[str, None] = {}
|
self.inputs: dict[str, None] = {}
|
||||||
self.outputs: Dict[str, None] = {}
|
self.outputs: dict[str, None] = {}
|
||||||
self.dependencies: Dict[str, None] = {}
|
self.dependencies: dict[str, None] = {}
|
||||||
self.dependents: Dict[str, None] = {}
|
self.dependents: dict[str, None] = {}
|
||||||
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||||
self.environment: Dict[Node, Node] = {}
|
self.environment: dict[Node, Node] = {}
|
||||||
self.targets: Dict[str, Any] = {}
|
self.targets: dict[str, Any] = {}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
@ -55,7 +55,7 @@ def split_module(
|
|||||||
m: GraphModule,
|
m: GraphModule,
|
||||||
root_m: torch.nn.Module,
|
root_m: torch.nn.Module,
|
||||||
split_callback: Callable[[Node], int],
|
split_callback: Callable[[Node], int],
|
||||||
qualname_map: Optional[Dict[str, str]] = None,
|
qualname_map: Optional[dict[str, str]] = None,
|
||||||
keep_original_order: Optional[bool] = False,
|
keep_original_order: Optional[bool] = False,
|
||||||
keep_original_node_name: Optional[bool] = False,
|
keep_original_node_name: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
@ -161,8 +161,8 @@ def split_module(
|
|||||||
|
|
||||||
def construct_graph(
|
def construct_graph(
|
||||||
node: Node,
|
node: Node,
|
||||||
base_mod_env: Dict[str, Node],
|
base_mod_env: dict[str, Node],
|
||||||
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule],
|
base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule],
|
||||||
):
|
):
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
default_value = (
|
default_value = (
|
||||||
@ -195,9 +195,9 @@ def split_module(
|
|||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
partitions: Dict[str, Partition] = {}
|
partitions: dict[str, Partition] = {}
|
||||||
orig_nodes: Dict[str, Node] = {}
|
orig_nodes: dict[str, Node] = {}
|
||||||
symbol_to_node: Dict[sympy.Symbol, Node] = {}
|
symbol_to_node: dict[sympy.Symbol, Node] = {}
|
||||||
|
|
||||||
def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
|
def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
|
||||||
from torch.fx.experimental.symbolic_shapes import free_symbols
|
from torch.fx.experimental.symbolic_shapes import free_symbols
|
||||||
@ -273,7 +273,7 @@ def split_module(
|
|||||||
# ------------------------
|
# ------------------------
|
||||||
# 1. first region: we do nothing
|
# 1. first region: we do nothing
|
||||||
# 2. subsequent regions: we insert the set_grad at the beginning
|
# 2. subsequent regions: we insert the set_grad at the beginning
|
||||||
grad_regions: OrderedDict[Node, Set[int]] = OrderedDict()
|
grad_regions: OrderedDict[Node, set[int]] = OrderedDict()
|
||||||
|
|
||||||
# For autocast regions:
|
# For autocast regions:
|
||||||
# ------------------------
|
# ------------------------
|
||||||
@ -282,8 +282,8 @@ def split_module(
|
|||||||
# _enter at the beginning and _exit at the end
|
# _enter at the beginning and _exit at the end
|
||||||
# 3. last region: we will only insert _enter at the beginning
|
# 3. last region: we will only insert _enter at the beginning
|
||||||
# We will do so in the order in which the autocasts were instantiated.
|
# We will do so in the order in which the autocasts were instantiated.
|
||||||
autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict()
|
autocast_regions: OrderedDict[Node, set[int]] = OrderedDict()
|
||||||
autocast_exits: Dict[Node, Optional[Node]] = {}
|
autocast_exits: dict[Node, Optional[Node]] = {}
|
||||||
|
|
||||||
active_grad = None
|
active_grad = None
|
||||||
active_autocasts = set()
|
active_autocasts = set()
|
||||||
@ -379,13 +379,13 @@ def split_module(
|
|||||||
|
|
||||||
original_partition_order = list(partitions.keys())
|
original_partition_order = list(partitions.keys())
|
||||||
# find partitions with no dependencies
|
# find partitions with no dependencies
|
||||||
root_partitions: List[str] = []
|
root_partitions: list[str] = []
|
||||||
for partition_name, partition in partitions.items():
|
for partition_name, partition in partitions.items():
|
||||||
if not len(partition.dependencies):
|
if not len(partition.dependencies):
|
||||||
root_partitions.append(partition_name)
|
root_partitions.append(partition_name)
|
||||||
|
|
||||||
# check partitions for circular dependencies and create topological partition ordering
|
# check partitions for circular dependencies and create topological partition ordering
|
||||||
sorted_partitions: List[str] = []
|
sorted_partitions: list[str] = []
|
||||||
while root_partitions:
|
while root_partitions:
|
||||||
root_partition = root_partitions.pop()
|
root_partition = root_partitions.pop()
|
||||||
sorted_partitions.append(root_partition)
|
sorted_partitions.append(root_partition)
|
||||||
@ -418,7 +418,7 @@ def split_module(
|
|||||||
# add placeholders to partition inputs
|
# add placeholders to partition inputs
|
||||||
for partition_name in sorted_partitions:
|
for partition_name in sorted_partitions:
|
||||||
partition = partitions[partition_name]
|
partition = partitions[partition_name]
|
||||||
new_inputs: Dict[str, None] = {}
|
new_inputs: dict[str, None] = {}
|
||||||
for inp in partition.inputs:
|
for inp in partition.inputs:
|
||||||
orig_node = orig_nodes[inp]
|
orig_node = orig_nodes[inp]
|
||||||
# We don't pass in get_attr nodes as inputs to the partition, but
|
# We don't pass in get_attr nodes as inputs to the partition, but
|
||||||
@ -507,11 +507,11 @@ def split_module(
|
|||||||
) # is it really a good idea to copy this?
|
) # is it really a good idea to copy this?
|
||||||
|
|
||||||
# original module environment dict mapping node names to nodes
|
# original module environment dict mapping node names to nodes
|
||||||
orig_mod_env: Dict[str, Node] = {}
|
orig_mod_env: dict[str, Node] = {}
|
||||||
# Set up values to construct base module
|
# Set up values to construct base module
|
||||||
base_mod_env: Dict[str, Node] = {}
|
base_mod_env: dict[str, Node] = {}
|
||||||
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
||||||
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {}
|
||||||
if not keep_original_order:
|
if not keep_original_order:
|
||||||
for node in m.graph.nodes:
|
for node in m.graph.nodes:
|
||||||
base_mod_env, base_mod_attrs = construct_graph(
|
base_mod_env, base_mod_attrs = construct_graph(
|
||||||
@ -559,7 +559,7 @@ def split_module(
|
|||||||
|
|
||||||
if keep_original_order:
|
if keep_original_order:
|
||||||
# first get the attr nodes required by this partition
|
# first get the attr nodes required by this partition
|
||||||
orig_mod_attr_nodes: List[Node] = [
|
orig_mod_attr_nodes: list[Node] = [
|
||||||
orig_mod_env[key]
|
orig_mod_env[key]
|
||||||
for key in partition.inputs
|
for key in partition.inputs
|
||||||
if key not in original_order
|
if key not in original_order
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import copy
|
import copy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
@ -45,28 +45,28 @@ class Component:
|
|||||||
name: str
|
name: str
|
||||||
|
|
||||||
# Stores the placeholder nodes in `graph`.
|
# Stores the placeholder nodes in `graph`.
|
||||||
input_placeholders: List = field(default_factory=list)
|
input_placeholders: list = field(default_factory=list)
|
||||||
|
|
||||||
# Store the nodes in original graph that are placeholder in `graph`.
|
# Store the nodes in original graph that are placeholder in `graph`.
|
||||||
orig_inputs: List = field(default_factory=list)
|
orig_inputs: list = field(default_factory=list)
|
||||||
|
|
||||||
# Store the nodes in original graph that are outputs in `graph`.
|
# Store the nodes in original graph that are outputs in `graph`.
|
||||||
orig_outputs: List = field(default_factory=list)
|
orig_outputs: list = field(default_factory=list)
|
||||||
|
|
||||||
# Mapping from get_attr node in original graph to get_attr node in `graph`.
|
# Mapping from get_attr node in original graph to get_attr node in `graph`.
|
||||||
getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
|
getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
|
||||||
constructor_args: List[str] = field(default_factory=list)
|
constructor_args: list[str] = field(default_factory=list)
|
||||||
gm: Optional[torch.fx.GraphModule] = None
|
gm: Optional[torch.fx.GraphModule] = None
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def split_by_tags(
|
def split_by_tags(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
tags: List[str],
|
tags: list[str],
|
||||||
return_fqn_mapping: bool = False,
|
return_fqn_mapping: bool = False,
|
||||||
return_tuple: bool = False,
|
return_tuple: bool = False,
|
||||||
GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule,
|
GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule,
|
||||||
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
|
) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]:
|
||||||
"""
|
"""
|
||||||
Splits a GraphModule using tags on its graph nodes. We honor the order of
|
Splits a GraphModule using tags on its graph nodes. We honor the order of
|
||||||
tags. For example, we have tags = ["a", "b", "c"], the function will create
|
tags. For example, we have tags = ["a", "b", "c"], the function will create
|
||||||
@ -133,26 +133,26 @@ def split_by_tags(
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
# Mapping from node in original module to node in created submodule.
|
# Mapping from node in original module to node in created submodule.
|
||||||
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||||
|
|
||||||
# Mapping from node in original module or created submodules to
|
# Mapping from node in original module or created submodules to
|
||||||
# corresponding component.
|
# corresponding component.
|
||||||
node_to_component: Dict[torch.fx.Node, Component] = {}
|
node_to_component: dict[torch.fx.Node, Component] = {}
|
||||||
|
|
||||||
# Mapping from tag to the corresponding component.
|
# Mapping from tag to the corresponding component.
|
||||||
tag_to_component: Dict[str, Component] = {}
|
tag_to_component: dict[str, Component] = {}
|
||||||
|
|
||||||
# Stores all components.
|
# Stores all components.
|
||||||
all_components: List[Component] = []
|
all_components: list[Component] = []
|
||||||
|
|
||||||
# Stores nodes that will be used in main graph.
|
# Stores nodes that will be used in main graph.
|
||||||
used_in_main: Dict[torch.fx.Node, None] = {}
|
used_in_main: dict[torch.fx.Node, None] = {}
|
||||||
|
|
||||||
# Main graph after split.
|
# Main graph after split.
|
||||||
main_g = torch.fx.Graph()
|
main_g = torch.fx.Graph()
|
||||||
|
|
||||||
# Mapping from node in original module to node in main graph after split.
|
# Mapping from node in original module to node in main graph after split.
|
||||||
main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
main_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||||
|
|
||||||
# Output node of original module.
|
# Output node of original module.
|
||||||
output_node: Optional[torch.fx.Node] = None
|
output_node: Optional[torch.fx.Node] = None
|
||||||
@ -258,7 +258,7 @@ def split_by_tags(
|
|||||||
node_to_component[n].orig_outputs.append(n)
|
node_to_component[n].orig_outputs.append(n)
|
||||||
|
|
||||||
# Now we create a graphmodule for each component.
|
# Now we create a graphmodule for each component.
|
||||||
orig_to_split_fqn_mapping: Dict[str, str] = {}
|
orig_to_split_fqn_mapping: dict[str, str] = {}
|
||||||
for comp in all_components:
|
for comp in all_components:
|
||||||
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
|
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
|
||||||
|
|
||||||
|
@ -3,8 +3,9 @@ import argparse
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple
|
from typing import Any, NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
@ -225,7 +226,7 @@ class SplitResult(NamedTuple):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
split_module: torch.fx.GraphModule
|
split_module: torch.fx.GraphModule
|
||||||
submodule_inputs: Dict[str, Any]
|
submodule_inputs: dict[str, Any]
|
||||||
non_acc_submodule_prefix: str
|
non_acc_submodule_prefix: str
|
||||||
|
|
||||||
|
|
||||||
@ -235,7 +236,7 @@ def generate_inputs_for_submodules(
|
|||||||
inputs: Sequence[Any],
|
inputs: Sequence[Any],
|
||||||
target_submodules: Iterable[str],
|
target_submodules: Iterable[str],
|
||||||
deepcopy: bool = False,
|
deepcopy: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
|
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
|
||||||
function doesn't work.
|
function doesn't work.
|
||||||
@ -365,16 +366,16 @@ class _SplitterBase:
|
|||||||
self.update_deps_for_fusions()
|
self.update_deps_for_fusions()
|
||||||
|
|
||||||
self.non_acc_submodule_name = non_acc_submodule_name
|
self.non_acc_submodule_name = non_acc_submodule_name
|
||||||
self._node_submodule_map: Dict[str, str] = {}
|
self._node_submodule_map: dict[str, str] = {}
|
||||||
self._return_tuple = return_tuple
|
self._return_tuple = return_tuple
|
||||||
|
|
||||||
self.tags: List[str] = []
|
self.tags: list[str] = []
|
||||||
|
|
||||||
# ===============================================================
|
# ===============================================================
|
||||||
# Helpers for ctor and initial state
|
# Helpers for ctor and initial state
|
||||||
# ===============================================================
|
# ===============================================================
|
||||||
|
|
||||||
def get_node_submodule_map(self) -> Dict[str, str]:
|
def get_node_submodule_map(self) -> dict[str, str]:
|
||||||
"""Returns a map from node name to submodule name, e.g.
|
"""Returns a map from node name to submodule name, e.g.
|
||||||
node: main_module_impl_impl_over_arch_unary_multiple_embedding
|
node: main_module_impl_impl_over_arch_unary_multiple_embedding
|
||||||
_pooling_embedding_pooling_sparse_entity_equivalence_key
|
_pooling_embedding_pooling_sparse_entity_equivalence_key
|
||||||
@ -383,7 +384,7 @@ class _SplitterBase:
|
|||||||
"""
|
"""
|
||||||
return self._node_submodule_map
|
return self._node_submodule_map
|
||||||
|
|
||||||
def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
|
def find_deps(self) -> dict[torch.fx.Node, NodeSet]:
|
||||||
"""
|
"""
|
||||||
Builds a graph of node dependencies. Leaf nodes don't have any
|
Builds a graph of node dependencies. Leaf nodes don't have any
|
||||||
dependencies and the "output" node doesn't have nodes depending on it.
|
dependencies and the "output" node doesn't have nodes depending on it.
|
||||||
@ -391,7 +392,7 @@ class _SplitterBase:
|
|||||||
Resulting graph has only direct dependencies, i.e. there are no
|
Resulting graph has only direct dependencies, i.e. there are no
|
||||||
transitive dependencies.
|
transitive dependencies.
|
||||||
"""
|
"""
|
||||||
deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
|
deps: dict[torch.fx.Node, NodeSet] = defaultdict(set)
|
||||||
for node in self.module.graph.nodes:
|
for node in self.module.graph.nodes:
|
||||||
if node.op not in CALLABLE_NODE_OPS:
|
if node.op not in CALLABLE_NODE_OPS:
|
||||||
continue
|
continue
|
||||||
@ -647,12 +648,12 @@ class _SplitterBase:
|
|||||||
|
|
||||||
def find_reverse_deps(
|
def find_reverse_deps(
|
||||||
self, tag_id: Optional[int] = None
|
self, tag_id: Optional[int] = None
|
||||||
) -> Dict[torch.fx.Node, NodeSet]:
|
) -> dict[torch.fx.Node, NodeSet]:
|
||||||
"""
|
"""
|
||||||
Builds reversed topological node dependencies, if tag_id is specified,
|
Builds reversed topological node dependencies, if tag_id is specified,
|
||||||
we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
|
we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
|
||||||
"""
|
"""
|
||||||
result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
|
result: dict[torch.fx.Node, NodeSet] = defaultdict(set)
|
||||||
|
|
||||||
for node in self.module.graph.nodes:
|
for node in self.module.graph.nodes:
|
||||||
if node.op not in CALLABLE_NODE_OPS:
|
if node.op not in CALLABLE_NODE_OPS:
|
||||||
@ -667,7 +668,7 @@ class _SplitterBase:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def update_reverse_deps_for_fusions(self, deps: Dict[torch.fx.Node, NodeSet]):
|
def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]):
|
||||||
processed_node = set()
|
processed_node = set()
|
||||||
|
|
||||||
for node, fusion in self.fusions.items():
|
for node, fusion in self.fusions.items():
|
||||||
@ -757,7 +758,7 @@ class _SplitterBase:
|
|||||||
# Helpers for split() method
|
# Helpers for split() method
|
||||||
# ===============================================================
|
# ===============================================================
|
||||||
|
|
||||||
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
|
def starter_nodes(self) -> tuple[NodeSet, NodeSet]:
|
||||||
"""
|
"""
|
||||||
Finds nodes that consume module inputs or get_attr nodes.
|
Finds nodes that consume module inputs or get_attr nodes.
|
||||||
"""
|
"""
|
||||||
@ -773,7 +774,7 @@ class _SplitterBase:
|
|||||||
starter_cpu_nodes.add(user)
|
starter_cpu_nodes.add(user)
|
||||||
return starter_cpu_nodes, starter_acc_nodes
|
return starter_cpu_nodes, starter_acc_nodes
|
||||||
|
|
||||||
def put_nodes_into_subgraphs(self) -> List[Subgraph]:
|
def put_nodes_into_subgraphs(self) -> list[Subgraph]:
|
||||||
# We start graph traversal from leaf nodes
|
# We start graph traversal from leaf nodes
|
||||||
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
|
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
|
||||||
visited_nodes: NodeSet = set()
|
visited_nodes: NodeSet = set()
|
||||||
@ -785,7 +786,7 @@ class _SplitterBase:
|
|||||||
current_subgraph_nodes: NodeList = []
|
current_subgraph_nodes: NodeList = []
|
||||||
|
|
||||||
# Result accumulator
|
# Result accumulator
|
||||||
subgraphs: List[Subgraph] = []
|
subgraphs: list[Subgraph] = []
|
||||||
while current_cpu_nodes or current_acc_nodes:
|
while current_cpu_nodes or current_acc_nodes:
|
||||||
# Find the first node that should belong to the current subgraph and has all dependencies resolved
|
# Find the first node that should belong to the current subgraph and has all dependencies resolved
|
||||||
current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
|
current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
|
||||||
@ -839,12 +840,12 @@ class _SplitterBase:
|
|||||||
|
|
||||||
return subgraphs
|
return subgraphs
|
||||||
|
|
||||||
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
|
def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]:
|
||||||
"""
|
"""
|
||||||
This pass finds ACC submodules with less than specified size and merges
|
This pass finds ACC submodules with less than specified size and merges
|
||||||
them with adjacent CPU submodules.
|
them with adjacent CPU submodules.
|
||||||
"""
|
"""
|
||||||
result: List[Subgraph] = []
|
result: list[Subgraph] = []
|
||||||
for subgraph in subgraphs:
|
for subgraph in subgraphs:
|
||||||
if subgraph.is_acc:
|
if subgraph.is_acc:
|
||||||
if len(subgraph.nodes) >= self.settings.min_acc_module_size:
|
if len(subgraph.nodes) >= self.settings.min_acc_module_size:
|
||||||
@ -866,7 +867,7 @@ class _SplitterBase:
|
|||||||
result.append(subgraph)
|
result.append(subgraph)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def tag(self, subgraphs: List[Subgraph]):
|
def tag(self, subgraphs: list[Subgraph]):
|
||||||
self.tags = []
|
self.tags = []
|
||||||
for subgraph in subgraphs:
|
for subgraph in subgraphs:
|
||||||
tag = (
|
tag = (
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import collections
|
import collections
|
||||||
import operator
|
import operator
|
||||||
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
@ -18,11 +19,11 @@ __all__ = [
|
|||||||
"legalize_graph",
|
"legalize_graph",
|
||||||
]
|
]
|
||||||
|
|
||||||
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
|
Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]]
|
||||||
TensorOrTensors = Union[torch.Tensor, Tensors]
|
TensorOrTensors = Union[torch.Tensor, Tensors]
|
||||||
NodeList = List[torch.fx.Node]
|
NodeList = list[torch.fx.Node]
|
||||||
NodeSet = Set[torch.fx.Node]
|
NodeSet = set[torch.fx.Node]
|
||||||
Names = List[str]
|
Names = list[str]
|
||||||
CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
|
CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
|
||||||
|
|
||||||
|
|
||||||
@ -172,8 +173,8 @@ class FxNetAccFusionsFinder:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __call__(self) -> Dict[torch.fx.Node, NodeSet]:
|
def __call__(self) -> dict[torch.fx.Node, NodeSet]:
|
||||||
result: Dict[torch.fx.Node, NodeSet] = {}
|
result: dict[torch.fx.Node, NodeSet] = {}
|
||||||
acc_nodes = list(self.acc_nodes)
|
acc_nodes = list(self.acc_nodes)
|
||||||
|
|
||||||
for node in acc_nodes:
|
for node in acc_nodes:
|
||||||
@ -294,7 +295,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if indeg[node] == 0:
|
if indeg[node] == 0:
|
||||||
queue.append(node)
|
queue.append(node)
|
||||||
env: Dict[torch.fx.Node, torch.fx.Node] = {}
|
env: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||||
# Pop nodes from the queue, and add nodes that have had all their
|
# Pop nodes from the queue, and add nodes that have had all their
|
||||||
# dependencies fulfilled
|
# dependencies fulfilled
|
||||||
while len(queue) > 0:
|
while len(queue) > 0:
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
from torch.fx.graph import Graph
|
from torch.fx.graph import Graph
|
||||||
@ -30,7 +29,7 @@ def lift_subgraph_as_module(
|
|||||||
subgraph: Graph,
|
subgraph: Graph,
|
||||||
comp_name: str = "",
|
comp_name: str = "",
|
||||||
class_name: str = "GraphModule",
|
class_name: str = "GraphModule",
|
||||||
) -> Tuple[GraphModule, Dict[str, str]]:
|
) -> tuple[GraphModule, dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
|
Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
|
||||||
|
|
||||||
@ -52,7 +51,7 @@ def lift_subgraph_as_module(
|
|||||||
# make "weight" a attribute of "conv" HolderModule and point to conv.weight in
|
# make "weight" a attribute of "conv" HolderModule and point to conv.weight in
|
||||||
# the original module.
|
# the original module.
|
||||||
submodule = HolderModule({})
|
submodule = HolderModule({})
|
||||||
orig_to_split_fqn_mapping: Dict[str, str] = {}
|
orig_to_split_fqn_mapping: dict[str, str] = {}
|
||||||
for n in subgraph.nodes:
|
for n in subgraph.nodes:
|
||||||
if n.op not in ("call_module", "get_attr"):
|
if n.op not in ("call_module", "get_attr"):
|
||||||
continue
|
continue
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import copy
|
import copy
|
||||||
from queue import SimpleQueue
|
from queue import SimpleQueue
|
||||||
from typing import Dict, List, Optional as _Optional, Tuple
|
from typing import Optional as _Optional
|
||||||
|
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
@ -97,10 +97,10 @@ def fuse_as_graphmodule(
|
|||||||
gm: GraphModule,
|
gm: GraphModule,
|
||||||
nodes: NodeList,
|
nodes: NodeList,
|
||||||
module_name: str,
|
module_name: str,
|
||||||
partition_lookup_table: _Optional[Dict[Node, None]] = None,
|
partition_lookup_table: _Optional[dict[Node, None]] = None,
|
||||||
*,
|
*,
|
||||||
always_return_tuple: bool = False,
|
always_return_tuple: bool = False,
|
||||||
) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
|
) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]:
|
||||||
"""
|
"""
|
||||||
Fuse nodes in graph_module into a GraphModule.
|
Fuse nodes in graph_module into a GraphModule.
|
||||||
|
|
||||||
@ -144,10 +144,10 @@ def fuse_as_graphmodule(
|
|||||||
|
|
||||||
subgraph = Graph()
|
subgraph = Graph()
|
||||||
|
|
||||||
node_to_placeholder: Dict[
|
node_to_placeholder: dict[
|
||||||
Node, Node
|
Node, Node
|
||||||
] = {} # mapping of nodes from old graph to placeholder in new graph
|
] = {} # mapping of nodes from old graph to placeholder in new graph
|
||||||
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
|
node_map: dict[Node, Node] = {} # mapping of nodes from old graph to new graph
|
||||||
|
|
||||||
# handles inputs through graph.node_copy's arg_transform functions
|
# handles inputs through graph.node_copy's arg_transform functions
|
||||||
def remap_inputs(x):
|
def remap_inputs(x):
|
||||||
@ -176,7 +176,7 @@ def fuse_as_graphmodule(
|
|||||||
node_map[node] = new_node
|
node_map[node] = new_node
|
||||||
|
|
||||||
# handles outputs
|
# handles outputs
|
||||||
output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs
|
output_mapping: dict[Node, Node] = {} # mapping from old output to new outputs
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
for user_node in node.users:
|
for user_node in node.users:
|
||||||
@ -202,10 +202,10 @@ def fuse_as_graphmodule(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# sub_gm's input nodes in the original module
|
# sub_gm's input nodes in the original module
|
||||||
original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
|
original_inputs: tuple[Node, ...] = tuple(node_to_placeholder.keys())
|
||||||
|
|
||||||
# sub_gm's outputs node in the original module
|
# sub_gm's outputs node in the original module
|
||||||
original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys())
|
original_outputs: tuple[Node, ...] = tuple(output_mapping.keys())
|
||||||
|
|
||||||
return fused_gm, original_inputs, original_outputs
|
return fused_gm, original_inputs, original_outputs
|
||||||
|
|
||||||
@ -214,8 +214,8 @@ def fuse_as_graphmodule(
|
|||||||
def insert_subgm(
|
def insert_subgm(
|
||||||
gm: GraphModule,
|
gm: GraphModule,
|
||||||
sub_gm: GraphModule,
|
sub_gm: GraphModule,
|
||||||
orig_inputs: Tuple[Node, ...],
|
orig_inputs: tuple[Node, ...],
|
||||||
orig_outputs: Tuple[Node, ...],
|
orig_outputs: tuple[Node, ...],
|
||||||
):
|
):
|
||||||
# add sub_gm into gm
|
# add sub_gm into gm
|
||||||
submodule_name = sub_gm.__class__.__name__
|
submodule_name = sub_gm.__class__.__name__
|
||||||
@ -250,7 +250,7 @@ def erase_nodes(gm: GraphModule, nodes: NodeList):
|
|||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def fuse_by_partitions(
|
def fuse_by_partitions(
|
||||||
gm: GraphModule,
|
gm: GraphModule,
|
||||||
partitions: List[Dict[Node, None]],
|
partitions: list[dict[Node, None]],
|
||||||
prefix: str = "fused_",
|
prefix: str = "fused_",
|
||||||
always_return_tuple: bool = False,
|
always_return_tuple: bool = False,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Set, Tuple, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
@ -37,19 +37,19 @@ logger = _init_logger()
|
|||||||
@dataclass
|
@dataclass
|
||||||
class InternalMatch:
|
class InternalMatch:
|
||||||
# Nodes from which the match was found
|
# Nodes from which the match was found
|
||||||
anchors: List[Node]
|
anchors: list[Node]
|
||||||
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
||||||
nodes_map: Dict[Node, Node] = field(default_factory=dict)
|
nodes_map: dict[Node, Node] = field(default_factory=dict)
|
||||||
|
|
||||||
# nodes in target graph that are matched placeholder in pattern
|
# nodes in target graph that are matched placeholder in pattern
|
||||||
placeholder_nodes: List[Node] = field(default_factory=list)
|
placeholder_nodes: list[Node] = field(default_factory=list)
|
||||||
|
|
||||||
# nodes in matched subgraph returned by output
|
# nodes in matched subgraph returned by output
|
||||||
returning_nodes: List[Node] = field(default_factory=list)
|
returning_nodes: list[Node] = field(default_factory=list)
|
||||||
|
|
||||||
# map from a string name to a node in the target graph
|
# map from a string name to a node in the target graph
|
||||||
# only available if the matcher is `SubgraphMatcherWithNameNodesMap`
|
# only available if the matcher is `SubgraphMatcherWithNameNodesMap`
|
||||||
name_node_map: Dict[str, Node] = field(default_factory=dict)
|
name_node_map: dict[str, Node] = field(default_factory=dict)
|
||||||
|
|
||||||
def __copy__(self):
|
def __copy__(self):
|
||||||
return InternalMatch(
|
return InternalMatch(
|
||||||
@ -107,9 +107,9 @@ class SubgraphMatcher:
|
|||||||
]
|
]
|
||||||
output_node = next(iter(reversed(pattern.nodes)))
|
output_node = next(iter(reversed(pattern.nodes)))
|
||||||
# nodes returned by outputs
|
# nodes returned by outputs
|
||||||
self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes
|
self.pattern_returning_nodes: list[Node] = output_node.all_input_nodes
|
||||||
|
|
||||||
self.pattern_anchors: List[Node] = []
|
self.pattern_anchors: list[Node] = []
|
||||||
if match_output:
|
if match_output:
|
||||||
self.pattern_anchors = [output_node]
|
self.pattern_anchors = [output_node]
|
||||||
else:
|
else:
|
||||||
@ -150,12 +150,12 @@ class SubgraphMatcher:
|
|||||||
return pn.target == gn.target
|
return pn.target == gn.target
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
|
def _is_contained(self, nodes_map: dict[Node, Node]) -> bool:
|
||||||
# `lookup` represents all the nodes in `original_graph`
|
# `lookup` represents all the nodes in `original_graph`
|
||||||
# that are part of `pattern`
|
# that are part of `pattern`
|
||||||
|
|
||||||
# Placeholders can be used by other nodes in the graphs
|
# Placeholders can be used by other nodes in the graphs
|
||||||
lookup: Dict[Node, Node] = {
|
lookup: dict[Node, Node] = {
|
||||||
gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder"
|
gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,10 +172,10 @@ class SubgraphMatcher:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _remove_overlapping_matches(
|
def _remove_overlapping_matches(
|
||||||
self, matches: List[InternalMatch]
|
self, matches: list[InternalMatch]
|
||||||
) -> List[InternalMatch]:
|
) -> list[InternalMatch]:
|
||||||
non_overlapping_matches: List[InternalMatch] = []
|
non_overlapping_matches: list[InternalMatch] = []
|
||||||
nodes_matched: Set[Node] = set()
|
nodes_matched: set[Node] = set()
|
||||||
|
|
||||||
for match in matches:
|
for match in matches:
|
||||||
found_overlap = False
|
found_overlap = False
|
||||||
@ -244,7 +244,7 @@ class SubgraphMatcher:
|
|||||||
# match for `gn`
|
# match for `gn`
|
||||||
match_found = True
|
match_found = True
|
||||||
|
|
||||||
def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool:
|
def _match_args(args1: Union[list, tuple], args2: Union[list, tuple]) -> bool:
|
||||||
if len(args1) != len(args2):
|
if len(args1) != len(args2):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -313,7 +313,7 @@ class SubgraphMatcher:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def match(self, graph: Graph) -> List[InternalMatch]:
|
def match(self, graph: Graph) -> list[InternalMatch]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
The matched subgraphs.
|
The matched subgraphs.
|
||||||
@ -352,7 +352,7 @@ class SubgraphMatcher:
|
|||||||
from torch.fx.passes.utils.fuser_utils import validate_partition
|
from torch.fx.passes.utils.fuser_utils import validate_partition
|
||||||
|
|
||||||
# find candidate nodes to match with pattern anchors
|
# find candidate nodes to match with pattern anchors
|
||||||
match_candidates: Dict[Node, List[Node]] = defaultdict(list)
|
match_candidates: dict[Node, list[Node]] = defaultdict(list)
|
||||||
for pattern_anchor in self.pattern_anchors:
|
for pattern_anchor in self.pattern_anchors:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
if self._nodes_are_equal(pattern_anchor, node):
|
if self._nodes_are_equal(pattern_anchor, node):
|
||||||
@ -361,7 +361,7 @@ class SubgraphMatcher:
|
|||||||
|
|
||||||
logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
|
logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
|
||||||
|
|
||||||
matches: List[InternalMatch] = []
|
matches: list[InternalMatch] = []
|
||||||
|
|
||||||
def backtracking(anchor_index, match):
|
def backtracking(anchor_index, match):
|
||||||
if anchor_index == len(match_candidates_list):
|
if anchor_index == len(match_candidates_list):
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
from torch.fx import Graph, GraphModule, Node
|
from torch.fx import Graph, GraphModule, Node
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
|
|
||||||
@ -11,7 +9,7 @@ __all__ = ["SubgraphMatcherWithNameNodeMap"]
|
|||||||
|
|
||||||
def _split_to_graph_and_name_node_map(
|
def _split_to_graph_and_name_node_map(
|
||||||
gm: GraphModule,
|
gm: GraphModule,
|
||||||
) -> Tuple[GraphModule, Dict[str, Node]]:
|
) -> tuple[GraphModule, dict[str, Node]]:
|
||||||
from torch.fx.graph import _PyTreeInfo
|
from torch.fx.graph import _PyTreeInfo
|
||||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
@ -29,7 +27,7 @@ def _split_to_graph_and_name_node_map(
|
|||||||
*out, name_node_map = output
|
*out, name_node_map = output
|
||||||
flattened, out_spec = tree_flatten(out)
|
flattened, out_spec = tree_flatten(out)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
name_node_map, Dict
|
name_node_map, dict
|
||||||
), "Expecting the input graph to have a dict output as the last element"
|
), "Expecting the input graph to have a dict output as the last element"
|
||||||
n.args = (flattened,)
|
n.args = (flattened,)
|
||||||
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
|
orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined]
|
||||||
@ -88,7 +86,7 @@ class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
|
|||||||
ignore_literals,
|
ignore_literals,
|
||||||
)
|
)
|
||||||
|
|
||||||
def match(self, graph: Graph) -> List[InternalMatch]:
|
def match(self, graph: Graph) -> list[InternalMatch]:
|
||||||
"""The returned InternalMatch will have name_node_map populated with a map
|
"""The returned InternalMatch will have name_node_map populated with a map
|
||||||
from node name (str) to the target node, e.g.
|
from node name (str) to the target node, e.g.
|
||||||
{"conv": target_conv_ndoe, "relu": target_relu_node}
|
{"conv": target_conv_ndoe, "relu": target_relu_node}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
from torch.fx.graph import Graph
|
from torch.fx.graph import Graph
|
||||||
@ -34,29 +34,29 @@ logger = _init_logger()
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SourcePartition:
|
class SourcePartition:
|
||||||
# Nodes in a particular partition
|
# Nodes in a particular partition
|
||||||
nodes: List[Node]
|
nodes: list[Node]
|
||||||
|
|
||||||
# The source these nodes decomposed from
|
# The source these nodes decomposed from
|
||||||
source: Any
|
source: Any
|
||||||
|
|
||||||
# Nodes in the graph that are needed as inputs to the partition
|
# Nodes in the graph that are needed as inputs to the partition
|
||||||
# These do not include the params of the partition
|
# These do not include the params of the partition
|
||||||
input_nodes: List[Node] = field(default_factory=list)
|
input_nodes: list[Node] = field(default_factory=list)
|
||||||
|
|
||||||
# Nodes in the partition that are being used by nodes outside of the
|
# Nodes in the partition that are being used by nodes outside of the
|
||||||
# partition
|
# partition
|
||||||
output_nodes: List[Node] = field(default_factory=list)
|
output_nodes: list[Node] = field(default_factory=list)
|
||||||
|
|
||||||
# Parameters that are being used
|
# Parameters that are being used
|
||||||
params: List[Node] = field(default_factory=list)
|
params: list[Node] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False) # type: ignore[misc]
|
@compatibility(is_backward_compatible=False) # type: ignore[misc]
|
||||||
def get_source_partitions(
|
def get_source_partitions(
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
wanted_sources: List[Any],
|
wanted_sources: list[Any],
|
||||||
filter_fn: Optional[Callable[[Node], bool]] = None,
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
||||||
) -> Dict[Any, List[SourcePartition]]:
|
) -> dict[Any, list[SourcePartition]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
graph: The graph we want to partition
|
graph: The graph we want to partition
|
||||||
@ -69,7 +69,7 @@ def get_source_partitions(
|
|||||||
that correspond to the list of nodes that were decomposed from the given
|
that correspond to the list of nodes that were decomposed from the given
|
||||||
source.
|
source.
|
||||||
"""
|
"""
|
||||||
modules: Dict[Type, Dict[str, List[Node]]] = {}
|
modules: dict[type, dict[str, list[Node]]] = {}
|
||||||
|
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
# The metadata source_fn should contain a tuple of a unique name for the
|
# The metadata source_fn should contain a tuple of a unique name for the
|
||||||
@ -98,7 +98,7 @@ def get_source_partitions(
|
|||||||
partition = diff_modules.setdefault(source_fn[0], [])
|
partition = diff_modules.setdefault(source_fn[0], [])
|
||||||
partition.append(node)
|
partition.append(node)
|
||||||
|
|
||||||
def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
|
def make_partition(nodes: list[Node], module_type: type) -> SourcePartition:
|
||||||
input_nodes = set()
|
input_nodes = set()
|
||||||
output_nodes = set()
|
output_nodes = set()
|
||||||
params = set()
|
params = set()
|
||||||
@ -124,7 +124,7 @@ def get_source_partitions(
|
|||||||
list(params), # type: ignore[arg-type]
|
list(params), # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
ret: Dict[Type[Any], List[SourcePartition]] = {}
|
ret: dict[type[Any], list[SourcePartition]] = {}
|
||||||
|
|
||||||
if filter_fn:
|
if filter_fn:
|
||||||
# for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the
|
# for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the
|
||||||
|
@ -8,8 +8,10 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
import sys
|
import sys
|
||||||
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Iterator
|
||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from typing import Any, Callable, Dict, Iterator, Optional, OrderedDict, Tuple
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx.traceback as fx_traceback
|
import torch.fx.traceback as fx_traceback
|
||||||
@ -135,18 +137,18 @@ class TracerBase:
|
|||||||
scope: Scope
|
scope: Scope
|
||||||
|
|
||||||
# Records the module call stack
|
# Records the module call stack
|
||||||
module_stack: OrderedDict[str, Tuple[str, Any]]
|
module_stack: OrderedDict[str, tuple[str, Any]]
|
||||||
|
|
||||||
# Mapping of node name to module scope
|
# Mapping of node name to module scope
|
||||||
node_name_to_scope: Dict[str, Tuple[str, type]]
|
node_name_to_scope: dict[str, tuple[str, type]]
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def create_node(
|
def create_node(
|
||||||
self,
|
self,
|
||||||
kind: str,
|
kind: str,
|
||||||
target: Target,
|
target: Target,
|
||||||
args: Tuple[Argument, ...],
|
args: tuple[Argument, ...],
|
||||||
kwargs: Dict[str, Argument],
|
kwargs: dict[str, Argument],
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
type_expr: Optional[Any] = None,
|
type_expr: Optional[Any] = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
@ -171,7 +173,7 @@ class TracerBase:
|
|||||||
|
|
||||||
# Optionally set stack trace on the created Node for debugging purposes
|
# Optionally set stack trace on the created Node for debugging purposes
|
||||||
if fx_traceback.has_preserved_node_meta():
|
if fx_traceback.has_preserved_node_meta():
|
||||||
current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
|
current_meta: dict[str, Any] = fx_traceback.get_current_meta()
|
||||||
|
|
||||||
stack_trace = current_meta.get("stack_trace")
|
stack_trace = current_meta.get("stack_trace")
|
||||||
if stack_trace:
|
if stack_trace:
|
||||||
@ -211,8 +213,8 @@ class TracerBase:
|
|||||||
self,
|
self,
|
||||||
kind: str,
|
kind: str,
|
||||||
target: Target,
|
target: Target,
|
||||||
args: Tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
type_expr: Optional[Any] = None,
|
type_expr: Optional[Any] = None,
|
||||||
# fix noqa when updating bc tests
|
# fix noqa when updating bc tests
|
||||||
@ -455,10 +457,10 @@ class Proxy:
|
|||||||
# we peephole optimize to the method invocation
|
# we peephole optimize to the method invocation
|
||||||
return Attribute(self, k)
|
return Attribute(self, k)
|
||||||
|
|
||||||
def __getstate__(self) -> Dict:
|
def __getstate__(self) -> dict:
|
||||||
return self.__dict__
|
return self.__dict__
|
||||||
|
|
||||||
def __deepcopy__(self, memo) -> Dict:
|
def __deepcopy__(self, memo) -> dict:
|
||||||
# We have to explicitly override this method, because otherwise deepcopy
|
# We have to explicitly override this method, because otherwise deepcopy
|
||||||
# will go to __getattr__(self, "__deepcopy__") and return a
|
# will go to __getattr__(self, "__deepcopy__") and return a
|
||||||
# Attribute(__deepcopy__), and may go into an infinite loop in some cases.
|
# Attribute(__deepcopy__), and may go into an infinite loop in some cases.
|
||||||
@ -564,7 +566,7 @@ class Proxy:
|
|||||||
args = args if args else ()
|
args = args if args else ()
|
||||||
kwargs = kwargs if kwargs else {}
|
kwargs = kwargs if kwargs else {}
|
||||||
|
|
||||||
tracers: Dict[Any, None] = {}
|
tracers: dict[Any, None] = {}
|
||||||
|
|
||||||
def find_tracer(a):
|
def find_tracer(a):
|
||||||
if isinstance(a, cls):
|
if isinstance(a, cls):
|
||||||
|
@ -1,16 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -37,7 +27,7 @@ class Match(NamedTuple):
|
|||||||
# Node from which the match was found
|
# Node from which the match was found
|
||||||
anchor: Node
|
anchor: Node
|
||||||
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
||||||
nodes_map: Dict[Node, Node]
|
nodes_map: dict[Node, Node]
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
@ -46,9 +36,9 @@ class ReplacedPatterns:
|
|||||||
# Node from which the match was found
|
# Node from which the match was found
|
||||||
anchor: Node
|
anchor: Node
|
||||||
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
||||||
nodes_map: Dict[Node, Node]
|
nodes_map: dict[Node, Node]
|
||||||
# List of nodes that were added into the graph
|
# List of nodes that were added into the graph
|
||||||
replacements: List[Node]
|
replacements: list[Node]
|
||||||
|
|
||||||
|
|
||||||
def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
|
def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
|
||||||
@ -106,7 +96,7 @@ def replace_pattern(
|
|||||||
gm: GraphModule,
|
gm: GraphModule,
|
||||||
pattern: Union[Callable, GraphModule],
|
pattern: Union[Callable, GraphModule],
|
||||||
replacement: Union[Callable, GraphModule],
|
replacement: Union[Callable, GraphModule],
|
||||||
) -> List[Match]:
|
) -> list[Match]:
|
||||||
"""
|
"""
|
||||||
Matches all possible non-overlapping sets of operators and their
|
Matches all possible non-overlapping sets of operators and their
|
||||||
data dependencies (``pattern``) in the Graph of a GraphModule
|
data dependencies (``pattern``) in the Graph of a GraphModule
|
||||||
@ -237,14 +227,14 @@ def replace_pattern_with_filters(
|
|||||||
pattern: Union[Callable, Graph, GraphModule],
|
pattern: Union[Callable, Graph, GraphModule],
|
||||||
replacement: Union[Callable, Graph, GraphModule, None] = None,
|
replacement: Union[Callable, Graph, GraphModule, None] = None,
|
||||||
match_filters: Optional[
|
match_filters: Optional[
|
||||||
List[Callable[["InternalMatch", Graph, Graph], bool]]
|
list[Callable[["InternalMatch", Graph, Graph], bool]]
|
||||||
] = None,
|
] = None,
|
||||||
ignore_literals: bool = False,
|
ignore_literals: bool = False,
|
||||||
# Placed at the end to avoid breaking backward compatibility
|
# Placed at the end to avoid breaking backward compatibility
|
||||||
replacement_callback: Optional[
|
replacement_callback: Optional[
|
||||||
Callable[["InternalMatch", Graph, Graph], Graph]
|
Callable[["InternalMatch", Graph, Graph], Graph]
|
||||||
] = None,
|
] = None,
|
||||||
) -> List[ReplacedPatterns]:
|
) -> list[ReplacedPatterns]:
|
||||||
"""
|
"""
|
||||||
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
|
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
|
||||||
|
|
||||||
@ -268,14 +258,14 @@ def _replace_pattern(
|
|||||||
pattern: Union[Callable, Graph, GraphModule],
|
pattern: Union[Callable, Graph, GraphModule],
|
||||||
replacement: Union[Callable, Graph, GraphModule, None] = None,
|
replacement: Union[Callable, Graph, GraphModule, None] = None,
|
||||||
match_filters: Optional[
|
match_filters: Optional[
|
||||||
List[Callable[["InternalMatch", Graph, Graph], bool]]
|
list[Callable[["InternalMatch", Graph, Graph], bool]]
|
||||||
] = None,
|
] = None,
|
||||||
ignore_literals: bool = False,
|
ignore_literals: bool = False,
|
||||||
# Placed at the end to avoid breaking backward compatibility
|
# Placed at the end to avoid breaking backward compatibility
|
||||||
replacement_callback: Optional[
|
replacement_callback: Optional[
|
||||||
Callable[["InternalMatch", Graph, Graph], Graph]
|
Callable[["InternalMatch", Graph, Graph], Graph]
|
||||||
] = None,
|
] = None,
|
||||||
) -> List[ReplacedPatterns]:
|
) -> list[ReplacedPatterns]:
|
||||||
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
|
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
|
||||||
|
|
||||||
if match_filters is None:
|
if match_filters is None:
|
||||||
@ -298,7 +288,7 @@ def _replace_pattern(
|
|||||||
remove_overlapping_matches=True,
|
remove_overlapping_matches=True,
|
||||||
ignore_literals=ignore_literals,
|
ignore_literals=ignore_literals,
|
||||||
)
|
)
|
||||||
_matches: List[InternalMatch] = matcher.match(original_graph)
|
_matches: list[InternalMatch] = matcher.match(original_graph)
|
||||||
|
|
||||||
# Filter out matches that don't match the filter
|
# Filter out matches that don't match the filter
|
||||||
_matches = [
|
_matches = [
|
||||||
@ -323,7 +313,7 @@ def _replace_pattern(
|
|||||||
common_replacement_graph = None
|
common_replacement_graph = None
|
||||||
|
|
||||||
# As we progressively replace nodes, we'll need to keep track of how the match results should change
|
# As we progressively replace nodes, we'll need to keep track of how the match results should change
|
||||||
match_changed_node: Dict[Node, Node] = {}
|
match_changed_node: dict[Node, Node] = {}
|
||||||
|
|
||||||
match_and_replacements = []
|
match_and_replacements = []
|
||||||
for match in _matches:
|
for match in _matches:
|
||||||
@ -345,7 +335,7 @@ def _replace_pattern(
|
|||||||
# Initialize `val_map` with mappings from placeholder nodes in
|
# Initialize `val_map` with mappings from placeholder nodes in
|
||||||
# `replacement` to their corresponding node in `original_graph`
|
# `replacement` to their corresponding node in `original_graph`
|
||||||
assert len(match.placeholder_nodes) == len(replacement_placeholders)
|
assert len(match.placeholder_nodes) == len(replacement_placeholders)
|
||||||
val_map: Dict[Node, Node] = {}
|
val_map: dict[Node, Node] = {}
|
||||||
for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
|
for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
|
||||||
if isinstance(gn, Node):
|
if isinstance(gn, Node):
|
||||||
val_map[rn] = match_changed_node.get(gn, gn)
|
val_map[rn] = match_changed_node.get(gn, gn)
|
||||||
@ -361,7 +351,7 @@ def _replace_pattern(
|
|||||||
val_map[rn] = gn
|
val_map[rn] = gn
|
||||||
|
|
||||||
# Copy the replacement graph over
|
# Copy the replacement graph over
|
||||||
user_nodes: Set[Node] = set()
|
user_nodes: set[Node] = set()
|
||||||
for n in match.returning_nodes:
|
for n in match.returning_nodes:
|
||||||
user_nodes.update(n.users)
|
user_nodes.update(n.users)
|
||||||
|
|
||||||
@ -402,7 +392,7 @@ def _replace_pattern(
|
|||||||
copied_returning_nodes = (copied_returning_nodes,)
|
copied_returning_nodes = (copied_returning_nodes,)
|
||||||
|
|
||||||
# Get a list of nodes that have been replaced into the graph
|
# Get a list of nodes that have been replaced into the graph
|
||||||
replacement_nodes: List[Node] = [
|
replacement_nodes: list[Node] = [
|
||||||
v for v in val_map.values() if v not in match.placeholder_nodes
|
v for v in val_map.values() if v not in match.placeholder_nodes
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import json
|
|||||||
import traceback
|
import traceback
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from ._compatibility import compatibility
|
from ._compatibility import compatibility
|
||||||
from .graph import Graph
|
from .graph import Graph
|
||||||
@ -25,7 +25,7 @@ __all__ = [
|
|||||||
"get_graph_provenance_json",
|
"get_graph_provenance_json",
|
||||||
]
|
]
|
||||||
|
|
||||||
current_meta: Dict[str, Any] = {}
|
current_meta: dict[str, Any] = {}
|
||||||
should_preserve_node_meta = False
|
should_preserve_node_meta = False
|
||||||
|
|
||||||
|
|
||||||
@ -49,15 +49,15 @@ class NodeSource:
|
|||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
|
|
||||||
pass_name: str
|
pass_name: str
|
||||||
action: List["NodeSourceAction"]
|
action: list["NodeSourceAction"]
|
||||||
from_node: List["NodeSource"]
|
from_node: list["NodeSource"]
|
||||||
node_info: Optional["NodeInfo"]
|
node_info: Optional["NodeInfo"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
node: Optional[Node],
|
node: Optional[Node],
|
||||||
pass_name: str = "",
|
pass_name: str = "",
|
||||||
action: Optional[Union["NodeSourceAction", List["NodeSourceAction"]]] = None,
|
action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
|
||||||
):
|
):
|
||||||
self.pass_name = pass_name
|
self.pass_name = pass_name
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ def preserve_node_meta(enable=True):
|
|||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def set_stack_trace(stack: List[str]):
|
def set_stack_trace(stack: list[str]):
|
||||||
global current_meta
|
global current_meta
|
||||||
|
|
||||||
if should_preserve_node_meta and stack:
|
if should_preserve_node_meta and stack:
|
||||||
@ -182,7 +182,7 @@ def reset_grad_fn_seq_nr():
|
|||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def format_stack() -> List[str]:
|
def format_stack() -> list[str]:
|
||||||
if should_preserve_node_meta:
|
if should_preserve_node_meta:
|
||||||
return [current_meta.get("stack_trace", "")]
|
return [current_meta.get("stack_trace", "")]
|
||||||
else:
|
else:
|
||||||
@ -219,7 +219,7 @@ def set_current_meta(node, pass_name=""):
|
|||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def get_current_meta() -> Dict[str, Any]:
|
def get_current_meta() -> dict[str, Any]:
|
||||||
return current_meta
|
return current_meta
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user