[BE][Easy] enable PYFMT for torch.fx (#138443)

Reproduce command:

```bash
ghstack checkout https://github.com/pytorch/pytorch/pull/138443
git checkout HEAD~1 torch/
lintrunner -a --take "PYFMT" --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138443
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2024-10-22 00:32:23 +08:00
committed by PyTorch MergeBot
parent 8231180147
commit abbd71d29d
78 changed files with 4403 additions and 2361 deletions

View File

@ -1232,87 +1232,6 @@ exclude_patterns = [
'torch/fft/__init__.py',
'torch/func/__init__.py',
'torch/futures/__init__.py',
'torch/fx/__init__.py',
'torch/fx/_compatibility.py',
'torch/fx/_symbolic_trace.py',
'torch/fx/annotate.py',
'torch/fx/config.py',
'torch/fx/experimental/__init__.py',
'torch/fx/experimental/accelerator_partitioner.py',
'torch/fx/experimental/const_fold.py',
'torch/fx/experimental/debug.py',
'torch/fx/experimental/graph_gradual_typechecker.py',
'torch/fx/experimental/merge_matmul.py',
'torch/fx/experimental/meta_tracer.py',
'torch/fx/experimental/migrate_gradual_types/__init__.py',
'torch/fx/experimental/migrate_gradual_types/constraint.py',
'torch/fx/experimental/migrate_gradual_types/constraint_generator.py',
'torch/fx/experimental/migrate_gradual_types/constraint_transformation.py',
'torch/fx/experimental/migrate_gradual_types/operation.py',
'torch/fx/experimental/migrate_gradual_types/transform_to_z3.py',
'torch/fx/experimental/migrate_gradual_types/util.py',
'torch/fx/experimental/migrate_gradual_types/z3_types.py',
'torch/fx/experimental/normalize.py',
'torch/fx/experimental/optimization.py',
'torch/fx/experimental/partitioner_utils.py',
'torch/fx/experimental/refinement_types.py',
'torch/fx/experimental/rewriter.py',
'torch/fx/experimental/schema_type_annotation.py',
'torch/fx/experimental/unification/__init__.py',
'torch/fx/experimental/unification/core.py',
'torch/fx/experimental/unification/dispatch.py',
'torch/fx/experimental/unification/match.py',
'torch/fx/experimental/unification/more.py',
'torch/fx/experimental/unification/multipledispatch/__init__.py',
'torch/fx/experimental/unification/multipledispatch/conflict.py',
'torch/fx/experimental/unification/multipledispatch/core.py',
'torch/fx/experimental/unification/multipledispatch/dispatcher.py',
'torch/fx/experimental/unification/multipledispatch/utils.py',
'torch/fx/experimental/unification/multipledispatch/variadic.py',
'torch/fx/experimental/unification/unification_tools.py',
'torch/fx/experimental/unification/utils.py',
'torch/fx/experimental/unification/variable.py',
'torch/fx/experimental/unify_refinements.py',
'torch/fx/graph.py',
'torch/fx/graph_module.py',
'torch/fx/interpreter.py',
'torch/fx/node.py',
'torch/fx/operator_schemas.py',
'torch/fx/passes/__init__.py',
'torch/fx/passes/annotate_getitem_nodes.py',
'torch/fx/passes/backends/__init__.py',
'torch/fx/passes/backends/cudagraphs.py',
'torch/fx/passes/dialect/__init__.py',
'torch/fx/passes/dialect/common/__init__.py',
'torch/fx/passes/dialect/common/cse_pass.py',
'torch/fx/passes/fake_tensor_prop.py',
'torch/fx/passes/graph_drawer.py',
'torch/fx/passes/graph_manipulation.py',
'torch/fx/passes/infra/__init__.py',
'torch/fx/passes/infra/partitioner.py',
'torch/fx/passes/infra/pass_base.py',
'torch/fx/passes/infra/pass_manager.py',
'torch/fx/passes/net_min_base.py',
'torch/fx/passes/operator_support.py',
'torch/fx/passes/param_fetch.py',
'torch/fx/passes/pass_manager.py',
'torch/fx/passes/reinplace.py',
'torch/fx/passes/shape_prop.py',
'torch/fx/passes/split_module.py',
'torch/fx/passes/split_utils.py',
'torch/fx/passes/splitter_base.py',
'torch/fx/passes/tests/__init__.py',
'torch/fx/passes/tests/test_pass_manager.py',
'torch/fx/passes/tools_common.py',
'torch/fx/passes/utils/__init__.py',
'torch/fx/passes/utils/common.py',
'torch/fx/passes/utils/fuser_utils.py',
'torch/fx/passes/utils/matcher_utils.py',
'torch/fx/passes/utils/source_matcher_utils.py',
'torch/fx/proxy.py',
'torch/fx/subgraph_rewriter.py',
'torch/fx/tensor_type.py',
'torch/fx/traceback.py',
'torch/linalg/__init__.py',
'torch/monitor/__init__.py',
'torch/nested/__init__.py',

View File

@ -262,7 +262,9 @@
"Future"
],
"torch.fx": [
"PH",
"ProxyableClassMeta",
"CodeGen",
"Tracer",
"symbolic_trace",
"wrap"

View File

@ -2514,6 +2514,7 @@ if "TORCH_CUDA_SANITIZER" in os.environ:
# Populate magic methods on SymInt and SymFloat
import torch.fx.experimental.sym_node
from torch import fx as fx
# Register MPS specific decomps

View File

@ -7,6 +7,8 @@ demonstration of these components in action:
::
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
@ -17,11 +19,13 @@ demonstration of these components in action:
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
@ -80,10 +84,32 @@ Several example transformations can be found at the
repository.
'''
from .graph_module import GraphModule
from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
from .graph import Graph, CodeGen
from .node import Node, map_arg, has_side_effect
from .proxy import Proxy
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
from .subgraph_rewriter import replace_pattern
from torch.fx._symbolic_trace import ( # noqa: F401
PH,
ProxyableClassMeta,
symbolic_trace,
Tracer,
wrap,
)
from torch.fx.graph import CodeGen, Graph # noqa: F401
from torch.fx.graph_module import GraphModule
from torch.fx.interpreter import Interpreter, Transformer
from torch.fx.node import has_side_effect, map_arg, Node
from torch.fx.proxy import Proxy
from torch.fx.subgraph_rewriter import replace_pattern
__all__ = [
"symbolic_trace",
"Tracer",
"wrap",
"Graph",
"GraphModule",
"Interpreter",
"Transformer",
"Node",
"Proxy",
"replace_pattern",
"has_side_effect",
"map_arg",
]

View File

@ -1,15 +0,0 @@
from torch.fx._symbolic_trace import (
symbolic_trace as symbolic_trace,
Tracer as Tracer,
wrap as wrap,
)
from torch.fx.graph import Graph as Graph
from torch.fx.graph_module import GraphModule as GraphModule
from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer
from torch.fx.node import (
has_side_effect as has_side_effect,
map_arg as map_arg,
Node as Node,
)
from torch.fx.proxy import Proxy as Proxy
from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern

View File

@ -1,16 +1,19 @@
from typing import Any, Dict, Callable, TypeVar
import textwrap
from typing import Any, Callable, Dict, TypeVar
_BACK_COMPAT_OBJECTS: Dict[Any, None] = {}
_MARKED_WITH_COMPATIBILITY: Dict[Any, None] = {}
_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {}
_T = TypeVar("_T")
def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]:
if is_backward_compatible:
def mark_back_compat(fn: _T) -> _T:
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "")
docstring += """
.. note::
Backwards-compatibility for this API is guaranteed.
@ -24,7 +27,7 @@ def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]:
else:
def mark_not_back_compat(fn: _T) -> _T:
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "")
docstring += """
.. warning::
This API is experimental and is *NOT* backward-compatible.

View File

@ -1,9 +1,9 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from torch.fx import GraphModule
from torch.fx.graph_module import (
_format_import_block,
GraphModule,
reduce_graph_module,
reduce_package_graph_module,
)

View File

@ -1,13 +1,13 @@
# mypy: allow-untyped-defs
import builtins
import copy
import collections
import contextlib
import copy
import functools
import inspect
import math
import os
import warnings
import collections
from itertools import chain
from types import CodeType, FunctionType, ModuleType
from typing import (
@ -29,11 +29,12 @@ from torch._C import ScriptObject # type: ignore[attr-defined]
from torch._library.fake_class_registry import FakeScriptObject
from ._compatibility import compatibility
from ._lazy_graph_module import _make_graph_module
from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
from .graph_module import GraphModule
from ._lazy_graph_module import _make_graph_module
from .node import Argument, base_types, map_aggregate
from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
@ -49,6 +50,7 @@ _is_fx_tracing_flag = False
def is_fx_tracing():
return _is_fx_tracing_flag
@compatibility(is_backward_compatible=True)
class ProxyableClassMeta(type):
"""
@ -58,6 +60,7 @@ class ProxyableClassMeta(type):
import torch
import torch.fx
class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
def __init__(self, left, right):
self.left, self.right = left, right
@ -72,10 +75,12 @@ class ProxyableClassMeta(type):
r = self.right * other.right
return TensorPair(l, r)
def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor):
s = x.add(TensorPair(y, y))
return s.mul(x)
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
y = torch.randn(5, 3)
ref_out = use_tensor_pair_ctor(x, y)
@ -214,6 +219,7 @@ class PHWithMeta(PHBase):
"""
Object representing an input placeholder to `concrete_args`
"""
def __init__(self, ph_key: Optional[str] = None):
super().__init__()
@ -404,7 +410,11 @@ class Tracer(TracerBase):
# Tensor was not found in the Module hierarchy, stow it away in a
# special attribute and set the qualname to refer to that
if not qualname:
base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj"
base_name = (
"_tensor_constant"
if isinstance(a, torch.Tensor)
else "_torchbind_obj"
)
qualname = self.get_fresh_qualname(base_name)
assert isinstance(qualname, str)
self.tensor_attrs[a] = qualname
@ -446,9 +456,9 @@ class Tracer(TracerBase):
appear with the qualified name ``foo.bar.baz`` here.
"""
return (
(m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
and not isinstance(m, torch.nn.Sequential)
)
m.__module__.startswith("torch.nn")
or m.__module__.startswith("torch.ao.nn")
) and not isinstance(m, torch.nn.Sequential)
@compatibility(is_backward_compatible=True)
def path_of_module(self, mod: torch.nn.Module) -> str:
@ -512,17 +522,25 @@ class Tracer(TracerBase):
value was returned from the ``Module`` invocation.
"""
module_qualified_name = self.path_of_module(m)
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
with ScopeContextManager(
self.scope, Scope(module_qualified_name, type(m))
) as _scope:
# module_stack is an ordered dict so writing then deleting the
# entry is equivalent to push/pop on a list
num_calls = self.num_calls.get(module_qualified_name, 0)
module_key = f"{_scope.module_path}@{num_calls}" if num_calls > 0 else _scope.module_path
module_key = (
f"{_scope.module_path}@{num_calls}"
if num_calls > 0
else _scope.module_path
)
self.module_stack[module_key] = (module_qualified_name, _scope.module_type)
self.num_calls[module_qualified_name] = num_calls + 1
if not self.is_leaf_module(m, module_qualified_name):
ret_val = forward(*args, **kwargs)
else:
ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
ret_val = self.create_proxy(
"call_module", module_qualified_name, args, kwargs
)
key, _ = self.module_stack.popitem(last=True)
assert key == module_key, f" Unexpected key {key}"
@ -551,6 +569,7 @@ class Tracer(TracerBase):
The return value from the getattr call.
"""
def maybe_get_proxy_for_attr(
attr_val, collection_to_search, parameter_proxy_cache
):
@ -620,15 +639,16 @@ class Tracer(TracerBase):
sig = inspect.signature(fn_for_analysis)
# This covers the very specific case where we are passing in flat
# concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
# In this case, just take the concrete_args and pass them through.
name_idx = 0
if isinstance(concrete_args, tuple) and \
len(concrete_args) > 0 and \
(co.co_flags & HAS_VARSTUFF) and \
total_args == 1:
if (
isinstance(concrete_args, tuple)
and len(concrete_args) > 0
and (co.co_flags & HAS_VARSTUFF)
and total_args == 1
):
for concrete_arg in concrete_args:
out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
if isinstance(concrete_arg, PHBase):
@ -722,12 +742,12 @@ class Tracer(TracerBase):
_is_fx_tracing_flag = True
try:
if isinstance(root, torch.nn.Module):
# do real recompilation for _LazyGraphModule before retracing since the trace
# method can not trace the _lazy_forward method. Got error:
# https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
# without this.
from torch.fx._lazy_graph_module import _LazyGraphModule
_LazyGraphModule.force_recompile(root)
self.root = root
@ -745,12 +765,12 @@ class Tracer(TracerBase):
tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None)
self.graph = Graph(tracer_cls=tracer_cls)
if hasattr(fn, '__code__'):
if hasattr(fn, "__code__"):
code = fn.__code__
self.graph._co_fields = {
'co_name': code.co_name,
'co_filename': code.co_filename,
'co_firstlineno': code.co_firstlineno,
"co_name": code.co_name,
"co_filename": code.co_filename,
"co_firstlineno": code.co_firstlineno,
}
# When we encounter a Tensor value that's not a parameter, we look if it
@ -758,11 +778,7 @@ class Tracer(TracerBase):
# values to the qualified name here for efficiency. This is used downstream
# in create_arg
self.tensor_attrs: Dict[
Union[
torch.Tensor,
ScriptObject,
FakeScriptObject
], str
Union[torch.Tensor, ScriptObject, FakeScriptObject], str
] = {}
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
@ -839,7 +855,7 @@ class Tracer(TracerBase):
new_tracer = Tracer.__new__(Tracer)
for k, v in self.__dict__.items():
if k in {'_autowrap_search'}:
if k in {"_autowrap_search"}:
new_obj = copy.copy(v)
else:
new_obj = copy.deepcopy(v, memo)
@ -857,9 +873,7 @@ class Tracer(TracerBase):
cnt += 1
param = sig.parameters[name]
default = (
()
if param.default is inspect.Parameter.empty
else (param.default,)
() if param.default is inspect.Parameter.empty else (param.default,)
)
out = self.create_proxy(
"placeholder", f"{name}_{str(cnt)}", default, {}
@ -877,11 +891,7 @@ class Tracer(TracerBase):
return out
# Union[int, bool] == bool in Python <= 3.6
if (
type(x) == bool
or type(x) in base_types
and type(x) != torch.Tensor
):
if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor:
torch._assert(
out == x,
f"{name} has been specialized to have value {x} but got another value",
@ -906,13 +916,15 @@ class Tracer(TracerBase):
default = ()
else:
param = sig.parameters[name]
default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment]
default = ( # type: ignore[assignment]
() if param.default is inspect.Parameter.empty else (param.default,)
)
return self.create_proxy(
"placeholder",
name,
default,
{},
type_expr=fn_for_analysis.__annotations__.get(name, None)
type_expr=fn_for_analysis.__annotations__.get(name, None),
)
@ -1011,6 +1023,7 @@ class _PatchedFnSetItem(_PatchedFn):
def patch(self):
self.frame_dict[self.fn_name] = self.new_fn
class _PatchedFnDel(_PatchedFn):
def revert(self):
del self.frame_dict[self.fn_name]
@ -1026,6 +1039,7 @@ class _PatchedFnSetAttr(_PatchedFn):
def patch(self):
setattr(self.frame_dict, self.fn_name, self.new_fn)
class _Patcher:
def __init__(self) -> None:
super().__init__()
@ -1106,6 +1120,7 @@ class _Patcher:
CURRENT_PATCHER: Optional[_Patcher] = None
@contextlib.contextmanager
def _new_patcher():
global CURRENT_PATCHER
@ -1132,7 +1147,10 @@ def _maybe_revert_all_patches():
finally:
if current_patcher is not None:
patches_made = current_patcher.reapply_all_patches()
assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches"
assert (
patches_made == patches_removed
), "CURRENT_PATCHER was changed during a revert_all_patches"
def _patch_wrapped_functions(patcher: _Patcher):
"""
@ -1178,7 +1196,9 @@ def wrap(fn_or_name: Union[str, Callable]):
def my_custom_function(x, y):
return x * x + y * y
torch.fx.wrap('my_custom_function')
torch.fx.wrap("my_custom_function")
def fn_to_be_traced(x, y):
# When symbolic tracing, the below call to my_custom_function will be inserted into
@ -1248,14 +1268,14 @@ def symbolic_trace(
if b == True:
return a
else:
return a*2
return a * 2
FX can typically not trace through this due to the presence of control
flow. However, we can use `concrete_args` to specialize on the value of
`b` to trace through this::
f = fx.symbolic_trace(f, concrete_args={'b': False})
assert f(3, False) == 6
f = fx.symbolic_trace(f, concrete_args={"b": False})
assert f(3, False) == 6
Note that although you can still pass in different values of `b`, they will be ignored.
@ -1269,8 +1289,10 @@ def symbolic_trace(
for v in x.values():
out += v
return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
assert f({"a": 1, "b": 2, "c": 4}) == 7
Args:

View File

@ -1,7 +1,9 @@
# mypy: allow-untyped-defs
from torch.fx.proxy import Proxy
from ._compatibility import compatibility
@compatibility(is_backward_compatible=False)
def annotate(val, type):
"""
@ -18,13 +20,15 @@ def annotate(val, type):
"""
if isinstance(val, Proxy):
if val.node.type:
raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
f" Existing type is {val.node.type} "
f"and new type is {type}. "
f"This could happen if you tried to annotate a function parameter "
f"value (in which case you should use the type slot "
f"on the function signature) or you called "
f"annotate on the same value twice")
raise RuntimeError(
f"Tried to annotate a value that already had a type on it!"
f" Existing type is {val.node.type} "
f"and new type is {type}. "
f"This could happen if you tried to annotate a function parameter "
f"value (in which case you should use the type slot "
f"on the function signature) or you called "
f"annotate on the same value twice"
)
else:
val.node.type = type
return val

View File

@ -1,22 +1,22 @@
# mypy: allow-untyped-defs
import operator
from collections import deque
from typing import Dict, List, Set, NamedTuple, Tuple, Deque
from typing import Deque, Dict, List, NamedTuple, Set, Tuple
import torch
from torch.fx.passes.graph_manipulation import get_size_of_all_nodes
from torch.fx.experimental.partitioner_utils import (
Partition,
Device,
PartitionerConfig,
get_partition_to_latency_mapping,
get_latency_of_partitioned_graph,
NodeLatency,
get_extra_size_of,
get_latency_of_partitioned_graph,
get_partition_to_latency_mapping,
NodeLatency,
Partition,
PartitionerConfig,
PartitionMode,
)
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node, map_arg
from torch.fx.node import map_arg, Node
from torch.fx.passes.graph_manipulation import get_size_of_all_nodes
from torch.fx.passes.split_module import split_module
@ -260,7 +260,9 @@ def get_device_to_partitions_mapping(
# Find devices for all the partitions without a device
found_device = True
for partition in no_device_partitions:
device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1)))
device_to_left_mem_bytes = dict(
sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1))
)
found_device = find_device_for(partition)
if not found_device:
break

View File

@ -7,7 +7,12 @@ from torch.fx.node import map_arg
from torch.fx.passes.split_module import split_module
__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs']
__all__ = [
"FoldedGraphModule",
"get_unique_attr_name_in_module",
"split_const_subgraphs",
]
class FoldedGraphModule(torch.fx.GraphModule):
"""

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import torch.fx as fx
def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
"""
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
@ -13,18 +14,14 @@ def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
Returns:
the `gm` with breakpoint inserted.
"""
def insert_pdb(body):
return ["import pdb; pdb.set_trace()\n", *body]
with gm.graph.on_generate_code(
make_transformer=lambda cur_transform: (
# new code transformer to register
lambda body: (
insert_pdb(
cur_transform(body) if cur_transform
else body
)
)
lambda body: (insert_pdb(cur_transform(body) if cur_transform else body))
)
):
gm.recompile()

View File

@ -1,20 +1,21 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from functools import reduce
import torch
import operator
from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise
from typing import Callable, Dict
from torch.fx.node import Target, Node
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.conv import Conv2d
from torch.fx.experimental.refinement_types import Equality
import itertools
from torch.fx.experimental.unification import Var # type: ignore[attr-defined]
import operator
from functools import reduce
from typing import Callable, Dict
import sympy
import torch
from torch.fx.experimental.refinement_types import Equality
from torch.fx.experimental.unification import Var # type: ignore[attr-defined]
from torch.fx.node import Node, Target
from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.conv import Conv2d
_INFERENCE_RULES: Dict[Target, Callable] = {}
_REFINEMENT_RULES: Dict[Target, Callable] = {}
_RULES: Dict[Target, Callable] = {}
@ -32,10 +33,12 @@ def expand_to_tensor_dim(t, n):
return TensorType(tuple(dims))
elif isinstance(t, TensorType):
if len(t.__args__) != n:
raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
raise TypeError(
f"Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}"
)
return t
else:
raise TypeError(f'Cannot match the type {t}')
raise TypeError(f"Cannot match the type {t}")
def broadcast_types(t1, t2):
@ -80,32 +83,39 @@ def broadcast_types(t1, t2):
(t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
return (t1, t2)
else:
raise TypeError(f'Cannot broadcast types {t1} and {t2}')
raise TypeError(f"Cannot broadcast types {t1} and {t2}")
def register_inference_rule(call_target):
def register(fn):
if call_target in _INFERENCE_RULES:
raise RuntimeError(f'Inference rule already registered for {call_target}!')
raise RuntimeError(f"Inference rule already registered for {call_target}!")
_INFERENCE_RULES[call_target] = fn
return fn
return register
def register_refinement_rule(call_target):
def register(fn):
if call_target in _REFINEMENT_RULES:
raise RuntimeError(f'Refinement rule already registered for {call_target}!')
raise RuntimeError(f"Refinement rule already registered for {call_target}!")
_REFINEMENT_RULES[call_target] = fn
return fn
return register
def register_algebraic_expressions_inference_rule(call_target):
def register(fn):
if call_target in _RULES:
raise RuntimeError(f'Rule already registered for {call_target}!')
raise RuntimeError(f"Rule already registered for {call_target}!")
_RULES[call_target] = fn
return fn
return register
@register_inference_rule(torch.add)
@register_inference_rule(operator.add)
def add_inference_rule(n: Node):
@ -142,15 +152,15 @@ def add_inference_rule(n: Node):
(new_t1, new_t2) = broadcast_types(t1, t2)
if new_t1 != t1 or new_t2 != t2:
n.meta['broadcast'] = True
n.meta["broadcast"] = True
n.meta[str(n.args[0])] = new_t1
n.meta[str(n.args[1])] = new_t2
else:
n.meta['broadcast'] = False
n.meta["broadcast"] = False
new_t1 = t1 if not n.meta['broadcast'] else new_t1
new_t2 = t2 if not n.meta['broadcast'] else new_t2
new_t1 = t1 if not n.meta["broadcast"] else new_t1
new_t2 = t2 if not n.meta["broadcast"] else new_t2
# we check for consistency between the new types
if is_consistent(new_t1, new_t2):
@ -164,8 +174,11 @@ def add_inference_rule(n: Node):
n.type = new_t1
return n.type
else:
raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.'
f' Types should match ')
raise TypeError(
f"Cannot add arguments {n.args[0]} ({n.args[0].type}) and {n.args[1]} ({n.args[1].type}) in node {n}."
f" Types should match "
)
@register_inference_rule(getattr)
def get_attr_inference_rule(n: Node, traced):
@ -185,6 +198,7 @@ def get_attr_inference_rule(n: Node, traced):
# TODO. We leave it like this till we add a type to represent tensor sizes
return n.type
@register_inference_rule(torch.transpose)
def transpose_inference_rule(n: Node):
"""
@ -211,9 +225,13 @@ def transpose_inference_rule(n: Node):
n.type = get_greatest_upper_bound(n.type, final)
return n.type
else:
raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
raise TypeError(
f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}"
)
else:
raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
raise TypeError(
f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}"
)
@register_inference_rule(torch.reshape)
@ -251,9 +269,10 @@ def reshape_inference_rule(n: Node):
n.type = t2_type
return t2_type
else:
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}")
else:
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}")
@register_inference_rule(BatchNorm2d)
def bn2d_inference_rule(n: Node, module_instance):
@ -274,10 +293,11 @@ def bn2d_inference_rule(n: Node, module_instance):
# we check the conditions on the incoming argument
# and any existing annotation
# we also check for consistency between both annotations
if is_consistent(arg_type.__args__[1], module_instance.num_features) and \
is_consistent(n.type.__args__[1], module_instance.num_features) and \
is_consistent(arg_type, n.type):
if (
is_consistent(arg_type.__args__[1], module_instance.num_features)
and is_consistent(n.type.__args__[1], module_instance.num_features)
and is_consistent(arg_type, n.type)
):
# we choose the more precise type
# to be the node type
# so if an incoming argument has more type information
@ -285,21 +305,35 @@ def bn2d_inference_rule(n: Node, module_instance):
n.type = get_greatest_upper_bound(arg_type, n.type)
return n.type
else:
raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}')
raise TypeError(
f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}"
)
def calculate_out_dimension(d_in, module_instance, index):
"""
For calculating h_in and w_out according to the conv2D documentation
"""
padding = (module_instance.padding, module_instance.padding) \
if isinstance(module_instance.padding, int) else module_instance.padding
kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \
if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size
stride = (module_instance.stride, module_instance.stride) \
if isinstance(module_instance.stride, int) else module_instance.stride
dilation = (module_instance.dilation, module_instance.dilation) \
if isinstance(module_instance.dilation, int) else module_instance.dilation
padding = (
(module_instance.padding, module_instance.padding)
if isinstance(module_instance.padding, int)
else module_instance.padding
)
kernel_size = (
(module_instance.kernel_size, module_instance.kernel_size)
if isinstance(module_instance.kernel_size, int)
else module_instance.kernel_size
)
stride = (
(module_instance.stride, module_instance.stride)
if isinstance(module_instance.stride, int)
else module_instance.stride
)
dilation = (
(module_instance.dilation, module_instance.dilation)
if isinstance(module_instance.dilation, int)
else module_instance.dilation
)
DIMENSION_TYPES = (int, sympy.Symbol)
@ -307,14 +341,14 @@ def calculate_out_dimension(d_in, module_instance, index):
return Dyn
elif isinstance(d_in, DIMENSION_TYPES):
n = d_in + 2 * padding[index] - \
dilation[index] * \
(kernel_size[index] - 1) - 1
n = d_in + 2 * padding[index] - dilation[index] * (kernel_size[index] - 1) - 1
return (n // stride[0]) + 1
else:
raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}')
raise TypeError(
f"{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}"
)
def get_greatest_upper_bound(type1, type2):
@ -327,8 +361,11 @@ def get_greatest_upper_bound(type1, type2):
return type1
elif isinstance(type1, TensorType) and isinstance(type2, TensorType):
if not is_consistent(type1, type2):
raise TypeError(f'Inconsistent types {type1}, {type2}')
gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)]
raise TypeError(f"Inconsistent types {type1}, {type2}")
gub = [
t1 if is_more_precise(t1, t2) else t2
for (t1, t2) in zip(type1.__args__, type2.__args__)
]
return TensorType(tuple(gub))
@ -352,12 +389,16 @@ def conv2d_inference_rule(n: Node, module_instance):
h_in = arg_type.__args__[2]
h_out = calculate_out_dimension(h_in, module_instance, 0)
w_out = calculate_out_dimension(w_in, module_instance, 1)
new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out))
new_type = TensorType(
(arg_type.__args__[0], module_instance.out_channels, h_out, w_out)
)
gub = get_greatest_upper_bound(new_type, curr_node_type)
n.type = gub
return n.type
else:
raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
raise TypeError(
f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}"
)
@register_inference_rule(torch.nn.ReLU)
@ -393,7 +434,7 @@ def maxpool2d_check(typ, module_instance):
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Wrong size {typ} for {module_instance}')
raise TypeError(f"Wrong size {typ} for {module_instance}")
@register_inference_rule(torch.nn.MaxPool2d)
@ -417,7 +458,6 @@ def maxpool2d_inference_rule(n: Node, module_instance):
return n.type
def linear_check(tensor_type, module_instance):
"""
Checks that an input tensor type satisfies the conditions for linear operation
@ -429,9 +469,11 @@ def linear_check(tensor_type, module_instance):
new_type_args[-1] = module_instance.out_features
return TensorType(tuple(new_type_args))
else:
raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}')
raise TypeError(
f"Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}"
)
else:
raise TypeError(f'Type {tensor_type} must have rank 2 or more.')
raise TypeError(f"Type {tensor_type} must have rank 2 or more.")
@register_inference_rule(torch.nn.Linear)
@ -469,7 +511,8 @@ def adaptiveavgpool2d_check(tensor_type, module_instance):
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}')
raise TypeError(f"Tensor ranks must be 3 or 4. Got {tensor_type}")
@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
@ -485,6 +528,7 @@ def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
n.type = get_greatest_upper_bound(n.type, output_type)
return n.type
def flatten_check(tensor_type, start_dim, end_dim):
l = len(tensor_type.__args__)
@ -503,7 +547,10 @@ def flatten_check(tensor_type, start_dim, end_dim):
new_type_list = lhs + mid + rhs
return TensorType(tuple(new_type_list))
else:
raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}')
raise TypeError(
f"Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}"
)
@register_inference_rule(torch.flatten)
def flatten_inference_rule(n: Node):
@ -530,10 +577,11 @@ def flatten_inference_rule(n: Node):
if isinstance(n.args[0].type, TensorType):
output_type = flatten_check(n.args[0].type, start_dim, end_dim)
n.type = get_greatest_upper_bound(output_type , n.type)
n.type = get_greatest_upper_bound(output_type, n.type)
return n.type
class GraphTypeChecker:
def __init__(self, env, traced):
self.env = env
@ -571,16 +619,16 @@ class GraphTypeChecker:
if n.type is None:
n.type = Dyn
if n.op == 'placeholder':
if n.op == "placeholder":
return n.type
elif n.op == 'get_attr':
elif n.op == "get_attr":
t = get_parameter(self.traced, n.target) # type: ignore[arg-type]
if isinstance(t.data, torch.Tensor):
n.type = TensorType(t.data.shape)
return n.type
elif n.op == 'call_function':
elif n.op == "call_function":
if n.target == getattr:
assert getattr in _INFERENCE_RULES
return _INFERENCE_RULES[n.target](n, self.traced)
@ -588,18 +636,24 @@ class GraphTypeChecker:
elif n.target in _INFERENCE_RULES:
return _INFERENCE_RULES[n.target](n)
else:
raise RuntimeError(f'No inference rule registered for target {n.target}!')
raise RuntimeError(
f"No inference rule registered for target {n.target}!"
)
elif n.op == 'call_module':
elif n.op == "call_module":
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _INFERENCE_RULES:
return _INFERENCE_RULES[type(module_instance)](n, module_instance)
else:
raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
raise RuntimeError(
f"No inference rule registered for class {type(module_instance)}!"
)
elif n.op == "output":
elif n.op == 'output':
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
@ -634,6 +688,7 @@ def linear_refinement_rule(n: Node):
res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
return res
@register_refinement_rule(BatchNorm2d)
@register_refinement_rule(torch.nn.ReLU)
def all_eq(n: Node):
@ -688,7 +743,11 @@ def element_wise_eq(n: Node):
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
arg_type1 = n.args[0].type
arg_type2 = n.args[1].type
if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType):
if (
isinstance(arg_type1, TensorType)
and isinstance(arg_type2, TensorType)
and isinstance(n.type, TensorType)
):
args1, args2 = broadcast_types(arg_type1, arg_type2)
# by this point, we know that args1 and args2 are the same size.
a1 = args1.__args__
@ -757,12 +816,14 @@ def conv_rule(n: Node, module_instance):
n.type = new_type
return new_type
class Refine:
"""
Symbolic shape inference.
Generates constraints over type variables.
Currently all constraints are equality constraints.
"""
def __init__(self, traced):
self.constraints = []
self.traced = traced
@ -805,7 +866,6 @@ class Refine:
else:
return typ
def convert_to_sympy_symbols(self, typ):
"""
Replace all unknown types with fresh type variables.
@ -835,22 +895,24 @@ class Refine:
n.type = self.replace_dyn_with_fresh_var(n.type)
if n.op == 'call_function':
if n.op == "call_function":
if n.target in _REFINEMENT_RULES:
self.constraints += _REFINEMENT_RULES[n.target](n)
else:
pass
if n.op == 'call_module':
if n.op == "call_module":
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _REFINEMENT_RULES:
self.constraints += _REFINEMENT_RULES[type(module_instance)](n)
else:
pass
if n.op == 'output':
if n.op == "output":
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
@ -859,28 +921,31 @@ class Refine:
def infer_symbolic_relations(self, n: Node):
n.type = self.convert_to_sympy_symbols(n.type)
if n.op == 'call_function':
if n.op == "call_function":
if n.target in _RULES:
return _RULES[n.target](n)
else:
pass
if n.op == 'call_module':
if n.op == "call_module":
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _RULES:
return _RULES[type(module_instance)](n, module_instance)
else:
pass
if n.op == 'output':
if n.op == "output":
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
else:
pass
def get_parameter(traced, target: str):
"""
Returns the parameter given by ``target`` if it exists,

View File

@ -1,14 +1,13 @@
# mypy: allow-untyped-defs
import torch
from torch.fx.node import Node
from torch.fx._symbolic_trace import symbolic_trace
from torch.fx.passes.tools_common import legalize_graph
import itertools
import operator
from typing import Dict, List, Tuple
import torch
from torch.fx._symbolic_trace import symbolic_trace
from torch.fx.node import Node
from torch.fx.passes.tools_common import legalize_graph
def split_result_tensors(
result: torch.Tensor, inputs: List[torch.Tensor]
@ -146,7 +145,14 @@ def merge_matmul(in_mod: torch.nn.Module):
# Multiply the concatenated LHS operands with the one RHS. This will produce
# the same results as all the individual matmuls involving rhs in the original graph,
# but they will all be concatenated together.
merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
merge_mm = gm.graph.call_function(
torch.matmul,
(
merge_mm_cat,
rhs,
),
{},
)
# Split the result of the merged matmul using the shapes of the LHS operands
# to ascertain how large each chunk should be.

View File

@ -1,14 +1,15 @@
# mypy: allow-untyped-defs
import torch
import torch.fx
import warnings
import functools
import builtins
import functools
import warnings
from typing import Any, Callable, Dict, Optional, Union
import torch
import torch.fx
def embedding_override(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
def nn_layernorm_override(self, input):
@ -24,21 +25,22 @@ def torch_nn_relu_override(self, x):
def functional_relu_override(x, inplace=False):
assert not inplace, 'dont support inplace functional.relu for metatensor analysis'
assert not inplace, "dont support inplace functional.relu for metatensor analysis"
return x
def torch_where_override(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta')
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
def torch_abs_override(input, *, out=None):
assert out is None, 'Dont support in-place abs for MetaTensor analysis'
assert out is None, "Dont support in-place abs for MetaTensor analysis"
return input
manual_meta_overrides : Dict[Callable, Callable] = {
manual_meta_overrides: Dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override,
torch.nn.LayerNorm: nn_layernorm_override,
torch.relu: torch_relu_override,
@ -48,6 +50,7 @@ manual_meta_overrides : Dict[Callable, Callable] = {
torch.abs: torch_abs_override,
}
def gen_constructor_wrapper(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
@ -57,57 +60,66 @@ def gen_constructor_wrapper(target):
if isinstance(v, torch.fx.Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy('call_function', target, args, kwargs)
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
class MetaProxy(torch.fx.Proxy):
def install_tensor_meta(self, tensor_meta):
self._tensor_meta = tensor_meta
def size(self, dim=None):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
return self._tensor_meta.size(*[dim] if dim else [])
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
return self.tracer.create_proxy(
"call_method", "size", (self, dim) if dim else (self,), {}
)
def dim(self):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
return self._tensor_meta.dim()
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
return self.tracer.create_proxy("call_method", "dim", (self,), {})
@property
def shape(self):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
return self._tensor_meta.shape
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {})
return self.tracer.create_proxy(
"call_function", builtins.getattr, (self, "shape"), {}
)
@property
def dtype(self):
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
return self._tensor_meta.dtype
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {})
return self.tracer.create_proxy(
"call_function", builtins.getattr, (self, "dtype"), {}
)
@property
def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
# replace these values with a constant 'meta'
return MetaDeviceAttribute(self, 'device')
return MetaDeviceAttribute(self, "device")
def __getattr__(self, k):
if k == '_tensor_meta':
if k == "_tensor_meta":
return self.__getattribute__(k)
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return MetaAttribute(self, k)
class MetaAttribute(MetaProxy):
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
@ -118,33 +130,51 @@ class MetaAttribute(MetaProxy):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
self._node = self.tracer.create_proxy(
"call_function", getattr, (self.root, self.attr), {}
).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
return self.tracer.create_proxy(
"call_method", self.attr, (self.root,) + args, kwargs
)
class MetaDeviceAttribute(MetaAttribute):
pass
def proxys_to_metas(v):
if isinstance(v, MetaDeviceAttribute):
return 'meta'
return "meta"
if isinstance(v, torch.fx.Proxy):
assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}'
assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta'
assert isinstance(v, MetaProxy), f"Expected MetaProxy but got {type(v)}"
assert hasattr(v, "_tensor_meta"), "MetaProxy does not have an associated meta"
return v._tensor_meta
return v
class MetaTracer(torch.fx.Tracer):
allow_insert_stateless_mods : bool = True
allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye']
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
def create_proxy(
self,
kind,
target,
args,
kwargs,
name=None,
type_expr=None,
proxy_factory_fn=None,
):
rv = super().create_proxy(
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
)
if kind == 'placeholder' and target in self.meta_args:
if kind == "placeholder" and target in self.meta_args:
rv.install_tensor_meta(self.meta_args[target])
return rv
@ -154,54 +184,57 @@ class MetaTracer(torch.fx.Tracer):
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if 'device' in kwargs:
kwargs['device'] = 'meta'
if "device" in kwargs:
kwargs["device"] = "meta"
try:
args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
if kind == 'call_function':
if kind == "call_function":
meta_target = manual_meta_overrides.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == 'call_method':
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index]
elif kind == 'call_module':
assert hasattr(self, 'orig_forward')
elif kind == "call_method":
meta_target = getattr(args_metas[0], target) # type: ignore[index]
meta_out = meta_target(*args_metas[1:], **kwargs_metas) # type: ignore[index]
elif kind == "call_module":
assert hasattr(self, "orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in manual_meta_overrides:
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type]
meta_out = manual_meta_overrides[mod_type](
mod, *args_metas, **kwargs_metas
) # type: ignore[misc, arg-type]
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == 'get_attr':
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split('.')
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
assert isinstance(attr_itr, torch.Tensor)
meta_out = attr_itr.to(device='meta')
meta_out = attr_itr.to(device="meta")
finally:
self._disable_module_getattr = False
else:
return rv
# TODO
assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet'
assert isinstance(rv, torch.fx.Proxy), "Dont support composite output yet"
rv.install_tensor_meta(meta_out)
except Exception as e:
warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}')
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
return rv
def getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, '_disable_module_getattr', False):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
return super().getattr(attr, attr_val, parameter_proxy_cache)
@ -228,7 +261,11 @@ class MetaTracer(torch.fx.Tracer):
try:
return super().path_of_module(mod)
except NameError:
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
if (
self.allow_insert_stateless_mods
and len(list(mod.parameters())) == 0
and len(list(mod.buffers())) == 0
):
path = self._insert_module_as_submodule(mod)
self.prev_module = path
return path
@ -237,12 +274,13 @@ class MetaTracer(torch.fx.Tracer):
def proxy(self, node):
return MetaProxy(node, self)
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
def trace(self, root, meta_args: Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
assert isinstance(meta_args, dict)
self.meta_args = meta_args
self.patched_torch_methods = {
target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
target: gen_constructor_wrapper(getattr(torch, target))
for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
@ -252,18 +290,22 @@ class MetaTracer(torch.fx.Tracer):
try:
graph = super().trace(root, concrete_args)
graph._tracer_extras = {'meta_args': meta_args}
graph._tracer_extras = {"meta_args": meta_args}
return graph
finally:
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
meta_args : Optional[Dict[str, torch.Tensor]] = None,
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
meta_args: Optional[Dict[str, torch.Tensor]] = None,
concrete_args: Optional[Dict[str, Any]] = None,
) -> torch.fx.GraphModule:
tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
name = (
root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
)
gm = torch.fx.GraphModule(tracer.root, graph, name)
return gm

View File

@ -1,7 +1,16 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
op_mod, op_gt, op_lt, op_neq, op_eq
from torch.fx.tensor_type import TensorType, Dyn
from torch.fx.experimental.migrate_gradual_types.operation import (
op_add,
op_div,
op_eq,
op_gt,
op_lt,
op_mod,
op_mul,
op_neq,
op_sub,
)
from torch.fx.tensor_type import Dyn, TensorType
class Constraint:
@ -22,7 +31,7 @@ class Conj(Constraint):
return False
def __repr__(self):
return f'And({self.conjucts})'
return f"And({self.conjucts})"
class Disj(Constraint):
@ -34,12 +43,14 @@ class Disj(Constraint):
def __eq__(self, other):
if isinstance(other, Disj):
return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
return (
self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
)
else:
return False
def __repr__(self):
return f'Or({self.disjuncts})'
return f"Or({self.disjuncts})"
class Prod(Constraint):
@ -56,13 +67,14 @@ class Prod(Constraint):
return False
def __repr__(self):
return f'Product({self.products})'
return f"Product({self.products})"
class T(Constraint):
"""
True
"""
def __init__(self) -> None:
pass
@ -70,12 +82,14 @@ class T(Constraint):
return isinstance(other, T)
def __repr__(self):
return 'True'
return "True"
class F(Constraint):
"""
False
"""
def __init__(self) -> None:
pass
@ -83,13 +97,14 @@ class F(Constraint):
return isinstance(other, F)
def __repr__(self):
return 'False'
return "False"
class BinaryConstraint(Constraint):
"""
Represents all binary operations
"""
def __init__(self, lhs, rhs, op):
"""
:param lhs: lhs of the constraint
@ -102,21 +117,25 @@ class BinaryConstraint(Constraint):
def __eq__(self, other):
if isinstance(other, BinaryConstraint):
return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
return (
self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
)
else:
return False
def __repr__(self):
return f'({self.lhs} {self.op} {self.rhs})'
return f"({self.lhs} {self.op} {self.rhs})"
class BinConstraintT(BinaryConstraint):
"""
Binary constraints about tensors
"""
def __init__(self, lhs, rhs, op):
assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
(isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and (
isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn
)
super().__init__(lhs, rhs, op)
def __eq__(self, other):
@ -127,6 +146,7 @@ class BinConstraintD(BinaryConstraint):
"""
Binary constraints about dimensions
"""
def __init__(self, lhs, rhs, op):
assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
@ -137,11 +157,11 @@ class BinConstraintD(BinaryConstraint):
return super().__eq__(other)
class TGreatestUpperBound(Constraint):
"""
Greatest Upper bound for tensors with dynamic type
"""
def __init__(self, res, rhs1, rhs2):
"""
:param res: tensor variable that stores the result of the outout
@ -153,11 +173,15 @@ class TGreatestUpperBound(Constraint):
self.rhs2 = rhs2
def __repr__(self):
return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}'
return f"{self.res} = {self.rhs1}\u2294*{self.rhs2}"
def __eq__(self, other):
if isinstance(other, TGreatestUpperBound):
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
return (
self.res == other.res
and self.rhs1 == other.rhs1
and self.rhs2 == other.rhs2
)
else:
return False
@ -166,6 +190,7 @@ class DGreatestUpperBound(Constraint):
"""
Greatest Upper bound for dimensions
"""
def __init__(self, res, rhs1, rhs2):
"""
:param res: Dimension variable to store the result
@ -181,11 +206,15 @@ class DGreatestUpperBound(Constraint):
self.rhs2 = rhs2
def __repr__(self):
return f'{self.res} = {self.rhs1}\u2294{self.rhs2}'
return f"{self.res} = {self.rhs1}\u2294{self.rhs2}"
def __eq__(self, other):
if isinstance(other, DGreatestUpperBound):
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
return (
self.res == other.res
and self.rhs1 == other.rhs1
and self.rhs2 == other.rhs2
)
else:
return False
@ -194,6 +223,7 @@ class CanReshape(Constraint):
"""
can_reshape constraint
"""
def __init__(self, src, target):
"""
:param src: tensor variable
@ -203,7 +233,7 @@ class CanReshape(Constraint):
self.target = target
def __repr__(self):
return f'can-reshape({self.src}, {self.target})'
return f"can-reshape({self.src}, {self.target})"
def __eq__(self, other):
if isinstance(other, CanReshape):
@ -213,7 +243,6 @@ class CanReshape(Constraint):
class IndexSelect(Constraint):
def __init__(self, tensor_size, input_var, dim_replace, index, output):
"""
Args:
@ -235,26 +264,28 @@ class IndexSelect(Constraint):
self.output = output
def __repr__(self):
return f' {self.output} = ' \
f'IndexSelect({self.input_var}, ' \
f'tensor_size: {self.tensor_size}, ' \
f'{self.dim_replace}, ' \
f'{self.index})'
return (
f" {self.output} = "
f"IndexSelect({self.input_var}, "
f"tensor_size: {self.tensor_size}, "
f"{self.dim_replace}, "
f"{self.index})"
)
def __eq__(self, other):
if isinstance(other, IndexSelect):
return self.tensor_size == other.tensor_size and \
self.dim_replace == other.dim_replace and \
self.index == other.index and \
self.output == other.output and \
self.input_var == other.input_var
return (
self.tensor_size == other.tensor_size
and self.dim_replace == other.dim_replace
and self.index == other.index
and self.output == other.output
and self.input_var == other.input_var
)
else:
return False
class Transpose(Constraint):
def __init__(self, tensor_size, input_var, index1, index2, output):
"""
Args:
@ -276,26 +307,28 @@ class Transpose(Constraint):
self.output = output
def __repr__(self):
return f' {self.output} = ' \
f'Transpose({self.input_var}, ' \
f'tensor_size: {self.tensor_size}, ' \
f'{self.index1}, ' \
f'{self.index2})'
return (
f" {self.output} = "
f"Transpose({self.input_var}, "
f"tensor_size: {self.tensor_size}, "
f"{self.index1}, "
f"{self.index2})"
)
def __eq__(self, other):
if isinstance(other, Transpose):
return self.tensor_size == other.tensor_size and \
self.index1 == other.index1 and \
self.index2 == other.index2 and \
self.output == other.output and \
self.input_var == other.input_var
return (
self.tensor_size == other.tensor_size
and self.index1 == other.index1
and self.index2 == other.index2
and self.output == other.output
and self.input_var == other.input_var
)
else:
return False
class GetItem(Constraint):
def __init__(self, tensor_size, index, res, input_var):
"""
Constraint for getting item given a tensor size
@ -312,19 +345,21 @@ class GetItem(Constraint):
self.input_var = input_var
def __repr__(self):
return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
return f" {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})"
def __eq__(self, other):
if isinstance(other, GetItem):
return self.res == other.res and \
self.tensor_size == other.tensor_size and \
self.index == other.index and \
self.input_var == other.input_var
return (
self.res == other.res
and self.tensor_size == other.tensor_size
and self.index == other.index
and self.input_var == other.input_var
)
else:
return False
class GetItemTensor(Constraint):
class GetItemTensor(Constraint):
def __init__(self, tensor_size, index_tuple, res, input_var):
"""
Constraint for getting item given a tensor size
@ -343,20 +378,32 @@ class GetItemTensor(Constraint):
self.input_var = input_var
def __repr__(self):
return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
return f" {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})"
def __eq__(self, other):
if isinstance(other, GetItemTensor):
return self.res == other.res and \
self.tensor_size == other.tensor_size and \
self.index_tuple == other.index_tuple and \
self.input_var == other.input_var
return (
self.res == other.res
and self.tensor_size == other.tensor_size
and self.index_tuple == other.index_tuple
and self.input_var == other.input_var
)
else:
return False
class CalcConv(Constraint):
def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
class CalcConv(Constraint):
def __init__(
self,
conv_result,
input_var,
c_out,
kernel,
padding,
stride,
dilation,
matching_constraint_vars,
):
"""
:param conv_result: the convolution result
:param input_var: input to convolution
@ -373,25 +420,41 @@ class CalcConv(Constraint):
self.matching_constraint = matching_constraint_vars
def __repr__(self):
return f'{self.conv_result} =' \
f' calc-conv({self.input_var},' \
f' {self.c_out}, {self.kernel}, ' \
f'{self.padding}, {self.stride},' \
f' {self.dilation})'
return (
f"{self.conv_result} ="
f" calc-conv({self.input_var},"
f" {self.c_out}, {self.kernel}, "
f"{self.padding}, {self.stride},"
f" {self.dilation})"
)
def __eq__(self, other):
if isinstance(other, CalcConv):
return self.conv_result == other.conv_result and self.input_var == other.input_var and \
self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
and self.stride == other.stride and self.dilation == other.dilation \
return (
self.conv_result == other.conv_result
and self.input_var == other.input_var
and self.c_out == other.c_out
and self.kernel == other.kernel
and self.padding == other.padding
and self.stride == other.stride
and self.dilation == other.dilation
and self.matching_constraint == other.matching_constraint
)
else:
return False
class CalcMaxPool(Constraint):
def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
def __init__(
self,
maxpool_result,
input_var,
kernel,
padding,
stride,
dilation,
matching_constraint_vars,
):
"""
:param maxpool_result: the result of maxpool
:param input_var: input to convolution
@ -406,18 +469,25 @@ class CalcMaxPool(Constraint):
self.matching_constraint = matching_constraint_vars
def __repr__(self):
return f'{self.maxpool_result} =' \
f' calc-maxpool({self.input_var},' \
f' {self.kernel}, ' \
f'{self.padding}, {self.stride},' \
f' {self.dilation})'
return (
f"{self.maxpool_result} ="
f" calc-maxpool({self.input_var},"
f" {self.kernel}, "
f"{self.padding}, {self.stride},"
f" {self.dilation})"
)
def __eq__(self, other):
if isinstance(other, CalcMaxPool):
return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
and self.kernel == other.kernel and self.padding == other.padding \
and self.stride == other.stride and self.dilation == other.dilation \
return (
self.maxpool_result == other.maxpool_result
and self.input_var == other.input_var
and self.kernel == other.kernel
and self.padding == other.padding
and self.stride == other.stride
and self.dilation == other.dilation
and self.matching_constraint == other.matching_constraint
)
else:
return False
@ -437,21 +507,28 @@ class ApplyBroadcasting(Constraint):
def __eq__(self, other):
if isinstance(other, ApplyBroadcasting):
return self.res1 == other.res1 \
and self.res2 == other.res2 \
and self.input1 == other.input1 \
return (
self.res1 == other.res1
and self.res2 == other.res2
and self.input1 == other.input1
and self.input2 == other.input2
)
else:
return False
def __repr__(self):
return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
return (
f"{self.res1}, {self.res2} ="
f" apply-broadcasting({self.input1},"
f" {self.input2})"
)
class CalcProduct(Constraint):
"""
Given correct dimensions, calculate the product for flatten accounting for Dyn
"""
def __init__(self, start, end, flattened, dims_to_flatten):
"""
:param start: start index
@ -471,20 +548,25 @@ class CalcProduct(Constraint):
def __eq__(self, other):
if isinstance(other, CalcProduct):
return self.start == other.start and self.end == other.end and \
self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
return (
self.start == other.start
and self.end == other.end
and self.dims_to_flatten == other.dims_to_flatten
and self.flattened == other.flattened
)
else:
return False
def __repr__(self):
return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
return f"{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})"
class TVar:
"""
Tensor variable with no tensor constructor
"""
def __init__(self, tvar):
"""
:param tvar: tensor variable
@ -492,7 +574,7 @@ class TVar:
self.tvar = tvar
def __repr__(self):
return f'TV({self.tvar})'
return f"TV({self.tvar})"
def __eq__(self, other):
if isinstance(other, TVar):
@ -505,6 +587,7 @@ class DVar:
"""
Dimension variable
"""
def __init__(self, c):
"""
:param c: character or number
@ -512,7 +595,7 @@ class DVar:
self.c = c
def __repr__(self):
return f'DV({self.c})'
return f"DV({self.c})"
def __eq__(self, other):
if isinstance(other, DVar):
@ -525,6 +608,7 @@ class BVar:
"""
Boolean variable
"""
def __init__(self, c):
"""
:param c: character or number
@ -532,7 +616,7 @@ class BVar:
self.c = c
def __repr__(self):
return f'BV({self.c})'
return f"BV({self.c})"
def __eq__(self, other):
if isinstance(other, BVar):
@ -554,5 +638,6 @@ def is_bool_expr(constraint):
else:
return isinstance(constraint, (BVar, Conj, Disj))
def is_dim(d):
return isinstance(d, (DVar, int)) or d == Dyn

View File

@ -1,34 +1,71 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
import operator
import warnings
from typing import Callable, Dict, Iterable
import torch
from torch.fx._symbolic_trace import _assert_is_none
from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \
TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound
from torch.fx.experimental.migrate_gradual_types.operation import \
op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul
from torch.fx.node import Target, Node
from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \
gen_bvar
from torch.fx.experimental.migrate_gradual_types.constraint import (
ApplyBroadcasting,
BinConstraintD,
BinConstraintT,
CalcConv,
CalcMaxPool,
CalcProduct,
CanReshape,
Conj,
DGreatestUpperBound,
Disj,
DVar,
F,
GetItem,
GetItemTensor,
IndexSelect,
T,
TGreatestUpperBound,
Transpose,
TVar,
)
from torch.fx.experimental.migrate_gradual_types.operation import (
op_add,
op_consistency,
op_div,
op_eq,
op_gt,
op_leq,
op_lt,
op_matching,
op_mul,
op_neq,
op_precision,
op_sub,
)
from torch.fx.experimental.migrate_gradual_types.util import (
gen_bvar,
gen_dvar,
gen_nat_constraints,
gen_tensor_dims,
gen_tvar,
)
from torch.fx.node import Node, Target
from torch.fx.tensor_type import Dyn, TensorType
from torch.nn.modules.conv import Conv2d
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.conv import Conv2d
_INFERENCE_RULES: Dict[Target, Callable] = {}
MAX_TENSOR_RANK = 4
def register_inference_rule(call_target):
def register(fn):
if call_target in _INFERENCE_RULES:
raise RuntimeError(f'Inference rule already registered for {call_target}!')
raise RuntimeError(f"Inference rule already registered for {call_target}!")
_INFERENCE_RULES[call_target] = fn
return fn
return register
@ -55,10 +92,11 @@ def get_attr_inference_rule(n: Node, symbols, constraints, counter):
input = symbols[n.args[0]]
attr = n.args[1]
if attr == 'device':
if attr == "device":
return [BinConstraintT(input, output, op_eq)], counter
else:
raise NotImplementedError('Not yet implemented')
raise NotImplementedError("Not yet implemented")
@register_inference_rule(torch.bmm)
def bmm_inference_rule(n: Node, symbols, constraints, counter):
@ -79,26 +117,53 @@ def bmm_inference_rule(n: Node, symbols, constraints, counter):
dims_input1, counter = gen_tensor_dims(3, counter)
dims_input2, counter = gen_tensor_dims(3, counter)
inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
BinConstraintT(bmm_input2, Dyn, op_eq),
BinConstraintT(bmm_output, Dyn, op_eq)])
inputs_dyn = Conj(
[
BinConstraintT(bmm_input1, Dyn, op_eq),
BinConstraintT(bmm_input2, Dyn, op_eq),
BinConstraintT(bmm_output, Dyn, op_eq),
]
)
input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)])
input1_dyn = Conj(
[
BinConstraintT(bmm_input1, Dyn, op_eq),
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
BinConstraintT(
bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq
),
]
)
input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq),
BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)])
input2_dyn = Conj(
[
BinConstraintT(bmm_input2, Dyn, op_eq),
BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
BinConstraintT(
bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq
),
]
)
consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)]
consistency_constraints = [
BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)
]
batch_size, counter = gen_dvar(counter)
inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq),
*consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])])
inputs_are_tensors = Conj(
[
BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
BinConstraintT(
bmm_output,
TensorType([batch_size, dims_input1[1], dims_input2[2]]),
op_eq,
),
*consistency_constraints,
DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]),
]
)
return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
@ -115,8 +180,6 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(n.args[1], int)
assert isinstance(n.args[2], Node)
index_select, counter = gen_tvar(counter)
symbols[n] = index_select
@ -126,10 +189,30 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter):
is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
for i in range(MAX_TENSOR_RANK)])])
c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
for i in range(MAX_TENSOR_RANK)])])
c2 = Conj(
[
is_size_1,
Disj(
[
IndexSelect(
i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select
)
for i in range(MAX_TENSOR_RANK)
]
),
]
)
c3 = Conj(
[
is_dyn,
Disj(
[
IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
for i in range(MAX_TENSOR_RANK)
]
),
]
)
return [Disj([c2, c3])], counter
@ -158,14 +241,27 @@ def expand_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(symbols[arg], DVar)
e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))
e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq)
e2_constraint = BinConstraintT(
e2,
TensorType(
[arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]
),
op_eq,
)
constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand)
constraints, counter = gen_broadcasting_constraints(
e1, e2, symbols, counter, expand
)
# constraint the output size
dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
nat_constraints = gen_nat_constraints(dims)
c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints]
c = [
BinConstraintT(expand, TensorType(dims), op_eq),
*nat_constraints,
e2_constraint,
*e2_nat_constraints,
]
constraints += c
return constraints, counter
@ -206,7 +302,7 @@ def equality_inference_rule(n: Node, symbols, constraints, counter):
my_size = [symbols[arg] for arg in n.args[0]]
return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
@register_inference_rule("transpose")
@ -225,10 +321,17 @@ def transpose_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(from_arg, TVar)
# input and output are dyn
is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)])
is_dyn = Conj(
[BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]
)
# or input is a tensor and we actually do the replacement
c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)])
c3 = Disj(
[
Transpose(i + 1, from_arg, n.args[1], n.args[2], output)
for i in range(MAX_TENSOR_RANK)
]
)
return [Disj([is_dyn, c3])], counter
@ -250,8 +353,11 @@ def type_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(from_arg, TVar)
assert isinstance(to_arg, TVar)
return [BinConstraintT(from_arg, to_arg, op_consistency),
BinConstraintT(output, to_arg, op_eq)], counter
return [
BinConstraintT(from_arg, to_arg, op_consistency),
BinConstraintT(output, to_arg, op_eq),
], counter
@register_inference_rule("masked_fill_")
def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
@ -273,9 +379,11 @@ def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
if isinstance(e1, TVar) and isinstance(e2, TVar):
masked_fill_tensor, counter = gen_tvar(counter)
symbols[n] = masked_fill_tensor
return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor)
return gen_broadcasting_constraints(
e1, e2, symbols, counter, masked_fill_tensor
)
else:
raise NotImplementedError('Not yet implemented')
raise NotImplementedError("Not yet implemented")
@register_inference_rule(torch.nn.functional.embedding)
@ -286,7 +394,9 @@ def embedding_inference_rule_functional(n: Node, symbols, constraints, counter):
# will treat this as a static shape. So we will not use matching.
weight_dims, counter = gen_tensor_dims(2, counter)
equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq)
equality_constraint = BinConstraintT(
embedding_dim_weights, TensorType(weight_dims), op_eq
)
embedding_dim = weight_dims[1]
constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter)
return [equality_constraint] + constraints, counter
@ -302,7 +412,6 @@ def embedding_inference_rule(n: Node, module_instance, symbols, constraints, cou
def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
embedding_output, counter = gen_tvar(counter)
symbols[n] = embedding_output
embedding_input = symbols[n.args[0]]
@ -318,9 +427,15 @@ def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
nat_constraints = gen_nat_constraints(new_dims)
# we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] +
nat_constraints)
c_tensor_i = Conj(
[
BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
BinConstraintT(
embedding_output, TensorType(new_dims + [embedding_dim]), op_eq
),
]
+ nat_constraints
)
c2.append(c_tensor_i)
return [Disj([c1, Disj(c2)])], counter
@ -348,9 +463,10 @@ def view_inference_rule(n: Node, symbols, constraints, counter):
my_view, counter = gen_tvar(counter)
symbols[n] = my_view
src_var = symbols[n.args[0]]
t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape
t2 = [
symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]
] # target shape
t2_type = []
num_constraints = []
@ -382,7 +498,6 @@ def size_inference_rule(n: Node, symbols, constraints, counter):
Ex: size = input_ids.size()
"""
if len(n.args) == 1:
# generate the new variable
size, counter = gen_tvar(counter)
@ -398,7 +513,10 @@ def size_inference_rule(n: Node, symbols, constraints, counter):
size_index, counter = gen_dvar(counter)
symbols[n] = size_index
input = symbols[n.args[0]]
c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)]
c2 = [
GetItem(i + 1, n.args[1], size_index, input)
for i in range(MAX_TENSOR_RANK)
]
c3 = BinConstraintD(0, size_index, op_leq)
input_dyn = BinConstraintT(input, Dyn, op_eq)
@ -452,9 +570,14 @@ def cumsum_inference_rule(n: Node, symbols, constraints, counter):
nat_constraints = gen_nat_constraints(new_dims)
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq),
BinConstraintT(output, TensorType(new_dims), op_eq)] +
[range_check(arg_1, i)] + nat_constraints)
c_tensor_i = Conj(
[
BinConstraintT(input, TensorType(new_dims), op_eq),
BinConstraintT(output, TensorType(new_dims), op_eq),
]
+ [range_check(arg_1, i)]
+ nat_constraints
)
c2.append(c_tensor_i)
dyn_or_tensor = Disj([c1, Disj(c2)])
@ -481,7 +604,6 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
get_item_arg = symbols[n.args[0]]
assert isinstance(get_item_arg, TVar)
# if the input is dynamic, we accept any index and return
# a dynamic dimension as output
input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
@ -492,8 +614,10 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
# generate a getItem constraint which will be expanded based on the
# tensor dimension.
c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)]
c2 = [
GetItem(i + 1, n.args[1], get_item_output, get_item_arg)
for i in range(MAX_TENSOR_RANK)
]
# since the output is a dimension, we make sure it's a natural number
# added as a conjunction to the disjunction of c2
@ -515,8 +639,10 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment]
c1 = Conj([input_dyn, output_dyn])
c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc]
for i in range(MAX_TENSOR_RANK)]
c2 = [
GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc]
for i in range(MAX_TENSOR_RANK)
]
else:
# TODO: we should figure out why there is a key-error here.
return [], counter
@ -524,7 +650,7 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
return [Disj([c1, *c2])], counter
else:
raise RuntimeError('Method not yet implemented')
raise RuntimeError("Method not yet implemented")
@register_inference_rule(operator.gt)
@ -553,7 +679,7 @@ def gt_inference_rule(n: Node, symbols, constraints, counter):
return [equality_constraint], counter
else:
raise RuntimeError('Sort Mismatch')
raise RuntimeError("Sort Mismatch")
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
if isinstance(e1, DVar):
@ -567,7 +693,9 @@ def gt_inference_rule(n: Node, symbols, constraints, counter):
elif isinstance(e1, TVar) and isinstance(e2, int):
# then we made the wrong assumption about the argument being a tensor
# so we should fix the assumption
warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.')
warnings.warn(
f"Made the wrong assumption for node {n}. Correctness not guaranteed."
)
new_e1, counter = gen_dvar(counter)
symbols[n.args[0]] = new_e1
@ -580,10 +708,10 @@ def gt_inference_rule(n: Node, symbols, constraints, counter):
return [equality_constraint], counter
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
@register_inference_rule(operator.eq)
@ -609,7 +737,7 @@ def eq_inference_rule(n: Node, symbols, constraints, counter):
return [equality_constraint], counter
else:
raise RuntimeError('Sort Mismatch')
raise RuntimeError("Sort Mismatch")
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
if isinstance(e1, DVar):
@ -620,9 +748,10 @@ def eq_inference_rule(n: Node, symbols, constraints, counter):
equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
return [equality_constraint], counter
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
@register_inference_rule(operator.ne)
def neq_inference_rule(n: Node, symbols, constraints, counter):
@ -641,7 +770,6 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
# implementing for size 3 and 4
if len(n.args[1]) == 3:
assert isinstance(n.args[1][0], (Node, int))
assert isinstance(n.args[1][1], (Node, int))
assert isinstance(n.args[1][2], (Node, int))
@ -662,11 +790,19 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
neq_3 = BinConstraintD(d3, b[2], op_neq)
# dimensions inconsistent
dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1])
dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2])
dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3])
dims_inconsistent1 = Conj(
[BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]
)
dims_inconsistent2 = Conj(
[BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]
)
dims_inconsistent3 = Conj(
[BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]
)
dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3])
dims_inconsistent = Disj(
[dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]
)
# we are covering size 3 and 4 only for now
ne_constraint = Conj([input_is_size3, dims_inconsistent])
@ -675,7 +811,6 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
elif len(n.args[1]) == 4:
assert isinstance(n.args[1][0], (Node, int))
assert isinstance(n.args[1][1], (Node, int))
assert isinstance(n.args[1][2], (Node, int))
@ -703,12 +838,27 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
neq_4 = BinConstraintD(d4, b4, op_neq)
# dimensions to inconsistent
dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1])
dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2])
dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3])
dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4])
dims_inconsistent1 = Conj(
[BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]
)
dims_inconsistent2 = Conj(
[BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]
)
dims_inconsistent3 = Conj(
[BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]
)
dims_inconsistent4 = Conj(
[BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]
)
dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4])
dims_inconsistent = Disj(
[
dims_inconsistent1,
dims_inconsistent2,
dims_inconsistent3,
dims_inconsistent4,
]
)
ne_constraint = Conj([input_is_size4, dims_inconsistent])
@ -717,7 +867,7 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
return [equality_constraint], counter
@ -748,7 +898,7 @@ def lt_inference_rule(n: Node, symbols, constraints, counter):
return [equality_constraint], counter
else:
raise RuntimeError('Sort Mismatch')
raise RuntimeError("Sort Mismatch")
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
if isinstance(e1, DVar):
@ -759,10 +909,10 @@ def lt_inference_rule(n: Node, symbols, constraints, counter):
equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
return [equality_constraint], counter
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
@register_inference_rule(torch.full)
@ -788,28 +938,42 @@ def arange_inference_rule(n: Node, symbols, constraints, counter):
if len(n.args) == 1:
end = symbols[n.args[0]]
else:
raise NotImplementedError('Not yet implemented')
raise NotImplementedError("Not yet implemented")
# int((end - start) / step)
d1, counter = gen_dvar(counter)
size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
size_constraint = BinConstraintD(
d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq
)
arange, counter = gen_tvar(counter)
symbols[n] = arange
# either the a parameter is a number or it is Dyn
c1 = Disj([BinConstraintD(end, Dyn, op_eq),
BinConstraintD(start, Dyn, op_eq),
BinConstraintD(step, Dyn, op_eq)])
c1 = Disj(
[
BinConstraintD(end, Dyn, op_eq),
BinConstraintD(start, Dyn, op_eq),
BinConstraintD(step, Dyn, op_eq),
]
)
c2 = BinConstraintD(d1, Dyn, op_eq)
both_dyn = Conj([c1, c2])
c11 = Conj([BinConstraintD(end, Dyn, op_neq),
BinConstraintD(start, Dyn, op_neq),
BinConstraintD(step, Dyn, op_neq)])
c11 = Conj(
[
BinConstraintD(end, Dyn, op_neq),
BinConstraintD(start, Dyn, op_neq),
BinConstraintD(step, Dyn, op_neq),
]
)
c22 = BinConstraintD(d1, Dyn, op_neq)
both_numbers = Conj([c11, c22, size_constraint])
return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
return [
BinConstraintT(arange, TensorType([d1]), op_eq),
Disj([both_dyn, both_numbers]),
], counter
def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
# additional vars that don't correspond to expressions
@ -829,7 +993,6 @@ def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
@register_inference_rule(torch.add)
@register_inference_rule(operator.add)
def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
op_code = None
if n.target == operator.add or n.target == torch.add:
op_code = op_add
@ -837,7 +1000,9 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
op_code = op_mul
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar):
if isinstance(symbols[n.args[0]], TVar) and isinstance(
symbols[n.args[1]], TVar
):
my_output, counter = gen_tvar(counter)
symbols[n] = my_output
e1 = symbols[n.args[0]]
@ -845,7 +1010,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output)
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
if isinstance(symbols[n.args[0]], TVar):
@ -859,8 +1024,14 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
e1 = symbols[n.args[0]]
# we will propagate the runtime value here since this is regular addition
c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq),
BinConstraintD(0, my_output, op_leq)])
c = Conj(
[
BinConstraintD(
my_output, BinConstraintD(e1, n.args[1], op_code), op_eq
),
BinConstraintD(0, my_output, op_leq),
]
)
return [c], counter
elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
@ -875,16 +1046,22 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
e2 = symbols[n.args[1]]
# we will propagate the runtime value here since this is regular addition
c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq),
BinConstraintD(0, my_output, op_leq)])
c = Conj(
[
BinConstraintD(
my_output, BinConstraintD(e2, n.args[0], op_code), op_eq
),
BinConstraintD(0, my_output, op_leq),
]
)
return [c], counter
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
else:
# TODO generate add constraints for scalar addition
raise NotImplementedError('Addition not yet implemented')
raise NotImplementedError("Addition not yet implemented")
@register_inference_rule(torch.flatten)
@ -915,7 +1092,9 @@ def flatten_inference_rule(n: Node, symbols, constraints, counter):
const = []
for i in range(1, MAX_TENSOR_RANK + 1):
c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter)
c, counter = generate_flatten_constraints(
start_dim, end_dim, input, flattened, i, counter
)
const.append(c)
return [Disj([both_dyn, *const])], counter
@ -937,7 +1116,9 @@ def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, co
Input should be consistent with the normalized_shape
"""
assert isinstance(n.args[0], Node)
return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter)
return gen_layer_norm_constraints(
n, module_instance.normalized_shape, symbols, counter
)
def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
@ -955,13 +1136,18 @@ def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
new_dims_rhs, counter = gen_tensor_dims(i, counter)
nat_constraints = gen_nat_constraints(new_dims_rhs)
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
nat_constraints)
c_tensor_i = Conj(
[
BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
BinConstraintT(output, TensorType(new_dims_rhs), op_eq),
]
+ add_layer_norm_constraints(new_dims_rhs, list(normalized_shape))
+ nat_constraints
)
c2.append(c_tensor_i)
return [Disj([c1, Disj(c2)])], counter
@register_inference_rule(torch.nn.Dropout)
@register_inference_rule(torch.nn.ReLU)
def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):
@ -983,7 +1169,9 @@ def linear_inference_rule(n: Node, module_instance, symbols, constraints, counte
If the input is Dyn, then so should the output
"""
assert isinstance(n.args[0], Node)
return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter)
return linear_constraints(
n, module_instance.in_features, module_instance.out_features, symbols, counter
)
@register_inference_rule("dim") # type: ignore[attr-defined]
@ -1001,8 +1189,12 @@ def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
for i in range(1, MAX_TENSOR_RANK + 1):
new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
BinConstraintD(my_dim, i, op_eq)])
c_tensor_i = Conj(
[
BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
BinConstraintD(my_dim, i, op_eq),
]
)
c1.append(c_tensor_i)
return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
@ -1012,8 +1204,12 @@ def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
def torch_linear_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(n.args[0], Node)
weight_dims, counter = gen_tensor_dims(2, counter)
equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq)
constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter)
equality_constraint = BinConstraintT(
symbols[n.args[1]], TensorType(weight_dims), op_eq
)
constraints, counter = linear_constraints(
n, weight_dims[1], weight_dims[0], symbols, counter
)
return [equality_constraint] + constraints, counter
@ -1034,13 +1230,20 @@ def linear_constraints(n: Node, in_features, out_features, symbols, counter):
nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) +
nat_constraints)
c_tensor_i = Conj(
[
BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq),
]
+ add_linear_constraints(
new_dims_rhs_1, new_dims_rhs_2, in_features, out_features
)
+ nat_constraints
)
c2.append(c_tensor_i)
return [Disj([c1, Disj(c2)])], counter
def add_layer_norm_constraints(input_dim, normalized_dim):
"""
The constraints say that the type has te form: [*, 1024, 1024]
@ -1130,7 +1333,13 @@ def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, coun
d4, counter = gen_dvar(counter)
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq)
c2 = BinConstraintT(
avg_pool,
TensorType(
[d1, d2, module_instance.output_size[0], module_instance.output_size[1]]
),
op_eq,
)
return [c1, c2, *nat_constraints], counter
@ -1152,12 +1361,16 @@ def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counte
# c2 = DConsistency(module_instance.in_channels, d2)
c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)
c3 = CalcConv(my_conv, input_var,
module_instance.out_channels,
module_instance.kernel_size,
module_instance.padding,
module_instance.stride,
module_instance.dilation, [d1, d2, d3, d4])
c3 = CalcConv(
my_conv,
input_var,
module_instance.out_channels,
module_instance.kernel_size,
module_instance.padding,
module_instance.stride,
module_instance.dilation,
[d1, d2, d3, d4],
)
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
@ -1176,8 +1389,15 @@ def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, count
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding,
module_instance.stride, module_instance.dilation, [d1, d2, d3, d4])
c2 = CalcMaxPool(
maxpool,
input_var,
module_instance.kernel_size,
module_instance.padding,
module_instance.stride,
module_instance.dilation,
[d1, d2, d3, d4],
)
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
@ -1190,8 +1410,7 @@ class ConstraintGenerator:
self.traced_params = dict(self.traced.named_parameters())
self.constraints = []
self.symbol_dict = {}
self.graph = traced.graph if hasattr(traced, 'graph') else graph
self.graph = traced.graph if hasattr(traced, "graph") else graph
def generate_constraints(self, counter=0):
"""
@ -1217,7 +1436,7 @@ class ConstraintGenerator:
- conv2d
"""
if n.op == 'placeholder':
if n.op == "placeholder":
x, counter = gen_tvar(counter)
self.symbol_dict[n] = x
@ -1226,8 +1445,8 @@ class ConstraintGenerator:
if n.type != Dyn and (not isinstance(n.type, TensorType)):
if n.type == torch.nn.parameter.Parameter:
# since we have a parameter, the shape must be static
assert 'example_value' in n.meta
my_type = TensorType(n.meta['example_value'].size())
assert "example_value" in n.meta
my_type = TensorType(n.meta["example_value"].size())
else:
my_type = Dyn
@ -1235,30 +1454,38 @@ class ConstraintGenerator:
c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
return [c1, c2], counter
elif n.op == 'call_function':
elif n.op == "call_function":
if n.target in _INFERENCE_RULES:
return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
return _INFERENCE_RULES[n.target](
n, self.symbol_dict, self.constraints, counter
)
else:
raise RuntimeError(f'No inference rule registered for target {n.target}!')
elif n.op == 'call_module':
raise RuntimeError(
f"No inference rule registered for target {n.target}!"
)
elif n.op == "call_module":
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _INFERENCE_RULES:
return _INFERENCE_RULES[type(module_instance)](n,
module_instance,
self.symbol_dict,
self.constraints, counter)
return _INFERENCE_RULES[type(module_instance)](
n, module_instance, self.symbol_dict, self.constraints, counter
)
else:
raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
raise RuntimeError(
f"No inference rule registered for class {type(module_instance)}!"
)
elif n.op == 'call_method':
elif n.op == "call_method":
if n.target in _INFERENCE_RULES:
return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
return _INFERENCE_RULES[n.target](
n, self.symbol_dict, self.constraints, counter
)
else:
raise RuntimeError(f'No inference rule registered for target {n.target}!')
raise RuntimeError(
f"No inference rule registered for target {n.target}!"
)
elif n.op == 'get_attr':
elif n.op == "get_attr":
t = self.traced_params.get(n.target, None)
if isinstance(t, torch.Tensor):
@ -1274,7 +1501,7 @@ class ConstraintGenerator:
else:
return [], counter
elif n.op == 'output':
elif n.op == "output":
return [], counter
else:

View File

@ -1,14 +1,14 @@
op_add = '+'
op_sub = '-'
op_mul = '*'
op_div = '/'
op_eq = '='
op_neq = '!='
op_imp = '=>'
op_matching = '\u22b3' # (contains)
op_consistency = '~'
op_precision = '\u2291' # (square image of or equal to)
op_leq = '\u2264' # less-than or equal to
op_lt = '<'
op_gt = '>'
op_mod = '%'
op_add = "+"
op_sub = "-"
op_mul = "*"
op_div = "/"
op_eq = "="
op_neq = "!="
op_imp = "=>"
op_matching = "\u22b3" # (contains)
op_consistency = "~"
op_precision = "\u2291" # (square image of or equal to)
op_leq = "\u2264" # less-than or equal to
op_lt = "<"
op_gt = ">"
op_mod = "%"

View File

@ -1,16 +1,49 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr
from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar
from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim
from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt
from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod
from torch.fx.tensor_type import TensorType, Dyn
from torch.fx.experimental.migrate_gradual_types.constraint import (
BinConstraintD,
BinConstraintT,
BVar,
Conj,
Disj,
DVar,
F,
is_algebraic_expression,
is_bool_expr,
is_dim,
Prod,
T,
TVar,
)
from torch.fx.experimental.migrate_gradual_types.constraint_generator import (
ConstraintGenerator,
)
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import (
transform_constraint,
)
from torch.fx.experimental.migrate_gradual_types.operation import (
op_add,
op_div,
op_eq,
op_gt,
op_leq,
op_lt,
op_mod,
op_mul,
op_neq,
op_sub,
)
from torch.fx.tensor_type import Dyn, TensorType
try:
import z3 # type: ignore[import]
from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D
from torch.fx.experimental.migrate_gradual_types.z3_types import (
D,
tensor_type,
z3_dyn,
)
HAS_Z3 = True
def transform_to_z3(constraint, counter, dimension_dict):
@ -41,35 +74,48 @@ try:
return (lhs == rhs), counter
else:
raise NotImplementedError('Method not yet implemented')
raise NotImplementedError("Method not yet implemented")
elif isinstance(constraint, BinConstraintD):
if constraint.op == op_eq:
if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict)
transformed_rhs, counter = transform_to_z3(
constraint.rhs, counter, dimension_dict
)
transformed_lhs = z3.Bool(constraint.lhs.c)
return transformed_lhs == transformed_rhs, counter
elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
# with dimension transformations we consider the encoding
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
lhs, counter = transform_dimension(
constraint.lhs, counter, dimension_dict
)
rhs, counter = transform_dimension(
constraint.rhs, counter, dimension_dict
)
return lhs == rhs, counter
else:
# then we have an algebraic expression which means that we disregard the
# first element of the encoding
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
lhs, counter = transform_algebraic_expression(
constraint.lhs, counter, dimension_dict
)
rhs, counter = transform_algebraic_expression(
constraint.rhs, counter, dimension_dict
)
return lhs == rhs, counter
# The assumption here is that the LHS and RHS must be dimensions
elif constraint.op == op_neq:
assert is_dim(constraint.lhs)
assert is_dim(constraint.rhs)
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
lhs, counter = transform_dimension(
constraint.lhs, counter, dimension_dict
)
rhs, counter = transform_dimension(
constraint.rhs, counter, dimension_dict
)
if constraint.rhs == Dyn or constraint.lhs == Dyn:
if constraint.rhs == Dyn:
return lhs.arg(0) == 1, counter
@ -79,44 +125,83 @@ try:
# if one of the instances is a number
elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
if isinstance(constraint.lhs, int):
return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
return (
z3.Or(
[
rhs.arg(0) == 0,
z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]),
]
),
counter,
)
elif isinstance(constraint.rhs, int):
return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
return (
z3.Or(
[
lhs.arg(0) == 0,
z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]),
]
),
counter,
)
else:
return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter
return (
z3.Or(
[
z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
z3.And(
[
lhs.arg(0) != 0,
rhs.arg(0) != 0,
lhs.arg(1) != rhs.arg(1),
]
),
]
),
counter,
)
elif constraint.op == op_leq:
# if the dimensions are not dyn, this will come into effect
# there would have been another constraint specifying if a given dimension
# is dyn or not
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
lhs, counter = transform_algebraic_expression(
constraint.lhs, counter, dimension_dict
)
rhs, counter = transform_algebraic_expression(
constraint.rhs, counter, dimension_dict
)
return lhs <= rhs, counter
elif constraint.op == op_gt:
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
lhs, counter = transform_algebraic_expression(
constraint.lhs, counter, dimension_dict
)
rhs, counter = transform_algebraic_expression(
constraint.rhs, counter, dimension_dict
)
return lhs > rhs, counter
elif constraint.op == op_lt:
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
lhs, counter = transform_algebraic_expression(
constraint.lhs, counter, dimension_dict
)
rhs, counter = transform_algebraic_expression(
constraint.rhs, counter, dimension_dict
)
return lhs < rhs, counter
else:
raise NotImplementedError('operation not yet implemented')
raise NotImplementedError("operation not yet implemented")
else:
raise NotImplementedError('Operation not yet implemented')
raise NotImplementedError("Operation not yet implemented")
def transform_var(tensor, counter, dimension_dict):
"""
@ -166,13 +251,15 @@ try:
return D(1, dimension), counter
elif isinstance(dimension, DVar):
if dimension.c in dimension_dict:
return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter
return (
D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)),
counter,
)
else:
counter += 1
dimension_dict[dimension.c] = counter
return D(z3.Int(counter), z3.Int(dimension.c)), counter
def transform_algebraic_expression(expr, counter, dimension_dict):
"""
Transforms an algebraic expression to z3 format
@ -190,7 +277,6 @@ try:
return transformed.arg(1), counter
elif isinstance(expr, Prod):
dims = []
for dim in expr.products:
assert is_dim(dim)
@ -199,9 +285,12 @@ try:
return z3.Product(dims), counter
elif is_algebraic_expression(expr):
lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict)
rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict)
lhs, counter = transform_algebraic_expression(
expr.lhs, counter, dimension_dict
)
rhs, counter = transform_algebraic_expression(
expr.rhs, counter, dimension_dict
)
if expr.op == op_sub:
c = lhs - rhs
@ -219,14 +308,13 @@ try:
c = lhs % rhs
else:
raise NotImplementedError('operation not yet implemented')
raise NotImplementedError("operation not yet implemented")
return c, counter
else:
raise RuntimeError
def transform_all_constraints(traced, counter=0):
"""
Given a trace, generates constraints and transforms them to z3 format
@ -291,7 +379,6 @@ try:
# transform precision, matching, consistency till obtaining a fixed point
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
# since the function returns a list of one element, we get the first element
# we are only interested in the RHS in this case because the LHS just stores
# the result
@ -304,19 +391,27 @@ try:
condition_constraint_rhs = condition_constraint.rhs
# transform the condition constraint
condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter)
condition_constraint_rhs, counter = iterate_till_fixed_point(
condition_constraint_rhs, counter
)
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict)
transformed_condition_constraint, counter = transform_to_z3(
condition_constraint_rhs, counter, dimension_dict
)
negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
negation_transformed_condition_constraint = z3.Not(
transformed_condition_constraint
)
return z3.And([transformed, transformed_condition_constraint]), \
z3.And([transformed, negation_transformed_condition_constraint])
return z3.And([transformed, transformed_condition_constraint]), z3.And(
[transformed, negation_transformed_condition_constraint]
)
def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None):
def evaluate_conditional_with_constraints(
tracer_root, graph, node, counter=0, user_constraints=None
):
"""
Given an IR and a node representing a conditional, evaluate the conditional
and its negation
@ -329,8 +424,10 @@ try:
"""
transformed_positive, transformed_negative = \
transform_all_constraints_trace_time(tracer_root, graph, node, counter)
(
transformed_positive,
transformed_negative,
) = transform_all_constraints_trace_time(tracer_root, graph, node, counter)
s = z3.Solver()
s.add(transformed_positive)

View File

@ -1,6 +1,10 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
BVar
from torch.fx.experimental.migrate_gradual_types.constraint import (
BinConstraintD,
BVar,
DVar,
TVar,
)
from torch.fx.experimental.migrate_gradual_types.operation import op_leq
@ -23,6 +27,7 @@ def gen_dvar(curr):
curr += 1
return DVar(curr), curr
def gen_bvar(curr):
"""
Generate a boolean variable
@ -32,6 +37,7 @@ def gen_bvar(curr):
curr += 1
return BVar(curr), curr
def gen_tensor_dims(n, curr):
"""
Generate a list of tensor dimensions

View File

@ -1,22 +1,23 @@
try:
import z3 # type: ignore[import]
HAS_Z3 = True
# dynamic type
dyn = z3.DeclareSort('Dyn')
dyn_type = z3.Const('dyn', dyn)
dyn = z3.DeclareSort("Dyn")
dyn_type = z3.Const("dyn", dyn)
# dimension
dim = z3.Datatype('dim')
dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort()))
dim = z3.Datatype("dim")
dim.declare("dim", ("0", z3.IntSort()), ("1", z3.IntSort()))
dim = dim.create()
# tensors
tensor_type = z3.Datatype('TensorType')
tensor_type.declare('Dyn', ('dyn', dyn))
tensor_type.declare('tensor1', ('0', dim))
tensor_type.declare('tensor2', ('0', dim), ('1', dim))
tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim))
tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim))
tensor_type = z3.Datatype("TensorType")
tensor_type.declare("Dyn", ("dyn", dyn))
tensor_type.declare("tensor1", ("0", dim))
tensor_type.declare("tensor2", ("0", dim), ("1", dim))
tensor_type.declare("tensor3", ("0", dim), ("1", dim), ("2", dim))
tensor_type.declare("tensor4", ("0", dim), ("1", dim), ("2", dim), ("3", dim))
tensor_type = tensor_type.create()
# create dimension

View File

@ -1,16 +1,16 @@
# mypy: allow-untyped-defs
import operator
from typing import Any, Callable, Dict, Tuple, Optional
from typing import Any, Callable, Dict, Optional, Tuple
import torch
import torch.fx
import torch.fx as fx
from torch.fx import Transformer, Proxy
from torch.fx.node import Argument, Target, Node, map_aggregate
from torch.fx import Proxy, Transformer
from torch.fx.node import Argument, map_aggregate, Node, Target
from torch.fx.operator_schemas import (
normalize_module,
normalize_function,
create_type_hint,
normalize_function,
normalize_module,
)
from .schema_type_annotation import AnnotateTypesWithSchema

View File

@ -1,37 +1,42 @@
# mypy: allow-untyped-defs
import torch.fx as fx
from torch.fx.node import Argument, Target
from torch.nn.utils.fusion import fuse_conv_bn_eval
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fx.passes.shape_prop import ShapeProp
import copy
from collections import defaultdict
import torch.utils.mkldnn as th_mkldnn
import logging
import operator
import time
import logging
from collections import defaultdict
from enum import Enum
from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type
def _parent_name(target : str) -> Tuple[str, str]:
import torch
import torch.fx as fx
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.mkldnn as th_mkldnn
from torch.fx.node import Argument, Target
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn.utils.fusion import fuse_conv_bn_eval
def _parent_name(target: str) -> Tuple[str, str]:
"""
Splits a qualname into parent path and last atom.
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
"""
*parent, name = target.rsplit('.', 1)
return parent[0] if parent else '', name
*parent, name = target.rsplit(".", 1)
return parent[0] if parent else "", name
# Works for length 2 patterns with 2 modules
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
def matches_module_pattern(
pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]
):
if len(node.args) == 0:
return False
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
for expected_type, current_node in zip(pattern, nodes):
if not isinstance(current_node, fx.Node):
return False
if current_node.op != 'call_module':
if current_node.op != "call_module":
return False
if not isinstance(current_node.target, str):
return False
@ -42,20 +47,25 @@ def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict
return True
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
def replace_node_module(
node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module
):
assert isinstance(node.target, str)
parent_name, name = _parent_name(node.target)
modules[node.target] = new_module
setattr(modules[parent_name], name, new_module)
def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
"""
Fuses convolution/BN layers for inference purposes. Will deepcopy your
model by default, but can modify the model inplace as well.
"""
patterns = [(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d)]
patterns = [
(nn.Conv1d, nn.BatchNorm1d),
(nn.Conv2d, nn.BatchNorm2d),
(nn.Conv3d, nn.BatchNorm3d),
]
if not inplace:
model = copy.deepcopy(model)
if not no_trace or not isinstance(model, torch.fx.GraphModule):
@ -80,6 +90,7 @@ def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Modu
new_graph.erase_node(node)
return fx.GraphModule(fx_model, new_graph)
def remove_dropout(model: nn.Module) -> nn.Module:
"""
Removes all dropout layers from the module.
@ -87,15 +98,24 @@ def remove_dropout(model: nn.Module) -> nn.Module:
fx_model = fx.symbolic_trace(model)
class DropoutRemover(torch.fx.Transformer):
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
if isinstance(self.submodules[target], nn.Dropout):
assert len(args) == 1
return args[0]
else:
return super().call_module(target, args, kwargs)
return DropoutRemover(fx_model).transform()
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
def extract_subgraph(
orig_module: nn.Module,
nodes: List[fx.Node],
inputs: List[fx.Node],
outputs: List[fx.Node],
):
"""
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
"""
@ -111,10 +131,21 @@ def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[
new_graph.lint()
return fx.GraphModule(orig_module, new_graph)
mkldnn_supported = [
nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d,
torch.relu, torch.transpose, torch.sigmoid,
F.relu, F.avg_pool2d, F.adaptive_avg_pool2d
nn.Conv2d,
nn.Linear,
nn.BatchNorm2d,
nn.ReLU,
nn.MaxPool2d,
nn.AvgPool2d,
nn.AdaptiveAvgPool2d,
torch.relu,
torch.transpose,
torch.sigmoid,
F.relu,
F.avg_pool2d,
F.adaptive_avg_pool2d,
]
# These are operators that may not be convertible into MKLDNN ops (e.g. the
# args are scalar values). Thus, we only include them in the subgraph if their
@ -124,7 +155,7 @@ mkldnn_supported_unknown = [operator.add, operator.mul]
mkldnn_map = {
nn.Conv2d: th_mkldnn.MkldnnConv2d,
nn.Linear: th_mkldnn.MkldnnLinear,
nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a)
nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a),
}
@ -136,7 +167,7 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
"""
old_modules: Dict[nn.Module, nn.Module] = {}
for node in nodes:
if node.op == 'call_module':
if node.op == "call_module":
assert isinstance(node.target, str)
cur_module = modules[node.target]
if type(cur_module) in mkldnn_map:
@ -146,18 +177,24 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
replace_node_module(node, modules, new_module)
return old_modules
def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]):
def reset_modules(
nodes: List[fx.Node],
modules: Dict[str, nn.Module],
old_modules: Dict[nn.Module, nn.Module],
):
"""
Maps each module that's been changed with `modules_to_mkldnn` back to its
original.
"""
for node in nodes:
if node.op == 'call_module':
assert (isinstance(node.target, str))
if node.op == "call_module":
assert isinstance(node.target, str)
cur_module = modules[node.target]
if cur_module in old_modules:
replace_node_module(node, modules, old_modules[cur_module])
class MklSubgraph:
def __init__(self, fx_graph: fx.Graph):
self.fx_graph = fx_graph
@ -165,6 +202,7 @@ class MklSubgraph:
self.start_nodes: List[fx.Node] = []
self.end_nodes: List[fx.Node] = []
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
"""
This generates a heuristic that can be passed into `optimize_for_inference` that
@ -196,13 +234,21 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
f()
return time.time() - begin
mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
mkl_time = benchmark(
lambda: [
i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])
]
)
reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules)
reset_modules(
submodule.graph.nodes, dict(submodule.named_modules()), old_modules
)
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
return mkl_time < no_mkl_time
return use_mkl_heuristic
def use_mkl_length(graph: MklSubgraph) -> bool:
"""
This is a heuristic that can be passed into `optimize_for_inference` that
@ -211,6 +257,7 @@ def use_mkl_length(graph: MklSubgraph) -> bool:
"""
return len(graph.nodes) > 2
class UnionFind:
def __init__(self, n):
self.parent: List[Optional[int]] = [None] * n
@ -237,10 +284,11 @@ class UnionFind:
self.parent[b] = a
self.size[a] += self.size[b]
def optimize_for_inference(
model: torch.nn.Module,
pass_config: Optional[Dict[str, Any]] = None,
tracer: Type[fx.Tracer] = fx.Tracer
tracer: Type[fx.Tracer] = fx.Tracer,
) -> torch.nn.Module:
"""
Performs a set of optimization passes to optimize a model for the
@ -258,7 +306,7 @@ def optimize_for_inference(
default_pass_config = {
"conv_bn_fuse": True,
"remove_dropout": True,
"mkldnn_layout_optimize": {'heuristic': use_mkl_length},
"mkldnn_layout_optimize": {"heuristic": use_mkl_length},
}
if pass_config is None:
pass_config = {}
@ -292,15 +340,19 @@ def optimize_for_inference(
# a MKLDNN node if its inputs are MKLDNN nodes.
for node in list(fx_graph.nodes):
supports_mkldnn = MklSupport.NO
if node.op == 'call_module':
if node.op == "call_module":
cur_module = modules[node.target]
if type(cur_module) in mkldnn_supported:
supports_mkldnn = MklSupport.YES
sample_parameter = next(cur_module.parameters(), None)
if sample_parameter is not None:
assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules"
assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules"
elif node.op == 'call_function':
assert (
sample_parameter.dtype == torch.float
), "this pass is only for torch.float modules"
assert sample_parameter.device == torch.device(
"cpu"
), "this pass is only for CPU modules"
elif node.op == "call_function":
if node.target in mkldnn_supported:
supports_mkldnn = MklSupport.YES
elif node.target in mkldnn_supported_unknown:
@ -308,15 +360,17 @@ def optimize_for_inference(
if supports_mkldnn != MklSupport.NO:
if supports_mkldnn == MklSupport.UNKNOWN:
if not any(arg.target == 'to_dense' for arg in node.args):
if not any(arg.target == "to_dense" for arg in node.args):
continue
with fx_graph.inserting_before(node):
mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))
mkldnn_args = fx.map_arg(
node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
)
node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
with fx_graph.inserting_after(node):
dense_x = fx_graph.create_node('call_method', 'to_dense', (node,))
dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
node.replace_all_uses_with(dense_x)
dense_x.args = (node,)
@ -326,28 +380,26 @@ def optimize_for_inference(
# optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
for node in fx_graph.nodes:
if node.op == 'call_method' and node.target == 'to_dense':
if node.op == "call_method" and node.target == "to_dense":
prv_node = node.args[0]
users = list(node.users)
for user in users:
if user.op == 'call_method' and user.target == 'to_mkldnn':
if user.op == "call_method" and user.target == "to_mkldnn":
user.replace_all_uses_with(prv_node)
fx_graph.erase_node(user)
if len(node.users) == 0:
fx_graph.erase_node(node)
num_nodes = len(fx_graph.nodes)
uf = UnionFind(num_nodes)
def get_color(n):
if hasattr(n, 'color'): # Current node is part of a MKL subgraph
if hasattr(n, "color"): # Current node is part of a MKL subgraph
return uf.find(n.color)
if hasattr(n, 'start_color'): # Current node is input to MKL subgraph
if hasattr(n, "start_color"): # Current node is input to MKL subgraph
return uf.find(n.start_color)
return None
# This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
# of input nodes (which are only `to_mkldnn` calls), output nodes
# (`to_dense` calls), and intermediate nodes, which are run entirely on
@ -360,14 +412,19 @@ def optimize_for_inference(
# nodes (i.e. colors), we need to join these 2 colors into 1. That's done
# using a Disjoint Set Union.
for cur_idx, node in enumerate(fx_graph.nodes):
if node.op == 'call_method' and node.target == 'to_mkldnn':
if node.op == "call_method" and node.target == "to_mkldnn":
node.start_color = cur_idx
uf.make_set(cur_idx)
elif node.op == 'call_method' and node.target == 'to_dense':
elif node.op == "call_method" and node.target == "to_dense":
assert get_color(node.args[0]) is not None
node.end_color = get_color(node.args[0])
else:
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
cur_colors = [
get_color(i)
for i in node.all_input_nodes
if isinstance(i, fx.Node)
if get_color(i) is not None
]
if len(cur_colors) == 0:
continue
@ -377,17 +434,15 @@ def optimize_for_inference(
for other_color in cur_colors[1:]:
uf.join(cur_colors[0], other_color)
mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
for node in fx_graph.nodes:
if hasattr(node, 'color'):
if hasattr(node, "color"):
mkldnn_graphs[uf.find(node.color)].nodes.append(node)
if hasattr(node, 'start_color'):
if hasattr(node, "start_color"):
mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
if hasattr(node, 'end_color'):
if hasattr(node, "end_color"):
mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
# Now that we have all the subgraphs, we need to decide which MKLDNN
# subgraphs we actually want to keep in MKLDNN.
for graph in mkldnn_graphs.values():
@ -400,7 +455,7 @@ def optimize_for_inference(
mkldnn_conversions = 0
for node in fx_graph.nodes:
if node.target == 'to_mkldnn' or node.target == 'to_dense':
if node.target == "to_mkldnn" or node.target == "to_dense":
mkldnn_conversions += 1
logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions)

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import NamedTuple, Dict, List, Set
from typing import Dict, List, NamedTuple, Set
from torch.fx.node import Node, map_arg
from torch.fx.node import map_arg, Node
class Partition:
@ -146,7 +146,7 @@ def get_latency_of_one_partition(
# this node is on the top bfs level in this partition
if not any(
n in partition.nodes and n.op not in {"placeholder", "get_attr"}
for n in input_nodes
for n in input_nodes
):
top_nodes.append(node)
return top_nodes
@ -279,7 +279,9 @@ def get_latency_of_partitioned_graph(
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
"""This function helps to recursively get the latency of a path of partitions"""
# Update latency by adding current partition's latency
latency_so_far_sec += partition_to_latency_mapping[partition].overall_latency_sec
latency_so_far_sec += partition_to_latency_mapping[
partition
].overall_latency_sec
if partition.children:
max_latency_sec = 0.0

View File

@ -5,10 +5,10 @@ class Equality:
self.rhs = rhs
def __str__(self):
return f'{self.lhs} = {self.rhs}'
return f"{self.lhs} = {self.rhs}"
def __repr__(self):
return f'{self.lhs} = {self.rhs}'
return f"{self.lhs} = {self.rhs}"
def __eq__(self, other):
if isinstance(other, Equality):

View File

@ -1,16 +1,18 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import ast
import inspect
import textwrap
import copy
import functools
import inspect
import textwrap
from types import FunctionType
from typing import cast, Union, Callable, Dict, Optional, Any
from typing import Any, Callable, cast, Dict, Optional, Union
import torch
from torch._sources import normalize_source_lines
from torch.fx._symbolic_trace import Tracer
from torch.fx.graph import Graph
from torch._sources import normalize_source_lines
import torch
class AST_Rewriter(ast.NodeTransformer):
"""
@ -29,11 +31,10 @@ class AST_Rewriter(ast.NodeTransformer):
# suitable for dynamo tracing anyways.
@torch._dynamo.disable
def rewrite(self, fn: FunctionType):
# Normalize the source lines
sourcelines, _ = inspect.getsourcelines(fn)
sourcelines = normalize_source_lines(sourcelines)
source = ''.join(sourcelines)
source = "".join(sourcelines)
normalized_str = textwrap.dedent(source)
# Rewrite the original AST
@ -64,6 +65,7 @@ class AST_Rewriter(ast.NodeTransformer):
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined]
return g
# Return the correct FunctionType object
return change_func_globals(fn_compiled, globals=fn.__globals__)
@ -73,7 +75,7 @@ class AST_Rewriter(ast.NodeTransformer):
symbolically-traceable torch._assert function
"""
# Create the Call node
n = ast.parse('torch._assert()', mode='eval')
n = ast.parse("torch._assert()", mode="eval")
assert isinstance(n, ast.Expression)
call_node = n.body
assert isinstance(call_node, ast.Call)
@ -96,13 +98,22 @@ class AST_Rewriter(ast.NodeTransformer):
Output:
y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
"""
return ast.Assign(targets=[node.target], value=ast.Call(
func=ast.Name(id='annotate', ctx=ast.Load()),
args=[node.value, node.annotation], keywords=[]))
return ast.Assign(
targets=[node.target],
value=ast.Call(
func=ast.Name(id="annotate", ctx=ast.Load()),
args=[node.value, node.annotation],
keywords=[],
),
)
class RewritingTracer(Tracer):
def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
def trace(
self,
root: Union[torch.nn.Module, Callable],
concrete_args: Optional[Dict[str, Any]] = None,
) -> Graph:
return super().trace(_rewrite(root), concrete_args)
@ -111,7 +122,7 @@ def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Cal
# Rewrite this module's `forward` as well as the `forward`s of
# all of this module's recursive descendents. Return the new,
# rewritten module hierarchy.
def rewrite_module(m : torch.nn.Module):
def rewrite_module(m: torch.nn.Module):
class RewrittenModule(torch.nn.Module):
def __init__(self, orig):
super().__init__()
@ -120,8 +131,12 @@ def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Cal
self.__dict__[k] = copy.copy(rewrite_module(v))
else:
self.__dict__[k] = copy.copy(v)
RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
RewrittenModule.forward = AST_Rewriter().rewrite(
cast(FunctionType, m.forward)
)
return RewrittenModule(m)
return rewrite_module(fn)
else:
# Rewrite this single free function

View File

@ -1,13 +1,14 @@
# mypy: allow-untyped-defs
import torch
import torch.fx
import inspect
from typing import Any, Dict, Optional, Tuple
from torch.fx.node import Argument, Target
import torch
import torch.fx
from torch._jit_internal import boolean_dispatched
from torch.fx import Transformer
from torch.fx.node import Argument, Target
from torch.fx.operator_schemas import _torchscript_type_to_python_type
from torch.fx import Transformer
class AnnotateTypesWithSchema(Transformer):
"""
@ -27,16 +28,24 @@ class AnnotateTypesWithSchema(Transformer):
traced = AnnotateTypesWithSchema(traced).transform()
"""
def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True,
annotate_modules : bool = True, annotate_get_attrs : bool = True):
def __init__(
self,
module: torch.nn.Module,
annotate_functionals: bool = True,
annotate_modules: bool = True,
annotate_get_attrs: bool = True,
):
super().__init__(module)
self.annotate_functionals = annotate_functionals
self.annotate_modules = annotate_modules
self.annotate_get_attrs = annotate_get_attrs
def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
def call_function(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
):
python_ret_type = None
if self.annotate_functionals and target.__module__ == 'torch.nn.functional':
if self.annotate_functionals and target.__module__ == "torch.nn.functional":
target_for_analysis = target
if target in boolean_dispatched:
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
@ -45,51 +54,71 @@ class AnnotateTypesWithSchema(Transformer):
# branch signature for analysis. Otherwise, leave this un-normalized
assert not isinstance(target, str)
dispatched = boolean_dispatched[target]
if_true, if_false = dispatched['if_true'], dispatched['if_false']
if_true, if_false = dispatched["if_true"], dispatched["if_false"]
# TODO: can we emit the union of these? What are the implications on TorchScript
# compilation?
if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation:
if (
inspect.signature(if_true).return_annotation
!= inspect.signature(if_false).return_annotation
):
return super().call_function(target, args, kwargs)
target_for_analysis = if_true
python_ret_type = self._extract_python_return_type(target_for_analysis)
return_proxy = super().call_function(target, args, kwargs)
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
return_proxy.node.type = (
return_proxy.node.type if return_proxy.node.type else python_ret_type
)
return return_proxy
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
def call_module(
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
):
python_ret_type = None
assert isinstance(target, str)
submod = self.fetch_attr(target)
if self.annotate_modules and hasattr(submod.__class__, '__name__'):
if self.annotate_modules and hasattr(submod.__class__, "__name__"):
classname = submod.__class__.__name__
if getattr(torch.nn, classname, None) == submod.__class__:
python_ret_type = self._extract_python_return_type(submod.forward)
return_proxy = super().call_module(target, args, kwargs)
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
return_proxy.node.type = (
return_proxy.node.type if return_proxy.node.type else python_ret_type
)
return return_proxy
def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
def get_attr(
self,
target: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Any],
):
attr_proxy = super().get_attr(target, args, kwargs)
if self.annotate_get_attrs:
module_itr = self.module
assert isinstance(target, str)
atoms = target.split('.')
atoms = target.split(".")
for i, atom in enumerate(atoms):
if not hasattr(module_itr, atom):
raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!')
raise RuntimeError(
f'Node referenced nonextent target {".".join(atoms[:i])}!'
)
module_itr = getattr(module_itr, atom)
maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr)
if maybe_inferred_ts_type.success():
python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type())
attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type
python_type = _torchscript_type_to_python_type(
maybe_inferred_ts_type.type()
)
attr_proxy.node.type = (
python_type if not attr_proxy.node.type else attr_proxy.node.type
)
return attr_proxy
def _extract_python_return_type(self, target : Target) -> Optional[Any]:
def _extract_python_return_type(self, target: Target) -> Optional[Any]:
"""
Given a Python call target, try to extract the Python return annotation
if it is available, otherwise return None
@ -109,4 +138,8 @@ class AnnotateTypesWithSchema(Transformer):
except (ValueError, TypeError):
return None
return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None
return (
sig.return_annotation
if sig.return_annotation is not inspect.Signature.empty
else None
)

View File

@ -1,4 +1,4 @@
# mypy: disable-error-code=attr-defined
from .core import unify, reify # noqa: F403
from .core import reify, unify # noqa: F403
from .more import unifiable # noqa: F403
from .variable import var, isvar, vars, variables, Var # noqa: F403
from .variable import isvar, Var, var, variables, vars # noqa: F403

View File

@ -2,10 +2,11 @@
from collections.abc import Iterator # type: ignore[import]
from functools import partial
from .dispatch import dispatch
from .unification_tools import assoc # type: ignore[import]
from .utils import transitive_get as walk
from .variable import isvar
from .dispatch import dispatch
__all__ = ["reify", "unify"]
@ -13,33 +14,47 @@ __all__ = ["reify", "unify"]
# Reification #
###############
@dispatch(Iterator, dict)
def _reify(t, s):
return map(partial(reify, s=s), t)
# return (reify(arg, s) for arg in t)
_reify
@dispatch(tuple, dict) # type: ignore[no-redef]
def _reify(t, s):
return tuple(reify(iter(t), s))
_reify
@dispatch(list, dict) # type: ignore[no-redef]
def _reify(t, s):
return list(reify(iter(t), s))
_reify
@dispatch(dict, dict) # type: ignore[no-redef]
def _reify(d, s):
return {k: reify(v, s) for k, v in d.items()}
_reify
@dispatch(object, dict) # type: ignore[no-redef]
def _reify(o, s):
return o # catch all, just return the object
def reify(e, s):
""" Replace variables of expression with substitution
"""Replace variables of expression with substitution
>>> # xdoctest: +SKIP
>>> x, y = var(), var()
>>> e = (1, x, (3, y))
@ -54,12 +69,14 @@ def reify(e, s):
return reify(s[e], s) if e in s else e
return _reify(e, s)
###############
# Unification #
###############
seq = tuple, list, Iterator
@dispatch(seq, seq, dict)
def _unify(u, v, s):
if len(u) != len(v):
@ -69,6 +86,8 @@ def _unify(u, v, s):
if s is False:
return False
return s
#
# @dispatch((set, frozenset), (set, frozenset), dict)
# def _unify(u, v, s):
@ -98,8 +117,8 @@ def _unify(u, v, s):
@dispatch(object, object, dict)
def unify(u, v, s): # no check at the moment
""" Find substitution so that u == v while satisfying s
>>> x = var('x')
"""Find substitution so that u == v while satisfying s
>>> x = var("x")
>>> unify((1, x), (1, 2), {})
{~x: 2}
"""
@ -112,8 +131,11 @@ def unify(u, v, s): # no check at the moment
if isvar(v):
return assoc(s, v, u)
return _unify(u, v, s)
unify
@dispatch(object, object) # type: ignore[no-redef]
def unify(u, v):
return unify(u, v, {})

View File

@ -1,6 +1,8 @@
from functools import partial
from .multipledispatch import dispatch # type: ignore[import]
namespace = {} # type: ignore[var-annotated]
dispatch = partial(dispatch, namespace=namespace)

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
from .core import unify, reify # type: ignore[attr-defined]
from .variable import isvar
from .core import reify, unify # type: ignore[attr-defined]
from .unification_tools import first, groupby # type: ignore[import]
from .utils import _toposort, freeze
from .unification_tools import groupby, first # type: ignore[import]
from .variable import isvar
class Dispatcher:
@ -28,32 +28,38 @@ class Dispatcher:
if s is not False:
result = self.funcs[signature]
return result, s
raise NotImplementedError("No match found. \nKnown matches: "
+ str(self.ordering) + "\nInput: " + str(args))
raise NotImplementedError(
"No match found. \nKnown matches: "
+ str(self.ordering)
+ "\nInput: "
+ str(args)
)
def register(self, *signature):
def _(func):
self.add(signature, func)
return self
return _
class VarDispatcher(Dispatcher):
""" A dispatcher that calls functions with variable names
"""A dispatcher that calls functions with variable names
>>> # xdoctest: +SKIP
>>> d = VarDispatcher('d')
>>> x = var('x')
>>> @d.register('inc', x)
>>> d = VarDispatcher("d")
>>> x = var("x")
>>> @d.register("inc", x)
... def f(x):
... return x + 1
>>> @d.register('double', x)
>>> @d.register("double", x)
... def f(x):
... return x * 2
>>> d('inc', 10)
>>> d("inc", 10)
11
>>> d('double', 10)
>>> d("double", 10)
20
"""
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
d = {k.token: v for k, v in s.items()}
@ -64,8 +70,8 @@ global_namespace = {} # type: ignore[var-annotated]
def match(*signature, **kwargs):
namespace = kwargs.get('namespace', global_namespace)
dispatcher = kwargs.get('Dispatcher', Dispatcher)
namespace = kwargs.get("namespace", global_namespace)
dispatcher = kwargs.get("Dispatcher", Dispatcher)
def _(func):
name = func.__name__
@ -77,11 +83,12 @@ def match(*signature, **kwargs):
d.add(signature, func)
return d
return _
def supercedes(a, b):
""" ``a`` is a more specific match than ``b`` """
"""``a`` is a more specific match than ``b``"""
if isvar(b) and not isvar(a):
return True
s = unify(a, b)
@ -96,7 +103,7 @@ def supercedes(a, b):
# Taken from multipledispatch
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
"""A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
if supercedes(a, b):
@ -109,7 +116,7 @@ def edge(a, b, tie_breaker=hash):
# Taken from multipledispatch
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
"""A sane ordering of signatures to check, first to last
Topological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))

View File

@ -1,10 +1,10 @@
# mypy: allow-untyped-defs
from .core import unify, reify # type: ignore[attr-defined]
from .core import reify, unify # type: ignore[attr-defined]
from .dispatch import dispatch
def unifiable(cls):
""" Register standard unify and reify operations on class
"""Register standard unify and reify operations on class
This uses the type and __dict__ or __slots__ attributes to define the
nature of the term
See Also:
@ -15,7 +15,7 @@ def unifiable(cls):
... self.b = b
>>> unifiable(A)
<class 'unification.more.A'>
>>> x = var('x')
>>> x = var("x")
>>> a = A(1, 2)
>>> b = A(1, x)
>>> unify(a, b, {})
@ -33,22 +33,23 @@ def unifiable(cls):
def reify_object(o, s):
""" Reify a Python object with a substitution
"""Reify a Python object with a substitution
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
...
... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> x = var('x')
... return "Foo(%s, %s)" % (str(self.a), str(self.b))
>>> x = var("x")
>>> f = Foo(1, x)
>>> print(f)
Foo(1, ~x)
>>> print(reify_object(f, {x: 2}))
Foo(1, 2)
"""
if hasattr(o, '__slots__'):
if hasattr(o, "__slots__"):
return _reify_object_slots(o, s)
else:
return _reify_object_dict(o, s)
@ -77,7 +78,7 @@ def _reify_object_slots(o, s):
@dispatch(slice, dict)
def _reify(o, s):
""" Reify a Python ``slice`` object """
"""Reify a Python ``slice`` object"""
return slice(*reify((o.start, o.stop, o.step), s))
@ -87,16 +88,17 @@ def _reify(o, s):
def unify_object(u, v, s):
""" Unify two Python objects
"""Unify two Python objects
Unifies their type and ``__dict__`` attributes
>>> # xdoctest: +SKIP
>>> class Foo(object):
... def __init__(self, a, b):
... self.a = a
... self.b = b
...
... def __str__(self):
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
>>> x = var('x')
... return "Foo(%s, %s)" % (str(self.a), str(self.b))
>>> x = var("x")
>>> f = Foo(1, x)
>>> g = Foo(1, 2)
>>> unify_object(f, g, {})
@ -104,15 +106,17 @@ def unify_object(u, v, s):
"""
if type(u) != type(v):
return False
if hasattr(u, '__slots__'):
return unify([getattr(u, slot) for slot in u.__slots__],
[getattr(v, slot) for slot in v.__slots__],
s)
if hasattr(u, "__slots__"):
return unify(
[getattr(u, slot) for slot in u.__slots__],
[getattr(v, slot) for slot in v.__slots__],
s,
)
else:
return unify(u.__dict__, v.__dict__, s)
@dispatch(slice, slice, dict)
def _unify(u, v, s):
""" Unify a Python ``slice`` object """
"""Unify a Python ``slice`` object"""
return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)

View File

@ -1,3 +1,7 @@
from .core import dispatch
from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
MDNotImplementedError)
from .dispatcher import (
Dispatcher,
halt_ordering,
MDNotImplementedError,
restart_ordering,
)

View File

@ -1,17 +1,28 @@
# mypy: allow-untyped-defs
from .utils import _toposort, groupby
from .variadic import isvariadic
import operator
__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature",
"edge", "ordering"]
from .utils import _toposort, groupby
from .variadic import isvariadic
__all__ = [
"AmbiguityWarning",
"supercedes",
"consistent",
"ambiguous",
"ambiguities",
"super_signature",
"edge",
"ordering",
]
class AmbiguityWarning(Warning):
pass
def supercedes(a, b):
""" A is consistent and strictly more specific than B """
"""A is consistent and strictly more specific than B"""
if len(a) < len(b):
# only case is if a is empty and b is variadic
return not a and len(b) == 1 and isvariadic(b[-1])
@ -41,7 +52,7 @@ def supercedes(a, b):
def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
"""It is possible for an argument list to satisfy both A and B"""
# Need to check for empty args
if not a:
@ -51,8 +62,7 @@ def consistent(a, b):
# Non-empty args check for mutual subclasses
if len(a) == len(b):
return all(issubclass(aa, bb) or issubclass(bb, aa)
for aa, bb in zip(a, b))
return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b))
else:
p1 = 0
p2 = 0
@ -70,45 +80,53 @@ def consistent(a, b):
p1 += 1
# We only need to check for variadic ends
# Variadic types are guaranteed to be the last element
return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined]
isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined]
return (
isvariadic(cur_a) # type: ignore[possibly-undefined]
and p2 == len(b)
or isvariadic(cur_b) # type: ignore[possibly-undefined]
and p1 == len(a)
)
def ambiguous(a, b):
""" A is consistent with B but neither is strictly more specific """
"""A is consistent with B but neither is strictly more specific"""
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
def ambiguities(signatures):
""" All signature pairs such that A is ambiguous with B """
"""All signature pairs such that A is ambiguous with B"""
signatures = list(map(tuple, signatures))
return {(a, b) for a in signatures for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b)
for c in signatures)}
return {
(a, b)
for a in signatures
for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b) for c in signatures)
}
def super_signature(signatures):
""" A signature that would break ambiguities """
"""A signature that would break ambiguities"""
n = len(signatures[0])
assert all(len(s) == n for s in signatures)
return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
for i in range(n)]
return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)]
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
"""A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
# A either supercedes B and B does not supercede A or if B does then call
# tie_breaker
return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b))
return supercedes(a, b) and (
not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)
)
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
"""A sane ordering of signatures to check, first to last
Topological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))

View File

@ -4,12 +4,14 @@ import sys
from .dispatcher import Dispatcher, MethodDispatcher
global_namespace = {} # type: ignore[var-annotated]
__all__ = ["dispatch", "ismethod"]
def dispatch(*types, **kwargs):
""" Dispatch function on the types of the inputs
"""Dispatch function on the types of the inputs
Supports dispatch on all non-keyword arguments.
Collects implementations based on the function name. Ignores namespaces.
If ambiguous type signatures occur a warning is raised when the function is
@ -38,6 +40,7 @@ def dispatch(*types, **kwargs):
... @dispatch(list)
... def __init__(self, data):
... self.data = data
...
... @dispatch(int)
... def __init__(self, datum):
... self.data = [datum]
@ -46,7 +49,7 @@ def dispatch(*types, **kwargs):
>>> MyClass(3).data
[3]
"""
namespace = kwargs.get('namespace', global_namespace)
namespace = kwargs.get("namespace", global_namespace)
types = tuple(types)
@ -65,20 +68,21 @@ def dispatch(*types, **kwargs):
dispatcher.add(types, func)
return dispatcher
return _df
def ismethod(func):
""" Is func a method?
"""Is func a method?
Note that this has to work as the method is defined but before the class is
defined. At this stage methods look like functions.
"""
if hasattr(inspect, "signature"):
signature = inspect.signature(func)
return signature.parameters.get('self', None) is not None
return signature.parameters.get("self", None) is not None
else:
if sys.version_info.major < 3:
spec = inspect.getargspec(func) # type: ignore[attr-defined]
else:
spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment]
return spec and spec.args and spec.args[0] == 'self'
return spec and spec.args and spec.args[0] == "self"

View File

@ -1,21 +1,35 @@
# mypy: allow-untyped-defs
from warnings import warn
import inspect
from typing_extensions import deprecated
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
from .variadic import Variadic, isvariadic
import itertools as itl
from typing_extensions import deprecated
from warnings import warn
from .conflict import ambiguities, AmbiguityWarning, ordering, super_signature
from .utils import expand_tuples
from .variadic import isvariadic, Variadic
__all__ = [
"MDNotImplementedError",
"ambiguity_warn",
"halt_ordering",
"restart_ordering",
"variadic_signature_matches_iter",
"variadic_signature_matches",
"Dispatcher",
"source",
"MethodDispatcher",
"str_signature",
"warning_text",
]
__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
"variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
class MDNotImplementedError(NotImplementedError):
""" A NotImplementedError for multiple dispatch """
"""A NotImplementedError for multiple dispatch"""
def ambiguity_warn(dispatcher, ambiguities):
""" Raise warning when ambiguity is detected
"""Raise warning when ambiguity is detected
Parameters
----------
dispatcher : Dispatcher
@ -92,7 +106,7 @@ def variadic_signature_matches(types, full_signature):
class Dispatcher:
""" Dispatch methods based on type signature
"""Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Examples
--------
@ -109,7 +123,8 @@ class Dispatcher:
>>> f(3.0)
2.0
"""
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
__slots__ = "__name__", "name", "funcs", "_ordering", "_cache", "doc"
def __init__(self, name, doc=None):
self.name = self.__name__ = name
@ -119,9 +134,9 @@ class Dispatcher:
self._cache = {}
def register(self, *types, **kwargs):
""" register dispatcher with new implementation
"""register dispatcher with new implementation
>>> # xdoctest: +SKIP
>>> f = Dispatcher('f')
>>> f = Dispatcher("f")
>>> @f.register(int)
... def inc(x):
... return x + 1
@ -139,9 +154,11 @@ class Dispatcher:
>>> f([1, 2, 3])
[3, 2, 1]
"""
def _df(func):
self.add(types, func, **kwargs) # type: ignore[call-arg]
self.add(types, func, **kwargs) # type: ignore[call-arg]
return func
return _df
@classmethod
@ -152,28 +169,27 @@ class Dispatcher:
@classmethod
def get_func_annotations(cls, func):
""" get annotations of function positional parameters
"""
"""get annotations of function positional parameters"""
params = cls.get_func_params(func)
if params:
Parameter = inspect.Parameter
params = (param for param in params
if param.kind in
(Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD))
params = (
param
for param in params
if param.kind
in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
)
annotations = tuple(
param.annotation
for param in params)
annotations = tuple(param.annotation for param in params)
if all(ann is not Parameter.empty for ann in annotations):
return annotations
def add(self, signature, func):
""" Add new types/method pair to dispatcher
"""Add new types/method pair to dispatcher
>>> # xdoctest: +SKIP
>>> D = Dispatcher('add')
>>> D = Dispatcher("add")
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D(1, 2)
@ -202,24 +218,25 @@ class Dispatcher:
for index, typ in enumerate(signature, start=1):
if not isinstance(typ, (type, list)):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError(f"Tried to dispatch on non-type: {typ}\n"
f"In signature: <{str_sig}>\n"
f"In function: {self.name}")
str_sig = ", ".join(
c.__name__ if isinstance(c, type) else str(c) for c in signature
)
raise TypeError(
f"Tried to dispatch on non-type: {typ}\n"
f"In signature: <{str_sig}>\n"
f"In function: {self.name}"
)
# handle variadic signatures
if isinstance(typ, list):
if index != len(signature):
raise TypeError(
'Variadic signature must be the last element'
)
raise TypeError("Variadic signature must be the last element")
if len(typ) != 1:
raise TypeError(
'Variadic signature must contain exactly one element. '
'To use a variadic union type place the desired types '
'inside of a tuple, e.g., [(int, str)]'
"Variadic signature must contain exactly one element. "
"To use a variadic union type place the desired types "
"inside of a tuple, e.g., [(int, str)]"
)
new_signature.append(Variadic[typ[0]])
else:
@ -255,7 +272,8 @@ class Dispatcher:
func = self.dispatch(*types)
if not func:
raise NotImplementedError(
f'Could not find signature for {self.name}: <{str_signature(types)}>') from e
f"Could not find signature for {self.name}: <{str_signature(types)}>"
) from e
self._cache[types] = func
try:
return func(*args, **kwargs)
@ -271,10 +289,12 @@ class Dispatcher:
raise NotImplementedError(
"Matching functions for "
f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e
f"{self.name}: <{str_signature(types)}> found, but none completed successfully",
) from e
def __str__(self):
return f"<dispatched {self.name}>"
__repr__ = __str__
def dispatch(self, *types):
@ -304,7 +324,6 @@ class Dispatcher:
return None
def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if len(signature) == n and all(map(issubclass, types, signature)):
@ -315,21 +334,22 @@ class Dispatcher:
result = self.funcs[signature]
yield result
@deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning)
@deprecated(
"`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning
)
def resolve(self, types):
""" Determine appropriate implementation for this type signature
"""Determine appropriate implementation for this type signature
.. deprecated:: 0.4.4
Use ``dispatch(*types)`` instead
"""
return self.dispatch(*types)
def __getstate__(self):
return {'name': self.name,
'funcs': self.funcs}
return {"name": self.name, "funcs": self.funcs}
def __setstate__(self, d):
self.name = d['name']
self.funcs = d['funcs']
self.name = d["name"]
self.funcs = d["funcs"]
self._ordering = ordering(self.funcs)
self._cache = {}
@ -344,23 +364,23 @@ class Dispatcher:
for sig in self.ordering[::-1]:
func = self.funcs[sig]
if func.__doc__:
s = f'Inputs: <{str_signature(sig)}>\n'
s += '-' * len(s) + '\n'
s = f"Inputs: <{str_signature(sig)}>\n"
s += "-" * len(s) + "\n"
s += func.__doc__.strip()
docs.append(s)
else:
other.append(str_signature(sig))
if other:
docs.append('Other signatures:\n ' + '\n '.join(other))
docs.append("Other signatures:\n " + "\n ".join(other))
return '\n\n'.join(docs)
return "\n\n".join(docs)
def _help(self, *args):
return self.dispatch(*map(type, args)).__doc__
def help(self, *args, **kwargs):
""" Print docstring for the function corresponding to inputs """
"""Print docstring for the function corresponding to inputs"""
print(self._help(*args))
def _source(self, *args):
@ -370,22 +390,23 @@ class Dispatcher:
return source(func)
def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
"""Print source code for the function corresponding to inputs"""
print(self._source(*args))
def source(func):
s = f'File: {inspect.getsourcefile(func)}\n\n'
s = f"File: {inspect.getsourcefile(func)}\n\n"
s = s + inspect.getsource(func)
return s
class MethodDispatcher(Dispatcher):
""" Dispatch methods based on type signature
"""Dispatch methods based on type signature
See Also:
Dispatcher
"""
__slots__ = ('obj', 'cls')
__slots__ = ("obj", "cls")
@classmethod
def get_func_params(cls, func):
@ -402,26 +423,31 @@ class MethodDispatcher(Dispatcher):
types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>')
raise NotImplementedError(
f"Could not find signature for {self.name}: <{str_signature(types)}>"
)
return func(self.obj, *args, **kwargs)
def str_signature(sig):
""" String representation of type signature
"""String representation of type signature
>>> str_signature((int, float))
'int, float'
"""
return ', '.join(cls.__name__ for cls in sig)
return ", ".join(cls.__name__ for cls in sig)
def warning_text(name, amb):
""" The text for ambiguity warnings """
"""The text for ambiguity warnings"""
text = f"\nAmbiguities exist in dispatched function {name}\n\n"
text += "The following signatures may result in ambiguous behavior:\n"
for pair in amb:
text += "\t" + \
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n"
text += "\n\nConsider making the following additions:\n\n"
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
+ f')\ndef {name}(...)' for s in amb])
text += "\n\n".join(
[
"@dispatch(" + str_signature(super_signature(s)) + f")\ndef {name}(...)"
for s in amb
]
)
return text

View File

@ -1,8 +1,10 @@
# mypy: allow-untyped-defs
from collections import OrderedDict
__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
def raises(err, lamda):
try:
lamda()
@ -31,12 +33,12 @@ def expand_tuples(L):
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
"""Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> _toposort({1: (2, 3), 2: (3, )})
>>> _toposort({1: (2, 3), 2: (3,)})
[1, 2, 3]
>>> # Closely follows the wikipedia page [2]
>>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
@ -44,8 +46,7 @@ def _toposort(edges):
>>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = OrderedDict((k, set(val))
for k, val in incoming_edges.items())
incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items())
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
L = []
@ -64,7 +65,7 @@ def _toposort(edges):
def reverse_dict(d):
"""Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> d = {"a": (1, 2), "b": (2, 3), "c": ()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
@ -82,8 +83,8 @@ def reverse_dict(d):
# Taken from toolz
# Avoids licensing issues because this version was authored by Matthew Rocklin
def groupby(func, seq):
""" Group a collection by a key function
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
"""Group a collection by a key function
>>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
>>> iseven = lambda x: x % 2 == 0

View File

@ -1,15 +1,17 @@
# mypy: allow-untyped-defs
from .utils import typename
__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
class VariadicSignatureType(type):
# checking if subclass is a subclass of self
def __subclasscheck__(cls, subclass):
other_type = (subclass.variadic_type if isvariadic(subclass)
else (subclass,))
other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,)
return subclass is cls or all(
issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined]
issubclass(other, cls.variadic_type) # type: ignore[attr-defined]
for other in other_type
)
def __eq__(cls, other):
@ -24,8 +26,7 @@ class VariadicSignatureType(type):
bool
Whether or not `other` is equal to `self`
"""
return (isvariadic(other) and
set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined]
return isvariadic(other) and set(cls.variadic_type) == set(other.variadic_type) # type: ignore[attr-defined]
def __hash__(cls):
return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined]
@ -57,17 +58,20 @@ class VariadicSignatureMeta(type):
generate a new type for Variadic signatures. See the Variadic class for
examples of how this behaves.
"""
def __getitem__(cls, variadic_type):
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
raise ValueError("Variadic types must be type or tuple of types"
" (Variadic[int] or Variadic[(int, float)]")
raise ValueError(
"Variadic types must be type or tuple of types"
" (Variadic[int] or Variadic[(int, float)]"
)
if not isinstance(variadic_type, tuple):
variadic_type = variadic_type,
variadic_type = (variadic_type,)
return VariadicSignatureType(
f'Variadic[{typename(variadic_type)}]',
f"Variadic[{typename(variadic_type)}]",
(),
dict(variadic_type=variadic_type, __slots__=())
dict(variadic_type=variadic_type, __slots__=()),
)

View File

@ -1,25 +1,40 @@
# mypy: allow-untyped-defs
import collections
import operator
from functools import reduce
from collections.abc import Mapping
from functools import reduce
__all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
'valfilter', 'keyfilter', 'itemfilter',
'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in']
__all__ = [
"merge",
"merge_with",
"valmap",
"keymap",
"itemmap",
"valfilter",
"keyfilter",
"itemfilter",
"assoc",
"dissoc",
"assoc_in",
"update_in",
"get_in",
]
def _get_factory(f, kwargs):
factory = kwargs.pop('factory', dict)
factory = kwargs.pop("factory", dict)
if kwargs:
raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
raise TypeError(
f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'"
)
return factory
def merge(*dicts, **kwargs):
""" Merge a collection of dictionaries
"""Merge a collection of dictionaries
>>> merge({1: 'one'}, {2: 'two'})
>>> merge({1: "one"}, {2: "two"})
{1: 'one', 2: 'two'}
Later dictionaries have precedence
@ -41,7 +56,7 @@ def merge(*dicts, **kwargs):
def merge_with(func, *dicts, **kwargs):
""" Merge dictionaries and apply function to combined values
"""Merge dictionaries and apply function to combined values
A key may occur in more than one dict, and all values mapped from the key
will be passed to the function as a list, such as func([val1, val2, ...]).
@ -70,7 +85,7 @@ def merge_with(func, *dicts, **kwargs):
def valmap(func, d, factory=dict):
""" Apply function to values of dictionary
"""Apply function to values of dictionary
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
>>> valmap(sum, bills) # doctest: +SKIP
@ -86,7 +101,7 @@ def valmap(func, d, factory=dict):
def keymap(func, d, factory=dict):
""" Apply function to keys of dictionary
"""Apply function to keys of dictionary
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
>>> keymap(str.lower, bills) # doctest: +SKIP
@ -102,7 +117,7 @@ def keymap(func, d, factory=dict):
def itemmap(func, d, factory=dict):
""" Apply function to items of dictionary
"""Apply function to items of dictionary
>>> accountids = {"Alice": 10, "Bob": 20}
>>> itemmap(reversed, accountids) # doctest: +SKIP
@ -118,7 +133,7 @@ def itemmap(func, d, factory=dict):
def valfilter(predicate, d, factory=dict):
""" Filter items in dictionary by value
"""Filter items in dictionary by value
>>> iseven = lambda x: x % 2 == 0
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
@ -138,7 +153,7 @@ def valfilter(predicate, d, factory=dict):
def keyfilter(predicate, d, factory=dict):
""" Filter items in dictionary by key
"""Filter items in dictionary by key
>>> iseven = lambda x: x % 2 == 0
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
@ -158,7 +173,7 @@ def keyfilter(predicate, d, factory=dict):
def itemfilter(predicate, d, factory=dict):
""" Filter items in dictionary by item
"""Filter items in dictionary by item
>>> def isvalid(item):
... k, v = item
@ -182,13 +197,13 @@ def itemfilter(predicate, d, factory=dict):
def assoc(d, key, value, factory=dict):
""" Return a new dict with new key value pair
"""Return a new dict with new key value pair
New dict has d[key] set to value. Does not modify the initial dictionary.
>>> assoc({'x': 1}, 'x', 2)
>>> assoc({"x": 1}, "x", 2)
{'x': 2}
>>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
>>> assoc({"x": 1}, "y", 3) # doctest: +SKIP
{'x': 1, 'y': 3}
"""
d2 = factory()
@ -198,22 +213,22 @@ def assoc(d, key, value, factory=dict):
def dissoc(d, *keys, **kwargs):
""" Return a new dict with the given key(s) removed.
"""Return a new dict with the given key(s) removed.
New dict has d[key] deleted for each supplied key.
Does not modify the initial dictionary.
>>> dissoc({'x': 1, 'y': 2}, 'y')
>>> dissoc({"x": 1, "y": 2}, "y")
{'x': 1}
>>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
>>> dissoc({"x": 1, "y": 2}, "y", "x")
{}
>>> dissoc({'x': 1}, 'y') # Ignores missing keys
>>> dissoc({"x": 1}, "y") # Ignores missing keys
{'x': 1}
"""
factory = _get_factory(dissoc, kwargs)
d2 = factory()
if len(keys) < len(d) * .6:
if len(keys) < len(d) * 0.6:
d2.update(d)
for key in keys:
if key in d2:
@ -227,13 +242,14 @@ def dissoc(d, *keys, **kwargs):
def assoc_in(d, keys, value, factory=dict):
""" Return a new dict with new, potentially nested, key value pair
"""Return a new dict with new, potentially nested, key value pair
>>> purchase = {'name': 'Alice',
... 'order': {'items': ['Apple', 'Orange'],
... 'costs': [0.50, 1.25]},
... 'credit card': '5555-1234-1234-1234'}
>>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
>>> purchase = {
... "name": "Alice",
... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
... "credit card": "5555-1234-1234-1234",
... }
>>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP
{'credit card': '5555-1234-1234-1234',
'name': 'Alice',
'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
@ -242,7 +258,7 @@ def assoc_in(d, keys, value, factory=dict):
def update_in(d, keys, func, default=None, factory=dict):
""" Update value in a (potentially) nested dictionary
"""Update value in a (potentially) nested dictionary
inputs:
d - dictionary on which to operate
@ -257,14 +273,15 @@ def update_in(d, keys, func, default=None, factory=dict):
specified by the keys, with the innermost value set to func(default).
>>> inc = lambda x: x + 1
>>> update_in({'a': 0}, ['a'], inc)
>>> update_in({"a": 0}, ["a"], inc)
{'a': 1}
>>> transaction = {'name': 'Alice',
... 'purchase': {'items': ['Apple', 'Orange'],
... 'costs': [0.50, 1.25]},
... 'credit card': '5555-1234-1234-1234'}
>>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
>>> transaction = {
... "name": "Alice",
... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
... "credit card": "5555-1234-1234-1234",
... }
>>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP
{'credit card': '5555-1234-1234-1234',
'name': 'Alice',
'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
@ -272,7 +289,7 @@ def update_in(d, keys, func, default=None, factory=dict):
>>> # updating a value when k0 is not in d
>>> update_in({}, [1, 2, 3], str, default="bar")
{1: {2: {3: 'bar'}}}
>>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
>>> update_in({1: "foo"}, [2, 3, 4], inc, 0)
{1: 'foo', 2: {3: {4: 1}}}
"""
ks = iter(keys)
@ -300,7 +317,7 @@ def update_in(d, keys, func, default=None, factory=dict):
def get_in(keys, coll, default=None, no_default=False):
""" Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
"""Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
``no_default`` is specified, then it raises KeyError or IndexError.
@ -308,20 +325,21 @@ def get_in(keys, coll, default=None, no_default=False):
``get_in`` is a generalization of ``operator.getitem`` for nested data
structures such as dictionaries and lists.
>>> transaction = {'name': 'Alice',
... 'purchase': {'items': ['Apple', 'Orange'],
... 'costs': [0.50, 1.25]},
... 'credit card': '5555-1234-1234-1234'}
>>> get_in(['purchase', 'items', 0], transaction)
>>> transaction = {
... "name": "Alice",
... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
... "credit card": "5555-1234-1234-1234",
... }
>>> get_in(["purchase", "items", 0], transaction)
'Apple'
>>> get_in(['name'], transaction)
>>> get_in(["name"], transaction)
'Alice'
>>> get_in(['purchase', 'total'], transaction)
>>> get_in(['purchase', 'items', 'apple'], transaction)
>>> get_in(['purchase', 'items', 10], transaction)
>>> get_in(['purchase', 'total'], transaction, 0)
>>> get_in(["purchase", "total"], transaction)
>>> get_in(["purchase", "items", "apple"], transaction)
>>> get_in(["purchase", "items", 10], transaction)
>>> get_in(["purchase", "total"], transaction, 0)
0
>>> get_in(['y'], {}, no_default=True)
>>> get_in(["y"], {}, no_default=True)
Traceback (most recent call last):
...
KeyError: 'y'
@ -352,9 +370,9 @@ def getter(index):
def groupby(key, seq):
""" Group a collection by a key function
"""Group a collection by a key function
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
@ -364,9 +382,14 @@ def groupby(key, seq):
Non-callable keys imply grouping on a member.
>>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
... {'name': 'Bob', 'gender': 'M'},
... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
>>> groupby(
... "gender",
... [
... {"name": "Alice", "gender": "F"},
... {"name": "Bob", "gender": "M"},
... {"name": "Charlie", "gender": "M"},
... ],
... ) # doctest:+SKIP
{'F': [{'gender': 'F', 'name': 'Alice'}],
'M': [{'gender': 'M', 'name': 'Bob'},
{'gender': 'M', 'name': 'Charlie'}]}
@ -388,9 +411,9 @@ def groupby(key, seq):
def first(seq):
""" The first element in a sequence
"""The first element in a sequence
>>> first('ABC')
>>> first("ABC")
'A'
"""
return next(iter(seq))

View File

@ -1,5 +1,7 @@
# mypy: allow-untyped-defs
__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
def hashable(x):
try:
hash(x)
@ -9,7 +11,7 @@ def hashable(x):
def transitive_get(key, d):
""" Transitive dict.get
"""Transitive dict.get
>>> d = {1: 2, 2: 3, 3: 4}
>>> d.get(1)
2
@ -32,13 +34,13 @@ def raises(err, lamda):
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
"""Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> # xdoctest: +SKIP
>>> _toposort({1: (2, 3), 2: (3, )})
>>> _toposort({1: (2, 3), 2: (3,)})
[1, 2, 3]
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
@ -47,7 +49,7 @@ def _toposort(edges):
"""
incoming_edges = reverse_dict(edges)
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = ({v for v in edges if v not in incoming_edges})
S = {v for v in edges if v not in incoming_edges}
L = []
while S:
@ -65,7 +67,7 @@ def _toposort(edges):
def reverse_dict(d):
"""Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> d = {"a": (1, 2), "b": (2, 3), "c": ()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
@ -89,12 +91,12 @@ def xfail(func):
def freeze(d):
""" Freeze container to hashable form
"""Freeze container to hashable form
>>> freeze(1)
1
>>> freeze([1, 2])
(1, 2)
>>> freeze({1: 2}) # doctest: +SKIP
>>> freeze({1: 2}) # doctest: +SKIP
frozenset([(1, 2)])
"""
if isinstance(d, dict):

View File

@ -1,14 +1,16 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from .utils import hashable
from .dispatch import dispatch
from .utils import hashable
_global_logic_variables = set() # type: ignore[var-annotated]
_glv = _global_logic_variables
class Var:
""" Logic Variable """
"""Logic Variable"""
_id = 1
@ -25,6 +27,7 @@ class Var:
def __str__(self):
return "~" + str(self.token) # type: ignore[attr-defined]
__repr__ = __str__
def __eq__(self, other):
@ -46,6 +49,7 @@ def vars():
def isvar(v):
return True
isvar
@ -69,12 +73,12 @@ def variables(*variables):
False
>>> # Normal approach
>>> from unification import unify
>>> x = var('x')
>>> x = var("x")
>>> unify(x, 1)
{~x: 1}
>>> # Context Manager approach
>>> with variables('x'):
... print(unify('x', 1))
>>> with variables("x"):
... print(unify("x", 1))
{'x': 1}
"""
old_global_logic_variables = _global_logic_variables.copy()

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from torch.fx.experimental.graph_gradual_typechecker import Refine
from torch.fx.experimental.unification import unify, Var # type: ignore[attr-defined]
from torch.fx.tensor_type import TensorType
from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined]
def infer_symbolic_types_single_pass(traced):
@ -13,6 +13,7 @@ def infer_symbolic_types_single_pass(traced):
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
def infer_symbolic_types(traced):
"""
Calls our symbolic inferencer twice.
@ -32,6 +33,7 @@ def infer_symbolic_types(traced):
r.symbolic_relations()
def convert_eq(list_of_eq):
"""
Convert equality constraints in the right format
@ -109,6 +111,7 @@ def substitute_all_types(graph, mapping):
for n in graph.nodes:
n.type = substitute_solution_one_type(mapping, n.type)
def check_for_type_equality(g1, g2):
"""
A check equality to be used in fixed points.

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,7 @@ from torch.package import Importer, PackageExporter, PackageImporter, sys_import
from ._compatibility import compatibility
from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
__all__ = [
"reduce_graph_module",
"reduce_package_graph_module",
@ -386,11 +387,9 @@ class _WrappedCall:
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
topmost_framesummary: (
traceback.FrameSummary
) = traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[
-1
] # type: ignore[arg-type]
topmost_framesummary: traceback.FrameSummary = (
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]
)
if "eval_with_key" in topmost_framesummary.filename:
print(
_WrappedCall._generate_error_message(topmost_framesummary),
@ -612,20 +611,20 @@ class {module_name}(torch.nn.Module):
module_str = (
f"torch.load(r'{module_file}', weights_only=False) # {module_repr}"
)
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
model_str += f"{tab * 2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in self._buffers.items():
if buffer is None:
continue
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950
for param_name, param in self._parameters.items():
if param is None:
continue
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
model_str += f"{tab * 2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" # noqa: B950
model_str += (
f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
f"{tab * 2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
)
model_str += f"{_addindent(self.code, 4)}\n"
@ -667,7 +666,6 @@ class {module_name}(torch.nn.Module):
mod: torch.nn.Module = self
for item in prefix:
submod = getattr(mod, item, None)
if submod is None:
@ -707,7 +705,6 @@ class {module_name}(torch.nn.Module):
# Get the parent module
for item in path:
if not hasattr(mod, item):
return False
@ -743,9 +740,7 @@ class {module_name}(torch.nn.Module):
used: List[str] = []
for node in self.graph.nodes:
if node.op == "call_module" or node.op == "get_attr":
# A list of strings representing the different parts
# of the path. For example, `foo.bar.baz` gives us
# ["foo", "bar", "baz"]

View File

@ -1,20 +1,24 @@
# mypy: allow-untyped-defs
from .graph_module import GraphModule
from ._lazy_graph_module import _make_graph_module
from .graph import Graph
from .node import Argument, Node, Target, map_arg, map_aggregate
from .proxy import Proxy
from ._symbolic_trace import Tracer
from ._compatibility import compatibility
from . import config
import torch.fx.traceback as fx_traceback
import torch
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import inspect
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.fx.traceback as fx_traceback
from torch.hub import tqdm
__all__ = ['Interpreter', 'Transformer']
from . import config
from ._compatibility import compatibility
from ._lazy_graph_module import _make_graph_module
from ._symbolic_trace import Tracer
from .graph import Graph
from .graph_module import GraphModule
from .node import Argument, map_aggregate, map_arg, Node, Target
from .proxy import Proxy
__all__ = ["Interpreter", "Transformer"]
@compatibility(is_backward_compatible=True)
class Interpreter:
@ -43,22 +47,22 @@ class Interpreter:
method equivalents). We could subclass Interpreter like so::
class NegSigmSwapInterpreter(Interpreter):
def call_function(self, target : Target,
args : Tuple, kwargs : Dict) -> Any:
def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)
def call_method(self, target : Target,
args : Tuple, kwargs : Dict) -> Any:
if target == 'neg':
def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
if target == "neg":
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)
def fn(x):
return torch.sigmoid(x).neg()
gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
@ -74,15 +78,21 @@ class Interpreter:
graph instead of `module.graph`, using the provided `module`
argument to satisfy any requests for state.
"""
@compatibility(is_backward_compatible=True)
def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None):
def __init__(
self,
module: torch.nn.Module,
garbage_collect_values: bool = True,
graph: Optional[Graph] = None,
):
self.module = module
self.submodules = dict(self.module.named_modules())
if graph is not None:
self.graph = graph
else:
self.graph = self.module.graph
self.env : Dict[Node, Any] = {}
self.env: Dict[Node, Any] = {}
self.name = "Interpreter"
self.garbage_collect_values = garbage_collect_values
self.extra_traceback = True
@ -92,10 +102,10 @@ class Interpreter:
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use : Dict[Node, Node] = {}
self.user_to_last_uses : Dict[Node, List[Node]] = {}
node_to_last_use: Dict[Node, Node] = {}
self.user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n : Node, user : Node):
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
self.user_to_last_uses.setdefault(user, []).append(n)
@ -105,7 +115,12 @@ class Interpreter:
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
def run(
self,
*args,
initial_env: Optional[Dict[Node, Any]] = None,
enable_io_processing: bool = True,
) -> Any:
"""
Run `module` via interpretation and return the result.
@ -128,10 +143,16 @@ class Interpreter:
# position and extract those values.
if enable_io_processing:
args = self.graph.process_inputs(*args)
self.args_iter : Iterator[Any] = iter(args)
pbar = tqdm(total=len(self.graph.nodes),
desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
initial=0, position=0, leave=True, disable=config.disable_progress, delay=0)
self.args_iter: Iterator[Any] = iter(args)
pbar = tqdm(
total=len(self.graph.nodes),
desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
initial=0,
position=0,
leave=True,
disable=config.disable_progress,
delay=0,
)
for node in self.graph.nodes:
pbar.update(1)
@ -147,7 +168,7 @@ class Interpreter:
except Exception as e:
if self.extra_traceback:
msg = f"While executing {node.format_node()}"
msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg)
msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg)
msg += f"\nOriginal traceback:\n{node.stack_trace}"
e.args = (msg,) + e.args[1:]
if isinstance(e, KeyError):
@ -158,9 +179,13 @@ class Interpreter:
for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete]
if node.op == 'output':
if node.op == "output":
output_val = self.env[node]
return self.graph.process_outputs(output_val) if enable_io_processing else output_val
return (
self.graph.process_outputs(output_val)
if enable_io_processing
else output_val
)
@compatibility(is_backward_compatible=True)
def boxed_run(self, args_list):
@ -183,7 +208,7 @@ class Interpreter:
yield
@compatibility(is_backward_compatible=True)
def run_node(self, n : Node) -> Any:
def run_node(self, n: Node) -> Any:
"""
Run a specific node ``n`` and return the result.
Calls into placeholder, get_attr, call_function,
@ -204,7 +229,9 @@ class Interpreter:
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def placeholder(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@ -222,7 +249,7 @@ class Interpreter:
Any: The argument value that was retrieved.
"""
assert isinstance(target, str)
if target.startswith('*'):
if target.startswith("*"):
# For a starred parameter e.g. `*args`, retrieve all
# remaining values from the args list.
return list(self.args_iter)
@ -233,10 +260,14 @@ class Interpreter:
if len(args) > 0:
return args[0]
else:
raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
raise RuntimeError(
f"Expected positional argument for parameter {target}, but one was not passed in!"
) from si
@compatibility(is_backward_compatible=True)
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def get_attr(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@ -255,7 +286,9 @@ class Interpreter:
return self.fetch_attr(target)
@compatibility(is_backward_compatible=True)
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
"""
Execute a ``call_function`` node and return the result.
@ -275,7 +308,9 @@ class Interpreter:
return target(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def call_method(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
"""
Execute a ``call_method`` node and return the result.
@ -297,7 +332,9 @@ class Interpreter:
return getattr(self_obj, target)(*args_tail, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def call_module(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
"""
Execute a ``call_module`` node and return the result.
@ -320,7 +357,9 @@ class Interpreter:
return submod(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def output(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@ -339,7 +378,7 @@ class Interpreter:
# Helper methods
@compatibility(is_backward_compatible=True)
def fetch_attr(self, target : str):
def fetch_attr(self, target: str):
"""
Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
@ -349,16 +388,18 @@ class Interpreter:
Return:
Any: The value of the attribute.
"""
target_atoms = target.split('.')
target_atoms = target.split(".")
attr_itr = self.module
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}")
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i + 1])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
@compatibility(is_backward_compatible=True)
def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
def fetch_args_kwargs_from_env(self, n: Node) -> Tuple[Tuple, Dict]:
"""
Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
from the current execution environment.
@ -376,7 +417,7 @@ class Interpreter:
return args, kwargs
@compatibility(is_backward_compatible=True)
def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
def map_nodes_to_values(self, args: Argument, n: Node) -> Argument:
"""
Recursively descend through ``args`` and look up the concrete value
for each ``Node`` in the current execution environment.
@ -386,13 +427,18 @@ class Interpreter:
n (Node): Node to which ``args`` belongs. This is only used for error reporting.
"""
def load_arg(n_arg : Node) -> Any:
def load_arg(n_arg: Node) -> Any:
if n_arg not in self.env:
raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
f'to diagnose such issues')
raise RuntimeError(
f"Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() "
f"to diagnose such issues"
)
return self.env[n_arg]
return map_arg(args, load_arg)
@compatibility(is_backward_compatible=True)
class Transformer(Interpreter):
"""
@ -409,23 +455,29 @@ class Transformer(Interpreter):
method equivalents). We could subclass ``Transformer`` like so::
class NegSigmSwapXformer(Transformer):
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if target == 'neg':
def call_method(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
if target == "neg":
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)
def fn(x):
return torch.sigmoid(x).neg()
gm = torch.fx.symbolic_trace(fn)
transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
@ -452,7 +504,9 @@ class Transformer(Interpreter):
self.tracer.root = module
@compatibility(is_backward_compatible=True)
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
def placeholder(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Proxy:
"""
Execute a ``placeholder`` node. In ``Transformer``, this is
overridden to insert a new ``placeholder`` into the output
@ -467,10 +521,14 @@ class Transformer(Interpreter):
"""
assert isinstance(target, str)
default_value = next(iter(args)) if args else inspect.Signature.empty
return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
return Proxy(
self.new_graph.placeholder(target, default_value=default_value), self.tracer
)
@compatibility(is_backward_compatible=True)
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
def get_attr(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Proxy:
"""
Execute a ``get_attr`` node. In ``Transformer``, this is
overridden to insert a new ``get_attr`` node into the output
@ -487,16 +545,20 @@ class Transformer(Interpreter):
return self.tracer.create_proxy("get_attr", target, args, kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def call_module(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
# Override so that the leaf module policy from `self.tracer` is respected.
assert isinstance(target, str)
submod = self.fetch_attr(target)
return self.tracer.call_module(submod, submod.forward, args, kwargs)
@compatibility(is_backward_compatible=True)
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
# Override so that functions that were wrapped are still wrapped.
return self.tracer.create_proxy('call_function', target, args, kwargs)
return self.tracer.create_proxy("call_function", target, args, kwargs)
@compatibility(is_backward_compatible=True)
def transform(self) -> GraphModule:
@ -507,8 +569,10 @@ class Transformer(Interpreter):
with fx_traceback.preserve_node_meta():
result = super().run(enable_io_processing=False)
if result is not None:
def strip_proxy(a : Union[Argument, Proxy]) -> Any:
def strip_proxy(a: Union[Argument, Proxy]) -> Any:
return a.node if isinstance(a, Proxy) else a
new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy))
# also preserve the metadata from the old output node, if it exists
old_output_node = list(self.graph.nodes)[-1]
@ -516,5 +580,4 @@ class Transformer(Interpreter):
for k, v in old_output_node.meta.items():
new_output_node.meta[k] = v
return _make_graph_module(self.module, self.new_graph)

View File

@ -1,39 +1,71 @@
# Nodes represent a definition of a value in our graph of operators.
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
import builtins
import inspect
import types
import warnings
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
import torch
from torch._C import _NodeBase
from torch.fx.operator_schemas import (
ArgsKwargsPair,
normalize_function,
normalize_module,
)
from .._ops import ops as _ops
from ._compatibility import compatibility
from .immutable_collections import immutable_dict, immutable_list
import torch
import builtins
import types
import inspect
import warnings
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
from .._ops import ops as _ops
from torch._C import _NodeBase
if TYPE_CHECKING:
from .graph import Graph
__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"]
__all__ = ["Node", "map_arg", "map_aggregate", "has_side_effect"]
BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload,
torch.SymInt, torch.SymBool, torch.SymFloat]
BaseArgumentTypes = Union[
str,
int,
float,
bool,
complex,
torch.dtype,
torch.Tensor,
torch.device,
torch.memory_format,
torch.layout,
torch._ops.OpOverload,
torch.SymInt,
torch.SymBool,
torch.SymFloat,
]
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]
Target = Union[Callable[..., Any], str]
Argument = Optional[Union[
Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
List[Any], # actually Argument
Dict[str, Any], # actually Argument
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
range,
'Node',
BaseArgumentTypes
]]
Argument = Optional[
Union[
Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
List[Any], # actually Argument
Dict[str, Any], # actually Argument
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
range,
"Node",
BaseArgumentTypes,
]
]
_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'])
_legal_ops = dict.fromkeys(
[
"placeholder",
"call_method",
"call_module",
"call_function",
"get_attr",
"output",
"root",
]
)
_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
torch._C._set_grad_enabled,
@ -74,7 +106,8 @@ def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
for guess in [torch, torch.nn.functional]:
if getattr(guess, name, None) is orig_method:
return guess.__name__
raise RuntimeError(f'cannot find module for {orig_method}')
raise RuntimeError(f"cannot find module for {orig_method}")
# Borrowed from CPython typing module
# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
@ -86,22 +119,24 @@ def _type_repr(obj: object) -> str:
else, we fall back on repr(obj).
"""
if isinstance(obj, type):
if obj.__module__ == 'builtins':
if obj.__module__ == "builtins":
return obj.__qualname__
return f'{obj.__module__}.{obj.__qualname__}'
return f"{obj.__module__}.{obj.__qualname__}"
if obj is ...:
return '...'
return "..."
if isinstance(obj, types.FunctionType):
return obj.__name__
return repr(obj)
def _get_qualified_name(func: Callable[..., Any]) -> str:
# things like getattr just appear in builtins
if getattr(builtins, func.__name__, None) is func:
return func.__name__
# torch.Tensor.{fn}
if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType))
and func is getattr(torch.Tensor, func.__name__, None)):
if isinstance(
func, (types.MethodDescriptorType, types.WrapperDescriptorType)
) and func is getattr(torch.Tensor, func.__name__, None):
return f"torch.Tensor.{func.__name__}"
name = func.__name__
if name == "<lambda>":
@ -111,33 +146,45 @@ def _get_qualified_name(func: Callable[..., Any]) -> str:
except Exception as e:
raise RuntimeError("Unable to represent lambda") from e
module = _find_module_of_method(func)
module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
module = module.replace(
"torch._ops", "torch.ops"
) # WAR for bug in how torch.ops assigns module
# Fixup segment_reduce mismatch
if module == "torch" and name == "segment_reduce":
name = "_" + name
return f'{module}.{name}'
return f"{module}.{name}"
def _format_arg(arg: object, max_list_len: float = float('inf')) -> str:
if hasattr(arg, '_custom_fx_repr_fn'):
def _format_arg(arg: object, max_list_len: float = float("inf")) -> str:
if hasattr(arg, "_custom_fx_repr_fn"):
return arg._custom_fx_repr_fn()
elif isinstance(arg, list):
items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
return f'[{items}{maybe_len}]'
items = ", ".join(
_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len
)
maybe_len = (
"" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]"
)
return f"[{items}{maybe_len}]"
elif isinstance(arg, tuple):
items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
maybe_comma = ',' if len(arg) == 1 else ''
return f'({items}{maybe_comma}{maybe_len})'
items = ", ".join(
_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len
)
maybe_len = (
"" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]"
)
maybe_comma = "," if len(arg) == 1 else ""
return f"({items}{maybe_comma}{maybe_len})"
elif isinstance(arg, dict):
items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items())
return f'{{{items_str}}}'
items_str = ", ".join(f"{k}: {_format_arg(v)}" for k, v in arg.items())
return f"{{{items_str}}}"
if isinstance(arg, Node):
return '%' + str(arg)
return "%" + str(arg)
else:
return str(arg)
@compatibility(is_backward_compatible=True)
class Node(_NodeBase):
"""
@ -166,23 +213,31 @@ class Node(_NodeBase):
- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
in the Graph printout.
"""
_args: Tuple['Argument', ...]
_kwargs: Dict[str, 'Argument']
graph: 'Graph'
_args: Tuple["Argument", ...]
_kwargs: Dict[str, "Argument"]
graph: "Graph"
name: str
op: str
target: 'Target'
_input_nodes: Dict['Node', None]
users: Dict['Node', None]
target: "Target"
_input_nodes: Dict["Node", None]
users: Dict["Node", None]
type: Optional[Any]
_sort_key: Any
_repr_fn: Optional[Callable[['Node'], str]]
_repr_fn: Optional[Callable[["Node"], str]]
meta: Dict[str, Any]
@compatibility(is_backward_compatible=True)
def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
return_type : Optional[Any] = None) -> None:
def __init__(
self,
graph: "Graph",
name: str,
op: str,
target: "Target",
args: Tuple["Argument", ...],
kwargs: Dict[str, "Argument"],
return_type: Optional[Any] = None,
) -> None:
"""
Instantiate an instance of ``Node``. Note: most often, you want to use the
Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather
@ -210,14 +265,18 @@ class Node(_NodeBase):
of analyses.
"""
assert op in _legal_ops
if op == 'call_function':
if op == "call_function":
if not callable(target):
raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
'but a Callable is expected')
raise ValueError(
f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} "
"but a Callable is expected"
)
else:
if not isinstance(target, str):
raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
'but a str is expected')
raise ValueError(
f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} "
"but a str is expected"
)
super().__init__()
# bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects
@ -225,9 +284,13 @@ class Node(_NodeBase):
assign(self, "graph", graph)
assign(self, "name", name) # unique name of value being created
assign(self, "op", op) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
assign(
self, "op", op
) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
assign(self, "target", target) # for method/module/function, the name of the method/module/function/attr
assign(
self, "target", target
) # for method/module/function, the name of the method/module/function/attr
# being invoked, e.g add, layer1, or torch.add
# All `Node`-valued inputs. Key is the Node, value is don't-care.
@ -280,7 +343,7 @@ class Node(_NodeBase):
self._next = _next
@property
def next(self) -> 'Node':
def next(self) -> "Node":
"""
Returns the next ``Node`` in the linked list of Nodes.
@ -291,7 +354,7 @@ class Node(_NodeBase):
return self._next
@property
def prev(self) -> 'Node':
def prev(self) -> "Node":
"""
Returns the previous ``Node`` in the linked list of Nodes.
@ -302,7 +365,7 @@ class Node(_NodeBase):
return self._prev
@compatibility(is_backward_compatible=True)
def prepend(self, x: 'Node') -> None:
def prepend(self, x: "Node") -> None:
"""
Insert x before this node in the list of nodes in the graph. Example::
@ -316,7 +379,9 @@ class Node(_NodeBase):
"""
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
if self == x:
warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.")
warnings.warn(
"Trying to prepend a node to itself. This behavior has no effect on the graph."
)
return
x._remove_from_list()
p = self._prev
@ -328,28 +393,28 @@ class Node(_NodeBase):
nsk = x._next._sort_key
if len(psk) > len(nsk):
idx: int
*prefix, idx = psk[:len(nsk) + 1]
*prefix, idx = psk[: len(nsk) + 1]
x._sort_key = (*prefix, idx + 1)
elif len(psk) < len(nsk):
*prefix, idx = nsk[:len(psk) + 1]
*prefix, idx = nsk[: len(psk) + 1]
x._sort_key = (*prefix, idx - 1)
else: # same length, increase length by 1
x._sort_key = (*psk, 0)
def __gt__(self, other: 'Node') -> bool:
def __gt__(self, other: "Node") -> bool:
return self._sort_key > other._sort_key
def __lt__(self, other: 'Node') -> bool:
def __lt__(self, other: "Node") -> bool:
return self._sort_key < other._sort_key
def __ge__(self, other: 'Node') -> bool:
def __ge__(self, other: "Node") -> bool:
return self > other or self == other
def __le__(self, other: 'Node') -> bool:
def __le__(self, other: "Node") -> bool:
return self < other or self == other
@compatibility(is_backward_compatible=True)
def append(self, x: 'Node') -> None:
def append(self, x: "Node") -> None:
"""
Insert ``x`` after this node in the list of nodes in the graph.
Equivalent to ``self.next.prepend(x)``
@ -376,7 +441,7 @@ class Node(_NodeBase):
return self._args
@args.setter
def args(self, a : Tuple[Argument, ...]) -> None:
def args(self, a: Tuple[Argument, ...]) -> None:
"""
Set the tuple of arguments to this Node. The interpretation of arguments
depends on the node's opcode. See the ``fx.Graph`` docstring for more
@ -399,7 +464,7 @@ class Node(_NodeBase):
return self._kwargs
@kwargs.setter
def kwargs(self, k : Dict[str, Argument]) -> None:
def kwargs(self, k: Dict[str, Argument]) -> None:
"""
Set the dict of kwargs to this Node. The interpretation of arguments
depends on the node's opcode. See the ``fx.Graph`` docstring for more
@ -410,7 +475,7 @@ class Node(_NodeBase):
self.__update_args_kwargs(self._args, k)
@property
def all_input_nodes(self) -> List['Node']:
def all_input_nodes(self) -> List["Node"]:
"""
Return all Nodes that are inputs to this Node. This is equivalent to
iterating over ``args`` and ``kwargs`` and only collecting the values that
@ -424,7 +489,7 @@ class Node(_NodeBase):
return list(self._input_nodes.keys())
@compatibility(is_backward_compatible=True)
def update_arg(self, idx : int, arg : Argument) -> None:
def update_arg(self, idx: int, arg: Argument) -> None:
"""
Update an existing positional argument to contain the new value
``arg``. After calling, ``self.args[idx] == arg``.
@ -439,7 +504,7 @@ class Node(_NodeBase):
self.args = tuple(args)
@compatibility(is_backward_compatible=True)
def insert_arg(self, idx : int, arg : Argument) -> None:
def insert_arg(self, idx: int, arg: Argument) -> None:
"""
Insert an positional argument to the argument list with given index.
@ -448,7 +513,9 @@ class Node(_NodeBase):
idx (int): The index of the element in ``self.args`` to be inserted before.
arg (Argument): The new argument value to insert into ``args``
"""
assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)"
assert (
0 <= idx <= len(self.args)
), "insert_args index must be between 0 and len(self.args)"
args_left = self.args[:idx]
args_right = self.args[idx:]
@ -463,7 +530,7 @@ class Node(_NodeBase):
new_use.users.setdefault(self)
@compatibility(is_backward_compatible=True)
def update_kwarg(self, key : str, arg : Argument) -> None:
def update_kwarg(self, key: str, arg: Argument) -> None:
"""
Update an existing keyword argument to contain the new value
``arg``. After calling, ``self.kwargs[key] == arg``.
@ -490,13 +557,16 @@ class Node(_NodeBase):
return self.meta.get("stack_trace", None)
@stack_trace.setter
def stack_trace(self, trace : Optional[str]) -> None:
def stack_trace(self, trace: Optional[str]) -> None:
self.meta["stack_trace"] = trace
def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None:
def __update_args_kwargs(
self, new_args: Tuple["Argument", ...], new_kwargs: Dict[str, "Argument"]
) -> None:
"""
This API is internal. Do *not* call it directly.
"""
def update_users_and_input_nodes(n: Any) -> Any:
if isinstance(n, Node):
self._input_nodes.setdefault(n)
@ -512,8 +582,12 @@ class Node(_NodeBase):
# - Normalize list->immutable_list, dict->immutable_dict, etc
# - Populate self._input_nodes
# - Populate arg.users[self] for each arg
object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes))
object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes))
object.__setattr__(
self, "_args", map_aggregate(new_args, update_users_and_input_nodes)
)
object.__setattr__(
self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)
)
def __repr__(self) -> str:
if self._repr_fn:
@ -529,8 +603,8 @@ class Node(_NodeBase):
"""
if isinstance(target, str):
return target
if hasattr(target, '__module__'):
name = getattr(target, '__name__', None)
if hasattr(target, "__module__"):
name = getattr(target, "__name__", None)
if name is None:
# Just to be defensive, if we don't have `__name__`, get the
# qualname. Not sure if this happens for any members of `operator`
@ -538,16 +612,18 @@ class Node(_NodeBase):
# things in `operator` have `_operator` as their __module__.
# TODO: THIS IS BROKEN: _get_qualified_name calls `__name__`
return _get_qualified_name(target) # type: ignore[arg-type]
if target.__module__ == 'builtins':
return f'builtins.{name}'
elif target.__module__ == '_operator':
return f'operator.{name}'
if target.__module__ == "builtins":
return f"builtins.{name}"
elif target.__module__ == "_operator":
return f"operator.{name}"
return _get_qualified_name(target) # type: ignore[arg-type]
@compatibility(is_backward_compatible=True)
def format_node(self,
placeholder_names: Optional[List[str]] = None,
maybe_return_typename: Optional[List[str]] = None) -> Optional[str]:
def format_node(
self,
placeholder_names: Optional[List[str]] = None,
maybe_return_typename: Optional[List[str]] = None,
) -> Optional[str]:
"""
Return a descriptive string representation of ``self``.
@ -576,37 +652,46 @@ class Node(_NodeBase):
return a descriptive string representation of the
current Node.
"""
if self.op == 'placeholder':
if self.op == "placeholder":
assert isinstance(self.target, str)
arg_str = self.target
arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else ''
arg_str += arg_str + f": {_type_repr(self.type)}" if self.type else ""
if placeholder_names:
placeholder_names.append(arg_str)
return None
maybe_typename = f'{_type_repr(self.type)} ' if self.type else ''
default_val = '(default=' + str(self.args[0]) + ')' if self.args else ''
return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}'
elif self.op == 'get_attr':
maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
f'{self.op}[target={self._pretty_print_target(self.target)}]'
elif self.op == 'output':
maybe_typename = f"{_type_repr(self.type)} " if self.type else ""
default_val = "(default=" + str(self.args[0]) + ")" if self.args else ""
return f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}"
elif self.op == "get_attr":
maybe_typename = (
f"{_type_repr(self.type)} " if self.type is not None else ""
)
return (
f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = "
f"{self.op}[target={self._pretty_print_target(self.target)}]"
)
elif self.op == "output":
if self.type and maybe_return_typename:
maybe_return_typename[0] = f' -> {_type_repr(self.type)}'
return f'return {self.args[0]}'
maybe_return_typename[0] = f" -> {_type_repr(self.type)}"
return f"return {self.args[0]}"
else:
maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
f'{self.op}[target={self._pretty_print_target(self.target)}](' \
f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})'
maybe_typename = (
f"{_type_repr(self.type)} " if self.type is not None else ""
)
return (
f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = "
f"{self.op}[target={self._pretty_print_target(self.target)}]("
f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})"
)
@compatibility(is_backward_compatible=True)
def replace_all_uses_with(self,
replace_with: 'Node',
delete_user_cb: Callable[['Node'], bool] = lambda user: True,
*,
propagate_meta: bool = False
) -> List['Node']:
def replace_all_uses_with(
self,
replace_with: "Node",
delete_user_cb: Callable[["Node"], bool] = lambda user: True,
*,
propagate_meta: bool = False,
) -> List["Node"]:
"""
Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
@ -625,9 +710,10 @@ class Node(_NodeBase):
The list of Nodes on which this change was made.
"""
if propagate_meta:
assert len(replace_with.meta) == 0, \
'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \
'but replace_with already has .meta keys'
assert len(replace_with.meta) == 0, (
"Called node.replace_all_uses_with(replace_with, propagate_meta=True), "
"but replace_with already has .meta keys"
)
for k, v in self.meta.items():
replace_with.meta[k] = v
to_process = list(self.users)
@ -638,7 +724,7 @@ class Node(_NodeBase):
skipped.append(use_node)
continue
def maybe_replace_node(n : Node) -> Node:
def maybe_replace_node(n: Node) -> Node:
if n == self:
return replace_with
else:
@ -690,9 +776,12 @@ class Node(_NodeBase):
@compatibility(is_backward_compatible=False)
def normalized_arguments(
self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
kwarg_types : Optional[Dict[str, Any]] = None,
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
self,
root: torch.nn.Module,
arg_types: Optional[Tuple[Any]] = None,
kwarg_types: Optional[Dict[str, Any]] = None,
normalize_to_only_use_kwargs: bool = False,
) -> Optional[ArgsKwargsPair]:
"""
Returns normalized arguments to Python targets. This means that
`args/kwargs` will be matched up to the module/functional's
@ -715,17 +804,23 @@ class Node(_NodeBase):
Returns NamedTuple ArgsKwargsPair, or `None` if not successful.
"""
if self.op == 'call_function':
if self.op == "call_function":
assert callable(self.target)
return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type]
elif self.op == 'call_module':
return normalize_function(
self.target,
self.args, # type: ignore[arg-type]
self.kwargs,
arg_types,
kwarg_types,
)
elif self.op == "call_module":
assert isinstance(self.target, str)
return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type]
return None
@compatibility(is_backward_compatible=True)
def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None:
def replace_input_with(self, old_input: "Node", new_input: "Node") -> None:
"""
Loop through input nodes of ``self``, and replace all instances of
``old_input`` with ``new_input``.
@ -735,7 +830,8 @@ class Node(_NodeBase):
old_input (Node): The old input node to be replaced.
new_input (Node): The new input node to replace ``old_input``.
"""
def maybe_replace_node(n : Node) -> Node:
def maybe_replace_node(n: Node) -> Node:
return new_input if n == old_input else n
m = self.graph.owning_module
@ -756,7 +852,7 @@ class Node(_NodeBase):
self.graph._graph_namespace._rename_object(self, name)
def __setattr__(self, name: str, value: Any) -> None:
if name == 'name' and hasattr(self, "name"):
if name == "name" and hasattr(self, "name"):
m = self.graph.owning_module
if getattr(m, "_replace_hook", None):
assert isinstance(value, str)
@ -764,9 +860,9 @@ class Node(_NodeBase):
m._replace_hook(old=self, new=value, user=user)
update = False
if (
hasattr(self, name) and
hasattr(self.graph, "_find_nodes_lookup_table") and
self in self.graph._find_nodes_lookup_table
hasattr(self, name)
and hasattr(self.graph, "_find_nodes_lookup_table")
and self in self.graph._find_nodes_lookup_table
):
update = True
self.graph._find_nodes_lookup_table.remove(self)
@ -774,6 +870,7 @@ class Node(_NodeBase):
if update:
self.graph._find_nodes_lookup_table.insert(self)
@compatibility(is_backward_compatible=True)
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
"""
@ -782,6 +879,7 @@ def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
@compatibility(is_backward_compatible=True)
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
"""
@ -790,7 +888,7 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
if isinstance(a, tuple):
t = tuple([map_aggregate(elem, fn) for elem in a])
# Support NamedTuple (if it has `_fields`) by repacking into original type.
return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type]
return t if not hasattr(a, "_fields") else type(a)(*t) # type: ignore[arg-type]
elif isinstance(a, list):
return immutable_list([map_aggregate(elem, fn) for elem in a])
elif isinstance(a, dict):
@ -799,6 +897,10 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
dict.__setitem__(rv, k, map_aggregate(v, fn))
return rv
elif isinstance(a, slice):
return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn))
return slice(
map_aggregate(a.start, fn),
map_aggregate(a.stop, fn),
map_aggregate(a.step, fn),
)
else:
return fn(a)

View File

@ -1,63 +1,100 @@
# mypy: allow-untyped-defs
import torch
import enum
import inspect
import numbers
import types
import typing
import enum
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
from typing import (
Any,
Callable,
cast,
Dict,
List,
NamedTuple,
Optional,
Tuple,
TYPE_CHECKING,
)
import torch
from torch._jit_internal import boolean_dispatched
from torch._ops import OpOverload, OpOverloadPacket
from ._compatibility import compatibility
from torch._ops import OpOverloadPacket, OpOverload
if TYPE_CHECKING:
from .node import Argument
__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint",
"type_matches", "normalize_function", "normalize_module"]
__all__ = [
"ArgsKwargsPair",
"check_for_mutable_operation",
"get_signature_for_torch_op",
"create_type_hint",
"type_matches",
"normalize_function",
"normalize_module",
]
@compatibility(is_backward_compatible=False)
class ArgsKwargsPair(NamedTuple):
"""
Simple named tuple for wrapping args/kwargs pairs.
"""
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
_manual_overrides : Dict[Callable, List[inspect.Signature]] = {}
_manual_overrides: Dict[Callable, List[inspect.Signature]] = {}
def _nonzero_schemas():
signatures = []
def nonzero(self):
pass
signatures.append(inspect.signature(nonzero))
def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef]
def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef]
pass
signatures.append(inspect.signature(nonzero))
return signatures
_manual_overrides[torch.nonzero] = _nonzero_schemas()
class _FakeGlobalNamespace:
def __getattr__(self, name):
if name == 'torch':
if name == "torch":
return torch
raise RuntimeError('Expected a torch namespace lookup')
raise RuntimeError("Expected a torch namespace lookup")
_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout,
'number' : numbers.Number, 'Future' : torch.jit.Future,
'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme,
'__torch__': _FakeGlobalNamespace(), 'NoneType': type(None),
'Storage': torch.UntypedStorage,
't': typing.TypeVar('t')}
_type_eval_globals = {
"Tensor": torch.Tensor,
"Device": torch.device,
"Layout": torch.layout,
"number": numbers.Number,
"Future": torch.jit.Future,
"AnyEnumType": enum.Enum,
"QScheme": torch.qscheme,
"__torch__": _FakeGlobalNamespace(),
"NoneType": type(None),
"Storage": torch.UntypedStorage,
"t": typing.TypeVar("t"),
}
for k in dir(typing):
_type_eval_globals[k] = getattr(typing, k)
def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any:
"""
Convert a TorchScript type to a Python type (including subtypes) via
eval'ing the annotation_str. _type_eval_globals sets up expressions
@ -65,9 +102,13 @@ def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
"""
return eval(ts_type.annotation_str, _type_eval_globals)
def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
def _torchscript_schema_to_signature_impl(
ts_schema: torch._C.FunctionSchema,
) -> inspect.Signature:
from inspect import Parameter
parameters : List[Parameter] = []
parameters: List[Parameter] = []
for arg in ts_schema.arguments:
arg_type = _torchscript_type_to_python_type(arg.type)
default = arg.default_value if arg.has_default_value() else Parameter.empty
@ -76,8 +117,12 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -
# argument name. Downstream, if someone converts that positional argument to a keyword
# argument, the name mismatch will break things, so here we're going to normalize the
# name to "input"
name = arg.name if arg.name != 'self' else 'input'
kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD
name = arg.name if arg.name != "self" else "input"
kind = (
Parameter.KEYWORD_ONLY
if arg.kwarg_only
else Parameter.POSITIONAL_OR_KEYWORD
)
# "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
if name == "from":
assert kind == Parameter.POSITIONAL_OR_KEYWORD
@ -87,9 +132,18 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -
# This renders all previous arguments to positional only
for idx, p in enumerate(parameters):
assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation)
parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type))
return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns]
parameters[idx] = Parameter(
name=p.name,
kind=Parameter.POSITIONAL_ONLY,
default=p.default,
annotation=p.annotation,
)
parameters.append(
Parameter(name=name, kind=kind, default=default, annotation=arg_type)
)
return_types = [
_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns
]
if len(return_types) == 0:
return_type = None
elif len(return_types) == 1:
@ -99,9 +153,13 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -
return inspect.Signature(parameters, return_annotation=return_type)
_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {}
def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
_SCHEMA_TO_SIGNATURE_CACHE: Dict[Tuple[str, str], inspect.Signature] = {}
def _torchscript_schema_to_signature(
ts_schema: torch._C.FunctionSchema,
) -> inspect.Signature:
# Cached as it's called in the hot path of FakeTensor dispatch
cache_key = ts_schema.name, ts_schema.overload_name
cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key)
@ -112,8 +170,11 @@ def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> ins
_SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res
return res
@compatibility(is_backward_compatible=False)
def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']):
def check_for_mutable_operation(
target: Callable, args: Tuple["Argument", ...], kwargs: Dict[str, "Argument"]
):
signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
if signatures and schemas:
@ -131,9 +192,11 @@ def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...]
def throw_if_mutable(schema):
if schema.is_mutable:
raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional '
f'code, so operations that mutate operands in-place (e.g. via `out` arguments) '
f'are not supported')
raise RuntimeError(
f"Tried to trace mutable operation {schema}. FX only supports functional "
f"code, so operations that mutate operands in-place (e.g. via `out` arguments) "
f"are not supported"
)
if len(matched_schemas) == 0:
# Did not match any schema. Cannot check for mutation
@ -147,8 +210,9 @@ def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...]
# do nothing.
pass
@compatibility(is_backward_compatible=False)
def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
def get_signature_for_torch_op(op: Callable, return_schemas: bool = False):
"""
Given an operator on the `torch` namespace, return a list of `inspect.Signature`
objects corresponding to the overloads of that op.. May return `None` if a signature
@ -181,6 +245,7 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
return (signatures, schemas) if return_schemas else signatures
@compatibility(is_backward_compatible=False)
def create_type_hint(x):
"""
@ -198,11 +263,15 @@ def create_type_hint(x):
if isinstance(x, (list, tuple)):
# todo(chilli): Figure out the right way for mypy to handle this
if isinstance(x, list):
def ret_type(x):
return List[x] # type: ignore[valid-type]
else:
def ret_type(x):
return Tuple[x, ...]
if len(x) == 0:
return ret_type(Any)
base_type = x[0]
@ -216,12 +285,15 @@ def create_type_hint(x):
return ret_type(base_type)
except Exception:
# We tried to create a type hint for list but failed.
warnings.warn(f"We were not able to successfully create type hint from the type {x}")
warnings.warn(
f"We were not able to successfully create type hint from the type {x}"
)
return x
@compatibility(is_backward_compatible=False)
def type_matches(signature_type : Any, argument_type : Any):
sig_origin_type = getattr(signature_type, '__origin__', signature_type)
def type_matches(signature_type: Any, argument_type: Any):
sig_origin_type = getattr(signature_type, "__origin__", signature_type)
if signature_type is argument_type:
return True
@ -236,13 +308,14 @@ def type_matches(signature_type : Any, argument_type : Any):
# int can be promoted to List[int]
return True
if getattr(signature_type, '__origin__', None) in {list, List}:
if getattr(signature_type, "__origin__", None) in {list, List}:
sig_el_type = signature_type.__args__[0]
if not inspect.isclass(sig_el_type):
warnings.warn(
f"Does not support nested parametric types, got {signature_type}. Please file a bug.")
f"Does not support nested parametric types, got {signature_type}. Please file a bug."
)
return False
if getattr(argument_type, '__origin__', None) in {list, List}:
if getattr(argument_type, "__origin__", None) in {list, List}:
return issubclass(argument_type.__args__[0], sig_el_type)
def is_homogeneous_tuple(t):
@ -267,11 +340,16 @@ def type_matches(signature_type : Any, argument_type : Any):
return False
@compatibility(is_backward_compatible=False)
def normalize_function(
target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
kwarg_types : Optional[Dict[str, Any]] = None,
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
target: Callable,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
arg_types: Optional[Tuple[Any]] = None,
kwarg_types: Optional[Dict[str, Any]] = None,
normalize_to_only_use_kwargs: bool = False,
) -> Optional[ArgsKwargsPair]:
"""
Returns normalized arguments to PyTorch functions. This means that
`args/kwargs` will be matched up to the functional's
@ -308,14 +386,19 @@ def normalize_function(
# branch signature for analysis. Otherwise, leave this un-normalized
assert not isinstance(target, str)
dispatched = boolean_dispatched[target]
if_true, if_false = dispatched['if_true'], dispatched['if_false']
if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters:
if_true, if_false = dispatched["if_true"], dispatched["if_false"]
if (
inspect.signature(if_true).parameters
!= inspect.signature(if_false).parameters
):
return None
target_for_analysis = if_true
assert callable(target_for_analysis)
sig = inspect.signature(inspect.unwrap(target_for_analysis))
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs)
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
sig, args, kwargs, normalize_to_only_use_kwargs
)
else:
assert callable(target)
torch_op_schemas = get_signature_for_torch_op(target)
@ -336,8 +419,9 @@ def normalize_function(
pass
elif len(matched_schemas) == 1:
# Matched exactly one schema, unambiguous
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs,
normalize_to_only_use_kwargs)
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs
)
else:
if arg_types is not None or kwarg_types is not None:
arg_types = arg_types if arg_types else cast(Tuple[Any], ())
@ -345,30 +429,49 @@ def normalize_function(
for candidate_signature in torch_op_schemas:
sig_matches = True
try:
bound_types = candidate_signature.bind(*arg_types, **kwarg_types)
bound_types = candidate_signature.bind(
*arg_types, **kwarg_types
)
for arg_name, arg_type in bound_types.arguments.items():
param = candidate_signature.parameters[arg_name]
sig_matches = sig_matches and type_matches(param.annotation, arg_type)
sig_matches = sig_matches and type_matches(
param.annotation, arg_type
)
except TypeError:
sig_matches = False
if sig_matches:
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,
normalize_to_only_use_kwargs)
new_args_and_kwargs = (
_args_kwargs_to_normalized_args_kwargs(
candidate_signature,
args,
kwargs,
normalize_to_only_use_kwargs,
)
)
break
else:
# Matched more than one schema. In this situation, the caller must provide the types of
# the arguments of the overload they expect.
schema_printouts = '\n'.join(str(schema) for schema in matched_schemas)
raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but '
f'the schema match was ambiguous! Please provide argument types to '
f'the normalize_arguments() call. Available schemas:\n{schema_printouts}')
schema_printouts = "\n".join(
str(schema) for schema in matched_schemas
)
raise RuntimeError(
f"Tried to normalize arguments to {torch.typename(target)} but "
f"the schema match was ambiguous! Please provide argument types to "
f"the normalize_arguments() call. Available schemas:\n{schema_printouts}"
)
return new_args_and_kwargs
@compatibility(is_backward_compatible=False)
def normalize_module(
root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
root: torch.nn.Module,
target: str,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
normalize_to_only_use_kwargs: bool = False,
) -> Optional[ArgsKwargsPair]:
"""
Returns normalized arguments to PyTorch modules. This means that
`args/kwargs` will be matched up to the functional's
@ -391,22 +494,29 @@ def normalize_module(
try:
submod = root.get_submodule(target)
except AttributeError as e:
raise RuntimeError(f"Tried to normalize node with target {target} but root did not "
f"have that target!") from e
if hasattr(submod.__class__, '__name__'):
raise RuntimeError(
f"Tried to normalize node with target {target} but root did not "
f"have that target!"
) from e
if hasattr(submod.__class__, "__name__"):
classname = submod.__class__.__name__
if getattr(torch.nn, classname, None) == submod.__class__:
sig = inspect.signature(inspect.unwrap(submod.forward))
if kwargs is None:
kwargs = {}
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,
normalize_to_only_use_kwargs)
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
sig, args, kwargs, normalize_to_only_use_kwargs
)
return new_args_and_kwargs
return None
def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...],
kwargs : Dict[str, Any],
normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]:
def _args_kwargs_to_normalized_args_kwargs(
sig: inspect.Signature,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
normalize_to_only_use_kwargs: bool,
) -> Optional[ArgsKwargsPair]:
"""
Given a call target, args, and kwargs, return the arguments normalized into
an ArgsKwargsPair, or None if the type signature is not supported by
@ -428,20 +538,22 @@ def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple
# Don't currently support positional-only
# or varargs (*args, **kwargs) signatures
supported_parameter_types = {
inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
}
if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
# Add an exception for one signature, which is common for random/uniform, i.e.:
# Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
# `from` is Python keyword and as such functions with that signature should have
# positional-only args, but at the same time they could be dispatched as kwargs
if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']:
if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]:
return None
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
new_kwargs : Dict[str, Any] = {}
new_args : List[Any] = []
new_kwargs: Dict[str, Any] = {}
new_args: List[Any] = []
for i, param in enumerate(sig.parameters):
if not normalize_to_only_use_kwargs and i < len(args):
new_args.append(bound_args.arguments[param])

View File

@ -1,12 +1,14 @@
from . import graph_drawer
from . import graph_manipulation
from . import net_min_base
from . import operator_support
from . import param_fetch
from . import reinplace
from . import runtime_assert
from . import shape_prop
from . import split_module
from . import split_utils
from . import splitter_base
from . import tools_common
from . import (
graph_drawer,
graph_manipulation,
net_min_base,
operator_support,
param_fetch,
reinplace,
runtime_assert,
shape_prop,
split_module,
split_utils,
splitter_base,
tools_common,
)

View File

@ -1,12 +1,13 @@
# mypy: allow-untyped-defs
import operator
import torch
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.utils import _pytree as pytree
import operator
class CudaGraphsSupport(OperatorSupport):
# TODO: why is submodules passed here
@ -27,7 +28,7 @@ class CudaGraphsSupport(OperatorSupport):
def find_not_cuda(t):
nonlocal found_not_cuda
if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
if isinstance(t, torch.Tensor) and t.device.type != "cuda":
found_not_cuda = True
for n in node.all_input_nodes:
@ -40,6 +41,7 @@ class CudaGraphsSupport(OperatorSupport):
return not found_not_cuda
def partition_cudagraphs(gm, inputs):
"""
Partition an FX graph into sub-GraphModules that can be validly run under
@ -51,7 +53,9 @@ def partition_cudagraphs(gm, inputs):
supported_ops = CudaGraphsSupport()
# TODO: single node partition may be wrong due to the pessimization
# from copying in and out the data. Check in benchmarks, perhaps
partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True)
partitioner = CapabilityBasedPartitioner(
gm, supported_ops, allows_single_node_partition=True
)
partitions = partitioner.propose_partitions()
fused_graph = partitioner.fuse_partitions(partitions)
return fused_graph

View File

@ -1,20 +1,45 @@
# mypy: allow-untyped-defs
from typing import Dict, Tuple, Any
from typing import Any, Dict, Tuple
import torch
from torch.fx import Graph, GraphModule, Node
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.utils._pytree import tree_flatten
from torch.fx import GraphModule, Graph
from torch.fx import Node
aten = torch.ops.aten
# stateful ops are banned from CSE
rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950
rand_ops = {
aten.dropout,
aten._fused_dropout,
aten._standard_gamma,
aten.bernoulli,
aten.multinomial,
aten.native_dropout,
aten.normal,
aten.poisson,
aten.binomial,
aten.rrelu,
aten.rand_like,
aten.rand,
aten.randint,
aten.randn,
aten.randperm,
} # noqa: E501,B950
inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
inplace_ops = {
aten.add_,
aten.sub_,
aten.mul_,
aten.div_,
aten.pow_,
aten.lerp_,
aten.relu_,
aten.sigmoid_,
aten.tanh_,
} # noqa: E501
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
@ -24,7 +49,6 @@ def get_CSE_banned_ops():
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
class CSEPass(PassBase):
def __init__(self, banned_ops=None):
"""
This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
@ -58,20 +82,32 @@ class CSEPass(PassBase):
result = p(traced_graph)
print(result.graph_module)
"""
def get_aten_target(node):
if hasattr(node.target, 'overloadpacket'):
if hasattr(node.target, "overloadpacket"):
return node.target.overloadpacket
return node.target
modified = False
new_graph = Graph()
env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph
hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph
token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token
env: Dict[
Node, Node
] = {} # map from node in the old graph to node in the new graph
hash_env: Dict[
Tuple[torch._ops.OpOverload, int], Node
] = {} # map from hash to a node in the new graph
token_map: Dict[
Tuple[torch._ops.OpOverload, int], Dict[str, Any]
] = {} # map from hash to token
for n in graph_module.graph.nodes:
# The placeholder, output, and get_attr nodes are copied to the new graph without change
# do not CSE away random operations
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops:
if (
n.op == "placeholder"
or n.op == "output"
or n.op == "get_attr"
or get_aten_target(n) in self.banned_ops
):
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
@ -84,13 +120,19 @@ class CSEPass(PassBase):
if isinstance(v, Node) and v in env:
arg_list[i] = env[v]
return tuple(arg_list), spec
args, args_spec = substitute(n.args)
kwargs, kwargs_spec = substitute(n.kwargs)
# each token corresponds to a unique node
# nodes with the same token can be substituted
token = {"target": n.target, "args": args, "args_spec": args_spec,
"kwargs": kwargs, "kwargs_spec": kwargs_spec}
token = {
"target": n.target,
"args": args,
"args_spec": args_spec,
"kwargs": kwargs,
"kwargs_spec": kwargs_spec,
}
# hash substituted args to a number, do not hash specs because specs are not hashable
hash_arg = hash((args, kwargs))

View File

@ -2,13 +2,15 @@
from typing import Optional
import torch.fx
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx import Node
from torch.fx.node import map_aggregate
from torch.fx._compatibility import compatibility
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.fx.experimental.proxy_tensor import snapshot_fake, py_sym_types
from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
from torch.fx.node import map_aggregate
__all__ = ["FakeTensorProp"]
__all__ = ['FakeTensorProp']
@compatibility(is_backward_compatible=False)
class FakeTensorProp(torch.fx.Interpreter):
@ -24,7 +26,10 @@ class FakeTensorProp(torch.fx.Interpreter):
module (GraphModule): The module to be executed
mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
"""
def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None):
def __init__(
self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None
):
super().__init__(module)
if mode is None:
mode = FakeTensorMode()
@ -33,7 +38,10 @@ class FakeTensorProp(torch.fx.Interpreter):
mode.reset_nt_tensor_id_counter()
def run_node(self, n: Node):
from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
rebind_unbacked,
)
result = super().run_node(n)
rebind_unbacked(self._mode.shape_env, n, result)
@ -52,8 +60,10 @@ class FakeTensorProp(torch.fx.Interpreter):
meta = map_aggregate(result, extract_val)
if meta is not None:
n.meta['val'] = meta
if (shape_env := self._mode.shape_env) and (symbol_to_path := compute_unbacked_bindings(shape_env, result)):
n.meta["val"] = meta
if (shape_env := self._mode.shape_env) and (
symbol_to_path := compute_unbacked_bindings(shape_env, result)
):
n.meta["unbacked_bindings"] = symbol_to_path
return result

View File

@ -58,6 +58,7 @@ _WEIGHT_TEMPLATE = {
}
if HAS_PYDOT:
@compatibility(is_backward_compatible=False)
class FxGraphDrawer:
"""
@ -87,7 +88,12 @@ if HAS_PYDOT:
self._dot_graphs = {
name: self._to_dot(
graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace
graph_module,
name,
ignore_getattr,
ignore_parameters_and_buffers,
skip_node_names_in_args,
parse_stack_trace,
)
}
@ -127,8 +133,8 @@ if HAS_PYDOT:
>>> symbolic_traced = torch.fx.symbolic_trace(module)
>>> # setup output file
>>> import ubelt as ub
>>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir()
>>> fpath = dpath / 'linear.svg'
>>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir()
>>> fpath = dpath / "linear.svg"
>>> # draw the graph
>>> g = FxGraphDrawer(symbolic_traced, "linear")
>>> g.get_dot_graph().write_svg(fpath)
@ -148,7 +154,6 @@ if HAS_PYDOT:
return self._dot_graphs
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
template = {
"shape": self.dot_graph_shape,
"fillcolor": "#CAFFE3",
@ -161,7 +166,9 @@ if HAS_PYDOT:
# Use a random color for each node; based on its name so it's stable.
target_name = node._pretty_print_target(node.target)
target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16)
template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)]
template["fillcolor"] = _HASH_COLOR_MAP[
target_hash % len(_HASH_COLOR_MAP)
]
return template
def _get_leaf_node(
@ -199,12 +206,11 @@ if HAS_PYDOT:
full_file_name: str,
truncate_to_last_n: int = 2,
):
splits = full_file_name.split('/')
splits = full_file_name.split("/")
if len(splits) >= truncate_to_last_n:
return '/'.join(splits[-truncate_to_last_n:])
return "/".join(splits[-truncate_to_last_n:])
return full_file_name
def _get_node_label(
self,
module: torch.fx.GraphModule,
@ -219,8 +225,7 @@ if HAS_PYDOT:
elif isinstance(arg, dict):
prefix, suffix = r"|kwargs={\l", r",\n}\l"
arg_strs_list = [
f"{k}: {_format_arg(v, max_list_len=8)}"
for k, v in arg.items()
f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items()
]
else: # Fall back to nothing in unexpected case.
return ""
@ -235,7 +240,6 @@ if HAS_PYDOT:
arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
return arg_strs.replace("{", r"\{").replace("}", r"\}")
label = "{" + f"name=%{node.name}|op_code={node.op}\n"
if node.op == "call_module":
@ -244,7 +248,10 @@ if HAS_PYDOT:
extra = ""
if hasattr(leaf_module, "__constants__"):
extra = r"\n".join(
[f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr]
[
f"{c}: {getattr(leaf_module, c)}"
for c in leaf_module.__constants__
] # type: ignore[union-attr]
)
label += extra + r"\n"
else:
@ -252,7 +259,10 @@ if HAS_PYDOT:
if self.normalize_args:
try:
args, kwargs = normalize_function( # type: ignore[misc]
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
node.target, # type: ignore[arg-type]
node.args, # type: ignore[arg-type]
node.kwargs,
normalize_to_only_use_kwargs=True,
)
except Exception:
# Fallback to not normalizing if there's an exception.
@ -266,12 +276,12 @@ if HAS_PYDOT:
label += _get_str_for_args_kwargs(kwargs)
label += f"|num_users={len(node.users)}" + r"\n"
tensor_meta = node.meta.get('tensor_meta')
tensor_meta = node.meta.get("tensor_meta")
label += self._tensor_meta_to_label(tensor_meta)
# for original fx graph
# print buf=buf0, n_origin=6
buf_meta = node.meta.get('buf_meta', None)
buf_meta = node.meta.get("buf_meta", None)
if buf_meta is not None:
label += f"|buf={buf_meta.name}" + r"\n"
label += f"|n_origin={buf_meta.n_origin}" + r"\n"
@ -281,8 +291,10 @@ if HAS_PYDOT:
if parse_stack_trace and node.stack_trace is not None:
parsed_stack_trace = _parse_stack_trace(node.stack_trace)
fname = self._shorten_file_name(parsed_stack_trace.file)
label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n"
label += (
f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}"
+ r"\n"
)
return label + "}"
@ -322,19 +334,43 @@ if HAS_PYDOT:
assert "qscheme" in tm.qparams
qscheme = tm.qparams["qscheme"]
if qscheme in {
torch.per_tensor_affine,
torch.per_tensor_symmetric,
torch.per_tensor_affine,
torch.per_tensor_symmetric,
}:
result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
result += (
"|"
+ "q_zero_point"
+ "="
+ str(tm.qparams["zero_point"])
+ r"\n"
)
elif qscheme in {
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
}:
result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n"
result += (
"|"
+ "q_per_channel_scale"
+ "="
+ str(tm.qparams["scale"])
+ r"\n"
)
result += (
"|"
+ "q_per_channel_zero_point"
+ "="
+ str(tm.qparams["zero_point"])
+ r"\n"
)
result += (
"|"
+ "q_per_channel_axis"
+ "="
+ str(tm.qparams["axis"])
+ r"\n"
)
else:
raise RuntimeError(f"Unsupported qscheme: {qscheme}")
result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
@ -363,7 +399,6 @@ if HAS_PYDOT:
# "TB" means top-to-bottom rank direction in layout
dot_graph = pydot.Dot(name, rankdir="TB")
buf_name_to_subgraph = {}
for node in graph_module.graph.nodes:
@ -372,16 +407,22 @@ if HAS_PYDOT:
style = self._get_node_style(node)
dot_node = pydot.Node(
node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style
node.name,
label=self._get_node_label(
graph_module, node, skip_node_names_in_args, parse_stack_trace
),
**style,
)
current_graph = dot_graph
buf_meta = node.meta.get('buf_meta', None)
buf_meta = node.meta.get("buf_meta", None)
if buf_meta is not None and buf_meta.n_origin > 1:
buf_name = buf_meta.name
if buf_name not in buf_name_to_subgraph:
buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name)
buf_name_to_subgraph[buf_name] = pydot.Cluster(
buf_name, label=buf_name
)
current_graph = buf_name_to_subgraph.get(buf_name)
current_graph.add_node(dot_node)
@ -407,12 +448,14 @@ if HAS_PYDOT:
if node.op == "call_module":
leaf_module = self._get_leaf_node(graph_module, node)
if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule):
if not ignore_parameters_and_buffers and not isinstance(
leaf_module, torch.fx.GraphModule
):
get_module_params_or_buffers()
for subgraph in buf_name_to_subgraph.values():
subgraph.set('color', 'royalblue')
subgraph.set('penwidth', '2')
subgraph.set("color", "royalblue")
subgraph.set("penwidth", "2")
dot_graph.add_subgraph(subgraph)
for node in graph_module.graph.nodes:
@ -426,6 +469,7 @@ if HAS_PYDOT:
else:
if not TYPE_CHECKING:
@compatibility(is_backward_compatible=False)
class FxGraphDrawer:
def __init__(
@ -439,5 +483,7 @@ else:
dot_graph_shape: Optional[str] = None,
normalize_args: bool = False,
):
raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install '
'pydot through your favorite Python package manager.')
raise RuntimeError(
"FXGraphDrawer requires the pydot package to be installed. Please install "
"pydot through your favorite Python package manager."
)

View File

@ -5,15 +5,18 @@ import torch
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import (
map_arg,
Node,
Target,
)
from torch.fx.node import map_arg, Node, Target
from torch.fx.passes.shape_prop import ShapeProp
__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
'get_size_of_node']
__all__ = [
"replace_target_nodes_with",
"size_bytes",
"get_size_of_all_nodes",
"get_tensor_meta",
"get_size_of_node",
]
@compatibility(is_backward_compatible=False)
def replace_target_nodes_with(

View File

@ -1,2 +1 @@
from . import pass_manager

View File

@ -1,22 +1,24 @@
# mypy: allow-untyped-defs
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
import collections
import itertools
import logging
from copy import copy
from typing import Dict, Iterable, List, Optional, Sequence, Set
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node, _get_qualified_name
from torch.fx.node import _get_qualified_name, Node
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
class Partition:
def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
def __init__(
self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None
):
self.id = id
self.nodes = dict.fromkeys(nodes) if nodes is not None else {}
@ -32,6 +34,7 @@ class Partition:
def size(self):
return len(self.nodes)
class _DependencyViewer:
def __init__(self, graph_module: GraphModule):
self.upstreams = collections.defaultdict(set)
@ -55,15 +58,16 @@ class _DependencyViewer:
def upstreams_of(self, node: Node) -> Set[Node]:
return self.upstreams[node]
class CapabilityBasedPartitioner:
def __init__(self,
graph_module: GraphModule,
operator_support: OperatorSupportBase,
allows_single_node_partition: bool = False,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
) -> None:
class CapabilityBasedPartitioner:
def __init__(
self,
graph_module: GraphModule,
operator_support: OperatorSupportBase,
allows_single_node_partition: bool = False,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
) -> None:
self.graph_module = graph_module
self.operator_support = operator_support
self.allows_single_node_partition = allows_single_node_partition
@ -76,19 +80,21 @@ class CapabilityBasedPartitioner:
self.dependency_viewer = _DependencyViewer(graph_module)
def __is_node_supported(self, node: Node) -> bool:
return (
self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
return self.operator_support.is_node_supported(
dict(self.graph_module.named_modules()), node
)
def propose_partitions(self) -> List[Partition]:
# partition_map is a mapping from partition id to a set of partition id's.
# The value set contains all the partition ids that can be reached by doing a
# DFS starting from the partition id in the key.
partition_map : Dict[int, Set] = collections.defaultdict(set)
partition_map: Dict[int, Set] = collections.defaultdict(set)
# assumptions: nodes in candidate list is sorted in topological order
assignment: Dict[Node, int] = {} # mapping from node to partition_id
partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition
assignment: Dict[Node, int] = {} # mapping from node to partition_id
partitions_by_id: Dict[
int, Partition
] = {} # mapping from partition_id to partition
new_partition_id = itertools.count()
# try to merge partition other_id into partition self_id
@ -149,7 +155,9 @@ class CapabilityBasedPartitioner:
# delete other partition
del partitions_by_id[other_id]
partition_map[self_id] = partition_map[self_id].union(partition_map[other_id])
partition_map[self_id] = partition_map[self_id].union(
partition_map[other_id]
)
del partition_map[other_id]
return True
@ -223,16 +231,18 @@ class CapabilityBasedPartitioner:
for node in self.graph_module.graph.nodes:
is_tuple_output = True
for user in node.users:
if user.op != "call_function" or \
_get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type]
if (
user.op != "call_function"
or _get_qualified_name(user.target) != "_operator.getitem"
): # type: ignore[arg-type]
is_tuple_output = False
break
# node has tuple outputs, re-assign all following getitem node into node's partition
if is_tuple_output:
id = assignment.get(node, None) # type: ignore[arg-type]
id = assignment.get(node, None) # type: ignore[arg-type]
for user in node.users:
if assignment.get(user, None) != id: # type: ignore[arg-type]
if assignment.get(user, None) != id: # type: ignore[arg-type]
nodes_reassignment[user] = id # type: ignore[assignment]
for node, id in nodes_reassignment.items():
merge_single_node(node, id)
@ -250,7 +260,10 @@ class CapabilityBasedPartitioner:
assert callable(node.target)
if _get_qualified_name(node.target) not in non_compute_ops:
compute_node_count += 1
if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
if (
_get_qualified_name(node.target)
in self.allowed_single_node_partition_ops
):
compute_node_count += 1
if compute_node_count <= 1:
partitions_to_remove.append(id)
@ -259,11 +272,17 @@ class CapabilityBasedPartitioner:
logger.debug("Partitions proposed:")
for id, partition in partitions_by_id.items():
logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes])
logger.debug(
"partition #%s: %s", id, [node.name for node in partition.nodes]
)
return [partition for partition in partitions_by_id.values() if partition.size() > 0]
return [
partition for partition in partitions_by_id.values() if partition.size() > 0
]
def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule:
def fuse_partitions(
self, partitions: List[Partition], prefix: str = "fused_"
) -> GraphModule:
logger.debug("Fusing partitions...")
# fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ]
return fuse_by_partitions(
@ -277,15 +296,23 @@ class CapabilityBasedPartitioner:
non_compute_ops = set(self.non_compute_ops)
def is_non_compute_node(node: Node):
return node.op == "call_function" and \
_get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
return (
node.op == "call_function"
and _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
)
# cache transparent nodes
transparent_input_nodes: Dict[Node, bool] = {}
transparent_output_nodes: Dict[Node, bool] = {}
def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
def is_transparent_input_node(
node: Node, partition: Set[Node], removed_nodes: Set[Node]
):
if (
node.op == "placeholder"
or (node not in partition)
or (node in removed_nodes)
):
return True
if node in transparent_input_nodes:
return transparent_input_nodes[node]
@ -299,14 +326,22 @@ class CapabilityBasedPartitioner:
transparent_input_nodes[node] = False
return False
def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
def is_transparent_output_node(
node: Node, partition: Set[Node], removed_nodes: Set[Node]
):
if (
node.op == "placeholder"
or (node not in partition)
or (node in removed_nodes)
):
return True
if node in transparent_output_nodes:
return transparent_output_nodes[node]
if is_non_compute_node(node):
for output_n in node.users:
if not is_transparent_output_node(output_n, partition, removed_nodes):
if not is_transparent_output_node(
output_n, partition, removed_nodes
):
transparent_output_nodes[node] = False
return False
transparent_output_nodes[node] = True
@ -320,9 +355,12 @@ class CapabilityBasedPartitioner:
# the set.
remove_node: Set[Node] = set()
for node in partition.nodes:
if is_non_compute_node(node) and \
(is_transparent_input_node(node, set(partition.nodes), remove_node) or
is_transparent_output_node(node, set(partition.nodes), remove_node)):
if is_non_compute_node(node) and (
is_transparent_input_node(node, set(partition.nodes), remove_node)
or is_transparent_output_node(
node, set(partition.nodes), remove_node
)
):
remove_node.add(node)
if len(remove_node) != 0:

View File

@ -3,11 +3,12 @@ import abc
from collections import namedtuple
from typing import Optional
from torch.fx.graph_module import GraphModule
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
__all__ = ['PassResult', 'PassBase']
__all__ = ["PassResult", "PassBase"]
@compatibility(is_backward_compatible=False)
class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
@ -16,9 +17,11 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
graph_module: The modified graph module
modified: A flag for if the pass has modified the graph module
"""
def __new__(cls, graph_module, modified):
return super().__new__(cls, graph_module, modified)
@compatibility(is_backward_compatible=False)
class PassBase(abc.ABC):
"""

View File

@ -1,19 +1,21 @@
# mypy: allow-untyped-defs
import inspect
import logging
from queue import Queue
from functools import wraps
from queue import Queue
from typing import Callable, Dict, List
import torch.nn as nn
from torch.fx.graph_module import GraphModule
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from torch.fx.passes.infra.pass_base import PassResult
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
__all__ = ["pass_result_wrapper", "this_before_that_pass_constraint", "PassManager"]
@compatibility(is_backward_compatible=False)
def pass_result_wrapper(fn: Callable) -> Callable:
@ -46,6 +48,7 @@ def pass_result_wrapper(fn: Callable) -> Callable:
return wrapped_fn
def _validate_pass_schedule_constraint(
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
) -> None:
@ -59,6 +62,7 @@ def _validate_pass_schedule_constraint(
f" list."
)
def _topological_sort_passes(
passes: List[Callable], constraints: List[Callable]
) -> List[Callable]:
@ -75,7 +79,7 @@ def _topological_sort_passes(
return passes
# Contruct a graph mapping nodes to a list of their users
graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
graph: Dict[Callable, List[Callable]] = {p: [] for p in passes}
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
candidates: Queue = Queue()
for a in passes:
@ -108,11 +112,14 @@ def _topological_sort_passes(
# Check if there are unvisited nodes (aka cycles in the graph)
cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
if len(cycle_passes) != 0:
error = f"Circular dependency detected within the following passes: {cycle_passes}"
error = (
f"Circular dependency detected within the following passes: {cycle_passes}"
)
raise RuntimeError(error)
return sorted_passes
@compatibility(is_backward_compatible=False)
def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
"""
@ -123,9 +130,7 @@ def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable
```
passes = [pass_b, pass_a]
constraints = [
this_before_that_pass_constraint(pass_a, pass_b)
]
constraints = [this_before_that_pass_constraint(pass_a, pass_b)]
```
Args:
@ -231,7 +236,9 @@ class PassManager:
sig = inspect.signature(check)
if len(list(sig.parameters.values())) != 1:
raise TypeError("PassManager check function should only take in one variable, a module")
raise TypeError(
"PassManager check function should only take in one variable, a module"
)
setattr(self, "check", check) # noqa: B010

View File

@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.fx
from torch.fx._compatibility import compatibility
from torch.fx.node import map_arg
@ -21,6 +20,7 @@ from .tools_common import (
Tensors,
)
__all__ = [
"FxNetMinimizerBadModuleError",
"FxNetMinimizerRunFuncError",
@ -37,7 +37,6 @@ class FxNetMinimizerBadModuleError(Exception):
"""
@compatibility(is_backward_compatible=False)
class FxNetMinimizerRunFuncError(Exception):
"""
@ -45,7 +44,6 @@ class FxNetMinimizerRunFuncError(Exception):
"""
@compatibility(is_backward_compatible=False)
class FxNetMinimizerResultMismatchError(Exception):
"""
@ -53,7 +51,6 @@ class FxNetMinimizerResultMismatchError(Exception):
"""
@dataclass
class _MinimizerSettingBase:
"""
@ -109,14 +106,9 @@ class _MinimizerBase:
],
settings: _MinimizerSettingBase,
module_exporter: Optional[
Callable[
[Tensors, torch.fx.GraphModule, str],
None
]
] = None,
exclusion_fn: Optional[
Callable[[NodeList, int, int], None]
Callable[[Tensors, torch.fx.GraphModule, str], None]
] = None,
exclusion_fn: Optional[Callable[[NodeList, int, int], None]] = None,
):
assert isinstance(module, torch.fx.GraphModule)
@ -159,14 +151,18 @@ class _MinimizerBase:
self.a_outputs[name] = sample_input[i]
self.b_outputs[name] = sample_input[i]
def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors:
def run_a(
self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1
) -> TensorOrTensors:
"""
Run `mod` with `inputs` and generate output. The output will be compared with
output of run_b().
"""
raise RuntimeError("run_a() is not implemented.")
def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors:
def run_b(
self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1
) -> TensorOrTensors:
"""
Run `mod` with `inputs` and generate output. The output will be compared with
output of run_a().
@ -323,7 +319,7 @@ class _MinimizerBase:
split_module: torch.fx.GraphModule,
submod_name: str,
output_names: Names,
report_idx: int = -1
report_idx: int = -1,
):
"""
Run the submodule in `split_module` that has name `submod_name`
@ -388,10 +384,14 @@ class _MinimizerBase:
report.append(f"Result mismatch for {result_key}")
if self.module_exporter:
self.module_exporter(
a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index]
a_input,
submodule,
str(result_key[0]) + "_cpu", # type: ignore[index]
)
self.module_exporter(
b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index]
b_input,
submodule,
str(result_key[0]) + "_acc", # type: ignore[index]
)
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
@ -418,7 +418,7 @@ class _MinimizerBase:
self.reports.append(report)
report.append(f"Binary search iteration {self.iteration}")
report.append(
f"From node index {start_idx}:{first_node_name} to {end_idx-1}:{output_node_name}. "
f"From node index {start_idx}:{first_node_name} to {end_idx - 1}:{output_node_name}. "
f"Size of the interested node list is {len(nodes)}"
)
cur_nodes: NodeSet = set(nodes)
@ -428,7 +428,6 @@ class _MinimizerBase:
self._run_and_compare(split_module, submod_name, [output_node_name])
except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
if len(nodes) == 1:
report.append(
f"This is the last node in the sub-module. "
@ -504,13 +503,13 @@ class _MinimizerBase:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, [node.name])
self.print_report(report)
except (FxNetMinimizerResultMismatchError):
except FxNetMinimizerResultMismatchError:
culprits.add(node)
report.append(f"Found culprit from numeric error: {node}")
self.print_report(report)
if not self.settings.find_all:
return culprits
except (FxNetMinimizerRunFuncError):
except FxNetMinimizerRunFuncError:
culprits.update(cur_nodes)
report.append(f"Found culprit from run error: {node}")
self.print_report(report)
@ -519,8 +518,9 @@ class _MinimizerBase:
return culprits
def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool) -> int:
def _block_traverse_impl(
self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool
) -> int:
"""
Recursive block search implementation.
find_last_node: If True, search for the last node which result in numerics difference
@ -529,7 +529,7 @@ class _MinimizerBase:
report: List[str] = []
mid = (start_idx + end_idx) // 2
cur_nodes_list: NodeList = nodes[:mid + 1] if find_last_node else nodes[mid:]
cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:]
if self.exclusion_fn:
self.exclusion_fn(cur_nodes_list, -1, -1)
@ -561,16 +561,20 @@ class _MinimizerBase:
try:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, [last_node_name], report_idx)
self._run_and_compare(
split_module, submod_name, [last_node_name], report_idx
)
except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
report.append(f"Culprits found from node {first_node_name} to {last_node_name}.")
report.append(
f"Culprits found from node {first_node_name} to {last_node_name}."
)
if start_idx == mid:
report.extend(
[
"This is the last node in the sub-module. ",
"Search in the current branch is successful with node :",
f"{start_idx}, node name: {nodes[start_idx].name}."
f"{start_idx}, node name: {nodes[start_idx].name}.",
]
)
self.print_report(report)
@ -585,9 +589,13 @@ class _MinimizerBase:
if find_last_node:
return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
else:
return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node)
return self._block_traverse_impl(
nodes, mid + 1, end_idx, find_last_node
)
else:
report.append(f"Culprits not found from node start to {mid}:{nodes[mid].name}.")
report.append(
f"Culprits not found from node start to {mid}:{nodes[mid].name}."
)
if start_idx == mid:
report.extend(
@ -607,12 +615,15 @@ class _MinimizerBase:
self.print_report(report)
if find_last_node:
return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node)
return self._block_traverse_impl(
nodes, mid + 1, end_idx, find_last_node
)
else:
return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> NodeSet:
def _block_traverse(
self, nodes: NodeList, find_last_node: Optional[bool]
) -> NodeSet:
"""
Traverse topologically sorted node list
Find minimium block (start_idx, end_idx) which contains the culprit
@ -639,10 +650,7 @@ class _MinimizerBase:
self.print_report(last_node_report)
end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True)
last_node_report.extend(
[
"Finish Pass 1",
f"Find end_idx = {end_idx}:{nodes[end_idx].name}"
]
["Finish Pass 1", f"Find end_idx = {end_idx}:{nodes[end_idx].name}"]
)
self.print_report(last_node_report)
@ -650,25 +658,28 @@ class _MinimizerBase:
if run_both or not find_last_node:
first_node_report = ["Start searching for first node in culprit"]
self.print_report(first_node_report)
start_idx = self._block_traverse_impl(nodes[0:end_idx + 1], start_idx, end_idx, False)
start_idx = self._block_traverse_impl(
nodes[0 : end_idx + 1], start_idx, end_idx, False
)
first_node_report.append("*" * 50)
self.reports.append(first_node_report)
first_node_report.extend(
[
"Finish Pass 2",
f"Find start_idx = {start_idx}:{nodes[start_idx].name}"
f"Find start_idx = {start_idx}:{nodes[start_idx].name}",
]
)
self.print_report(first_node_report)
# step 3: form module with minimum culprits
culprits.update(nodes[start_idx:end_idx + 1])
result_report = [f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"]
culprits.update(nodes[start_idx : end_idx + 1])
result_report = [
f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"
]
self.reports.append(result_report)
self.print_report(result_report)
return culprits
def _defined_traverse(self, nodes: NodeList) -> NodeSet:
"""
run user defined `nodes` and determine if it is a culprit.
@ -735,7 +746,9 @@ class _MinimizerBase:
return culprits
def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet:
def _skip_traverse_impl(
self, all_nodes: NodeList, start_idx: int, end_idx: int
) -> NodeSet:
"""
Skip certain nodes in graph based on settings
"""
@ -754,19 +767,19 @@ class _MinimizerBase:
self.iteration += 1
report.append(f" Nodes block {self.iteration}.")
report.append(
f"From node index {start_idx} to {end_idx-1}. "
f"From node index {start_idx} to {end_idx - 1}. "
f"Size of the interested node list is {len(nodes)}"
)
try:
split_module, submod_name = self._build_submodule(cur_nodes)
self._run_and_compare(split_module, submod_name, [])
except (FxNetMinimizerResultMismatchError):
except FxNetMinimizerResultMismatchError:
culprits.update(cur_nodes)
report.append(f"Found culprit from numeric error: {cur_nodes}")
self.print_report(report)
return culprits
except (FxNetMinimizerRunFuncError):
except FxNetMinimizerRunFuncError:
culprits.update(cur_nodes)
report.append(f"Found culprit from run error: {cur_nodes}")
self.print_report(report)
@ -776,7 +789,6 @@ class _MinimizerBase:
self.print_report(report)
return set()
def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet:
"""
Skip certain nodes in graph based on settings
@ -787,7 +799,7 @@ class _MinimizerBase:
culprits = set()
while idx < num_nodes:
node = all_nodes[idx]
if (node.name in skip_nodes): # skip the node
if node.name in skip_nodes: # skip the node
if idx > start_idx:
culprits = self._skip_traverse_impl(all_nodes, start_idx, idx)
start_idx = idx + 1
@ -797,8 +809,6 @@ class _MinimizerBase:
return culprits
def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
"""
Collect nodes in the model that between nodes with name of `start` and `end`.
@ -911,8 +921,10 @@ class _MinimizerBase:
return self._accumulate_traverse(nodes)
if self.settings.traverse_method == "skip":
if (skip_nodes is None):
raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
if skip_nodes is None:
raise RuntimeError(
"'skip_nodes' can't be None when 'traverse_method' is 'skip'."
)
return self._skip_traverse(nodes, skip_nodes)
if self.settings.traverse_method == "defined":

View File

@ -5,11 +5,19 @@ import typing as t
import torch
import torch.fx
from torch.fx._compatibility import compatibility
from .shape_prop import TensorMetadata
from .tools_common import get_node_target, CALLABLE_NODE_OPS
from .tools_common import CALLABLE_NODE_OPS, get_node_target
__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain']
__all__ = [
"OperatorSupportBase",
"OperatorSupport",
"create_op_support",
"chain",
"OpSupports",
"any_chain",
]
# fx.Node.target typename, as returned by `get_node_target()`
TargetTypeName = str
@ -28,6 +36,7 @@ SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
@compatibility(is_backward_compatible=False)
class OperatorSupportBase(abc.ABC):
"""Interface for determining if a fx.Node is supported by a backend"""
@abc.abstractmethod
def is_node_supported(
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
@ -57,10 +66,7 @@ class OperatorSupport(OperatorSupportBase):
_support_dict: SupportDict
def __init__(
self,
support_dict: t.Optional[SupportDict] = None
):
def __init__(self, support_dict: t.Optional[SupportDict] = None):
self._support_dict = support_dict or {}
def is_node_supported(
@ -139,11 +145,13 @@ def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase
`IsNodeSupported` has the same call signature as
`OperatorSupportBase.is_node_supported`
"""
class FunctionalOperatorSupport(OperatorSupportBase):
def is_node_supported(
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
return is_node_supported(submodules, node)
return FunctionalOperatorSupport()
@ -153,11 +161,10 @@ def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
instance by evaluating each input `OperatorSupportBase` instance, and returns False if
any of it reports False.
"""
def _chain(submods, node) -> bool:
return all(
x.is_node_supported(submods, node)
for x in op_support
)
return all(x.is_node_supported(submods, node) for x in op_support)
return create_op_support(_chain)
@ -167,11 +174,10 @@ def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
instance by evaluating each input `OperatorSupportBase` instance, and returns True if
any of it reports True.
"""
def _any_chain(submods, node) -> bool:
return any(
x.is_node_supported(submods, node)
for x in op_support
)
return any(x.is_node_supported(submods, node) for x in op_support)
return create_op_support(_any_chain)
@ -180,6 +186,7 @@ class OpSupports:
"""A set of atomic `OperatorSupportBase` instances that can be combined together
to form more complex operator support logic.
"""
@classmethod
def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
"""Report a node as non-supported, if any of its arguments is of dtype"""
@ -193,6 +200,7 @@ class OpSupports:
if arg_dtype == dtype:
return False
return True
return create_op_support(_decline_if_input_dtype)
@classmethod
@ -200,16 +208,22 @@ class OpSupports:
"""
If a node has a name that is in the disallow set, reported it as non-supported.
"""
def _decline_if_node_in_names(
submodules: t.Mapping[str, torch.nn.Module],
node: torch.fx.Node,
) -> bool:
return node.name not in disallow_set
return create_op_support(_decline_if_node_in_names)
def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
assert isinstance(arg, torch.fx.Node)
tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr]
dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
dtype = (
tensor_meta.dtype
if isinstance(tensor_meta, TensorMetadata)
else arg.meta["type"]
)
return dtype

View File

@ -1,35 +1,59 @@
from torch.fx.graph_module import GraphModule
from typing import Any, Callable, Dict, List, Tuple, Type
import torch
import torch.nn as nn
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
__all__ = [
"default_matching",
"extract_attrs_for_lowering",
"lift_lowering_attrs_to_nodes",
]
__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
# Matching method matches the attribute name of current version to the attribute name of `target_version`
@compatibility(is_backward_compatible=False)
def default_matching(name: str, target_version: int) -> str:
"""Default matching method
"""
"""Default matching method"""
return name
# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = {
torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
torch.nn.modules.conv.Conv2d: (
1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching
1,
[
"weight",
"bias",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"padding_mode",
],
default_matching,
),
torch.nn.modules.batchnorm.BatchNorm2d: (
2,
["weight", "bias", "running_mean", "running_var", "eps"],
default_matching,
),
torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching),
torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
torch.nn.modules.pooling.MaxPool2d: (
1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching
1,
["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"],
default_matching,
),
torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
}
@compatibility(is_backward_compatible=False)
def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
"""If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
@ -41,21 +65,25 @@ def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
if type(mod) in module_fetch_book:
version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
if version < mod._version:
raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
"please upgrade the module_fetch_book, open an issue and @842974287 "
"or report a bug to AIACC team directly.")
raise RuntimeError(
f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
"please upgrade the module_fetch_book, open an issue and @842974287 "
"or report a bug to AIACC team directly."
)
for attr in param_to_fetch:
attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
else:
raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, "
"please add it to the module_fetch_book, open an issue and @842974287 "
"or report a bug to AIACC team directly.")
raise RuntimeError(
f"{torch.typename(mod)} is not in the module_fetch_book yet, "
"please add it to the module_fetch_book, open an issue and @842974287 "
"or report a bug to AIACC team directly."
)
return attrs_for_lowering
@compatibility(is_backward_compatible=False)
def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
"""Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.
"""
"""Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module."""
submodules = dict(fx_module.named_modules())
for node in fx_module.graph.nodes:
@ -63,4 +91,6 @@ def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
if isinstance(submodules[node.target], GraphModule):
lift_lowering_attrs_to_nodes(submodules[node.target])
else:
node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target])
node.attrs_for_lowering = extract_attrs_for_lowering(
submodules[node.target]
)

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import logging
from functools import wraps
from inspect import unwrap
from typing import Callable, List, Optional
import logging
logger = logging.getLogger(__name__)
@ -15,6 +16,7 @@ __all__ = [
"these_before_those_pass_constraint",
]
# for callables which modify object inplace and return something other than
# the object on which they act
def inplace_wrapper(fn: Callable) -> Callable:
@ -36,6 +38,7 @@ def inplace_wrapper(fn: Callable) -> Callable:
return wrapped_fn
def log_hook(fn: Callable, level=logging.INFO) -> Callable:
"""
Logs callable output.
@ -48,16 +51,13 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable:
```
def my_pass(d: Dict) -> bool:
changed = False
if 'foo' in d:
d['foo'] = 'bar'
if "foo" in d:
d["foo"] = "bar"
changed = True
return changed
pm = PassManager(
passes=[
inplace_wrapper(log_hook(my_pass))
]
)
pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))])
```
Args:
@ -67,6 +67,7 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable:
Returns:
wrapped_fn (Callable[Type1, Type2])
"""
@wraps(fn)
def wrapped_fn(gm):
val = fn(gm)
@ -76,8 +77,11 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable:
return wrapped_fn
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
def loop_pass(
base_pass: Callable,
n_iter: Optional[int] = None,
predicate: Optional[Callable] = None,
):
"""
Convenience wrapper for passes which need to be applied multiple times.
@ -154,9 +158,7 @@ def these_before_those_pass_constraint(these: Callable, those: Callable):
loop_pass(pass_a, 5),
]
constraints = [
these_before_those_pass_constraint(pass_a, pass_b)
]
constraints = [these_before_those_pass_constraint(pass_a, pass_b)]
```
Args:

View File

@ -1,32 +1,38 @@
# mypy: allow-untyped-defs
import _operator
import itertools
from collections import defaultdict
from enum import Enum
from typing import Dict, Set
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.fx import Node
from torch.fx._compatibility import compatibility
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.utils._pytree import tree_map_only
from torch.utils import _pytree as pytree
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_map_only
import _operator
from enum import Enum
import itertools
from typing import Set, Dict
from collections import defaultdict
__all__ = ['reinplace']
__all__ = ["reinplace"]
class _ViewType(Enum):
NonView = 0
SingleOutputView = 1
MultiOutputView = 2
def _is_view_op(tgt):
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
schema = tgt._schema
if len(schema.arguments) > 0:
first_arg = schema.arguments[0]
# check if op is a view
return first_arg.alias_info is not None and not first_arg.alias_info.is_write
return (
first_arg.alias_info is not None and not first_arg.alias_info.is_write
)
def _get_view_type(tgt) -> _ViewType:
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
@ -36,7 +42,7 @@ def _get_view_type(tgt) -> _ViewType:
# check if op is a view
if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
# check if op is a multi-output view
if '*' in first_arg.alias_info.after_set:
if "*" in first_arg.alias_info.after_set:
return _ViewType.MultiOutputView
else:
return _ViewType.SingleOutputView
@ -54,12 +60,11 @@ def _get_view_type(tgt) -> _ViewType:
# to sanity check that our aliasing information is correct.
@compatibility(is_backward_compatible=False)
class _FunctionalizationMetadataProp(torch.fx.Interpreter):
def run_node(self, node: Node):
self.node_counter += 1
result = super().run_node(node)
node.meta['fake_result'] = result
node.meta['node_idx'] = self.node_counter
node.meta["fake_result"] = result
node.meta["node_idx"] = self.node_counter
# (1) Update metadata with the list of nodes that are used by this node
# copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
@ -69,11 +74,11 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter):
node_args = node_args[1:]
# (2) Update metadata to track aliasing information about view tensor nodes.
if node.op == 'call_function':
if node.op == "call_function":
view_type = _get_view_type(node.target)
if view_type == _ViewType.SingleOutputView:
assert isinstance(node.args[0], Node)
node.meta['view_of'] = node.args[0]
node.meta["view_of"] = node.args[0]
elif view_type == _ViewType.MultiOutputView:
self.multi_output_view_nodes[node] = node.args[0]
@ -95,38 +100,52 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter):
# Note: we could also track indexing info here for multi-output views.
# I don't think this metadata is strictly needed for de-functionalization.
assert isinstance(maybe_base_of_view, Node)
node.meta['view_of'] = maybe_base_of_view
node.meta["view_of"] = maybe_base_of_view
if 'view_of' in node.meta:
if "view_of" in node.meta:
# We're linking the current node with its first argument as views.
# Assert here that this is actually the case, and their storages are the same.
assert isinstance(node.meta['fake_result'], FakeTensor)
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
assert isinstance(node.meta["fake_result"], FakeTensor)
assert isinstance(node.meta["view_of"].meta["fake_result"], FakeTensor)
view_storage = StorageWeakRef(node.meta["fake_result"]._typed_storage())
base_storage = StorageWeakRef(
node.meta["view_of"].meta["fake_result"]._typed_storage()
)
assert view_storage == base_storage
return result
def propagate(self, *args):
self.multi_output_view_nodes = {}
self.node_counter = -1
with FakeTensorMode() as mode:
fake_args = [mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
fake_args = [
mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
]
return super().run(*fake_args)
def _schemas_match(functional_schema, inplace_schema):
names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
names_match = (
inplace_schema.name.endswith("_")
and inplace_schema.name[:-1] == functional_schema.name
)
arg_types_match = len(functional_schema.arguments) == len(
inplace_schema.arguments
) and all(
a1.type == a2.type
for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)
)
# for the inplace op, its first argument should be mutable
assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
assert (
inplace_schema.arguments[0].alias_info is not None
and inplace_schema.arguments[0].alias_info.is_write
)
# and its remaining arguments shouldn't be.
assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
return names_match and arg_types_match
# TODO: this should be beefed up to be able to properly re-inplace with:
# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
# - out= ops (e.g. angle -> angle.out)
@ -143,17 +162,20 @@ def _maybe_get_inplace_op(op):
op_namespace = op.__module__.split(".")[-1]
op_base_name = op.overloadpacket.__name__
maybe_namespace_module = getattr(torch.ops, op_namespace)
maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
maybe_inplace_op = (
None
if maybe_namespace_module is None
else getattr(maybe_namespace_module, f"{op_base_name}_", None)
)
if maybe_inplace_op is None:
return None
inplace_overloads = [
getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
getattr(maybe_inplace_op, overload_name)
for overload_name in maybe_inplace_op.overloads()
]
inplace_overloads_with_matching_schemas = [
f
for f in inplace_overloads
if _schemas_match(op._schema, f._schema)
f for f in inplace_overloads if _schemas_match(op._schema, f._schema)
]
# Just because foo() and foo_() are both existing operators,
# They aren't guaranteed to have compatible schemas.
@ -165,6 +187,7 @@ def _maybe_get_inplace_op(op):
inplace_op = inplace_overloads_with_matching_schemas[0]
return inplace_op
_VIEW_INVERSE_MAP = {
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
@ -172,6 +195,7 @@ _VIEW_INVERSE_MAP = {
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
}
# This function, given a set of set of (aliased) tensor nodes,
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
# in the node ordering.
@ -186,17 +210,21 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
usage_nodes = t.users
for n in usage_nodes:
# We only care about usages after the current node
if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index:
if "node_idx" not in n.meta or n.meta["node_idx"] <= op_index:
continue
# We also don't care about intermediate view ops.
# They only matter if their output is then used elsewhere
# (either in an out-of-place op, or as an output to the function).
if n in tensor_aliases:
if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
if (
isinstance(n.target, torch._ops.OpOverload)
or n.target == _operator.getitem
):
continue
nodes_used_after.add(n)
return nodes_used_after
# Given an op that we're trying to re-inplace, "b = foo(a)",
# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
@ -204,23 +232,27 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
# as "alias"
def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
def _get_view_inverse_node_usages(
later_node_usages: Set[Node], self_aliases: Set[Node]
) -> Set[Node]:
def matching_view_metadata(a, b):
return a.size() == b.size() and \
a.stride() == b.stride() and \
a.storage_offset() == b.storage_offset()
return (
a.size() == b.size()
and a.stride() == b.stride()
and a.storage_offset() == b.storage_offset()
)
view_inverse_nodes = set()
# Go through them in node order, so we can see chains of view_scatter ops.
for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
for n in sorted(later_node_usages, key=lambda x: x.meta["node_idx"]):
if n.target not in _VIEW_INVERSE_MAP:
continue
base = n.args[0]
mutated_view = n.args[1]
assert isinstance(base, Node)
assert isinstance(base.meta['fake_result'], FakeTensor)
assert isinstance(base.meta["fake_result"], FakeTensor)
assert isinstance(mutated_view, Node)
assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
assert isinstance(mutated_view.meta["fake_result"], FakeTensor)
# Check that this view_inverse op actually corresponds to taking doing the inverse
# of one of our existing self_alias nodes.
original_view = _VIEW_INVERSE_MAP[n.target]
@ -229,18 +261,21 @@ def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Se
# that was created from some op `alias = foo(base, args...)`
# such that the current _scatter op "inverts" that foo call.
# We can check that by running the original op again, and checking that the strides match.
if 'view_of' not in self_alias.meta:
if "view_of" not in self_alias.meta:
continue
self_alias_base = self_alias.meta['view_of']
self_alias_base = self_alias.meta["view_of"]
try:
# The we're trying to re-use the args from the view_scatter call inside of the corresponding
# view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
# of the current alias we're looking at.
view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
expected_metadata = self_alias.meta['fake_result']
view_replay_metadata = original_view(
self_alias_base.meta["fake_result"], *n.args[2:], **n.kwargs
)
expected_metadata = self_alias.meta["fake_result"]
# If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
matching_view_metadata(view_replay_metadata, expected_metadata):
if matching_view_metadata(
self_alias_base.meta["fake_result"], base.meta["fake_result"]
) and matching_view_metadata(view_replay_metadata, expected_metadata):
view_inverse_nodes.add(n)
except Exception:
continue
@ -471,25 +506,29 @@ def reinplace(gm, *sample_args):
# NOTE: later, we'll need to add an optimization for fully recovering performance
# on programs that mutate inputs.
input_storages = {
StorageWeakRef(
node.meta['fake_result']._typed_storage()
) for node in gm.graph.nodes if (node.op == 'placeholder' and isinstance(node.meta['fake_result'], torch.Tensor))}
StorageWeakRef(node.meta["fake_result"]._typed_storage())
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.meta["fake_result"], torch.Tensor)
)
}
# We also need to know for a given node, what are all of its aliasing nodes.
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
for n in gm.graph.nodes:
if 'fake_result' in n.meta:
if "fake_result" in n.meta:
# Tree-mapping because some ops can return lists of tensors.
def _add_to_map(x):
if isinstance(x, FakeTensor):
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
pytree.tree_map_(_add_to_map, n.meta['fake_result'])
pytree.tree_map_(_add_to_map, n.meta["fake_result"])
# inplace-ify functional ops, subject to the constraints written below.
all_later_view_inverse_nodes_to_delete = set()
for node in gm.graph.nodes:
if node.op == 'call_function':
if node.op == "call_function":
# Today, the re-inplace pass on directly acts on:
# - functional ops with an inplace variant
# - {view}_scatter ops that can be potentially removed from the graph.
@ -512,8 +551,8 @@ def reinplace(gm, *sample_args):
# (We could potentially swizzle this into larger_tensor.add_(scalar_tensor),
# this is probably an optimization to revisit later).
self_arg = node.args[0]
self_flattened = pytree.tree_leaves(self_arg.meta['fake_result'])
node_flattened = pytree.tree_leaves(node.meta['fake_result'])
self_flattened = pytree.tree_leaves(self_arg.meta["fake_result"])
node_flattened = pytree.tree_leaves(node.meta["fake_result"])
self_has_wrong_metadata = False
if len(self_flattened) == len(node_flattened):
for self_meta, node_meta in zip(self_flattened, node_flattened):
@ -532,7 +571,9 @@ def reinplace(gm, *sample_args):
continue
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
self_arg_storage = StorageWeakRef(
self_arg.meta["fake_result"]._typed_storage()
)
if self_arg_storage in input_storages:
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
continue
@ -542,14 +583,20 @@ def reinplace(gm, *sample_args):
# so we prevent re-inplacing in this case.
continue
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
self_arg_storage = StorageWeakRef(
self_arg.meta["fake_result"]._typed_storage()
)
self_aliases = storage_to_nodes[self_arg_storage]
# First, we find all later usages of any of the aliases of self_arg.
later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
later_node_usages = _get_all_later_node_usages(
self_aliases, node.meta["node_idx"]
)
# Then, we check if any of those later usages are actually view_scatter ops
# that are safe to fully remove.
later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
later_view_inverse_node_usages = _get_view_inverse_node_usages(
later_node_usages, self_aliases
)
# Step 2: Check to see if the input to the op is re-used later in the graph.
# If not (same goes for its aliases), then this op is safe to re-in place.
@ -565,7 +612,10 @@ def reinplace(gm, *sample_args):
# we would prefer to remove it from the graph entirely,
# and instead copy_() the slice directly into the larger tensor.
# See the description of the algorithm for a full example.
if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete:
if (
node.target in _VIEW_INVERSE_MAP
and node not in all_later_view_inverse_nodes_to_delete
):
view_op = _VIEW_INVERSE_MAP[node.target]
# Before:
# base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
@ -576,13 +626,23 @@ def reinplace(gm, *sample_args):
mutated_slice_node = node.args[1]
remaining_slice_args = node.args[2:]
slice_node = gm.graph.create_node(
'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
"call_function",
view_op,
(self_arg,) + tuple(remaining_slice_args),
node.kwargs,
)
gm.graph.create_node(
'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
"call_function",
torch.ops.aten.copy_.default,
(
slice_node,
mutated_slice_node,
),
{},
)
# Add the slice_scatter node to our "nodes to delete" list.
all_later_view_inverse_nodes_to_delete.add(node)
else:
# Step 3b: Check to see if this operator has an inplace variant.
maybe_inplace_op = _maybe_get_inplace_op(node.target)
@ -597,19 +657,29 @@ def reinplace(gm, *sample_args):
# Hmm... morally I think we also want to keep the `fake_result` metadata
# up to date here, but I'm not sure how easy it is to do.
# Maybe it's fine to wait until the end of the pass to update it.
curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
curr_node_storage = StorageWeakRef(
node.meta["fake_result"]._typed_storage()
)
storage_to_nodes[self_arg_storage].update(
storage_to_nodes[curr_node_storage]
)
storage_to_nodes[curr_node_storage].update(
storage_to_nodes[self_arg_storage]
)
# Need to remember the view_scatter view nodes we found so we can remove them alter.
all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages)
all_later_view_inverse_nodes_to_delete.update(
later_view_inverse_node_usages
)
# Step 4:
# Now that we've replaced b = a.foo() with a.foo_(),
# We need to replace any later usages of "b" with "a"
for old in itertools.chain([node], later_view_inverse_node_usages):
new = old.args[0]
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
nodes_to_update = [
n for n in old.users if n.meta["node_idx"] > node.meta["node_idx"]
]
for node_to_update in nodes_to_update:
def replace_arg(a):
@ -618,21 +688,29 @@ def reinplace(gm, *sample_args):
return a
# First, replace usages of "b" with "a"
node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
node_to_update.args = tree_map_only(
Node, replace_arg, node_to_update.args
)
node_to_update.kwargs = tree_map_only(
Node, replace_arg, node_to_update.kwargs
)
# Second, update our storage_to_nodes data structure.
old_flattened_res = pytree.tree_leaves(old.meta['fake_result'])
node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result'])
old_flattened_res = pytree.tree_leaves(old.meta["fake_result"])
node_flattened_res = pytree.tree_leaves(
node_to_update.meta["fake_result"]
)
old_res_storage = {
StorageWeakRef(
x._typed_storage()
) for x in old_flattened_res if isinstance(x, FakeTensor)}
StorageWeakRef(x._typed_storage())
for x in old_flattened_res
if isinstance(x, FakeTensor)
}
node_res_storage = {
StorageWeakRef(
x._typed_storage()
) for x in node_flattened_res if isinstance(x, FakeTensor)}
StorageWeakRef(x._typed_storage())
for x in node_flattened_res
if isinstance(x, FakeTensor)
}
# This will happen if we're updating a view op, e.g.
# e.g. replacing
@ -644,12 +722,17 @@ def reinplace(gm, *sample_args):
# We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor,
# or multiple tensors that all share the same storage.
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
new_flattened_res = pytree.tree_leaves(new.meta['fake_result'])
if (
len(old_res_storage) == 1
and len(node_res_storage) == 1
and old_res_storage == node_res_storage
):
new_flattened_res = pytree.tree_leaves(new.meta["fake_result"])
new_res_storage = {
StorageWeakRef(
x._typed_storage()
) for x in new_flattened_res if isinstance(x, FakeTensor)}
StorageWeakRef(x._typed_storage())
for x in new_flattened_res
if isinstance(x, FakeTensor)
}
assert len(new_res_storage) == 1
(new_ref,) = new_res_storage
(node_ref,) = node_res_storage
@ -666,6 +749,5 @@ def reinplace(gm, *sample_args):
for to_delete in all_later_view_inverse_nodes_to_delete:
gm.graph.erase_node(to_delete)
gm.recompile()
return gm

View File

@ -1,17 +1,19 @@
# mypy: ignore-errors
import traceback
from typing import Any, Dict, NamedTuple, Optional, Tuple
import torch
import torch.fx
import traceback
from torch._dispatch.python import enable_python_dispatcher
from torch.fx.node import Node, map_aggregate
from typing import Any, Tuple, NamedTuple, Optional, Dict
from torch.fx._compatibility import compatibility
from torch._guards import detect_fake_mode
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx._compatibility import compatibility
from torch.fx.node import map_aggregate, Node
__all__ = ["TensorMetadata", "ShapeProp"]
__all__ = ['TensorMetadata', 'ShapeProp']
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
@ -19,17 +21,20 @@ class TensorMetadata(NamedTuple):
# about a tensor within a PyTorch program.
# General Tensor metadata
shape : torch.Size
dtype : torch.dtype
requires_grad : bool
stride : Tuple[int, ...]
memory_format : Optional[torch.memory_format]
shape: torch.Size
dtype: torch.dtype
requires_grad: bool
stride: Tuple[int, ...]
memory_format: Optional[torch.memory_format]
# Quantization metadata
is_quantized : bool
is_quantized: bool
qparams: Dict[str, Any]
def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
def _extract_tensor_metadata(
result: torch.Tensor, include_contiguity=True
) -> TensorMetadata:
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
@ -59,7 +64,11 @@ def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) ->
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
qparams["scale"] = result.q_scale() # type: ignore[assignment]
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
elif qscheme in {
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
torch.per_channel_symmetric,
}:
# In this branch, scale and zero_point are expected to be tensors,
# we store the values as immutable_list in TensorMetadata for
# easier serialization downstream
@ -68,7 +77,9 @@ def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) ->
qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
return TensorMetadata(
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams
)
@compatibility(is_backward_compatible=True)
class ShapeProp(torch.fx.Interpreter):
@ -117,12 +128,14 @@ class ShapeProp(torch.fx.Interpreter):
fake_mode (FakeTensorMode): A fake mode for copying the gm
"""
def __init__(self, gm, fake_mode=None):
super().__init__(gm)
if fake_mode is None:
fake_mode = detect_fake_mode()
if fake_mode is not None:
from torch._dynamo.utils import deepcopy_to_fake_tensor
# Note:
# We need fake execution cause the inputs are fake, however, we cannot fakify the module
# - because we need to write to the tensor_meta of the real module. So we fakify to
@ -140,7 +153,7 @@ class ShapeProp(torch.fx.Interpreter):
self.real_module = self.module
def run_node(self, n : Node) -> Any:
def run_node(self, n: Node) -> Any:
try:
if self.fake_module is not None:
# Hacky swap. Alternatively, we could do this with overriding
@ -157,8 +170,7 @@ class ShapeProp(torch.fx.Interpreter):
except Exception as e:
traceback.print_exc()
raise RuntimeError(
f"ShapeProp error for: node={n.format_node()} with "
f"meta={n.meta}"
f"ShapeProp error for: node={n.format_node()} with " f"meta={n.meta}"
) from e
found_tensor = False
@ -173,9 +185,9 @@ class ShapeProp(torch.fx.Interpreter):
meta = map_aggregate(result, extract_tensor_meta)
if found_tensor:
n.meta['tensor_meta'] = meta
n.meta["tensor_meta"] = meta
n.meta['type'] = type(result)
n.meta["type"] = type(result)
return result
def propagate(self, *args):
@ -190,7 +202,10 @@ class ShapeProp(torch.fx.Interpreter):
Any: The value returned from executing the Module
"""
if self.fake_mode is not None:
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
else:
fake_args = args
return super().run(*fake_args)

View File

@ -1,19 +1,20 @@
# mypy: allow-untyped-defs
import inspect
from typing import Any, Callable, Dict, List, Optional, Set
from collections import OrderedDict
import logging
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set
import torch
from torch.fx._compatibility import compatibility
from torch.fx._utils import lazy_format_graph_code
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx._utils import lazy_format_graph_code
__all__ = ["Partition", "split_module"]
log = _LOGGER = logging.getLogger(__name__)
@compatibility(is_backward_compatible=True)
class Partition:
def __init__(self, name: str):
@ -146,9 +147,7 @@ def split_module(
log.debug(
"%s",
lazy_format_graph_code(
"pre split_module", m, colored=True
),
lazy_format_graph_code("pre split_module", m, colored=True),
)
def construct_graph(
@ -161,11 +160,20 @@ def split_module(
node.args[0] if len(node.args) > 0 else inspect.Signature.empty
)
if keep_original_node_name:
args = () if default_value is inspect.Signature.empty else (default_value,)
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type]
args = (
() if default_value is inspect.Signature.empty else (default_value,)
)
base_mod_env[node.name] = base_mod_graph.create_node(
"placeholder",
node.name,
args=args, # type: ignore[arg-type]
type_expr=node.type,
)
else:
base_mod_env[node.name] = base_mod_graph.placeholder(
node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type]
node.target, # type: ignore[arg-type]
type_expr=node.type,
default_value=default_value,
)
base_mod_env[node.name].meta = node.meta.copy()
elif node.op == "get_attr":
@ -185,9 +193,7 @@ def split_module(
orig_nodes: Dict[str, Node] = {}
symbol_to_node: Dict[sympy.Symbol, Node] = {}
def record_cross_partition_use(
def_node: Node, use_node: Optional[Node]
): # noqa: B950
def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
from torch.fx.experimental.symbolic_shapes import free_symbols
defined = getattr(def_node, "_fx_partition", None)
@ -195,7 +201,10 @@ def split_module(
log.debug(
"record_cross_partition_use %s (%s) %s (%s)",
def_node.name, defined, use_node.name if use_node is not None else "-", used
def_node.name,
defined,
use_node.name if use_node is not None else "-",
used,
)
if defined != used:
@ -234,7 +243,9 @@ def split_module(
def instantiate_node_partition_mapping(node):
partition_name = str(split_callback(node))
log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name)
log.debug(
"instantiate_node_partition_mapping %s (%s)", node.name, partition_name
)
# add node to partitions
partition = partitions.get(partition_name)
@ -249,7 +260,7 @@ def split_module(
GLOBAL_STATE_NODES = [
torch.amp._enter_autocast,
torch.amp._exit_autocast,
torch._C._set_grad_enabled
torch._C._set_grad_enabled,
]
# For grad regions:
@ -280,10 +291,10 @@ def split_module(
# rely on later, but this needs some extra work. Quick fix first.
# See https://github.com/pytorch/pytorch/issues/130534
if (
(val := node.meta.get("example_value")) is not None and
isinstance(val, torch.SymInt) and
isinstance(s0 := val.node.expr, sympy.Symbol) and
s0 not in symbol_to_node
(val := node.meta.get("example_value")) is not None
and isinstance(val, torch.SymInt)
and isinstance(s0 := val.node.expr, sympy.Symbol)
and s0 not in symbol_to_node
):
symbol_to_node[val.node.expr] = node
@ -344,9 +355,10 @@ def split_module(
if assert_monotonically_increasing:
pid = split_callback(node)
assert highest_partition <= pid, \
("autocast or set_grad_enabled require monotonically increasing partitions:"
f"highest: {highest_partition}, this node's: {pid}")
assert highest_partition <= pid, (
"autocast or set_grad_enabled require monotonically increasing partitions:"
f"highest: {highest_partition}, this node's: {pid}"
)
highest_partition = pid
# do not capture cross-partition dependencies for global state nodes as they will be
@ -392,7 +404,9 @@ def split_module(
kwargs={},
type_expr=node.type,
)
new_node.meta = node.meta.copy() # is it really a good idea to copy this?
new_node.meta = (
node.meta.copy()
) # is it really a good idea to copy this?
partition.environment[node] = new_node
# add placeholders to partition inputs
@ -425,7 +439,9 @@ def split_module(
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise AttributeError(f"Operator target {node.target} not found!")
raise AttributeError(
f"Operator target {node.target} not found!"
)
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
target = "_".join(target_atoms)
@ -467,7 +483,9 @@ def split_module(
kwargs={},
type_expr=exit_node.type,
)
new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this?
new_node.meta = (
exit_node.meta.copy()
) # is it really a good idea to copy this?
# original module environment dict mapping node names to nodes
orig_mod_env: Dict[str, Node] = {}
@ -520,7 +538,9 @@ def split_module(
if keep_original_order:
# first get the attr nodes required by this partition
orig_mod_attr_nodes: List[Node] = [
orig_mod_env[key] for key in partition.inputs if key not in original_order
orig_mod_env[key]
for key in partition.inputs
if key not in original_order
]
for node in original_order:
@ -568,8 +588,6 @@ def split_module(
ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
log.debug(
"%s",
lazy_format_graph_code(
"post split_module", ret, colored=True
),
lazy_format_graph_code("post split_module", ret, colored=True),
)
return ret

View File

@ -10,6 +10,7 @@ from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
from .tools_common import NodeList
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]

View File

@ -1,40 +1,44 @@
# mypy: allow-untyped-defs
import argparse
import copy
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
import logging
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple
import torch
from torch.fx.passes.graph_manipulation import get_size_of_node
from torch.fx.node import map_arg
from torch.fx._compatibility import compatibility
from torch.fx.node import map_arg
from torch.fx.passes.graph_manipulation import get_size_of_node
from .operator_support import (
get_node_target,
OperatorSupportBase,
)
from .graph_drawer import FxGraphDrawer
from .operator_support import get_node_target, OperatorSupportBase
from .shape_prop import ShapeProp
from .split_utils import split_by_tags
from .tools_common import (
FxNetAccFusionsFinder,
CALLABLE_NODE_OPS,
Tensors,
FxNetAccFusionsFinder,
is_node_output_tensor,
NodeList,
NodeSet,
is_node_output_tensor,
Tensors,
)
__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
__all__ = [
"FxNetAccNodesFinder",
"FxNetSplitterInternalError",
"Subgraph",
"SplitResult",
"generate_inputs_for_submodules",
]
_LOGGER = logging.getLogger(__name__)
DEFAULT_MIN_ACC_MODULE_SIZE = 1
DEFAULT_SKIP_FUSION = False
DEFAULT_ALLOW_NON_TENSOR = False
class _SplitterSettingBase:
def __init__(
self,
@ -82,9 +86,15 @@ class _SplitterSettingBase:
)
args, _unknown = parser.parse_known_args()
self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
self.min_acc_module_size: int = (
args.min_acc_module_size
if args.min_acc_module_size
else min_acc_module_size
)
self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
self.allow_non_tensor: bool = (
args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
)
self.max_acc_splits: int = max_acc_splits
@ -114,9 +124,7 @@ class FxNetAccNodesFinder:
self.allow_non_tensor = allow_non_tensor
self.acc_nodes: NodeSet = set()
def reduce_acc_nodes_non_tensor_input_helper(
self, cpu_worklist: NodeList
):
def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
"""
Transitively excludes nodes from ACC supported set.
For every node in the worklist:
@ -190,10 +198,12 @@ class FxNetAccNodesFinder:
return self.acc_nodes
@compatibility(is_backward_compatible=False)
class FxNetSplitterInternalError(Exception):
pass
@compatibility(is_backward_compatible=False)
@dataclass
class Subgraph:
@ -201,6 +211,7 @@ class Subgraph:
nodes: NodeList
device_ordinal: Optional[int] = None
@compatibility(is_backward_compatible=False)
class SplitResult(NamedTuple):
"""
@ -243,7 +254,9 @@ def generate_inputs_for_submodules(
submodule_to_names = {mod: name for name, mod in model.named_modules()}
def pre_forward(module, module_inputs):
results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
results[submodule_to_names[module]] = (
copy.deepcopy(module_inputs) if deepcopy else module_inputs
)
for name, mod in model.named_modules():
if name in target_submodules:
@ -308,7 +321,7 @@ class _SplitterBase:
"""
# PCIe bandwidth for the backend, default to 100 GB/s
PCIe_BW = 100 * 2 ** 30
PCIe_BW = 100 * 2**30
def __init__(
self,
@ -335,7 +348,9 @@ class _SplitterBase:
self.settings = settings
self.operator_support = operator_support
self.sample_input = sample_input
self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
self.acc_nodes = FxNetAccNodesFinder(
self.module, self.operator_support, self.settings.allow_non_tensor
)()
if self.settings.skip_fusion:
self.fusions = {}
@ -357,11 +372,11 @@ class _SplitterBase:
# ===============================================================
def get_node_submodule_map(self) -> Dict[str, str]:
""" Returns a map from node name to submodule name, e.g.
node: main_module_impl_impl_over_arch_unary_multiple_embedding
_pooling_embedding_pooling_sparse_entity_equivalence_key
_proxy_embedding_bag
maps to submodule name of: _run_on_acc_1
"""Returns a map from node name to submodule name, e.g.
node: main_module_impl_impl_over_arch_unary_multiple_embedding
_pooling_embedding_pooling_sparse_entity_equivalence_key
_proxy_embedding_bag
maps to submodule name of: _run_on_acc_1
"""
return self._node_submodule_map
@ -411,9 +426,7 @@ class _SplitterBase:
return mod
def _find_culprit(
self, mod: torch.fx.GraphModule, inputs: Tensors
) -> str:
def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str:
"""
When an error occurs during lowering or running the lowered mod, we use this
function to find culprits in the `mod` that causes the error.
@ -492,7 +505,9 @@ class _SplitterBase:
supported_nodes.append(node)
supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
else:
unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
unsupported_node_types[target].add(
(arg_dtypes_tuple, kwarg_dtypes_tuple)
)
if dump_graph:
self._draw_graph_based_on_node_support(self.module, supported_nodes)
@ -527,7 +542,11 @@ class _SplitterBase:
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
for i, subgraph in enumerate(subgraphs):
reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
reports += (
f"_run_on_acc_{i}: "
if subgraph.is_acc
else f"{self.non_acc_submodule_name}{i}: "
)
reports += f"{len(subgraph.nodes)} node(s)\n"
self.tag(subgraphs)
@ -535,9 +554,7 @@ class _SplitterBase:
split_mod.eval()
if dump_graph:
drawer = FxGraphDrawer(
split_mod, "preview", ignore_getattr=True
)
drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True)
dot_graphs = drawer.get_all_dot_graphs()
for name, dot_graph in dot_graphs.items():
# pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
@ -564,9 +581,7 @@ class _SplitterBase:
handle.remove()
return sub_inputs
submod_inputs = get_submod_inputs(
split_mod, submod, self.sample_input
)
submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input)
ShapeProp(submod).propagate(*submod_inputs)
total_input_bytes = 0
@ -649,9 +664,7 @@ class _SplitterBase:
return result
def update_reverse_deps_for_fusions(
self, deps: Dict[torch.fx.Node, NodeSet]
):
def update_reverse_deps_for_fusions(self, deps: Dict[torch.fx.Node, NodeSet]):
processed_node = set()
for node, fusion in self.fusions.items():
@ -853,7 +866,11 @@ class _SplitterBase:
def tag(self, subgraphs: List[Subgraph]):
self.tags = []
for subgraph in subgraphs:
tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
tag = (
f"_run_on_acc_{len(self.tags)}"
if subgraph.is_acc
else f"{self.non_acc_submodule_name}{len(self.tags)}"
)
self.tags.append(tag)
for node in subgraph.nodes:
if hasattr(node, "tag"):
@ -863,7 +880,9 @@ class _SplitterBase:
self._node_submodule_map[node.name] = tag
def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
split_module = split_by_tags(self.module, self.tags, return_tuple=self._return_tuple)
split_module = split_by_tags(
self.module, self.tags, return_tuple=self._return_tuple
)
if remove_tag:
for node in self.module.graph.nodes:
if hasattr(node, "tag"):
@ -875,7 +894,9 @@ class _SplitterBase:
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
print(
f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs"
)
self.tag(subgraphs)
return self.split()
@ -894,5 +915,7 @@ class _SplitterBase:
"result in performance issues."
)
submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
submodule_inputs = generate_inputs_for_submodules(
split_module, self.sample_input, submodule_names
)
return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)

View File

@ -26,9 +26,7 @@ class TestPassManager(unittest.TestCase):
def test_these_before_those_pass_constraint(self) -> None:
passes = [lambda x: 2 * x for _ in range(10)]
constraint = these_before_those_pass_constraint(passes[-1], passes[0])
pm = PassManager(
[inplace_wrapper(p) for p in passes]
)
pm = PassManager([inplace_wrapper(p) for p in passes])
# add unfulfillable constraint
pm.add_constraint(constraint)
@ -46,7 +44,7 @@ class TestPassManager(unittest.TestCase):
pm1.add_pass(p)
pm1.add_constraint(constraint)
output1 = pm1(1)
self.assertEqual(output1, 2 ** 3)
self.assertEqual(output1, 2**3)
passes = [lambda x: 3 * x for _ in range(3)]
constraint = these_before_those_pass_constraint(passes[0], passes[1])
@ -55,4 +53,4 @@ class TestPassManager(unittest.TestCase):
pm2.add_pass(p)
pm2.add_constraint(constraint)
output2 = pm2(1)
self.assertEqual(output2, 3 ** 3)
self.assertEqual(output2, 3**3)

View File

@ -1,15 +1,22 @@
# mypy: allow-untyped-defs
from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
import collections
from dataclasses import dataclass
import operator
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
from torch.fx._compatibility import compatibility
from torch.fx.node import _get_qualified_name
__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph']
__all__ = [
"get_acc_ops_name",
"get_node_target",
"is_node_output_tensor",
"FxNetAccFusionsFinder",
"legalize_graph",
]
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
TensorOrTensors = Union[torch.Tensor, Tensors]
@ -26,12 +33,16 @@ def get_acc_ops_name(k):
elif k.__module__ and "acc_ops" in k.__module__:
return f"acc_ops.{k.__name__}"
else:
module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
module = k.__module__.replace(
"torch._ops", "torch.ops"
) # WAR for bug in how torch.ops assigns module
return f"{module if module else ''}.{k.__name__}"
@compatibility(is_backward_compatible=False)
def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str:
def get_node_target(
submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> str:
"""
Given a `node` returns its target typename.
@ -66,6 +77,7 @@ def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.No
assert isinstance(node.target, str)
return node.target
@compatibility(is_backward_compatible=False)
def is_node_output_tensor(node: torch.fx.Node) -> bool:
"""Checks if the node output produces a Tensor or not.
@ -77,6 +89,7 @@ def is_node_output_tensor(node: torch.fx.Node) -> bool:
type_ = node.meta.get("type", None)
return type_ is not None and issubclass(type_, torch.Tensor)
@compatibility(is_backward_compatible=False)
class FxNetAccFusionsFinder:
"""
@ -297,7 +310,9 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
# If the new graph's size is not as large as the old one, then there must be
# a cycle (i.e. some node's dependencies were not satisfied.)
if len(new_graph.nodes) < len(gm.graph.nodes):
raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
raise RuntimeError(
f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
)
new_graph._codegen = gm.graph._codegen
gm.graph = new_graph
return gm

View File

@ -1 +1 @@
from .common import lift_subgraph_as_module, HolderModule, compare_graphs
from .common import compare_graphs, HolderModule, lift_subgraph_as_module

View File

@ -3,7 +3,6 @@ from typing import Dict, Tuple
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
from torch.nn import Module

View File

@ -1,15 +1,16 @@
# mypy: allow-untyped-defs
import copy
from queue import SimpleQueue
from typing import List, Dict, Optional as _Optional, Tuple
from typing import Dict, List, Optional as _Optional, Tuple
import torch.fx
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph
from torch.fx.passes.utils import lift_subgraph_as_module
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx.passes.tools_common import legalize_graph, NodeList, NodeSet
from torch.fx.passes.utils import lift_subgraph_as_module
@compatibility(is_backward_compatible=False)
def topo_sort(nodes: NodeList) -> NodeList:
@ -35,7 +36,9 @@ def topo_sort(nodes: NodeList) -> NodeList:
if indegree_map[n] == 0:
candidates.put(n)
assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes"
assert len(nodes) == len(
sorted_nodes
), "topological sorted nodes doesn't have same length as input nodes"
return sorted_nodes
@ -96,7 +99,6 @@ def fuse_as_graphmodule(
module_name: str,
partition_lookup_table: _Optional[Dict[Node, None]] = None,
) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
"""
Fuse nodes in graph_module into a GraphModule.
@ -121,9 +123,13 @@ def fuse_as_graphmodule(
# assumption: nodes are already sorted in topo order
for node in nodes:
assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}"
assert (
node.graph.owning_module is gm
), f"{node} doesn't belong to passed in graph module {gm._get_name()}"
assert not node._erased, f"{node} has been removed from owning graph"
assert node in gm.graph._find_nodes_lookup_table, f"{node} is not found in graph module {gm._get_name()}"
assert (
node in gm.graph._find_nodes_lookup_table
), f"{node} is not found in graph module {gm._get_name()}"
# validates partition doesn't introduce dependency circles in the graph
assert validate_partition(nodes), "Invalid partition, found dependency cycles"
@ -134,8 +140,10 @@ def fuse_as_graphmodule(
subgraph = Graph()
node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
node_to_placeholder: Dict[
Node, Node
] = {} # mapping of nodes from old graph to placeholder in new graph
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
# handles inputs through graph.node_copy's arg_transform functions
def remap_inputs(x):
@ -184,7 +192,9 @@ def fuse_as_graphmodule(
# lint to ensure correctness
subgraph.lint()
fused_gm: GraphModule
fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name)
fused_gm, _ = lift_subgraph_as_module(
gm, subgraph, comp_name="", class_name=module_name
)
# sub_gm's input nodes in the original module
original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
@ -196,16 +206,18 @@ def fuse_as_graphmodule(
@compatibility(is_backward_compatible=False)
def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
def insert_subgm(
gm: GraphModule,
sub_gm: GraphModule,
orig_inputs: Tuple[Node, ...],
orig_outputs: Tuple[Node, ...],
):
# add sub_gm into gm
submodule_name = sub_gm.__class__.__name__
gm.add_submodule(submodule_name, sub_gm)
# Create a call_module node in main graph.
module_node = gm.graph.call_module(
submodule_name,
args=orig_inputs,
kwargs=None)
module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None)
if len(orig_outputs) == 1:
# main_remapping[comp.orig_outputs[0]] = module_node
@ -216,24 +228,30 @@ def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node,
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
module_node.meta["val"] = tuple(orig_output.meta.get("val", None) for orig_output in orig_outputs)
module_node.meta["val"] = tuple(
orig_output.meta.get("val", None) for orig_output in orig_outputs
)
return gm
@compatibility(is_backward_compatible=False)
def erase_nodes(gm: GraphModule, nodes: NodeList):
# erase original nodes in inversed topological order
for node in reversed(nodes):
gm.graph.erase_node(node)
@compatibility(is_backward_compatible=False)
def fuse_by_partitions(gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_") -> GraphModule:
def fuse_by_partitions(
gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_"
) -> GraphModule:
for partition_id, partition in enumerate(partitions):
sorted_nodes = topo_sort(list(partition))
submodule_name = prefix + str(partition_id)
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name, partition)
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
gm, sorted_nodes, submodule_name, partition
)
insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)

View File

@ -10,6 +10,7 @@ import torch
from torch.fx import Graph, Node
from torch.fx._compatibility import compatibility
__all__ = ["SubgraphMatcher", "InternalMatch"]

View File

@ -1,19 +1,21 @@
from dataclasses import dataclass, field
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx._compatibility import compatibility
from typing import Dict, List, Any, Type, Optional, Callable
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.node import Node
__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition']
__all__ = ["get_source_partitions", "check_subgraphs_connected", "SourcePartition"]
# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
def _init_logger() -> logging.Logger:
logger = logging.getLogger(__name__)
level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper()
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter("%(filename)s > %(message)s")
@ -24,6 +26,7 @@ def _init_logger() -> logging.Logger:
logger.propagate = False
return logger
logger = _init_logger()
@ -77,8 +80,9 @@ def get_source_partitions(
# be different from "source_fn_stack", for example for the add_ node
# decomposed from batch norm. We should remove the check on "source_fn_stack"
# after we fix "torch_fn". T199561090
if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and
(torch_fn := node.meta.get("torch_fn", None)) is not None):
if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and (
torch_fn := node.meta.get("torch_fn", None)
) is not None:
node_fqn, source_fn = torch_fn
source_fn_name = source_fn.split(".")[1]
if source_fn_name in wanted_sources:
@ -86,7 +90,6 @@ def get_source_partitions(
partition = diff_modules.setdefault(node_fqn, [])
partition.append(node)
if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None:
source_fn = source_fn_st[-1]
if source_fn[1] in wanted_sources:
@ -140,7 +143,9 @@ def get_source_partitions(
@compatibility(is_backward_compatible=False) # type: ignore[misc]
def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
def check_subgraphs_connected(
subgraph1: SourcePartition, subgraph2: SourcePartition
) -> bool:
"""
Given two subgraphs A and B (in the form of a list of nodes), checks if
A has nodes connecting to at least one node in B -- aka there exists a node

View File

@ -1,29 +1,37 @@
# mypy: ignore-errors
import enum
import dis
import copy
import sys
import torch
import inspect
import operator
import collections
import copy
import dis
import enum
import inspect
import logging
import operator
import sys
from dataclasses import fields, is_dataclass
from typing import Any, Callable, Dict, Iterator, Optional, OrderedDict, Tuple
from dataclasses import is_dataclass, fields
from .graph import magic_methods, reflectable_magic_methods, Graph
from torch.utils._traceback import CapturedTraceback
from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable
from .node import Target, Node, Argument, base_types, map_aggregate
from ._compatibility import compatibility
from .operator_schemas import check_for_mutable_operation
import torch
import torch.fx.traceback as fx_traceback
from torch.utils._traceback import CapturedTraceback
__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
'Proxy', 'MetaProxy', 'Attribute', 'ParameterProxy', 'Scope',
'ScopeContextManager']
from ._compatibility import compatibility
from .graph import Graph, magic_methods, reflectable_magic_methods
from .node import Argument, base_types, map_aggregate, Node, Target
from .operator_schemas import check_for_mutable_operation
__all__ = [
"TracerBase",
"GraphAppendingTracer",
"TraceError",
"Proxy",
"MetaProxy",
"Attribute",
"ParameterProxy",
"Scope",
"ScopeContextManager",
]
log = logging.getLogger(__name__)
@ -31,7 +39,7 @@ log = logging.getLogger(__name__)
@compatibility(is_backward_compatible=False)
class Scope:
""" Scope object that records the module path and the module type
"""Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example::
@ -41,6 +49,7 @@ class Scope:
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self) -> None:
self.sub = Sub()
@ -62,7 +71,7 @@ class Scope:
@compatibility(is_backward_compatible=False)
class ScopeContextManager:
""" A context manager to track the Scope of Node during symbolic tracing.
"""A context manager to track the Scope of Node during symbolic tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
@ -102,28 +111,28 @@ _COPY_META_FIELDS = [
"quantization_tag", # TODO deprecated
"_numeric_debug_handle", # TODO deprecated
"custom",
"partitioner_tag"
"partitioner_tag",
]
@compatibility(is_backward_compatible=True)
class TracerBase:
graph: Graph
record_stack_traces : bool = False
record_stack_traces: bool = False
# Feature flag for mutable schema checking
# Enableby default in 1.12
check_mutable_operations : bool = False
check_mutable_operations: bool = False
# Feature flag for assert tracing
trace_asserts : bool = False
trace_asserts: bool = False
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes : bool = False
proxy_buffer_attributes: bool = False
# Name of the function to be traced. It will only be used when
# ``root`` is an instance of ``nn.Module``
traced_func_name: str = "forward"
# Maps the containing module's name to the operator name
scope : Scope
scope: Scope
# Records the module call stack
module_stack: OrderedDict[str, Tuple[str, Any]]
@ -132,9 +141,15 @@ class TracerBase:
node_name_to_scope: Dict[str, Tuple[str, type]]
@compatibility(is_backward_compatible=True)
def create_node(self, kind : str, target : Target,
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
type_expr : Optional[Any] = None) -> Node:
def create_node(
self,
kind: str,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
) -> Node:
"""
Inserts a graph node given target, args, kwargs, and name.
@ -143,7 +158,7 @@ class TracerBase:
want to disallow in-place operations from being recorded.
"""
if kind == 'call_function' and self.check_mutable_operations:
if kind == "call_function" and self.check_mutable_operations:
check_for_mutable_operation(target, args, kwargs)
node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
@ -182,20 +197,27 @@ class TracerBase:
node.meta["seq_nr"] = new_seq_nr
elif self.module_stack:
node.meta['nn_module_stack'] = copy.copy(self.module_stack)
node.meta["nn_module_stack"] = copy.copy(self.module_stack)
log.debug("create_node %s", node)
return node
@compatibility(is_backward_compatible=True)
def proxy(self, node: Node) -> 'Proxy':
def proxy(self, node: Node) -> "Proxy":
return Proxy(node, self)
@compatibility(is_backward_compatible=True)
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
name: Optional[str] = None, type_expr : Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
'''
def create_proxy(
self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], "Proxy"] = None,
):
"""
Create a Node from the given arguments, then return the Node
wrapped in a Proxy object.
@ -203,7 +225,7 @@ class TracerBase:
represents the parameter of a function. If we need to encode
a default parameter, we use the ``args`` tuple. ``args`` is
otherwise empty for ``placeholder`` Nodes.
'''
"""
args_ = self.create_arg(args)
kwargs_ = self.create_arg(kwargs)
@ -218,8 +240,7 @@ class TracerBase:
proxy = proxy_factory_fn(node)
if self.record_stack_traces and not proxy.node.stack_trace:
proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format())
proxy.node.stack_trace = "".join(CapturedTraceback.extract().format())
return proxy
@ -233,20 +254,23 @@ class TracerBase:
# the user code during tracing.
frame = inspect.currentframe()
pt_files = ['torch/fx/proxy.py',
'torch/fx/_symbolic_trace.py',
'torch/fx/experimental/proxy_tensor.py',
'torch/_ops.py',
'torch/_tensor.py',
'torch/utils/_python_dispatch.py',
'torch/_prims_common/wrappers.py',
'torch/_refs/__init__.py',
'torch/_refs/nn/functional/__init__.py',
'torch/utils/_stats.py',
]
pt_files = [
"torch/fx/proxy.py",
"torch/fx/_symbolic_trace.py",
"torch/fx/experimental/proxy_tensor.py",
"torch/_ops.py",
"torch/_tensor.py",
"torch/utils/_python_dispatch.py",
"torch/_prims_common/wrappers.py",
"torch/_refs/__init__.py",
"torch/_refs/nn/functional/__init__.py",
"torch/utils/_stats.py",
]
while frame:
frame = frame.f_back
if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files):
if frame and all(
not frame.f_code.co_filename.endswith(file) for file in pt_files
):
break
if not frame:
@ -264,11 +288,11 @@ class TracerBase:
"""
if isinstance(a, Proxy):
return a.node # most common arg type goes first
elif hasattr(a, '__fx_create_arg__'):
elif hasattr(a, "__fx_create_arg__"):
return a.__fx_create_arg__(self)
# aggregates
elif isinstance(a, tuple):
if hasattr(a, '_fields'):
if hasattr(a, "_fields"):
# NamedTuple constructors don't seem to like getting a generator
# expression as an argument to their constructor, so build this
# intermediate tuple and unpack it into the NamedTuple constructor
@ -278,10 +302,13 @@ class TracerBase:
elif isinstance(a, list):
return [self.create_arg(elem) for elem in a]
elif isinstance(a, dict):
def no_node(arg):
if isinstance(arg, Node):
raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
f"Node. Got key: {k}")
raise RuntimeError(
"Keys for dictionaries used as an argument cannot contain a "
f"Node. Got key: {k}"
)
r = {}
for k, v in a.items():
@ -294,16 +321,27 @@ class TracerBase:
r[k] = self.create_arg(v)
return r
elif isinstance(a, slice):
return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
return slice(
self.create_arg(a.start),
self.create_arg(a.stop),
self.create_arg(a.step),
)
elif isinstance(a, range):
return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
return range(
self.create_arg(a.start),
self.create_arg(a.stop),
self.create_arg(a.step),
)
elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
return a
elif is_dataclass(a):
kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)}
kwargs = {
field.name: self.create_arg(getattr(a, field.name))
for field in fields(a)
}
return self.create_node("call_function", a.__class__, (), kwargs)
elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
@ -312,37 +350,41 @@ class TracerBase:
raise NotImplementedError(f"argument of type: {type(a)}")
@compatibility(is_backward_compatible=True)
def to_bool(self, obj: 'Proxy') -> bool:
def to_bool(self, obj: "Proxy") -> bool:
"""Called when a proxy object is being converted to a boolean, such as
when used in control flow. Normally we don't know what to do because
we don't know the value of the proxy, but a custom tracer can attach more
information to the graph node using create_node and can choose to return a value.
"""
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
raise TraceError(
"symbolically traced variables cannot be used as inputs to control flow"
)
@compatibility(is_backward_compatible=True)
def iter(self, obj: 'Proxy') -> Iterator:
def iter(self, obj: "Proxy") -> Iterator:
"""Called when a proxy object is being iterated over, such as
when used in control flow. Normally we don't know what to do because
we don't know the value of the proxy, but a custom tracer can attach more
information to the graph node using create_node and can choose to return an iterator.
"""
raise TraceError('Proxy object cannot be iterated. This can be '
'attempted when the Proxy is used in a loop or'
' as a *args or **kwargs function argument. '
'See the torch.fx docs on pytorch.org for a '
'more detailed explanation of what types of '
'control flow can be traced, and check out the'
' Proxy docstring for help troubleshooting '
'Proxy iteration errors')
raise TraceError(
"Proxy object cannot be iterated. This can be "
"attempted when the Proxy is used in a loop or"
" as a *args or **kwargs function argument. "
"See the torch.fx docs on pytorch.org for a "
"more detailed explanation of what types of "
"control flow can be traced, and check out the"
" Proxy docstring for help troubleshooting "
"Proxy iteration errors"
)
@compatibility(is_backward_compatible=True)
def keys(self, obj: 'Proxy') -> Any:
def keys(self, obj: "Proxy") -> Any:
"""Called when a proxy object is has the keys() method called.
This is what happens when ** is called on a proxy. This should return an
iterator it ** is suppose to work in your custom tracer.
"""
return Attribute(obj, 'keys')()
return Attribute(obj, "keys")()
# used in Proxy object when just appending to the graph while not tracing.
@ -355,14 +397,17 @@ class GraphAppendingTracer(TracerBase):
self.module_stack = collections.OrderedDict()
self.node_name_to_scope = {}
@compatibility(is_backward_compatible=False)
def assert_fn(x):
assert x
@compatibility(is_backward_compatible=True)
class TraceError(ValueError):
pass
@compatibility(is_backward_compatible=True)
class Proxy:
"""
@ -394,7 +439,7 @@ class Proxy:
"""
@compatibility(is_backward_compatible=True)
def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
def __init__(self, node: Node, tracer: "Optional[TracerBase]" = None):
if tracer is None:
# This allows you to create a Proxy object around a raw Node
tracer = GraphAppendingTracer(node.graph)
@ -402,9 +447,9 @@ class Proxy:
self.node = node
def __repr__(self) -> str:
return f'Proxy({self.node.name})'
return f"Proxy({self.node.name})"
def __getattr__(self, k) -> 'Attribute':
def __getattr__(self, k) -> "Attribute":
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return Attribute(self, k)
@ -417,6 +462,7 @@ class Proxy:
# will go to __getattr__(self, "__deepcopy__") and return a
# Attribute(__deepcopy__), and may go into an infinite loop in some cases.
import copy
new_dict = {}
for k, v in self.__dict__.items():
try:
@ -424,7 +470,10 @@ class Proxy:
except Exception:
log.warning(
"Shallow copy %s of Proxy because it cannot be deepcopied. "
"Proxy is created for node %s", k, self.node.name)
"Proxy is created for node %s",
k,
self.node.name,
)
new_obj = copy.copy(v)
new_dict[k] = new_obj
assert "node" in new_dict
@ -438,10 +487,12 @@ class Proxy:
# This is called when being unpickled/loaded.
self.__dict__ = d
def __call__(self, *args, **kwargs) -> 'Proxy':
return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
def __call__(self, *args, **kwargs) -> "Proxy":
return self.tracer.create_proxy(
"call_method", "__call__", (self,) + args, kwargs
)
def __iter__(self) -> Iterator['Proxy']:
def __iter__(self) -> Iterator["Proxy"]:
frame = inspect.currentframe()
assert frame is not None
calling_frame = frame.f_back
@ -449,17 +500,20 @@ class Proxy:
inst_list = list(dis.get_instructions(calling_frame.f_code))
if sys.version_info >= (3, 11):
from bisect import bisect_left
inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset)
inst_idx = bisect_left(
inst_list, calling_frame.f_lasti, key=lambda x: x.offset
)
else:
inst_idx = calling_frame.f_lasti // 2
inst = inst_list[inst_idx]
if inst.opname == 'UNPACK_SEQUENCE':
if inst.opname == "UNPACK_SEQUENCE":
return (self[i] for i in range(inst.argval)) # type: ignore[index]
return self.tracer.iter(self)
def __abs__(self):
return self.tracer.create_proxy('call_function', operator.abs, (self,), {})
return self.tracer.create_proxy("call_function", operator.abs, (self,), {})
def __bool__(self) -> bool:
if self.tracer.trace_asserts:
@ -472,19 +526,23 @@ class Proxy:
insts = list(dis.get_instructions(calling_frame.f_code))
if sys.version_info >= (3, 11):
from bisect import bisect_left
cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset)
else:
cur = calling_frame.f_lasti // 2
inst = insts[cur]
if inst.opname == 'POP_JUMP_IF_TRUE':
if inst.opname == "POP_JUMP_IF_TRUE":
first = insts[cur + 1]
assert inst.arg is not None
last = insts[inst.arg // 2 - 1]
starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError'
or first.opname == 'LOAD_ASSERTION_ERROR')
if starts_with_assert and last.opname == 'RAISE_VARARGS':
self.tracer.create_proxy('call_function', assert_fn, (self,), {})
starts_with_assert = (
first.opname == "LOAD_GLOBAL"
and first.argval == "AssertionError"
or first.opname == "LOAD_ASSERTION_ERROR"
)
if starts_with_assert and last.opname == "RAISE_VARARGS":
self.tracer.create_proxy("call_function", assert_fn, (self,), {})
return True
return self.tracer.to_bool(self)
@ -494,39 +552,51 @@ class Proxy:
return self.tracer.keys(self)
def __len__(self):
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
"this call to be recorded, please call torch.fx.wrap('len') at "
"module scope")
raise RuntimeError(
"'len' is not supported in symbolic tracing by default. If you want "
"this call to be recorded, please call torch.fx.wrap('len') at "
"module scope"
)
@classmethod
def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
args = args if args else ()
kwargs = kwargs if kwargs else {}
tracers : Dict[Any, None] = {}
tracers: Dict[Any, None] = {}
def find_tracer(a):
if isinstance(a, cls):
tracers[a.tracer] = None
torch.fx.node.map_aggregate(args, find_tracer)
torch.fx.node.map_aggregate(kwargs, find_tracer)
if len(tracers) > 1:
raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while '
f'trying to trace operations {orig_method}')
raise RuntimeError(
f"Found multiple different tracers {list(tracers.keys())} while "
f"trying to trace operations {orig_method}"
)
tracer = next(iter(tracers.keys()))
if isinstance(orig_method, torch._C.ScriptMethod):
args = (orig_method.owner,) + args
return tracer.create_proxy('call_method', orig_method.name, args, kwargs)
return tracer.create_proxy("call_method", orig_method.name, args, kwargs)
if torch.overrides.is_tensor_method_or_property(orig_method):
return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
return tracer.create_proxy(
"call_method", orig_method.__name__, args, kwargs
)
else:
if isinstance(orig_method, torch._ops.HigherOrderOperator):
# TODO: Define how to symbolically trace HigherOrderOperators
raise RuntimeError("Unable to symbolically trace HigherOrderOperators")
return tracer.create_proxy('call_function', orig_method, args, kwargs,
name=tracer.graph._target_to_str(orig_method.__name__))
return tracer.create_proxy(
"call_function",
orig_method,
args,
kwargs,
name=tracer.graph._target_to_str(orig_method.__name__),
)
@compatibility(is_backward_compatible=False)
@ -535,12 +605,14 @@ class MetaProxy(Proxy):
A Proxy subclass that propagates metadata (meta['val']) during graph tracing.
"""
def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None, fake_mode=None):
def __init__(
self, node: Node, tracer: "Optional[TracerBase]" = None, fake_mode=None
):
super().__init__(node, tracer)
self.fake_mode = fake_mode
def __repr__(self) -> str:
return f'MetaProxy({self.node.name})'
return f"MetaProxy({self.node.name})"
@classmethod
def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
@ -553,16 +625,19 @@ class MetaProxy(Proxy):
meta_proxy = arg
break
assert meta_proxy is not None, "No MetaProxy found in arguments, but one is expected."
assert (
meta_proxy is not None
), "No MetaProxy found in arguments, but one is expected."
proxy = super().__torch_function__(orig_method, types, args, kwargs)
with meta_proxy.fake_mode:
proxy.node.meta["val"] = orig_method(
*[a.node.meta["val"] if isinstance(a, Proxy) else a for a in args],
**kwargs
**kwargs,
)
return MetaProxy(proxy.node, proxy.tracer, meta_proxy.fake_mode)
@compatibility(is_backward_compatible=True)
class Attribute(Proxy):
@compatibility(is_backward_compatible=True)
@ -577,11 +652,15 @@ class Attribute(Proxy):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
self._node = self.tracer.create_proxy(
"call_function", getattr, (self.root, self.attr), {}
).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
return self.tracer.create_proxy(
"call_method", self.attr, (self.root,) + args, kwargs
)
@compatibility(is_backward_compatible=False)
@ -591,6 +670,7 @@ class ParameterProxy(Proxy):
attribute accesses pass through to the underlying module parameter object,
so that conditional tests on these attributes will not throw exception during tracing
"""
def __init__(self, tracer: TracerBase, node: Node, name, param):
super().__init__(node, tracer)
assert isinstance(param, torch.nn.Parameter)
@ -598,7 +678,7 @@ class ParameterProxy(Proxy):
self.name = name
def __repr__(self) -> str:
return f'ParameterProxy({self.name})'
return f"ParameterProxy({self.name})"
@property
def shape(self):
@ -622,25 +702,31 @@ class ParameterProxy(Proxy):
for method in magic_methods:
def _scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
return tracer.create_proxy('call_function', target, args, kwargs)
return tracer.create_proxy("call_function", target, args, kwargs)
impl.__name__ = method
as_magic = f'__{method.strip("_")}__'
setattr(Proxy, as_magic, impl)
_scope(method)
def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name.strip("_")}__'
def impl(self, rhs):
target = getattr(operator, orig_method_name)
return self.tracer.create_proxy('call_function', target, (rhs, self), {})
return self.tracer.create_proxy("call_function", target, (rhs, self), {})
impl.__name__ = method_name
impl.__qualname__ = method_name
setattr(Proxy, method_name, impl)
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)

View File

@ -1,18 +1,36 @@
from .graph_module import GraphModule
from .graph import Graph
from .node import Node
from ._symbolic_trace import symbolic_trace
from ._compatibility import compatibility
import copy
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Set,
TYPE_CHECKING,
Union,
)
import torch
from ._compatibility import compatibility
from ._symbolic_trace import symbolic_trace
from .graph import Graph
from .graph_module import GraphModule
from .node import Node
if TYPE_CHECKING:
from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
__all__ = [
"Match",
"replace_pattern",
"replace_pattern_with_filters",
"ReplacedPatterns",
]
@compatibility(is_backward_compatible=True)
class Match(NamedTuple):
@ -21,6 +39,7 @@ class Match(NamedTuple):
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
@compatibility(is_backward_compatible=False)
@dataclass
class ReplacedPatterns:
@ -31,6 +50,7 @@ class ReplacedPatterns:
# List of nodes that were added into the graph
replacements: List[Node]
def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
gm.delete_all_unused_submodules()
@ -48,7 +68,6 @@ def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
for node in gm.graph.nodes:
if node.op == "call_module" or node.op == "get_attr":
gm_attr = try_get_attr(gm, node.target)
replacement_attr = try_get_attr(replacement, node.target)
@ -70,11 +89,14 @@ def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
# CASE 3: The target doesn't exist as an attribute in `gm`
# or `replacement`
else:
raise RuntimeError('Attempted to create a "', node.op,
'" node during subgraph rewriting '
f"with target {node.target}, but "
"the referenced attribute does not "
"exist in the replacement GraphModule")
raise RuntimeError(
'Attempted to create a "',
node.op,
'" node during subgraph rewriting '
f"with target {node.target}, but "
"the referenced attribute does not "
"exist in the replacement GraphModule",
)
gm.graph.lint()
@ -83,7 +105,7 @@ def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
def replace_pattern(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule]
replacement: Union[Callable, GraphModule],
) -> List[Match]:
"""
Matches all possible non-overlapping sets of operators and their
@ -116,6 +138,7 @@ def replace_pattern(
import torch
from torch.fx import symbolic_trace, subgraph_rewriter
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -125,12 +148,15 @@ def replace_pattern(
m2 = torch.cat([w1, w2]).sum()
return x + torch.max(m1) + torch.max(m2)
def pattern(w1, w2):
return torch.cat([w1, w2]).sum()
def replacement(w1, w2):
return torch.stack([w1, w2])
traced_module = symbolic_trace(M())
subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
@ -199,7 +225,9 @@ def replace_pattern(
return add_2
"""
match_and_replacements = _replace_pattern(gm, pattern, replacement)
return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements]
return [
Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements
]
# Experimental API, not backward compatible
@ -208,10 +236,14 @@ def replace_pattern_with_filters(
gm: GraphModule,
pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule, None] = None,
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
match_filters: Optional[
List[Callable[["InternalMatch", Graph, Graph], bool]]
] = None,
ignore_literals: bool = False,
# Placed at the end to avoid breaking backward compatibility
replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None,
replacement_callback: Optional[
Callable[["InternalMatch", Graph, Graph], Graph]
] = None,
) -> List[ReplacedPatterns]:
"""
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
@ -226,20 +258,25 @@ def replace_pattern_with_filters(
replacement graph based on the match.
"""
return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals, replacement_callback)
return _replace_pattern(
gm, pattern, replacement, match_filters, ignore_literals, replacement_callback
)
def _replace_pattern(
gm: GraphModule,
pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule, None] = None,
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
match_filters: Optional[
List[Callable[["InternalMatch", Graph, Graph], bool]]
] = None,
ignore_literals: bool = False,
# Placed at the end to avoid breaking backward compatibility
replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None,
replacement_callback: Optional[
Callable[["InternalMatch", Graph, Graph], Graph]
] = None,
) -> List[ReplacedPatterns]:
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
if match_filters is None:
match_filters = []
@ -254,15 +291,23 @@ def _replace_pattern(
else:
pattern_graph = symbolic_trace(pattern).graph
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
remove_overlapping_matches=True, ignore_literals=ignore_literals)
matcher = SubgraphMatcher(
pattern_graph,
match_output=False,
match_placeholder=False,
remove_overlapping_matches=True,
ignore_literals=ignore_literals,
)
_matches: List[InternalMatch] = matcher.match(original_graph)
# Filter out matches that don't match the filter
_matches = [
m for m in _matches
if all(match_filter(m, original_graph, pattern_graph)
for match_filter in match_filters)
m
for m in _matches
if all(
match_filter(m, original_graph, pattern_graph)
for match_filter in match_filters
)
]
if isinstance(replacement, GraphModule):
@ -272,7 +317,9 @@ def _replace_pattern(
elif callable(replacement):
common_replacement_graph = symbolic_trace(replacement).graph
else:
assert replacement_callback is not None, "Must provide either a replacement GraphModule or a replacement callback"
assert (
replacement_callback is not None
), "Must provide either a replacement GraphModule or a replacement callback"
common_replacement_graph = None
# As we progressively replace nodes, we'll need to keep track of how the match results should change
@ -281,11 +328,17 @@ def _replace_pattern(
match_and_replacements = []
for match in _matches:
if replacement_callback is not None:
replacement_graph = replacement_callback(match, original_graph, pattern_graph)
replacement_graph = replacement_callback(
match, original_graph, pattern_graph
)
else:
assert common_replacement_graph is not None, "Must provide either a replacement GraphModule or a replacement callback"
assert (
common_replacement_graph is not None
), "Must provide either a replacement GraphModule or a replacement callback"
replacement_graph = common_replacement_graph
replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
replacement_placeholders = [
n for n in replacement_graph.nodes if n.op == "placeholder"
]
# Build connecting between replacement graph's input and original graph input producer node
@ -300,7 +353,9 @@ def _replace_pattern(
# Update match.placeholder_nodes and match.nodes_map with the node that replaced gn
gn_ind = match.placeholder_nodes.index(gn)
match.placeholder_nodes[gn_ind] = match_changed_node[gn]
map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)]
map_key = list(match.nodes_map.keys())[
list(match.nodes_map.values()).index(gn)
]
match.nodes_map[map_key] = match_changed_node[gn]
else:
val_map[rn] = gn
@ -322,13 +377,17 @@ def _replace_pattern(
break
with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined]
copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
copied_returning_nodes = original_graph.graph_copy(
replacement_graph, val_map
)
if isinstance(copied_returning_nodes, Node):
copied_returning_nodes = (copied_returning_nodes, )
copied_returning_nodes = (copied_returning_nodes,)
# Get a list of nodes that have been replaced into the graph
replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes]
replacement_nodes: List[Node] = [
v for v in val_map.values() if v not in match.placeholder_nodes
]
# Hook the output Node of the replacement subgraph in to the
# original Graph at the correct location
@ -346,7 +405,7 @@ def _replace_pattern(
ReplacedPatterns(
anchor=match.anchors[0],
nodes_map=match.nodes_map,
replacements=replacement_nodes
replacements=replacement_nodes,
)
)

View File

@ -19,7 +19,7 @@ class TensorType:
self.__args__ = dim
def __repr__(self):
return f'TensorType[{self.__args__}]'
return f"TensorType[{self.__args__}]"
def __eq__(self, other):
if isinstance(other, self.__class__):
@ -38,8 +38,9 @@ class _DynType:
"""
_DynType defines a type which stands for the absence of type information.
"""
def __init__(self) -> None:
self.__name__ = '_DynType'
self.__name__ = "_DynType"
def __eq__(self, other):
return isinstance(other, self.__class__)
@ -53,6 +54,7 @@ class _DynType:
Dyn = _DynType()
@compatibility(is_backward_compatible=False)
def is_consistent(t1, t2):
"""
@ -73,8 +75,10 @@ def is_consistent(t1, t2):
return True
if isinstance(t1, TensorType) and isinstance(t2, TensorType):
return len(t1.__args__) == len(t2.__args__) and \
all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
return len(t1.__args__) == len(t2.__args__) and all(
is_consistent(elem1, elem2)
for elem1, elem2 in zip(t1.__args__, t2.__args__)
)
else:
return False
@ -98,8 +102,10 @@ def is_more_precise(t1, t2):
return True
if isinstance(t1, TensorType) and isinstance(t2, TensorType):
return len(t1.__args__) == len(t2.__args__) and \
all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
return len(t1.__args__) == len(t2.__args__) and all(
is_more_precise(elem1, elem2)
for elem1, elem2 in zip(t1.__args__, t2.__args__)
)
else:
return False

View File

@ -1,12 +1,21 @@
# mypy: allow-untyped-defs
import traceback
from contextlib import contextmanager
from typing import List, Any, Dict
from typing import Any, Dict, List
from ._compatibility import compatibility
__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr',
'format_stack', 'set_current_meta', 'get_current_meta']
__all__ = [
"preserve_node_meta",
"has_preserved_node_meta",
"set_stack_trace",
"set_grad_fn_seq_nr",
"reset_grad_fn_seq_nr",
"format_stack",
"set_current_meta",
"get_current_meta",
]
current_meta: Dict[str, Any] = {}
should_preserve_node_meta = False
@ -30,7 +39,7 @@ def preserve_node_meta():
@compatibility(is_backward_compatible=False)
def set_stack_trace(stack : List[str]):
def set_stack_trace(stack: List[str]):
global current_meta
if should_preserve_node_meta and stack:
@ -43,7 +52,9 @@ def set_grad_fn_seq_nr(seq_nr):
if should_preserve_node_meta:
# The seq_nr is captured by eager mode in the grad_fn during forward
current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr]
current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [
seq_nr
]
current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
@ -90,7 +101,9 @@ def set_current_meta(node):
if "from_node" not in current_meta:
current_meta["from_node"] = [(node.name, node.target)]
elif current_meta["from_node"][-1][0] != node.name:
current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)]
current_meta["from_node"] = current_meta["from_node"] + [
(node.name, node.target)
]
yield
finally: