mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] enable PYFMT for torch.fx (#138443)
Reproduce command: ```bash ghstack checkout https://github.com/pytorch/pytorch/pull/138443 git checkout HEAD~1 torch/ lintrunner -a --take "PYFMT" --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/138443 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
8231180147
commit
abbd71d29d
@ -1232,87 +1232,6 @@ exclude_patterns = [
|
||||
'torch/fft/__init__.py',
|
||||
'torch/func/__init__.py',
|
||||
'torch/futures/__init__.py',
|
||||
'torch/fx/__init__.py',
|
||||
'torch/fx/_compatibility.py',
|
||||
'torch/fx/_symbolic_trace.py',
|
||||
'torch/fx/annotate.py',
|
||||
'torch/fx/config.py',
|
||||
'torch/fx/experimental/__init__.py',
|
||||
'torch/fx/experimental/accelerator_partitioner.py',
|
||||
'torch/fx/experimental/const_fold.py',
|
||||
'torch/fx/experimental/debug.py',
|
||||
'torch/fx/experimental/graph_gradual_typechecker.py',
|
||||
'torch/fx/experimental/merge_matmul.py',
|
||||
'torch/fx/experimental/meta_tracer.py',
|
||||
'torch/fx/experimental/migrate_gradual_types/__init__.py',
|
||||
'torch/fx/experimental/migrate_gradual_types/constraint.py',
|
||||
'torch/fx/experimental/migrate_gradual_types/constraint_generator.py',
|
||||
'torch/fx/experimental/migrate_gradual_types/constraint_transformation.py',
|
||||
'torch/fx/experimental/migrate_gradual_types/operation.py',
|
||||
'torch/fx/experimental/migrate_gradual_types/transform_to_z3.py',
|
||||
'torch/fx/experimental/migrate_gradual_types/util.py',
|
||||
'torch/fx/experimental/migrate_gradual_types/z3_types.py',
|
||||
'torch/fx/experimental/normalize.py',
|
||||
'torch/fx/experimental/optimization.py',
|
||||
'torch/fx/experimental/partitioner_utils.py',
|
||||
'torch/fx/experimental/refinement_types.py',
|
||||
'torch/fx/experimental/rewriter.py',
|
||||
'torch/fx/experimental/schema_type_annotation.py',
|
||||
'torch/fx/experimental/unification/__init__.py',
|
||||
'torch/fx/experimental/unification/core.py',
|
||||
'torch/fx/experimental/unification/dispatch.py',
|
||||
'torch/fx/experimental/unification/match.py',
|
||||
'torch/fx/experimental/unification/more.py',
|
||||
'torch/fx/experimental/unification/multipledispatch/__init__.py',
|
||||
'torch/fx/experimental/unification/multipledispatch/conflict.py',
|
||||
'torch/fx/experimental/unification/multipledispatch/core.py',
|
||||
'torch/fx/experimental/unification/multipledispatch/dispatcher.py',
|
||||
'torch/fx/experimental/unification/multipledispatch/utils.py',
|
||||
'torch/fx/experimental/unification/multipledispatch/variadic.py',
|
||||
'torch/fx/experimental/unification/unification_tools.py',
|
||||
'torch/fx/experimental/unification/utils.py',
|
||||
'torch/fx/experimental/unification/variable.py',
|
||||
'torch/fx/experimental/unify_refinements.py',
|
||||
'torch/fx/graph.py',
|
||||
'torch/fx/graph_module.py',
|
||||
'torch/fx/interpreter.py',
|
||||
'torch/fx/node.py',
|
||||
'torch/fx/operator_schemas.py',
|
||||
'torch/fx/passes/__init__.py',
|
||||
'torch/fx/passes/annotate_getitem_nodes.py',
|
||||
'torch/fx/passes/backends/__init__.py',
|
||||
'torch/fx/passes/backends/cudagraphs.py',
|
||||
'torch/fx/passes/dialect/__init__.py',
|
||||
'torch/fx/passes/dialect/common/__init__.py',
|
||||
'torch/fx/passes/dialect/common/cse_pass.py',
|
||||
'torch/fx/passes/fake_tensor_prop.py',
|
||||
'torch/fx/passes/graph_drawer.py',
|
||||
'torch/fx/passes/graph_manipulation.py',
|
||||
'torch/fx/passes/infra/__init__.py',
|
||||
'torch/fx/passes/infra/partitioner.py',
|
||||
'torch/fx/passes/infra/pass_base.py',
|
||||
'torch/fx/passes/infra/pass_manager.py',
|
||||
'torch/fx/passes/net_min_base.py',
|
||||
'torch/fx/passes/operator_support.py',
|
||||
'torch/fx/passes/param_fetch.py',
|
||||
'torch/fx/passes/pass_manager.py',
|
||||
'torch/fx/passes/reinplace.py',
|
||||
'torch/fx/passes/shape_prop.py',
|
||||
'torch/fx/passes/split_module.py',
|
||||
'torch/fx/passes/split_utils.py',
|
||||
'torch/fx/passes/splitter_base.py',
|
||||
'torch/fx/passes/tests/__init__.py',
|
||||
'torch/fx/passes/tests/test_pass_manager.py',
|
||||
'torch/fx/passes/tools_common.py',
|
||||
'torch/fx/passes/utils/__init__.py',
|
||||
'torch/fx/passes/utils/common.py',
|
||||
'torch/fx/passes/utils/fuser_utils.py',
|
||||
'torch/fx/passes/utils/matcher_utils.py',
|
||||
'torch/fx/passes/utils/source_matcher_utils.py',
|
||||
'torch/fx/proxy.py',
|
||||
'torch/fx/subgraph_rewriter.py',
|
||||
'torch/fx/tensor_type.py',
|
||||
'torch/fx/traceback.py',
|
||||
'torch/linalg/__init__.py',
|
||||
'torch/monitor/__init__.py',
|
||||
'torch/nested/__init__.py',
|
||||
|
||||
@ -262,7 +262,9 @@
|
||||
"Future"
|
||||
],
|
||||
"torch.fx": [
|
||||
"PH",
|
||||
"ProxyableClassMeta",
|
||||
"CodeGen",
|
||||
"Tracer",
|
||||
"symbolic_trace",
|
||||
"wrap"
|
||||
|
||||
@ -2514,6 +2514,7 @@ if "TORCH_CUDA_SANITIZER" in os.environ:
|
||||
|
||||
# Populate magic methods on SymInt and SymFloat
|
||||
import torch.fx.experimental.sym_node
|
||||
from torch import fx as fx
|
||||
|
||||
|
||||
# Register MPS specific decomps
|
||||
|
||||
@ -7,6 +7,8 @@ demonstration of these components in action:
|
||||
::
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Simple module for demonstration
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -17,11 +19,13 @@ demonstration of these components in action:
|
||||
def forward(self, x):
|
||||
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
||||
|
||||
|
||||
module = MyModule()
|
||||
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
# Symbolic tracing frontend - captures the semantics of the module
|
||||
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
|
||||
# High-level intermediate representation (IR) - Graph representation
|
||||
print(symbolic_traced.graph)
|
||||
@ -80,10 +84,32 @@ Several example transformations can be found at the
|
||||
repository.
|
||||
'''
|
||||
|
||||
from .graph_module import GraphModule
|
||||
from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
|
||||
from .graph import Graph, CodeGen
|
||||
from .node import Node, map_arg, has_side_effect
|
||||
from .proxy import Proxy
|
||||
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
|
||||
from .subgraph_rewriter import replace_pattern
|
||||
from torch.fx._symbolic_trace import ( # noqa: F401
|
||||
PH,
|
||||
ProxyableClassMeta,
|
||||
symbolic_trace,
|
||||
Tracer,
|
||||
wrap,
|
||||
)
|
||||
from torch.fx.graph import CodeGen, Graph # noqa: F401
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.interpreter import Interpreter, Transformer
|
||||
from torch.fx.node import has_side_effect, map_arg, Node
|
||||
from torch.fx.proxy import Proxy
|
||||
from torch.fx.subgraph_rewriter import replace_pattern
|
||||
|
||||
|
||||
__all__ = [
|
||||
"symbolic_trace",
|
||||
"Tracer",
|
||||
"wrap",
|
||||
"Graph",
|
||||
"GraphModule",
|
||||
"Interpreter",
|
||||
"Transformer",
|
||||
"Node",
|
||||
"Proxy",
|
||||
"replace_pattern",
|
||||
"has_side_effect",
|
||||
"map_arg",
|
||||
]
|
||||
|
||||
@ -1,15 +0,0 @@
|
||||
from torch.fx._symbolic_trace import (
|
||||
symbolic_trace as symbolic_trace,
|
||||
Tracer as Tracer,
|
||||
wrap as wrap,
|
||||
)
|
||||
from torch.fx.graph import Graph as Graph
|
||||
from torch.fx.graph_module import GraphModule as GraphModule
|
||||
from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer
|
||||
from torch.fx.node import (
|
||||
has_side_effect as has_side_effect,
|
||||
map_arg as map_arg,
|
||||
Node as Node,
|
||||
)
|
||||
from torch.fx.proxy import Proxy as Proxy
|
||||
from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern
|
||||
@ -1,16 +1,19 @@
|
||||
from typing import Any, Dict, Callable, TypeVar
|
||||
import textwrap
|
||||
from typing import Any, Callable, Dict, TypeVar
|
||||
|
||||
|
||||
_BACK_COMPAT_OBJECTS: Dict[Any, None] = {}
|
||||
_MARKED_WITH_COMPATIBILITY: Dict[Any, None] = {}
|
||||
|
||||
_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
|
||||
_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {}
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]:
|
||||
if is_backward_compatible:
|
||||
|
||||
def mark_back_compat(fn: _T) -> _T:
|
||||
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
|
||||
docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "")
|
||||
docstring += """
|
||||
.. note::
|
||||
Backwards-compatibility for this API is guaranteed.
|
||||
@ -24,7 +27,7 @@ def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]:
|
||||
else:
|
||||
|
||||
def mark_not_back_compat(fn: _T) -> _T:
|
||||
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
|
||||
docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "")
|
||||
docstring += """
|
||||
.. warning::
|
||||
This API is experimental and is *NOT* backward-compatible.
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from contextlib import contextmanager
|
||||
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph_module import (
|
||||
_format_import_block,
|
||||
GraphModule,
|
||||
reduce_graph_module,
|
||||
reduce_package_graph_module,
|
||||
)
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import builtins
|
||||
import copy
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
import collections
|
||||
from itertools import chain
|
||||
from types import CodeType, FunctionType, ModuleType
|
||||
from typing import (
|
||||
@ -29,11 +29,12 @@ from torch._C import ScriptObject # type: ignore[attr-defined]
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from ._lazy_graph_module import _make_graph_module
|
||||
from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
|
||||
from .graph_module import GraphModule
|
||||
from ._lazy_graph_module import _make_graph_module
|
||||
from .node import Argument, base_types, map_aggregate
|
||||
from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
|
||||
from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase
|
||||
|
||||
|
||||
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
|
||||
|
||||
@ -49,6 +50,7 @@ _is_fx_tracing_flag = False
|
||||
def is_fx_tracing():
|
||||
return _is_fx_tracing_flag
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ProxyableClassMeta(type):
|
||||
"""
|
||||
@ -58,6 +60,7 @@ class ProxyableClassMeta(type):
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
|
||||
class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
|
||||
def __init__(self, left, right):
|
||||
self.left, self.right = left, right
|
||||
@ -72,10 +75,12 @@ class ProxyableClassMeta(type):
|
||||
r = self.right * other.right
|
||||
return TensorPair(l, r)
|
||||
|
||||
def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
|
||||
|
||||
def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor):
|
||||
s = x.add(TensorPair(y, y))
|
||||
return s.mul(x)
|
||||
|
||||
|
||||
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
|
||||
y = torch.randn(5, 3)
|
||||
ref_out = use_tensor_pair_ctor(x, y)
|
||||
@ -214,6 +219,7 @@ class PHWithMeta(PHBase):
|
||||
"""
|
||||
Object representing an input placeholder to `concrete_args`
|
||||
"""
|
||||
|
||||
def __init__(self, ph_key: Optional[str] = None):
|
||||
super().__init__()
|
||||
|
||||
@ -404,7 +410,11 @@ class Tracer(TracerBase):
|
||||
# Tensor was not found in the Module hierarchy, stow it away in a
|
||||
# special attribute and set the qualname to refer to that
|
||||
if not qualname:
|
||||
base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj"
|
||||
base_name = (
|
||||
"_tensor_constant"
|
||||
if isinstance(a, torch.Tensor)
|
||||
else "_torchbind_obj"
|
||||
)
|
||||
qualname = self.get_fresh_qualname(base_name)
|
||||
assert isinstance(qualname, str)
|
||||
self.tensor_attrs[a] = qualname
|
||||
@ -446,9 +456,9 @@ class Tracer(TracerBase):
|
||||
appear with the qualified name ``foo.bar.baz`` here.
|
||||
"""
|
||||
return (
|
||||
(m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
|
||||
and not isinstance(m, torch.nn.Sequential)
|
||||
)
|
||||
m.__module__.startswith("torch.nn")
|
||||
or m.__module__.startswith("torch.ao.nn")
|
||||
) and not isinstance(m, torch.nn.Sequential)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def path_of_module(self, mod: torch.nn.Module) -> str:
|
||||
@ -512,17 +522,25 @@ class Tracer(TracerBase):
|
||||
value was returned from the ``Module`` invocation.
|
||||
"""
|
||||
module_qualified_name = self.path_of_module(m)
|
||||
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
|
||||
with ScopeContextManager(
|
||||
self.scope, Scope(module_qualified_name, type(m))
|
||||
) as _scope:
|
||||
# module_stack is an ordered dict so writing then deleting the
|
||||
# entry is equivalent to push/pop on a list
|
||||
num_calls = self.num_calls.get(module_qualified_name, 0)
|
||||
module_key = f"{_scope.module_path}@{num_calls}" if num_calls > 0 else _scope.module_path
|
||||
module_key = (
|
||||
f"{_scope.module_path}@{num_calls}"
|
||||
if num_calls > 0
|
||||
else _scope.module_path
|
||||
)
|
||||
self.module_stack[module_key] = (module_qualified_name, _scope.module_type)
|
||||
self.num_calls[module_qualified_name] = num_calls + 1
|
||||
if not self.is_leaf_module(m, module_qualified_name):
|
||||
ret_val = forward(*args, **kwargs)
|
||||
else:
|
||||
ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
|
||||
ret_val = self.create_proxy(
|
||||
"call_module", module_qualified_name, args, kwargs
|
||||
)
|
||||
key, _ = self.module_stack.popitem(last=True)
|
||||
assert key == module_key, f" Unexpected key {key}"
|
||||
|
||||
@ -551,6 +569,7 @@ class Tracer(TracerBase):
|
||||
|
||||
The return value from the getattr call.
|
||||
"""
|
||||
|
||||
def maybe_get_proxy_for_attr(
|
||||
attr_val, collection_to_search, parameter_proxy_cache
|
||||
):
|
||||
@ -620,15 +639,16 @@ class Tracer(TracerBase):
|
||||
|
||||
sig = inspect.signature(fn_for_analysis)
|
||||
|
||||
|
||||
# This covers the very specific case where we are passing in flat
|
||||
# concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
|
||||
# In this case, just take the concrete_args and pass them through.
|
||||
name_idx = 0
|
||||
if isinstance(concrete_args, tuple) and \
|
||||
len(concrete_args) > 0 and \
|
||||
(co.co_flags & HAS_VARSTUFF) and \
|
||||
total_args == 1:
|
||||
if (
|
||||
isinstance(concrete_args, tuple)
|
||||
and len(concrete_args) > 0
|
||||
and (co.co_flags & HAS_VARSTUFF)
|
||||
and total_args == 1
|
||||
):
|
||||
for concrete_arg in concrete_args:
|
||||
out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
|
||||
if isinstance(concrete_arg, PHBase):
|
||||
@ -722,12 +742,12 @@ class Tracer(TracerBase):
|
||||
_is_fx_tracing_flag = True
|
||||
try:
|
||||
if isinstance(root, torch.nn.Module):
|
||||
|
||||
# do real recompilation for _LazyGraphModule before retracing since the trace
|
||||
# method can not trace the _lazy_forward method. Got error:
|
||||
# https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
|
||||
# without this.
|
||||
from torch.fx._lazy_graph_module import _LazyGraphModule
|
||||
|
||||
_LazyGraphModule.force_recompile(root)
|
||||
|
||||
self.root = root
|
||||
@ -745,12 +765,12 @@ class Tracer(TracerBase):
|
||||
|
||||
tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None)
|
||||
self.graph = Graph(tracer_cls=tracer_cls)
|
||||
if hasattr(fn, '__code__'):
|
||||
if hasattr(fn, "__code__"):
|
||||
code = fn.__code__
|
||||
self.graph._co_fields = {
|
||||
'co_name': code.co_name,
|
||||
'co_filename': code.co_filename,
|
||||
'co_firstlineno': code.co_firstlineno,
|
||||
"co_name": code.co_name,
|
||||
"co_filename": code.co_filename,
|
||||
"co_firstlineno": code.co_firstlineno,
|
||||
}
|
||||
|
||||
# When we encounter a Tensor value that's not a parameter, we look if it
|
||||
@ -758,11 +778,7 @@ class Tracer(TracerBase):
|
||||
# values to the qualified name here for efficiency. This is used downstream
|
||||
# in create_arg
|
||||
self.tensor_attrs: Dict[
|
||||
Union[
|
||||
torch.Tensor,
|
||||
ScriptObject,
|
||||
FakeScriptObject
|
||||
], str
|
||||
Union[torch.Tensor, ScriptObject, FakeScriptObject], str
|
||||
] = {}
|
||||
|
||||
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
|
||||
@ -839,7 +855,7 @@ class Tracer(TracerBase):
|
||||
new_tracer = Tracer.__new__(Tracer)
|
||||
|
||||
for k, v in self.__dict__.items():
|
||||
if k in {'_autowrap_search'}:
|
||||
if k in {"_autowrap_search"}:
|
||||
new_obj = copy.copy(v)
|
||||
else:
|
||||
new_obj = copy.deepcopy(v, memo)
|
||||
@ -857,9 +873,7 @@ class Tracer(TracerBase):
|
||||
cnt += 1
|
||||
param = sig.parameters[name]
|
||||
default = (
|
||||
()
|
||||
if param.default is inspect.Parameter.empty
|
||||
else (param.default,)
|
||||
() if param.default is inspect.Parameter.empty else (param.default,)
|
||||
)
|
||||
out = self.create_proxy(
|
||||
"placeholder", f"{name}_{str(cnt)}", default, {}
|
||||
@ -877,11 +891,7 @@ class Tracer(TracerBase):
|
||||
|
||||
return out
|
||||
# Union[int, bool] == bool in Python <= 3.6
|
||||
if (
|
||||
type(x) == bool
|
||||
or type(x) in base_types
|
||||
and type(x) != torch.Tensor
|
||||
):
|
||||
if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor:
|
||||
torch._assert(
|
||||
out == x,
|
||||
f"{name} has been specialized to have value {x} but got another value",
|
||||
@ -906,13 +916,15 @@ class Tracer(TracerBase):
|
||||
default = ()
|
||||
else:
|
||||
param = sig.parameters[name]
|
||||
default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment]
|
||||
default = ( # type: ignore[assignment]
|
||||
() if param.default is inspect.Parameter.empty else (param.default,)
|
||||
)
|
||||
return self.create_proxy(
|
||||
"placeholder",
|
||||
name,
|
||||
default,
|
||||
{},
|
||||
type_expr=fn_for_analysis.__annotations__.get(name, None)
|
||||
type_expr=fn_for_analysis.__annotations__.get(name, None),
|
||||
)
|
||||
|
||||
|
||||
@ -1011,6 +1023,7 @@ class _PatchedFnSetItem(_PatchedFn):
|
||||
def patch(self):
|
||||
self.frame_dict[self.fn_name] = self.new_fn
|
||||
|
||||
|
||||
class _PatchedFnDel(_PatchedFn):
|
||||
def revert(self):
|
||||
del self.frame_dict[self.fn_name]
|
||||
@ -1026,6 +1039,7 @@ class _PatchedFnSetAttr(_PatchedFn):
|
||||
def patch(self):
|
||||
setattr(self.frame_dict, self.fn_name, self.new_fn)
|
||||
|
||||
|
||||
class _Patcher:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -1106,6 +1120,7 @@ class _Patcher:
|
||||
|
||||
CURRENT_PATCHER: Optional[_Patcher] = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _new_patcher():
|
||||
global CURRENT_PATCHER
|
||||
@ -1132,7 +1147,10 @@ def _maybe_revert_all_patches():
|
||||
finally:
|
||||
if current_patcher is not None:
|
||||
patches_made = current_patcher.reapply_all_patches()
|
||||
assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches"
|
||||
assert (
|
||||
patches_made == patches_removed
|
||||
), "CURRENT_PATCHER was changed during a revert_all_patches"
|
||||
|
||||
|
||||
def _patch_wrapped_functions(patcher: _Patcher):
|
||||
"""
|
||||
@ -1178,7 +1196,9 @@ def wrap(fn_or_name: Union[str, Callable]):
|
||||
def my_custom_function(x, y):
|
||||
return x * x + y * y
|
||||
|
||||
torch.fx.wrap('my_custom_function')
|
||||
|
||||
torch.fx.wrap("my_custom_function")
|
||||
|
||||
|
||||
def fn_to_be_traced(x, y):
|
||||
# When symbolic tracing, the below call to my_custom_function will be inserted into
|
||||
@ -1248,14 +1268,14 @@ def symbolic_trace(
|
||||
if b == True:
|
||||
return a
|
||||
else:
|
||||
return a*2
|
||||
return a * 2
|
||||
|
||||
FX can typically not trace through this due to the presence of control
|
||||
flow. However, we can use `concrete_args` to specialize on the value of
|
||||
`b` to trace through this::
|
||||
|
||||
f = fx.symbolic_trace(f, concrete_args={'b': False})
|
||||
assert f(3, False) == 6
|
||||
f = fx.symbolic_trace(f, concrete_args={"b": False})
|
||||
assert f(3, False) == 6
|
||||
|
||||
Note that although you can still pass in different values of `b`, they will be ignored.
|
||||
|
||||
@ -1269,8 +1289,10 @@ def symbolic_trace(
|
||||
for v in x.values():
|
||||
out += v
|
||||
return out
|
||||
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
|
||||
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
|
||||
|
||||
|
||||
f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
|
||||
assert f({"a": 1, "b": 2, "c": 4}) == 7
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from torch.fx.proxy import Proxy
|
||||
|
||||
from ._compatibility import compatibility
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def annotate(val, type):
|
||||
"""
|
||||
@ -18,13 +20,15 @@ def annotate(val, type):
|
||||
"""
|
||||
if isinstance(val, Proxy):
|
||||
if val.node.type:
|
||||
raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
|
||||
f" Existing type is {val.node.type} "
|
||||
f"and new type is {type}. "
|
||||
f"This could happen if you tried to annotate a function parameter "
|
||||
f"value (in which case you should use the type slot "
|
||||
f"on the function signature) or you called "
|
||||
f"annotate on the same value twice")
|
||||
raise RuntimeError(
|
||||
f"Tried to annotate a value that already had a type on it!"
|
||||
f" Existing type is {val.node.type} "
|
||||
f"and new type is {type}. "
|
||||
f"This could happen if you tried to annotate a function parameter "
|
||||
f"value (in which case you should use the type slot "
|
||||
f"on the function signature) or you called "
|
||||
f"annotate on the same value twice"
|
||||
)
|
||||
else:
|
||||
val.node.type = type
|
||||
return val
|
||||
|
||||
@ -1,22 +1,22 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
from collections import deque
|
||||
from typing import Dict, List, Set, NamedTuple, Tuple, Deque
|
||||
from typing import Deque, Dict, List, NamedTuple, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.passes.graph_manipulation import get_size_of_all_nodes
|
||||
from torch.fx.experimental.partitioner_utils import (
|
||||
Partition,
|
||||
Device,
|
||||
PartitionerConfig,
|
||||
get_partition_to_latency_mapping,
|
||||
get_latency_of_partitioned_graph,
|
||||
NodeLatency,
|
||||
get_extra_size_of,
|
||||
get_latency_of_partitioned_graph,
|
||||
get_partition_to_latency_mapping,
|
||||
NodeLatency,
|
||||
Partition,
|
||||
PartitionerConfig,
|
||||
PartitionMode,
|
||||
)
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node, map_arg
|
||||
from torch.fx.node import map_arg, Node
|
||||
from torch.fx.passes.graph_manipulation import get_size_of_all_nodes
|
||||
from torch.fx.passes.split_module import split_module
|
||||
|
||||
|
||||
@ -260,7 +260,9 @@ def get_device_to_partitions_mapping(
|
||||
# Find devices for all the partitions without a device
|
||||
found_device = True
|
||||
for partition in no_device_partitions:
|
||||
device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1)))
|
||||
device_to_left_mem_bytes = dict(
|
||||
sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1))
|
||||
)
|
||||
found_device = find_device_for(partition)
|
||||
if not found_device:
|
||||
break
|
||||
|
||||
@ -7,7 +7,12 @@ from torch.fx.node import map_arg
|
||||
from torch.fx.passes.split_module import split_module
|
||||
|
||||
|
||||
__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs']
|
||||
__all__ = [
|
||||
"FoldedGraphModule",
|
||||
"get_unique_attr_name_in_module",
|
||||
"split_const_subgraphs",
|
||||
]
|
||||
|
||||
|
||||
class FoldedGraphModule(torch.fx.GraphModule):
|
||||
"""
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch.fx as fx
|
||||
|
||||
|
||||
def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
|
||||
"""
|
||||
Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
|
||||
@ -13,18 +14,14 @@ def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
|
||||
Returns:
|
||||
the `gm` with breakpoint inserted.
|
||||
"""
|
||||
|
||||
def insert_pdb(body):
|
||||
return ["import pdb; pdb.set_trace()\n", *body]
|
||||
|
||||
with gm.graph.on_generate_code(
|
||||
make_transformer=lambda cur_transform: (
|
||||
# new code transformer to register
|
||||
lambda body: (
|
||||
insert_pdb(
|
||||
cur_transform(body) if cur_transform
|
||||
else body
|
||||
)
|
||||
)
|
||||
lambda body: (insert_pdb(cur_transform(body) if cur_transform else body))
|
||||
)
|
||||
):
|
||||
gm.recompile()
|
||||
|
||||
@ -1,20 +1,21 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from functools import reduce
|
||||
import torch
|
||||
import operator
|
||||
from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise
|
||||
from typing import Callable, Dict
|
||||
from torch.fx.node import Target, Node
|
||||
from torch.nn.modules.batchnorm import BatchNorm2d
|
||||
from torch.nn.modules.conv import Conv2d
|
||||
from torch.fx.experimental.refinement_types import Equality
|
||||
import itertools
|
||||
|
||||
from torch.fx.experimental.unification import Var # type: ignore[attr-defined]
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Callable, Dict
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.refinement_types import Equality
|
||||
from torch.fx.experimental.unification import Var # type: ignore[attr-defined]
|
||||
from torch.fx.node import Node, Target
|
||||
from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType
|
||||
from torch.nn.modules.batchnorm import BatchNorm2d
|
||||
from torch.nn.modules.conv import Conv2d
|
||||
|
||||
|
||||
_INFERENCE_RULES: Dict[Target, Callable] = {}
|
||||
_REFINEMENT_RULES: Dict[Target, Callable] = {}
|
||||
_RULES: Dict[Target, Callable] = {}
|
||||
@ -32,10 +33,12 @@ def expand_to_tensor_dim(t, n):
|
||||
return TensorType(tuple(dims))
|
||||
elif isinstance(t, TensorType):
|
||||
if len(t.__args__) != n:
|
||||
raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
|
||||
raise TypeError(
|
||||
f"Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}"
|
||||
)
|
||||
return t
|
||||
else:
|
||||
raise TypeError(f'Cannot match the type {t}')
|
||||
raise TypeError(f"Cannot match the type {t}")
|
||||
|
||||
|
||||
def broadcast_types(t1, t2):
|
||||
@ -80,32 +83,39 @@ def broadcast_types(t1, t2):
|
||||
(t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
|
||||
return (t1, t2)
|
||||
else:
|
||||
raise TypeError(f'Cannot broadcast types {t1} and {t2}')
|
||||
raise TypeError(f"Cannot broadcast types {t1} and {t2}")
|
||||
|
||||
|
||||
def register_inference_rule(call_target):
|
||||
def register(fn):
|
||||
if call_target in _INFERENCE_RULES:
|
||||
raise RuntimeError(f'Inference rule already registered for {call_target}!')
|
||||
raise RuntimeError(f"Inference rule already registered for {call_target}!")
|
||||
_INFERENCE_RULES[call_target] = fn
|
||||
return fn
|
||||
|
||||
return register
|
||||
|
||||
|
||||
def register_refinement_rule(call_target):
|
||||
def register(fn):
|
||||
if call_target in _REFINEMENT_RULES:
|
||||
raise RuntimeError(f'Refinement rule already registered for {call_target}!')
|
||||
raise RuntimeError(f"Refinement rule already registered for {call_target}!")
|
||||
_REFINEMENT_RULES[call_target] = fn
|
||||
return fn
|
||||
|
||||
return register
|
||||
|
||||
|
||||
def register_algebraic_expressions_inference_rule(call_target):
|
||||
def register(fn):
|
||||
if call_target in _RULES:
|
||||
raise RuntimeError(f'Rule already registered for {call_target}!')
|
||||
raise RuntimeError(f"Rule already registered for {call_target}!")
|
||||
_RULES[call_target] = fn
|
||||
return fn
|
||||
|
||||
return register
|
||||
|
||||
|
||||
@register_inference_rule(torch.add)
|
||||
@register_inference_rule(operator.add)
|
||||
def add_inference_rule(n: Node):
|
||||
@ -142,15 +152,15 @@ def add_inference_rule(n: Node):
|
||||
(new_t1, new_t2) = broadcast_types(t1, t2)
|
||||
|
||||
if new_t1 != t1 or new_t2 != t2:
|
||||
n.meta['broadcast'] = True
|
||||
n.meta["broadcast"] = True
|
||||
n.meta[str(n.args[0])] = new_t1
|
||||
n.meta[str(n.args[1])] = new_t2
|
||||
|
||||
else:
|
||||
n.meta['broadcast'] = False
|
||||
n.meta["broadcast"] = False
|
||||
|
||||
new_t1 = t1 if not n.meta['broadcast'] else new_t1
|
||||
new_t2 = t2 if not n.meta['broadcast'] else new_t2
|
||||
new_t1 = t1 if not n.meta["broadcast"] else new_t1
|
||||
new_t2 = t2 if not n.meta["broadcast"] else new_t2
|
||||
|
||||
# we check for consistency between the new types
|
||||
if is_consistent(new_t1, new_t2):
|
||||
@ -164,8 +174,11 @@ def add_inference_rule(n: Node):
|
||||
n.type = new_t1
|
||||
return n.type
|
||||
else:
|
||||
raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.'
|
||||
f' Types should match ')
|
||||
raise TypeError(
|
||||
f"Cannot add arguments {n.args[0]} ({n.args[0].type}) and {n.args[1]} ({n.args[1].type}) in node {n}."
|
||||
f" Types should match "
|
||||
)
|
||||
|
||||
|
||||
@register_inference_rule(getattr)
|
||||
def get_attr_inference_rule(n: Node, traced):
|
||||
@ -185,6 +198,7 @@ def get_attr_inference_rule(n: Node, traced):
|
||||
# TODO. We leave it like this till we add a type to represent tensor sizes
|
||||
return n.type
|
||||
|
||||
|
||||
@register_inference_rule(torch.transpose)
|
||||
def transpose_inference_rule(n: Node):
|
||||
"""
|
||||
@ -211,9 +225,13 @@ def transpose_inference_rule(n: Node):
|
||||
n.type = get_greatest_upper_bound(n.type, final)
|
||||
return n.type
|
||||
else:
|
||||
raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
|
||||
raise TypeError(
|
||||
f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}"
|
||||
)
|
||||
else:
|
||||
raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
|
||||
raise TypeError(
|
||||
f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}"
|
||||
)
|
||||
|
||||
|
||||
@register_inference_rule(torch.reshape)
|
||||
@ -251,9 +269,10 @@ def reshape_inference_rule(n: Node):
|
||||
n.type = t2_type
|
||||
return t2_type
|
||||
else:
|
||||
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
|
||||
raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}")
|
||||
else:
|
||||
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
|
||||
raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}")
|
||||
|
||||
|
||||
@register_inference_rule(BatchNorm2d)
|
||||
def bn2d_inference_rule(n: Node, module_instance):
|
||||
@ -274,10 +293,11 @@ def bn2d_inference_rule(n: Node, module_instance):
|
||||
# we check the conditions on the incoming argument
|
||||
# and any existing annotation
|
||||
# we also check for consistency between both annotations
|
||||
if is_consistent(arg_type.__args__[1], module_instance.num_features) and \
|
||||
is_consistent(n.type.__args__[1], module_instance.num_features) and \
|
||||
is_consistent(arg_type, n.type):
|
||||
|
||||
if (
|
||||
is_consistent(arg_type.__args__[1], module_instance.num_features)
|
||||
and is_consistent(n.type.__args__[1], module_instance.num_features)
|
||||
and is_consistent(arg_type, n.type)
|
||||
):
|
||||
# we choose the more precise type
|
||||
# to be the node type
|
||||
# so if an incoming argument has more type information
|
||||
@ -285,21 +305,35 @@ def bn2d_inference_rule(n: Node, module_instance):
|
||||
n.type = get_greatest_upper_bound(arg_type, n.type)
|
||||
return n.type
|
||||
else:
|
||||
raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}')
|
||||
raise TypeError(
|
||||
f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}"
|
||||
)
|
||||
|
||||
|
||||
def calculate_out_dimension(d_in, module_instance, index):
|
||||
"""
|
||||
For calculating h_in and w_out according to the conv2D documentation
|
||||
"""
|
||||
padding = (module_instance.padding, module_instance.padding) \
|
||||
if isinstance(module_instance.padding, int) else module_instance.padding
|
||||
kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \
|
||||
if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size
|
||||
stride = (module_instance.stride, module_instance.stride) \
|
||||
if isinstance(module_instance.stride, int) else module_instance.stride
|
||||
dilation = (module_instance.dilation, module_instance.dilation) \
|
||||
if isinstance(module_instance.dilation, int) else module_instance.dilation
|
||||
padding = (
|
||||
(module_instance.padding, module_instance.padding)
|
||||
if isinstance(module_instance.padding, int)
|
||||
else module_instance.padding
|
||||
)
|
||||
kernel_size = (
|
||||
(module_instance.kernel_size, module_instance.kernel_size)
|
||||
if isinstance(module_instance.kernel_size, int)
|
||||
else module_instance.kernel_size
|
||||
)
|
||||
stride = (
|
||||
(module_instance.stride, module_instance.stride)
|
||||
if isinstance(module_instance.stride, int)
|
||||
else module_instance.stride
|
||||
)
|
||||
dilation = (
|
||||
(module_instance.dilation, module_instance.dilation)
|
||||
if isinstance(module_instance.dilation, int)
|
||||
else module_instance.dilation
|
||||
)
|
||||
|
||||
DIMENSION_TYPES = (int, sympy.Symbol)
|
||||
|
||||
@ -307,14 +341,14 @@ def calculate_out_dimension(d_in, module_instance, index):
|
||||
return Dyn
|
||||
|
||||
elif isinstance(d_in, DIMENSION_TYPES):
|
||||
n = d_in + 2 * padding[index] - \
|
||||
dilation[index] * \
|
||||
(kernel_size[index] - 1) - 1
|
||||
n = d_in + 2 * padding[index] - dilation[index] * (kernel_size[index] - 1) - 1
|
||||
|
||||
return (n // stride[0]) + 1
|
||||
|
||||
else:
|
||||
raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}')
|
||||
raise TypeError(
|
||||
f"{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}"
|
||||
)
|
||||
|
||||
|
||||
def get_greatest_upper_bound(type1, type2):
|
||||
@ -327,8 +361,11 @@ def get_greatest_upper_bound(type1, type2):
|
||||
return type1
|
||||
elif isinstance(type1, TensorType) and isinstance(type2, TensorType):
|
||||
if not is_consistent(type1, type2):
|
||||
raise TypeError(f'Inconsistent types {type1}, {type2}')
|
||||
gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)]
|
||||
raise TypeError(f"Inconsistent types {type1}, {type2}")
|
||||
gub = [
|
||||
t1 if is_more_precise(t1, t2) else t2
|
||||
for (t1, t2) in zip(type1.__args__, type2.__args__)
|
||||
]
|
||||
return TensorType(tuple(gub))
|
||||
|
||||
|
||||
@ -352,12 +389,16 @@ def conv2d_inference_rule(n: Node, module_instance):
|
||||
h_in = arg_type.__args__[2]
|
||||
h_out = calculate_out_dimension(h_in, module_instance, 0)
|
||||
w_out = calculate_out_dimension(w_in, module_instance, 1)
|
||||
new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out))
|
||||
new_type = TensorType(
|
||||
(arg_type.__args__[0], module_instance.out_channels, h_out, w_out)
|
||||
)
|
||||
gub = get_greatest_upper_bound(new_type, curr_node_type)
|
||||
n.type = gub
|
||||
return n.type
|
||||
else:
|
||||
raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
|
||||
raise TypeError(
|
||||
f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}"
|
||||
)
|
||||
|
||||
|
||||
@register_inference_rule(torch.nn.ReLU)
|
||||
@ -393,7 +434,7 @@ def maxpool2d_check(typ, module_instance):
|
||||
return TensorType(tuple(new_type_list))
|
||||
|
||||
else:
|
||||
raise TypeError(f'Wrong size {typ} for {module_instance}')
|
||||
raise TypeError(f"Wrong size {typ} for {module_instance}")
|
||||
|
||||
|
||||
@register_inference_rule(torch.nn.MaxPool2d)
|
||||
@ -417,7 +458,6 @@ def maxpool2d_inference_rule(n: Node, module_instance):
|
||||
return n.type
|
||||
|
||||
|
||||
|
||||
def linear_check(tensor_type, module_instance):
|
||||
"""
|
||||
Checks that an input tensor type satisfies the conditions for linear operation
|
||||
@ -429,9 +469,11 @@ def linear_check(tensor_type, module_instance):
|
||||
new_type_args[-1] = module_instance.out_features
|
||||
return TensorType(tuple(new_type_args))
|
||||
else:
|
||||
raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}')
|
||||
raise TypeError(
|
||||
f"Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}"
|
||||
)
|
||||
else:
|
||||
raise TypeError(f'Type {tensor_type} must have rank 2 or more.')
|
||||
raise TypeError(f"Type {tensor_type} must have rank 2 or more.")
|
||||
|
||||
|
||||
@register_inference_rule(torch.nn.Linear)
|
||||
@ -469,7 +511,8 @@ def adaptiveavgpool2d_check(tensor_type, module_instance):
|
||||
return TensorType(tuple(new_type_list))
|
||||
|
||||
else:
|
||||
raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}')
|
||||
raise TypeError(f"Tensor ranks must be 3 or 4. Got {tensor_type}")
|
||||
|
||||
|
||||
@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
|
||||
def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
|
||||
@ -485,6 +528,7 @@ def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
|
||||
n.type = get_greatest_upper_bound(n.type, output_type)
|
||||
return n.type
|
||||
|
||||
|
||||
def flatten_check(tensor_type, start_dim, end_dim):
|
||||
l = len(tensor_type.__args__)
|
||||
|
||||
@ -503,7 +547,10 @@ def flatten_check(tensor_type, start_dim, end_dim):
|
||||
new_type_list = lhs + mid + rhs
|
||||
return TensorType(tuple(new_type_list))
|
||||
else:
|
||||
raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}')
|
||||
raise TypeError(
|
||||
f"Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}"
|
||||
)
|
||||
|
||||
|
||||
@register_inference_rule(torch.flatten)
|
||||
def flatten_inference_rule(n: Node):
|
||||
@ -530,10 +577,11 @@ def flatten_inference_rule(n: Node):
|
||||
|
||||
if isinstance(n.args[0].type, TensorType):
|
||||
output_type = flatten_check(n.args[0].type, start_dim, end_dim)
|
||||
n.type = get_greatest_upper_bound(output_type , n.type)
|
||||
n.type = get_greatest_upper_bound(output_type, n.type)
|
||||
|
||||
return n.type
|
||||
|
||||
|
||||
class GraphTypeChecker:
|
||||
def __init__(self, env, traced):
|
||||
self.env = env
|
||||
@ -571,16 +619,16 @@ class GraphTypeChecker:
|
||||
if n.type is None:
|
||||
n.type = Dyn
|
||||
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
return n.type
|
||||
|
||||
elif n.op == 'get_attr':
|
||||
elif n.op == "get_attr":
|
||||
t = get_parameter(self.traced, n.target) # type: ignore[arg-type]
|
||||
if isinstance(t.data, torch.Tensor):
|
||||
n.type = TensorType(t.data.shape)
|
||||
return n.type
|
||||
|
||||
elif n.op == 'call_function':
|
||||
elif n.op == "call_function":
|
||||
if n.target == getattr:
|
||||
assert getattr in _INFERENCE_RULES
|
||||
return _INFERENCE_RULES[n.target](n, self.traced)
|
||||
@ -588,18 +636,24 @@ class GraphTypeChecker:
|
||||
elif n.target in _INFERENCE_RULES:
|
||||
return _INFERENCE_RULES[n.target](n)
|
||||
else:
|
||||
raise RuntimeError(f'No inference rule registered for target {n.target}!')
|
||||
raise RuntimeError(
|
||||
f"No inference rule registered for target {n.target}!"
|
||||
)
|
||||
|
||||
elif n.op == 'call_module':
|
||||
elif n.op == "call_module":
|
||||
module_instance = self.traced.get_submodule(n.target)
|
||||
if type(module_instance) in _INFERENCE_RULES:
|
||||
return _INFERENCE_RULES[type(module_instance)](n, module_instance)
|
||||
else:
|
||||
raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
|
||||
raise RuntimeError(
|
||||
f"No inference rule registered for class {type(module_instance)}!"
|
||||
)
|
||||
|
||||
elif n.op == "output":
|
||||
|
||||
elif n.op == 'output':
|
||||
def get_node_type(a):
|
||||
return a.type
|
||||
|
||||
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
|
||||
return n.type
|
||||
|
||||
@ -634,6 +688,7 @@ def linear_refinement_rule(n: Node):
|
||||
res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
|
||||
return res
|
||||
|
||||
|
||||
@register_refinement_rule(BatchNorm2d)
|
||||
@register_refinement_rule(torch.nn.ReLU)
|
||||
def all_eq(n: Node):
|
||||
@ -688,7 +743,11 @@ def element_wise_eq(n: Node):
|
||||
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
||||
arg_type1 = n.args[0].type
|
||||
arg_type2 = n.args[1].type
|
||||
if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType):
|
||||
if (
|
||||
isinstance(arg_type1, TensorType)
|
||||
and isinstance(arg_type2, TensorType)
|
||||
and isinstance(n.type, TensorType)
|
||||
):
|
||||
args1, args2 = broadcast_types(arg_type1, arg_type2)
|
||||
# by this point, we know that args1 and args2 are the same size.
|
||||
a1 = args1.__args__
|
||||
@ -757,12 +816,14 @@ def conv_rule(n: Node, module_instance):
|
||||
n.type = new_type
|
||||
return new_type
|
||||
|
||||
|
||||
class Refine:
|
||||
"""
|
||||
Symbolic shape inference.
|
||||
Generates constraints over type variables.
|
||||
Currently all constraints are equality constraints.
|
||||
"""
|
||||
|
||||
def __init__(self, traced):
|
||||
self.constraints = []
|
||||
self.traced = traced
|
||||
@ -805,7 +866,6 @@ class Refine:
|
||||
else:
|
||||
return typ
|
||||
|
||||
|
||||
def convert_to_sympy_symbols(self, typ):
|
||||
"""
|
||||
Replace all unknown types with fresh type variables.
|
||||
@ -835,22 +895,24 @@ class Refine:
|
||||
|
||||
n.type = self.replace_dyn_with_fresh_var(n.type)
|
||||
|
||||
if n.op == 'call_function':
|
||||
if n.op == "call_function":
|
||||
if n.target in _REFINEMENT_RULES:
|
||||
self.constraints += _REFINEMENT_RULES[n.target](n)
|
||||
else:
|
||||
pass
|
||||
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
module_instance = self.traced.get_submodule(n.target)
|
||||
if type(module_instance) in _REFINEMENT_RULES:
|
||||
self.constraints += _REFINEMENT_RULES[type(module_instance)](n)
|
||||
else:
|
||||
pass
|
||||
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
|
||||
def get_node_type(a):
|
||||
return a.type
|
||||
|
||||
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
|
||||
return n.type
|
||||
|
||||
@ -859,28 +921,31 @@ class Refine:
|
||||
|
||||
def infer_symbolic_relations(self, n: Node):
|
||||
n.type = self.convert_to_sympy_symbols(n.type)
|
||||
if n.op == 'call_function':
|
||||
if n.op == "call_function":
|
||||
if n.target in _RULES:
|
||||
return _RULES[n.target](n)
|
||||
else:
|
||||
pass
|
||||
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
module_instance = self.traced.get_submodule(n.target)
|
||||
if type(module_instance) in _RULES:
|
||||
return _RULES[type(module_instance)](n, module_instance)
|
||||
else:
|
||||
pass
|
||||
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
|
||||
def get_node_type(a):
|
||||
return a.type
|
||||
|
||||
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
|
||||
return n.type
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def get_parameter(traced, target: str):
|
||||
"""
|
||||
Returns the parameter given by ``target`` if it exists,
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
from torch.fx.node import Node
|
||||
from torch.fx._symbolic_trace import symbolic_trace
|
||||
from torch.fx.passes.tools_common import legalize_graph
|
||||
import itertools
|
||||
import operator
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx._symbolic_trace import symbolic_trace
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.passes.tools_common import legalize_graph
|
||||
|
||||
|
||||
def split_result_tensors(
|
||||
result: torch.Tensor, inputs: List[torch.Tensor]
|
||||
@ -146,7 +145,14 @@ def merge_matmul(in_mod: torch.nn.Module):
|
||||
# Multiply the concatenated LHS operands with the one RHS. This will produce
|
||||
# the same results as all the individual matmuls involving rhs in the original graph,
|
||||
# but they will all be concatenated together.
|
||||
merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
|
||||
merge_mm = gm.graph.call_function(
|
||||
torch.matmul,
|
||||
(
|
||||
merge_mm_cat,
|
||||
rhs,
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
# Split the result of the merged matmul using the shapes of the LHS operands
|
||||
# to ascertain how large each chunk should be.
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import torch.fx
|
||||
import warnings
|
||||
import functools
|
||||
import builtins
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
|
||||
def embedding_override(self, input):
|
||||
return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
|
||||
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
|
||||
|
||||
|
||||
def nn_layernorm_override(self, input):
|
||||
@ -24,21 +25,22 @@ def torch_nn_relu_override(self, x):
|
||||
|
||||
|
||||
def functional_relu_override(x, inplace=False):
|
||||
assert not inplace, 'dont support inplace functional.relu for metatensor analysis'
|
||||
assert not inplace, "dont support inplace functional.relu for metatensor analysis"
|
||||
return x
|
||||
|
||||
|
||||
def torch_where_override(condition, x, y):
|
||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||
# so hack it by using addition
|
||||
return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta')
|
||||
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
||||
|
||||
|
||||
def torch_abs_override(input, *, out=None):
|
||||
assert out is None, 'Dont support in-place abs for MetaTensor analysis'
|
||||
assert out is None, "Dont support in-place abs for MetaTensor analysis"
|
||||
return input
|
||||
|
||||
manual_meta_overrides : Dict[Callable, Callable] = {
|
||||
|
||||
manual_meta_overrides: Dict[Callable, Callable] = {
|
||||
torch.nn.Embedding: embedding_override,
|
||||
torch.nn.LayerNorm: nn_layernorm_override,
|
||||
torch.relu: torch_relu_override,
|
||||
@ -48,6 +50,7 @@ manual_meta_overrides : Dict[Callable, Callable] = {
|
||||
torch.abs: torch_abs_override,
|
||||
}
|
||||
|
||||
|
||||
def gen_constructor_wrapper(target):
|
||||
@functools.wraps(target)
|
||||
def wrapper(*args, **kwargs):
|
||||
@ -57,57 +60,66 @@ def gen_constructor_wrapper(target):
|
||||
if isinstance(v, torch.fx.Proxy):
|
||||
nonlocal proxy
|
||||
proxy = v
|
||||
|
||||
torch.fx.node.map_aggregate(args, check_has_proxy)
|
||||
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
|
||||
|
||||
if proxy is not None:
|
||||
return proxy.tracer.create_proxy('call_function', target, args, kwargs)
|
||||
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||
else:
|
||||
return target(*args, **kwargs)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
|
||||
class MetaProxy(torch.fx.Proxy):
|
||||
def install_tensor_meta(self, tensor_meta):
|
||||
self._tensor_meta = tensor_meta
|
||||
|
||||
def size(self, dim=None):
|
||||
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
|
||||
if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
|
||||
return self._tensor_meta.size(*[dim] if dim else [])
|
||||
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
|
||||
return self.tracer.create_proxy(
|
||||
"call_method", "size", (self, dim) if dim else (self,), {}
|
||||
)
|
||||
|
||||
def dim(self):
|
||||
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
|
||||
if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
|
||||
return self._tensor_meta.dim()
|
||||
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
|
||||
return self.tracer.create_proxy("call_method", "dim", (self,), {})
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
|
||||
if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
|
||||
return self._tensor_meta.shape
|
||||
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {})
|
||||
return self.tracer.create_proxy(
|
||||
"call_function", builtins.getattr, (self, "shape"), {}
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
|
||||
if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
|
||||
return self._tensor_meta.dtype
|
||||
return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {})
|
||||
return self.tracer.create_proxy(
|
||||
"call_function", builtins.getattr, (self, "dtype"), {}
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Hack so we can track when devices are used. During meta-tensor propagation,
|
||||
# replace these values with a constant 'meta'
|
||||
return MetaDeviceAttribute(self, 'device')
|
||||
return MetaDeviceAttribute(self, "device")
|
||||
|
||||
def __getattr__(self, k):
|
||||
if k == '_tensor_meta':
|
||||
if k == "_tensor_meta":
|
||||
return self.__getattribute__(k)
|
||||
# note: not added to the graph yet, if this is a method call
|
||||
# we peephole optimize to the method invocation
|
||||
return MetaAttribute(self, k)
|
||||
|
||||
|
||||
class MetaAttribute(MetaProxy):
|
||||
def __init__(self, root, attr: str):
|
||||
|
||||
self.root = root
|
||||
self.attr = attr
|
||||
self.tracer = root.tracer
|
||||
@ -118,33 +130,51 @@ class MetaAttribute(MetaProxy):
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||
self._node = self.tracer.create_proxy(
|
||||
"call_function", getattr, (self.root, self.attr), {}
|
||||
).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||
return self.tracer.create_proxy(
|
||||
"call_method", self.attr, (self.root,) + args, kwargs
|
||||
)
|
||||
|
||||
|
||||
class MetaDeviceAttribute(MetaAttribute):
|
||||
pass
|
||||
|
||||
|
||||
def proxys_to_metas(v):
|
||||
if isinstance(v, MetaDeviceAttribute):
|
||||
return 'meta'
|
||||
return "meta"
|
||||
if isinstance(v, torch.fx.Proxy):
|
||||
assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}'
|
||||
assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta'
|
||||
assert isinstance(v, MetaProxy), f"Expected MetaProxy but got {type(v)}"
|
||||
assert hasattr(v, "_tensor_meta"), "MetaProxy does not have an associated meta"
|
||||
return v._tensor_meta
|
||||
return v
|
||||
|
||||
|
||||
class MetaTracer(torch.fx.Tracer):
|
||||
allow_insert_stateless_mods : bool = True
|
||||
allow_insert_stateless_mods: bool = True
|
||||
|
||||
_TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye']
|
||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
|
||||
|
||||
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
|
||||
rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
def create_proxy(
|
||||
self,
|
||||
kind,
|
||||
target,
|
||||
args,
|
||||
kwargs,
|
||||
name=None,
|
||||
type_expr=None,
|
||||
proxy_factory_fn=None,
|
||||
):
|
||||
rv = super().create_proxy(
|
||||
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
|
||||
)
|
||||
|
||||
if kind == 'placeholder' and target in self.meta_args:
|
||||
if kind == "placeholder" and target in self.meta_args:
|
||||
rv.install_tensor_meta(self.meta_args[target])
|
||||
return rv
|
||||
|
||||
@ -154,54 +184,57 @@ class MetaTracer(torch.fx.Tracer):
|
||||
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
|
||||
# this will break and you will likely see issues where we cannot infer
|
||||
# the size of the output.
|
||||
if 'device' in kwargs:
|
||||
kwargs['device'] = 'meta'
|
||||
if "device" in kwargs:
|
||||
kwargs["device"] = "meta"
|
||||
|
||||
try:
|
||||
args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
|
||||
kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
|
||||
|
||||
if kind == 'call_function':
|
||||
if kind == "call_function":
|
||||
meta_target = manual_meta_overrides.get(target, target)
|
||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||
elif kind == 'call_method':
|
||||
meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index]
|
||||
elif kind == 'call_module':
|
||||
assert hasattr(self, 'orig_forward')
|
||||
elif kind == "call_method":
|
||||
meta_target = getattr(args_metas[0], target) # type: ignore[index]
|
||||
meta_out = meta_target(*args_metas[1:], **kwargs_metas) # type: ignore[index]
|
||||
elif kind == "call_module":
|
||||
assert hasattr(self, "orig_forward")
|
||||
self._disable_module_getattr = True
|
||||
try:
|
||||
mod = self.root.get_submodule(target)
|
||||
mod_type = type(mod)
|
||||
if mod_type in manual_meta_overrides:
|
||||
meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type]
|
||||
meta_out = manual_meta_overrides[mod_type](
|
||||
mod, *args_metas, **kwargs_metas
|
||||
) # type: ignore[misc, arg-type]
|
||||
else:
|
||||
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
|
||||
finally:
|
||||
self._disable_module_getattr = False
|
||||
elif kind == 'get_attr':
|
||||
elif kind == "get_attr":
|
||||
self._disable_module_getattr = True
|
||||
try:
|
||||
attr_itr = self.root
|
||||
atoms = target.split('.')
|
||||
atoms = target.split(".")
|
||||
for atom in atoms:
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
assert isinstance(attr_itr, torch.Tensor)
|
||||
meta_out = attr_itr.to(device='meta')
|
||||
meta_out = attr_itr.to(device="meta")
|
||||
finally:
|
||||
self._disable_module_getattr = False
|
||||
else:
|
||||
return rv
|
||||
|
||||
# TODO
|
||||
assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet'
|
||||
assert isinstance(rv, torch.fx.Proxy), "Dont support composite output yet"
|
||||
rv.install_tensor_meta(meta_out)
|
||||
except Exception as e:
|
||||
warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}')
|
||||
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
|
||||
|
||||
return rv
|
||||
|
||||
def getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||
if getattr(self, '_disable_module_getattr', False):
|
||||
if getattr(self, "_disable_module_getattr", False):
|
||||
return attr_val
|
||||
else:
|
||||
return super().getattr(attr, attr_val, parameter_proxy_cache)
|
||||
@ -228,7 +261,11 @@ class MetaTracer(torch.fx.Tracer):
|
||||
try:
|
||||
return super().path_of_module(mod)
|
||||
except NameError:
|
||||
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
|
||||
if (
|
||||
self.allow_insert_stateless_mods
|
||||
and len(list(mod.parameters())) == 0
|
||||
and len(list(mod.buffers())) == 0
|
||||
):
|
||||
path = self._insert_module_as_submodule(mod)
|
||||
self.prev_module = path
|
||||
return path
|
||||
@ -237,12 +274,13 @@ class MetaTracer(torch.fx.Tracer):
|
||||
def proxy(self, node):
|
||||
return MetaProxy(node, self)
|
||||
|
||||
def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
|
||||
def trace(self, root, meta_args: Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
|
||||
assert isinstance(meta_args, dict)
|
||||
self.meta_args = meta_args
|
||||
|
||||
self.patched_torch_methods = {
|
||||
target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
|
||||
target: gen_constructor_wrapper(getattr(torch, target))
|
||||
for target in self._TORCH_METHODS_TO_PATCH
|
||||
}
|
||||
self.orig_fns = set()
|
||||
|
||||
@ -252,18 +290,22 @@ class MetaTracer(torch.fx.Tracer):
|
||||
|
||||
try:
|
||||
graph = super().trace(root, concrete_args)
|
||||
graph._tracer_extras = {'meta_args': meta_args}
|
||||
graph._tracer_extras = {"meta_args": meta_args}
|
||||
return graph
|
||||
finally:
|
||||
for name, (_, orig) in self.patched_torch_methods.items():
|
||||
setattr(torch, name, orig)
|
||||
|
||||
|
||||
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
||||
meta_args : Optional[Dict[str, torch.Tensor]] = None,
|
||||
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
|
||||
def symbolic_trace(
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
meta_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.fx.GraphModule:
|
||||
tracer = MetaTracer()
|
||||
graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
name = (
|
||||
root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
)
|
||||
gm = torch.fx.GraphModule(tracer.root, graph, name)
|
||||
return gm
|
||||
|
||||
@ -1,7 +1,16 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
|
||||
op_mod, op_gt, op_lt, op_neq, op_eq
|
||||
from torch.fx.tensor_type import TensorType, Dyn
|
||||
from torch.fx.experimental.migrate_gradual_types.operation import (
|
||||
op_add,
|
||||
op_div,
|
||||
op_eq,
|
||||
op_gt,
|
||||
op_lt,
|
||||
op_mod,
|
||||
op_mul,
|
||||
op_neq,
|
||||
op_sub,
|
||||
)
|
||||
from torch.fx.tensor_type import Dyn, TensorType
|
||||
|
||||
|
||||
class Constraint:
|
||||
@ -22,7 +31,7 @@ class Conj(Constraint):
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f'And({self.conjucts})'
|
||||
return f"And({self.conjucts})"
|
||||
|
||||
|
||||
class Disj(Constraint):
|
||||
@ -34,12 +43,14 @@ class Disj(Constraint):
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Disj):
|
||||
return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
|
||||
return (
|
||||
self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f'Or({self.disjuncts})'
|
||||
return f"Or({self.disjuncts})"
|
||||
|
||||
|
||||
class Prod(Constraint):
|
||||
@ -56,13 +67,14 @@ class Prod(Constraint):
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f'Product({self.products})'
|
||||
return f"Product({self.products})"
|
||||
|
||||
|
||||
class T(Constraint):
|
||||
"""
|
||||
True
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@ -70,12 +82,14 @@ class T(Constraint):
|
||||
return isinstance(other, T)
|
||||
|
||||
def __repr__(self):
|
||||
return 'True'
|
||||
return "True"
|
||||
|
||||
|
||||
class F(Constraint):
|
||||
"""
|
||||
False
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@ -83,13 +97,14 @@ class F(Constraint):
|
||||
return isinstance(other, F)
|
||||
|
||||
def __repr__(self):
|
||||
return 'False'
|
||||
return "False"
|
||||
|
||||
|
||||
class BinaryConstraint(Constraint):
|
||||
"""
|
||||
Represents all binary operations
|
||||
"""
|
||||
|
||||
def __init__(self, lhs, rhs, op):
|
||||
"""
|
||||
:param lhs: lhs of the constraint
|
||||
@ -102,21 +117,25 @@ class BinaryConstraint(Constraint):
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, BinaryConstraint):
|
||||
return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
|
||||
return (
|
||||
self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f'({self.lhs} {self.op} {self.rhs})'
|
||||
return f"({self.lhs} {self.op} {self.rhs})"
|
||||
|
||||
|
||||
class BinConstraintT(BinaryConstraint):
|
||||
"""
|
||||
Binary constraints about tensors
|
||||
"""
|
||||
|
||||
def __init__(self, lhs, rhs, op):
|
||||
assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
|
||||
(isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
|
||||
assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and (
|
||||
isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn
|
||||
)
|
||||
super().__init__(lhs, rhs, op)
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -127,6 +146,7 @@ class BinConstraintD(BinaryConstraint):
|
||||
"""
|
||||
Binary constraints about dimensions
|
||||
"""
|
||||
|
||||
def __init__(self, lhs, rhs, op):
|
||||
assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
|
||||
assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
|
||||
@ -137,11 +157,11 @@ class BinConstraintD(BinaryConstraint):
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
|
||||
class TGreatestUpperBound(Constraint):
|
||||
"""
|
||||
Greatest Upper bound for tensors with dynamic type
|
||||
"""
|
||||
|
||||
def __init__(self, res, rhs1, rhs2):
|
||||
"""
|
||||
:param res: tensor variable that stores the result of the outout
|
||||
@ -153,11 +173,15 @@ class TGreatestUpperBound(Constraint):
|
||||
self.rhs2 = rhs2
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}'
|
||||
return f"{self.res} = {self.rhs1}\u2294*{self.rhs2}"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, TGreatestUpperBound):
|
||||
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
|
||||
return (
|
||||
self.res == other.res
|
||||
and self.rhs1 == other.rhs1
|
||||
and self.rhs2 == other.rhs2
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -166,6 +190,7 @@ class DGreatestUpperBound(Constraint):
|
||||
"""
|
||||
Greatest Upper bound for dimensions
|
||||
"""
|
||||
|
||||
def __init__(self, res, rhs1, rhs2):
|
||||
"""
|
||||
:param res: Dimension variable to store the result
|
||||
@ -181,11 +206,15 @@ class DGreatestUpperBound(Constraint):
|
||||
self.rhs2 = rhs2
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.res} = {self.rhs1}\u2294{self.rhs2}'
|
||||
return f"{self.res} = {self.rhs1}\u2294{self.rhs2}"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, DGreatestUpperBound):
|
||||
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
|
||||
return (
|
||||
self.res == other.res
|
||||
and self.rhs1 == other.rhs1
|
||||
and self.rhs2 == other.rhs2
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -194,6 +223,7 @@ class CanReshape(Constraint):
|
||||
"""
|
||||
can_reshape constraint
|
||||
"""
|
||||
|
||||
def __init__(self, src, target):
|
||||
"""
|
||||
:param src: tensor variable
|
||||
@ -203,7 +233,7 @@ class CanReshape(Constraint):
|
||||
self.target = target
|
||||
|
||||
def __repr__(self):
|
||||
return f'can-reshape({self.src}, {self.target})'
|
||||
return f"can-reshape({self.src}, {self.target})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CanReshape):
|
||||
@ -213,7 +243,6 @@ class CanReshape(Constraint):
|
||||
|
||||
|
||||
class IndexSelect(Constraint):
|
||||
|
||||
def __init__(self, tensor_size, input_var, dim_replace, index, output):
|
||||
"""
|
||||
Args:
|
||||
@ -235,26 +264,28 @@ class IndexSelect(Constraint):
|
||||
self.output = output
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
return f' {self.output} = ' \
|
||||
f'IndexSelect({self.input_var}, ' \
|
||||
f'tensor_size: {self.tensor_size}, ' \
|
||||
f'{self.dim_replace}, ' \
|
||||
f'{self.index})'
|
||||
return (
|
||||
f" {self.output} = "
|
||||
f"IndexSelect({self.input_var}, "
|
||||
f"tensor_size: {self.tensor_size}, "
|
||||
f"{self.dim_replace}, "
|
||||
f"{self.index})"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, IndexSelect):
|
||||
return self.tensor_size == other.tensor_size and \
|
||||
self.dim_replace == other.dim_replace and \
|
||||
self.index == other.index and \
|
||||
self.output == other.output and \
|
||||
self.input_var == other.input_var
|
||||
return (
|
||||
self.tensor_size == other.tensor_size
|
||||
and self.dim_replace == other.dim_replace
|
||||
and self.index == other.index
|
||||
and self.output == other.output
|
||||
and self.input_var == other.input_var
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class Transpose(Constraint):
|
||||
|
||||
def __init__(self, tensor_size, input_var, index1, index2, output):
|
||||
"""
|
||||
Args:
|
||||
@ -276,26 +307,28 @@ class Transpose(Constraint):
|
||||
self.output = output
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
return f' {self.output} = ' \
|
||||
f'Transpose({self.input_var}, ' \
|
||||
f'tensor_size: {self.tensor_size}, ' \
|
||||
f'{self.index1}, ' \
|
||||
f'{self.index2})'
|
||||
return (
|
||||
f" {self.output} = "
|
||||
f"Transpose({self.input_var}, "
|
||||
f"tensor_size: {self.tensor_size}, "
|
||||
f"{self.index1}, "
|
||||
f"{self.index2})"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Transpose):
|
||||
return self.tensor_size == other.tensor_size and \
|
||||
self.index1 == other.index1 and \
|
||||
self.index2 == other.index2 and \
|
||||
self.output == other.output and \
|
||||
self.input_var == other.input_var
|
||||
return (
|
||||
self.tensor_size == other.tensor_size
|
||||
and self.index1 == other.index1
|
||||
and self.index2 == other.index2
|
||||
and self.output == other.output
|
||||
and self.input_var == other.input_var
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class GetItem(Constraint):
|
||||
|
||||
def __init__(self, tensor_size, index, res, input_var):
|
||||
"""
|
||||
Constraint for getting item given a tensor size
|
||||
@ -312,19 +345,21 @@ class GetItem(Constraint):
|
||||
self.input_var = input_var
|
||||
|
||||
def __repr__(self):
|
||||
return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
|
||||
return f" {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, GetItem):
|
||||
return self.res == other.res and \
|
||||
self.tensor_size == other.tensor_size and \
|
||||
self.index == other.index and \
|
||||
self.input_var == other.input_var
|
||||
return (
|
||||
self.res == other.res
|
||||
and self.tensor_size == other.tensor_size
|
||||
and self.index == other.index
|
||||
and self.input_var == other.input_var
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
class GetItemTensor(Constraint):
|
||||
|
||||
class GetItemTensor(Constraint):
|
||||
def __init__(self, tensor_size, index_tuple, res, input_var):
|
||||
"""
|
||||
Constraint for getting item given a tensor size
|
||||
@ -343,20 +378,32 @@ class GetItemTensor(Constraint):
|
||||
self.input_var = input_var
|
||||
|
||||
def __repr__(self):
|
||||
return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
|
||||
return f" {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, GetItemTensor):
|
||||
return self.res == other.res and \
|
||||
self.tensor_size == other.tensor_size and \
|
||||
self.index_tuple == other.index_tuple and \
|
||||
self.input_var == other.input_var
|
||||
return (
|
||||
self.res == other.res
|
||||
and self.tensor_size == other.tensor_size
|
||||
and self.index_tuple == other.index_tuple
|
||||
and self.input_var == other.input_var
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
class CalcConv(Constraint):
|
||||
|
||||
def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
|
||||
class CalcConv(Constraint):
|
||||
def __init__(
|
||||
self,
|
||||
conv_result,
|
||||
input_var,
|
||||
c_out,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
matching_constraint_vars,
|
||||
):
|
||||
"""
|
||||
:param conv_result: the convolution result
|
||||
:param input_var: input to convolution
|
||||
@ -373,25 +420,41 @@ class CalcConv(Constraint):
|
||||
self.matching_constraint = matching_constraint_vars
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.conv_result} =' \
|
||||
f' calc-conv({self.input_var},' \
|
||||
f' {self.c_out}, {self.kernel}, ' \
|
||||
f'{self.padding}, {self.stride},' \
|
||||
f' {self.dilation})'
|
||||
return (
|
||||
f"{self.conv_result} ="
|
||||
f" calc-conv({self.input_var},"
|
||||
f" {self.c_out}, {self.kernel}, "
|
||||
f"{self.padding}, {self.stride},"
|
||||
f" {self.dilation})"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CalcConv):
|
||||
return self.conv_result == other.conv_result and self.input_var == other.input_var and \
|
||||
self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
|
||||
and self.stride == other.stride and self.dilation == other.dilation \
|
||||
return (
|
||||
self.conv_result == other.conv_result
|
||||
and self.input_var == other.input_var
|
||||
and self.c_out == other.c_out
|
||||
and self.kernel == other.kernel
|
||||
and self.padding == other.padding
|
||||
and self.stride == other.stride
|
||||
and self.dilation == other.dilation
|
||||
and self.matching_constraint == other.matching_constraint
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class CalcMaxPool(Constraint):
|
||||
|
||||
def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
|
||||
def __init__(
|
||||
self,
|
||||
maxpool_result,
|
||||
input_var,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
matching_constraint_vars,
|
||||
):
|
||||
"""
|
||||
:param maxpool_result: the result of maxpool
|
||||
:param input_var: input to convolution
|
||||
@ -406,18 +469,25 @@ class CalcMaxPool(Constraint):
|
||||
self.matching_constraint = matching_constraint_vars
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.maxpool_result} =' \
|
||||
f' calc-maxpool({self.input_var},' \
|
||||
f' {self.kernel}, ' \
|
||||
f'{self.padding}, {self.stride},' \
|
||||
f' {self.dilation})'
|
||||
return (
|
||||
f"{self.maxpool_result} ="
|
||||
f" calc-maxpool({self.input_var},"
|
||||
f" {self.kernel}, "
|
||||
f"{self.padding}, {self.stride},"
|
||||
f" {self.dilation})"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CalcMaxPool):
|
||||
return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
|
||||
and self.kernel == other.kernel and self.padding == other.padding \
|
||||
and self.stride == other.stride and self.dilation == other.dilation \
|
||||
return (
|
||||
self.maxpool_result == other.maxpool_result
|
||||
and self.input_var == other.input_var
|
||||
and self.kernel == other.kernel
|
||||
and self.padding == other.padding
|
||||
and self.stride == other.stride
|
||||
and self.dilation == other.dilation
|
||||
and self.matching_constraint == other.matching_constraint
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -437,21 +507,28 @@ class ApplyBroadcasting(Constraint):
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, ApplyBroadcasting):
|
||||
return self.res1 == other.res1 \
|
||||
and self.res2 == other.res2 \
|
||||
and self.input1 == other.input1 \
|
||||
return (
|
||||
self.res1 == other.res1
|
||||
and self.res2 == other.res2
|
||||
and self.input1 == other.input1
|
||||
and self.input2 == other.input2
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
|
||||
return (
|
||||
f"{self.res1}, {self.res2} ="
|
||||
f" apply-broadcasting({self.input1},"
|
||||
f" {self.input2})"
|
||||
)
|
||||
|
||||
|
||||
class CalcProduct(Constraint):
|
||||
"""
|
||||
Given correct dimensions, calculate the product for flatten accounting for Dyn
|
||||
"""
|
||||
|
||||
def __init__(self, start, end, flattened, dims_to_flatten):
|
||||
"""
|
||||
:param start: start index
|
||||
@ -471,20 +548,25 @@ class CalcProduct(Constraint):
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CalcProduct):
|
||||
return self.start == other.start and self.end == other.end and \
|
||||
self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
|
||||
return (
|
||||
self.start == other.start
|
||||
and self.end == other.end
|
||||
and self.dims_to_flatten == other.dims_to_flatten
|
||||
and self.flattened == other.flattened
|
||||
)
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
|
||||
return f"{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})"
|
||||
|
||||
|
||||
class TVar:
|
||||
"""
|
||||
Tensor variable with no tensor constructor
|
||||
"""
|
||||
|
||||
def __init__(self, tvar):
|
||||
"""
|
||||
:param tvar: tensor variable
|
||||
@ -492,7 +574,7 @@ class TVar:
|
||||
self.tvar = tvar
|
||||
|
||||
def __repr__(self):
|
||||
return f'TV({self.tvar})'
|
||||
return f"TV({self.tvar})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, TVar):
|
||||
@ -505,6 +587,7 @@ class DVar:
|
||||
"""
|
||||
Dimension variable
|
||||
"""
|
||||
|
||||
def __init__(self, c):
|
||||
"""
|
||||
:param c: character or number
|
||||
@ -512,7 +595,7 @@ class DVar:
|
||||
self.c = c
|
||||
|
||||
def __repr__(self):
|
||||
return f'DV({self.c})'
|
||||
return f"DV({self.c})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, DVar):
|
||||
@ -525,6 +608,7 @@ class BVar:
|
||||
"""
|
||||
Boolean variable
|
||||
"""
|
||||
|
||||
def __init__(self, c):
|
||||
"""
|
||||
:param c: character or number
|
||||
@ -532,7 +616,7 @@ class BVar:
|
||||
self.c = c
|
||||
|
||||
def __repr__(self):
|
||||
return f'BV({self.c})'
|
||||
return f"BV({self.c})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, BVar):
|
||||
@ -554,5 +638,6 @@ def is_bool_expr(constraint):
|
||||
else:
|
||||
return isinstance(constraint, (BVar, Conj, Disj))
|
||||
|
||||
|
||||
def is_dim(d):
|
||||
return isinstance(d, (DVar, int)) or d == Dyn
|
||||
|
||||
@ -1,34 +1,71 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import operator
|
||||
import warnings
|
||||
from typing import Callable, Dict, Iterable
|
||||
|
||||
import torch
|
||||
from torch.fx._symbolic_trace import _assert_is_none
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
|
||||
Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \
|
||||
TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound
|
||||
from torch.fx.experimental.migrate_gradual_types.operation import \
|
||||
op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul
|
||||
from torch.fx.node import Target, Node
|
||||
from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \
|
||||
gen_bvar
|
||||
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint import (
|
||||
ApplyBroadcasting,
|
||||
BinConstraintD,
|
||||
BinConstraintT,
|
||||
CalcConv,
|
||||
CalcMaxPool,
|
||||
CalcProduct,
|
||||
CanReshape,
|
||||
Conj,
|
||||
DGreatestUpperBound,
|
||||
Disj,
|
||||
DVar,
|
||||
F,
|
||||
GetItem,
|
||||
GetItemTensor,
|
||||
IndexSelect,
|
||||
T,
|
||||
TGreatestUpperBound,
|
||||
Transpose,
|
||||
TVar,
|
||||
)
|
||||
from torch.fx.experimental.migrate_gradual_types.operation import (
|
||||
op_add,
|
||||
op_consistency,
|
||||
op_div,
|
||||
op_eq,
|
||||
op_gt,
|
||||
op_leq,
|
||||
op_lt,
|
||||
op_matching,
|
||||
op_mul,
|
||||
op_neq,
|
||||
op_precision,
|
||||
op_sub,
|
||||
)
|
||||
from torch.fx.experimental.migrate_gradual_types.util import (
|
||||
gen_bvar,
|
||||
gen_dvar,
|
||||
gen_nat_constraints,
|
||||
gen_tensor_dims,
|
||||
gen_tvar,
|
||||
)
|
||||
from torch.fx.node import Node, Target
|
||||
from torch.fx.tensor_type import Dyn, TensorType
|
||||
from torch.nn.modules.conv import Conv2d
|
||||
from torch.nn.modules.batchnorm import BatchNorm2d
|
||||
from torch.nn.modules.conv import Conv2d
|
||||
|
||||
|
||||
_INFERENCE_RULES: Dict[Target, Callable] = {}
|
||||
|
||||
MAX_TENSOR_RANK = 4
|
||||
|
||||
|
||||
def register_inference_rule(call_target):
|
||||
def register(fn):
|
||||
if call_target in _INFERENCE_RULES:
|
||||
raise RuntimeError(f'Inference rule already registered for {call_target}!')
|
||||
raise RuntimeError(f"Inference rule already registered for {call_target}!")
|
||||
_INFERENCE_RULES[call_target] = fn
|
||||
return fn
|
||||
|
||||
return register
|
||||
|
||||
|
||||
@ -55,10 +92,11 @@ def get_attr_inference_rule(n: Node, symbols, constraints, counter):
|
||||
input = symbols[n.args[0]]
|
||||
attr = n.args[1]
|
||||
|
||||
if attr == 'device':
|
||||
if attr == "device":
|
||||
return [BinConstraintT(input, output, op_eq)], counter
|
||||
else:
|
||||
raise NotImplementedError('Not yet implemented')
|
||||
raise NotImplementedError("Not yet implemented")
|
||||
|
||||
|
||||
@register_inference_rule(torch.bmm)
|
||||
def bmm_inference_rule(n: Node, symbols, constraints, counter):
|
||||
@ -79,26 +117,53 @@ def bmm_inference_rule(n: Node, symbols, constraints, counter):
|
||||
dims_input1, counter = gen_tensor_dims(3, counter)
|
||||
dims_input2, counter = gen_tensor_dims(3, counter)
|
||||
|
||||
inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
|
||||
BinConstraintT(bmm_input2, Dyn, op_eq),
|
||||
BinConstraintT(bmm_output, Dyn, op_eq)])
|
||||
inputs_dyn = Conj(
|
||||
[
|
||||
BinConstraintT(bmm_input1, Dyn, op_eq),
|
||||
BinConstraintT(bmm_input2, Dyn, op_eq),
|
||||
BinConstraintT(bmm_output, Dyn, op_eq),
|
||||
]
|
||||
)
|
||||
|
||||
input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
|
||||
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
|
||||
BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)])
|
||||
input1_dyn = Conj(
|
||||
[
|
||||
BinConstraintT(bmm_input1, Dyn, op_eq),
|
||||
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
|
||||
BinConstraintT(
|
||||
bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq),
|
||||
BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
|
||||
BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)])
|
||||
input2_dyn = Conj(
|
||||
[
|
||||
BinConstraintT(bmm_input2, Dyn, op_eq),
|
||||
BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
|
||||
BinConstraintT(
|
||||
bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)]
|
||||
consistency_constraints = [
|
||||
BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)
|
||||
]
|
||||
|
||||
batch_size, counter = gen_dvar(counter)
|
||||
|
||||
inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
|
||||
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
|
||||
BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq),
|
||||
*consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])])
|
||||
inputs_are_tensors = Conj(
|
||||
[
|
||||
BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
|
||||
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
|
||||
BinConstraintT(
|
||||
bmm_output,
|
||||
TensorType([batch_size, dims_input1[1], dims_input2[2]]),
|
||||
op_eq,
|
||||
),
|
||||
*consistency_constraints,
|
||||
DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]),
|
||||
]
|
||||
)
|
||||
|
||||
return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
|
||||
|
||||
@ -115,8 +180,6 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter):
|
||||
assert isinstance(n.args[1], int)
|
||||
assert isinstance(n.args[2], Node)
|
||||
|
||||
|
||||
|
||||
index_select, counter = gen_tvar(counter)
|
||||
symbols[n] = index_select
|
||||
|
||||
@ -126,10 +189,30 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter):
|
||||
is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
|
||||
is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
|
||||
|
||||
c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
|
||||
for i in range(MAX_TENSOR_RANK)])])
|
||||
c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
|
||||
for i in range(MAX_TENSOR_RANK)])])
|
||||
c2 = Conj(
|
||||
[
|
||||
is_size_1,
|
||||
Disj(
|
||||
[
|
||||
IndexSelect(
|
||||
i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select
|
||||
)
|
||||
for i in range(MAX_TENSOR_RANK)
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
c3 = Conj(
|
||||
[
|
||||
is_dyn,
|
||||
Disj(
|
||||
[
|
||||
IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
|
||||
for i in range(MAX_TENSOR_RANK)
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return [Disj([c2, c3])], counter
|
||||
|
||||
@ -158,14 +241,27 @@ def expand_inference_rule(n: Node, symbols, constraints, counter):
|
||||
assert isinstance(symbols[arg], DVar)
|
||||
e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))
|
||||
|
||||
e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq)
|
||||
e2_constraint = BinConstraintT(
|
||||
e2,
|
||||
TensorType(
|
||||
[arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]
|
||||
),
|
||||
op_eq,
|
||||
)
|
||||
|
||||
constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand)
|
||||
constraints, counter = gen_broadcasting_constraints(
|
||||
e1, e2, symbols, counter, expand
|
||||
)
|
||||
|
||||
# constraint the output size
|
||||
dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
|
||||
nat_constraints = gen_nat_constraints(dims)
|
||||
c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints]
|
||||
c = [
|
||||
BinConstraintT(expand, TensorType(dims), op_eq),
|
||||
*nat_constraints,
|
||||
e2_constraint,
|
||||
*e2_nat_constraints,
|
||||
]
|
||||
constraints += c
|
||||
|
||||
return constraints, counter
|
||||
@ -206,7 +302,7 @@ def equality_inference_rule(n: Node, symbols, constraints, counter):
|
||||
my_size = [symbols[arg] for arg in n.args[0]]
|
||||
return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
|
||||
@register_inference_rule("transpose")
|
||||
@ -225,10 +321,17 @@ def transpose_inference_rule(n: Node, symbols, constraints, counter):
|
||||
assert isinstance(from_arg, TVar)
|
||||
|
||||
# input and output are dyn
|
||||
is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)])
|
||||
is_dyn = Conj(
|
||||
[BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]
|
||||
)
|
||||
|
||||
# or input is a tensor and we actually do the replacement
|
||||
c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)])
|
||||
c3 = Disj(
|
||||
[
|
||||
Transpose(i + 1, from_arg, n.args[1], n.args[2], output)
|
||||
for i in range(MAX_TENSOR_RANK)
|
||||
]
|
||||
)
|
||||
|
||||
return [Disj([is_dyn, c3])], counter
|
||||
|
||||
@ -250,8 +353,11 @@ def type_inference_rule(n: Node, symbols, constraints, counter):
|
||||
assert isinstance(from_arg, TVar)
|
||||
assert isinstance(to_arg, TVar)
|
||||
|
||||
return [BinConstraintT(from_arg, to_arg, op_consistency),
|
||||
BinConstraintT(output, to_arg, op_eq)], counter
|
||||
return [
|
||||
BinConstraintT(from_arg, to_arg, op_consistency),
|
||||
BinConstraintT(output, to_arg, op_eq),
|
||||
], counter
|
||||
|
||||
|
||||
@register_inference_rule("masked_fill_")
|
||||
def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
|
||||
@ -273,9 +379,11 @@ def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
|
||||
if isinstance(e1, TVar) and isinstance(e2, TVar):
|
||||
masked_fill_tensor, counter = gen_tvar(counter)
|
||||
symbols[n] = masked_fill_tensor
|
||||
return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor)
|
||||
return gen_broadcasting_constraints(
|
||||
e1, e2, symbols, counter, masked_fill_tensor
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError('Not yet implemented')
|
||||
raise NotImplementedError("Not yet implemented")
|
||||
|
||||
|
||||
@register_inference_rule(torch.nn.functional.embedding)
|
||||
@ -286,7 +394,9 @@ def embedding_inference_rule_functional(n: Node, symbols, constraints, counter):
|
||||
|
||||
# will treat this as a static shape. So we will not use matching.
|
||||
weight_dims, counter = gen_tensor_dims(2, counter)
|
||||
equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq)
|
||||
equality_constraint = BinConstraintT(
|
||||
embedding_dim_weights, TensorType(weight_dims), op_eq
|
||||
)
|
||||
embedding_dim = weight_dims[1]
|
||||
constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter)
|
||||
return [equality_constraint] + constraints, counter
|
||||
@ -302,7 +412,6 @@ def embedding_inference_rule(n: Node, module_instance, symbols, constraints, cou
|
||||
|
||||
|
||||
def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
|
||||
|
||||
embedding_output, counter = gen_tvar(counter)
|
||||
symbols[n] = embedding_output
|
||||
embedding_input = symbols[n.args[0]]
|
||||
@ -318,9 +427,15 @@ def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
|
||||
nat_constraints = gen_nat_constraints(new_dims)
|
||||
|
||||
# we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
|
||||
c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
|
||||
BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] +
|
||||
nat_constraints)
|
||||
c_tensor_i = Conj(
|
||||
[
|
||||
BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
|
||||
BinConstraintT(
|
||||
embedding_output, TensorType(new_dims + [embedding_dim]), op_eq
|
||||
),
|
||||
]
|
||||
+ nat_constraints
|
||||
)
|
||||
c2.append(c_tensor_i)
|
||||
|
||||
return [Disj([c1, Disj(c2)])], counter
|
||||
@ -348,9 +463,10 @@ def view_inference_rule(n: Node, symbols, constraints, counter):
|
||||
my_view, counter = gen_tvar(counter)
|
||||
symbols[n] = my_view
|
||||
|
||||
|
||||
src_var = symbols[n.args[0]]
|
||||
t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape
|
||||
t2 = [
|
||||
symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]
|
||||
] # target shape
|
||||
t2_type = []
|
||||
num_constraints = []
|
||||
|
||||
@ -382,7 +498,6 @@ def size_inference_rule(n: Node, symbols, constraints, counter):
|
||||
Ex: size = input_ids.size()
|
||||
"""
|
||||
|
||||
|
||||
if len(n.args) == 1:
|
||||
# generate the new variable
|
||||
size, counter = gen_tvar(counter)
|
||||
@ -398,7 +513,10 @@ def size_inference_rule(n: Node, symbols, constraints, counter):
|
||||
size_index, counter = gen_dvar(counter)
|
||||
symbols[n] = size_index
|
||||
input = symbols[n.args[0]]
|
||||
c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)]
|
||||
c2 = [
|
||||
GetItem(i + 1, n.args[1], size_index, input)
|
||||
for i in range(MAX_TENSOR_RANK)
|
||||
]
|
||||
c3 = BinConstraintD(0, size_index, op_leq)
|
||||
|
||||
input_dyn = BinConstraintT(input, Dyn, op_eq)
|
||||
@ -452,9 +570,14 @@ def cumsum_inference_rule(n: Node, symbols, constraints, counter):
|
||||
|
||||
nat_constraints = gen_nat_constraints(new_dims)
|
||||
|
||||
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq),
|
||||
BinConstraintT(output, TensorType(new_dims), op_eq)] +
|
||||
[range_check(arg_1, i)] + nat_constraints)
|
||||
c_tensor_i = Conj(
|
||||
[
|
||||
BinConstraintT(input, TensorType(new_dims), op_eq),
|
||||
BinConstraintT(output, TensorType(new_dims), op_eq),
|
||||
]
|
||||
+ [range_check(arg_1, i)]
|
||||
+ nat_constraints
|
||||
)
|
||||
|
||||
c2.append(c_tensor_i)
|
||||
dyn_or_tensor = Disj([c1, Disj(c2)])
|
||||
@ -481,7 +604,6 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
|
||||
get_item_arg = symbols[n.args[0]]
|
||||
assert isinstance(get_item_arg, TVar)
|
||||
|
||||
|
||||
# if the input is dynamic, we accept any index and return
|
||||
# a dynamic dimension as output
|
||||
input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
|
||||
@ -492,8 +614,10 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
|
||||
# generate a getItem constraint which will be expanded based on the
|
||||
# tensor dimension.
|
||||
|
||||
c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)]
|
||||
|
||||
c2 = [
|
||||
GetItem(i + 1, n.args[1], get_item_output, get_item_arg)
|
||||
for i in range(MAX_TENSOR_RANK)
|
||||
]
|
||||
|
||||
# since the output is a dimension, we make sure it's a natural number
|
||||
# added as a conjunction to the disjunction of c2
|
||||
@ -515,8 +639,10 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
|
||||
output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment]
|
||||
c1 = Conj([input_dyn, output_dyn])
|
||||
|
||||
c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc]
|
||||
for i in range(MAX_TENSOR_RANK)]
|
||||
c2 = [
|
||||
GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc]
|
||||
for i in range(MAX_TENSOR_RANK)
|
||||
]
|
||||
else:
|
||||
# TODO: we should figure out why there is a key-error here.
|
||||
return [], counter
|
||||
@ -524,7 +650,7 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter):
|
||||
return [Disj([c1, *c2])], counter
|
||||
|
||||
else:
|
||||
raise RuntimeError('Method not yet implemented')
|
||||
raise RuntimeError("Method not yet implemented")
|
||||
|
||||
|
||||
@register_inference_rule(operator.gt)
|
||||
@ -553,7 +679,7 @@ def gt_inference_rule(n: Node, symbols, constraints, counter):
|
||||
return [equality_constraint], counter
|
||||
|
||||
else:
|
||||
raise RuntimeError('Sort Mismatch')
|
||||
raise RuntimeError("Sort Mismatch")
|
||||
|
||||
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
|
||||
if isinstance(e1, DVar):
|
||||
@ -567,7 +693,9 @@ def gt_inference_rule(n: Node, symbols, constraints, counter):
|
||||
elif isinstance(e1, TVar) and isinstance(e2, int):
|
||||
# then we made the wrong assumption about the argument being a tensor
|
||||
# so we should fix the assumption
|
||||
warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.')
|
||||
warnings.warn(
|
||||
f"Made the wrong assumption for node {n}. Correctness not guaranteed."
|
||||
)
|
||||
|
||||
new_e1, counter = gen_dvar(counter)
|
||||
symbols[n.args[0]] = new_e1
|
||||
@ -580,10 +708,10 @@ def gt_inference_rule(n: Node, symbols, constraints, counter):
|
||||
return [equality_constraint], counter
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
|
||||
@register_inference_rule(operator.eq)
|
||||
@ -609,7 +737,7 @@ def eq_inference_rule(n: Node, symbols, constraints, counter):
|
||||
return [equality_constraint], counter
|
||||
|
||||
else:
|
||||
raise RuntimeError('Sort Mismatch')
|
||||
raise RuntimeError("Sort Mismatch")
|
||||
|
||||
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
|
||||
if isinstance(e1, DVar):
|
||||
@ -620,9 +748,10 @@ def eq_inference_rule(n: Node, symbols, constraints, counter):
|
||||
equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
|
||||
return [equality_constraint], counter
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
|
||||
@register_inference_rule(operator.ne)
|
||||
def neq_inference_rule(n: Node, symbols, constraints, counter):
|
||||
@ -641,7 +770,6 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
|
||||
|
||||
# implementing for size 3 and 4
|
||||
if len(n.args[1]) == 3:
|
||||
|
||||
assert isinstance(n.args[1][0], (Node, int))
|
||||
assert isinstance(n.args[1][1], (Node, int))
|
||||
assert isinstance(n.args[1][2], (Node, int))
|
||||
@ -662,11 +790,19 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
|
||||
neq_3 = BinConstraintD(d3, b[2], op_neq)
|
||||
|
||||
# dimensions inconsistent
|
||||
dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1])
|
||||
dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2])
|
||||
dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3])
|
||||
dims_inconsistent1 = Conj(
|
||||
[BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]
|
||||
)
|
||||
dims_inconsistent2 = Conj(
|
||||
[BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]
|
||||
)
|
||||
dims_inconsistent3 = Conj(
|
||||
[BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]
|
||||
)
|
||||
|
||||
dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3])
|
||||
dims_inconsistent = Disj(
|
||||
[dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]
|
||||
)
|
||||
|
||||
# we are covering size 3 and 4 only for now
|
||||
ne_constraint = Conj([input_is_size3, dims_inconsistent])
|
||||
@ -675,7 +811,6 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
|
||||
equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
|
||||
|
||||
elif len(n.args[1]) == 4:
|
||||
|
||||
assert isinstance(n.args[1][0], (Node, int))
|
||||
assert isinstance(n.args[1][1], (Node, int))
|
||||
assert isinstance(n.args[1][2], (Node, int))
|
||||
@ -703,12 +838,27 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
|
||||
neq_4 = BinConstraintD(d4, b4, op_neq)
|
||||
|
||||
# dimensions to inconsistent
|
||||
dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1])
|
||||
dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2])
|
||||
dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3])
|
||||
dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4])
|
||||
dims_inconsistent1 = Conj(
|
||||
[BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]
|
||||
)
|
||||
dims_inconsistent2 = Conj(
|
||||
[BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]
|
||||
)
|
||||
dims_inconsistent3 = Conj(
|
||||
[BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]
|
||||
)
|
||||
dims_inconsistent4 = Conj(
|
||||
[BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]
|
||||
)
|
||||
|
||||
dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4])
|
||||
dims_inconsistent = Disj(
|
||||
[
|
||||
dims_inconsistent1,
|
||||
dims_inconsistent2,
|
||||
dims_inconsistent3,
|
||||
dims_inconsistent4,
|
||||
]
|
||||
)
|
||||
|
||||
ne_constraint = Conj([input_is_size4, dims_inconsistent])
|
||||
|
||||
@ -717,7 +867,7 @@ def neq_inference_rule(n: Node, symbols, constraints, counter):
|
||||
equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
return [equality_constraint], counter
|
||||
|
||||
@ -748,7 +898,7 @@ def lt_inference_rule(n: Node, symbols, constraints, counter):
|
||||
return [equality_constraint], counter
|
||||
|
||||
else:
|
||||
raise RuntimeError('Sort Mismatch')
|
||||
raise RuntimeError("Sort Mismatch")
|
||||
|
||||
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
|
||||
if isinstance(e1, DVar):
|
||||
@ -759,10 +909,10 @@ def lt_inference_rule(n: Node, symbols, constraints, counter):
|
||||
equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
|
||||
return [equality_constraint], counter
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
|
||||
@register_inference_rule(torch.full)
|
||||
@ -788,28 +938,42 @@ def arange_inference_rule(n: Node, symbols, constraints, counter):
|
||||
if len(n.args) == 1:
|
||||
end = symbols[n.args[0]]
|
||||
else:
|
||||
raise NotImplementedError('Not yet implemented')
|
||||
raise NotImplementedError("Not yet implemented")
|
||||
|
||||
# int((end - start) / step)
|
||||
d1, counter = gen_dvar(counter)
|
||||
size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
|
||||
size_constraint = BinConstraintD(
|
||||
d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq
|
||||
)
|
||||
arange, counter = gen_tvar(counter)
|
||||
symbols[n] = arange
|
||||
|
||||
# either the a parameter is a number or it is Dyn
|
||||
c1 = Disj([BinConstraintD(end, Dyn, op_eq),
|
||||
BinConstraintD(start, Dyn, op_eq),
|
||||
BinConstraintD(step, Dyn, op_eq)])
|
||||
c1 = Disj(
|
||||
[
|
||||
BinConstraintD(end, Dyn, op_eq),
|
||||
BinConstraintD(start, Dyn, op_eq),
|
||||
BinConstraintD(step, Dyn, op_eq),
|
||||
]
|
||||
)
|
||||
c2 = BinConstraintD(d1, Dyn, op_eq)
|
||||
both_dyn = Conj([c1, c2])
|
||||
|
||||
c11 = Conj([BinConstraintD(end, Dyn, op_neq),
|
||||
BinConstraintD(start, Dyn, op_neq),
|
||||
BinConstraintD(step, Dyn, op_neq)])
|
||||
c11 = Conj(
|
||||
[
|
||||
BinConstraintD(end, Dyn, op_neq),
|
||||
BinConstraintD(start, Dyn, op_neq),
|
||||
BinConstraintD(step, Dyn, op_neq),
|
||||
]
|
||||
)
|
||||
c22 = BinConstraintD(d1, Dyn, op_neq)
|
||||
both_numbers = Conj([c11, c22, size_constraint])
|
||||
|
||||
return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
|
||||
return [
|
||||
BinConstraintT(arange, TensorType([d1]), op_eq),
|
||||
Disj([both_dyn, both_numbers]),
|
||||
], counter
|
||||
|
||||
|
||||
def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
|
||||
# additional vars that don't correspond to expressions
|
||||
@ -829,7 +993,6 @@ def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
|
||||
@register_inference_rule(torch.add)
|
||||
@register_inference_rule(operator.add)
|
||||
def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
|
||||
|
||||
op_code = None
|
||||
if n.target == operator.add or n.target == torch.add:
|
||||
op_code = op_add
|
||||
@ -837,7 +1000,9 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
|
||||
op_code = op_mul
|
||||
|
||||
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
||||
if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar):
|
||||
if isinstance(symbols[n.args[0]], TVar) and isinstance(
|
||||
symbols[n.args[1]], TVar
|
||||
):
|
||||
my_output, counter = gen_tvar(counter)
|
||||
symbols[n] = my_output
|
||||
e1 = symbols[n.args[0]]
|
||||
@ -845,7 +1010,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
|
||||
|
||||
return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output)
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
|
||||
if isinstance(symbols[n.args[0]], TVar):
|
||||
@ -859,8 +1024,14 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
|
||||
e1 = symbols[n.args[0]]
|
||||
|
||||
# we will propagate the runtime value here since this is regular addition
|
||||
c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq),
|
||||
BinConstraintD(0, my_output, op_leq)])
|
||||
c = Conj(
|
||||
[
|
||||
BinConstraintD(
|
||||
my_output, BinConstraintD(e1, n.args[1], op_code), op_eq
|
||||
),
|
||||
BinConstraintD(0, my_output, op_leq),
|
||||
]
|
||||
)
|
||||
return [c], counter
|
||||
|
||||
elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
|
||||
@ -875,16 +1046,22 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
|
||||
e2 = symbols[n.args[1]]
|
||||
|
||||
# we will propagate the runtime value here since this is regular addition
|
||||
c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq),
|
||||
BinConstraintD(0, my_output, op_leq)])
|
||||
c = Conj(
|
||||
[
|
||||
BinConstraintD(
|
||||
my_output, BinConstraintD(e2, n.args[0], op_code), op_eq
|
||||
),
|
||||
BinConstraintD(0, my_output, op_leq),
|
||||
]
|
||||
)
|
||||
return [c], counter
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
else:
|
||||
# TODO generate add constraints for scalar addition
|
||||
raise NotImplementedError('Addition not yet implemented')
|
||||
raise NotImplementedError("Addition not yet implemented")
|
||||
|
||||
|
||||
@register_inference_rule(torch.flatten)
|
||||
@ -915,7 +1092,9 @@ def flatten_inference_rule(n: Node, symbols, constraints, counter):
|
||||
|
||||
const = []
|
||||
for i in range(1, MAX_TENSOR_RANK + 1):
|
||||
c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter)
|
||||
c, counter = generate_flatten_constraints(
|
||||
start_dim, end_dim, input, flattened, i, counter
|
||||
)
|
||||
const.append(c)
|
||||
|
||||
return [Disj([both_dyn, *const])], counter
|
||||
@ -937,7 +1116,9 @@ def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, co
|
||||
Input should be consistent with the normalized_shape
|
||||
"""
|
||||
assert isinstance(n.args[0], Node)
|
||||
return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter)
|
||||
return gen_layer_norm_constraints(
|
||||
n, module_instance.normalized_shape, symbols, counter
|
||||
)
|
||||
|
||||
|
||||
def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
|
||||
@ -955,13 +1136,18 @@ def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
|
||||
new_dims_rhs, counter = gen_tensor_dims(i, counter)
|
||||
nat_constraints = gen_nat_constraints(new_dims_rhs)
|
||||
|
||||
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
|
||||
BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
|
||||
add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
|
||||
nat_constraints)
|
||||
c_tensor_i = Conj(
|
||||
[
|
||||
BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
|
||||
BinConstraintT(output, TensorType(new_dims_rhs), op_eq),
|
||||
]
|
||||
+ add_layer_norm_constraints(new_dims_rhs, list(normalized_shape))
|
||||
+ nat_constraints
|
||||
)
|
||||
c2.append(c_tensor_i)
|
||||
return [Disj([c1, Disj(c2)])], counter
|
||||
|
||||
|
||||
@register_inference_rule(torch.nn.Dropout)
|
||||
@register_inference_rule(torch.nn.ReLU)
|
||||
def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
||||
@ -983,7 +1169,9 @@ def linear_inference_rule(n: Node, module_instance, symbols, constraints, counte
|
||||
If the input is Dyn, then so should the output
|
||||
"""
|
||||
assert isinstance(n.args[0], Node)
|
||||
return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter)
|
||||
return linear_constraints(
|
||||
n, module_instance.in_features, module_instance.out_features, symbols, counter
|
||||
)
|
||||
|
||||
|
||||
@register_inference_rule("dim") # type: ignore[attr-defined]
|
||||
@ -1001,8 +1189,12 @@ def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
|
||||
for i in range(1, MAX_TENSOR_RANK + 1):
|
||||
new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
|
||||
|
||||
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
|
||||
BinConstraintD(my_dim, i, op_eq)])
|
||||
c_tensor_i = Conj(
|
||||
[
|
||||
BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
|
||||
BinConstraintD(my_dim, i, op_eq),
|
||||
]
|
||||
)
|
||||
c1.append(c_tensor_i)
|
||||
|
||||
return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
|
||||
@ -1012,8 +1204,12 @@ def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
|
||||
def torch_linear_inference_rule(n: Node, symbols, constraints, counter):
|
||||
assert isinstance(n.args[0], Node)
|
||||
weight_dims, counter = gen_tensor_dims(2, counter)
|
||||
equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq)
|
||||
constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter)
|
||||
equality_constraint = BinConstraintT(
|
||||
symbols[n.args[1]], TensorType(weight_dims), op_eq
|
||||
)
|
||||
constraints, counter = linear_constraints(
|
||||
n, weight_dims[1], weight_dims[0], symbols, counter
|
||||
)
|
||||
return [equality_constraint] + constraints, counter
|
||||
|
||||
|
||||
@ -1034,13 +1230,20 @@ def linear_constraints(n: Node, in_features, out_features, symbols, counter):
|
||||
|
||||
nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
|
||||
|
||||
c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
|
||||
BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
|
||||
add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) +
|
||||
nat_constraints)
|
||||
c_tensor_i = Conj(
|
||||
[
|
||||
BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
|
||||
BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq),
|
||||
]
|
||||
+ add_linear_constraints(
|
||||
new_dims_rhs_1, new_dims_rhs_2, in_features, out_features
|
||||
)
|
||||
+ nat_constraints
|
||||
)
|
||||
c2.append(c_tensor_i)
|
||||
return [Disj([c1, Disj(c2)])], counter
|
||||
|
||||
|
||||
def add_layer_norm_constraints(input_dim, normalized_dim):
|
||||
"""
|
||||
The constraints say that the type has te form: [*, 1024, 1024]
|
||||
@ -1130,7 +1333,13 @@ def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, coun
|
||||
d4, counter = gen_dvar(counter)
|
||||
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
||||
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
|
||||
c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq)
|
||||
c2 = BinConstraintT(
|
||||
avg_pool,
|
||||
TensorType(
|
||||
[d1, d2, module_instance.output_size[0], module_instance.output_size[1]]
|
||||
),
|
||||
op_eq,
|
||||
)
|
||||
|
||||
return [c1, c2, *nat_constraints], counter
|
||||
|
||||
@ -1152,12 +1361,16 @@ def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counte
|
||||
# c2 = DConsistency(module_instance.in_channels, d2)
|
||||
c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)
|
||||
|
||||
c3 = CalcConv(my_conv, input_var,
|
||||
module_instance.out_channels,
|
||||
module_instance.kernel_size,
|
||||
module_instance.padding,
|
||||
module_instance.stride,
|
||||
module_instance.dilation, [d1, d2, d3, d4])
|
||||
c3 = CalcConv(
|
||||
my_conv,
|
||||
input_var,
|
||||
module_instance.out_channels,
|
||||
module_instance.kernel_size,
|
||||
module_instance.padding,
|
||||
module_instance.stride,
|
||||
module_instance.dilation,
|
||||
[d1, d2, d3, d4],
|
||||
)
|
||||
|
||||
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
||||
|
||||
@ -1176,8 +1389,15 @@ def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, count
|
||||
|
||||
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
|
||||
|
||||
c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding,
|
||||
module_instance.stride, module_instance.dilation, [d1, d2, d3, d4])
|
||||
c2 = CalcMaxPool(
|
||||
maxpool,
|
||||
input_var,
|
||||
module_instance.kernel_size,
|
||||
module_instance.padding,
|
||||
module_instance.stride,
|
||||
module_instance.dilation,
|
||||
[d1, d2, d3, d4],
|
||||
)
|
||||
|
||||
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
||||
|
||||
@ -1190,8 +1410,7 @@ class ConstraintGenerator:
|
||||
self.traced_params = dict(self.traced.named_parameters())
|
||||
self.constraints = []
|
||||
self.symbol_dict = {}
|
||||
self.graph = traced.graph if hasattr(traced, 'graph') else graph
|
||||
|
||||
self.graph = traced.graph if hasattr(traced, "graph") else graph
|
||||
|
||||
def generate_constraints(self, counter=0):
|
||||
"""
|
||||
@ -1217,7 +1436,7 @@ class ConstraintGenerator:
|
||||
- conv2d
|
||||
"""
|
||||
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
x, counter = gen_tvar(counter)
|
||||
self.symbol_dict[n] = x
|
||||
|
||||
@ -1226,8 +1445,8 @@ class ConstraintGenerator:
|
||||
if n.type != Dyn and (not isinstance(n.type, TensorType)):
|
||||
if n.type == torch.nn.parameter.Parameter:
|
||||
# since we have a parameter, the shape must be static
|
||||
assert 'example_value' in n.meta
|
||||
my_type = TensorType(n.meta['example_value'].size())
|
||||
assert "example_value" in n.meta
|
||||
my_type = TensorType(n.meta["example_value"].size())
|
||||
else:
|
||||
my_type = Dyn
|
||||
|
||||
@ -1235,30 +1454,38 @@ class ConstraintGenerator:
|
||||
c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
|
||||
return [c1, c2], counter
|
||||
|
||||
elif n.op == 'call_function':
|
||||
elif n.op == "call_function":
|
||||
if n.target in _INFERENCE_RULES:
|
||||
return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
|
||||
return _INFERENCE_RULES[n.target](
|
||||
n, self.symbol_dict, self.constraints, counter
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f'No inference rule registered for target {n.target}!')
|
||||
|
||||
elif n.op == 'call_module':
|
||||
raise RuntimeError(
|
||||
f"No inference rule registered for target {n.target}!"
|
||||
)
|
||||
|
||||
elif n.op == "call_module":
|
||||
module_instance = self.traced.get_submodule(n.target)
|
||||
if type(module_instance) in _INFERENCE_RULES:
|
||||
return _INFERENCE_RULES[type(module_instance)](n,
|
||||
module_instance,
|
||||
self.symbol_dict,
|
||||
self.constraints, counter)
|
||||
return _INFERENCE_RULES[type(module_instance)](
|
||||
n, module_instance, self.symbol_dict, self.constraints, counter
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
|
||||
raise RuntimeError(
|
||||
f"No inference rule registered for class {type(module_instance)}!"
|
||||
)
|
||||
|
||||
elif n.op == 'call_method':
|
||||
elif n.op == "call_method":
|
||||
if n.target in _INFERENCE_RULES:
|
||||
return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
|
||||
return _INFERENCE_RULES[n.target](
|
||||
n, self.symbol_dict, self.constraints, counter
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f'No inference rule registered for target {n.target}!')
|
||||
raise RuntimeError(
|
||||
f"No inference rule registered for target {n.target}!"
|
||||
)
|
||||
|
||||
elif n.op == 'get_attr':
|
||||
elif n.op == "get_attr":
|
||||
t = self.traced_params.get(n.target, None)
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
@ -1274,7 +1501,7 @@ class ConstraintGenerator:
|
||||
else:
|
||||
return [], counter
|
||||
|
||||
elif n.op == 'output':
|
||||
elif n.op == "output":
|
||||
return [], counter
|
||||
|
||||
else:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,14 +1,14 @@
|
||||
op_add = '+'
|
||||
op_sub = '-'
|
||||
op_mul = '*'
|
||||
op_div = '/'
|
||||
op_eq = '='
|
||||
op_neq = '!='
|
||||
op_imp = '=>'
|
||||
op_matching = '\u22b3' # (contains)
|
||||
op_consistency = '~'
|
||||
op_precision = '\u2291' # (square image of or equal to)
|
||||
op_leq = '\u2264' # less-than or equal to
|
||||
op_lt = '<'
|
||||
op_gt = '>'
|
||||
op_mod = '%'
|
||||
op_add = "+"
|
||||
op_sub = "-"
|
||||
op_mul = "*"
|
||||
op_div = "/"
|
||||
op_eq = "="
|
||||
op_neq = "!="
|
||||
op_imp = "=>"
|
||||
op_matching = "\u22b3" # (contains)
|
||||
op_consistency = "~"
|
||||
op_precision = "\u2291" # (square image of or equal to)
|
||||
op_leq = "\u2264" # less-than or equal to
|
||||
op_lt = "<"
|
||||
op_gt = ">"
|
||||
op_mod = "%"
|
||||
|
||||
@ -1,16 +1,49 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
|
||||
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt
|
||||
from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod
|
||||
from torch.fx.tensor_type import TensorType, Dyn
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint import (
|
||||
BinConstraintD,
|
||||
BinConstraintT,
|
||||
BVar,
|
||||
Conj,
|
||||
Disj,
|
||||
DVar,
|
||||
F,
|
||||
is_algebraic_expression,
|
||||
is_bool_expr,
|
||||
is_dim,
|
||||
Prod,
|
||||
T,
|
||||
TVar,
|
||||
)
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint_generator import (
|
||||
ConstraintGenerator,
|
||||
)
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import (
|
||||
transform_constraint,
|
||||
)
|
||||
from torch.fx.experimental.migrate_gradual_types.operation import (
|
||||
op_add,
|
||||
op_div,
|
||||
op_eq,
|
||||
op_gt,
|
||||
op_leq,
|
||||
op_lt,
|
||||
op_mod,
|
||||
op_mul,
|
||||
op_neq,
|
||||
op_sub,
|
||||
)
|
||||
from torch.fx.tensor_type import Dyn, TensorType
|
||||
|
||||
|
||||
try:
|
||||
import z3 # type: ignore[import]
|
||||
from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D
|
||||
|
||||
from torch.fx.experimental.migrate_gradual_types.z3_types import (
|
||||
D,
|
||||
tensor_type,
|
||||
z3_dyn,
|
||||
)
|
||||
|
||||
HAS_Z3 = True
|
||||
|
||||
def transform_to_z3(constraint, counter, dimension_dict):
|
||||
@ -41,35 +74,48 @@ try:
|
||||
return (lhs == rhs), counter
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Method not yet implemented')
|
||||
raise NotImplementedError("Method not yet implemented")
|
||||
|
||||
elif isinstance(constraint, BinConstraintD):
|
||||
if constraint.op == op_eq:
|
||||
|
||||
if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
|
||||
transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict)
|
||||
transformed_rhs, counter = transform_to_z3(
|
||||
constraint.rhs, counter, dimension_dict
|
||||
)
|
||||
transformed_lhs = z3.Bool(constraint.lhs.c)
|
||||
return transformed_lhs == transformed_rhs, counter
|
||||
|
||||
elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
|
||||
# with dimension transformations we consider the encoding
|
||||
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
|
||||
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
|
||||
lhs, counter = transform_dimension(
|
||||
constraint.lhs, counter, dimension_dict
|
||||
)
|
||||
rhs, counter = transform_dimension(
|
||||
constraint.rhs, counter, dimension_dict
|
||||
)
|
||||
return lhs == rhs, counter
|
||||
|
||||
else:
|
||||
# then we have an algebraic expression which means that we disregard the
|
||||
# first element of the encoding
|
||||
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
||||
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
||||
lhs, counter = transform_algebraic_expression(
|
||||
constraint.lhs, counter, dimension_dict
|
||||
)
|
||||
rhs, counter = transform_algebraic_expression(
|
||||
constraint.rhs, counter, dimension_dict
|
||||
)
|
||||
return lhs == rhs, counter
|
||||
|
||||
# The assumption here is that the LHS and RHS must be dimensions
|
||||
elif constraint.op == op_neq:
|
||||
assert is_dim(constraint.lhs)
|
||||
assert is_dim(constraint.rhs)
|
||||
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
|
||||
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
|
||||
lhs, counter = transform_dimension(
|
||||
constraint.lhs, counter, dimension_dict
|
||||
)
|
||||
rhs, counter = transform_dimension(
|
||||
constraint.rhs, counter, dimension_dict
|
||||
)
|
||||
if constraint.rhs == Dyn or constraint.lhs == Dyn:
|
||||
if constraint.rhs == Dyn:
|
||||
return lhs.arg(0) == 1, counter
|
||||
@ -79,44 +125,83 @@ try:
|
||||
# if one of the instances is a number
|
||||
elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
|
||||
if isinstance(constraint.lhs, int):
|
||||
return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
|
||||
return (
|
||||
z3.Or(
|
||||
[
|
||||
rhs.arg(0) == 0,
|
||||
z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]),
|
||||
]
|
||||
),
|
||||
counter,
|
||||
)
|
||||
|
||||
elif isinstance(constraint.rhs, int):
|
||||
return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
|
||||
return (
|
||||
z3.Or(
|
||||
[
|
||||
lhs.arg(0) == 0,
|
||||
z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]),
|
||||
]
|
||||
),
|
||||
counter,
|
||||
)
|
||||
|
||||
else:
|
||||
return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
|
||||
z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
|
||||
z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter
|
||||
|
||||
return (
|
||||
z3.Or(
|
||||
[
|
||||
z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
|
||||
z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
|
||||
z3.And(
|
||||
[
|
||||
lhs.arg(0) != 0,
|
||||
rhs.arg(0) != 0,
|
||||
lhs.arg(1) != rhs.arg(1),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
counter,
|
||||
)
|
||||
|
||||
elif constraint.op == op_leq:
|
||||
# if the dimensions are not dyn, this will come into effect
|
||||
# there would have been another constraint specifying if a given dimension
|
||||
# is dyn or not
|
||||
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
|
||||
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
||||
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
||||
lhs, counter = transform_algebraic_expression(
|
||||
constraint.lhs, counter, dimension_dict
|
||||
)
|
||||
rhs, counter = transform_algebraic_expression(
|
||||
constraint.rhs, counter, dimension_dict
|
||||
)
|
||||
return lhs <= rhs, counter
|
||||
|
||||
elif constraint.op == op_gt:
|
||||
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
|
||||
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
||||
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
||||
lhs, counter = transform_algebraic_expression(
|
||||
constraint.lhs, counter, dimension_dict
|
||||
)
|
||||
rhs, counter = transform_algebraic_expression(
|
||||
constraint.rhs, counter, dimension_dict
|
||||
)
|
||||
return lhs > rhs, counter
|
||||
|
||||
elif constraint.op == op_lt:
|
||||
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
|
||||
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
||||
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
||||
lhs, counter = transform_algebraic_expression(
|
||||
constraint.lhs, counter, dimension_dict
|
||||
)
|
||||
rhs, counter = transform_algebraic_expression(
|
||||
constraint.rhs, counter, dimension_dict
|
||||
)
|
||||
return lhs < rhs, counter
|
||||
|
||||
else:
|
||||
raise NotImplementedError('operation not yet implemented')
|
||||
raise NotImplementedError("operation not yet implemented")
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Operation not yet implemented')
|
||||
|
||||
raise NotImplementedError("Operation not yet implemented")
|
||||
|
||||
def transform_var(tensor, counter, dimension_dict):
|
||||
"""
|
||||
@ -166,13 +251,15 @@ try:
|
||||
return D(1, dimension), counter
|
||||
elif isinstance(dimension, DVar):
|
||||
if dimension.c in dimension_dict:
|
||||
return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter
|
||||
return (
|
||||
D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)),
|
||||
counter,
|
||||
)
|
||||
else:
|
||||
counter += 1
|
||||
dimension_dict[dimension.c] = counter
|
||||
return D(z3.Int(counter), z3.Int(dimension.c)), counter
|
||||
|
||||
|
||||
def transform_algebraic_expression(expr, counter, dimension_dict):
|
||||
"""
|
||||
Transforms an algebraic expression to z3 format
|
||||
@ -190,7 +277,6 @@ try:
|
||||
return transformed.arg(1), counter
|
||||
|
||||
elif isinstance(expr, Prod):
|
||||
|
||||
dims = []
|
||||
for dim in expr.products:
|
||||
assert is_dim(dim)
|
||||
@ -199,9 +285,12 @@ try:
|
||||
return z3.Product(dims), counter
|
||||
|
||||
elif is_algebraic_expression(expr):
|
||||
|
||||
lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict)
|
||||
rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict)
|
||||
lhs, counter = transform_algebraic_expression(
|
||||
expr.lhs, counter, dimension_dict
|
||||
)
|
||||
rhs, counter = transform_algebraic_expression(
|
||||
expr.rhs, counter, dimension_dict
|
||||
)
|
||||
|
||||
if expr.op == op_sub:
|
||||
c = lhs - rhs
|
||||
@ -219,14 +308,13 @@ try:
|
||||
c = lhs % rhs
|
||||
|
||||
else:
|
||||
raise NotImplementedError('operation not yet implemented')
|
||||
raise NotImplementedError("operation not yet implemented")
|
||||
|
||||
return c, counter
|
||||
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
def transform_all_constraints(traced, counter=0):
|
||||
"""
|
||||
Given a trace, generates constraints and transforms them to z3 format
|
||||
@ -291,7 +379,6 @@ try:
|
||||
# transform precision, matching, consistency till obtaining a fixed point
|
||||
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
|
||||
|
||||
|
||||
# since the function returns a list of one element, we get the first element
|
||||
# we are only interested in the RHS in this case because the LHS just stores
|
||||
# the result
|
||||
@ -304,19 +391,27 @@ try:
|
||||
condition_constraint_rhs = condition_constraint.rhs
|
||||
|
||||
# transform the condition constraint
|
||||
condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter)
|
||||
condition_constraint_rhs, counter = iterate_till_fixed_point(
|
||||
condition_constraint_rhs, counter
|
||||
)
|
||||
|
||||
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
|
||||
|
||||
transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict)
|
||||
transformed_condition_constraint, counter = transform_to_z3(
|
||||
condition_constraint_rhs, counter, dimension_dict
|
||||
)
|
||||
|
||||
negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
|
||||
negation_transformed_condition_constraint = z3.Not(
|
||||
transformed_condition_constraint
|
||||
)
|
||||
|
||||
return z3.And([transformed, transformed_condition_constraint]), \
|
||||
z3.And([transformed, negation_transformed_condition_constraint])
|
||||
return z3.And([transformed, transformed_condition_constraint]), z3.And(
|
||||
[transformed, negation_transformed_condition_constraint]
|
||||
)
|
||||
|
||||
|
||||
def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None):
|
||||
def evaluate_conditional_with_constraints(
|
||||
tracer_root, graph, node, counter=0, user_constraints=None
|
||||
):
|
||||
"""
|
||||
Given an IR and a node representing a conditional, evaluate the conditional
|
||||
and its negation
|
||||
@ -329,8 +424,10 @@ try:
|
||||
|
||||
"""
|
||||
|
||||
transformed_positive, transformed_negative = \
|
||||
transform_all_constraints_trace_time(tracer_root, graph, node, counter)
|
||||
(
|
||||
transformed_positive,
|
||||
transformed_negative,
|
||||
) = transform_all_constraints_trace_time(tracer_root, graph, node, counter)
|
||||
|
||||
s = z3.Solver()
|
||||
s.add(transformed_positive)
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
|
||||
BVar
|
||||
from torch.fx.experimental.migrate_gradual_types.constraint import (
|
||||
BinConstraintD,
|
||||
BVar,
|
||||
DVar,
|
||||
TVar,
|
||||
)
|
||||
from torch.fx.experimental.migrate_gradual_types.operation import op_leq
|
||||
|
||||
|
||||
@ -23,6 +27,7 @@ def gen_dvar(curr):
|
||||
curr += 1
|
||||
return DVar(curr), curr
|
||||
|
||||
|
||||
def gen_bvar(curr):
|
||||
"""
|
||||
Generate a boolean variable
|
||||
@ -32,6 +37,7 @@ def gen_bvar(curr):
|
||||
curr += 1
|
||||
return BVar(curr), curr
|
||||
|
||||
|
||||
def gen_tensor_dims(n, curr):
|
||||
"""
|
||||
Generate a list of tensor dimensions
|
||||
|
||||
@ -1,22 +1,23 @@
|
||||
try:
|
||||
import z3 # type: ignore[import]
|
||||
|
||||
HAS_Z3 = True
|
||||
# dynamic type
|
||||
dyn = z3.DeclareSort('Dyn')
|
||||
dyn_type = z3.Const('dyn', dyn)
|
||||
dyn = z3.DeclareSort("Dyn")
|
||||
dyn_type = z3.Const("dyn", dyn)
|
||||
|
||||
# dimension
|
||||
dim = z3.Datatype('dim')
|
||||
dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort()))
|
||||
dim = z3.Datatype("dim")
|
||||
dim.declare("dim", ("0", z3.IntSort()), ("1", z3.IntSort()))
|
||||
dim = dim.create()
|
||||
|
||||
# tensors
|
||||
tensor_type = z3.Datatype('TensorType')
|
||||
tensor_type.declare('Dyn', ('dyn', dyn))
|
||||
tensor_type.declare('tensor1', ('0', dim))
|
||||
tensor_type.declare('tensor2', ('0', dim), ('1', dim))
|
||||
tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim))
|
||||
tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim))
|
||||
tensor_type = z3.Datatype("TensorType")
|
||||
tensor_type.declare("Dyn", ("dyn", dyn))
|
||||
tensor_type.declare("tensor1", ("0", dim))
|
||||
tensor_type.declare("tensor2", ("0", dim), ("1", dim))
|
||||
tensor_type.declare("tensor3", ("0", dim), ("1", dim), ("2", dim))
|
||||
tensor_type.declare("tensor4", ("0", dim), ("1", dim), ("2", dim), ("3", dim))
|
||||
tensor_type = tensor_type.create()
|
||||
|
||||
# create dimension
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, Tuple, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.fx as fx
|
||||
from torch.fx import Transformer, Proxy
|
||||
from torch.fx.node import Argument, Target, Node, map_aggregate
|
||||
from torch.fx import Proxy, Transformer
|
||||
from torch.fx.node import Argument, map_aggregate, Node, Target
|
||||
from torch.fx.operator_schemas import (
|
||||
normalize_module,
|
||||
normalize_function,
|
||||
create_type_hint,
|
||||
normalize_function,
|
||||
normalize_module,
|
||||
)
|
||||
|
||||
from .schema_type_annotation import AnnotateTypesWithSchema
|
||||
|
||||
@ -1,37 +1,42 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch.fx as fx
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_eval
|
||||
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
import torch.utils.mkldnn as th_mkldnn
|
||||
import logging
|
||||
import operator
|
||||
import time
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
def _parent_name(target : str) -> Tuple[str, str]:
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.mkldnn as th_mkldnn
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_eval
|
||||
|
||||
|
||||
def _parent_name(target: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Splits a qualname into parent path and last atom.
|
||||
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
|
||||
"""
|
||||
*parent, name = target.rsplit('.', 1)
|
||||
return parent[0] if parent else '', name
|
||||
*parent, name = target.rsplit(".", 1)
|
||||
return parent[0] if parent else "", name
|
||||
|
||||
|
||||
# Works for length 2 patterns with 2 modules
|
||||
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]):
|
||||
def matches_module_pattern(
|
||||
pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]
|
||||
):
|
||||
if len(node.args) == 0:
|
||||
return False
|
||||
nodes: Tuple[Any, fx.Node] = (node.args[0], node)
|
||||
for expected_type, current_node in zip(pattern, nodes):
|
||||
if not isinstance(current_node, fx.Node):
|
||||
return False
|
||||
if current_node.op != 'call_module':
|
||||
if current_node.op != "call_module":
|
||||
return False
|
||||
if not isinstance(current_node.target, str):
|
||||
return False
|
||||
@ -42,20 +47,25 @@ def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict
|
||||
return True
|
||||
|
||||
|
||||
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
|
||||
def replace_node_module(
|
||||
node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module
|
||||
):
|
||||
assert isinstance(node.target, str)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
modules[node.target] = new_module
|
||||
setattr(modules[parent_name], name, new_module)
|
||||
|
||||
|
||||
def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
|
||||
"""
|
||||
Fuses convolution/BN layers for inference purposes. Will deepcopy your
|
||||
model by default, but can modify the model inplace as well.
|
||||
"""
|
||||
patterns = [(nn.Conv1d, nn.BatchNorm1d),
|
||||
(nn.Conv2d, nn.BatchNorm2d),
|
||||
(nn.Conv3d, nn.BatchNorm3d)]
|
||||
patterns = [
|
||||
(nn.Conv1d, nn.BatchNorm1d),
|
||||
(nn.Conv2d, nn.BatchNorm2d),
|
||||
(nn.Conv3d, nn.BatchNorm3d),
|
||||
]
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
if not no_trace or not isinstance(model, torch.fx.GraphModule):
|
||||
@ -80,6 +90,7 @@ def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Modu
|
||||
new_graph.erase_node(node)
|
||||
return fx.GraphModule(fx_model, new_graph)
|
||||
|
||||
|
||||
def remove_dropout(model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Removes all dropout layers from the module.
|
||||
@ -87,15 +98,24 @@ def remove_dropout(model: nn.Module) -> nn.Module:
|
||||
fx_model = fx.symbolic_trace(model)
|
||||
|
||||
class DropoutRemover(torch.fx.Transformer):
|
||||
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_module(
|
||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
if isinstance(self.submodules[target], nn.Dropout):
|
||||
assert len(args) == 1
|
||||
return args[0]
|
||||
else:
|
||||
return super().call_module(target, args, kwargs)
|
||||
|
||||
return DropoutRemover(fx_model).transform()
|
||||
|
||||
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
|
||||
|
||||
def extract_subgraph(
|
||||
orig_module: nn.Module,
|
||||
nodes: List[fx.Node],
|
||||
inputs: List[fx.Node],
|
||||
outputs: List[fx.Node],
|
||||
):
|
||||
"""
|
||||
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
|
||||
"""
|
||||
@ -111,10 +131,21 @@ def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[
|
||||
new_graph.lint()
|
||||
return fx.GraphModule(orig_module, new_graph)
|
||||
|
||||
|
||||
mkldnn_supported = [
|
||||
nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d,
|
||||
torch.relu, torch.transpose, torch.sigmoid,
|
||||
F.relu, F.avg_pool2d, F.adaptive_avg_pool2d
|
||||
nn.Conv2d,
|
||||
nn.Linear,
|
||||
nn.BatchNorm2d,
|
||||
nn.ReLU,
|
||||
nn.MaxPool2d,
|
||||
nn.AvgPool2d,
|
||||
nn.AdaptiveAvgPool2d,
|
||||
torch.relu,
|
||||
torch.transpose,
|
||||
torch.sigmoid,
|
||||
F.relu,
|
||||
F.avg_pool2d,
|
||||
F.adaptive_avg_pool2d,
|
||||
]
|
||||
# These are operators that may not be convertible into MKLDNN ops (e.g. the
|
||||
# args are scalar values). Thus, we only include them in the subgraph if their
|
||||
@ -124,7 +155,7 @@ mkldnn_supported_unknown = [operator.add, operator.mul]
|
||||
mkldnn_map = {
|
||||
nn.Conv2d: th_mkldnn.MkldnnConv2d,
|
||||
nn.Linear: th_mkldnn.MkldnnLinear,
|
||||
nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a)
|
||||
nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a),
|
||||
}
|
||||
|
||||
|
||||
@ -136,7 +167,7 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
|
||||
"""
|
||||
old_modules: Dict[nn.Module, nn.Module] = {}
|
||||
for node in nodes:
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
cur_module = modules[node.target]
|
||||
if type(cur_module) in mkldnn_map:
|
||||
@ -146,18 +177,24 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]):
|
||||
replace_node_module(node, modules, new_module)
|
||||
return old_modules
|
||||
|
||||
def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]):
|
||||
|
||||
def reset_modules(
|
||||
nodes: List[fx.Node],
|
||||
modules: Dict[str, nn.Module],
|
||||
old_modules: Dict[nn.Module, nn.Module],
|
||||
):
|
||||
"""
|
||||
Maps each module that's been changed with `modules_to_mkldnn` back to its
|
||||
original.
|
||||
"""
|
||||
for node in nodes:
|
||||
if node.op == 'call_module':
|
||||
assert (isinstance(node.target, str))
|
||||
if node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
cur_module = modules[node.target]
|
||||
if cur_module in old_modules:
|
||||
replace_node_module(node, modules, old_modules[cur_module])
|
||||
|
||||
|
||||
class MklSubgraph:
|
||||
def __init__(self, fx_graph: fx.Graph):
|
||||
self.fx_graph = fx_graph
|
||||
@ -165,6 +202,7 @@ class MklSubgraph:
|
||||
self.start_nodes: List[fx.Node] = []
|
||||
self.end_nodes: List[fx.Node] = []
|
||||
|
||||
|
||||
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
|
||||
"""
|
||||
This generates a heuristic that can be passed into `optimize_for_inference` that
|
||||
@ -196,13 +234,21 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
|
||||
f()
|
||||
return time.time() - begin
|
||||
|
||||
mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])])
|
||||
mkl_time = benchmark(
|
||||
lambda: [
|
||||
i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])
|
||||
]
|
||||
)
|
||||
|
||||
reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules)
|
||||
reset_modules(
|
||||
submodule.graph.nodes, dict(submodule.named_modules()), old_modules
|
||||
)
|
||||
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
|
||||
return mkl_time < no_mkl_time
|
||||
|
||||
return use_mkl_heuristic
|
||||
|
||||
|
||||
def use_mkl_length(graph: MklSubgraph) -> bool:
|
||||
"""
|
||||
This is a heuristic that can be passed into `optimize_for_inference` that
|
||||
@ -211,6 +257,7 @@ def use_mkl_length(graph: MklSubgraph) -> bool:
|
||||
"""
|
||||
return len(graph.nodes) > 2
|
||||
|
||||
|
||||
class UnionFind:
|
||||
def __init__(self, n):
|
||||
self.parent: List[Optional[int]] = [None] * n
|
||||
@ -237,10 +284,11 @@ class UnionFind:
|
||||
self.parent[b] = a
|
||||
self.size[a] += self.size[b]
|
||||
|
||||
|
||||
def optimize_for_inference(
|
||||
model: torch.nn.Module,
|
||||
pass_config: Optional[Dict[str, Any]] = None,
|
||||
tracer: Type[fx.Tracer] = fx.Tracer
|
||||
tracer: Type[fx.Tracer] = fx.Tracer,
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Performs a set of optimization passes to optimize a model for the
|
||||
@ -258,7 +306,7 @@ def optimize_for_inference(
|
||||
default_pass_config = {
|
||||
"conv_bn_fuse": True,
|
||||
"remove_dropout": True,
|
||||
"mkldnn_layout_optimize": {'heuristic': use_mkl_length},
|
||||
"mkldnn_layout_optimize": {"heuristic": use_mkl_length},
|
||||
}
|
||||
if pass_config is None:
|
||||
pass_config = {}
|
||||
@ -292,15 +340,19 @@ def optimize_for_inference(
|
||||
# a MKLDNN node if its inputs are MKLDNN nodes.
|
||||
for node in list(fx_graph.nodes):
|
||||
supports_mkldnn = MklSupport.NO
|
||||
if node.op == 'call_module':
|
||||
if node.op == "call_module":
|
||||
cur_module = modules[node.target]
|
||||
if type(cur_module) in mkldnn_supported:
|
||||
supports_mkldnn = MklSupport.YES
|
||||
sample_parameter = next(cur_module.parameters(), None)
|
||||
if sample_parameter is not None:
|
||||
assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules"
|
||||
assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules"
|
||||
elif node.op == 'call_function':
|
||||
assert (
|
||||
sample_parameter.dtype == torch.float
|
||||
), "this pass is only for torch.float modules"
|
||||
assert sample_parameter.device == torch.device(
|
||||
"cpu"
|
||||
), "this pass is only for CPU modules"
|
||||
elif node.op == "call_function":
|
||||
if node.target in mkldnn_supported:
|
||||
supports_mkldnn = MklSupport.YES
|
||||
elif node.target in mkldnn_supported_unknown:
|
||||
@ -308,15 +360,17 @@ def optimize_for_inference(
|
||||
|
||||
if supports_mkldnn != MklSupport.NO:
|
||||
if supports_mkldnn == MklSupport.UNKNOWN:
|
||||
if not any(arg.target == 'to_dense' for arg in node.args):
|
||||
if not any(arg.target == "to_dense" for arg in node.args):
|
||||
continue
|
||||
with fx_graph.inserting_before(node):
|
||||
mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, )))
|
||||
mkldnn_args = fx.map_arg(
|
||||
node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
|
||||
)
|
||||
|
||||
node.args = cast(Tuple[fx.node.Argument], mkldnn_args)
|
||||
|
||||
with fx_graph.inserting_after(node):
|
||||
dense_x = fx_graph.create_node('call_method', 'to_dense', (node,))
|
||||
dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
|
||||
node.replace_all_uses_with(dense_x)
|
||||
dense_x.args = (node,)
|
||||
|
||||
@ -326,28 +380,26 @@ def optimize_for_inference(
|
||||
|
||||
# optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
|
||||
for node in fx_graph.nodes:
|
||||
if node.op == 'call_method' and node.target == 'to_dense':
|
||||
if node.op == "call_method" and node.target == "to_dense":
|
||||
prv_node = node.args[0]
|
||||
users = list(node.users)
|
||||
for user in users:
|
||||
if user.op == 'call_method' and user.target == 'to_mkldnn':
|
||||
if user.op == "call_method" and user.target == "to_mkldnn":
|
||||
user.replace_all_uses_with(prv_node)
|
||||
fx_graph.erase_node(user)
|
||||
if len(node.users) == 0:
|
||||
fx_graph.erase_node(node)
|
||||
|
||||
|
||||
num_nodes = len(fx_graph.nodes)
|
||||
uf = UnionFind(num_nodes)
|
||||
|
||||
def get_color(n):
|
||||
if hasattr(n, 'color'): # Current node is part of a MKL subgraph
|
||||
if hasattr(n, "color"): # Current node is part of a MKL subgraph
|
||||
return uf.find(n.color)
|
||||
if hasattr(n, 'start_color'): # Current node is input to MKL subgraph
|
||||
if hasattr(n, "start_color"): # Current node is input to MKL subgraph
|
||||
return uf.find(n.start_color)
|
||||
return None
|
||||
|
||||
|
||||
# This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
|
||||
# of input nodes (which are only `to_mkldnn` calls), output nodes
|
||||
# (`to_dense` calls), and intermediate nodes, which are run entirely on
|
||||
@ -360,14 +412,19 @@ def optimize_for_inference(
|
||||
# nodes (i.e. colors), we need to join these 2 colors into 1. That's done
|
||||
# using a Disjoint Set Union.
|
||||
for cur_idx, node in enumerate(fx_graph.nodes):
|
||||
if node.op == 'call_method' and node.target == 'to_mkldnn':
|
||||
if node.op == "call_method" and node.target == "to_mkldnn":
|
||||
node.start_color = cur_idx
|
||||
uf.make_set(cur_idx)
|
||||
elif node.op == 'call_method' and node.target == 'to_dense':
|
||||
elif node.op == "call_method" and node.target == "to_dense":
|
||||
assert get_color(node.args[0]) is not None
|
||||
node.end_color = get_color(node.args[0])
|
||||
else:
|
||||
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None]
|
||||
cur_colors = [
|
||||
get_color(i)
|
||||
for i in node.all_input_nodes
|
||||
if isinstance(i, fx.Node)
|
||||
if get_color(i) is not None
|
||||
]
|
||||
|
||||
if len(cur_colors) == 0:
|
||||
continue
|
||||
@ -377,17 +434,15 @@ def optimize_for_inference(
|
||||
for other_color in cur_colors[1:]:
|
||||
uf.join(cur_colors[0], other_color)
|
||||
|
||||
|
||||
mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
|
||||
for node in fx_graph.nodes:
|
||||
if hasattr(node, 'color'):
|
||||
if hasattr(node, "color"):
|
||||
mkldnn_graphs[uf.find(node.color)].nodes.append(node)
|
||||
if hasattr(node, 'start_color'):
|
||||
if hasattr(node, "start_color"):
|
||||
mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
|
||||
if hasattr(node, 'end_color'):
|
||||
if hasattr(node, "end_color"):
|
||||
mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
|
||||
|
||||
|
||||
# Now that we have all the subgraphs, we need to decide which MKLDNN
|
||||
# subgraphs we actually want to keep in MKLDNN.
|
||||
for graph in mkldnn_graphs.values():
|
||||
@ -400,7 +455,7 @@ def optimize_for_inference(
|
||||
|
||||
mkldnn_conversions = 0
|
||||
for node in fx_graph.nodes:
|
||||
if node.target == 'to_mkldnn' or node.target == 'to_dense':
|
||||
if node.target == "to_mkldnn" or node.target == "to_dense":
|
||||
mkldnn_conversions += 1
|
||||
|
||||
logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Dict, List, Set
|
||||
from typing import Dict, List, NamedTuple, Set
|
||||
|
||||
from torch.fx.node import Node, map_arg
|
||||
from torch.fx.node import map_arg, Node
|
||||
|
||||
|
||||
class Partition:
|
||||
@ -146,7 +146,7 @@ def get_latency_of_one_partition(
|
||||
# this node is on the top bfs level in this partition
|
||||
if not any(
|
||||
n in partition.nodes and n.op not in {"placeholder", "get_attr"}
|
||||
for n in input_nodes
|
||||
for n in input_nodes
|
||||
):
|
||||
top_nodes.append(node)
|
||||
return top_nodes
|
||||
@ -279,7 +279,9 @@ def get_latency_of_partitioned_graph(
|
||||
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
|
||||
"""This function helps to recursively get the latency of a path of partitions"""
|
||||
# Update latency by adding current partition's latency
|
||||
latency_so_far_sec += partition_to_latency_mapping[partition].overall_latency_sec
|
||||
latency_so_far_sec += partition_to_latency_mapping[
|
||||
partition
|
||||
].overall_latency_sec
|
||||
|
||||
if partition.children:
|
||||
max_latency_sec = 0.0
|
||||
|
||||
@ -5,10 +5,10 @@ class Equality:
|
||||
self.rhs = rhs
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.lhs} = {self.rhs}'
|
||||
return f"{self.lhs} = {self.rhs}"
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.lhs} = {self.rhs}'
|
||||
return f"{self.lhs} = {self.rhs}"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Equality):
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import textwrap
|
||||
from types import FunctionType
|
||||
from typing import cast, Union, Callable, Dict, Optional, Any
|
||||
from typing import Any, Callable, cast, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._sources import normalize_source_lines
|
||||
from torch.fx._symbolic_trace import Tracer
|
||||
from torch.fx.graph import Graph
|
||||
from torch._sources import normalize_source_lines
|
||||
import torch
|
||||
|
||||
|
||||
class AST_Rewriter(ast.NodeTransformer):
|
||||
"""
|
||||
@ -29,11 +31,10 @@ class AST_Rewriter(ast.NodeTransformer):
|
||||
# suitable for dynamo tracing anyways.
|
||||
@torch._dynamo.disable
|
||||
def rewrite(self, fn: FunctionType):
|
||||
|
||||
# Normalize the source lines
|
||||
sourcelines, _ = inspect.getsourcelines(fn)
|
||||
sourcelines = normalize_source_lines(sourcelines)
|
||||
source = ''.join(sourcelines)
|
||||
source = "".join(sourcelines)
|
||||
normalized_str = textwrap.dedent(source)
|
||||
|
||||
# Rewrite the original AST
|
||||
@ -64,6 +65,7 @@ class AST_Rewriter(ast.NodeTransformer):
|
||||
g = functools.update_wrapper(g, f)
|
||||
g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined]
|
||||
return g
|
||||
|
||||
# Return the correct FunctionType object
|
||||
return change_func_globals(fn_compiled, globals=fn.__globals__)
|
||||
|
||||
@ -73,7 +75,7 @@ class AST_Rewriter(ast.NodeTransformer):
|
||||
symbolically-traceable torch._assert function
|
||||
"""
|
||||
# Create the Call node
|
||||
n = ast.parse('torch._assert()', mode='eval')
|
||||
n = ast.parse("torch._assert()", mode="eval")
|
||||
assert isinstance(n, ast.Expression)
|
||||
call_node = n.body
|
||||
assert isinstance(call_node, ast.Call)
|
||||
@ -96,13 +98,22 @@ class AST_Rewriter(ast.NodeTransformer):
|
||||
Output:
|
||||
y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
|
||||
"""
|
||||
return ast.Assign(targets=[node.target], value=ast.Call(
|
||||
func=ast.Name(id='annotate', ctx=ast.Load()),
|
||||
args=[node.value, node.annotation], keywords=[]))
|
||||
return ast.Assign(
|
||||
targets=[node.target],
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="annotate", ctx=ast.Load()),
|
||||
args=[node.value, node.annotation],
|
||||
keywords=[],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RewritingTracer(Tracer):
|
||||
def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
|
||||
def trace(
|
||||
self,
|
||||
root: Union[torch.nn.Module, Callable],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Graph:
|
||||
return super().trace(_rewrite(root), concrete_args)
|
||||
|
||||
|
||||
@ -111,7 +122,7 @@ def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Cal
|
||||
# Rewrite this module's `forward` as well as the `forward`s of
|
||||
# all of this module's recursive descendents. Return the new,
|
||||
# rewritten module hierarchy.
|
||||
def rewrite_module(m : torch.nn.Module):
|
||||
def rewrite_module(m: torch.nn.Module):
|
||||
class RewrittenModule(torch.nn.Module):
|
||||
def __init__(self, orig):
|
||||
super().__init__()
|
||||
@ -120,8 +131,12 @@ def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Cal
|
||||
self.__dict__[k] = copy.copy(rewrite_module(v))
|
||||
else:
|
||||
self.__dict__[k] = copy.copy(v)
|
||||
RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
|
||||
|
||||
RewrittenModule.forward = AST_Rewriter().rewrite(
|
||||
cast(FunctionType, m.forward)
|
||||
)
|
||||
return RewrittenModule(m)
|
||||
|
||||
return rewrite_module(fn)
|
||||
else:
|
||||
# Rewrite this single free function
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import torch.fx
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from torch.fx.node import Argument, Target
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._jit_internal import boolean_dispatched
|
||||
from torch.fx import Transformer
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.fx.operator_schemas import _torchscript_type_to_python_type
|
||||
|
||||
from torch.fx import Transformer
|
||||
|
||||
class AnnotateTypesWithSchema(Transformer):
|
||||
"""
|
||||
@ -27,16 +28,24 @@ class AnnotateTypesWithSchema(Transformer):
|
||||
traced = AnnotateTypesWithSchema(traced).transform()
|
||||
|
||||
"""
|
||||
def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True,
|
||||
annotate_modules : bool = True, annotate_get_attrs : bool = True):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
annotate_functionals: bool = True,
|
||||
annotate_modules: bool = True,
|
||||
annotate_get_attrs: bool = True,
|
||||
):
|
||||
super().__init__(module)
|
||||
self.annotate_functionals = annotate_functionals
|
||||
self.annotate_modules = annotate_modules
|
||||
self.annotate_get_attrs = annotate_get_attrs
|
||||
|
||||
def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
|
||||
def call_function(
|
||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
):
|
||||
python_ret_type = None
|
||||
if self.annotate_functionals and target.__module__ == 'torch.nn.functional':
|
||||
if self.annotate_functionals and target.__module__ == "torch.nn.functional":
|
||||
target_for_analysis = target
|
||||
if target in boolean_dispatched:
|
||||
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
|
||||
@ -45,51 +54,71 @@ class AnnotateTypesWithSchema(Transformer):
|
||||
# branch signature for analysis. Otherwise, leave this un-normalized
|
||||
assert not isinstance(target, str)
|
||||
dispatched = boolean_dispatched[target]
|
||||
if_true, if_false = dispatched['if_true'], dispatched['if_false']
|
||||
if_true, if_false = dispatched["if_true"], dispatched["if_false"]
|
||||
# TODO: can we emit the union of these? What are the implications on TorchScript
|
||||
# compilation?
|
||||
if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation:
|
||||
if (
|
||||
inspect.signature(if_true).return_annotation
|
||||
!= inspect.signature(if_false).return_annotation
|
||||
):
|
||||
return super().call_function(target, args, kwargs)
|
||||
target_for_analysis = if_true
|
||||
|
||||
python_ret_type = self._extract_python_return_type(target_for_analysis)
|
||||
|
||||
return_proxy = super().call_function(target, args, kwargs)
|
||||
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
|
||||
return_proxy.node.type = (
|
||||
return_proxy.node.type if return_proxy.node.type else python_ret_type
|
||||
)
|
||||
return return_proxy
|
||||
|
||||
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
|
||||
def call_module(
|
||||
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
):
|
||||
python_ret_type = None
|
||||
assert isinstance(target, str)
|
||||
submod = self.fetch_attr(target)
|
||||
if self.annotate_modules and hasattr(submod.__class__, '__name__'):
|
||||
if self.annotate_modules and hasattr(submod.__class__, "__name__"):
|
||||
classname = submod.__class__.__name__
|
||||
if getattr(torch.nn, classname, None) == submod.__class__:
|
||||
python_ret_type = self._extract_python_return_type(submod.forward)
|
||||
return_proxy = super().call_module(target, args, kwargs)
|
||||
return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
|
||||
return_proxy.node.type = (
|
||||
return_proxy.node.type if return_proxy.node.type else python_ret_type
|
||||
)
|
||||
return return_proxy
|
||||
|
||||
def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
|
||||
def get_attr(
|
||||
self,
|
||||
target: torch.fx.node.Target,
|
||||
args: Tuple[Argument, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
):
|
||||
attr_proxy = super().get_attr(target, args, kwargs)
|
||||
|
||||
if self.annotate_get_attrs:
|
||||
module_itr = self.module
|
||||
assert isinstance(target, str)
|
||||
atoms = target.split('.')
|
||||
atoms = target.split(".")
|
||||
for i, atom in enumerate(atoms):
|
||||
if not hasattr(module_itr, atom):
|
||||
raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!')
|
||||
raise RuntimeError(
|
||||
f'Node referenced nonextent target {".".join(atoms[:i])}!'
|
||||
)
|
||||
module_itr = getattr(module_itr, atom)
|
||||
|
||||
maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr)
|
||||
if maybe_inferred_ts_type.success():
|
||||
python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type())
|
||||
attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type
|
||||
python_type = _torchscript_type_to_python_type(
|
||||
maybe_inferred_ts_type.type()
|
||||
)
|
||||
attr_proxy.node.type = (
|
||||
python_type if not attr_proxy.node.type else attr_proxy.node.type
|
||||
)
|
||||
|
||||
return attr_proxy
|
||||
|
||||
def _extract_python_return_type(self, target : Target) -> Optional[Any]:
|
||||
def _extract_python_return_type(self, target: Target) -> Optional[Any]:
|
||||
"""
|
||||
Given a Python call target, try to extract the Python return annotation
|
||||
if it is available, otherwise return None
|
||||
@ -109,4 +138,8 @@ class AnnotateTypesWithSchema(Transformer):
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None
|
||||
return (
|
||||
sig.return_annotation
|
||||
if sig.return_annotation is not inspect.Signature.empty
|
||||
else None
|
||||
)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# mypy: disable-error-code=attr-defined
|
||||
from .core import unify, reify # noqa: F403
|
||||
from .core import reify, unify # noqa: F403
|
||||
from .more import unifiable # noqa: F403
|
||||
from .variable import var, isvar, vars, variables, Var # noqa: F403
|
||||
from .variable import isvar, Var, var, variables, vars # noqa: F403
|
||||
|
||||
@ -2,10 +2,11 @@
|
||||
from collections.abc import Iterator # type: ignore[import]
|
||||
from functools import partial
|
||||
|
||||
from .dispatch import dispatch
|
||||
from .unification_tools import assoc # type: ignore[import]
|
||||
from .utils import transitive_get as walk
|
||||
from .variable import isvar
|
||||
from .dispatch import dispatch
|
||||
|
||||
|
||||
__all__ = ["reify", "unify"]
|
||||
|
||||
@ -13,33 +14,47 @@ __all__ = ["reify", "unify"]
|
||||
# Reification #
|
||||
###############
|
||||
|
||||
|
||||
@dispatch(Iterator, dict)
|
||||
def _reify(t, s):
|
||||
return map(partial(reify, s=s), t)
|
||||
# return (reify(arg, s) for arg in t)
|
||||
|
||||
|
||||
_reify
|
||||
|
||||
|
||||
@dispatch(tuple, dict) # type: ignore[no-redef]
|
||||
def _reify(t, s):
|
||||
return tuple(reify(iter(t), s))
|
||||
|
||||
|
||||
_reify
|
||||
|
||||
|
||||
@dispatch(list, dict) # type: ignore[no-redef]
|
||||
def _reify(t, s):
|
||||
return list(reify(iter(t), s))
|
||||
|
||||
|
||||
_reify
|
||||
|
||||
|
||||
@dispatch(dict, dict) # type: ignore[no-redef]
|
||||
def _reify(d, s):
|
||||
return {k: reify(v, s) for k, v in d.items()}
|
||||
|
||||
|
||||
_reify
|
||||
|
||||
|
||||
@dispatch(object, dict) # type: ignore[no-redef]
|
||||
def _reify(o, s):
|
||||
return o # catch all, just return the object
|
||||
|
||||
|
||||
def reify(e, s):
|
||||
""" Replace variables of expression with substitution
|
||||
"""Replace variables of expression with substitution
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> x, y = var(), var()
|
||||
>>> e = (1, x, (3, y))
|
||||
@ -54,12 +69,14 @@ def reify(e, s):
|
||||
return reify(s[e], s) if e in s else e
|
||||
return _reify(e, s)
|
||||
|
||||
|
||||
###############
|
||||
# Unification #
|
||||
###############
|
||||
|
||||
seq = tuple, list, Iterator
|
||||
|
||||
|
||||
@dispatch(seq, seq, dict)
|
||||
def _unify(u, v, s):
|
||||
if len(u) != len(v):
|
||||
@ -69,6 +86,8 @@ def _unify(u, v, s):
|
||||
if s is False:
|
||||
return False
|
||||
return s
|
||||
|
||||
|
||||
#
|
||||
# @dispatch((set, frozenset), (set, frozenset), dict)
|
||||
# def _unify(u, v, s):
|
||||
@ -98,8 +117,8 @@ def _unify(u, v, s):
|
||||
|
||||
@dispatch(object, object, dict)
|
||||
def unify(u, v, s): # no check at the moment
|
||||
""" Find substitution so that u == v while satisfying s
|
||||
>>> x = var('x')
|
||||
"""Find substitution so that u == v while satisfying s
|
||||
>>> x = var("x")
|
||||
>>> unify((1, x), (1, 2), {})
|
||||
{~x: 2}
|
||||
"""
|
||||
@ -112,8 +131,11 @@ def unify(u, v, s): # no check at the moment
|
||||
if isvar(v):
|
||||
return assoc(s, v, u)
|
||||
return _unify(u, v, s)
|
||||
|
||||
|
||||
unify
|
||||
|
||||
|
||||
@dispatch(object, object) # type: ignore[no-redef]
|
||||
def unify(u, v):
|
||||
return unify(u, v, {})
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from functools import partial
|
||||
|
||||
from .multipledispatch import dispatch # type: ignore[import]
|
||||
|
||||
|
||||
namespace = {} # type: ignore[var-annotated]
|
||||
|
||||
dispatch = partial(dispatch, namespace=namespace)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from .core import unify, reify # type: ignore[attr-defined]
|
||||
from .variable import isvar
|
||||
from .core import reify, unify # type: ignore[attr-defined]
|
||||
from .unification_tools import first, groupby # type: ignore[import]
|
||||
from .utils import _toposort, freeze
|
||||
from .unification_tools import groupby, first # type: ignore[import]
|
||||
from .variable import isvar
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
@ -28,32 +28,38 @@ class Dispatcher:
|
||||
if s is not False:
|
||||
result = self.funcs[signature]
|
||||
return result, s
|
||||
raise NotImplementedError("No match found. \nKnown matches: "
|
||||
+ str(self.ordering) + "\nInput: " + str(args))
|
||||
raise NotImplementedError(
|
||||
"No match found. \nKnown matches: "
|
||||
+ str(self.ordering)
|
||||
+ "\nInput: "
|
||||
+ str(args)
|
||||
)
|
||||
|
||||
def register(self, *signature):
|
||||
def _(func):
|
||||
self.add(signature, func)
|
||||
return self
|
||||
|
||||
return _
|
||||
|
||||
|
||||
class VarDispatcher(Dispatcher):
|
||||
""" A dispatcher that calls functions with variable names
|
||||
"""A dispatcher that calls functions with variable names
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> d = VarDispatcher('d')
|
||||
>>> x = var('x')
|
||||
>>> @d.register('inc', x)
|
||||
>>> d = VarDispatcher("d")
|
||||
>>> x = var("x")
|
||||
>>> @d.register("inc", x)
|
||||
... def f(x):
|
||||
... return x + 1
|
||||
>>> @d.register('double', x)
|
||||
>>> @d.register("double", x)
|
||||
... def f(x):
|
||||
... return x * 2
|
||||
>>> d('inc', 10)
|
||||
>>> d("inc", 10)
|
||||
11
|
||||
>>> d('double', 10)
|
||||
>>> d("double", 10)
|
||||
20
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
func, s = self.resolve(args)
|
||||
d = {k.token: v for k, v in s.items()}
|
||||
@ -64,8 +70,8 @@ global_namespace = {} # type: ignore[var-annotated]
|
||||
|
||||
|
||||
def match(*signature, **kwargs):
|
||||
namespace = kwargs.get('namespace', global_namespace)
|
||||
dispatcher = kwargs.get('Dispatcher', Dispatcher)
|
||||
namespace = kwargs.get("namespace", global_namespace)
|
||||
dispatcher = kwargs.get("Dispatcher", Dispatcher)
|
||||
|
||||
def _(func):
|
||||
name = func.__name__
|
||||
@ -77,11 +83,12 @@ def match(*signature, **kwargs):
|
||||
d.add(signature, func)
|
||||
|
||||
return d
|
||||
|
||||
return _
|
||||
|
||||
|
||||
def supercedes(a, b):
|
||||
""" ``a`` is a more specific match than ``b`` """
|
||||
"""``a`` is a more specific match than ``b``"""
|
||||
if isvar(b) and not isvar(a):
|
||||
return True
|
||||
s = unify(a, b)
|
||||
@ -96,7 +103,7 @@ def supercedes(a, b):
|
||||
|
||||
# Taken from multipledispatch
|
||||
def edge(a, b, tie_breaker=hash):
|
||||
""" A should be checked before B
|
||||
"""A should be checked before B
|
||||
Tie broken by tie_breaker, defaults to ``hash``
|
||||
"""
|
||||
if supercedes(a, b):
|
||||
@ -109,7 +116,7 @@ def edge(a, b, tie_breaker=hash):
|
||||
|
||||
# Taken from multipledispatch
|
||||
def ordering(signatures):
|
||||
""" A sane ordering of signatures to check, first to last
|
||||
"""A sane ordering of signatures to check, first to last
|
||||
Topological sort of edges as given by ``edge`` and ``supercedes``
|
||||
"""
|
||||
signatures = list(map(tuple, signatures))
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from .core import unify, reify # type: ignore[attr-defined]
|
||||
from .core import reify, unify # type: ignore[attr-defined]
|
||||
from .dispatch import dispatch
|
||||
|
||||
|
||||
def unifiable(cls):
|
||||
""" Register standard unify and reify operations on class
|
||||
"""Register standard unify and reify operations on class
|
||||
This uses the type and __dict__ or __slots__ attributes to define the
|
||||
nature of the term
|
||||
See Also:
|
||||
@ -15,7 +15,7 @@ def unifiable(cls):
|
||||
... self.b = b
|
||||
>>> unifiable(A)
|
||||
<class 'unification.more.A'>
|
||||
>>> x = var('x')
|
||||
>>> x = var("x")
|
||||
>>> a = A(1, 2)
|
||||
>>> b = A(1, x)
|
||||
>>> unify(a, b, {})
|
||||
@ -33,22 +33,23 @@ def unifiable(cls):
|
||||
|
||||
|
||||
def reify_object(o, s):
|
||||
""" Reify a Python object with a substitution
|
||||
"""Reify a Python object with a substitution
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> class Foo(object):
|
||||
... def __init__(self, a, b):
|
||||
... self.a = a
|
||||
... self.b = b
|
||||
...
|
||||
... def __str__(self):
|
||||
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
|
||||
>>> x = var('x')
|
||||
... return "Foo(%s, %s)" % (str(self.a), str(self.b))
|
||||
>>> x = var("x")
|
||||
>>> f = Foo(1, x)
|
||||
>>> print(f)
|
||||
Foo(1, ~x)
|
||||
>>> print(reify_object(f, {x: 2}))
|
||||
Foo(1, 2)
|
||||
"""
|
||||
if hasattr(o, '__slots__'):
|
||||
if hasattr(o, "__slots__"):
|
||||
return _reify_object_slots(o, s)
|
||||
else:
|
||||
return _reify_object_dict(o, s)
|
||||
@ -77,7 +78,7 @@ def _reify_object_slots(o, s):
|
||||
|
||||
@dispatch(slice, dict)
|
||||
def _reify(o, s):
|
||||
""" Reify a Python ``slice`` object """
|
||||
"""Reify a Python ``slice`` object"""
|
||||
return slice(*reify((o.start, o.stop, o.step), s))
|
||||
|
||||
|
||||
@ -87,16 +88,17 @@ def _reify(o, s):
|
||||
|
||||
|
||||
def unify_object(u, v, s):
|
||||
""" Unify two Python objects
|
||||
"""Unify two Python objects
|
||||
Unifies their type and ``__dict__`` attributes
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> class Foo(object):
|
||||
... def __init__(self, a, b):
|
||||
... self.a = a
|
||||
... self.b = b
|
||||
...
|
||||
... def __str__(self):
|
||||
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
|
||||
>>> x = var('x')
|
||||
... return "Foo(%s, %s)" % (str(self.a), str(self.b))
|
||||
>>> x = var("x")
|
||||
>>> f = Foo(1, x)
|
||||
>>> g = Foo(1, 2)
|
||||
>>> unify_object(f, g, {})
|
||||
@ -104,15 +106,17 @@ def unify_object(u, v, s):
|
||||
"""
|
||||
if type(u) != type(v):
|
||||
return False
|
||||
if hasattr(u, '__slots__'):
|
||||
return unify([getattr(u, slot) for slot in u.__slots__],
|
||||
[getattr(v, slot) for slot in v.__slots__],
|
||||
s)
|
||||
if hasattr(u, "__slots__"):
|
||||
return unify(
|
||||
[getattr(u, slot) for slot in u.__slots__],
|
||||
[getattr(v, slot) for slot in v.__slots__],
|
||||
s,
|
||||
)
|
||||
else:
|
||||
return unify(u.__dict__, v.__dict__, s)
|
||||
|
||||
|
||||
@dispatch(slice, slice, dict)
|
||||
def _unify(u, v, s):
|
||||
""" Unify a Python ``slice`` object """
|
||||
"""Unify a Python ``slice`` object"""
|
||||
return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
from .core import dispatch
|
||||
from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
|
||||
MDNotImplementedError)
|
||||
from .dispatcher import (
|
||||
Dispatcher,
|
||||
halt_ordering,
|
||||
MDNotImplementedError,
|
||||
restart_ordering,
|
||||
)
|
||||
|
||||
@ -1,17 +1,28 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from .utils import _toposort, groupby
|
||||
from .variadic import isvariadic
|
||||
import operator
|
||||
|
||||
__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature",
|
||||
"edge", "ordering"]
|
||||
from .utils import _toposort, groupby
|
||||
from .variadic import isvariadic
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AmbiguityWarning",
|
||||
"supercedes",
|
||||
"consistent",
|
||||
"ambiguous",
|
||||
"ambiguities",
|
||||
"super_signature",
|
||||
"edge",
|
||||
"ordering",
|
||||
]
|
||||
|
||||
|
||||
class AmbiguityWarning(Warning):
|
||||
pass
|
||||
|
||||
|
||||
def supercedes(a, b):
|
||||
""" A is consistent and strictly more specific than B """
|
||||
"""A is consistent and strictly more specific than B"""
|
||||
if len(a) < len(b):
|
||||
# only case is if a is empty and b is variadic
|
||||
return not a and len(b) == 1 and isvariadic(b[-1])
|
||||
@ -41,7 +52,7 @@ def supercedes(a, b):
|
||||
|
||||
|
||||
def consistent(a, b):
|
||||
""" It is possible for an argument list to satisfy both A and B """
|
||||
"""It is possible for an argument list to satisfy both A and B"""
|
||||
|
||||
# Need to check for empty args
|
||||
if not a:
|
||||
@ -51,8 +62,7 @@ def consistent(a, b):
|
||||
|
||||
# Non-empty args check for mutual subclasses
|
||||
if len(a) == len(b):
|
||||
return all(issubclass(aa, bb) or issubclass(bb, aa)
|
||||
for aa, bb in zip(a, b))
|
||||
return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b))
|
||||
else:
|
||||
p1 = 0
|
||||
p2 = 0
|
||||
@ -70,45 +80,53 @@ def consistent(a, b):
|
||||
p1 += 1
|
||||
# We only need to check for variadic ends
|
||||
# Variadic types are guaranteed to be the last element
|
||||
return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined]
|
||||
isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined]
|
||||
return (
|
||||
isvariadic(cur_a) # type: ignore[possibly-undefined]
|
||||
and p2 == len(b)
|
||||
or isvariadic(cur_b) # type: ignore[possibly-undefined]
|
||||
and p1 == len(a)
|
||||
)
|
||||
|
||||
|
||||
def ambiguous(a, b):
|
||||
""" A is consistent with B but neither is strictly more specific """
|
||||
"""A is consistent with B but neither is strictly more specific"""
|
||||
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
|
||||
|
||||
|
||||
def ambiguities(signatures):
|
||||
""" All signature pairs such that A is ambiguous with B """
|
||||
"""All signature pairs such that A is ambiguous with B"""
|
||||
signatures = list(map(tuple, signatures))
|
||||
return {(a, b) for a in signatures for b in signatures
|
||||
if hash(a) < hash(b)
|
||||
and ambiguous(a, b)
|
||||
and not any(supercedes(c, a) and supercedes(c, b)
|
||||
for c in signatures)}
|
||||
return {
|
||||
(a, b)
|
||||
for a in signatures
|
||||
for b in signatures
|
||||
if hash(a) < hash(b)
|
||||
and ambiguous(a, b)
|
||||
and not any(supercedes(c, a) and supercedes(c, b) for c in signatures)
|
||||
}
|
||||
|
||||
|
||||
def super_signature(signatures):
|
||||
""" A signature that would break ambiguities """
|
||||
"""A signature that would break ambiguities"""
|
||||
n = len(signatures[0])
|
||||
assert all(len(s) == n for s in signatures)
|
||||
|
||||
return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
|
||||
for i in range(n)]
|
||||
return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)]
|
||||
|
||||
|
||||
def edge(a, b, tie_breaker=hash):
|
||||
""" A should be checked before B
|
||||
"""A should be checked before B
|
||||
Tie broken by tie_breaker, defaults to ``hash``
|
||||
"""
|
||||
# A either supercedes B and B does not supercede A or if B does then call
|
||||
# tie_breaker
|
||||
return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b))
|
||||
return supercedes(a, b) and (
|
||||
not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)
|
||||
)
|
||||
|
||||
|
||||
def ordering(signatures):
|
||||
""" A sane ordering of signatures to check, first to last
|
||||
"""A sane ordering of signatures to check, first to last
|
||||
Topological sort of edges as given by ``edge`` and ``supercedes``
|
||||
"""
|
||||
signatures = list(map(tuple, signatures))
|
||||
|
||||
@ -4,12 +4,14 @@ import sys
|
||||
|
||||
from .dispatcher import Dispatcher, MethodDispatcher
|
||||
|
||||
|
||||
global_namespace = {} # type: ignore[var-annotated]
|
||||
|
||||
__all__ = ["dispatch", "ismethod"]
|
||||
|
||||
|
||||
def dispatch(*types, **kwargs):
|
||||
""" Dispatch function on the types of the inputs
|
||||
"""Dispatch function on the types of the inputs
|
||||
Supports dispatch on all non-keyword arguments.
|
||||
Collects implementations based on the function name. Ignores namespaces.
|
||||
If ambiguous type signatures occur a warning is raised when the function is
|
||||
@ -38,6 +40,7 @@ def dispatch(*types, **kwargs):
|
||||
... @dispatch(list)
|
||||
... def __init__(self, data):
|
||||
... self.data = data
|
||||
...
|
||||
... @dispatch(int)
|
||||
... def __init__(self, datum):
|
||||
... self.data = [datum]
|
||||
@ -46,7 +49,7 @@ def dispatch(*types, **kwargs):
|
||||
>>> MyClass(3).data
|
||||
[3]
|
||||
"""
|
||||
namespace = kwargs.get('namespace', global_namespace)
|
||||
namespace = kwargs.get("namespace", global_namespace)
|
||||
|
||||
types = tuple(types)
|
||||
|
||||
@ -65,20 +68,21 @@ def dispatch(*types, **kwargs):
|
||||
|
||||
dispatcher.add(types, func)
|
||||
return dispatcher
|
||||
|
||||
return _df
|
||||
|
||||
|
||||
def ismethod(func):
|
||||
""" Is func a method?
|
||||
"""Is func a method?
|
||||
Note that this has to work as the method is defined but before the class is
|
||||
defined. At this stage methods look like functions.
|
||||
"""
|
||||
if hasattr(inspect, "signature"):
|
||||
signature = inspect.signature(func)
|
||||
return signature.parameters.get('self', None) is not None
|
||||
return signature.parameters.get("self", None) is not None
|
||||
else:
|
||||
if sys.version_info.major < 3:
|
||||
spec = inspect.getargspec(func) # type: ignore[attr-defined]
|
||||
else:
|
||||
spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment]
|
||||
return spec and spec.args and spec.args[0] == 'self'
|
||||
return spec and spec.args and spec.args[0] == "self"
|
||||
|
||||
@ -1,21 +1,35 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from warnings import warn
|
||||
import inspect
|
||||
from typing_extensions import deprecated
|
||||
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
|
||||
from .utils import expand_tuples
|
||||
from .variadic import Variadic, isvariadic
|
||||
import itertools as itl
|
||||
from typing_extensions import deprecated
|
||||
from warnings import warn
|
||||
|
||||
from .conflict import ambiguities, AmbiguityWarning, ordering, super_signature
|
||||
from .utils import expand_tuples
|
||||
from .variadic import isvariadic, Variadic
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MDNotImplementedError",
|
||||
"ambiguity_warn",
|
||||
"halt_ordering",
|
||||
"restart_ordering",
|
||||
"variadic_signature_matches_iter",
|
||||
"variadic_signature_matches",
|
||||
"Dispatcher",
|
||||
"source",
|
||||
"MethodDispatcher",
|
||||
"str_signature",
|
||||
"warning_text",
|
||||
]
|
||||
|
||||
__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
|
||||
"variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
|
||||
|
||||
class MDNotImplementedError(NotImplementedError):
|
||||
""" A NotImplementedError for multiple dispatch """
|
||||
"""A NotImplementedError for multiple dispatch"""
|
||||
|
||||
|
||||
def ambiguity_warn(dispatcher, ambiguities):
|
||||
""" Raise warning when ambiguity is detected
|
||||
"""Raise warning when ambiguity is detected
|
||||
Parameters
|
||||
----------
|
||||
dispatcher : Dispatcher
|
||||
@ -92,7 +106,7 @@ def variadic_signature_matches(types, full_signature):
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
""" Dispatch methods based on type signature
|
||||
"""Dispatch methods based on type signature
|
||||
Use ``dispatch`` to add implementations
|
||||
Examples
|
||||
--------
|
||||
@ -109,7 +123,8 @@ class Dispatcher:
|
||||
>>> f(3.0)
|
||||
2.0
|
||||
"""
|
||||
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
|
||||
|
||||
__slots__ = "__name__", "name", "funcs", "_ordering", "_cache", "doc"
|
||||
|
||||
def __init__(self, name, doc=None):
|
||||
self.name = self.__name__ = name
|
||||
@ -119,9 +134,9 @@ class Dispatcher:
|
||||
self._cache = {}
|
||||
|
||||
def register(self, *types, **kwargs):
|
||||
""" register dispatcher with new implementation
|
||||
"""register dispatcher with new implementation
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> f = Dispatcher('f')
|
||||
>>> f = Dispatcher("f")
|
||||
>>> @f.register(int)
|
||||
... def inc(x):
|
||||
... return x + 1
|
||||
@ -139,9 +154,11 @@ class Dispatcher:
|
||||
>>> f([1, 2, 3])
|
||||
[3, 2, 1]
|
||||
"""
|
||||
|
||||
def _df(func):
|
||||
self.add(types, func, **kwargs) # type: ignore[call-arg]
|
||||
self.add(types, func, **kwargs) # type: ignore[call-arg]
|
||||
return func
|
||||
|
||||
return _df
|
||||
|
||||
@classmethod
|
||||
@ -152,28 +169,27 @@ class Dispatcher:
|
||||
|
||||
@classmethod
|
||||
def get_func_annotations(cls, func):
|
||||
""" get annotations of function positional parameters
|
||||
"""
|
||||
"""get annotations of function positional parameters"""
|
||||
params = cls.get_func_params(func)
|
||||
if params:
|
||||
Parameter = inspect.Parameter
|
||||
|
||||
params = (param for param in params
|
||||
if param.kind in
|
||||
(Parameter.POSITIONAL_ONLY,
|
||||
Parameter.POSITIONAL_OR_KEYWORD))
|
||||
params = (
|
||||
param
|
||||
for param in params
|
||||
if param.kind
|
||||
in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
||||
)
|
||||
|
||||
annotations = tuple(
|
||||
param.annotation
|
||||
for param in params)
|
||||
annotations = tuple(param.annotation for param in params)
|
||||
|
||||
if all(ann is not Parameter.empty for ann in annotations):
|
||||
return annotations
|
||||
|
||||
def add(self, signature, func):
|
||||
""" Add new types/method pair to dispatcher
|
||||
"""Add new types/method pair to dispatcher
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> D = Dispatcher('add')
|
||||
>>> D = Dispatcher("add")
|
||||
>>> D.add((int, int), lambda x, y: x + y)
|
||||
>>> D.add((float, float), lambda x, y: x + y)
|
||||
>>> D(1, 2)
|
||||
@ -202,24 +218,25 @@ class Dispatcher:
|
||||
|
||||
for index, typ in enumerate(signature, start=1):
|
||||
if not isinstance(typ, (type, list)):
|
||||
str_sig = ', '.join(c.__name__ if isinstance(c, type)
|
||||
else str(c) for c in signature)
|
||||
raise TypeError(f"Tried to dispatch on non-type: {typ}\n"
|
||||
f"In signature: <{str_sig}>\n"
|
||||
f"In function: {self.name}")
|
||||
str_sig = ", ".join(
|
||||
c.__name__ if isinstance(c, type) else str(c) for c in signature
|
||||
)
|
||||
raise TypeError(
|
||||
f"Tried to dispatch on non-type: {typ}\n"
|
||||
f"In signature: <{str_sig}>\n"
|
||||
f"In function: {self.name}"
|
||||
)
|
||||
|
||||
# handle variadic signatures
|
||||
if isinstance(typ, list):
|
||||
if index != len(signature):
|
||||
raise TypeError(
|
||||
'Variadic signature must be the last element'
|
||||
)
|
||||
raise TypeError("Variadic signature must be the last element")
|
||||
|
||||
if len(typ) != 1:
|
||||
raise TypeError(
|
||||
'Variadic signature must contain exactly one element. '
|
||||
'To use a variadic union type place the desired types '
|
||||
'inside of a tuple, e.g., [(int, str)]'
|
||||
"Variadic signature must contain exactly one element. "
|
||||
"To use a variadic union type place the desired types "
|
||||
"inside of a tuple, e.g., [(int, str)]"
|
||||
)
|
||||
new_signature.append(Variadic[typ[0]])
|
||||
else:
|
||||
@ -255,7 +272,8 @@ class Dispatcher:
|
||||
func = self.dispatch(*types)
|
||||
if not func:
|
||||
raise NotImplementedError(
|
||||
f'Could not find signature for {self.name}: <{str_signature(types)}>') from e
|
||||
f"Could not find signature for {self.name}: <{str_signature(types)}>"
|
||||
) from e
|
||||
self._cache[types] = func
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@ -271,10 +289,12 @@ class Dispatcher:
|
||||
|
||||
raise NotImplementedError(
|
||||
"Matching functions for "
|
||||
f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e
|
||||
f"{self.name}: <{str_signature(types)}> found, but none completed successfully",
|
||||
) from e
|
||||
|
||||
def __str__(self):
|
||||
return f"<dispatched {self.name}>"
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
def dispatch(self, *types):
|
||||
@ -304,7 +324,6 @@ class Dispatcher:
|
||||
return None
|
||||
|
||||
def dispatch_iter(self, *types):
|
||||
|
||||
n = len(types)
|
||||
for signature in self.ordering:
|
||||
if len(signature) == n and all(map(issubclass, types, signature)):
|
||||
@ -315,21 +334,22 @@ class Dispatcher:
|
||||
result = self.funcs[signature]
|
||||
yield result
|
||||
|
||||
@deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning)
|
||||
@deprecated(
|
||||
"`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning
|
||||
)
|
||||
def resolve(self, types):
|
||||
""" Determine appropriate implementation for this type signature
|
||||
"""Determine appropriate implementation for this type signature
|
||||
.. deprecated:: 0.4.4
|
||||
Use ``dispatch(*types)`` instead
|
||||
"""
|
||||
return self.dispatch(*types)
|
||||
|
||||
def __getstate__(self):
|
||||
return {'name': self.name,
|
||||
'funcs': self.funcs}
|
||||
return {"name": self.name, "funcs": self.funcs}
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.name = d['name']
|
||||
self.funcs = d['funcs']
|
||||
self.name = d["name"]
|
||||
self.funcs = d["funcs"]
|
||||
self._ordering = ordering(self.funcs)
|
||||
self._cache = {}
|
||||
|
||||
@ -344,23 +364,23 @@ class Dispatcher:
|
||||
for sig in self.ordering[::-1]:
|
||||
func = self.funcs[sig]
|
||||
if func.__doc__:
|
||||
s = f'Inputs: <{str_signature(sig)}>\n'
|
||||
s += '-' * len(s) + '\n'
|
||||
s = f"Inputs: <{str_signature(sig)}>\n"
|
||||
s += "-" * len(s) + "\n"
|
||||
s += func.__doc__.strip()
|
||||
docs.append(s)
|
||||
else:
|
||||
other.append(str_signature(sig))
|
||||
|
||||
if other:
|
||||
docs.append('Other signatures:\n ' + '\n '.join(other))
|
||||
docs.append("Other signatures:\n " + "\n ".join(other))
|
||||
|
||||
return '\n\n'.join(docs)
|
||||
return "\n\n".join(docs)
|
||||
|
||||
def _help(self, *args):
|
||||
return self.dispatch(*map(type, args)).__doc__
|
||||
|
||||
def help(self, *args, **kwargs):
|
||||
""" Print docstring for the function corresponding to inputs """
|
||||
"""Print docstring for the function corresponding to inputs"""
|
||||
print(self._help(*args))
|
||||
|
||||
def _source(self, *args):
|
||||
@ -370,22 +390,23 @@ class Dispatcher:
|
||||
return source(func)
|
||||
|
||||
def source(self, *args, **kwargs):
|
||||
""" Print source code for the function corresponding to inputs """
|
||||
"""Print source code for the function corresponding to inputs"""
|
||||
print(self._source(*args))
|
||||
|
||||
|
||||
def source(func):
|
||||
s = f'File: {inspect.getsourcefile(func)}\n\n'
|
||||
s = f"File: {inspect.getsourcefile(func)}\n\n"
|
||||
s = s + inspect.getsource(func)
|
||||
return s
|
||||
|
||||
|
||||
class MethodDispatcher(Dispatcher):
|
||||
""" Dispatch methods based on type signature
|
||||
"""Dispatch methods based on type signature
|
||||
See Also:
|
||||
Dispatcher
|
||||
"""
|
||||
__slots__ = ('obj', 'cls')
|
||||
|
||||
__slots__ = ("obj", "cls")
|
||||
|
||||
@classmethod
|
||||
def get_func_params(cls, func):
|
||||
@ -402,26 +423,31 @@ class MethodDispatcher(Dispatcher):
|
||||
types = tuple([type(arg) for arg in args])
|
||||
func = self.dispatch(*types)
|
||||
if not func:
|
||||
raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>')
|
||||
raise NotImplementedError(
|
||||
f"Could not find signature for {self.name}: <{str_signature(types)}>"
|
||||
)
|
||||
return func(self.obj, *args, **kwargs)
|
||||
|
||||
|
||||
def str_signature(sig):
|
||||
""" String representation of type signature
|
||||
"""String representation of type signature
|
||||
>>> str_signature((int, float))
|
||||
'int, float'
|
||||
"""
|
||||
return ', '.join(cls.__name__ for cls in sig)
|
||||
return ", ".join(cls.__name__ for cls in sig)
|
||||
|
||||
|
||||
def warning_text(name, amb):
|
||||
""" The text for ambiguity warnings """
|
||||
"""The text for ambiguity warnings"""
|
||||
text = f"\nAmbiguities exist in dispatched function {name}\n\n"
|
||||
text += "The following signatures may result in ambiguous behavior:\n"
|
||||
for pair in amb:
|
||||
text += "\t" + \
|
||||
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
|
||||
text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n"
|
||||
text += "\n\nConsider making the following additions:\n\n"
|
||||
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
|
||||
+ f')\ndef {name}(...)' for s in amb])
|
||||
text += "\n\n".join(
|
||||
[
|
||||
"@dispatch(" + str_signature(super_signature(s)) + f")\ndef {name}(...)"
|
||||
for s in amb
|
||||
]
|
||||
)
|
||||
return text
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
|
||||
|
||||
|
||||
def raises(err, lamda):
|
||||
try:
|
||||
lamda()
|
||||
@ -31,12 +33,12 @@ def expand_tuples(L):
|
||||
# Taken from theano/theano/gof/sched.py
|
||||
# Avoids licensing issues because this was written by Matthew Rocklin
|
||||
def _toposort(edges):
|
||||
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
||||
"""Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
||||
inputs:
|
||||
edges - a dict of the form {a: {b, c}} where b and c depend on a
|
||||
outputs:
|
||||
L - an ordered list of nodes that satisfy the dependencies of edges
|
||||
>>> _toposort({1: (2, 3), 2: (3, )})
|
||||
>>> _toposort({1: (2, 3), 2: (3,)})
|
||||
[1, 2, 3]
|
||||
>>> # Closely follows the wikipedia page [2]
|
||||
>>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
|
||||
@ -44,8 +46,7 @@ def _toposort(edges):
|
||||
>>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
|
||||
"""
|
||||
incoming_edges = reverse_dict(edges)
|
||||
incoming_edges = OrderedDict((k, set(val))
|
||||
for k, val in incoming_edges.items())
|
||||
incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items())
|
||||
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
|
||||
L = []
|
||||
|
||||
@ -64,7 +65,7 @@ def _toposort(edges):
|
||||
|
||||
def reverse_dict(d):
|
||||
"""Reverses direction of dependence dict
|
||||
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
|
||||
>>> d = {"a": (1, 2), "b": (2, 3), "c": ()}
|
||||
>>> reverse_dict(d) # doctest: +SKIP
|
||||
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
|
||||
:note: dict order are not deterministic. As we iterate on the
|
||||
@ -82,8 +83,8 @@ def reverse_dict(d):
|
||||
# Taken from toolz
|
||||
# Avoids licensing issues because this version was authored by Matthew Rocklin
|
||||
def groupby(func, seq):
|
||||
""" Group a collection by a key function
|
||||
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
|
||||
"""Group a collection by a key function
|
||||
>>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]
|
||||
>>> groupby(len, names) # doctest: +SKIP
|
||||
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
|
||||
>>> iseven = lambda x: x % 2 == 0
|
||||
|
||||
@ -1,15 +1,17 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from .utils import typename
|
||||
|
||||
|
||||
__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
|
||||
|
||||
|
||||
class VariadicSignatureType(type):
|
||||
# checking if subclass is a subclass of self
|
||||
def __subclasscheck__(cls, subclass):
|
||||
other_type = (subclass.variadic_type if isvariadic(subclass)
|
||||
else (subclass,))
|
||||
other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,)
|
||||
return subclass is cls or all(
|
||||
issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined]
|
||||
issubclass(other, cls.variadic_type) # type: ignore[attr-defined]
|
||||
for other in other_type
|
||||
)
|
||||
|
||||
def __eq__(cls, other):
|
||||
@ -24,8 +26,7 @@ class VariadicSignatureType(type):
|
||||
bool
|
||||
Whether or not `other` is equal to `self`
|
||||
"""
|
||||
return (isvariadic(other) and
|
||||
set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined]
|
||||
return isvariadic(other) and set(cls.variadic_type) == set(other.variadic_type) # type: ignore[attr-defined]
|
||||
|
||||
def __hash__(cls):
|
||||
return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined]
|
||||
@ -57,17 +58,20 @@ class VariadicSignatureMeta(type):
|
||||
generate a new type for Variadic signatures. See the Variadic class for
|
||||
examples of how this behaves.
|
||||
"""
|
||||
|
||||
def __getitem__(cls, variadic_type):
|
||||
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
|
||||
raise ValueError("Variadic types must be type or tuple of types"
|
||||
" (Variadic[int] or Variadic[(int, float)]")
|
||||
raise ValueError(
|
||||
"Variadic types must be type or tuple of types"
|
||||
" (Variadic[int] or Variadic[(int, float)]"
|
||||
)
|
||||
|
||||
if not isinstance(variadic_type, tuple):
|
||||
variadic_type = variadic_type,
|
||||
variadic_type = (variadic_type,)
|
||||
return VariadicSignatureType(
|
||||
f'Variadic[{typename(variadic_type)}]',
|
||||
f"Variadic[{typename(variadic_type)}]",
|
||||
(),
|
||||
dict(variadic_type=variadic_type, __slots__=())
|
||||
dict(variadic_type=variadic_type, __slots__=()),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,25 +1,40 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
import operator
|
||||
from functools import reduce
|
||||
from collections.abc import Mapping
|
||||
from functools import reduce
|
||||
|
||||
__all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
|
||||
'valfilter', 'keyfilter', 'itemfilter',
|
||||
'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in']
|
||||
|
||||
__all__ = [
|
||||
"merge",
|
||||
"merge_with",
|
||||
"valmap",
|
||||
"keymap",
|
||||
"itemmap",
|
||||
"valfilter",
|
||||
"keyfilter",
|
||||
"itemfilter",
|
||||
"assoc",
|
||||
"dissoc",
|
||||
"assoc_in",
|
||||
"update_in",
|
||||
"get_in",
|
||||
]
|
||||
|
||||
|
||||
def _get_factory(f, kwargs):
|
||||
factory = kwargs.pop('factory', dict)
|
||||
factory = kwargs.pop("factory", dict)
|
||||
if kwargs:
|
||||
raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
|
||||
raise TypeError(
|
||||
f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'"
|
||||
)
|
||||
return factory
|
||||
|
||||
|
||||
def merge(*dicts, **kwargs):
|
||||
""" Merge a collection of dictionaries
|
||||
"""Merge a collection of dictionaries
|
||||
|
||||
>>> merge({1: 'one'}, {2: 'two'})
|
||||
>>> merge({1: "one"}, {2: "two"})
|
||||
{1: 'one', 2: 'two'}
|
||||
|
||||
Later dictionaries have precedence
|
||||
@ -41,7 +56,7 @@ def merge(*dicts, **kwargs):
|
||||
|
||||
|
||||
def merge_with(func, *dicts, **kwargs):
|
||||
""" Merge dictionaries and apply function to combined values
|
||||
"""Merge dictionaries and apply function to combined values
|
||||
|
||||
A key may occur in more than one dict, and all values mapped from the key
|
||||
will be passed to the function as a list, such as func([val1, val2, ...]).
|
||||
@ -70,7 +85,7 @@ def merge_with(func, *dicts, **kwargs):
|
||||
|
||||
|
||||
def valmap(func, d, factory=dict):
|
||||
""" Apply function to values of dictionary
|
||||
"""Apply function to values of dictionary
|
||||
|
||||
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
|
||||
>>> valmap(sum, bills) # doctest: +SKIP
|
||||
@ -86,7 +101,7 @@ def valmap(func, d, factory=dict):
|
||||
|
||||
|
||||
def keymap(func, d, factory=dict):
|
||||
""" Apply function to keys of dictionary
|
||||
"""Apply function to keys of dictionary
|
||||
|
||||
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
|
||||
>>> keymap(str.lower, bills) # doctest: +SKIP
|
||||
@ -102,7 +117,7 @@ def keymap(func, d, factory=dict):
|
||||
|
||||
|
||||
def itemmap(func, d, factory=dict):
|
||||
""" Apply function to items of dictionary
|
||||
"""Apply function to items of dictionary
|
||||
|
||||
>>> accountids = {"Alice": 10, "Bob": 20}
|
||||
>>> itemmap(reversed, accountids) # doctest: +SKIP
|
||||
@ -118,7 +133,7 @@ def itemmap(func, d, factory=dict):
|
||||
|
||||
|
||||
def valfilter(predicate, d, factory=dict):
|
||||
""" Filter items in dictionary by value
|
||||
"""Filter items in dictionary by value
|
||||
|
||||
>>> iseven = lambda x: x % 2 == 0
|
||||
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
||||
@ -138,7 +153,7 @@ def valfilter(predicate, d, factory=dict):
|
||||
|
||||
|
||||
def keyfilter(predicate, d, factory=dict):
|
||||
""" Filter items in dictionary by key
|
||||
"""Filter items in dictionary by key
|
||||
|
||||
>>> iseven = lambda x: x % 2 == 0
|
||||
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
||||
@ -158,7 +173,7 @@ def keyfilter(predicate, d, factory=dict):
|
||||
|
||||
|
||||
def itemfilter(predicate, d, factory=dict):
|
||||
""" Filter items in dictionary by item
|
||||
"""Filter items in dictionary by item
|
||||
|
||||
>>> def isvalid(item):
|
||||
... k, v = item
|
||||
@ -182,13 +197,13 @@ def itemfilter(predicate, d, factory=dict):
|
||||
|
||||
|
||||
def assoc(d, key, value, factory=dict):
|
||||
""" Return a new dict with new key value pair
|
||||
"""Return a new dict with new key value pair
|
||||
|
||||
New dict has d[key] set to value. Does not modify the initial dictionary.
|
||||
|
||||
>>> assoc({'x': 1}, 'x', 2)
|
||||
>>> assoc({"x": 1}, "x", 2)
|
||||
{'x': 2}
|
||||
>>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
|
||||
>>> assoc({"x": 1}, "y", 3) # doctest: +SKIP
|
||||
{'x': 1, 'y': 3}
|
||||
"""
|
||||
d2 = factory()
|
||||
@ -198,22 +213,22 @@ def assoc(d, key, value, factory=dict):
|
||||
|
||||
|
||||
def dissoc(d, *keys, **kwargs):
|
||||
""" Return a new dict with the given key(s) removed.
|
||||
"""Return a new dict with the given key(s) removed.
|
||||
|
||||
New dict has d[key] deleted for each supplied key.
|
||||
Does not modify the initial dictionary.
|
||||
|
||||
>>> dissoc({'x': 1, 'y': 2}, 'y')
|
||||
>>> dissoc({"x": 1, "y": 2}, "y")
|
||||
{'x': 1}
|
||||
>>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
|
||||
>>> dissoc({"x": 1, "y": 2}, "y", "x")
|
||||
{}
|
||||
>>> dissoc({'x': 1}, 'y') # Ignores missing keys
|
||||
>>> dissoc({"x": 1}, "y") # Ignores missing keys
|
||||
{'x': 1}
|
||||
"""
|
||||
factory = _get_factory(dissoc, kwargs)
|
||||
d2 = factory()
|
||||
|
||||
if len(keys) < len(d) * .6:
|
||||
if len(keys) < len(d) * 0.6:
|
||||
d2.update(d)
|
||||
for key in keys:
|
||||
if key in d2:
|
||||
@ -227,13 +242,14 @@ def dissoc(d, *keys, **kwargs):
|
||||
|
||||
|
||||
def assoc_in(d, keys, value, factory=dict):
|
||||
""" Return a new dict with new, potentially nested, key value pair
|
||||
"""Return a new dict with new, potentially nested, key value pair
|
||||
|
||||
>>> purchase = {'name': 'Alice',
|
||||
... 'order': {'items': ['Apple', 'Orange'],
|
||||
... 'costs': [0.50, 1.25]},
|
||||
... 'credit card': '5555-1234-1234-1234'}
|
||||
>>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
|
||||
>>> purchase = {
|
||||
... "name": "Alice",
|
||||
... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
|
||||
... "credit card": "5555-1234-1234-1234",
|
||||
... }
|
||||
>>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP
|
||||
{'credit card': '5555-1234-1234-1234',
|
||||
'name': 'Alice',
|
||||
'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
|
||||
@ -242,7 +258,7 @@ def assoc_in(d, keys, value, factory=dict):
|
||||
|
||||
|
||||
def update_in(d, keys, func, default=None, factory=dict):
|
||||
""" Update value in a (potentially) nested dictionary
|
||||
"""Update value in a (potentially) nested dictionary
|
||||
|
||||
inputs:
|
||||
d - dictionary on which to operate
|
||||
@ -257,14 +273,15 @@ def update_in(d, keys, func, default=None, factory=dict):
|
||||
specified by the keys, with the innermost value set to func(default).
|
||||
|
||||
>>> inc = lambda x: x + 1
|
||||
>>> update_in({'a': 0}, ['a'], inc)
|
||||
>>> update_in({"a": 0}, ["a"], inc)
|
||||
{'a': 1}
|
||||
|
||||
>>> transaction = {'name': 'Alice',
|
||||
... 'purchase': {'items': ['Apple', 'Orange'],
|
||||
... 'costs': [0.50, 1.25]},
|
||||
... 'credit card': '5555-1234-1234-1234'}
|
||||
>>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
|
||||
>>> transaction = {
|
||||
... "name": "Alice",
|
||||
... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
|
||||
... "credit card": "5555-1234-1234-1234",
|
||||
... }
|
||||
>>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP
|
||||
{'credit card': '5555-1234-1234-1234',
|
||||
'name': 'Alice',
|
||||
'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
|
||||
@ -272,7 +289,7 @@ def update_in(d, keys, func, default=None, factory=dict):
|
||||
>>> # updating a value when k0 is not in d
|
||||
>>> update_in({}, [1, 2, 3], str, default="bar")
|
||||
{1: {2: {3: 'bar'}}}
|
||||
>>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
|
||||
>>> update_in({1: "foo"}, [2, 3, 4], inc, 0)
|
||||
{1: 'foo', 2: {3: {4: 1}}}
|
||||
"""
|
||||
ks = iter(keys)
|
||||
@ -300,7 +317,7 @@ def update_in(d, keys, func, default=None, factory=dict):
|
||||
|
||||
|
||||
def get_in(keys, coll, default=None, no_default=False):
|
||||
""" Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
|
||||
"""Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
|
||||
|
||||
If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
|
||||
``no_default`` is specified, then it raises KeyError or IndexError.
|
||||
@ -308,20 +325,21 @@ def get_in(keys, coll, default=None, no_default=False):
|
||||
``get_in`` is a generalization of ``operator.getitem`` for nested data
|
||||
structures such as dictionaries and lists.
|
||||
|
||||
>>> transaction = {'name': 'Alice',
|
||||
... 'purchase': {'items': ['Apple', 'Orange'],
|
||||
... 'costs': [0.50, 1.25]},
|
||||
... 'credit card': '5555-1234-1234-1234'}
|
||||
>>> get_in(['purchase', 'items', 0], transaction)
|
||||
>>> transaction = {
|
||||
... "name": "Alice",
|
||||
... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]},
|
||||
... "credit card": "5555-1234-1234-1234",
|
||||
... }
|
||||
>>> get_in(["purchase", "items", 0], transaction)
|
||||
'Apple'
|
||||
>>> get_in(['name'], transaction)
|
||||
>>> get_in(["name"], transaction)
|
||||
'Alice'
|
||||
>>> get_in(['purchase', 'total'], transaction)
|
||||
>>> get_in(['purchase', 'items', 'apple'], transaction)
|
||||
>>> get_in(['purchase', 'items', 10], transaction)
|
||||
>>> get_in(['purchase', 'total'], transaction, 0)
|
||||
>>> get_in(["purchase", "total"], transaction)
|
||||
>>> get_in(["purchase", "items", "apple"], transaction)
|
||||
>>> get_in(["purchase", "items", 10], transaction)
|
||||
>>> get_in(["purchase", "total"], transaction, 0)
|
||||
0
|
||||
>>> get_in(['y'], {}, no_default=True)
|
||||
>>> get_in(["y"], {}, no_default=True)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
KeyError: 'y'
|
||||
@ -352,9 +370,9 @@ def getter(index):
|
||||
|
||||
|
||||
def groupby(key, seq):
|
||||
""" Group a collection by a key function
|
||||
"""Group a collection by a key function
|
||||
|
||||
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
|
||||
>>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]
|
||||
>>> groupby(len, names) # doctest: +SKIP
|
||||
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
|
||||
|
||||
@ -364,9 +382,14 @@ def groupby(key, seq):
|
||||
|
||||
Non-callable keys imply grouping on a member.
|
||||
|
||||
>>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
|
||||
... {'name': 'Bob', 'gender': 'M'},
|
||||
... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
|
||||
>>> groupby(
|
||||
... "gender",
|
||||
... [
|
||||
... {"name": "Alice", "gender": "F"},
|
||||
... {"name": "Bob", "gender": "M"},
|
||||
... {"name": "Charlie", "gender": "M"},
|
||||
... ],
|
||||
... ) # doctest:+SKIP
|
||||
{'F': [{'gender': 'F', 'name': 'Alice'}],
|
||||
'M': [{'gender': 'M', 'name': 'Bob'},
|
||||
{'gender': 'M', 'name': 'Charlie'}]}
|
||||
@ -388,9 +411,9 @@ def groupby(key, seq):
|
||||
|
||||
|
||||
def first(seq):
|
||||
""" The first element in a sequence
|
||||
"""The first element in a sequence
|
||||
|
||||
>>> first('ABC')
|
||||
>>> first("ABC")
|
||||
'A'
|
||||
"""
|
||||
return next(iter(seq))
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
|
||||
|
||||
|
||||
def hashable(x):
|
||||
try:
|
||||
hash(x)
|
||||
@ -9,7 +11,7 @@ def hashable(x):
|
||||
|
||||
|
||||
def transitive_get(key, d):
|
||||
""" Transitive dict.get
|
||||
"""Transitive dict.get
|
||||
>>> d = {1: 2, 2: 3, 3: 4}
|
||||
>>> d.get(1)
|
||||
2
|
||||
@ -32,13 +34,13 @@ def raises(err, lamda):
|
||||
# Taken from theano/theano/gof/sched.py
|
||||
# Avoids licensing issues because this was written by Matthew Rocklin
|
||||
def _toposort(edges):
|
||||
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
||||
"""Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
||||
inputs:
|
||||
edges - a dict of the form {a: {b, c}} where b and c depend on a
|
||||
outputs:
|
||||
L - an ordered list of nodes that satisfy the dependencies of edges
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> _toposort({1: (2, 3), 2: (3, )})
|
||||
>>> _toposort({1: (2, 3), 2: (3,)})
|
||||
[1, 2, 3]
|
||||
Closely follows the wikipedia page [2]
|
||||
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
|
||||
@ -47,7 +49,7 @@ def _toposort(edges):
|
||||
"""
|
||||
incoming_edges = reverse_dict(edges)
|
||||
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
|
||||
S = ({v for v in edges if v not in incoming_edges})
|
||||
S = {v for v in edges if v not in incoming_edges}
|
||||
L = []
|
||||
|
||||
while S:
|
||||
@ -65,7 +67,7 @@ def _toposort(edges):
|
||||
|
||||
def reverse_dict(d):
|
||||
"""Reverses direction of dependence dict
|
||||
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
|
||||
>>> d = {"a": (1, 2), "b": (2, 3), "c": ()}
|
||||
>>> reverse_dict(d) # doctest: +SKIP
|
||||
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
|
||||
:note: dict order are not deterministic. As we iterate on the
|
||||
@ -89,12 +91,12 @@ def xfail(func):
|
||||
|
||||
|
||||
def freeze(d):
|
||||
""" Freeze container to hashable form
|
||||
"""Freeze container to hashable form
|
||||
>>> freeze(1)
|
||||
1
|
||||
>>> freeze([1, 2])
|
||||
(1, 2)
|
||||
>>> freeze({1: 2}) # doctest: +SKIP
|
||||
>>> freeze({1: 2}) # doctest: +SKIP
|
||||
frozenset([(1, 2)])
|
||||
"""
|
||||
if isinstance(d, dict):
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from contextlib import contextmanager
|
||||
from .utils import hashable
|
||||
|
||||
from .dispatch import dispatch
|
||||
from .utils import hashable
|
||||
|
||||
|
||||
_global_logic_variables = set() # type: ignore[var-annotated]
|
||||
_glv = _global_logic_variables
|
||||
|
||||
|
||||
class Var:
|
||||
""" Logic Variable """
|
||||
"""Logic Variable"""
|
||||
|
||||
_id = 1
|
||||
|
||||
@ -25,6 +27,7 @@ class Var:
|
||||
|
||||
def __str__(self):
|
||||
return "~" + str(self.token) # type: ignore[attr-defined]
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -46,6 +49,7 @@ def vars():
|
||||
def isvar(v):
|
||||
return True
|
||||
|
||||
|
||||
isvar
|
||||
|
||||
|
||||
@ -69,12 +73,12 @@ def variables(*variables):
|
||||
False
|
||||
>>> # Normal approach
|
||||
>>> from unification import unify
|
||||
>>> x = var('x')
|
||||
>>> x = var("x")
|
||||
>>> unify(x, 1)
|
||||
{~x: 1}
|
||||
>>> # Context Manager approach
|
||||
>>> with variables('x'):
|
||||
... print(unify('x', 1))
|
||||
>>> with variables("x"):
|
||||
... print(unify("x", 1))
|
||||
{'x': 1}
|
||||
"""
|
||||
old_global_logic_variables = _global_logic_variables.copy()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from torch.fx.experimental.graph_gradual_typechecker import Refine
|
||||
from torch.fx.experimental.unification import unify, Var # type: ignore[attr-defined]
|
||||
from torch.fx.tensor_type import TensorType
|
||||
from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def infer_symbolic_types_single_pass(traced):
|
||||
@ -13,6 +13,7 @@ def infer_symbolic_types_single_pass(traced):
|
||||
mgu = unify_eq(r.constraints)
|
||||
substitute_all_types(traced.graph, mgu)
|
||||
|
||||
|
||||
def infer_symbolic_types(traced):
|
||||
"""
|
||||
Calls our symbolic inferencer twice.
|
||||
@ -32,6 +33,7 @@ def infer_symbolic_types(traced):
|
||||
|
||||
r.symbolic_relations()
|
||||
|
||||
|
||||
def convert_eq(list_of_eq):
|
||||
"""
|
||||
Convert equality constraints in the right format
|
||||
@ -109,6 +111,7 @@ def substitute_all_types(graph, mapping):
|
||||
for n in graph.nodes:
|
||||
n.type = substitute_solution_one_type(mapping, n.type)
|
||||
|
||||
|
||||
def check_for_type_equality(g1, g2):
|
||||
"""
|
||||
A check equality to be used in fixed points.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -19,6 +19,7 @@ from torch.package import Importer, PackageExporter, PackageImporter, sys_import
|
||||
from ._compatibility import compatibility
|
||||
from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode
|
||||
|
||||
|
||||
__all__ = [
|
||||
"reduce_graph_module",
|
||||
"reduce_package_graph_module",
|
||||
@ -386,11 +387,9 @@ class _WrappedCall:
|
||||
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||
except Exception as e:
|
||||
assert e.__traceback__
|
||||
topmost_framesummary: (
|
||||
traceback.FrameSummary
|
||||
) = traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[
|
||||
-1
|
||||
] # type: ignore[arg-type]
|
||||
topmost_framesummary: traceback.FrameSummary = (
|
||||
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]
|
||||
)
|
||||
if "eval_with_key" in topmost_framesummary.filename:
|
||||
print(
|
||||
_WrappedCall._generate_error_message(topmost_framesummary),
|
||||
@ -612,20 +611,20 @@ class {module_name}(torch.nn.Module):
|
||||
module_str = (
|
||||
f"torch.load(r'{module_file}', weights_only=False) # {module_repr}"
|
||||
)
|
||||
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
|
||||
model_str += f"{tab * 2}self.{module_name} = {module_str}\n"
|
||||
|
||||
for buffer_name, buffer in self._buffers.items():
|
||||
if buffer is None:
|
||||
continue
|
||||
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
|
||||
model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950
|
||||
|
||||
for param_name, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
|
||||
model_str += f"{tab * 2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" # noqa: B950
|
||||
|
||||
model_str += (
|
||||
f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
|
||||
f"{tab * 2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
|
||||
)
|
||||
model_str += f"{_addindent(self.code, 4)}\n"
|
||||
|
||||
@ -667,7 +666,6 @@ class {module_name}(torch.nn.Module):
|
||||
mod: torch.nn.Module = self
|
||||
|
||||
for item in prefix:
|
||||
|
||||
submod = getattr(mod, item, None)
|
||||
|
||||
if submod is None:
|
||||
@ -707,7 +705,6 @@ class {module_name}(torch.nn.Module):
|
||||
|
||||
# Get the parent module
|
||||
for item in path:
|
||||
|
||||
if not hasattr(mod, item):
|
||||
return False
|
||||
|
||||
@ -743,9 +740,7 @@ class {module_name}(torch.nn.Module):
|
||||
used: List[str] = []
|
||||
|
||||
for node in self.graph.nodes:
|
||||
|
||||
if node.op == "call_module" or node.op == "get_attr":
|
||||
|
||||
# A list of strings representing the different parts
|
||||
# of the path. For example, `foo.bar.baz` gives us
|
||||
# ["foo", "bar", "baz"]
|
||||
|
||||
@ -1,20 +1,24 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from .graph_module import GraphModule
|
||||
from ._lazy_graph_module import _make_graph_module
|
||||
from .graph import Graph
|
||||
from .node import Argument, Node, Target, map_arg, map_aggregate
|
||||
from .proxy import Proxy
|
||||
from ._symbolic_trace import Tracer
|
||||
from ._compatibility import compatibility
|
||||
from . import config
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch.hub import tqdm
|
||||
|
||||
__all__ = ['Interpreter', 'Transformer']
|
||||
from . import config
|
||||
from ._compatibility import compatibility
|
||||
from ._lazy_graph_module import _make_graph_module
|
||||
from ._symbolic_trace import Tracer
|
||||
from .graph import Graph
|
||||
from .graph_module import GraphModule
|
||||
from .node import Argument, map_aggregate, map_arg, Node, Target
|
||||
from .proxy import Proxy
|
||||
|
||||
|
||||
__all__ = ["Interpreter", "Transformer"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Interpreter:
|
||||
@ -43,22 +47,22 @@ class Interpreter:
|
||||
method equivalents). We could subclass Interpreter like so::
|
||||
|
||||
class NegSigmSwapInterpreter(Interpreter):
|
||||
def call_function(self, target : Target,
|
||||
args : Tuple, kwargs : Dict) -> Any:
|
||||
def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
|
||||
if target == torch.sigmoid:
|
||||
return torch.neg(*args, **kwargs)
|
||||
return super().call_function(n)
|
||||
|
||||
def call_method(self, target : Target,
|
||||
args : Tuple, kwargs : Dict) -> Any:
|
||||
if target == 'neg':
|
||||
def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any:
|
||||
if target == "neg":
|
||||
call_self, *args_tail = args
|
||||
return call_self.sigmoid(*args_tail, **kwargs)
|
||||
return super().call_method(n)
|
||||
|
||||
|
||||
def fn(x):
|
||||
return torch.sigmoid(x).neg()
|
||||
|
||||
|
||||
gm = torch.fx.symbolic_trace(fn)
|
||||
input = torch.randn(3, 4)
|
||||
result = NegSigmSwapInterpreter(gm).run(input)
|
||||
@ -74,15 +78,21 @@ class Interpreter:
|
||||
graph instead of `module.graph`, using the provided `module`
|
||||
argument to satisfy any requests for state.
|
||||
"""
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None):
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
garbage_collect_values: bool = True,
|
||||
graph: Optional[Graph] = None,
|
||||
):
|
||||
self.module = module
|
||||
self.submodules = dict(self.module.named_modules())
|
||||
if graph is not None:
|
||||
self.graph = graph
|
||||
else:
|
||||
self.graph = self.module.graph
|
||||
self.env : Dict[Node, Any] = {}
|
||||
self.env: Dict[Node, Any] = {}
|
||||
self.name = "Interpreter"
|
||||
self.garbage_collect_values = garbage_collect_values
|
||||
self.extra_traceback = True
|
||||
@ -92,10 +102,10 @@ class Interpreter:
|
||||
# of a given node. This represents the *last* use of the node in the
|
||||
# execution order of the program, which we will use to free unused
|
||||
# values
|
||||
node_to_last_use : Dict[Node, Node] = {}
|
||||
self.user_to_last_uses : Dict[Node, List[Node]] = {}
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
self.user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
|
||||
def register_last_uses(n : Node, user : Node):
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
if n not in node_to_last_use:
|
||||
node_to_last_use[n] = user
|
||||
self.user_to_last_uses.setdefault(user, []).append(n)
|
||||
@ -105,7 +115,12 @@ class Interpreter:
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
|
||||
def run(
|
||||
self,
|
||||
*args,
|
||||
initial_env: Optional[Dict[Node, Any]] = None,
|
||||
enable_io_processing: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Run `module` via interpretation and return the result.
|
||||
|
||||
@ -128,10 +143,16 @@ class Interpreter:
|
||||
# position and extract those values.
|
||||
if enable_io_processing:
|
||||
args = self.graph.process_inputs(*args)
|
||||
self.args_iter : Iterator[Any] = iter(args)
|
||||
pbar = tqdm(total=len(self.graph.nodes),
|
||||
desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
|
||||
initial=0, position=0, leave=True, disable=config.disable_progress, delay=0)
|
||||
self.args_iter: Iterator[Any] = iter(args)
|
||||
pbar = tqdm(
|
||||
total=len(self.graph.nodes),
|
||||
desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}",
|
||||
initial=0,
|
||||
position=0,
|
||||
leave=True,
|
||||
disable=config.disable_progress,
|
||||
delay=0,
|
||||
)
|
||||
|
||||
for node in self.graph.nodes:
|
||||
pbar.update(1)
|
||||
@ -147,7 +168,7 @@ class Interpreter:
|
||||
except Exception as e:
|
||||
if self.extra_traceback:
|
||||
msg = f"While executing {node.format_node()}"
|
||||
msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg)
|
||||
msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg)
|
||||
msg += f"\nOriginal traceback:\n{node.stack_trace}"
|
||||
e.args = (msg,) + e.args[1:]
|
||||
if isinstance(e, KeyError):
|
||||
@ -158,9 +179,13 @@ class Interpreter:
|
||||
for to_delete in self.user_to_last_uses.get(node, []):
|
||||
del self.env[to_delete]
|
||||
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
output_val = self.env[node]
|
||||
return self.graph.process_outputs(output_val) if enable_io_processing else output_val
|
||||
return (
|
||||
self.graph.process_outputs(output_val)
|
||||
if enable_io_processing
|
||||
else output_val
|
||||
)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def boxed_run(self, args_list):
|
||||
@ -183,7 +208,7 @@ class Interpreter:
|
||||
yield
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def run_node(self, n : Node) -> Any:
|
||||
def run_node(self, n: Node) -> Any:
|
||||
"""
|
||||
Run a specific node ``n`` and return the result.
|
||||
Calls into placeholder, get_attr, call_function,
|
||||
@ -204,7 +229,9 @@ class Interpreter:
|
||||
|
||||
# Main Node running APIs
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def placeholder(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a ``placeholder`` node. Note that this is stateful:
|
||||
``Interpreter`` maintains an internal iterator over
|
||||
@ -222,7 +249,7 @@ class Interpreter:
|
||||
Any: The argument value that was retrieved.
|
||||
"""
|
||||
assert isinstance(target, str)
|
||||
if target.startswith('*'):
|
||||
if target.startswith("*"):
|
||||
# For a starred parameter e.g. `*args`, retrieve all
|
||||
# remaining values from the args list.
|
||||
return list(self.args_iter)
|
||||
@ -233,10 +260,14 @@ class Interpreter:
|
||||
if len(args) > 0:
|
||||
return args[0]
|
||||
else:
|
||||
raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
|
||||
raise RuntimeError(
|
||||
f"Expected positional argument for parameter {target}, but one was not passed in!"
|
||||
) from si
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def get_attr(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||
value from the ``Module`` hierarchy of ``self.module``.
|
||||
@ -255,7 +286,9 @@ class Interpreter:
|
||||
return self.fetch_attr(target)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_function(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the result.
|
||||
|
||||
@ -275,7 +308,9 @@ class Interpreter:
|
||||
return target(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_method(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the result.
|
||||
|
||||
@ -297,7 +332,9 @@ class Interpreter:
|
||||
return getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_module(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node and return the result.
|
||||
|
||||
@ -320,7 +357,9 @@ class Interpreter:
|
||||
return submod(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def output(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Execute an ``output`` node. This really just retrieves
|
||||
the value referenced by the ``output`` node and returns it.
|
||||
@ -339,7 +378,7 @@ class Interpreter:
|
||||
|
||||
# Helper methods
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def fetch_attr(self, target : str):
|
||||
def fetch_attr(self, target: str):
|
||||
"""
|
||||
Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
|
||||
|
||||
@ -349,16 +388,18 @@ class Interpreter:
|
||||
Return:
|
||||
Any: The value of the attribute.
|
||||
"""
|
||||
target_atoms = target.split('.')
|
||||
target_atoms = target.split(".")
|
||||
attr_itr = self.module
|
||||
for i, atom in enumerate(target_atoms):
|
||||
if not hasattr(attr_itr, atom):
|
||||
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}")
|
||||
raise RuntimeError(
|
||||
f"Node referenced nonexistent target {'.'.join(target_atoms[:i + 1])}"
|
||||
)
|
||||
attr_itr = getattr(attr_itr, atom)
|
||||
return attr_itr
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
|
||||
def fetch_args_kwargs_from_env(self, n: Node) -> Tuple[Tuple, Dict]:
|
||||
"""
|
||||
Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
|
||||
from the current execution environment.
|
||||
@ -376,7 +417,7 @@ class Interpreter:
|
||||
return args, kwargs
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
|
||||
def map_nodes_to_values(self, args: Argument, n: Node) -> Argument:
|
||||
"""
|
||||
Recursively descend through ``args`` and look up the concrete value
|
||||
for each ``Node`` in the current execution environment.
|
||||
@ -386,13 +427,18 @@ class Interpreter:
|
||||
|
||||
n (Node): Node to which ``args`` belongs. This is only used for error reporting.
|
||||
"""
|
||||
def load_arg(n_arg : Node) -> Any:
|
||||
|
||||
def load_arg(n_arg: Node) -> Any:
|
||||
if n_arg not in self.env:
|
||||
raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
|
||||
f'to diagnose such issues')
|
||||
raise RuntimeError(
|
||||
f"Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() "
|
||||
f"to diagnose such issues"
|
||||
)
|
||||
return self.env[n_arg]
|
||||
|
||||
return map_arg(args, load_arg)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Transformer(Interpreter):
|
||||
"""
|
||||
@ -409,23 +455,29 @@ class Transformer(Interpreter):
|
||||
method equivalents). We could subclass ``Transformer`` like so::
|
||||
|
||||
class NegSigmSwapXformer(Transformer):
|
||||
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_function(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
if target == torch.sigmoid:
|
||||
return torch.neg(*args, **kwargs)
|
||||
return super().call_function(n)
|
||||
|
||||
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
if target == 'neg':
|
||||
def call_method(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
if target == "neg":
|
||||
call_self, *args_tail = args
|
||||
return call_self.sigmoid(*args_tail, **kwargs)
|
||||
return super().call_method(n)
|
||||
|
||||
|
||||
def fn(x):
|
||||
return torch.sigmoid(x).neg()
|
||||
|
||||
|
||||
gm = torch.fx.symbolic_trace(fn)
|
||||
|
||||
transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
|
||||
transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
|
||||
input = torch.randn(3, 4)
|
||||
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
|
||||
|
||||
@ -452,7 +504,9 @@ class Transformer(Interpreter):
|
||||
self.tracer.root = module
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
|
||||
def placeholder(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Proxy:
|
||||
"""
|
||||
Execute a ``placeholder`` node. In ``Transformer``, this is
|
||||
overridden to insert a new ``placeholder`` into the output
|
||||
@ -467,10 +521,14 @@ class Transformer(Interpreter):
|
||||
"""
|
||||
assert isinstance(target, str)
|
||||
default_value = next(iter(args)) if args else inspect.Signature.empty
|
||||
return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
|
||||
return Proxy(
|
||||
self.new_graph.placeholder(target, default_value=default_value), self.tracer
|
||||
)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
|
||||
def get_attr(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Proxy:
|
||||
"""
|
||||
Execute a ``get_attr`` node. In ``Transformer``, this is
|
||||
overridden to insert a new ``get_attr`` node into the output
|
||||
@ -487,16 +545,20 @@ class Transformer(Interpreter):
|
||||
return self.tracer.create_proxy("get_attr", target, args, kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_module(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
# Override so that the leaf module policy from `self.tracer` is respected.
|
||||
assert isinstance(target, str)
|
||||
submod = self.fetch_attr(target)
|
||||
return self.tracer.call_module(submod, submod.forward, args, kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
|
||||
def call_function(
|
||||
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
# Override so that functions that were wrapped are still wrapped.
|
||||
return self.tracer.create_proxy('call_function', target, args, kwargs)
|
||||
return self.tracer.create_proxy("call_function", target, args, kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def transform(self) -> GraphModule:
|
||||
@ -507,8 +569,10 @@ class Transformer(Interpreter):
|
||||
with fx_traceback.preserve_node_meta():
|
||||
result = super().run(enable_io_processing=False)
|
||||
if result is not None:
|
||||
def strip_proxy(a : Union[Argument, Proxy]) -> Any:
|
||||
|
||||
def strip_proxy(a: Union[Argument, Proxy]) -> Any:
|
||||
return a.node if isinstance(a, Proxy) else a
|
||||
|
||||
new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy))
|
||||
# also preserve the metadata from the old output node, if it exists
|
||||
old_output_node = list(self.graph.nodes)[-1]
|
||||
@ -516,5 +580,4 @@ class Transformer(Interpreter):
|
||||
for k, v in old_output_node.meta.items():
|
||||
new_output_node.meta[k] = v
|
||||
|
||||
|
||||
return _make_graph_module(self.module, self.new_graph)
|
||||
|
||||
364
torch/fx/node.py
364
torch/fx/node.py
@ -1,39 +1,71 @@
|
||||
# Nodes represent a definition of a value in our graph of operators.
|
||||
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
|
||||
import builtins
|
||||
import inspect
|
||||
import types
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._C import _NodeBase
|
||||
from torch.fx.operator_schemas import (
|
||||
ArgsKwargsPair,
|
||||
normalize_function,
|
||||
normalize_module,
|
||||
)
|
||||
|
||||
from .._ops import ops as _ops
|
||||
from ._compatibility import compatibility
|
||||
from .immutable_collections import immutable_dict, immutable_list
|
||||
import torch
|
||||
import builtins
|
||||
import types
|
||||
import inspect
|
||||
import warnings
|
||||
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
|
||||
from .._ops import ops as _ops
|
||||
from torch._C import _NodeBase
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
|
||||
__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"]
|
||||
__all__ = ["Node", "map_arg", "map_aggregate", "has_side_effect"]
|
||||
|
||||
BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
|
||||
torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload,
|
||||
torch.SymInt, torch.SymBool, torch.SymFloat]
|
||||
BaseArgumentTypes = Union[
|
||||
str,
|
||||
int,
|
||||
float,
|
||||
bool,
|
||||
complex,
|
||||
torch.dtype,
|
||||
torch.Tensor,
|
||||
torch.device,
|
||||
torch.memory_format,
|
||||
torch.layout,
|
||||
torch._ops.OpOverload,
|
||||
torch.SymInt,
|
||||
torch.SymBool,
|
||||
torch.SymFloat,
|
||||
]
|
||||
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]
|
||||
|
||||
Target = Union[Callable[..., Any], str]
|
||||
|
||||
Argument = Optional[Union[
|
||||
Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
|
||||
List[Any], # actually Argument
|
||||
Dict[str, Any], # actually Argument
|
||||
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
|
||||
range,
|
||||
'Node',
|
||||
BaseArgumentTypes
|
||||
]]
|
||||
Argument = Optional[
|
||||
Union[
|
||||
Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
|
||||
List[Any], # actually Argument
|
||||
Dict[str, Any], # actually Argument
|
||||
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
|
||||
range,
|
||||
"Node",
|
||||
BaseArgumentTypes,
|
||||
]
|
||||
]
|
||||
|
||||
_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'])
|
||||
_legal_ops = dict.fromkeys(
|
||||
[
|
||||
"placeholder",
|
||||
"call_method",
|
||||
"call_module",
|
||||
"call_function",
|
||||
"get_attr",
|
||||
"output",
|
||||
"root",
|
||||
]
|
||||
)
|
||||
|
||||
_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
|
||||
torch._C._set_grad_enabled,
|
||||
@ -74,7 +106,8 @@ def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
|
||||
for guess in [torch, torch.nn.functional]:
|
||||
if getattr(guess, name, None) is orig_method:
|
||||
return guess.__name__
|
||||
raise RuntimeError(f'cannot find module for {orig_method}')
|
||||
raise RuntimeError(f"cannot find module for {orig_method}")
|
||||
|
||||
|
||||
# Borrowed from CPython typing module
|
||||
# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
|
||||
@ -86,22 +119,24 @@ def _type_repr(obj: object) -> str:
|
||||
else, we fall back on repr(obj).
|
||||
"""
|
||||
if isinstance(obj, type):
|
||||
if obj.__module__ == 'builtins':
|
||||
if obj.__module__ == "builtins":
|
||||
return obj.__qualname__
|
||||
return f'{obj.__module__}.{obj.__qualname__}'
|
||||
return f"{obj.__module__}.{obj.__qualname__}"
|
||||
if obj is ...:
|
||||
return '...'
|
||||
return "..."
|
||||
if isinstance(obj, types.FunctionType):
|
||||
return obj.__name__
|
||||
return repr(obj)
|
||||
|
||||
|
||||
def _get_qualified_name(func: Callable[..., Any]) -> str:
|
||||
# things like getattr just appear in builtins
|
||||
if getattr(builtins, func.__name__, None) is func:
|
||||
return func.__name__
|
||||
# torch.Tensor.{fn}
|
||||
if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType))
|
||||
and func is getattr(torch.Tensor, func.__name__, None)):
|
||||
if isinstance(
|
||||
func, (types.MethodDescriptorType, types.WrapperDescriptorType)
|
||||
) and func is getattr(torch.Tensor, func.__name__, None):
|
||||
return f"torch.Tensor.{func.__name__}"
|
||||
name = func.__name__
|
||||
if name == "<lambda>":
|
||||
@ -111,33 +146,45 @@ def _get_qualified_name(func: Callable[..., Any]) -> str:
|
||||
except Exception as e:
|
||||
raise RuntimeError("Unable to represent lambda") from e
|
||||
module = _find_module_of_method(func)
|
||||
module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
|
||||
module = module.replace(
|
||||
"torch._ops", "torch.ops"
|
||||
) # WAR for bug in how torch.ops assigns module
|
||||
# Fixup segment_reduce mismatch
|
||||
if module == "torch" and name == "segment_reduce":
|
||||
name = "_" + name
|
||||
return f'{module}.{name}'
|
||||
return f"{module}.{name}"
|
||||
|
||||
def _format_arg(arg: object, max_list_len: float = float('inf')) -> str:
|
||||
if hasattr(arg, '_custom_fx_repr_fn'):
|
||||
|
||||
def _format_arg(arg: object, max_list_len: float = float("inf")) -> str:
|
||||
if hasattr(arg, "_custom_fx_repr_fn"):
|
||||
return arg._custom_fx_repr_fn()
|
||||
elif isinstance(arg, list):
|
||||
items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
|
||||
maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
|
||||
return f'[{items}{maybe_len}]'
|
||||
items = ", ".join(
|
||||
_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len
|
||||
)
|
||||
maybe_len = (
|
||||
"" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]"
|
||||
)
|
||||
return f"[{items}{maybe_len}]"
|
||||
elif isinstance(arg, tuple):
|
||||
items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
|
||||
maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
|
||||
maybe_comma = ',' if len(arg) == 1 else ''
|
||||
return f'({items}{maybe_comma}{maybe_len})'
|
||||
items = ", ".join(
|
||||
_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len
|
||||
)
|
||||
maybe_len = (
|
||||
"" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]"
|
||||
)
|
||||
maybe_comma = "," if len(arg) == 1 else ""
|
||||
return f"({items}{maybe_comma}{maybe_len})"
|
||||
elif isinstance(arg, dict):
|
||||
items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items())
|
||||
return f'{{{items_str}}}'
|
||||
items_str = ", ".join(f"{k}: {_format_arg(v)}" for k, v in arg.items())
|
||||
return f"{{{items_str}}}"
|
||||
|
||||
if isinstance(arg, Node):
|
||||
return '%' + str(arg)
|
||||
return "%" + str(arg)
|
||||
else:
|
||||
return str(arg)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Node(_NodeBase):
|
||||
"""
|
||||
@ -166,23 +213,31 @@ class Node(_NodeBase):
|
||||
- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
|
||||
in the Graph printout.
|
||||
"""
|
||||
_args: Tuple['Argument', ...]
|
||||
_kwargs: Dict[str, 'Argument']
|
||||
graph: 'Graph'
|
||||
|
||||
_args: Tuple["Argument", ...]
|
||||
_kwargs: Dict[str, "Argument"]
|
||||
graph: "Graph"
|
||||
name: str
|
||||
op: str
|
||||
target: 'Target'
|
||||
_input_nodes: Dict['Node', None]
|
||||
users: Dict['Node', None]
|
||||
target: "Target"
|
||||
_input_nodes: Dict["Node", None]
|
||||
users: Dict["Node", None]
|
||||
type: Optional[Any]
|
||||
_sort_key: Any
|
||||
_repr_fn: Optional[Callable[['Node'], str]]
|
||||
_repr_fn: Optional[Callable[["Node"], str]]
|
||||
meta: Dict[str, Any]
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
|
||||
args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
|
||||
return_type : Optional[Any] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
graph: "Graph",
|
||||
name: str,
|
||||
op: str,
|
||||
target: "Target",
|
||||
args: Tuple["Argument", ...],
|
||||
kwargs: Dict[str, "Argument"],
|
||||
return_type: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Instantiate an instance of ``Node``. Note: most often, you want to use the
|
||||
Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather
|
||||
@ -210,14 +265,18 @@ class Node(_NodeBase):
|
||||
of analyses.
|
||||
"""
|
||||
assert op in _legal_ops
|
||||
if op == 'call_function':
|
||||
if op == "call_function":
|
||||
if not callable(target):
|
||||
raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
|
||||
'but a Callable is expected')
|
||||
raise ValueError(
|
||||
f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} "
|
||||
"but a Callable is expected"
|
||||
)
|
||||
else:
|
||||
if not isinstance(target, str):
|
||||
raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
|
||||
'but a str is expected')
|
||||
raise ValueError(
|
||||
f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} "
|
||||
"but a str is expected"
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
# bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects
|
||||
@ -225,9 +284,13 @@ class Node(_NodeBase):
|
||||
|
||||
assign(self, "graph", graph)
|
||||
assign(self, "name", name) # unique name of value being created
|
||||
assign(self, "op", op) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
|
||||
assign(
|
||||
self, "op", op
|
||||
) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
|
||||
|
||||
assign(self, "target", target) # for method/module/function, the name of the method/module/function/attr
|
||||
assign(
|
||||
self, "target", target
|
||||
) # for method/module/function, the name of the method/module/function/attr
|
||||
# being invoked, e.g add, layer1, or torch.add
|
||||
|
||||
# All `Node`-valued inputs. Key is the Node, value is don't-care.
|
||||
@ -280,7 +343,7 @@ class Node(_NodeBase):
|
||||
self._next = _next
|
||||
|
||||
@property
|
||||
def next(self) -> 'Node':
|
||||
def next(self) -> "Node":
|
||||
"""
|
||||
Returns the next ``Node`` in the linked list of Nodes.
|
||||
|
||||
@ -291,7 +354,7 @@ class Node(_NodeBase):
|
||||
return self._next
|
||||
|
||||
@property
|
||||
def prev(self) -> 'Node':
|
||||
def prev(self) -> "Node":
|
||||
"""
|
||||
Returns the previous ``Node`` in the linked list of Nodes.
|
||||
|
||||
@ -302,7 +365,7 @@ class Node(_NodeBase):
|
||||
return self._prev
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def prepend(self, x: 'Node') -> None:
|
||||
def prepend(self, x: "Node") -> None:
|
||||
"""
|
||||
Insert x before this node in the list of nodes in the graph. Example::
|
||||
|
||||
@ -316,7 +379,9 @@ class Node(_NodeBase):
|
||||
"""
|
||||
assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
|
||||
if self == x:
|
||||
warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.")
|
||||
warnings.warn(
|
||||
"Trying to prepend a node to itself. This behavior has no effect on the graph."
|
||||
)
|
||||
return
|
||||
x._remove_from_list()
|
||||
p = self._prev
|
||||
@ -328,28 +393,28 @@ class Node(_NodeBase):
|
||||
nsk = x._next._sort_key
|
||||
if len(psk) > len(nsk):
|
||||
idx: int
|
||||
*prefix, idx = psk[:len(nsk) + 1]
|
||||
*prefix, idx = psk[: len(nsk) + 1]
|
||||
x._sort_key = (*prefix, idx + 1)
|
||||
elif len(psk) < len(nsk):
|
||||
*prefix, idx = nsk[:len(psk) + 1]
|
||||
*prefix, idx = nsk[: len(psk) + 1]
|
||||
x._sort_key = (*prefix, idx - 1)
|
||||
else: # same length, increase length by 1
|
||||
x._sort_key = (*psk, 0)
|
||||
|
||||
def __gt__(self, other: 'Node') -> bool:
|
||||
def __gt__(self, other: "Node") -> bool:
|
||||
return self._sort_key > other._sort_key
|
||||
|
||||
def __lt__(self, other: 'Node') -> bool:
|
||||
def __lt__(self, other: "Node") -> bool:
|
||||
return self._sort_key < other._sort_key
|
||||
|
||||
def __ge__(self, other: 'Node') -> bool:
|
||||
def __ge__(self, other: "Node") -> bool:
|
||||
return self > other or self == other
|
||||
|
||||
def __le__(self, other: 'Node') -> bool:
|
||||
def __le__(self, other: "Node") -> bool:
|
||||
return self < other or self == other
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def append(self, x: 'Node') -> None:
|
||||
def append(self, x: "Node") -> None:
|
||||
"""
|
||||
Insert ``x`` after this node in the list of nodes in the graph.
|
||||
Equivalent to ``self.next.prepend(x)``
|
||||
@ -376,7 +441,7 @@ class Node(_NodeBase):
|
||||
return self._args
|
||||
|
||||
@args.setter
|
||||
def args(self, a : Tuple[Argument, ...]) -> None:
|
||||
def args(self, a: Tuple[Argument, ...]) -> None:
|
||||
"""
|
||||
Set the tuple of arguments to this Node. The interpretation of arguments
|
||||
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
||||
@ -399,7 +464,7 @@ class Node(_NodeBase):
|
||||
return self._kwargs
|
||||
|
||||
@kwargs.setter
|
||||
def kwargs(self, k : Dict[str, Argument]) -> None:
|
||||
def kwargs(self, k: Dict[str, Argument]) -> None:
|
||||
"""
|
||||
Set the dict of kwargs to this Node. The interpretation of arguments
|
||||
depends on the node's opcode. See the ``fx.Graph`` docstring for more
|
||||
@ -410,7 +475,7 @@ class Node(_NodeBase):
|
||||
self.__update_args_kwargs(self._args, k)
|
||||
|
||||
@property
|
||||
def all_input_nodes(self) -> List['Node']:
|
||||
def all_input_nodes(self) -> List["Node"]:
|
||||
"""
|
||||
Return all Nodes that are inputs to this Node. This is equivalent to
|
||||
iterating over ``args`` and ``kwargs`` and only collecting the values that
|
||||
@ -424,7 +489,7 @@ class Node(_NodeBase):
|
||||
return list(self._input_nodes.keys())
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def update_arg(self, idx : int, arg : Argument) -> None:
|
||||
def update_arg(self, idx: int, arg: Argument) -> None:
|
||||
"""
|
||||
Update an existing positional argument to contain the new value
|
||||
``arg``. After calling, ``self.args[idx] == arg``.
|
||||
@ -439,7 +504,7 @@ class Node(_NodeBase):
|
||||
self.args = tuple(args)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def insert_arg(self, idx : int, arg : Argument) -> None:
|
||||
def insert_arg(self, idx: int, arg: Argument) -> None:
|
||||
"""
|
||||
Insert an positional argument to the argument list with given index.
|
||||
|
||||
@ -448,7 +513,9 @@ class Node(_NodeBase):
|
||||
idx (int): The index of the element in ``self.args`` to be inserted before.
|
||||
arg (Argument): The new argument value to insert into ``args``
|
||||
"""
|
||||
assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)"
|
||||
assert (
|
||||
0 <= idx <= len(self.args)
|
||||
), "insert_args index must be between 0 and len(self.args)"
|
||||
args_left = self.args[:idx]
|
||||
args_right = self.args[idx:]
|
||||
|
||||
@ -463,7 +530,7 @@ class Node(_NodeBase):
|
||||
new_use.users.setdefault(self)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def update_kwarg(self, key : str, arg : Argument) -> None:
|
||||
def update_kwarg(self, key: str, arg: Argument) -> None:
|
||||
"""
|
||||
Update an existing keyword argument to contain the new value
|
||||
``arg``. After calling, ``self.kwargs[key] == arg``.
|
||||
@ -490,13 +557,16 @@ class Node(_NodeBase):
|
||||
return self.meta.get("stack_trace", None)
|
||||
|
||||
@stack_trace.setter
|
||||
def stack_trace(self, trace : Optional[str]) -> None:
|
||||
def stack_trace(self, trace: Optional[str]) -> None:
|
||||
self.meta["stack_trace"] = trace
|
||||
|
||||
def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None:
|
||||
def __update_args_kwargs(
|
||||
self, new_args: Tuple["Argument", ...], new_kwargs: Dict[str, "Argument"]
|
||||
) -> None:
|
||||
"""
|
||||
This API is internal. Do *not* call it directly.
|
||||
"""
|
||||
|
||||
def update_users_and_input_nodes(n: Any) -> Any:
|
||||
if isinstance(n, Node):
|
||||
self._input_nodes.setdefault(n)
|
||||
@ -512,8 +582,12 @@ class Node(_NodeBase):
|
||||
# - Normalize list->immutable_list, dict->immutable_dict, etc
|
||||
# - Populate self._input_nodes
|
||||
# - Populate arg.users[self] for each arg
|
||||
object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes))
|
||||
object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes))
|
||||
object.__setattr__(
|
||||
self, "_args", map_aggregate(new_args, update_users_and_input_nodes)
|
||||
)
|
||||
object.__setattr__(
|
||||
self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._repr_fn:
|
||||
@ -529,8 +603,8 @@ class Node(_NodeBase):
|
||||
"""
|
||||
if isinstance(target, str):
|
||||
return target
|
||||
if hasattr(target, '__module__'):
|
||||
name = getattr(target, '__name__', None)
|
||||
if hasattr(target, "__module__"):
|
||||
name = getattr(target, "__name__", None)
|
||||
if name is None:
|
||||
# Just to be defensive, if we don't have `__name__`, get the
|
||||
# qualname. Not sure if this happens for any members of `operator`
|
||||
@ -538,16 +612,18 @@ class Node(_NodeBase):
|
||||
# things in `operator` have `_operator` as their __module__.
|
||||
# TODO: THIS IS BROKEN: _get_qualified_name calls `__name__`
|
||||
return _get_qualified_name(target) # type: ignore[arg-type]
|
||||
if target.__module__ == 'builtins':
|
||||
return f'builtins.{name}'
|
||||
elif target.__module__ == '_operator':
|
||||
return f'operator.{name}'
|
||||
if target.__module__ == "builtins":
|
||||
return f"builtins.{name}"
|
||||
elif target.__module__ == "_operator":
|
||||
return f"operator.{name}"
|
||||
return _get_qualified_name(target) # type: ignore[arg-type]
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def format_node(self,
|
||||
placeholder_names: Optional[List[str]] = None,
|
||||
maybe_return_typename: Optional[List[str]] = None) -> Optional[str]:
|
||||
def format_node(
|
||||
self,
|
||||
placeholder_names: Optional[List[str]] = None,
|
||||
maybe_return_typename: Optional[List[str]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return a descriptive string representation of ``self``.
|
||||
|
||||
@ -576,37 +652,46 @@ class Node(_NodeBase):
|
||||
return a descriptive string representation of the
|
||||
current Node.
|
||||
"""
|
||||
if self.op == 'placeholder':
|
||||
if self.op == "placeholder":
|
||||
assert isinstance(self.target, str)
|
||||
arg_str = self.target
|
||||
arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else ''
|
||||
arg_str += arg_str + f": {_type_repr(self.type)}" if self.type else ""
|
||||
if placeholder_names:
|
||||
placeholder_names.append(arg_str)
|
||||
return None
|
||||
maybe_typename = f'{_type_repr(self.type)} ' if self.type else ''
|
||||
default_val = '(default=' + str(self.args[0]) + ')' if self.args else ''
|
||||
return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}'
|
||||
elif self.op == 'get_attr':
|
||||
maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
|
||||
return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
|
||||
f'{self.op}[target={self._pretty_print_target(self.target)}]'
|
||||
elif self.op == 'output':
|
||||
maybe_typename = f"{_type_repr(self.type)} " if self.type else ""
|
||||
default_val = "(default=" + str(self.args[0]) + ")" if self.args else ""
|
||||
return f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}"
|
||||
elif self.op == "get_attr":
|
||||
maybe_typename = (
|
||||
f"{_type_repr(self.type)} " if self.type is not None else ""
|
||||
)
|
||||
return (
|
||||
f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = "
|
||||
f"{self.op}[target={self._pretty_print_target(self.target)}]"
|
||||
)
|
||||
elif self.op == "output":
|
||||
if self.type and maybe_return_typename:
|
||||
maybe_return_typename[0] = f' -> {_type_repr(self.type)}'
|
||||
return f'return {self.args[0]}'
|
||||
maybe_return_typename[0] = f" -> {_type_repr(self.type)}"
|
||||
return f"return {self.args[0]}"
|
||||
else:
|
||||
maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
|
||||
return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
|
||||
f'{self.op}[target={self._pretty_print_target(self.target)}](' \
|
||||
f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})'
|
||||
maybe_typename = (
|
||||
f"{_type_repr(self.type)} " if self.type is not None else ""
|
||||
)
|
||||
return (
|
||||
f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = "
|
||||
f"{self.op}[target={self._pretty_print_target(self.target)}]("
|
||||
f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})"
|
||||
)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def replace_all_uses_with(self,
|
||||
replace_with: 'Node',
|
||||
delete_user_cb: Callable[['Node'], bool] = lambda user: True,
|
||||
*,
|
||||
propagate_meta: bool = False
|
||||
) -> List['Node']:
|
||||
def replace_all_uses_with(
|
||||
self,
|
||||
replace_with: "Node",
|
||||
delete_user_cb: Callable[["Node"], bool] = lambda user: True,
|
||||
*,
|
||||
propagate_meta: bool = False,
|
||||
) -> List["Node"]:
|
||||
"""
|
||||
Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
|
||||
|
||||
@ -625,9 +710,10 @@ class Node(_NodeBase):
|
||||
The list of Nodes on which this change was made.
|
||||
"""
|
||||
if propagate_meta:
|
||||
assert len(replace_with.meta) == 0, \
|
||||
'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \
|
||||
'but replace_with already has .meta keys'
|
||||
assert len(replace_with.meta) == 0, (
|
||||
"Called node.replace_all_uses_with(replace_with, propagate_meta=True), "
|
||||
"but replace_with already has .meta keys"
|
||||
)
|
||||
for k, v in self.meta.items():
|
||||
replace_with.meta[k] = v
|
||||
to_process = list(self.users)
|
||||
@ -638,7 +724,7 @@ class Node(_NodeBase):
|
||||
skipped.append(use_node)
|
||||
continue
|
||||
|
||||
def maybe_replace_node(n : Node) -> Node:
|
||||
def maybe_replace_node(n: Node) -> Node:
|
||||
if n == self:
|
||||
return replace_with
|
||||
else:
|
||||
@ -690,9 +776,12 @@ class Node(_NodeBase):
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def normalized_arguments(
|
||||
self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
|
||||
kwarg_types : Optional[Dict[str, Any]] = None,
|
||||
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
|
||||
self,
|
||||
root: torch.nn.Module,
|
||||
arg_types: Optional[Tuple[Any]] = None,
|
||||
kwarg_types: Optional[Dict[str, Any]] = None,
|
||||
normalize_to_only_use_kwargs: bool = False,
|
||||
) -> Optional[ArgsKwargsPair]:
|
||||
"""
|
||||
Returns normalized arguments to Python targets. This means that
|
||||
`args/kwargs` will be matched up to the module/functional's
|
||||
@ -715,17 +804,23 @@ class Node(_NodeBase):
|
||||
|
||||
Returns NamedTuple ArgsKwargsPair, or `None` if not successful.
|
||||
"""
|
||||
if self.op == 'call_function':
|
||||
if self.op == "call_function":
|
||||
assert callable(self.target)
|
||||
return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type]
|
||||
elif self.op == 'call_module':
|
||||
return normalize_function(
|
||||
self.target,
|
||||
self.args, # type: ignore[arg-type]
|
||||
self.kwargs,
|
||||
arg_types,
|
||||
kwarg_types,
|
||||
)
|
||||
elif self.op == "call_module":
|
||||
assert isinstance(self.target, str)
|
||||
return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type]
|
||||
|
||||
return None
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None:
|
||||
def replace_input_with(self, old_input: "Node", new_input: "Node") -> None:
|
||||
"""
|
||||
Loop through input nodes of ``self``, and replace all instances of
|
||||
``old_input`` with ``new_input``.
|
||||
@ -735,7 +830,8 @@ class Node(_NodeBase):
|
||||
old_input (Node): The old input node to be replaced.
|
||||
new_input (Node): The new input node to replace ``old_input``.
|
||||
"""
|
||||
def maybe_replace_node(n : Node) -> Node:
|
||||
|
||||
def maybe_replace_node(n: Node) -> Node:
|
||||
return new_input if n == old_input else n
|
||||
|
||||
m = self.graph.owning_module
|
||||
@ -756,7 +852,7 @@ class Node(_NodeBase):
|
||||
self.graph._graph_namespace._rename_object(self, name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name == 'name' and hasattr(self, "name"):
|
||||
if name == "name" and hasattr(self, "name"):
|
||||
m = self.graph.owning_module
|
||||
if getattr(m, "_replace_hook", None):
|
||||
assert isinstance(value, str)
|
||||
@ -764,9 +860,9 @@ class Node(_NodeBase):
|
||||
m._replace_hook(old=self, new=value, user=user)
|
||||
update = False
|
||||
if (
|
||||
hasattr(self, name) and
|
||||
hasattr(self.graph, "_find_nodes_lookup_table") and
|
||||
self in self.graph._find_nodes_lookup_table
|
||||
hasattr(self, name)
|
||||
and hasattr(self.graph, "_find_nodes_lookup_table")
|
||||
and self in self.graph._find_nodes_lookup_table
|
||||
):
|
||||
update = True
|
||||
self.graph._find_nodes_lookup_table.remove(self)
|
||||
@ -774,6 +870,7 @@ class Node(_NodeBase):
|
||||
if update:
|
||||
self.graph._find_nodes_lookup_table.insert(self)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
|
||||
"""
|
||||
@ -782,6 +879,7 @@ def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
|
||||
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
|
||||
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
|
||||
"""
|
||||
@ -790,7 +888,7 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
|
||||
if isinstance(a, tuple):
|
||||
t = tuple([map_aggregate(elem, fn) for elem in a])
|
||||
# Support NamedTuple (if it has `_fields`) by repacking into original type.
|
||||
return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type]
|
||||
return t if not hasattr(a, "_fields") else type(a)(*t) # type: ignore[arg-type]
|
||||
elif isinstance(a, list):
|
||||
return immutable_list([map_aggregate(elem, fn) for elem in a])
|
||||
elif isinstance(a, dict):
|
||||
@ -799,6 +897,10 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
|
||||
dict.__setitem__(rv, k, map_aggregate(v, fn))
|
||||
return rv
|
||||
elif isinstance(a, slice):
|
||||
return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn))
|
||||
return slice(
|
||||
map_aggregate(a.start, fn),
|
||||
map_aggregate(a.stop, fn),
|
||||
map_aggregate(a.step, fn),
|
||||
)
|
||||
else:
|
||||
return fn(a)
|
||||
|
||||
@ -1,63 +1,100 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import enum
|
||||
import inspect
|
||||
import numbers
|
||||
import types
|
||||
import typing
|
||||
import enum
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._jit_internal import boolean_dispatched
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from torch._ops import OpOverloadPacket, OpOverload
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .node import Argument
|
||||
|
||||
__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint",
|
||||
"type_matches", "normalize_function", "normalize_module"]
|
||||
__all__ = [
|
||||
"ArgsKwargsPair",
|
||||
"check_for_mutable_operation",
|
||||
"get_signature_for_torch_op",
|
||||
"create_type_hint",
|
||||
"type_matches",
|
||||
"normalize_function",
|
||||
"normalize_module",
|
||||
]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class ArgsKwargsPair(NamedTuple):
|
||||
"""
|
||||
Simple named tuple for wrapping args/kwargs pairs.
|
||||
"""
|
||||
|
||||
args: Tuple[Any, ...]
|
||||
kwargs: Dict[str, Any]
|
||||
|
||||
_manual_overrides : Dict[Callable, List[inspect.Signature]] = {}
|
||||
|
||||
_manual_overrides: Dict[Callable, List[inspect.Signature]] = {}
|
||||
|
||||
|
||||
def _nonzero_schemas():
|
||||
signatures = []
|
||||
|
||||
def nonzero(self):
|
||||
pass
|
||||
|
||||
signatures.append(inspect.signature(nonzero))
|
||||
|
||||
def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef]
|
||||
def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
signatures.append(inspect.signature(nonzero))
|
||||
|
||||
return signatures
|
||||
|
||||
|
||||
_manual_overrides[torch.nonzero] = _nonzero_schemas()
|
||||
|
||||
|
||||
class _FakeGlobalNamespace:
|
||||
def __getattr__(self, name):
|
||||
if name == 'torch':
|
||||
if name == "torch":
|
||||
return torch
|
||||
raise RuntimeError('Expected a torch namespace lookup')
|
||||
raise RuntimeError("Expected a torch namespace lookup")
|
||||
|
||||
_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout,
|
||||
'number' : numbers.Number, 'Future' : torch.jit.Future,
|
||||
'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme,
|
||||
'__torch__': _FakeGlobalNamespace(), 'NoneType': type(None),
|
||||
'Storage': torch.UntypedStorage,
|
||||
't': typing.TypeVar('t')}
|
||||
|
||||
_type_eval_globals = {
|
||||
"Tensor": torch.Tensor,
|
||||
"Device": torch.device,
|
||||
"Layout": torch.layout,
|
||||
"number": numbers.Number,
|
||||
"Future": torch.jit.Future,
|
||||
"AnyEnumType": enum.Enum,
|
||||
"QScheme": torch.qscheme,
|
||||
"__torch__": _FakeGlobalNamespace(),
|
||||
"NoneType": type(None),
|
||||
"Storage": torch.UntypedStorage,
|
||||
"t": typing.TypeVar("t"),
|
||||
}
|
||||
for k in dir(typing):
|
||||
_type_eval_globals[k] = getattr(typing, k)
|
||||
|
||||
def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
|
||||
|
||||
def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any:
|
||||
"""
|
||||
Convert a TorchScript type to a Python type (including subtypes) via
|
||||
eval'ing the annotation_str. _type_eval_globals sets up expressions
|
||||
@ -65,9 +102,13 @@ def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
|
||||
"""
|
||||
return eval(ts_type.annotation_str, _type_eval_globals)
|
||||
|
||||
def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
|
||||
|
||||
def _torchscript_schema_to_signature_impl(
|
||||
ts_schema: torch._C.FunctionSchema,
|
||||
) -> inspect.Signature:
|
||||
from inspect import Parameter
|
||||
parameters : List[Parameter] = []
|
||||
|
||||
parameters: List[Parameter] = []
|
||||
for arg in ts_schema.arguments:
|
||||
arg_type = _torchscript_type_to_python_type(arg.type)
|
||||
default = arg.default_value if arg.has_default_value() else Parameter.empty
|
||||
@ -76,8 +117,12 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -
|
||||
# argument name. Downstream, if someone converts that positional argument to a keyword
|
||||
# argument, the name mismatch will break things, so here we're going to normalize the
|
||||
# name to "input"
|
||||
name = arg.name if arg.name != 'self' else 'input'
|
||||
kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD
|
||||
name = arg.name if arg.name != "self" else "input"
|
||||
kind = (
|
||||
Parameter.KEYWORD_ONLY
|
||||
if arg.kwarg_only
|
||||
else Parameter.POSITIONAL_OR_KEYWORD
|
||||
)
|
||||
# "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
|
||||
if name == "from":
|
||||
assert kind == Parameter.POSITIONAL_OR_KEYWORD
|
||||
@ -87,9 +132,18 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -
|
||||
# This renders all previous arguments to positional only
|
||||
for idx, p in enumerate(parameters):
|
||||
assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
|
||||
parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation)
|
||||
parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type))
|
||||
return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns]
|
||||
parameters[idx] = Parameter(
|
||||
name=p.name,
|
||||
kind=Parameter.POSITIONAL_ONLY,
|
||||
default=p.default,
|
||||
annotation=p.annotation,
|
||||
)
|
||||
parameters.append(
|
||||
Parameter(name=name, kind=kind, default=default, annotation=arg_type)
|
||||
)
|
||||
return_types = [
|
||||
_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns
|
||||
]
|
||||
if len(return_types) == 0:
|
||||
return_type = None
|
||||
elif len(return_types) == 1:
|
||||
@ -99,9 +153,13 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -
|
||||
|
||||
return inspect.Signature(parameters, return_annotation=return_type)
|
||||
|
||||
_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {}
|
||||
|
||||
def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
|
||||
_SCHEMA_TO_SIGNATURE_CACHE: Dict[Tuple[str, str], inspect.Signature] = {}
|
||||
|
||||
|
||||
def _torchscript_schema_to_signature(
|
||||
ts_schema: torch._C.FunctionSchema,
|
||||
) -> inspect.Signature:
|
||||
# Cached as it's called in the hot path of FakeTensor dispatch
|
||||
cache_key = ts_schema.name, ts_schema.overload_name
|
||||
cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key)
|
||||
@ -112,8 +170,11 @@ def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> ins
|
||||
_SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res
|
||||
return res
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']):
|
||||
def check_for_mutable_operation(
|
||||
target: Callable, args: Tuple["Argument", ...], kwargs: Dict[str, "Argument"]
|
||||
):
|
||||
signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
|
||||
|
||||
if signatures and schemas:
|
||||
@ -131,9 +192,11 @@ def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...]
|
||||
|
||||
def throw_if_mutable(schema):
|
||||
if schema.is_mutable:
|
||||
raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional '
|
||||
f'code, so operations that mutate operands in-place (e.g. via `out` arguments) '
|
||||
f'are not supported')
|
||||
raise RuntimeError(
|
||||
f"Tried to trace mutable operation {schema}. FX only supports functional "
|
||||
f"code, so operations that mutate operands in-place (e.g. via `out` arguments) "
|
||||
f"are not supported"
|
||||
)
|
||||
|
||||
if len(matched_schemas) == 0:
|
||||
# Did not match any schema. Cannot check for mutation
|
||||
@ -147,8 +210,9 @@ def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...]
|
||||
# do nothing.
|
||||
pass
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
|
||||
def get_signature_for_torch_op(op: Callable, return_schemas: bool = False):
|
||||
"""
|
||||
Given an operator on the `torch` namespace, return a list of `inspect.Signature`
|
||||
objects corresponding to the overloads of that op.. May return `None` if a signature
|
||||
@ -181,6 +245,7 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
|
||||
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
|
||||
return (signatures, schemas) if return_schemas else signatures
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def create_type_hint(x):
|
||||
"""
|
||||
@ -198,11 +263,15 @@ def create_type_hint(x):
|
||||
if isinstance(x, (list, tuple)):
|
||||
# todo(chilli): Figure out the right way for mypy to handle this
|
||||
if isinstance(x, list):
|
||||
|
||||
def ret_type(x):
|
||||
return List[x] # type: ignore[valid-type]
|
||||
|
||||
else:
|
||||
|
||||
def ret_type(x):
|
||||
return Tuple[x, ...]
|
||||
|
||||
if len(x) == 0:
|
||||
return ret_type(Any)
|
||||
base_type = x[0]
|
||||
@ -216,12 +285,15 @@ def create_type_hint(x):
|
||||
return ret_type(base_type)
|
||||
except Exception:
|
||||
# We tried to create a type hint for list but failed.
|
||||
warnings.warn(f"We were not able to successfully create type hint from the type {x}")
|
||||
warnings.warn(
|
||||
f"We were not able to successfully create type hint from the type {x}"
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def type_matches(signature_type : Any, argument_type : Any):
|
||||
sig_origin_type = getattr(signature_type, '__origin__', signature_type)
|
||||
def type_matches(signature_type: Any, argument_type: Any):
|
||||
sig_origin_type = getattr(signature_type, "__origin__", signature_type)
|
||||
|
||||
if signature_type is argument_type:
|
||||
return True
|
||||
@ -236,13 +308,14 @@ def type_matches(signature_type : Any, argument_type : Any):
|
||||
# int can be promoted to List[int]
|
||||
return True
|
||||
|
||||
if getattr(signature_type, '__origin__', None) in {list, List}:
|
||||
if getattr(signature_type, "__origin__", None) in {list, List}:
|
||||
sig_el_type = signature_type.__args__[0]
|
||||
if not inspect.isclass(sig_el_type):
|
||||
warnings.warn(
|
||||
f"Does not support nested parametric types, got {signature_type}. Please file a bug.")
|
||||
f"Does not support nested parametric types, got {signature_type}. Please file a bug."
|
||||
)
|
||||
return False
|
||||
if getattr(argument_type, '__origin__', None) in {list, List}:
|
||||
if getattr(argument_type, "__origin__", None) in {list, List}:
|
||||
return issubclass(argument_type.__args__[0], sig_el_type)
|
||||
|
||||
def is_homogeneous_tuple(t):
|
||||
@ -267,11 +340,16 @@ def type_matches(signature_type : Any, argument_type : Any):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def normalize_function(
|
||||
target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
|
||||
kwarg_types : Optional[Dict[str, Any]] = None,
|
||||
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
|
||||
target: Callable,
|
||||
args: Tuple[Any],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
arg_types: Optional[Tuple[Any]] = None,
|
||||
kwarg_types: Optional[Dict[str, Any]] = None,
|
||||
normalize_to_only_use_kwargs: bool = False,
|
||||
) -> Optional[ArgsKwargsPair]:
|
||||
"""
|
||||
Returns normalized arguments to PyTorch functions. This means that
|
||||
`args/kwargs` will be matched up to the functional's
|
||||
@ -308,14 +386,19 @@ def normalize_function(
|
||||
# branch signature for analysis. Otherwise, leave this un-normalized
|
||||
assert not isinstance(target, str)
|
||||
dispatched = boolean_dispatched[target]
|
||||
if_true, if_false = dispatched['if_true'], dispatched['if_false']
|
||||
if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters:
|
||||
if_true, if_false = dispatched["if_true"], dispatched["if_false"]
|
||||
if (
|
||||
inspect.signature(if_true).parameters
|
||||
!= inspect.signature(if_false).parameters
|
||||
):
|
||||
return None
|
||||
target_for_analysis = if_true
|
||||
|
||||
assert callable(target_for_analysis)
|
||||
sig = inspect.signature(inspect.unwrap(target_for_analysis))
|
||||
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs)
|
||||
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
|
||||
sig, args, kwargs, normalize_to_only_use_kwargs
|
||||
)
|
||||
else:
|
||||
assert callable(target)
|
||||
torch_op_schemas = get_signature_for_torch_op(target)
|
||||
@ -336,8 +419,9 @@ def normalize_function(
|
||||
pass
|
||||
elif len(matched_schemas) == 1:
|
||||
# Matched exactly one schema, unambiguous
|
||||
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs,
|
||||
normalize_to_only_use_kwargs)
|
||||
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
|
||||
matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs
|
||||
)
|
||||
else:
|
||||
if arg_types is not None or kwarg_types is not None:
|
||||
arg_types = arg_types if arg_types else cast(Tuple[Any], ())
|
||||
@ -345,30 +429,49 @@ def normalize_function(
|
||||
for candidate_signature in torch_op_schemas:
|
||||
sig_matches = True
|
||||
try:
|
||||
bound_types = candidate_signature.bind(*arg_types, **kwarg_types)
|
||||
bound_types = candidate_signature.bind(
|
||||
*arg_types, **kwarg_types
|
||||
)
|
||||
for arg_name, arg_type in bound_types.arguments.items():
|
||||
param = candidate_signature.parameters[arg_name]
|
||||
sig_matches = sig_matches and type_matches(param.annotation, arg_type)
|
||||
sig_matches = sig_matches and type_matches(
|
||||
param.annotation, arg_type
|
||||
)
|
||||
except TypeError:
|
||||
sig_matches = False
|
||||
if sig_matches:
|
||||
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,
|
||||
normalize_to_only_use_kwargs)
|
||||
new_args_and_kwargs = (
|
||||
_args_kwargs_to_normalized_args_kwargs(
|
||||
candidate_signature,
|
||||
args,
|
||||
kwargs,
|
||||
normalize_to_only_use_kwargs,
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Matched more than one schema. In this situation, the caller must provide the types of
|
||||
# the arguments of the overload they expect.
|
||||
schema_printouts = '\n'.join(str(schema) for schema in matched_schemas)
|
||||
raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but '
|
||||
f'the schema match was ambiguous! Please provide argument types to '
|
||||
f'the normalize_arguments() call. Available schemas:\n{schema_printouts}')
|
||||
schema_printouts = "\n".join(
|
||||
str(schema) for schema in matched_schemas
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Tried to normalize arguments to {torch.typename(target)} but "
|
||||
f"the schema match was ambiguous! Please provide argument types to "
|
||||
f"the normalize_arguments() call. Available schemas:\n{schema_printouts}"
|
||||
)
|
||||
|
||||
return new_args_and_kwargs
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def normalize_module(
|
||||
root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
|
||||
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
|
||||
root: torch.nn.Module,
|
||||
target: str,
|
||||
args: Tuple[Any],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_to_only_use_kwargs: bool = False,
|
||||
) -> Optional[ArgsKwargsPair]:
|
||||
"""
|
||||
Returns normalized arguments to PyTorch modules. This means that
|
||||
`args/kwargs` will be matched up to the functional's
|
||||
@ -391,22 +494,29 @@ def normalize_module(
|
||||
try:
|
||||
submod = root.get_submodule(target)
|
||||
except AttributeError as e:
|
||||
raise RuntimeError(f"Tried to normalize node with target {target} but root did not "
|
||||
f"have that target!") from e
|
||||
if hasattr(submod.__class__, '__name__'):
|
||||
raise RuntimeError(
|
||||
f"Tried to normalize node with target {target} but root did not "
|
||||
f"have that target!"
|
||||
) from e
|
||||
if hasattr(submod.__class__, "__name__"):
|
||||
classname = submod.__class__.__name__
|
||||
if getattr(torch.nn, classname, None) == submod.__class__:
|
||||
sig = inspect.signature(inspect.unwrap(submod.forward))
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,
|
||||
normalize_to_only_use_kwargs)
|
||||
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
|
||||
sig, args, kwargs, normalize_to_only_use_kwargs
|
||||
)
|
||||
return new_args_and_kwargs
|
||||
return None
|
||||
|
||||
def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...],
|
||||
kwargs : Dict[str, Any],
|
||||
normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]:
|
||||
|
||||
def _args_kwargs_to_normalized_args_kwargs(
|
||||
sig: inspect.Signature,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
normalize_to_only_use_kwargs: bool,
|
||||
) -> Optional[ArgsKwargsPair]:
|
||||
"""
|
||||
Given a call target, args, and kwargs, return the arguments normalized into
|
||||
an ArgsKwargsPair, or None if the type signature is not supported by
|
||||
@ -428,20 +538,22 @@ def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple
|
||||
# Don't currently support positional-only
|
||||
# or varargs (*args, **kwargs) signatures
|
||||
supported_parameter_types = {
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
}
|
||||
if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
|
||||
# Add an exception for one signature, which is common for random/uniform, i.e.:
|
||||
# Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
|
||||
# `from` is Python keyword and as such functions with that signature should have
|
||||
# positional-only args, but at the same time they could be dispatched as kwargs
|
||||
if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']:
|
||||
if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]:
|
||||
return None
|
||||
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
new_kwargs : Dict[str, Any] = {}
|
||||
new_args : List[Any] = []
|
||||
new_kwargs: Dict[str, Any] = {}
|
||||
new_args: List[Any] = []
|
||||
for i, param in enumerate(sig.parameters):
|
||||
if not normalize_to_only_use_kwargs and i < len(args):
|
||||
new_args.append(bound_args.arguments[param])
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
from . import graph_drawer
|
||||
from . import graph_manipulation
|
||||
from . import net_min_base
|
||||
from . import operator_support
|
||||
from . import param_fetch
|
||||
from . import reinplace
|
||||
from . import runtime_assert
|
||||
from . import shape_prop
|
||||
from . import split_module
|
||||
from . import split_utils
|
||||
from . import splitter_base
|
||||
from . import tools_common
|
||||
from . import (
|
||||
graph_drawer,
|
||||
graph_manipulation,
|
||||
net_min_base,
|
||||
operator_support,
|
||||
param_fetch,
|
||||
reinplace,
|
||||
runtime_assert,
|
||||
shape_prop,
|
||||
split_module,
|
||||
split_utils,
|
||||
splitter_base,
|
||||
tools_common,
|
||||
)
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||||
from torch.fx.passes.operator_support import OperatorSupport
|
||||
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
import operator
|
||||
|
||||
class CudaGraphsSupport(OperatorSupport):
|
||||
# TODO: why is submodules passed here
|
||||
@ -27,7 +28,7 @@ class CudaGraphsSupport(OperatorSupport):
|
||||
|
||||
def find_not_cuda(t):
|
||||
nonlocal found_not_cuda
|
||||
if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
|
||||
if isinstance(t, torch.Tensor) and t.device.type != "cuda":
|
||||
found_not_cuda = True
|
||||
|
||||
for n in node.all_input_nodes:
|
||||
@ -40,6 +41,7 @@ class CudaGraphsSupport(OperatorSupport):
|
||||
|
||||
return not found_not_cuda
|
||||
|
||||
|
||||
def partition_cudagraphs(gm, inputs):
|
||||
"""
|
||||
Partition an FX graph into sub-GraphModules that can be validly run under
|
||||
@ -51,7 +53,9 @@ def partition_cudagraphs(gm, inputs):
|
||||
supported_ops = CudaGraphsSupport()
|
||||
# TODO: single node partition may be wrong due to the pessimization
|
||||
# from copying in and out the data. Check in benchmarks, perhaps
|
||||
partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True)
|
||||
partitioner = CapabilityBasedPartitioner(
|
||||
gm, supported_ops, allows_single_node_partition=True
|
||||
)
|
||||
partitions = partitioner.propose_partitions()
|
||||
fused_graph = partitioner.fuse_partitions(partitions)
|
||||
return fused_graph
|
||||
|
||||
@ -1,20 +1,45 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, Tuple, Any
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from torch.fx import GraphModule, Graph
|
||||
from torch.fx import Node
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
# stateful ops are banned from CSE
|
||||
rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950
|
||||
rand_ops = {
|
||||
aten.dropout,
|
||||
aten._fused_dropout,
|
||||
aten._standard_gamma,
|
||||
aten.bernoulli,
|
||||
aten.multinomial,
|
||||
aten.native_dropout,
|
||||
aten.normal,
|
||||
aten.poisson,
|
||||
aten.binomial,
|
||||
aten.rrelu,
|
||||
aten.rand_like,
|
||||
aten.rand,
|
||||
aten.randint,
|
||||
aten.randn,
|
||||
aten.randperm,
|
||||
} # noqa: E501,B950
|
||||
|
||||
inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
|
||||
inplace_ops = {
|
||||
aten.add_,
|
||||
aten.sub_,
|
||||
aten.mul_,
|
||||
aten.div_,
|
||||
aten.pow_,
|
||||
aten.lerp_,
|
||||
aten.relu_,
|
||||
aten.sigmoid_,
|
||||
aten.tanh_,
|
||||
} # noqa: E501
|
||||
|
||||
|
||||
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
|
||||
@ -24,7 +49,6 @@ def get_CSE_banned_ops():
|
||||
|
||||
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
|
||||
class CSEPass(PassBase):
|
||||
|
||||
def __init__(self, banned_ops=None):
|
||||
"""
|
||||
This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
|
||||
@ -58,20 +82,32 @@ class CSEPass(PassBase):
|
||||
result = p(traced_graph)
|
||||
print(result.graph_module)
|
||||
"""
|
||||
|
||||
def get_aten_target(node):
|
||||
if hasattr(node.target, 'overloadpacket'):
|
||||
if hasattr(node.target, "overloadpacket"):
|
||||
return node.target.overloadpacket
|
||||
return node.target
|
||||
|
||||
modified = False
|
||||
new_graph = Graph()
|
||||
env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph
|
||||
hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph
|
||||
token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token
|
||||
env: Dict[
|
||||
Node, Node
|
||||
] = {} # map from node in the old graph to node in the new graph
|
||||
hash_env: Dict[
|
||||
Tuple[torch._ops.OpOverload, int], Node
|
||||
] = {} # map from hash to a node in the new graph
|
||||
token_map: Dict[
|
||||
Tuple[torch._ops.OpOverload, int], Dict[str, Any]
|
||||
] = {} # map from hash to token
|
||||
for n in graph_module.graph.nodes:
|
||||
# The placeholder, output, and get_attr nodes are copied to the new graph without change
|
||||
# do not CSE away random operations
|
||||
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops:
|
||||
if (
|
||||
n.op == "placeholder"
|
||||
or n.op == "output"
|
||||
or n.op == "get_attr"
|
||||
or get_aten_target(n) in self.banned_ops
|
||||
):
|
||||
new_node = new_graph.node_copy(n, lambda x: env[x])
|
||||
env[n] = new_node
|
||||
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
|
||||
@ -84,13 +120,19 @@ class CSEPass(PassBase):
|
||||
if isinstance(v, Node) and v in env:
|
||||
arg_list[i] = env[v]
|
||||
return tuple(arg_list), spec
|
||||
|
||||
args, args_spec = substitute(n.args)
|
||||
kwargs, kwargs_spec = substitute(n.kwargs)
|
||||
|
||||
# each token corresponds to a unique node
|
||||
# nodes with the same token can be substituted
|
||||
token = {"target": n.target, "args": args, "args_spec": args_spec,
|
||||
"kwargs": kwargs, "kwargs_spec": kwargs_spec}
|
||||
token = {
|
||||
"target": n.target,
|
||||
"args": args,
|
||||
"args_spec": args_spec,
|
||||
"kwargs": kwargs,
|
||||
"kwargs_spec": kwargs_spec,
|
||||
}
|
||||
|
||||
# hash substituted args to a number, do not hash specs because specs are not hashable
|
||||
hash_arg = hash((args, kwargs))
|
||||
|
||||
@ -2,13 +2,15 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.fx
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.fx import Node
|
||||
from torch.fx.node import map_aggregate
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
|
||||
from torch.fx.experimental.proxy_tensor import snapshot_fake, py_sym_types
|
||||
from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
|
||||
from torch.fx.node import map_aggregate
|
||||
|
||||
|
||||
__all__ = ["FakeTensorProp"]
|
||||
|
||||
__all__ = ['FakeTensorProp']
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class FakeTensorProp(torch.fx.Interpreter):
|
||||
@ -24,7 +26,10 @@ class FakeTensorProp(torch.fx.Interpreter):
|
||||
module (GraphModule): The module to be executed
|
||||
mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
|
||||
"""
|
||||
def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None):
|
||||
|
||||
def __init__(
|
||||
self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None
|
||||
):
|
||||
super().__init__(module)
|
||||
if mode is None:
|
||||
mode = FakeTensorMode()
|
||||
@ -33,7 +38,10 @@ class FakeTensorProp(torch.fx.Interpreter):
|
||||
mode.reset_nt_tensor_id_counter()
|
||||
|
||||
def run_node(self, n: Node):
|
||||
from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
compute_unbacked_bindings,
|
||||
rebind_unbacked,
|
||||
)
|
||||
|
||||
result = super().run_node(n)
|
||||
rebind_unbacked(self._mode.shape_env, n, result)
|
||||
@ -52,8 +60,10 @@ class FakeTensorProp(torch.fx.Interpreter):
|
||||
|
||||
meta = map_aggregate(result, extract_val)
|
||||
if meta is not None:
|
||||
n.meta['val'] = meta
|
||||
if (shape_env := self._mode.shape_env) and (symbol_to_path := compute_unbacked_bindings(shape_env, result)):
|
||||
n.meta["val"] = meta
|
||||
if (shape_env := self._mode.shape_env) and (
|
||||
symbol_to_path := compute_unbacked_bindings(shape_env, result)
|
||||
):
|
||||
n.meta["unbacked_bindings"] = symbol_to_path
|
||||
|
||||
return result
|
||||
|
||||
@ -58,6 +58,7 @@ _WEIGHT_TEMPLATE = {
|
||||
}
|
||||
|
||||
if HAS_PYDOT:
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class FxGraphDrawer:
|
||||
"""
|
||||
@ -87,7 +88,12 @@ if HAS_PYDOT:
|
||||
|
||||
self._dot_graphs = {
|
||||
name: self._to_dot(
|
||||
graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace
|
||||
graph_module,
|
||||
name,
|
||||
ignore_getattr,
|
||||
ignore_parameters_and_buffers,
|
||||
skip_node_names_in_args,
|
||||
parse_stack_trace,
|
||||
)
|
||||
}
|
||||
|
||||
@ -127,8 +133,8 @@ if HAS_PYDOT:
|
||||
>>> symbolic_traced = torch.fx.symbolic_trace(module)
|
||||
>>> # setup output file
|
||||
>>> import ubelt as ub
|
||||
>>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir()
|
||||
>>> fpath = dpath / 'linear.svg'
|
||||
>>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir()
|
||||
>>> fpath = dpath / "linear.svg"
|
||||
>>> # draw the graph
|
||||
>>> g = FxGraphDrawer(symbolic_traced, "linear")
|
||||
>>> g.get_dot_graph().write_svg(fpath)
|
||||
@ -148,7 +154,6 @@ if HAS_PYDOT:
|
||||
return self._dot_graphs
|
||||
|
||||
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
|
||||
|
||||
template = {
|
||||
"shape": self.dot_graph_shape,
|
||||
"fillcolor": "#CAFFE3",
|
||||
@ -161,7 +166,9 @@ if HAS_PYDOT:
|
||||
# Use a random color for each node; based on its name so it's stable.
|
||||
target_name = node._pretty_print_target(node.target)
|
||||
target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16)
|
||||
template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)]
|
||||
template["fillcolor"] = _HASH_COLOR_MAP[
|
||||
target_hash % len(_HASH_COLOR_MAP)
|
||||
]
|
||||
return template
|
||||
|
||||
def _get_leaf_node(
|
||||
@ -199,12 +206,11 @@ if HAS_PYDOT:
|
||||
full_file_name: str,
|
||||
truncate_to_last_n: int = 2,
|
||||
):
|
||||
splits = full_file_name.split('/')
|
||||
splits = full_file_name.split("/")
|
||||
if len(splits) >= truncate_to_last_n:
|
||||
return '/'.join(splits[-truncate_to_last_n:])
|
||||
return "/".join(splits[-truncate_to_last_n:])
|
||||
return full_file_name
|
||||
|
||||
|
||||
def _get_node_label(
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
@ -219,8 +225,7 @@ if HAS_PYDOT:
|
||||
elif isinstance(arg, dict):
|
||||
prefix, suffix = r"|kwargs={\l", r",\n}\l"
|
||||
arg_strs_list = [
|
||||
f"{k}: {_format_arg(v, max_list_len=8)}"
|
||||
for k, v in arg.items()
|
||||
f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items()
|
||||
]
|
||||
else: # Fall back to nothing in unexpected case.
|
||||
return ""
|
||||
@ -235,7 +240,6 @@ if HAS_PYDOT:
|
||||
arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
|
||||
return arg_strs.replace("{", r"\{").replace("}", r"\}")
|
||||
|
||||
|
||||
label = "{" + f"name=%{node.name}|op_code={node.op}\n"
|
||||
|
||||
if node.op == "call_module":
|
||||
@ -244,7 +248,10 @@ if HAS_PYDOT:
|
||||
extra = ""
|
||||
if hasattr(leaf_module, "__constants__"):
|
||||
extra = r"\n".join(
|
||||
[f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr]
|
||||
[
|
||||
f"{c}: {getattr(leaf_module, c)}"
|
||||
for c in leaf_module.__constants__
|
||||
] # type: ignore[union-attr]
|
||||
)
|
||||
label += extra + r"\n"
|
||||
else:
|
||||
@ -252,7 +259,10 @@ if HAS_PYDOT:
|
||||
if self.normalize_args:
|
||||
try:
|
||||
args, kwargs = normalize_function( # type: ignore[misc]
|
||||
node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
|
||||
node.target, # type: ignore[arg-type]
|
||||
node.args, # type: ignore[arg-type]
|
||||
node.kwargs,
|
||||
normalize_to_only_use_kwargs=True,
|
||||
)
|
||||
except Exception:
|
||||
# Fallback to not normalizing if there's an exception.
|
||||
@ -266,12 +276,12 @@ if HAS_PYDOT:
|
||||
label += _get_str_for_args_kwargs(kwargs)
|
||||
label += f"|num_users={len(node.users)}" + r"\n"
|
||||
|
||||
tensor_meta = node.meta.get('tensor_meta')
|
||||
tensor_meta = node.meta.get("tensor_meta")
|
||||
label += self._tensor_meta_to_label(tensor_meta)
|
||||
|
||||
# for original fx graph
|
||||
# print buf=buf0, n_origin=6
|
||||
buf_meta = node.meta.get('buf_meta', None)
|
||||
buf_meta = node.meta.get("buf_meta", None)
|
||||
if buf_meta is not None:
|
||||
label += f"|buf={buf_meta.name}" + r"\n"
|
||||
label += f"|n_origin={buf_meta.n_origin}" + r"\n"
|
||||
@ -281,8 +291,10 @@ if HAS_PYDOT:
|
||||
if parse_stack_trace and node.stack_trace is not None:
|
||||
parsed_stack_trace = _parse_stack_trace(node.stack_trace)
|
||||
fname = self._shorten_file_name(parsed_stack_trace.file)
|
||||
label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n"
|
||||
|
||||
label += (
|
||||
f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}"
|
||||
+ r"\n"
|
||||
)
|
||||
|
||||
return label + "}"
|
||||
|
||||
@ -322,19 +334,43 @@ if HAS_PYDOT:
|
||||
assert "qscheme" in tm.qparams
|
||||
qscheme = tm.qparams["qscheme"]
|
||||
if qscheme in {
|
||||
torch.per_tensor_affine,
|
||||
torch.per_tensor_symmetric,
|
||||
torch.per_tensor_affine,
|
||||
torch.per_tensor_symmetric,
|
||||
}:
|
||||
result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
|
||||
result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
|
||||
result += (
|
||||
"|"
|
||||
+ "q_zero_point"
|
||||
+ "="
|
||||
+ str(tm.qparams["zero_point"])
|
||||
+ r"\n"
|
||||
)
|
||||
elif qscheme in {
|
||||
torch.per_channel_affine,
|
||||
torch.per_channel_symmetric,
|
||||
torch.per_channel_affine_float_qparams,
|
||||
torch.per_channel_affine,
|
||||
torch.per_channel_symmetric,
|
||||
torch.per_channel_affine_float_qparams,
|
||||
}:
|
||||
result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
|
||||
result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
|
||||
result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n"
|
||||
result += (
|
||||
"|"
|
||||
+ "q_per_channel_scale"
|
||||
+ "="
|
||||
+ str(tm.qparams["scale"])
|
||||
+ r"\n"
|
||||
)
|
||||
result += (
|
||||
"|"
|
||||
+ "q_per_channel_zero_point"
|
||||
+ "="
|
||||
+ str(tm.qparams["zero_point"])
|
||||
+ r"\n"
|
||||
)
|
||||
result += (
|
||||
"|"
|
||||
+ "q_per_channel_axis"
|
||||
+ "="
|
||||
+ str(tm.qparams["axis"])
|
||||
+ r"\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported qscheme: {qscheme}")
|
||||
result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
|
||||
@ -363,7 +399,6 @@ if HAS_PYDOT:
|
||||
# "TB" means top-to-bottom rank direction in layout
|
||||
dot_graph = pydot.Dot(name, rankdir="TB")
|
||||
|
||||
|
||||
buf_name_to_subgraph = {}
|
||||
|
||||
for node in graph_module.graph.nodes:
|
||||
@ -372,16 +407,22 @@ if HAS_PYDOT:
|
||||
|
||||
style = self._get_node_style(node)
|
||||
dot_node = pydot.Node(
|
||||
node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style
|
||||
node.name,
|
||||
label=self._get_node_label(
|
||||
graph_module, node, skip_node_names_in_args, parse_stack_trace
|
||||
),
|
||||
**style,
|
||||
)
|
||||
|
||||
current_graph = dot_graph
|
||||
|
||||
buf_meta = node.meta.get('buf_meta', None)
|
||||
buf_meta = node.meta.get("buf_meta", None)
|
||||
if buf_meta is not None and buf_meta.n_origin > 1:
|
||||
buf_name = buf_meta.name
|
||||
if buf_name not in buf_name_to_subgraph:
|
||||
buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name)
|
||||
buf_name_to_subgraph[buf_name] = pydot.Cluster(
|
||||
buf_name, label=buf_name
|
||||
)
|
||||
current_graph = buf_name_to_subgraph.get(buf_name)
|
||||
|
||||
current_graph.add_node(dot_node)
|
||||
@ -407,12 +448,14 @@ if HAS_PYDOT:
|
||||
if node.op == "call_module":
|
||||
leaf_module = self._get_leaf_node(graph_module, node)
|
||||
|
||||
if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule):
|
||||
if not ignore_parameters_and_buffers and not isinstance(
|
||||
leaf_module, torch.fx.GraphModule
|
||||
):
|
||||
get_module_params_or_buffers()
|
||||
|
||||
for subgraph in buf_name_to_subgraph.values():
|
||||
subgraph.set('color', 'royalblue')
|
||||
subgraph.set('penwidth', '2')
|
||||
subgraph.set("color", "royalblue")
|
||||
subgraph.set("penwidth", "2")
|
||||
dot_graph.add_subgraph(subgraph)
|
||||
|
||||
for node in graph_module.graph.nodes:
|
||||
@ -426,6 +469,7 @@ if HAS_PYDOT:
|
||||
|
||||
else:
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class FxGraphDrawer:
|
||||
def __init__(
|
||||
@ -439,5 +483,7 @@ else:
|
||||
dot_graph_shape: Optional[str] = None,
|
||||
normalize_args: bool = False,
|
||||
):
|
||||
raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install '
|
||||
'pydot through your favorite Python package manager.')
|
||||
raise RuntimeError(
|
||||
"FXGraphDrawer requires the pydot package to be installed. Please install "
|
||||
"pydot through your favorite Python package manager."
|
||||
)
|
||||
|
||||
@ -5,15 +5,18 @@ import torch
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import (
|
||||
map_arg,
|
||||
Node,
|
||||
Target,
|
||||
)
|
||||
from torch.fx.node import map_arg, Node, Target
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
|
||||
__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
|
||||
'get_size_of_node']
|
||||
|
||||
__all__ = [
|
||||
"replace_target_nodes_with",
|
||||
"size_bytes",
|
||||
"get_size_of_all_nodes",
|
||||
"get_tensor_meta",
|
||||
"get_size_of_node",
|
||||
]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def replace_target_nodes_with(
|
||||
|
||||
@ -1,2 +1 @@
|
||||
|
||||
from . import pass_manager
|
||||
|
||||
@ -1,22 +1,24 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
|
||||
import collections
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
from copy import copy
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Set
|
||||
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node, _get_qualified_name
|
||||
from torch.fx.node import _get_qualified_name, Node
|
||||
from torch.fx.passes.operator_support import OperatorSupportBase
|
||||
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class Partition:
|
||||
def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
|
||||
def __init__(
|
||||
self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None
|
||||
):
|
||||
self.id = id
|
||||
self.nodes = dict.fromkeys(nodes) if nodes is not None else {}
|
||||
|
||||
@ -32,6 +34,7 @@ class Partition:
|
||||
def size(self):
|
||||
return len(self.nodes)
|
||||
|
||||
|
||||
class _DependencyViewer:
|
||||
def __init__(self, graph_module: GraphModule):
|
||||
self.upstreams = collections.defaultdict(set)
|
||||
@ -55,15 +58,16 @@ class _DependencyViewer:
|
||||
def upstreams_of(self, node: Node) -> Set[Node]:
|
||||
return self.upstreams[node]
|
||||
|
||||
class CapabilityBasedPartitioner:
|
||||
|
||||
def __init__(self,
|
||||
graph_module: GraphModule,
|
||||
operator_support: OperatorSupportBase,
|
||||
allows_single_node_partition: bool = False,
|
||||
non_compute_ops: Optional[Sequence[str]] = None,
|
||||
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
|
||||
) -> None:
|
||||
class CapabilityBasedPartitioner:
|
||||
def __init__(
|
||||
self,
|
||||
graph_module: GraphModule,
|
||||
operator_support: OperatorSupportBase,
|
||||
allows_single_node_partition: bool = False,
|
||||
non_compute_ops: Optional[Sequence[str]] = None,
|
||||
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
|
||||
) -> None:
|
||||
self.graph_module = graph_module
|
||||
self.operator_support = operator_support
|
||||
self.allows_single_node_partition = allows_single_node_partition
|
||||
@ -76,19 +80,21 @@ class CapabilityBasedPartitioner:
|
||||
self.dependency_viewer = _DependencyViewer(graph_module)
|
||||
|
||||
def __is_node_supported(self, node: Node) -> bool:
|
||||
return (
|
||||
self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
|
||||
return self.operator_support.is_node_supported(
|
||||
dict(self.graph_module.named_modules()), node
|
||||
)
|
||||
|
||||
def propose_partitions(self) -> List[Partition]:
|
||||
# partition_map is a mapping from partition id to a set of partition id's.
|
||||
# The value set contains all the partition ids that can be reached by doing a
|
||||
# DFS starting from the partition id in the key.
|
||||
partition_map : Dict[int, Set] = collections.defaultdict(set)
|
||||
partition_map: Dict[int, Set] = collections.defaultdict(set)
|
||||
|
||||
# assumptions: nodes in candidate list is sorted in topological order
|
||||
assignment: Dict[Node, int] = {} # mapping from node to partition_id
|
||||
partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition
|
||||
assignment: Dict[Node, int] = {} # mapping from node to partition_id
|
||||
partitions_by_id: Dict[
|
||||
int, Partition
|
||||
] = {} # mapping from partition_id to partition
|
||||
new_partition_id = itertools.count()
|
||||
|
||||
# try to merge partition other_id into partition self_id
|
||||
@ -149,7 +155,9 @@ class CapabilityBasedPartitioner:
|
||||
# delete other partition
|
||||
del partitions_by_id[other_id]
|
||||
|
||||
partition_map[self_id] = partition_map[self_id].union(partition_map[other_id])
|
||||
partition_map[self_id] = partition_map[self_id].union(
|
||||
partition_map[other_id]
|
||||
)
|
||||
del partition_map[other_id]
|
||||
|
||||
return True
|
||||
@ -223,16 +231,18 @@ class CapabilityBasedPartitioner:
|
||||
for node in self.graph_module.graph.nodes:
|
||||
is_tuple_output = True
|
||||
for user in node.users:
|
||||
if user.op != "call_function" or \
|
||||
_get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type]
|
||||
if (
|
||||
user.op != "call_function"
|
||||
or _get_qualified_name(user.target) != "_operator.getitem"
|
||||
): # type: ignore[arg-type]
|
||||
is_tuple_output = False
|
||||
break
|
||||
|
||||
# node has tuple outputs, re-assign all following getitem node into node's partition
|
||||
if is_tuple_output:
|
||||
id = assignment.get(node, None) # type: ignore[arg-type]
|
||||
id = assignment.get(node, None) # type: ignore[arg-type]
|
||||
for user in node.users:
|
||||
if assignment.get(user, None) != id: # type: ignore[arg-type]
|
||||
if assignment.get(user, None) != id: # type: ignore[arg-type]
|
||||
nodes_reassignment[user] = id # type: ignore[assignment]
|
||||
for node, id in nodes_reassignment.items():
|
||||
merge_single_node(node, id)
|
||||
@ -250,7 +260,10 @@ class CapabilityBasedPartitioner:
|
||||
assert callable(node.target)
|
||||
if _get_qualified_name(node.target) not in non_compute_ops:
|
||||
compute_node_count += 1
|
||||
if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
|
||||
if (
|
||||
_get_qualified_name(node.target)
|
||||
in self.allowed_single_node_partition_ops
|
||||
):
|
||||
compute_node_count += 1
|
||||
if compute_node_count <= 1:
|
||||
partitions_to_remove.append(id)
|
||||
@ -259,11 +272,17 @@ class CapabilityBasedPartitioner:
|
||||
|
||||
logger.debug("Partitions proposed:")
|
||||
for id, partition in partitions_by_id.items():
|
||||
logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes])
|
||||
logger.debug(
|
||||
"partition #%s: %s", id, [node.name for node in partition.nodes]
|
||||
)
|
||||
|
||||
return [partition for partition in partitions_by_id.values() if partition.size() > 0]
|
||||
return [
|
||||
partition for partition in partitions_by_id.values() if partition.size() > 0
|
||||
]
|
||||
|
||||
def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule:
|
||||
def fuse_partitions(
|
||||
self, partitions: List[Partition], prefix: str = "fused_"
|
||||
) -> GraphModule:
|
||||
logger.debug("Fusing partitions...")
|
||||
# fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ]
|
||||
return fuse_by_partitions(
|
||||
@ -277,15 +296,23 @@ class CapabilityBasedPartitioner:
|
||||
non_compute_ops = set(self.non_compute_ops)
|
||||
|
||||
def is_non_compute_node(node: Node):
|
||||
return node.op == "call_function" and \
|
||||
_get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# cache transparent nodes
|
||||
transparent_input_nodes: Dict[Node, bool] = {}
|
||||
transparent_output_nodes: Dict[Node, bool] = {}
|
||||
|
||||
def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
|
||||
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
|
||||
def is_transparent_input_node(
|
||||
node: Node, partition: Set[Node], removed_nodes: Set[Node]
|
||||
):
|
||||
if (
|
||||
node.op == "placeholder"
|
||||
or (node not in partition)
|
||||
or (node in removed_nodes)
|
||||
):
|
||||
return True
|
||||
if node in transparent_input_nodes:
|
||||
return transparent_input_nodes[node]
|
||||
@ -299,14 +326,22 @@ class CapabilityBasedPartitioner:
|
||||
transparent_input_nodes[node] = False
|
||||
return False
|
||||
|
||||
def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
|
||||
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
|
||||
def is_transparent_output_node(
|
||||
node: Node, partition: Set[Node], removed_nodes: Set[Node]
|
||||
):
|
||||
if (
|
||||
node.op == "placeholder"
|
||||
or (node not in partition)
|
||||
or (node in removed_nodes)
|
||||
):
|
||||
return True
|
||||
if node in transparent_output_nodes:
|
||||
return transparent_output_nodes[node]
|
||||
if is_non_compute_node(node):
|
||||
for output_n in node.users:
|
||||
if not is_transparent_output_node(output_n, partition, removed_nodes):
|
||||
if not is_transparent_output_node(
|
||||
output_n, partition, removed_nodes
|
||||
):
|
||||
transparent_output_nodes[node] = False
|
||||
return False
|
||||
transparent_output_nodes[node] = True
|
||||
@ -320,9 +355,12 @@ class CapabilityBasedPartitioner:
|
||||
# the set.
|
||||
remove_node: Set[Node] = set()
|
||||
for node in partition.nodes:
|
||||
if is_non_compute_node(node) and \
|
||||
(is_transparent_input_node(node, set(partition.nodes), remove_node) or
|
||||
is_transparent_output_node(node, set(partition.nodes), remove_node)):
|
||||
if is_non_compute_node(node) and (
|
||||
is_transparent_input_node(node, set(partition.nodes), remove_node)
|
||||
or is_transparent_output_node(
|
||||
node, set(partition.nodes), remove_node
|
||||
)
|
||||
):
|
||||
remove_node.add(node)
|
||||
|
||||
if len(remove_node) != 0:
|
||||
|
||||
@ -3,11 +3,12 @@ import abc
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
__all__ = ['PassResult', 'PassBase']
|
||||
__all__ = ["PassResult", "PassBase"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
|
||||
@ -16,9 +17,11 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
|
||||
graph_module: The modified graph module
|
||||
modified: A flag for if the pass has modified the graph module
|
||||
"""
|
||||
|
||||
def __new__(cls, graph_module, modified):
|
||||
return super().__new__(cls, graph_module, modified)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class PassBase(abc.ABC):
|
||||
"""
|
||||
|
||||
@ -1,19 +1,21 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
import logging
|
||||
from queue import Queue
|
||||
from functools import wraps
|
||||
from queue import Queue
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
|
||||
__all__ = ["pass_result_wrapper", "this_before_that_pass_constraint", "PassManager"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def pass_result_wrapper(fn: Callable) -> Callable:
|
||||
@ -46,6 +48,7 @@ def pass_result_wrapper(fn: Callable) -> Callable:
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
def _validate_pass_schedule_constraint(
|
||||
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
|
||||
) -> None:
|
||||
@ -59,6 +62,7 @@ def _validate_pass_schedule_constraint(
|
||||
f" list."
|
||||
)
|
||||
|
||||
|
||||
def _topological_sort_passes(
|
||||
passes: List[Callable], constraints: List[Callable]
|
||||
) -> List[Callable]:
|
||||
@ -75,7 +79,7 @@ def _topological_sort_passes(
|
||||
return passes
|
||||
|
||||
# Contruct a graph mapping nodes to a list of their users
|
||||
graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
|
||||
graph: Dict[Callable, List[Callable]] = {p: [] for p in passes}
|
||||
indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
|
||||
candidates: Queue = Queue()
|
||||
for a in passes:
|
||||
@ -108,11 +112,14 @@ def _topological_sort_passes(
|
||||
# Check if there are unvisited nodes (aka cycles in the graph)
|
||||
cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
|
||||
if len(cycle_passes) != 0:
|
||||
error = f"Circular dependency detected within the following passes: {cycle_passes}"
|
||||
error = (
|
||||
f"Circular dependency detected within the following passes: {cycle_passes}"
|
||||
)
|
||||
raise RuntimeError(error)
|
||||
|
||||
return sorted_passes
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
|
||||
"""
|
||||
@ -123,9 +130,7 @@ def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable
|
||||
```
|
||||
passes = [pass_b, pass_a]
|
||||
|
||||
constraints = [
|
||||
this_before_that_pass_constraint(pass_a, pass_b)
|
||||
]
|
||||
constraints = [this_before_that_pass_constraint(pass_a, pass_b)]
|
||||
```
|
||||
|
||||
Args:
|
||||
@ -231,7 +236,9 @@ class PassManager:
|
||||
sig = inspect.signature(check)
|
||||
|
||||
if len(list(sig.parameters.values())) != 1:
|
||||
raise TypeError("PassManager check function should only take in one variable, a module")
|
||||
raise TypeError(
|
||||
"PassManager check function should only take in one variable, a module"
|
||||
)
|
||||
|
||||
setattr(self, "check", check) # noqa: B010
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.node import map_arg
|
||||
|
||||
@ -21,6 +20,7 @@ from .tools_common import (
|
||||
Tensors,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FxNetMinimizerBadModuleError",
|
||||
"FxNetMinimizerRunFuncError",
|
||||
@ -37,7 +37,6 @@ class FxNetMinimizerBadModuleError(Exception):
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class FxNetMinimizerRunFuncError(Exception):
|
||||
"""
|
||||
@ -45,7 +44,6 @@ class FxNetMinimizerRunFuncError(Exception):
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class FxNetMinimizerResultMismatchError(Exception):
|
||||
"""
|
||||
@ -53,7 +51,6 @@ class FxNetMinimizerResultMismatchError(Exception):
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MinimizerSettingBase:
|
||||
"""
|
||||
@ -109,14 +106,9 @@ class _MinimizerBase:
|
||||
],
|
||||
settings: _MinimizerSettingBase,
|
||||
module_exporter: Optional[
|
||||
Callable[
|
||||
[Tensors, torch.fx.GraphModule, str],
|
||||
None
|
||||
]
|
||||
] = None,
|
||||
exclusion_fn: Optional[
|
||||
Callable[[NodeList, int, int], None]
|
||||
Callable[[Tensors, torch.fx.GraphModule, str], None]
|
||||
] = None,
|
||||
exclusion_fn: Optional[Callable[[NodeList, int, int], None]] = None,
|
||||
):
|
||||
assert isinstance(module, torch.fx.GraphModule)
|
||||
|
||||
@ -159,14 +151,18 @@ class _MinimizerBase:
|
||||
self.a_outputs[name] = sample_input[i]
|
||||
self.b_outputs[name] = sample_input[i]
|
||||
|
||||
def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors:
|
||||
def run_a(
|
||||
self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1
|
||||
) -> TensorOrTensors:
|
||||
"""
|
||||
Run `mod` with `inputs` and generate output. The output will be compared with
|
||||
output of run_b().
|
||||
"""
|
||||
raise RuntimeError("run_a() is not implemented.")
|
||||
|
||||
def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors:
|
||||
def run_b(
|
||||
self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1
|
||||
) -> TensorOrTensors:
|
||||
"""
|
||||
Run `mod` with `inputs` and generate output. The output will be compared with
|
||||
output of run_a().
|
||||
@ -323,7 +319,7 @@ class _MinimizerBase:
|
||||
split_module: torch.fx.GraphModule,
|
||||
submod_name: str,
|
||||
output_names: Names,
|
||||
report_idx: int = -1
|
||||
report_idx: int = -1,
|
||||
):
|
||||
"""
|
||||
Run the submodule in `split_module` that has name `submod_name`
|
||||
@ -388,10 +384,14 @@ class _MinimizerBase:
|
||||
report.append(f"Result mismatch for {result_key}")
|
||||
if self.module_exporter:
|
||||
self.module_exporter(
|
||||
a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index]
|
||||
a_input,
|
||||
submodule,
|
||||
str(result_key[0]) + "_cpu", # type: ignore[index]
|
||||
)
|
||||
self.module_exporter(
|
||||
b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index]
|
||||
b_input,
|
||||
submodule,
|
||||
str(result_key[0]) + "_acc", # type: ignore[index]
|
||||
)
|
||||
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
|
||||
|
||||
@ -418,7 +418,7 @@ class _MinimizerBase:
|
||||
self.reports.append(report)
|
||||
report.append(f"Binary search iteration {self.iteration}")
|
||||
report.append(
|
||||
f"From node index {start_idx}:{first_node_name} to {end_idx-1}:{output_node_name}. "
|
||||
f"From node index {start_idx}:{first_node_name} to {end_idx - 1}:{output_node_name}. "
|
||||
f"Size of the interested node list is {len(nodes)}"
|
||||
)
|
||||
cur_nodes: NodeSet = set(nodes)
|
||||
@ -428,7 +428,6 @@ class _MinimizerBase:
|
||||
self._run_and_compare(split_module, submod_name, [output_node_name])
|
||||
|
||||
except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
|
||||
|
||||
if len(nodes) == 1:
|
||||
report.append(
|
||||
f"This is the last node in the sub-module. "
|
||||
@ -504,13 +503,13 @@ class _MinimizerBase:
|
||||
split_module, submod_name = self._build_submodule(cur_nodes)
|
||||
self._run_and_compare(split_module, submod_name, [node.name])
|
||||
self.print_report(report)
|
||||
except (FxNetMinimizerResultMismatchError):
|
||||
except FxNetMinimizerResultMismatchError:
|
||||
culprits.add(node)
|
||||
report.append(f"Found culprit from numeric error: {node}")
|
||||
self.print_report(report)
|
||||
if not self.settings.find_all:
|
||||
return culprits
|
||||
except (FxNetMinimizerRunFuncError):
|
||||
except FxNetMinimizerRunFuncError:
|
||||
culprits.update(cur_nodes)
|
||||
report.append(f"Found culprit from run error: {node}")
|
||||
self.print_report(report)
|
||||
@ -519,8 +518,9 @@ class _MinimizerBase:
|
||||
|
||||
return culprits
|
||||
|
||||
|
||||
def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool) -> int:
|
||||
def _block_traverse_impl(
|
||||
self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool
|
||||
) -> int:
|
||||
"""
|
||||
Recursive block search implementation.
|
||||
find_last_node: If True, search for the last node which result in numerics difference
|
||||
@ -529,7 +529,7 @@ class _MinimizerBase:
|
||||
report: List[str] = []
|
||||
|
||||
mid = (start_idx + end_idx) // 2
|
||||
cur_nodes_list: NodeList = nodes[:mid + 1] if find_last_node else nodes[mid:]
|
||||
cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:]
|
||||
|
||||
if self.exclusion_fn:
|
||||
self.exclusion_fn(cur_nodes_list, -1, -1)
|
||||
@ -561,16 +561,20 @@ class _MinimizerBase:
|
||||
|
||||
try:
|
||||
split_module, submod_name = self._build_submodule(cur_nodes)
|
||||
self._run_and_compare(split_module, submod_name, [last_node_name], report_idx)
|
||||
self._run_and_compare(
|
||||
split_module, submod_name, [last_node_name], report_idx
|
||||
)
|
||||
except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
|
||||
report.append(f"Culprits found from node {first_node_name} to {last_node_name}.")
|
||||
report.append(
|
||||
f"Culprits found from node {first_node_name} to {last_node_name}."
|
||||
)
|
||||
|
||||
if start_idx == mid:
|
||||
report.extend(
|
||||
[
|
||||
"This is the last node in the sub-module. ",
|
||||
"Search in the current branch is successful with node :",
|
||||
f"{start_idx}, node name: {nodes[start_idx].name}."
|
||||
f"{start_idx}, node name: {nodes[start_idx].name}.",
|
||||
]
|
||||
)
|
||||
self.print_report(report)
|
||||
@ -585,9 +589,13 @@ class _MinimizerBase:
|
||||
if find_last_node:
|
||||
return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
|
||||
else:
|
||||
return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node)
|
||||
return self._block_traverse_impl(
|
||||
nodes, mid + 1, end_idx, find_last_node
|
||||
)
|
||||
else:
|
||||
report.append(f"Culprits not found from node start to {mid}:{nodes[mid].name}.")
|
||||
report.append(
|
||||
f"Culprits not found from node start to {mid}:{nodes[mid].name}."
|
||||
)
|
||||
|
||||
if start_idx == mid:
|
||||
report.extend(
|
||||
@ -607,12 +615,15 @@ class _MinimizerBase:
|
||||
self.print_report(report)
|
||||
|
||||
if find_last_node:
|
||||
return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node)
|
||||
return self._block_traverse_impl(
|
||||
nodes, mid + 1, end_idx, find_last_node
|
||||
)
|
||||
else:
|
||||
return self._block_traverse_impl(nodes, start_idx, mid, find_last_node)
|
||||
|
||||
|
||||
def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> NodeSet:
|
||||
def _block_traverse(
|
||||
self, nodes: NodeList, find_last_node: Optional[bool]
|
||||
) -> NodeSet:
|
||||
"""
|
||||
Traverse topologically sorted node list
|
||||
Find minimium block (start_idx, end_idx) which contains the culprit
|
||||
@ -639,10 +650,7 @@ class _MinimizerBase:
|
||||
self.print_report(last_node_report)
|
||||
end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True)
|
||||
last_node_report.extend(
|
||||
[
|
||||
"Finish Pass 1",
|
||||
f"Find end_idx = {end_idx}:{nodes[end_idx].name}"
|
||||
]
|
||||
["Finish Pass 1", f"Find end_idx = {end_idx}:{nodes[end_idx].name}"]
|
||||
)
|
||||
self.print_report(last_node_report)
|
||||
|
||||
@ -650,25 +658,28 @@ class _MinimizerBase:
|
||||
if run_both or not find_last_node:
|
||||
first_node_report = ["Start searching for first node in culprit"]
|
||||
self.print_report(first_node_report)
|
||||
start_idx = self._block_traverse_impl(nodes[0:end_idx + 1], start_idx, end_idx, False)
|
||||
start_idx = self._block_traverse_impl(
|
||||
nodes[0 : end_idx + 1], start_idx, end_idx, False
|
||||
)
|
||||
first_node_report.append("*" * 50)
|
||||
self.reports.append(first_node_report)
|
||||
first_node_report.extend(
|
||||
[
|
||||
"Finish Pass 2",
|
||||
f"Find start_idx = {start_idx}:{nodes[start_idx].name}"
|
||||
f"Find start_idx = {start_idx}:{nodes[start_idx].name}",
|
||||
]
|
||||
)
|
||||
self.print_report(first_node_report)
|
||||
|
||||
# step 3: form module with minimum culprits
|
||||
culprits.update(nodes[start_idx:end_idx + 1])
|
||||
result_report = [f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"]
|
||||
culprits.update(nodes[start_idx : end_idx + 1])
|
||||
result_report = [
|
||||
f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"
|
||||
]
|
||||
self.reports.append(result_report)
|
||||
self.print_report(result_report)
|
||||
return culprits
|
||||
|
||||
|
||||
def _defined_traverse(self, nodes: NodeList) -> NodeSet:
|
||||
"""
|
||||
run user defined `nodes` and determine if it is a culprit.
|
||||
@ -735,7 +746,9 @@ class _MinimizerBase:
|
||||
|
||||
return culprits
|
||||
|
||||
def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet:
|
||||
def _skip_traverse_impl(
|
||||
self, all_nodes: NodeList, start_idx: int, end_idx: int
|
||||
) -> NodeSet:
|
||||
"""
|
||||
Skip certain nodes in graph based on settings
|
||||
"""
|
||||
@ -754,19 +767,19 @@ class _MinimizerBase:
|
||||
self.iteration += 1
|
||||
report.append(f" Nodes block {self.iteration}.")
|
||||
report.append(
|
||||
f"From node index {start_idx} to {end_idx-1}. "
|
||||
f"From node index {start_idx} to {end_idx - 1}. "
|
||||
f"Size of the interested node list is {len(nodes)}"
|
||||
)
|
||||
|
||||
try:
|
||||
split_module, submod_name = self._build_submodule(cur_nodes)
|
||||
self._run_and_compare(split_module, submod_name, [])
|
||||
except (FxNetMinimizerResultMismatchError):
|
||||
except FxNetMinimizerResultMismatchError:
|
||||
culprits.update(cur_nodes)
|
||||
report.append(f"Found culprit from numeric error: {cur_nodes}")
|
||||
self.print_report(report)
|
||||
return culprits
|
||||
except (FxNetMinimizerRunFuncError):
|
||||
except FxNetMinimizerRunFuncError:
|
||||
culprits.update(cur_nodes)
|
||||
report.append(f"Found culprit from run error: {cur_nodes}")
|
||||
self.print_report(report)
|
||||
@ -776,7 +789,6 @@ class _MinimizerBase:
|
||||
self.print_report(report)
|
||||
return set()
|
||||
|
||||
|
||||
def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet:
|
||||
"""
|
||||
Skip certain nodes in graph based on settings
|
||||
@ -787,7 +799,7 @@ class _MinimizerBase:
|
||||
culprits = set()
|
||||
while idx < num_nodes:
|
||||
node = all_nodes[idx]
|
||||
if (node.name in skip_nodes): # skip the node
|
||||
if node.name in skip_nodes: # skip the node
|
||||
if idx > start_idx:
|
||||
culprits = self._skip_traverse_impl(all_nodes, start_idx, idx)
|
||||
start_idx = idx + 1
|
||||
@ -797,8 +809,6 @@ class _MinimizerBase:
|
||||
|
||||
return culprits
|
||||
|
||||
|
||||
|
||||
def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
|
||||
"""
|
||||
Collect nodes in the model that between nodes with name of `start` and `end`.
|
||||
@ -911,8 +921,10 @@ class _MinimizerBase:
|
||||
return self._accumulate_traverse(nodes)
|
||||
|
||||
if self.settings.traverse_method == "skip":
|
||||
if (skip_nodes is None):
|
||||
raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
|
||||
if skip_nodes is None:
|
||||
raise RuntimeError(
|
||||
"'skip_nodes' can't be None when 'traverse_method' is 'skip'."
|
||||
)
|
||||
return self._skip_traverse(nodes, skip_nodes)
|
||||
|
||||
if self.settings.traverse_method == "defined":
|
||||
|
||||
@ -5,11 +5,19 @@ import typing as t
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
from .shape_prop import TensorMetadata
|
||||
from .tools_common import get_node_target, CALLABLE_NODE_OPS
|
||||
from .tools_common import CALLABLE_NODE_OPS, get_node_target
|
||||
|
||||
|
||||
__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain']
|
||||
__all__ = [
|
||||
"OperatorSupportBase",
|
||||
"OperatorSupport",
|
||||
"create_op_support",
|
||||
"chain",
|
||||
"OpSupports",
|
||||
"any_chain",
|
||||
]
|
||||
|
||||
# fx.Node.target typename, as returned by `get_node_target()`
|
||||
TargetTypeName = str
|
||||
@ -28,6 +36,7 @@ SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class OperatorSupportBase(abc.ABC):
|
||||
"""Interface for determining if a fx.Node is supported by a backend"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_node_supported(
|
||||
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
|
||||
@ -57,10 +66,7 @@ class OperatorSupport(OperatorSupportBase):
|
||||
|
||||
_support_dict: SupportDict
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
support_dict: t.Optional[SupportDict] = None
|
||||
):
|
||||
def __init__(self, support_dict: t.Optional[SupportDict] = None):
|
||||
self._support_dict = support_dict or {}
|
||||
|
||||
def is_node_supported(
|
||||
@ -139,11 +145,13 @@ def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase
|
||||
`IsNodeSupported` has the same call signature as
|
||||
`OperatorSupportBase.is_node_supported`
|
||||
"""
|
||||
|
||||
class FunctionalOperatorSupport(OperatorSupportBase):
|
||||
def is_node_supported(
|
||||
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
|
||||
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
|
||||
) -> bool:
|
||||
return is_node_supported(submodules, node)
|
||||
|
||||
return FunctionalOperatorSupport()
|
||||
|
||||
|
||||
@ -153,11 +161,10 @@ def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
|
||||
instance by evaluating each input `OperatorSupportBase` instance, and returns False if
|
||||
any of it reports False.
|
||||
"""
|
||||
|
||||
def _chain(submods, node) -> bool:
|
||||
return all(
|
||||
x.is_node_supported(submods, node)
|
||||
for x in op_support
|
||||
)
|
||||
return all(x.is_node_supported(submods, node) for x in op_support)
|
||||
|
||||
return create_op_support(_chain)
|
||||
|
||||
|
||||
@ -167,11 +174,10 @@ def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
|
||||
instance by evaluating each input `OperatorSupportBase` instance, and returns True if
|
||||
any of it reports True.
|
||||
"""
|
||||
|
||||
def _any_chain(submods, node) -> bool:
|
||||
return any(
|
||||
x.is_node_supported(submods, node)
|
||||
for x in op_support
|
||||
)
|
||||
return any(x.is_node_supported(submods, node) for x in op_support)
|
||||
|
||||
return create_op_support(_any_chain)
|
||||
|
||||
|
||||
@ -180,6 +186,7 @@ class OpSupports:
|
||||
"""A set of atomic `OperatorSupportBase` instances that can be combined together
|
||||
to form more complex operator support logic.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
|
||||
"""Report a node as non-supported, if any of its arguments is of dtype"""
|
||||
@ -193,6 +200,7 @@ class OpSupports:
|
||||
if arg_dtype == dtype:
|
||||
return False
|
||||
return True
|
||||
|
||||
return create_op_support(_decline_if_input_dtype)
|
||||
|
||||
@classmethod
|
||||
@ -200,16 +208,22 @@ class OpSupports:
|
||||
"""
|
||||
If a node has a name that is in the disallow set, reported it as non-supported.
|
||||
"""
|
||||
|
||||
def _decline_if_node_in_names(
|
||||
submodules: t.Mapping[str, torch.nn.Module],
|
||||
node: torch.fx.Node,
|
||||
) -> bool:
|
||||
return node.name not in disallow_set
|
||||
|
||||
return create_op_support(_decline_if_node_in_names)
|
||||
|
||||
|
||||
def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
|
||||
assert isinstance(arg, torch.fx.Node)
|
||||
tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr]
|
||||
dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
|
||||
dtype = (
|
||||
tensor_meta.dtype
|
||||
if isinstance(tensor_meta, TensorMetadata)
|
||||
else arg.meta["type"]
|
||||
)
|
||||
return dtype
|
||||
|
||||
@ -1,35 +1,59 @@
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
__all__ = [
|
||||
"default_matching",
|
||||
"extract_attrs_for_lowering",
|
||||
"lift_lowering_attrs_to_nodes",
|
||||
]
|
||||
|
||||
__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
|
||||
|
||||
# Matching method matches the attribute name of current version to the attribute name of `target_version`
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def default_matching(name: str, target_version: int) -> str:
|
||||
"""Default matching method
|
||||
"""
|
||||
"""Default matching method"""
|
||||
return name
|
||||
|
||||
|
||||
# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
|
||||
# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
|
||||
# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
|
||||
module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = {
|
||||
torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
|
||||
torch.nn.modules.conv.Conv2d: (
|
||||
1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching
|
||||
1,
|
||||
[
|
||||
"weight",
|
||||
"bias",
|
||||
"kernel_size",
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"groups",
|
||||
"padding_mode",
|
||||
],
|
||||
default_matching,
|
||||
),
|
||||
torch.nn.modules.batchnorm.BatchNorm2d: (
|
||||
2,
|
||||
["weight", "bias", "running_mean", "running_var", "eps"],
|
||||
default_matching,
|
||||
),
|
||||
torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching),
|
||||
torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
|
||||
torch.nn.modules.pooling.MaxPool2d: (
|
||||
1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching
|
||||
1,
|
||||
["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"],
|
||||
default_matching,
|
||||
),
|
||||
torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
|
||||
}
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
|
||||
"""If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
|
||||
@ -41,21 +65,25 @@ def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
|
||||
if type(mod) in module_fetch_book:
|
||||
version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
|
||||
if version < mod._version:
|
||||
raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
|
||||
"please upgrade the module_fetch_book, open an issue and @842974287 "
|
||||
"or report a bug to AIACC team directly.")
|
||||
raise RuntimeError(
|
||||
f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
|
||||
"please upgrade the module_fetch_book, open an issue and @842974287 "
|
||||
"or report a bug to AIACC team directly."
|
||||
)
|
||||
for attr in param_to_fetch:
|
||||
attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
|
||||
else:
|
||||
raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, "
|
||||
"please add it to the module_fetch_book, open an issue and @842974287 "
|
||||
"or report a bug to AIACC team directly.")
|
||||
raise RuntimeError(
|
||||
f"{torch.typename(mod)} is not in the module_fetch_book yet, "
|
||||
"please add it to the module_fetch_book, open an issue and @842974287 "
|
||||
"or report a bug to AIACC team directly."
|
||||
)
|
||||
return attrs_for_lowering
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
|
||||
"""Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.
|
||||
"""
|
||||
"""Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module."""
|
||||
submodules = dict(fx_module.named_modules())
|
||||
|
||||
for node in fx_module.graph.nodes:
|
||||
@ -63,4 +91,6 @@ def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
|
||||
if isinstance(submodules[node.target], GraphModule):
|
||||
lift_lowering_attrs_to_nodes(submodules[node.target])
|
||||
else:
|
||||
node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target])
|
||||
node.attrs_for_lowering = extract_attrs_for_lowering(
|
||||
submodules[node.target]
|
||||
)
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from functools import wraps
|
||||
from inspect import unwrap
|
||||
from typing import Callable, List, Optional
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,6 +16,7 @@ __all__ = [
|
||||
"these_before_those_pass_constraint",
|
||||
]
|
||||
|
||||
|
||||
# for callables which modify object inplace and return something other than
|
||||
# the object on which they act
|
||||
def inplace_wrapper(fn: Callable) -> Callable:
|
||||
@ -36,6 +38,7 @@ def inplace_wrapper(fn: Callable) -> Callable:
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
def log_hook(fn: Callable, level=logging.INFO) -> Callable:
|
||||
"""
|
||||
Logs callable output.
|
||||
@ -48,16 +51,13 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable:
|
||||
```
|
||||
def my_pass(d: Dict) -> bool:
|
||||
changed = False
|
||||
if 'foo' in d:
|
||||
d['foo'] = 'bar'
|
||||
if "foo" in d:
|
||||
d["foo"] = "bar"
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
pm = PassManager(
|
||||
passes=[
|
||||
inplace_wrapper(log_hook(my_pass))
|
||||
]
|
||||
)
|
||||
|
||||
pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))])
|
||||
```
|
||||
|
||||
Args:
|
||||
@ -67,6 +67,7 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable:
|
||||
Returns:
|
||||
wrapped_fn (Callable[Type1, Type2])
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_fn(gm):
|
||||
val = fn(gm)
|
||||
@ -76,8 +77,11 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable:
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
|
||||
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
|
||||
def loop_pass(
|
||||
base_pass: Callable,
|
||||
n_iter: Optional[int] = None,
|
||||
predicate: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Convenience wrapper for passes which need to be applied multiple times.
|
||||
|
||||
@ -154,9 +158,7 @@ def these_before_those_pass_constraint(these: Callable, those: Callable):
|
||||
loop_pass(pass_a, 5),
|
||||
]
|
||||
|
||||
constraints = [
|
||||
these_before_those_pass_constraint(pass_a, pass_b)
|
||||
]
|
||||
constraints = [these_before_those_pass_constraint(pass_a, pass_b)]
|
||||
```
|
||||
|
||||
Args:
|
||||
|
||||
@ -1,32 +1,38 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import _operator
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Set
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.fx import Node
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
|
||||
from torch.utils._pytree import tree_map_only
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
import _operator
|
||||
from enum import Enum
|
||||
import itertools
|
||||
from typing import Set, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
__all__ = ['reinplace']
|
||||
__all__ = ["reinplace"]
|
||||
|
||||
|
||||
class _ViewType(Enum):
|
||||
NonView = 0
|
||||
SingleOutputView = 1
|
||||
MultiOutputView = 2
|
||||
|
||||
|
||||
def _is_view_op(tgt):
|
||||
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
|
||||
schema = tgt._schema
|
||||
if len(schema.arguments) > 0:
|
||||
first_arg = schema.arguments[0]
|
||||
# check if op is a view
|
||||
return first_arg.alias_info is not None and not first_arg.alias_info.is_write
|
||||
return (
|
||||
first_arg.alias_info is not None and not first_arg.alias_info.is_write
|
||||
)
|
||||
|
||||
|
||||
def _get_view_type(tgt) -> _ViewType:
|
||||
if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
|
||||
@ -36,7 +42,7 @@ def _get_view_type(tgt) -> _ViewType:
|
||||
# check if op is a view
|
||||
if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
|
||||
# check if op is a multi-output view
|
||||
if '*' in first_arg.alias_info.after_set:
|
||||
if "*" in first_arg.alias_info.after_set:
|
||||
return _ViewType.MultiOutputView
|
||||
else:
|
||||
return _ViewType.SingleOutputView
|
||||
@ -54,12 +60,11 @@ def _get_view_type(tgt) -> _ViewType:
|
||||
# to sanity check that our aliasing information is correct.
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class _FunctionalizationMetadataProp(torch.fx.Interpreter):
|
||||
|
||||
def run_node(self, node: Node):
|
||||
self.node_counter += 1
|
||||
result = super().run_node(node)
|
||||
node.meta['fake_result'] = result
|
||||
node.meta['node_idx'] = self.node_counter
|
||||
node.meta["fake_result"] = result
|
||||
node.meta["node_idx"] = self.node_counter
|
||||
|
||||
# (1) Update metadata with the list of nodes that are used by this node
|
||||
# copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
|
||||
@ -69,11 +74,11 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter):
|
||||
node_args = node_args[1:]
|
||||
|
||||
# (2) Update metadata to track aliasing information about view tensor nodes.
|
||||
if node.op == 'call_function':
|
||||
if node.op == "call_function":
|
||||
view_type = _get_view_type(node.target)
|
||||
if view_type == _ViewType.SingleOutputView:
|
||||
assert isinstance(node.args[0], Node)
|
||||
node.meta['view_of'] = node.args[0]
|
||||
node.meta["view_of"] = node.args[0]
|
||||
elif view_type == _ViewType.MultiOutputView:
|
||||
self.multi_output_view_nodes[node] = node.args[0]
|
||||
|
||||
@ -95,38 +100,52 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter):
|
||||
# Note: we could also track indexing info here for multi-output views.
|
||||
# I don't think this metadata is strictly needed for de-functionalization.
|
||||
assert isinstance(maybe_base_of_view, Node)
|
||||
node.meta['view_of'] = maybe_base_of_view
|
||||
node.meta["view_of"] = maybe_base_of_view
|
||||
|
||||
if 'view_of' in node.meta:
|
||||
if "view_of" in node.meta:
|
||||
# We're linking the current node with its first argument as views.
|
||||
# Assert here that this is actually the case, and their storages are the same.
|
||||
assert isinstance(node.meta['fake_result'], FakeTensor)
|
||||
assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
|
||||
view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
|
||||
base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
|
||||
assert isinstance(node.meta["fake_result"], FakeTensor)
|
||||
assert isinstance(node.meta["view_of"].meta["fake_result"], FakeTensor)
|
||||
view_storage = StorageWeakRef(node.meta["fake_result"]._typed_storage())
|
||||
base_storage = StorageWeakRef(
|
||||
node.meta["view_of"].meta["fake_result"]._typed_storage()
|
||||
)
|
||||
assert view_storage == base_storage
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def propagate(self, *args):
|
||||
self.multi_output_view_nodes = {}
|
||||
self.node_counter = -1
|
||||
|
||||
with FakeTensorMode() as mode:
|
||||
fake_args = [mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
fake_args = [
|
||||
mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
|
||||
]
|
||||
return super().run(*fake_args)
|
||||
|
||||
|
||||
def _schemas_match(functional_schema, inplace_schema):
|
||||
names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
|
||||
arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
|
||||
a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
|
||||
names_match = (
|
||||
inplace_schema.name.endswith("_")
|
||||
and inplace_schema.name[:-1] == functional_schema.name
|
||||
)
|
||||
arg_types_match = len(functional_schema.arguments) == len(
|
||||
inplace_schema.arguments
|
||||
) and all(
|
||||
a1.type == a2.type
|
||||
for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)
|
||||
)
|
||||
# for the inplace op, its first argument should be mutable
|
||||
assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
|
||||
assert (
|
||||
inplace_schema.arguments[0].alias_info is not None
|
||||
and inplace_schema.arguments[0].alias_info.is_write
|
||||
)
|
||||
# and its remaining arguments shouldn't be.
|
||||
assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
|
||||
return names_match and arg_types_match
|
||||
|
||||
|
||||
# TODO: this should be beefed up to be able to properly re-inplace with:
|
||||
# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
|
||||
# - out= ops (e.g. angle -> angle.out)
|
||||
@ -143,17 +162,20 @@ def _maybe_get_inplace_op(op):
|
||||
op_namespace = op.__module__.split(".")[-1]
|
||||
op_base_name = op.overloadpacket.__name__
|
||||
maybe_namespace_module = getattr(torch.ops, op_namespace)
|
||||
maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
|
||||
maybe_inplace_op = (
|
||||
None
|
||||
if maybe_namespace_module is None
|
||||
else getattr(maybe_namespace_module, f"{op_base_name}_", None)
|
||||
)
|
||||
if maybe_inplace_op is None:
|
||||
return None
|
||||
|
||||
inplace_overloads = [
|
||||
getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
|
||||
getattr(maybe_inplace_op, overload_name)
|
||||
for overload_name in maybe_inplace_op.overloads()
|
||||
]
|
||||
inplace_overloads_with_matching_schemas = [
|
||||
f
|
||||
for f in inplace_overloads
|
||||
if _schemas_match(op._schema, f._schema)
|
||||
f for f in inplace_overloads if _schemas_match(op._schema, f._schema)
|
||||
]
|
||||
# Just because foo() and foo_() are both existing operators,
|
||||
# They aren't guaranteed to have compatible schemas.
|
||||
@ -165,6 +187,7 @@ def _maybe_get_inplace_op(op):
|
||||
inplace_op = inplace_overloads_with_matching_schemas[0]
|
||||
return inplace_op
|
||||
|
||||
|
||||
_VIEW_INVERSE_MAP = {
|
||||
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
|
||||
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
|
||||
@ -172,6 +195,7 @@ _VIEW_INVERSE_MAP = {
|
||||
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
|
||||
}
|
||||
|
||||
|
||||
# This function, given a set of set of (aliased) tensor nodes,
|
||||
# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
|
||||
# in the node ordering.
|
||||
@ -186,17 +210,21 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
|
||||
usage_nodes = t.users
|
||||
for n in usage_nodes:
|
||||
# We only care about usages after the current node
|
||||
if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index:
|
||||
if "node_idx" not in n.meta or n.meta["node_idx"] <= op_index:
|
||||
continue
|
||||
# We also don't care about intermediate view ops.
|
||||
# They only matter if their output is then used elsewhere
|
||||
# (either in an out-of-place op, or as an output to the function).
|
||||
if n in tensor_aliases:
|
||||
if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
|
||||
if (
|
||||
isinstance(n.target, torch._ops.OpOverload)
|
||||
or n.target == _operator.getitem
|
||||
):
|
||||
continue
|
||||
nodes_used_after.add(n)
|
||||
return nodes_used_after
|
||||
|
||||
|
||||
# Given an op that we're trying to re-inplace, "b = foo(a)",
|
||||
# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
|
||||
# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
|
||||
@ -204,23 +232,27 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
|
||||
# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
|
||||
# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
|
||||
# as "alias"
|
||||
def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
|
||||
def _get_view_inverse_node_usages(
|
||||
later_node_usages: Set[Node], self_aliases: Set[Node]
|
||||
) -> Set[Node]:
|
||||
def matching_view_metadata(a, b):
|
||||
return a.size() == b.size() and \
|
||||
a.stride() == b.stride() and \
|
||||
a.storage_offset() == b.storage_offset()
|
||||
return (
|
||||
a.size() == b.size()
|
||||
and a.stride() == b.stride()
|
||||
and a.storage_offset() == b.storage_offset()
|
||||
)
|
||||
|
||||
view_inverse_nodes = set()
|
||||
# Go through them in node order, so we can see chains of view_scatter ops.
|
||||
for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
|
||||
for n in sorted(later_node_usages, key=lambda x: x.meta["node_idx"]):
|
||||
if n.target not in _VIEW_INVERSE_MAP:
|
||||
continue
|
||||
base = n.args[0]
|
||||
mutated_view = n.args[1]
|
||||
assert isinstance(base, Node)
|
||||
assert isinstance(base.meta['fake_result'], FakeTensor)
|
||||
assert isinstance(base.meta["fake_result"], FakeTensor)
|
||||
assert isinstance(mutated_view, Node)
|
||||
assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
|
||||
assert isinstance(mutated_view.meta["fake_result"], FakeTensor)
|
||||
# Check that this view_inverse op actually corresponds to taking doing the inverse
|
||||
# of one of our existing self_alias nodes.
|
||||
original_view = _VIEW_INVERSE_MAP[n.target]
|
||||
@ -229,18 +261,21 @@ def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Se
|
||||
# that was created from some op `alias = foo(base, args...)`
|
||||
# such that the current _scatter op "inverts" that foo call.
|
||||
# We can check that by running the original op again, and checking that the strides match.
|
||||
if 'view_of' not in self_alias.meta:
|
||||
if "view_of" not in self_alias.meta:
|
||||
continue
|
||||
self_alias_base = self_alias.meta['view_of']
|
||||
self_alias_base = self_alias.meta["view_of"]
|
||||
try:
|
||||
# The we're trying to re-use the args from the view_scatter call inside of the corresponding
|
||||
# view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
|
||||
# of the current alias we're looking at.
|
||||
view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
|
||||
expected_metadata = self_alias.meta['fake_result']
|
||||
view_replay_metadata = original_view(
|
||||
self_alias_base.meta["fake_result"], *n.args[2:], **n.kwargs
|
||||
)
|
||||
expected_metadata = self_alias.meta["fake_result"]
|
||||
# If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
|
||||
if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
|
||||
matching_view_metadata(view_replay_metadata, expected_metadata):
|
||||
if matching_view_metadata(
|
||||
self_alias_base.meta["fake_result"], base.meta["fake_result"]
|
||||
) and matching_view_metadata(view_replay_metadata, expected_metadata):
|
||||
view_inverse_nodes.add(n)
|
||||
except Exception:
|
||||
continue
|
||||
@ -471,25 +506,29 @@ def reinplace(gm, *sample_args):
|
||||
# NOTE: later, we'll need to add an optimization for fully recovering performance
|
||||
# on programs that mutate inputs.
|
||||
input_storages = {
|
||||
StorageWeakRef(
|
||||
node.meta['fake_result']._typed_storage()
|
||||
) for node in gm.graph.nodes if (node.op == 'placeholder' and isinstance(node.meta['fake_result'], torch.Tensor))}
|
||||
StorageWeakRef(node.meta["fake_result"]._typed_storage())
|
||||
for node in gm.graph.nodes
|
||||
if (
|
||||
node.op == "placeholder"
|
||||
and isinstance(node.meta["fake_result"], torch.Tensor)
|
||||
)
|
||||
}
|
||||
|
||||
# We also need to know for a given node, what are all of its aliasing nodes.
|
||||
storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
|
||||
for n in gm.graph.nodes:
|
||||
if 'fake_result' in n.meta:
|
||||
if "fake_result" in n.meta:
|
||||
# Tree-mapping because some ops can return lists of tensors.
|
||||
def _add_to_map(x):
|
||||
if isinstance(x, FakeTensor):
|
||||
storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
|
||||
pytree.tree_map_(_add_to_map, n.meta['fake_result'])
|
||||
|
||||
pytree.tree_map_(_add_to_map, n.meta["fake_result"])
|
||||
|
||||
# inplace-ify functional ops, subject to the constraints written below.
|
||||
all_later_view_inverse_nodes_to_delete = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'call_function':
|
||||
|
||||
if node.op == "call_function":
|
||||
# Today, the re-inplace pass on directly acts on:
|
||||
# - functional ops with an inplace variant
|
||||
# - {view}_scatter ops that can be potentially removed from the graph.
|
||||
@ -512,8 +551,8 @@ def reinplace(gm, *sample_args):
|
||||
# (We could potentially swizzle this into larger_tensor.add_(scalar_tensor),
|
||||
# this is probably an optimization to revisit later).
|
||||
self_arg = node.args[0]
|
||||
self_flattened = pytree.tree_leaves(self_arg.meta['fake_result'])
|
||||
node_flattened = pytree.tree_leaves(node.meta['fake_result'])
|
||||
self_flattened = pytree.tree_leaves(self_arg.meta["fake_result"])
|
||||
node_flattened = pytree.tree_leaves(node.meta["fake_result"])
|
||||
self_has_wrong_metadata = False
|
||||
if len(self_flattened) == len(node_flattened):
|
||||
for self_meta, node_meta in zip(self_flattened, node_flattened):
|
||||
@ -532,7 +571,9 @@ def reinplace(gm, *sample_args):
|
||||
continue
|
||||
|
||||
# Step 1b: ensure that the op we're trying to re-inplace isn't a program input
|
||||
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
|
||||
self_arg_storage = StorageWeakRef(
|
||||
self_arg.meta["fake_result"]._typed_storage()
|
||||
)
|
||||
if self_arg_storage in input_storages:
|
||||
# TODO: later, add the optimization for handling `copy_()` calls in the graph.
|
||||
continue
|
||||
@ -542,14 +583,20 @@ def reinplace(gm, *sample_args):
|
||||
# so we prevent re-inplacing in this case.
|
||||
continue
|
||||
|
||||
self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
|
||||
self_arg_storage = StorageWeakRef(
|
||||
self_arg.meta["fake_result"]._typed_storage()
|
||||
)
|
||||
self_aliases = storage_to_nodes[self_arg_storage]
|
||||
|
||||
# First, we find all later usages of any of the aliases of self_arg.
|
||||
later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
|
||||
later_node_usages = _get_all_later_node_usages(
|
||||
self_aliases, node.meta["node_idx"]
|
||||
)
|
||||
# Then, we check if any of those later usages are actually view_scatter ops
|
||||
# that are safe to fully remove.
|
||||
later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
|
||||
later_view_inverse_node_usages = _get_view_inverse_node_usages(
|
||||
later_node_usages, self_aliases
|
||||
)
|
||||
|
||||
# Step 2: Check to see if the input to the op is re-used later in the graph.
|
||||
# If not (same goes for its aliases), then this op is safe to re-in place.
|
||||
@ -565,7 +612,10 @@ def reinplace(gm, *sample_args):
|
||||
# we would prefer to remove it from the graph entirely,
|
||||
# and instead copy_() the slice directly into the larger tensor.
|
||||
# See the description of the algorithm for a full example.
|
||||
if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete:
|
||||
if (
|
||||
node.target in _VIEW_INVERSE_MAP
|
||||
and node not in all_later_view_inverse_nodes_to_delete
|
||||
):
|
||||
view_op = _VIEW_INVERSE_MAP[node.target]
|
||||
# Before:
|
||||
# base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
|
||||
@ -576,13 +626,23 @@ def reinplace(gm, *sample_args):
|
||||
mutated_slice_node = node.args[1]
|
||||
remaining_slice_args = node.args[2:]
|
||||
slice_node = gm.graph.create_node(
|
||||
'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
|
||||
"call_function",
|
||||
view_op,
|
||||
(self_arg,) + tuple(remaining_slice_args),
|
||||
node.kwargs,
|
||||
)
|
||||
gm.graph.create_node(
|
||||
'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
|
||||
"call_function",
|
||||
torch.ops.aten.copy_.default,
|
||||
(
|
||||
slice_node,
|
||||
mutated_slice_node,
|
||||
),
|
||||
{},
|
||||
)
|
||||
# Add the slice_scatter node to our "nodes to delete" list.
|
||||
all_later_view_inverse_nodes_to_delete.add(node)
|
||||
|
||||
|
||||
else:
|
||||
# Step 3b: Check to see if this operator has an inplace variant.
|
||||
maybe_inplace_op = _maybe_get_inplace_op(node.target)
|
||||
@ -597,19 +657,29 @@ def reinplace(gm, *sample_args):
|
||||
# Hmm... morally I think we also want to keep the `fake_result` metadata
|
||||
# up to date here, but I'm not sure how easy it is to do.
|
||||
# Maybe it's fine to wait until the end of the pass to update it.
|
||||
curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
|
||||
storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
|
||||
storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
|
||||
curr_node_storage = StorageWeakRef(
|
||||
node.meta["fake_result"]._typed_storage()
|
||||
)
|
||||
storage_to_nodes[self_arg_storage].update(
|
||||
storage_to_nodes[curr_node_storage]
|
||||
)
|
||||
storage_to_nodes[curr_node_storage].update(
|
||||
storage_to_nodes[self_arg_storage]
|
||||
)
|
||||
|
||||
# Need to remember the view_scatter view nodes we found so we can remove them alter.
|
||||
all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages)
|
||||
all_later_view_inverse_nodes_to_delete.update(
|
||||
later_view_inverse_node_usages
|
||||
)
|
||||
|
||||
# Step 4:
|
||||
# Now that we've replaced b = a.foo() with a.foo_(),
|
||||
# We need to replace any later usages of "b" with "a"
|
||||
for old in itertools.chain([node], later_view_inverse_node_usages):
|
||||
new = old.args[0]
|
||||
nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
|
||||
nodes_to_update = [
|
||||
n for n in old.users if n.meta["node_idx"] > node.meta["node_idx"]
|
||||
]
|
||||
for node_to_update in nodes_to_update:
|
||||
|
||||
def replace_arg(a):
|
||||
@ -618,21 +688,29 @@ def reinplace(gm, *sample_args):
|
||||
return a
|
||||
|
||||
# First, replace usages of "b" with "a"
|
||||
node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
|
||||
node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
|
||||
node_to_update.args = tree_map_only(
|
||||
Node, replace_arg, node_to_update.args
|
||||
)
|
||||
node_to_update.kwargs = tree_map_only(
|
||||
Node, replace_arg, node_to_update.kwargs
|
||||
)
|
||||
|
||||
# Second, update our storage_to_nodes data structure.
|
||||
old_flattened_res = pytree.tree_leaves(old.meta['fake_result'])
|
||||
node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result'])
|
||||
old_flattened_res = pytree.tree_leaves(old.meta["fake_result"])
|
||||
node_flattened_res = pytree.tree_leaves(
|
||||
node_to_update.meta["fake_result"]
|
||||
)
|
||||
|
||||
old_res_storage = {
|
||||
StorageWeakRef(
|
||||
x._typed_storage()
|
||||
) for x in old_flattened_res if isinstance(x, FakeTensor)}
|
||||
StorageWeakRef(x._typed_storage())
|
||||
for x in old_flattened_res
|
||||
if isinstance(x, FakeTensor)
|
||||
}
|
||||
node_res_storage = {
|
||||
StorageWeakRef(
|
||||
x._typed_storage()
|
||||
) for x in node_flattened_res if isinstance(x, FakeTensor)}
|
||||
StorageWeakRef(x._typed_storage())
|
||||
for x in node_flattened_res
|
||||
if isinstance(x, FakeTensor)
|
||||
}
|
||||
|
||||
# This will happen if we're updating a view op, e.g.
|
||||
# e.g. replacing
|
||||
@ -644,12 +722,17 @@ def reinplace(gm, *sample_args):
|
||||
# We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor,
|
||||
# or multiple tensors that all share the same storage.
|
||||
# We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
|
||||
if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
|
||||
new_flattened_res = pytree.tree_leaves(new.meta['fake_result'])
|
||||
if (
|
||||
len(old_res_storage) == 1
|
||||
and len(node_res_storage) == 1
|
||||
and old_res_storage == node_res_storage
|
||||
):
|
||||
new_flattened_res = pytree.tree_leaves(new.meta["fake_result"])
|
||||
new_res_storage = {
|
||||
StorageWeakRef(
|
||||
x._typed_storage()
|
||||
) for x in new_flattened_res if isinstance(x, FakeTensor)}
|
||||
StorageWeakRef(x._typed_storage())
|
||||
for x in new_flattened_res
|
||||
if isinstance(x, FakeTensor)
|
||||
}
|
||||
assert len(new_res_storage) == 1
|
||||
(new_ref,) = new_res_storage
|
||||
(node_ref,) = node_res_storage
|
||||
@ -666,6 +749,5 @@ def reinplace(gm, *sample_args):
|
||||
for to_delete in all_later_view_inverse_nodes_to_delete:
|
||||
gm.graph.erase_node(to_delete)
|
||||
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import traceback
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import traceback
|
||||
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch.fx.node import Node, map_aggregate
|
||||
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.node import map_aggregate, Node
|
||||
|
||||
|
||||
__all__ = ["TensorMetadata", "ShapeProp"]
|
||||
|
||||
__all__ = ['TensorMetadata', 'ShapeProp']
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class TensorMetadata(NamedTuple):
|
||||
@ -19,17 +21,20 @@ class TensorMetadata(NamedTuple):
|
||||
# about a tensor within a PyTorch program.
|
||||
|
||||
# General Tensor metadata
|
||||
shape : torch.Size
|
||||
dtype : torch.dtype
|
||||
requires_grad : bool
|
||||
stride : Tuple[int, ...]
|
||||
memory_format : Optional[torch.memory_format]
|
||||
shape: torch.Size
|
||||
dtype: torch.dtype
|
||||
requires_grad: bool
|
||||
stride: Tuple[int, ...]
|
||||
memory_format: Optional[torch.memory_format]
|
||||
|
||||
# Quantization metadata
|
||||
is_quantized : bool
|
||||
is_quantized: bool
|
||||
qparams: Dict[str, Any]
|
||||
|
||||
def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
|
||||
|
||||
def _extract_tensor_metadata(
|
||||
result: torch.Tensor, include_contiguity=True
|
||||
) -> TensorMetadata:
|
||||
"""
|
||||
Extract a TensorMetadata NamedTuple describing `result`.
|
||||
"""
|
||||
@ -59,7 +64,11 @@ def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) ->
|
||||
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
|
||||
qparams["scale"] = result.q_scale() # type: ignore[assignment]
|
||||
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
|
||||
elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
|
||||
elif qscheme in {
|
||||
torch.per_channel_affine,
|
||||
torch.per_channel_affine_float_qparams,
|
||||
torch.per_channel_symmetric,
|
||||
}:
|
||||
# In this branch, scale and zero_point are expected to be tensors,
|
||||
# we store the values as immutable_list in TensorMetadata for
|
||||
# easier serialization downstream
|
||||
@ -68,7 +77,9 @@ def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) ->
|
||||
qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
|
||||
|
||||
return TensorMetadata(
|
||||
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
|
||||
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams
|
||||
)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ShapeProp(torch.fx.Interpreter):
|
||||
@ -117,12 +128,14 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
fake_mode (FakeTensorMode): A fake mode for copying the gm
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, gm, fake_mode=None):
|
||||
super().__init__(gm)
|
||||
if fake_mode is None:
|
||||
fake_mode = detect_fake_mode()
|
||||
if fake_mode is not None:
|
||||
from torch._dynamo.utils import deepcopy_to_fake_tensor
|
||||
|
||||
# Note:
|
||||
# We need fake execution cause the inputs are fake, however, we cannot fakify the module
|
||||
# - because we need to write to the tensor_meta of the real module. So we fakify to
|
||||
@ -140,7 +153,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
|
||||
self.real_module = self.module
|
||||
|
||||
def run_node(self, n : Node) -> Any:
|
||||
def run_node(self, n: Node) -> Any:
|
||||
try:
|
||||
if self.fake_module is not None:
|
||||
# Hacky swap. Alternatively, we could do this with overriding
|
||||
@ -157,8 +170,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise RuntimeError(
|
||||
f"ShapeProp error for: node={n.format_node()} with "
|
||||
f"meta={n.meta}"
|
||||
f"ShapeProp error for: node={n.format_node()} with " f"meta={n.meta}"
|
||||
) from e
|
||||
|
||||
found_tensor = False
|
||||
@ -173,9 +185,9 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
|
||||
meta = map_aggregate(result, extract_tensor_meta)
|
||||
if found_tensor:
|
||||
n.meta['tensor_meta'] = meta
|
||||
n.meta["tensor_meta"] = meta
|
||||
|
||||
n.meta['type'] = type(result)
|
||||
n.meta["type"] = type(result)
|
||||
return result
|
||||
|
||||
def propagate(self, *args):
|
||||
@ -190,7 +202,10 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
if self.fake_mode is not None:
|
||||
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
]
|
||||
else:
|
||||
fake_args = args
|
||||
return super().run(*fake_args)
|
||||
|
||||
@ -1,19 +1,20 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx._utils import lazy_format_graph_code
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
from torch.fx._utils import lazy_format_graph_code
|
||||
|
||||
|
||||
__all__ = ["Partition", "split_module"]
|
||||
log = _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Partition:
|
||||
def __init__(self, name: str):
|
||||
@ -146,9 +147,7 @@ def split_module(
|
||||
|
||||
log.debug(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"pre split_module", m, colored=True
|
||||
),
|
||||
lazy_format_graph_code("pre split_module", m, colored=True),
|
||||
)
|
||||
|
||||
def construct_graph(
|
||||
@ -161,11 +160,20 @@ def split_module(
|
||||
node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||
)
|
||||
if keep_original_node_name:
|
||||
args = () if default_value is inspect.Signature.empty else (default_value,)
|
||||
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type]
|
||||
args = (
|
||||
() if default_value is inspect.Signature.empty else (default_value,)
|
||||
)
|
||||
base_mod_env[node.name] = base_mod_graph.create_node(
|
||||
"placeholder",
|
||||
node.name,
|
||||
args=args, # type: ignore[arg-type]
|
||||
type_expr=node.type,
|
||||
)
|
||||
else:
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(
|
||||
node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type]
|
||||
node.target, # type: ignore[arg-type]
|
||||
type_expr=node.type,
|
||||
default_value=default_value,
|
||||
)
|
||||
base_mod_env[node.name].meta = node.meta.copy()
|
||||
elif node.op == "get_attr":
|
||||
@ -185,9 +193,7 @@ def split_module(
|
||||
orig_nodes: Dict[str, Node] = {}
|
||||
symbol_to_node: Dict[sympy.Symbol, Node] = {}
|
||||
|
||||
def record_cross_partition_use(
|
||||
def_node: Node, use_node: Optional[Node]
|
||||
): # noqa: B950
|
||||
def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
|
||||
from torch.fx.experimental.symbolic_shapes import free_symbols
|
||||
|
||||
defined = getattr(def_node, "_fx_partition", None)
|
||||
@ -195,7 +201,10 @@ def split_module(
|
||||
|
||||
log.debug(
|
||||
"record_cross_partition_use %s (%s) %s (%s)",
|
||||
def_node.name, defined, use_node.name if use_node is not None else "-", used
|
||||
def_node.name,
|
||||
defined,
|
||||
use_node.name if use_node is not None else "-",
|
||||
used,
|
||||
)
|
||||
|
||||
if defined != used:
|
||||
@ -234,7 +243,9 @@ def split_module(
|
||||
|
||||
def instantiate_node_partition_mapping(node):
|
||||
partition_name = str(split_callback(node))
|
||||
log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name)
|
||||
log.debug(
|
||||
"instantiate_node_partition_mapping %s (%s)", node.name, partition_name
|
||||
)
|
||||
|
||||
# add node to partitions
|
||||
partition = partitions.get(partition_name)
|
||||
@ -249,7 +260,7 @@ def split_module(
|
||||
GLOBAL_STATE_NODES = [
|
||||
torch.amp._enter_autocast,
|
||||
torch.amp._exit_autocast,
|
||||
torch._C._set_grad_enabled
|
||||
torch._C._set_grad_enabled,
|
||||
]
|
||||
|
||||
# For grad regions:
|
||||
@ -280,10 +291,10 @@ def split_module(
|
||||
# rely on later, but this needs some extra work. Quick fix first.
|
||||
# See https://github.com/pytorch/pytorch/issues/130534
|
||||
if (
|
||||
(val := node.meta.get("example_value")) is not None and
|
||||
isinstance(val, torch.SymInt) and
|
||||
isinstance(s0 := val.node.expr, sympy.Symbol) and
|
||||
s0 not in symbol_to_node
|
||||
(val := node.meta.get("example_value")) is not None
|
||||
and isinstance(val, torch.SymInt)
|
||||
and isinstance(s0 := val.node.expr, sympy.Symbol)
|
||||
and s0 not in symbol_to_node
|
||||
):
|
||||
symbol_to_node[val.node.expr] = node
|
||||
|
||||
@ -344,9 +355,10 @@ def split_module(
|
||||
|
||||
if assert_monotonically_increasing:
|
||||
pid = split_callback(node)
|
||||
assert highest_partition <= pid, \
|
||||
("autocast or set_grad_enabled require monotonically increasing partitions:"
|
||||
f"highest: {highest_partition}, this node's: {pid}")
|
||||
assert highest_partition <= pid, (
|
||||
"autocast or set_grad_enabled require monotonically increasing partitions:"
|
||||
f"highest: {highest_partition}, this node's: {pid}"
|
||||
)
|
||||
highest_partition = pid
|
||||
|
||||
# do not capture cross-partition dependencies for global state nodes as they will be
|
||||
@ -392,7 +404,9 @@ def split_module(
|
||||
kwargs={},
|
||||
type_expr=node.type,
|
||||
)
|
||||
new_node.meta = node.meta.copy() # is it really a good idea to copy this?
|
||||
new_node.meta = (
|
||||
node.meta.copy()
|
||||
) # is it really a good idea to copy this?
|
||||
partition.environment[node] = new_node
|
||||
|
||||
# add placeholders to partition inputs
|
||||
@ -425,7 +439,9 @@ def split_module(
|
||||
target_attr = m
|
||||
for atom in target_atoms:
|
||||
if not hasattr(target_attr, atom):
|
||||
raise AttributeError(f"Operator target {node.target} not found!")
|
||||
raise AttributeError(
|
||||
f"Operator target {node.target} not found!"
|
||||
)
|
||||
target_attr = getattr(target_attr, atom)
|
||||
# target = target_atoms[-1]
|
||||
target = "_".join(target_atoms)
|
||||
@ -467,7 +483,9 @@ def split_module(
|
||||
kwargs={},
|
||||
type_expr=exit_node.type,
|
||||
)
|
||||
new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this?
|
||||
new_node.meta = (
|
||||
exit_node.meta.copy()
|
||||
) # is it really a good idea to copy this?
|
||||
|
||||
# original module environment dict mapping node names to nodes
|
||||
orig_mod_env: Dict[str, Node] = {}
|
||||
@ -520,7 +538,9 @@ def split_module(
|
||||
if keep_original_order:
|
||||
# first get the attr nodes required by this partition
|
||||
orig_mod_attr_nodes: List[Node] = [
|
||||
orig_mod_env[key] for key in partition.inputs if key not in original_order
|
||||
orig_mod_env[key]
|
||||
for key in partition.inputs
|
||||
if key not in original_order
|
||||
]
|
||||
|
||||
for node in original_order:
|
||||
@ -568,8 +588,6 @@ def split_module(
|
||||
ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
||||
log.debug(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
"post split_module", ret, colored=True
|
||||
),
|
||||
lazy_format_graph_code("post split_module", ret, colored=True),
|
||||
)
|
||||
return ret
|
||||
|
||||
@ -10,6 +10,7 @@ from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
|
||||
|
||||
from .tools_common import NodeList
|
||||
|
||||
|
||||
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
|
||||
|
||||
|
||||
|
||||
@ -1,40 +1,44 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.passes.graph_manipulation import get_size_of_node
|
||||
from torch.fx.node import map_arg
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.node import map_arg
|
||||
from torch.fx.passes.graph_manipulation import get_size_of_node
|
||||
|
||||
from .operator_support import (
|
||||
get_node_target,
|
||||
OperatorSupportBase,
|
||||
)
|
||||
from .graph_drawer import FxGraphDrawer
|
||||
from .operator_support import get_node_target, OperatorSupportBase
|
||||
from .shape_prop import ShapeProp
|
||||
from .split_utils import split_by_tags
|
||||
from .tools_common import (
|
||||
FxNetAccFusionsFinder,
|
||||
CALLABLE_NODE_OPS,
|
||||
Tensors,
|
||||
FxNetAccFusionsFinder,
|
||||
is_node_output_tensor,
|
||||
NodeList,
|
||||
NodeSet,
|
||||
is_node_output_tensor,
|
||||
Tensors,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
|
||||
__all__ = [
|
||||
"FxNetAccNodesFinder",
|
||||
"FxNetSplitterInternalError",
|
||||
"Subgraph",
|
||||
"SplitResult",
|
||||
"generate_inputs_for_submodules",
|
||||
]
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MIN_ACC_MODULE_SIZE = 1
|
||||
DEFAULT_SKIP_FUSION = False
|
||||
DEFAULT_ALLOW_NON_TENSOR = False
|
||||
|
||||
|
||||
class _SplitterSettingBase:
|
||||
def __init__(
|
||||
self,
|
||||
@ -82,9 +86,15 @@ class _SplitterSettingBase:
|
||||
)
|
||||
args, _unknown = parser.parse_known_args()
|
||||
|
||||
self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
|
||||
self.min_acc_module_size: int = (
|
||||
args.min_acc_module_size
|
||||
if args.min_acc_module_size
|
||||
else min_acc_module_size
|
||||
)
|
||||
self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
|
||||
self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
|
||||
self.allow_non_tensor: bool = (
|
||||
args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
|
||||
)
|
||||
self.max_acc_splits: int = max_acc_splits
|
||||
|
||||
|
||||
@ -114,9 +124,7 @@ class FxNetAccNodesFinder:
|
||||
self.allow_non_tensor = allow_non_tensor
|
||||
self.acc_nodes: NodeSet = set()
|
||||
|
||||
def reduce_acc_nodes_non_tensor_input_helper(
|
||||
self, cpu_worklist: NodeList
|
||||
):
|
||||
def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
|
||||
"""
|
||||
Transitively excludes nodes from ACC supported set.
|
||||
For every node in the worklist:
|
||||
@ -190,10 +198,12 @@ class FxNetAccNodesFinder:
|
||||
|
||||
return self.acc_nodes
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class FxNetSplitterInternalError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@dataclass
|
||||
class Subgraph:
|
||||
@ -201,6 +211,7 @@ class Subgraph:
|
||||
nodes: NodeList
|
||||
device_ordinal: Optional[int] = None
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class SplitResult(NamedTuple):
|
||||
"""
|
||||
@ -243,7 +254,9 @@ def generate_inputs_for_submodules(
|
||||
submodule_to_names = {mod: name for name, mod in model.named_modules()}
|
||||
|
||||
def pre_forward(module, module_inputs):
|
||||
results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
|
||||
results[submodule_to_names[module]] = (
|
||||
copy.deepcopy(module_inputs) if deepcopy else module_inputs
|
||||
)
|
||||
|
||||
for name, mod in model.named_modules():
|
||||
if name in target_submodules:
|
||||
@ -308,7 +321,7 @@ class _SplitterBase:
|
||||
"""
|
||||
|
||||
# PCIe bandwidth for the backend, default to 100 GB/s
|
||||
PCIe_BW = 100 * 2 ** 30
|
||||
PCIe_BW = 100 * 2**30
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -335,7 +348,9 @@ class _SplitterBase:
|
||||
self.settings = settings
|
||||
self.operator_support = operator_support
|
||||
self.sample_input = sample_input
|
||||
self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
|
||||
self.acc_nodes = FxNetAccNodesFinder(
|
||||
self.module, self.operator_support, self.settings.allow_non_tensor
|
||||
)()
|
||||
|
||||
if self.settings.skip_fusion:
|
||||
self.fusions = {}
|
||||
@ -357,11 +372,11 @@ class _SplitterBase:
|
||||
# ===============================================================
|
||||
|
||||
def get_node_submodule_map(self) -> Dict[str, str]:
|
||||
""" Returns a map from node name to submodule name, e.g.
|
||||
node: main_module_impl_impl_over_arch_unary_multiple_embedding
|
||||
_pooling_embedding_pooling_sparse_entity_equivalence_key
|
||||
_proxy_embedding_bag
|
||||
maps to submodule name of: _run_on_acc_1
|
||||
"""Returns a map from node name to submodule name, e.g.
|
||||
node: main_module_impl_impl_over_arch_unary_multiple_embedding
|
||||
_pooling_embedding_pooling_sparse_entity_equivalence_key
|
||||
_proxy_embedding_bag
|
||||
maps to submodule name of: _run_on_acc_1
|
||||
"""
|
||||
return self._node_submodule_map
|
||||
|
||||
@ -411,9 +426,7 @@ class _SplitterBase:
|
||||
|
||||
return mod
|
||||
|
||||
def _find_culprit(
|
||||
self, mod: torch.fx.GraphModule, inputs: Tensors
|
||||
) -> str:
|
||||
def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str:
|
||||
"""
|
||||
When an error occurs during lowering or running the lowered mod, we use this
|
||||
function to find culprits in the `mod` that causes the error.
|
||||
@ -492,7 +505,9 @@ class _SplitterBase:
|
||||
supported_nodes.append(node)
|
||||
supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
|
||||
else:
|
||||
unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
|
||||
unsupported_node_types[target].add(
|
||||
(arg_dtypes_tuple, kwarg_dtypes_tuple)
|
||||
)
|
||||
|
||||
if dump_graph:
|
||||
self._draw_graph_based_on_node_support(self.module, supported_nodes)
|
||||
@ -527,7 +542,11 @@ class _SplitterBase:
|
||||
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
|
||||
|
||||
for i, subgraph in enumerate(subgraphs):
|
||||
reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
|
||||
reports += (
|
||||
f"_run_on_acc_{i}: "
|
||||
if subgraph.is_acc
|
||||
else f"{self.non_acc_submodule_name}{i}: "
|
||||
)
|
||||
reports += f"{len(subgraph.nodes)} node(s)\n"
|
||||
|
||||
self.tag(subgraphs)
|
||||
@ -535,9 +554,7 @@ class _SplitterBase:
|
||||
split_mod.eval()
|
||||
|
||||
if dump_graph:
|
||||
drawer = FxGraphDrawer(
|
||||
split_mod, "preview", ignore_getattr=True
|
||||
)
|
||||
drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True)
|
||||
dot_graphs = drawer.get_all_dot_graphs()
|
||||
for name, dot_graph in dot_graphs.items():
|
||||
# pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
|
||||
@ -564,9 +581,7 @@ class _SplitterBase:
|
||||
handle.remove()
|
||||
return sub_inputs
|
||||
|
||||
submod_inputs = get_submod_inputs(
|
||||
split_mod, submod, self.sample_input
|
||||
)
|
||||
submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input)
|
||||
ShapeProp(submod).propagate(*submod_inputs)
|
||||
|
||||
total_input_bytes = 0
|
||||
@ -649,9 +664,7 @@ class _SplitterBase:
|
||||
|
||||
return result
|
||||
|
||||
def update_reverse_deps_for_fusions(
|
||||
self, deps: Dict[torch.fx.Node, NodeSet]
|
||||
):
|
||||
def update_reverse_deps_for_fusions(self, deps: Dict[torch.fx.Node, NodeSet]):
|
||||
processed_node = set()
|
||||
|
||||
for node, fusion in self.fusions.items():
|
||||
@ -853,7 +866,11 @@ class _SplitterBase:
|
||||
def tag(self, subgraphs: List[Subgraph]):
|
||||
self.tags = []
|
||||
for subgraph in subgraphs:
|
||||
tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
|
||||
tag = (
|
||||
f"_run_on_acc_{len(self.tags)}"
|
||||
if subgraph.is_acc
|
||||
else f"{self.non_acc_submodule_name}{len(self.tags)}"
|
||||
)
|
||||
self.tags.append(tag)
|
||||
for node in subgraph.nodes:
|
||||
if hasattr(node, "tag"):
|
||||
@ -863,7 +880,9 @@ class _SplitterBase:
|
||||
self._node_submodule_map[node.name] = tag
|
||||
|
||||
def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
|
||||
split_module = split_by_tags(self.module, self.tags, return_tuple=self._return_tuple)
|
||||
split_module = split_by_tags(
|
||||
self.module, self.tags, return_tuple=self._return_tuple
|
||||
)
|
||||
if remove_tag:
|
||||
for node in self.module.graph.nodes:
|
||||
if hasattr(node, "tag"):
|
||||
@ -875,7 +894,9 @@ class _SplitterBase:
|
||||
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
|
||||
acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
|
||||
non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
|
||||
print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
|
||||
print(
|
||||
f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs"
|
||||
)
|
||||
self.tag(subgraphs)
|
||||
return self.split()
|
||||
|
||||
@ -894,5 +915,7 @@ class _SplitterBase:
|
||||
"result in performance issues."
|
||||
)
|
||||
|
||||
submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
|
||||
submodule_inputs = generate_inputs_for_submodules(
|
||||
split_module, self.sample_input, submodule_names
|
||||
)
|
||||
return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
|
||||
|
||||
@ -26,9 +26,7 @@ class TestPassManager(unittest.TestCase):
|
||||
def test_these_before_those_pass_constraint(self) -> None:
|
||||
passes = [lambda x: 2 * x for _ in range(10)]
|
||||
constraint = these_before_those_pass_constraint(passes[-1], passes[0])
|
||||
pm = PassManager(
|
||||
[inplace_wrapper(p) for p in passes]
|
||||
)
|
||||
pm = PassManager([inplace_wrapper(p) for p in passes])
|
||||
|
||||
# add unfulfillable constraint
|
||||
pm.add_constraint(constraint)
|
||||
@ -46,7 +44,7 @@ class TestPassManager(unittest.TestCase):
|
||||
pm1.add_pass(p)
|
||||
pm1.add_constraint(constraint)
|
||||
output1 = pm1(1)
|
||||
self.assertEqual(output1, 2 ** 3)
|
||||
self.assertEqual(output1, 2**3)
|
||||
|
||||
passes = [lambda x: 3 * x for _ in range(3)]
|
||||
constraint = these_before_those_pass_constraint(passes[0], passes[1])
|
||||
@ -55,4 +53,4 @@ class TestPassManager(unittest.TestCase):
|
||||
pm2.add_pass(p)
|
||||
pm2.add_constraint(constraint)
|
||||
output2 = pm2(1)
|
||||
self.assertEqual(output2, 3 ** 3)
|
||||
self.assertEqual(output2, 3**3)
|
||||
|
||||
@ -1,15 +1,22 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
import operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import _get_qualified_name
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.node import _get_qualified_name
|
||||
|
||||
__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph']
|
||||
|
||||
__all__ = [
|
||||
"get_acc_ops_name",
|
||||
"get_node_target",
|
||||
"is_node_output_tensor",
|
||||
"FxNetAccFusionsFinder",
|
||||
"legalize_graph",
|
||||
]
|
||||
|
||||
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
|
||||
TensorOrTensors = Union[torch.Tensor, Tensors]
|
||||
@ -26,12 +33,16 @@ def get_acc_ops_name(k):
|
||||
elif k.__module__ and "acc_ops" in k.__module__:
|
||||
return f"acc_ops.{k.__name__}"
|
||||
else:
|
||||
module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
|
||||
module = k.__module__.replace(
|
||||
"torch._ops", "torch.ops"
|
||||
) # WAR for bug in how torch.ops assigns module
|
||||
return f"{module if module else ''}.{k.__name__}"
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str:
|
||||
def get_node_target(
|
||||
submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
|
||||
) -> str:
|
||||
"""
|
||||
Given a `node` returns its target typename.
|
||||
|
||||
@ -66,6 +77,7 @@ def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.No
|
||||
assert isinstance(node.target, str)
|
||||
return node.target
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def is_node_output_tensor(node: torch.fx.Node) -> bool:
|
||||
"""Checks if the node output produces a Tensor or not.
|
||||
@ -77,6 +89,7 @@ def is_node_output_tensor(node: torch.fx.Node) -> bool:
|
||||
type_ = node.meta.get("type", None)
|
||||
return type_ is not None and issubclass(type_, torch.Tensor)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class FxNetAccFusionsFinder:
|
||||
"""
|
||||
@ -297,7 +310,9 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
# If the new graph's size is not as large as the old one, then there must be
|
||||
# a cycle (i.e. some node's dependencies were not satisfied.)
|
||||
if len(new_graph.nodes) < len(gm.graph.nodes):
|
||||
raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
|
||||
raise RuntimeError(
|
||||
f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
|
||||
)
|
||||
new_graph._codegen = gm.graph._codegen
|
||||
gm.graph = new_graph
|
||||
return gm
|
||||
|
||||
@ -1 +1 @@
|
||||
from .common import lift_subgraph_as_module, HolderModule, compare_graphs
|
||||
from .common import compare_graphs, HolderModule, lift_subgraph_as_module
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import Dict, Tuple
|
||||
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph import Graph
|
||||
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
|
||||
from torch.nn import Module
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
from queue import SimpleQueue
|
||||
from typing import List, Dict, Optional as _Optional, Tuple
|
||||
from typing import Dict, List, Optional as _Optional, Tuple
|
||||
|
||||
import torch.fx
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph
|
||||
from torch.fx.passes.utils import lift_subgraph_as_module
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.passes.tools_common import legalize_graph, NodeList, NodeSet
|
||||
from torch.fx.passes.utils import lift_subgraph_as_module
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def topo_sort(nodes: NodeList) -> NodeList:
|
||||
@ -35,7 +36,9 @@ def topo_sort(nodes: NodeList) -> NodeList:
|
||||
if indegree_map[n] == 0:
|
||||
candidates.put(n)
|
||||
|
||||
assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes"
|
||||
assert len(nodes) == len(
|
||||
sorted_nodes
|
||||
), "topological sorted nodes doesn't have same length as input nodes"
|
||||
|
||||
return sorted_nodes
|
||||
|
||||
@ -96,7 +99,6 @@ def fuse_as_graphmodule(
|
||||
module_name: str,
|
||||
partition_lookup_table: _Optional[Dict[Node, None]] = None,
|
||||
) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
|
||||
|
||||
"""
|
||||
Fuse nodes in graph_module into a GraphModule.
|
||||
|
||||
@ -121,9 +123,13 @@ def fuse_as_graphmodule(
|
||||
# assumption: nodes are already sorted in topo order
|
||||
|
||||
for node in nodes:
|
||||
assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}"
|
||||
assert (
|
||||
node.graph.owning_module is gm
|
||||
), f"{node} doesn't belong to passed in graph module {gm._get_name()}"
|
||||
assert not node._erased, f"{node} has been removed from owning graph"
|
||||
assert node in gm.graph._find_nodes_lookup_table, f"{node} is not found in graph module {gm._get_name()}"
|
||||
assert (
|
||||
node in gm.graph._find_nodes_lookup_table
|
||||
), f"{node} is not found in graph module {gm._get_name()}"
|
||||
|
||||
# validates partition doesn't introduce dependency circles in the graph
|
||||
assert validate_partition(nodes), "Invalid partition, found dependency cycles"
|
||||
@ -134,8 +140,10 @@ def fuse_as_graphmodule(
|
||||
|
||||
subgraph = Graph()
|
||||
|
||||
node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph
|
||||
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
|
||||
node_to_placeholder: Dict[
|
||||
Node, Node
|
||||
] = {} # mapping of nodes from old graph to placeholder in new graph
|
||||
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
|
||||
|
||||
# handles inputs through graph.node_copy's arg_transform functions
|
||||
def remap_inputs(x):
|
||||
@ -184,7 +192,9 @@ def fuse_as_graphmodule(
|
||||
# lint to ensure correctness
|
||||
subgraph.lint()
|
||||
fused_gm: GraphModule
|
||||
fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name)
|
||||
fused_gm, _ = lift_subgraph_as_module(
|
||||
gm, subgraph, comp_name="", class_name=module_name
|
||||
)
|
||||
|
||||
# sub_gm's input nodes in the original module
|
||||
original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
|
||||
@ -196,16 +206,18 @@ def fuse_as_graphmodule(
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
|
||||
def insert_subgm(
|
||||
gm: GraphModule,
|
||||
sub_gm: GraphModule,
|
||||
orig_inputs: Tuple[Node, ...],
|
||||
orig_outputs: Tuple[Node, ...],
|
||||
):
|
||||
# add sub_gm into gm
|
||||
submodule_name = sub_gm.__class__.__name__
|
||||
gm.add_submodule(submodule_name, sub_gm)
|
||||
|
||||
# Create a call_module node in main graph.
|
||||
module_node = gm.graph.call_module(
|
||||
submodule_name,
|
||||
args=orig_inputs,
|
||||
kwargs=None)
|
||||
module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None)
|
||||
|
||||
if len(orig_outputs) == 1:
|
||||
# main_remapping[comp.orig_outputs[0]] = module_node
|
||||
@ -216,24 +228,30 @@ def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node,
|
||||
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
|
||||
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
|
||||
|
||||
module_node.meta["val"] = tuple(orig_output.meta.get("val", None) for orig_output in orig_outputs)
|
||||
module_node.meta["val"] = tuple(
|
||||
orig_output.meta.get("val", None) for orig_output in orig_outputs
|
||||
)
|
||||
return gm
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def erase_nodes(gm: GraphModule, nodes: NodeList):
|
||||
|
||||
# erase original nodes in inversed topological order
|
||||
for node in reversed(nodes):
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def fuse_by_partitions(gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_") -> GraphModule:
|
||||
def fuse_by_partitions(
|
||||
gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_"
|
||||
) -> GraphModule:
|
||||
for partition_id, partition in enumerate(partitions):
|
||||
sorted_nodes = topo_sort(list(partition))
|
||||
|
||||
submodule_name = prefix + str(partition_id)
|
||||
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name, partition)
|
||||
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
|
||||
gm, sorted_nodes, submodule_name, partition
|
||||
)
|
||||
|
||||
insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
from torch.fx import Graph, Node
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
|
||||
__all__ = ["SubgraphMatcher", "InternalMatch"]
|
||||
|
||||
|
||||
|
||||
@ -1,19 +1,21 @@
|
||||
from dataclasses import dataclass, field
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node
|
||||
from torch.fx._compatibility import compatibility
|
||||
from typing import Dict, List, Any, Type, Optional, Callable
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node
|
||||
|
||||
|
||||
__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition']
|
||||
__all__ = ["get_source_partitions", "check_subgraphs_connected", "SourcePartition"]
|
||||
|
||||
|
||||
# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
|
||||
def _init_logger() -> logging.Logger:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
|
||||
level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper()
|
||||
logger.setLevel(level)
|
||||
console = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(filename)s > %(message)s")
|
||||
@ -24,6 +26,7 @@ def _init_logger() -> logging.Logger:
|
||||
logger.propagate = False
|
||||
return logger
|
||||
|
||||
|
||||
logger = _init_logger()
|
||||
|
||||
|
||||
@ -77,8 +80,9 @@ def get_source_partitions(
|
||||
# be different from "source_fn_stack", for example for the add_ node
|
||||
# decomposed from batch norm. We should remove the check on "source_fn_stack"
|
||||
# after we fix "torch_fn". T199561090
|
||||
if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and
|
||||
(torch_fn := node.meta.get("torch_fn", None)) is not None):
|
||||
if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and (
|
||||
torch_fn := node.meta.get("torch_fn", None)
|
||||
) is not None:
|
||||
node_fqn, source_fn = torch_fn
|
||||
source_fn_name = source_fn.split(".")[1]
|
||||
if source_fn_name in wanted_sources:
|
||||
@ -86,7 +90,6 @@ def get_source_partitions(
|
||||
partition = diff_modules.setdefault(node_fqn, [])
|
||||
partition.append(node)
|
||||
|
||||
|
||||
if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None:
|
||||
source_fn = source_fn_st[-1]
|
||||
if source_fn[1] in wanted_sources:
|
||||
@ -140,7 +143,9 @@ def get_source_partitions(
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False) # type: ignore[misc]
|
||||
def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
|
||||
def check_subgraphs_connected(
|
||||
subgraph1: SourcePartition, subgraph2: SourcePartition
|
||||
) -> bool:
|
||||
"""
|
||||
Given two subgraphs A and B (in the form of a list of nodes), checks if
|
||||
A has nodes connecting to at least one node in B -- aka there exists a node
|
||||
|
||||
@ -1,29 +1,37 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import enum
|
||||
import dis
|
||||
import copy
|
||||
import sys
|
||||
import torch
|
||||
import inspect
|
||||
import operator
|
||||
import collections
|
||||
import copy
|
||||
import dis
|
||||
import enum
|
||||
import inspect
|
||||
import logging
|
||||
import operator
|
||||
import sys
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import Any, Callable, Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
from dataclasses import is_dataclass, fields
|
||||
|
||||
|
||||
from .graph import magic_methods, reflectable_magic_methods, Graph
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable
|
||||
from .node import Target, Node, Argument, base_types, map_aggregate
|
||||
from ._compatibility import compatibility
|
||||
from .operator_schemas import check_for_mutable_operation
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
|
||||
'Proxy', 'MetaProxy', 'Attribute', 'ParameterProxy', 'Scope',
|
||||
'ScopeContextManager']
|
||||
from ._compatibility import compatibility
|
||||
from .graph import Graph, magic_methods, reflectable_magic_methods
|
||||
from .node import Argument, base_types, map_aggregate, Node, Target
|
||||
from .operator_schemas import check_for_mutable_operation
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TracerBase",
|
||||
"GraphAppendingTracer",
|
||||
"TraceError",
|
||||
"Proxy",
|
||||
"MetaProxy",
|
||||
"Attribute",
|
||||
"ParameterProxy",
|
||||
"Scope",
|
||||
"ScopeContextManager",
|
||||
]
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -31,7 +39,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class Scope:
|
||||
""" Scope object that records the module path and the module type
|
||||
"""Scope object that records the module path and the module type
|
||||
of a module. Scope is used to track the information of the module
|
||||
that contains a Node in a Graph of GraphModule. For example::
|
||||
|
||||
@ -41,6 +49,7 @@ class Scope:
|
||||
# scope for this would be (module_path="sub", module_type=Sub)
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
self.sub = Sub()
|
||||
@ -62,7 +71,7 @@ class Scope:
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class ScopeContextManager:
|
||||
""" A context manager to track the Scope of Node during symbolic tracing.
|
||||
"""A context manager to track the Scope of Node during symbolic tracing.
|
||||
When entering a forward function of a Module, we'll update the scope information of
|
||||
the current module, and when we exit, we'll restore the previous scope information.
|
||||
"""
|
||||
@ -102,28 +111,28 @@ _COPY_META_FIELDS = [
|
||||
"quantization_tag", # TODO deprecated
|
||||
"_numeric_debug_handle", # TODO deprecated
|
||||
"custom",
|
||||
"partitioner_tag"
|
||||
"partitioner_tag",
|
||||
]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class TracerBase:
|
||||
graph: Graph
|
||||
record_stack_traces : bool = False
|
||||
record_stack_traces: bool = False
|
||||
# Feature flag for mutable schema checking
|
||||
# Enableby default in 1.12
|
||||
check_mutable_operations : bool = False
|
||||
check_mutable_operations: bool = False
|
||||
# Feature flag for assert tracing
|
||||
trace_asserts : bool = False
|
||||
trace_asserts: bool = False
|
||||
# Feature flag for proxying accesses to buffer values
|
||||
proxy_buffer_attributes : bool = False
|
||||
proxy_buffer_attributes: bool = False
|
||||
|
||||
# Name of the function to be traced. It will only be used when
|
||||
# ``root`` is an instance of ``nn.Module``
|
||||
traced_func_name: str = "forward"
|
||||
|
||||
# Maps the containing module's name to the operator name
|
||||
scope : Scope
|
||||
scope: Scope
|
||||
|
||||
# Records the module call stack
|
||||
module_stack: OrderedDict[str, Tuple[str, Any]]
|
||||
@ -132,9 +141,15 @@ class TracerBase:
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]]
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def create_node(self, kind : str, target : Target,
|
||||
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
|
||||
type_expr : Optional[Any] = None) -> Node:
|
||||
def create_node(
|
||||
self,
|
||||
kind: str,
|
||||
target: Target,
|
||||
args: Tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
) -> Node:
|
||||
"""
|
||||
Inserts a graph node given target, args, kwargs, and name.
|
||||
|
||||
@ -143,7 +158,7 @@ class TracerBase:
|
||||
want to disallow in-place operations from being recorded.
|
||||
"""
|
||||
|
||||
if kind == 'call_function' and self.check_mutable_operations:
|
||||
if kind == "call_function" and self.check_mutable_operations:
|
||||
check_for_mutable_operation(target, args, kwargs)
|
||||
|
||||
node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
|
||||
@ -182,20 +197,27 @@ class TracerBase:
|
||||
node.meta["seq_nr"] = new_seq_nr
|
||||
|
||||
elif self.module_stack:
|
||||
node.meta['nn_module_stack'] = copy.copy(self.module_stack)
|
||||
node.meta["nn_module_stack"] = copy.copy(self.module_stack)
|
||||
|
||||
log.debug("create_node %s", node)
|
||||
return node
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def proxy(self, node: Node) -> 'Proxy':
|
||||
def proxy(self, node: Node) -> "Proxy":
|
||||
return Proxy(node, self)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None, type_expr : Optional[Any] = None,
|
||||
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
|
||||
'''
|
||||
def create_proxy(
|
||||
self,
|
||||
kind: str,
|
||||
target: Target,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Callable[[Node], "Proxy"] = None,
|
||||
):
|
||||
"""
|
||||
Create a Node from the given arguments, then return the Node
|
||||
wrapped in a Proxy object.
|
||||
|
||||
@ -203,7 +225,7 @@ class TracerBase:
|
||||
represents the parameter of a function. If we need to encode
|
||||
a default parameter, we use the ``args`` tuple. ``args`` is
|
||||
otherwise empty for ``placeholder`` Nodes.
|
||||
'''
|
||||
"""
|
||||
|
||||
args_ = self.create_arg(args)
|
||||
kwargs_ = self.create_arg(kwargs)
|
||||
@ -218,8 +240,7 @@ class TracerBase:
|
||||
proxy = proxy_factory_fn(node)
|
||||
|
||||
if self.record_stack_traces and not proxy.node.stack_trace:
|
||||
proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format())
|
||||
|
||||
proxy.node.stack_trace = "".join(CapturedTraceback.extract().format())
|
||||
|
||||
return proxy
|
||||
|
||||
@ -233,20 +254,23 @@ class TracerBase:
|
||||
# the user code during tracing.
|
||||
frame = inspect.currentframe()
|
||||
|
||||
pt_files = ['torch/fx/proxy.py',
|
||||
'torch/fx/_symbolic_trace.py',
|
||||
'torch/fx/experimental/proxy_tensor.py',
|
||||
'torch/_ops.py',
|
||||
'torch/_tensor.py',
|
||||
'torch/utils/_python_dispatch.py',
|
||||
'torch/_prims_common/wrappers.py',
|
||||
'torch/_refs/__init__.py',
|
||||
'torch/_refs/nn/functional/__init__.py',
|
||||
'torch/utils/_stats.py',
|
||||
]
|
||||
pt_files = [
|
||||
"torch/fx/proxy.py",
|
||||
"torch/fx/_symbolic_trace.py",
|
||||
"torch/fx/experimental/proxy_tensor.py",
|
||||
"torch/_ops.py",
|
||||
"torch/_tensor.py",
|
||||
"torch/utils/_python_dispatch.py",
|
||||
"torch/_prims_common/wrappers.py",
|
||||
"torch/_refs/__init__.py",
|
||||
"torch/_refs/nn/functional/__init__.py",
|
||||
"torch/utils/_stats.py",
|
||||
]
|
||||
while frame:
|
||||
frame = frame.f_back
|
||||
if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files):
|
||||
if frame and all(
|
||||
not frame.f_code.co_filename.endswith(file) for file in pt_files
|
||||
):
|
||||
break
|
||||
|
||||
if not frame:
|
||||
@ -264,11 +288,11 @@ class TracerBase:
|
||||
"""
|
||||
if isinstance(a, Proxy):
|
||||
return a.node # most common arg type goes first
|
||||
elif hasattr(a, '__fx_create_arg__'):
|
||||
elif hasattr(a, "__fx_create_arg__"):
|
||||
return a.__fx_create_arg__(self)
|
||||
# aggregates
|
||||
elif isinstance(a, tuple):
|
||||
if hasattr(a, '_fields'):
|
||||
if hasattr(a, "_fields"):
|
||||
# NamedTuple constructors don't seem to like getting a generator
|
||||
# expression as an argument to their constructor, so build this
|
||||
# intermediate tuple and unpack it into the NamedTuple constructor
|
||||
@ -278,10 +302,13 @@ class TracerBase:
|
||||
elif isinstance(a, list):
|
||||
return [self.create_arg(elem) for elem in a]
|
||||
elif isinstance(a, dict):
|
||||
|
||||
def no_node(arg):
|
||||
if isinstance(arg, Node):
|
||||
raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
|
||||
f"Node. Got key: {k}")
|
||||
raise RuntimeError(
|
||||
"Keys for dictionaries used as an argument cannot contain a "
|
||||
f"Node. Got key: {k}"
|
||||
)
|
||||
|
||||
r = {}
|
||||
for k, v in a.items():
|
||||
@ -294,16 +321,27 @@ class TracerBase:
|
||||
r[k] = self.create_arg(v)
|
||||
return r
|
||||
elif isinstance(a, slice):
|
||||
return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
|
||||
return slice(
|
||||
self.create_arg(a.start),
|
||||
self.create_arg(a.stop),
|
||||
self.create_arg(a.step),
|
||||
)
|
||||
|
||||
elif isinstance(a, range):
|
||||
return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
|
||||
return range(
|
||||
self.create_arg(a.start),
|
||||
self.create_arg(a.stop),
|
||||
self.create_arg(a.step),
|
||||
)
|
||||
|
||||
elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
|
||||
return a
|
||||
|
||||
elif is_dataclass(a):
|
||||
kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)}
|
||||
kwargs = {
|
||||
field.name: self.create_arg(getattr(a, field.name))
|
||||
for field in fields(a)
|
||||
}
|
||||
return self.create_node("call_function", a.__class__, (), kwargs)
|
||||
|
||||
elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
|
||||
@ -312,37 +350,41 @@ class TracerBase:
|
||||
raise NotImplementedError(f"argument of type: {type(a)}")
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def to_bool(self, obj: 'Proxy') -> bool:
|
||||
def to_bool(self, obj: "Proxy") -> bool:
|
||||
"""Called when a proxy object is being converted to a boolean, such as
|
||||
when used in control flow. Normally we don't know what to do because
|
||||
we don't know the value of the proxy, but a custom tracer can attach more
|
||||
information to the graph node using create_node and can choose to return a value.
|
||||
"""
|
||||
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
|
||||
raise TraceError(
|
||||
"symbolically traced variables cannot be used as inputs to control flow"
|
||||
)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def iter(self, obj: 'Proxy') -> Iterator:
|
||||
def iter(self, obj: "Proxy") -> Iterator:
|
||||
"""Called when a proxy object is being iterated over, such as
|
||||
when used in control flow. Normally we don't know what to do because
|
||||
we don't know the value of the proxy, but a custom tracer can attach more
|
||||
information to the graph node using create_node and can choose to return an iterator.
|
||||
"""
|
||||
raise TraceError('Proxy object cannot be iterated. This can be '
|
||||
'attempted when the Proxy is used in a loop or'
|
||||
' as a *args or **kwargs function argument. '
|
||||
'See the torch.fx docs on pytorch.org for a '
|
||||
'more detailed explanation of what types of '
|
||||
'control flow can be traced, and check out the'
|
||||
' Proxy docstring for help troubleshooting '
|
||||
'Proxy iteration errors')
|
||||
raise TraceError(
|
||||
"Proxy object cannot be iterated. This can be "
|
||||
"attempted when the Proxy is used in a loop or"
|
||||
" as a *args or **kwargs function argument. "
|
||||
"See the torch.fx docs on pytorch.org for a "
|
||||
"more detailed explanation of what types of "
|
||||
"control flow can be traced, and check out the"
|
||||
" Proxy docstring for help troubleshooting "
|
||||
"Proxy iteration errors"
|
||||
)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def keys(self, obj: 'Proxy') -> Any:
|
||||
def keys(self, obj: "Proxy") -> Any:
|
||||
"""Called when a proxy object is has the keys() method called.
|
||||
This is what happens when ** is called on a proxy. This should return an
|
||||
iterator it ** is suppose to work in your custom tracer.
|
||||
"""
|
||||
return Attribute(obj, 'keys')()
|
||||
return Attribute(obj, "keys")()
|
||||
|
||||
|
||||
# used in Proxy object when just appending to the graph while not tracing.
|
||||
@ -355,14 +397,17 @@ class GraphAppendingTracer(TracerBase):
|
||||
self.module_stack = collections.OrderedDict()
|
||||
self.node_name_to_scope = {}
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def assert_fn(x):
|
||||
assert x
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class TraceError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Proxy:
|
||||
"""
|
||||
@ -394,7 +439,7 @@ class Proxy:
|
||||
"""
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
|
||||
def __init__(self, node: Node, tracer: "Optional[TracerBase]" = None):
|
||||
if tracer is None:
|
||||
# This allows you to create a Proxy object around a raw Node
|
||||
tracer = GraphAppendingTracer(node.graph)
|
||||
@ -402,9 +447,9 @@ class Proxy:
|
||||
self.node = node
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Proxy({self.node.name})'
|
||||
return f"Proxy({self.node.name})"
|
||||
|
||||
def __getattr__(self, k) -> 'Attribute':
|
||||
def __getattr__(self, k) -> "Attribute":
|
||||
# note: not added to the graph yet, if this is a method call
|
||||
# we peephole optimize to the method invocation
|
||||
return Attribute(self, k)
|
||||
@ -417,6 +462,7 @@ class Proxy:
|
||||
# will go to __getattr__(self, "__deepcopy__") and return a
|
||||
# Attribute(__deepcopy__), and may go into an infinite loop in some cases.
|
||||
import copy
|
||||
|
||||
new_dict = {}
|
||||
for k, v in self.__dict__.items():
|
||||
try:
|
||||
@ -424,7 +470,10 @@ class Proxy:
|
||||
except Exception:
|
||||
log.warning(
|
||||
"Shallow copy %s of Proxy because it cannot be deepcopied. "
|
||||
"Proxy is created for node %s", k, self.node.name)
|
||||
"Proxy is created for node %s",
|
||||
k,
|
||||
self.node.name,
|
||||
)
|
||||
new_obj = copy.copy(v)
|
||||
new_dict[k] = new_obj
|
||||
assert "node" in new_dict
|
||||
@ -438,10 +487,12 @@ class Proxy:
|
||||
# This is called when being unpickled/loaded.
|
||||
self.__dict__ = d
|
||||
|
||||
def __call__(self, *args, **kwargs) -> 'Proxy':
|
||||
return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
|
||||
def __call__(self, *args, **kwargs) -> "Proxy":
|
||||
return self.tracer.create_proxy(
|
||||
"call_method", "__call__", (self,) + args, kwargs
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator['Proxy']:
|
||||
def __iter__(self) -> Iterator["Proxy"]:
|
||||
frame = inspect.currentframe()
|
||||
assert frame is not None
|
||||
calling_frame = frame.f_back
|
||||
@ -449,17 +500,20 @@ class Proxy:
|
||||
inst_list = list(dis.get_instructions(calling_frame.f_code))
|
||||
if sys.version_info >= (3, 11):
|
||||
from bisect import bisect_left
|
||||
inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset)
|
||||
|
||||
inst_idx = bisect_left(
|
||||
inst_list, calling_frame.f_lasti, key=lambda x: x.offset
|
||||
)
|
||||
else:
|
||||
inst_idx = calling_frame.f_lasti // 2
|
||||
inst = inst_list[inst_idx]
|
||||
if inst.opname == 'UNPACK_SEQUENCE':
|
||||
if inst.opname == "UNPACK_SEQUENCE":
|
||||
return (self[i] for i in range(inst.argval)) # type: ignore[index]
|
||||
|
||||
return self.tracer.iter(self)
|
||||
|
||||
def __abs__(self):
|
||||
return self.tracer.create_proxy('call_function', operator.abs, (self,), {})
|
||||
return self.tracer.create_proxy("call_function", operator.abs, (self,), {})
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
if self.tracer.trace_asserts:
|
||||
@ -472,19 +526,23 @@ class Proxy:
|
||||
insts = list(dis.get_instructions(calling_frame.f_code))
|
||||
if sys.version_info >= (3, 11):
|
||||
from bisect import bisect_left
|
||||
|
||||
cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset)
|
||||
else:
|
||||
cur = calling_frame.f_lasti // 2
|
||||
inst = insts[cur]
|
||||
|
||||
if inst.opname == 'POP_JUMP_IF_TRUE':
|
||||
if inst.opname == "POP_JUMP_IF_TRUE":
|
||||
first = insts[cur + 1]
|
||||
assert inst.arg is not None
|
||||
last = insts[inst.arg // 2 - 1]
|
||||
starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError'
|
||||
or first.opname == 'LOAD_ASSERTION_ERROR')
|
||||
if starts_with_assert and last.opname == 'RAISE_VARARGS':
|
||||
self.tracer.create_proxy('call_function', assert_fn, (self,), {})
|
||||
starts_with_assert = (
|
||||
first.opname == "LOAD_GLOBAL"
|
||||
and first.argval == "AssertionError"
|
||||
or first.opname == "LOAD_ASSERTION_ERROR"
|
||||
)
|
||||
if starts_with_assert and last.opname == "RAISE_VARARGS":
|
||||
self.tracer.create_proxy("call_function", assert_fn, (self,), {})
|
||||
return True
|
||||
|
||||
return self.tracer.to_bool(self)
|
||||
@ -494,39 +552,51 @@ class Proxy:
|
||||
return self.tracer.keys(self)
|
||||
|
||||
def __len__(self):
|
||||
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
|
||||
"this call to be recorded, please call torch.fx.wrap('len') at "
|
||||
"module scope")
|
||||
raise RuntimeError(
|
||||
"'len' is not supported in symbolic tracing by default. If you want "
|
||||
"this call to be recorded, please call torch.fx.wrap('len') at "
|
||||
"module scope"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
|
||||
args = args if args else ()
|
||||
kwargs = kwargs if kwargs else {}
|
||||
|
||||
tracers : Dict[Any, None] = {}
|
||||
tracers: Dict[Any, None] = {}
|
||||
|
||||
def find_tracer(a):
|
||||
if isinstance(a, cls):
|
||||
tracers[a.tracer] = None
|
||||
|
||||
torch.fx.node.map_aggregate(args, find_tracer)
|
||||
torch.fx.node.map_aggregate(kwargs, find_tracer)
|
||||
|
||||
if len(tracers) > 1:
|
||||
raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while '
|
||||
f'trying to trace operations {orig_method}')
|
||||
raise RuntimeError(
|
||||
f"Found multiple different tracers {list(tracers.keys())} while "
|
||||
f"trying to trace operations {orig_method}"
|
||||
)
|
||||
tracer = next(iter(tracers.keys()))
|
||||
|
||||
if isinstance(orig_method, torch._C.ScriptMethod):
|
||||
args = (orig_method.owner,) + args
|
||||
return tracer.create_proxy('call_method', orig_method.name, args, kwargs)
|
||||
return tracer.create_proxy("call_method", orig_method.name, args, kwargs)
|
||||
if torch.overrides.is_tensor_method_or_property(orig_method):
|
||||
return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
|
||||
return tracer.create_proxy(
|
||||
"call_method", orig_method.__name__, args, kwargs
|
||||
)
|
||||
else:
|
||||
if isinstance(orig_method, torch._ops.HigherOrderOperator):
|
||||
# TODO: Define how to symbolically trace HigherOrderOperators
|
||||
raise RuntimeError("Unable to symbolically trace HigherOrderOperators")
|
||||
return tracer.create_proxy('call_function', orig_method, args, kwargs,
|
||||
name=tracer.graph._target_to_str(orig_method.__name__))
|
||||
return tracer.create_proxy(
|
||||
"call_function",
|
||||
orig_method,
|
||||
args,
|
||||
kwargs,
|
||||
name=tracer.graph._target_to_str(orig_method.__name__),
|
||||
)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@ -535,12 +605,14 @@ class MetaProxy(Proxy):
|
||||
A Proxy subclass that propagates metadata (meta['val']) during graph tracing.
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None, fake_mode=None):
|
||||
def __init__(
|
||||
self, node: Node, tracer: "Optional[TracerBase]" = None, fake_mode=None
|
||||
):
|
||||
super().__init__(node, tracer)
|
||||
self.fake_mode = fake_mode
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'MetaProxy({self.node.name})'
|
||||
return f"MetaProxy({self.node.name})"
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
|
||||
@ -553,16 +625,19 @@ class MetaProxy(Proxy):
|
||||
meta_proxy = arg
|
||||
break
|
||||
|
||||
assert meta_proxy is not None, "No MetaProxy found in arguments, but one is expected."
|
||||
assert (
|
||||
meta_proxy is not None
|
||||
), "No MetaProxy found in arguments, but one is expected."
|
||||
|
||||
proxy = super().__torch_function__(orig_method, types, args, kwargs)
|
||||
with meta_proxy.fake_mode:
|
||||
proxy.node.meta["val"] = orig_method(
|
||||
*[a.node.meta["val"] if isinstance(a, Proxy) else a for a in args],
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
return MetaProxy(proxy.node, proxy.tracer, meta_proxy.fake_mode)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Attribute(Proxy):
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@ -577,11 +652,15 @@ class Attribute(Proxy):
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||
self._node = self.tracer.create_proxy(
|
||||
"call_function", getattr, (self.root, self.attr), {}
|
||||
).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||
return self.tracer.create_proxy(
|
||||
"call_method", self.attr, (self.root,) + args, kwargs
|
||||
)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@ -591,6 +670,7 @@ class ParameterProxy(Proxy):
|
||||
attribute accesses pass through to the underlying module parameter object,
|
||||
so that conditional tests on these attributes will not throw exception during tracing
|
||||
"""
|
||||
|
||||
def __init__(self, tracer: TracerBase, node: Node, name, param):
|
||||
super().__init__(node, tracer)
|
||||
assert isinstance(param, torch.nn.Parameter)
|
||||
@ -598,7 +678,7 @@ class ParameterProxy(Proxy):
|
||||
self.name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'ParameterProxy({self.name})'
|
||||
return f"ParameterProxy({self.name})"
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@ -622,25 +702,31 @@ class ParameterProxy(Proxy):
|
||||
|
||||
|
||||
for method in magic_methods:
|
||||
|
||||
def _scope(method):
|
||||
def impl(*args, **kwargs):
|
||||
tracer = args[0].tracer
|
||||
target = getattr(operator, method)
|
||||
return tracer.create_proxy('call_function', target, args, kwargs)
|
||||
return tracer.create_proxy("call_function", target, args, kwargs)
|
||||
|
||||
impl.__name__ = method
|
||||
as_magic = f'__{method.strip("_")}__'
|
||||
setattr(Proxy, as_magic, impl)
|
||||
|
||||
_scope(method)
|
||||
|
||||
|
||||
def _define_reflectable(orig_method_name):
|
||||
method_name = f'__r{orig_method_name.strip("_")}__'
|
||||
|
||||
def impl(self, rhs):
|
||||
target = getattr(operator, orig_method_name)
|
||||
return self.tracer.create_proxy('call_function', target, (rhs, self), {})
|
||||
return self.tracer.create_proxy("call_function", target, (rhs, self), {})
|
||||
|
||||
impl.__name__ = method_name
|
||||
impl.__qualname__ = method_name
|
||||
setattr(Proxy, method_name, impl)
|
||||
|
||||
|
||||
for orig_method_name in reflectable_magic_methods:
|
||||
_define_reflectable(orig_method_name)
|
||||
|
||||
@ -1,18 +1,36 @@
|
||||
from .graph_module import GraphModule
|
||||
from .graph import Graph
|
||||
from .node import Node
|
||||
from ._symbolic_trace import symbolic_trace
|
||||
from ._compatibility import compatibility
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from ._symbolic_trace import symbolic_trace
|
||||
from .graph import Graph
|
||||
from .graph_module import GraphModule
|
||||
from .node import Node
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
|
||||
|
||||
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
|
||||
__all__ = [
|
||||
"Match",
|
||||
"replace_pattern",
|
||||
"replace_pattern_with_filters",
|
||||
"ReplacedPatterns",
|
||||
]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Match(NamedTuple):
|
||||
@ -21,6 +39,7 @@ class Match(NamedTuple):
|
||||
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
||||
nodes_map: Dict[Node, Node]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@dataclass
|
||||
class ReplacedPatterns:
|
||||
@ -31,6 +50,7 @@ class ReplacedPatterns:
|
||||
# List of nodes that were added into the graph
|
||||
replacements: List[Node]
|
||||
|
||||
|
||||
def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
|
||||
gm.delete_all_unused_submodules()
|
||||
|
||||
@ -48,7 +68,6 @@ def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" or node.op == "get_attr":
|
||||
|
||||
gm_attr = try_get_attr(gm, node.target)
|
||||
replacement_attr = try_get_attr(replacement, node.target)
|
||||
|
||||
@ -70,11 +89,14 @@ def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
|
||||
# CASE 3: The target doesn't exist as an attribute in `gm`
|
||||
# or `replacement`
|
||||
else:
|
||||
raise RuntimeError('Attempted to create a "', node.op,
|
||||
'" node during subgraph rewriting '
|
||||
f"with target {node.target}, but "
|
||||
"the referenced attribute does not "
|
||||
"exist in the replacement GraphModule")
|
||||
raise RuntimeError(
|
||||
'Attempted to create a "',
|
||||
node.op,
|
||||
'" node during subgraph rewriting '
|
||||
f"with target {node.target}, but "
|
||||
"the referenced attribute does not "
|
||||
"exist in the replacement GraphModule",
|
||||
)
|
||||
|
||||
gm.graph.lint()
|
||||
|
||||
@ -83,7 +105,7 @@ def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
|
||||
def replace_pattern(
|
||||
gm: GraphModule,
|
||||
pattern: Union[Callable, GraphModule],
|
||||
replacement: Union[Callable, GraphModule]
|
||||
replacement: Union[Callable, GraphModule],
|
||||
) -> List[Match]:
|
||||
"""
|
||||
Matches all possible non-overlapping sets of operators and their
|
||||
@ -116,6 +138,7 @@ def replace_pattern(
|
||||
import torch
|
||||
from torch.fx import symbolic_trace, subgraph_rewriter
|
||||
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -125,12 +148,15 @@ def replace_pattern(
|
||||
m2 = torch.cat([w1, w2]).sum()
|
||||
return x + torch.max(m1) + torch.max(m2)
|
||||
|
||||
|
||||
def pattern(w1, w2):
|
||||
return torch.cat([w1, w2]).sum()
|
||||
|
||||
|
||||
def replacement(w1, w2):
|
||||
return torch.stack([w1, w2])
|
||||
|
||||
|
||||
traced_module = symbolic_trace(M())
|
||||
|
||||
subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
|
||||
@ -199,7 +225,9 @@ def replace_pattern(
|
||||
return add_2
|
||||
"""
|
||||
match_and_replacements = _replace_pattern(gm, pattern, replacement)
|
||||
return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements]
|
||||
return [
|
||||
Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements
|
||||
]
|
||||
|
||||
|
||||
# Experimental API, not backward compatible
|
||||
@ -208,10 +236,14 @@ def replace_pattern_with_filters(
|
||||
gm: GraphModule,
|
||||
pattern: Union[Callable, Graph, GraphModule],
|
||||
replacement: Union[Callable, Graph, GraphModule, None] = None,
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
|
||||
match_filters: Optional[
|
||||
List[Callable[["InternalMatch", Graph, Graph], bool]]
|
||||
] = None,
|
||||
ignore_literals: bool = False,
|
||||
# Placed at the end to avoid breaking backward compatibility
|
||||
replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None,
|
||||
replacement_callback: Optional[
|
||||
Callable[["InternalMatch", Graph, Graph], Graph]
|
||||
] = None,
|
||||
) -> List[ReplacedPatterns]:
|
||||
"""
|
||||
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
|
||||
@ -226,20 +258,25 @@ def replace_pattern_with_filters(
|
||||
replacement graph based on the match.
|
||||
"""
|
||||
|
||||
return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals, replacement_callback)
|
||||
return _replace_pattern(
|
||||
gm, pattern, replacement, match_filters, ignore_literals, replacement_callback
|
||||
)
|
||||
|
||||
|
||||
def _replace_pattern(
|
||||
gm: GraphModule,
|
||||
pattern: Union[Callable, Graph, GraphModule],
|
||||
replacement: Union[Callable, Graph, GraphModule, None] = None,
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
|
||||
match_filters: Optional[
|
||||
List[Callable[["InternalMatch", Graph, Graph], bool]]
|
||||
] = None,
|
||||
ignore_literals: bool = False,
|
||||
# Placed at the end to avoid breaking backward compatibility
|
||||
replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None,
|
||||
replacement_callback: Optional[
|
||||
Callable[["InternalMatch", Graph, Graph], Graph]
|
||||
] = None,
|
||||
) -> List[ReplacedPatterns]:
|
||||
|
||||
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
|
||||
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
|
||||
|
||||
if match_filters is None:
|
||||
match_filters = []
|
||||
@ -254,15 +291,23 @@ def _replace_pattern(
|
||||
else:
|
||||
pattern_graph = symbolic_trace(pattern).graph
|
||||
|
||||
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
|
||||
remove_overlapping_matches=True, ignore_literals=ignore_literals)
|
||||
matcher = SubgraphMatcher(
|
||||
pattern_graph,
|
||||
match_output=False,
|
||||
match_placeholder=False,
|
||||
remove_overlapping_matches=True,
|
||||
ignore_literals=ignore_literals,
|
||||
)
|
||||
_matches: List[InternalMatch] = matcher.match(original_graph)
|
||||
|
||||
# Filter out matches that don't match the filter
|
||||
_matches = [
|
||||
m for m in _matches
|
||||
if all(match_filter(m, original_graph, pattern_graph)
|
||||
for match_filter in match_filters)
|
||||
m
|
||||
for m in _matches
|
||||
if all(
|
||||
match_filter(m, original_graph, pattern_graph)
|
||||
for match_filter in match_filters
|
||||
)
|
||||
]
|
||||
|
||||
if isinstance(replacement, GraphModule):
|
||||
@ -272,7 +317,9 @@ def _replace_pattern(
|
||||
elif callable(replacement):
|
||||
common_replacement_graph = symbolic_trace(replacement).graph
|
||||
else:
|
||||
assert replacement_callback is not None, "Must provide either a replacement GraphModule or a replacement callback"
|
||||
assert (
|
||||
replacement_callback is not None
|
||||
), "Must provide either a replacement GraphModule or a replacement callback"
|
||||
common_replacement_graph = None
|
||||
|
||||
# As we progressively replace nodes, we'll need to keep track of how the match results should change
|
||||
@ -281,11 +328,17 @@ def _replace_pattern(
|
||||
match_and_replacements = []
|
||||
for match in _matches:
|
||||
if replacement_callback is not None:
|
||||
replacement_graph = replacement_callback(match, original_graph, pattern_graph)
|
||||
replacement_graph = replacement_callback(
|
||||
match, original_graph, pattern_graph
|
||||
)
|
||||
else:
|
||||
assert common_replacement_graph is not None, "Must provide either a replacement GraphModule or a replacement callback"
|
||||
assert (
|
||||
common_replacement_graph is not None
|
||||
), "Must provide either a replacement GraphModule or a replacement callback"
|
||||
replacement_graph = common_replacement_graph
|
||||
replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
|
||||
replacement_placeholders = [
|
||||
n for n in replacement_graph.nodes if n.op == "placeholder"
|
||||
]
|
||||
|
||||
# Build connecting between replacement graph's input and original graph input producer node
|
||||
|
||||
@ -300,7 +353,9 @@ def _replace_pattern(
|
||||
# Update match.placeholder_nodes and match.nodes_map with the node that replaced gn
|
||||
gn_ind = match.placeholder_nodes.index(gn)
|
||||
match.placeholder_nodes[gn_ind] = match_changed_node[gn]
|
||||
map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)]
|
||||
map_key = list(match.nodes_map.keys())[
|
||||
list(match.nodes_map.values()).index(gn)
|
||||
]
|
||||
match.nodes_map[map_key] = match_changed_node[gn]
|
||||
else:
|
||||
val_map[rn] = gn
|
||||
@ -322,13 +377,17 @@ def _replace_pattern(
|
||||
break
|
||||
|
||||
with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined]
|
||||
copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
|
||||
copied_returning_nodes = original_graph.graph_copy(
|
||||
replacement_graph, val_map
|
||||
)
|
||||
|
||||
if isinstance(copied_returning_nodes, Node):
|
||||
copied_returning_nodes = (copied_returning_nodes, )
|
||||
copied_returning_nodes = (copied_returning_nodes,)
|
||||
|
||||
# Get a list of nodes that have been replaced into the graph
|
||||
replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes]
|
||||
replacement_nodes: List[Node] = [
|
||||
v for v in val_map.values() if v not in match.placeholder_nodes
|
||||
]
|
||||
|
||||
# Hook the output Node of the replacement subgraph in to the
|
||||
# original Graph at the correct location
|
||||
@ -346,7 +405,7 @@ def _replace_pattern(
|
||||
ReplacedPatterns(
|
||||
anchor=match.anchors[0],
|
||||
nodes_map=match.nodes_map,
|
||||
replacements=replacement_nodes
|
||||
replacements=replacement_nodes,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ class TensorType:
|
||||
self.__args__ = dim
|
||||
|
||||
def __repr__(self):
|
||||
return f'TensorType[{self.__args__}]'
|
||||
return f"TensorType[{self.__args__}]"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
@ -38,8 +38,9 @@ class _DynType:
|
||||
"""
|
||||
_DynType defines a type which stands for the absence of type information.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.__name__ = '_DynType'
|
||||
self.__name__ = "_DynType"
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__)
|
||||
@ -53,6 +54,7 @@ class _DynType:
|
||||
|
||||
Dyn = _DynType()
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def is_consistent(t1, t2):
|
||||
"""
|
||||
@ -73,8 +75,10 @@ def is_consistent(t1, t2):
|
||||
return True
|
||||
|
||||
if isinstance(t1, TensorType) and isinstance(t2, TensorType):
|
||||
return len(t1.__args__) == len(t2.__args__) and \
|
||||
all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
|
||||
return len(t1.__args__) == len(t2.__args__) and all(
|
||||
is_consistent(elem1, elem2)
|
||||
for elem1, elem2 in zip(t1.__args__, t2.__args__)
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -98,8 +102,10 @@ def is_more_precise(t1, t2):
|
||||
return True
|
||||
|
||||
if isinstance(t1, TensorType) and isinstance(t2, TensorType):
|
||||
return len(t1.__args__) == len(t2.__args__) and \
|
||||
all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__))
|
||||
return len(t1.__args__) == len(t2.__args__) and all(
|
||||
is_more_precise(elem1, elem2)
|
||||
for elem1, elem2 in zip(t1.__args__, t2.__args__)
|
||||
)
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -1,12 +1,21 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from ._compatibility import compatibility
|
||||
|
||||
__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
|
||||
'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr',
|
||||
'format_stack', 'set_current_meta', 'get_current_meta']
|
||||
|
||||
__all__ = [
|
||||
"preserve_node_meta",
|
||||
"has_preserved_node_meta",
|
||||
"set_stack_trace",
|
||||
"set_grad_fn_seq_nr",
|
||||
"reset_grad_fn_seq_nr",
|
||||
"format_stack",
|
||||
"set_current_meta",
|
||||
"get_current_meta",
|
||||
]
|
||||
|
||||
current_meta: Dict[str, Any] = {}
|
||||
should_preserve_node_meta = False
|
||||
@ -30,7 +39,7 @@ def preserve_node_meta():
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def set_stack_trace(stack : List[str]):
|
||||
def set_stack_trace(stack: List[str]):
|
||||
global current_meta
|
||||
|
||||
if should_preserve_node_meta and stack:
|
||||
@ -43,7 +52,9 @@ def set_grad_fn_seq_nr(seq_nr):
|
||||
|
||||
if should_preserve_node_meta:
|
||||
# The seq_nr is captured by eager mode in the grad_fn during forward
|
||||
current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr]
|
||||
current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [
|
||||
seq_nr
|
||||
]
|
||||
current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
|
||||
|
||||
|
||||
@ -90,7 +101,9 @@ def set_current_meta(node):
|
||||
if "from_node" not in current_meta:
|
||||
current_meta["from_node"] = [(node.name, node.target)]
|
||||
elif current_meta["from_node"][-1][0] != node.name:
|
||||
current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)]
|
||||
current_meta["from_node"] = current_meta["from_node"] + [
|
||||
(node.name, node.target)
|
||||
]
|
||||
|
||||
yield
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user