diff --git a/.lintrunner.toml b/.lintrunner.toml index e4afff558bf6..1e87d6f9c0ad 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1232,87 +1232,6 @@ exclude_patterns = [ 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/futures/__init__.py', - 'torch/fx/__init__.py', - 'torch/fx/_compatibility.py', - 'torch/fx/_symbolic_trace.py', - 'torch/fx/annotate.py', - 'torch/fx/config.py', - 'torch/fx/experimental/__init__.py', - 'torch/fx/experimental/accelerator_partitioner.py', - 'torch/fx/experimental/const_fold.py', - 'torch/fx/experimental/debug.py', - 'torch/fx/experimental/graph_gradual_typechecker.py', - 'torch/fx/experimental/merge_matmul.py', - 'torch/fx/experimental/meta_tracer.py', - 'torch/fx/experimental/migrate_gradual_types/__init__.py', - 'torch/fx/experimental/migrate_gradual_types/constraint.py', - 'torch/fx/experimental/migrate_gradual_types/constraint_generator.py', - 'torch/fx/experimental/migrate_gradual_types/constraint_transformation.py', - 'torch/fx/experimental/migrate_gradual_types/operation.py', - 'torch/fx/experimental/migrate_gradual_types/transform_to_z3.py', - 'torch/fx/experimental/migrate_gradual_types/util.py', - 'torch/fx/experimental/migrate_gradual_types/z3_types.py', - 'torch/fx/experimental/normalize.py', - 'torch/fx/experimental/optimization.py', - 'torch/fx/experimental/partitioner_utils.py', - 'torch/fx/experimental/refinement_types.py', - 'torch/fx/experimental/rewriter.py', - 'torch/fx/experimental/schema_type_annotation.py', - 'torch/fx/experimental/unification/__init__.py', - 'torch/fx/experimental/unification/core.py', - 'torch/fx/experimental/unification/dispatch.py', - 'torch/fx/experimental/unification/match.py', - 'torch/fx/experimental/unification/more.py', - 'torch/fx/experimental/unification/multipledispatch/__init__.py', - 'torch/fx/experimental/unification/multipledispatch/conflict.py', - 'torch/fx/experimental/unification/multipledispatch/core.py', - 'torch/fx/experimental/unification/multipledispatch/dispatcher.py', - 'torch/fx/experimental/unification/multipledispatch/utils.py', - 'torch/fx/experimental/unification/multipledispatch/variadic.py', - 'torch/fx/experimental/unification/unification_tools.py', - 'torch/fx/experimental/unification/utils.py', - 'torch/fx/experimental/unification/variable.py', - 'torch/fx/experimental/unify_refinements.py', - 'torch/fx/graph.py', - 'torch/fx/graph_module.py', - 'torch/fx/interpreter.py', - 'torch/fx/node.py', - 'torch/fx/operator_schemas.py', - 'torch/fx/passes/__init__.py', - 'torch/fx/passes/annotate_getitem_nodes.py', - 'torch/fx/passes/backends/__init__.py', - 'torch/fx/passes/backends/cudagraphs.py', - 'torch/fx/passes/dialect/__init__.py', - 'torch/fx/passes/dialect/common/__init__.py', - 'torch/fx/passes/dialect/common/cse_pass.py', - 'torch/fx/passes/fake_tensor_prop.py', - 'torch/fx/passes/graph_drawer.py', - 'torch/fx/passes/graph_manipulation.py', - 'torch/fx/passes/infra/__init__.py', - 'torch/fx/passes/infra/partitioner.py', - 'torch/fx/passes/infra/pass_base.py', - 'torch/fx/passes/infra/pass_manager.py', - 'torch/fx/passes/net_min_base.py', - 'torch/fx/passes/operator_support.py', - 'torch/fx/passes/param_fetch.py', - 'torch/fx/passes/pass_manager.py', - 'torch/fx/passes/reinplace.py', - 'torch/fx/passes/shape_prop.py', - 'torch/fx/passes/split_module.py', - 'torch/fx/passes/split_utils.py', - 'torch/fx/passes/splitter_base.py', - 'torch/fx/passes/tests/__init__.py', - 'torch/fx/passes/tests/test_pass_manager.py', - 'torch/fx/passes/tools_common.py', - 'torch/fx/passes/utils/__init__.py', - 'torch/fx/passes/utils/common.py', - 'torch/fx/passes/utils/fuser_utils.py', - 'torch/fx/passes/utils/matcher_utils.py', - 'torch/fx/passes/utils/source_matcher_utils.py', - 'torch/fx/proxy.py', - 'torch/fx/subgraph_rewriter.py', - 'torch/fx/tensor_type.py', - 'torch/fx/traceback.py', 'torch/linalg/__init__.py', 'torch/monitor/__init__.py', 'torch/nested/__init__.py', diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 9107da9a37cf..f2908243477c 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -262,7 +262,9 @@ "Future" ], "torch.fx": [ + "PH", "ProxyableClassMeta", + "CodeGen", "Tracer", "symbolic_trace", "wrap" diff --git a/torch/__init__.py b/torch/__init__.py index 1d84420f9643..995d90763531 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2514,6 +2514,7 @@ if "TORCH_CUDA_SANITIZER" in os.environ: # Populate magic methods on SymInt and SymFloat import torch.fx.experimental.sym_node +from torch import fx as fx # Register MPS specific decomps diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index dd04cdd09d7f..74691bbe72ac 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -7,6 +7,8 @@ demonstration of these components in action: :: import torch + + # Simple module for demonstration class MyModule(torch.nn.Module): def __init__(self) -> None: @@ -17,11 +19,13 @@ demonstration of these components in action: def forward(self, x): return self.linear(x + self.param).clamp(min=0.0, max=1.0) + module = MyModule() from torch.fx import symbolic_trace + # Symbolic tracing frontend - captures the semantics of the module - symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) + symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) # High-level intermediate representation (IR) - Graph representation print(symbolic_traced.graph) @@ -80,10 +84,32 @@ Several example transformations can be found at the repository. ''' -from .graph_module import GraphModule -from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta -from .graph import Graph, CodeGen -from .node import Node, map_arg, has_side_effect -from .proxy import Proxy -from .interpreter import Interpreter as Interpreter, Transformer as Transformer -from .subgraph_rewriter import replace_pattern +from torch.fx._symbolic_trace import ( # noqa: F401 + PH, + ProxyableClassMeta, + symbolic_trace, + Tracer, + wrap, +) +from torch.fx.graph import CodeGen, Graph # noqa: F401 +from torch.fx.graph_module import GraphModule +from torch.fx.interpreter import Interpreter, Transformer +from torch.fx.node import has_side_effect, map_arg, Node +from torch.fx.proxy import Proxy +from torch.fx.subgraph_rewriter import replace_pattern + + +__all__ = [ + "symbolic_trace", + "Tracer", + "wrap", + "Graph", + "GraphModule", + "Interpreter", + "Transformer", + "Node", + "Proxy", + "replace_pattern", + "has_side_effect", + "map_arg", +] diff --git a/torch/fx/__init__.pyi b/torch/fx/__init__.pyi deleted file mode 100644 index 0a263dfc5071..000000000000 --- a/torch/fx/__init__.pyi +++ /dev/null @@ -1,15 +0,0 @@ -from torch.fx._symbolic_trace import ( - symbolic_trace as symbolic_trace, - Tracer as Tracer, - wrap as wrap, -) -from torch.fx.graph import Graph as Graph -from torch.fx.graph_module import GraphModule as GraphModule -from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer -from torch.fx.node import ( - has_side_effect as has_side_effect, - map_arg as map_arg, - Node as Node, -) -from torch.fx.proxy import Proxy as Proxy -from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern diff --git a/torch/fx/_compatibility.py b/torch/fx/_compatibility.py index 27c1e600036d..8a2eeb0d2d69 100644 --- a/torch/fx/_compatibility.py +++ b/torch/fx/_compatibility.py @@ -1,16 +1,19 @@ -from typing import Any, Dict, Callable, TypeVar import textwrap +from typing import Any, Callable, Dict, TypeVar + + +_BACK_COMPAT_OBJECTS: Dict[Any, None] = {} +_MARKED_WITH_COMPATIBILITY: Dict[Any, None] = {} -_BACK_COMPAT_OBJECTS : Dict[Any, None] = {} -_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {} _T = TypeVar("_T") + def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]: if is_backward_compatible: def mark_back_compat(fn: _T) -> _T: - docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") docstring += """ .. note:: Backwards-compatibility for this API is guaranteed. @@ -24,7 +27,7 @@ def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]: else: def mark_not_back_compat(fn: _T) -> _T: - docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") docstring += """ .. warning:: This API is experimental and is *NOT* backward-compatible. diff --git a/torch/fx/_lazy_graph_module.py b/torch/fx/_lazy_graph_module.py index 2a14fce3782e..cc2f686ebba1 100644 --- a/torch/fx/_lazy_graph_module.py +++ b/torch/fx/_lazy_graph_module.py @@ -1,9 +1,9 @@ # mypy: allow-untyped-defs from contextlib import contextmanager -from torch.fx import GraphModule from torch.fx.graph_module import ( _format_import_block, + GraphModule, reduce_graph_module, reduce_package_graph_module, ) diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 635686070399..38835c6ca374 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -1,13 +1,13 @@ # mypy: allow-untyped-defs import builtins -import copy +import collections import contextlib +import copy import functools import inspect import math import os import warnings -import collections from itertools import chain from types import CodeType, FunctionType, ModuleType from typing import ( @@ -29,11 +29,12 @@ from torch._C import ScriptObject # type: ignore[attr-defined] from torch._library.fake_class_registry import FakeScriptObject from ._compatibility import compatibility +from ._lazy_graph_module import _make_graph_module from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph from .graph_module import GraphModule -from ._lazy_graph_module import _make_graph_module from .node import Argument, base_types, map_aggregate -from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager +from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase + HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS @@ -49,6 +50,7 @@ _is_fx_tracing_flag = False def is_fx_tracing(): return _is_fx_tracing_flag + @compatibility(is_backward_compatible=True) class ProxyableClassMeta(type): """ @@ -58,6 +60,7 @@ class ProxyableClassMeta(type): import torch import torch.fx + class TensorPair(metaclass=torch.fx.ProxyableClassMeta): def __init__(self, left, right): self.left, self.right = left, right @@ -72,10 +75,12 @@ class ProxyableClassMeta(type): r = self.right * other.right return TensorPair(l, r) - def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): + + def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor): s = x.add(TensorPair(y, y)) return s.mul(x) + x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) y = torch.randn(5, 3) ref_out = use_tensor_pair_ctor(x, y) @@ -214,6 +219,7 @@ class PHWithMeta(PHBase): """ Object representing an input placeholder to `concrete_args` """ + def __init__(self, ph_key: Optional[str] = None): super().__init__() @@ -404,7 +410,11 @@ class Tracer(TracerBase): # Tensor was not found in the Module hierarchy, stow it away in a # special attribute and set the qualname to refer to that if not qualname: - base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj" + base_name = ( + "_tensor_constant" + if isinstance(a, torch.Tensor) + else "_torchbind_obj" + ) qualname = self.get_fresh_qualname(base_name) assert isinstance(qualname, str) self.tensor_attrs[a] = qualname @@ -446,9 +456,9 @@ class Tracer(TracerBase): appear with the qualified name ``foo.bar.baz`` here. """ return ( - (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) - and not isinstance(m, torch.nn.Sequential) - ) + m.__module__.startswith("torch.nn") + or m.__module__.startswith("torch.ao.nn") + ) and not isinstance(m, torch.nn.Sequential) @compatibility(is_backward_compatible=True) def path_of_module(self, mod: torch.nn.Module) -> str: @@ -512,17 +522,25 @@ class Tracer(TracerBase): value was returned from the ``Module`` invocation. """ module_qualified_name = self.path_of_module(m) - with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: + with ScopeContextManager( + self.scope, Scope(module_qualified_name, type(m)) + ) as _scope: # module_stack is an ordered dict so writing then deleting the # entry is equivalent to push/pop on a list num_calls = self.num_calls.get(module_qualified_name, 0) - module_key = f"{_scope.module_path}@{num_calls}" if num_calls > 0 else _scope.module_path + module_key = ( + f"{_scope.module_path}@{num_calls}" + if num_calls > 0 + else _scope.module_path + ) self.module_stack[module_key] = (module_qualified_name, _scope.module_type) self.num_calls[module_qualified_name] = num_calls + 1 if not self.is_leaf_module(m, module_qualified_name): ret_val = forward(*args, **kwargs) else: - ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) + ret_val = self.create_proxy( + "call_module", module_qualified_name, args, kwargs + ) key, _ = self.module_stack.popitem(last=True) assert key == module_key, f" Unexpected key {key}" @@ -551,6 +569,7 @@ class Tracer(TracerBase): The return value from the getattr call. """ + def maybe_get_proxy_for_attr( attr_val, collection_to_search, parameter_proxy_cache ): @@ -620,15 +639,16 @@ class Tracer(TracerBase): sig = inspect.signature(fn_for_analysis) - # This covers the very specific case where we are passing in flat # concrete_args as a tuple, but our traced fn takes (*args, **kwargs). # In this case, just take the concrete_args and pass them through. name_idx = 0 - if isinstance(concrete_args, tuple) and \ - len(concrete_args) > 0 and \ - (co.co_flags & HAS_VARSTUFF) and \ - total_args == 1: + if ( + isinstance(concrete_args, tuple) + and len(concrete_args) > 0 + and (co.co_flags & HAS_VARSTUFF) + and total_args == 1 + ): for concrete_arg in concrete_args: out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) if isinstance(concrete_arg, PHBase): @@ -722,12 +742,12 @@ class Tracer(TracerBase): _is_fx_tracing_flag = True try: if isinstance(root, torch.nn.Module): - # do real recompilation for _LazyGraphModule before retracing since the trace # method can not trace the _lazy_forward method. Got error: # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 # without this. from torch.fx._lazy_graph_module import _LazyGraphModule + _LazyGraphModule.force_recompile(root) self.root = root @@ -745,12 +765,12 @@ class Tracer(TracerBase): tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) self.graph = Graph(tracer_cls=tracer_cls) - if hasattr(fn, '__code__'): + if hasattr(fn, "__code__"): code = fn.__code__ self.graph._co_fields = { - 'co_name': code.co_name, - 'co_filename': code.co_filename, - 'co_firstlineno': code.co_firstlineno, + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, } # When we encounter a Tensor value that's not a parameter, we look if it @@ -758,11 +778,7 @@ class Tracer(TracerBase): # values to the qualified name here for efficiency. This is used downstream # in create_arg self.tensor_attrs: Dict[ - Union[ - torch.Tensor, - ScriptObject, - FakeScriptObject - ], str + Union[torch.Tensor, ScriptObject, FakeScriptObject], str ] = {} def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): @@ -839,7 +855,7 @@ class Tracer(TracerBase): new_tracer = Tracer.__new__(Tracer) for k, v in self.__dict__.items(): - if k in {'_autowrap_search'}: + if k in {"_autowrap_search"}: new_obj = copy.copy(v) else: new_obj = copy.deepcopy(v, memo) @@ -857,9 +873,7 @@ class Tracer(TracerBase): cnt += 1 param = sig.parameters[name] default = ( - () - if param.default is inspect.Parameter.empty - else (param.default,) + () if param.default is inspect.Parameter.empty else (param.default,) ) out = self.create_proxy( "placeholder", f"{name}_{str(cnt)}", default, {} @@ -877,11 +891,7 @@ class Tracer(TracerBase): return out # Union[int, bool] == bool in Python <= 3.6 - if ( - type(x) == bool - or type(x) in base_types - and type(x) != torch.Tensor - ): + if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor: torch._assert( out == x, f"{name} has been specialized to have value {x} but got another value", @@ -906,13 +916,15 @@ class Tracer(TracerBase): default = () else: param = sig.parameters[name] - default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] + default = ( # type: ignore[assignment] + () if param.default is inspect.Parameter.empty else (param.default,) + ) return self.create_proxy( "placeholder", name, default, {}, - type_expr=fn_for_analysis.__annotations__.get(name, None) + type_expr=fn_for_analysis.__annotations__.get(name, None), ) @@ -1011,6 +1023,7 @@ class _PatchedFnSetItem(_PatchedFn): def patch(self): self.frame_dict[self.fn_name] = self.new_fn + class _PatchedFnDel(_PatchedFn): def revert(self): del self.frame_dict[self.fn_name] @@ -1026,6 +1039,7 @@ class _PatchedFnSetAttr(_PatchedFn): def patch(self): setattr(self.frame_dict, self.fn_name, self.new_fn) + class _Patcher: def __init__(self) -> None: super().__init__() @@ -1106,6 +1120,7 @@ class _Patcher: CURRENT_PATCHER: Optional[_Patcher] = None + @contextlib.contextmanager def _new_patcher(): global CURRENT_PATCHER @@ -1132,7 +1147,10 @@ def _maybe_revert_all_patches(): finally: if current_patcher is not None: patches_made = current_patcher.reapply_all_patches() - assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches" + assert ( + patches_made == patches_removed + ), "CURRENT_PATCHER was changed during a revert_all_patches" + def _patch_wrapped_functions(patcher: _Patcher): """ @@ -1178,7 +1196,9 @@ def wrap(fn_or_name: Union[str, Callable]): def my_custom_function(x, y): return x * x + y * y - torch.fx.wrap('my_custom_function') + + torch.fx.wrap("my_custom_function") + def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into @@ -1248,14 +1268,14 @@ def symbolic_trace( if b == True: return a else: - return a*2 + return a * 2 FX can typically not trace through this due to the presence of control flow. However, we can use `concrete_args` to specialize on the value of `b` to trace through this:: - f = fx.symbolic_trace(f, concrete_args={'b': False}) - assert f(3, False) == 6 + f = fx.symbolic_trace(f, concrete_args={"b": False}) + assert f(3, False) == 6 Note that although you can still pass in different values of `b`, they will be ignored. @@ -1269,8 +1289,10 @@ def symbolic_trace( for v in x.values(): out += v return out - f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) - assert f({'a': 1, 'b': 2, 'c': 4}) == 7 + + + f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) + assert f({"a": 1, "b": 2, "c": 4}) == 7 Args: diff --git a/torch/fx/annotate.py b/torch/fx/annotate.py index d1b5b5f2d376..b3c505606625 100644 --- a/torch/fx/annotate.py +++ b/torch/fx/annotate.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs from torch.fx.proxy import Proxy + from ._compatibility import compatibility + @compatibility(is_backward_compatible=False) def annotate(val, type): """ @@ -18,13 +20,15 @@ def annotate(val, type): """ if isinstance(val, Proxy): if val.node.type: - raise RuntimeError(f"Tried to annotate a value that already had a type on it!" - f" Existing type is {val.node.type} " - f"and new type is {type}. " - f"This could happen if you tried to annotate a function parameter " - f"value (in which case you should use the type slot " - f"on the function signature) or you called " - f"annotate on the same value twice") + raise RuntimeError( + f"Tried to annotate a value that already had a type on it!" + f" Existing type is {val.node.type} " + f"and new type is {type}. " + f"This could happen if you tried to annotate a function parameter " + f"value (in which case you should use the type slot " + f"on the function signature) or you called " + f"annotate on the same value twice" + ) else: val.node.type = type return val diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index b832124bd09d..4f9fe0f9a140 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -1,22 +1,22 @@ # mypy: allow-untyped-defs import operator from collections import deque -from typing import Dict, List, Set, NamedTuple, Tuple, Deque +from typing import Deque, Dict, List, NamedTuple, Set, Tuple import torch -from torch.fx.passes.graph_manipulation import get_size_of_all_nodes from torch.fx.experimental.partitioner_utils import ( - Partition, Device, - PartitionerConfig, - get_partition_to_latency_mapping, - get_latency_of_partitioned_graph, - NodeLatency, get_extra_size_of, + get_latency_of_partitioned_graph, + get_partition_to_latency_mapping, + NodeLatency, + Partition, + PartitionerConfig, PartitionMode, ) from torch.fx.graph_module import GraphModule -from torch.fx.node import Node, map_arg +from torch.fx.node import map_arg, Node +from torch.fx.passes.graph_manipulation import get_size_of_all_nodes from torch.fx.passes.split_module import split_module @@ -260,7 +260,9 @@ def get_device_to_partitions_mapping( # Find devices for all the partitions without a device found_device = True for partition in no_device_partitions: - device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1))) + device_to_left_mem_bytes = dict( + sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1)) + ) found_device = find_device_for(partition) if not found_device: break diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 9b12a027f056..d1ca4acde2b8 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -7,7 +7,12 @@ from torch.fx.node import map_arg from torch.fx.passes.split_module import split_module -__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs'] +__all__ = [ + "FoldedGraphModule", + "get_unique_attr_name_in_module", + "split_const_subgraphs", +] + class FoldedGraphModule(torch.fx.GraphModule): """ diff --git a/torch/fx/experimental/debug.py b/torch/fx/experimental/debug.py index d3c482319f2e..e59dcbb3296f 100644 --- a/torch/fx/experimental/debug.py +++ b/torch/fx/experimental/debug.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import torch.fx as fx + def set_trace(gm: fx.GraphModule) -> fx.GraphModule: """ Sets a breakpoint in `gm`'s generated python code. It drops into pdb when @@ -13,18 +14,14 @@ def set_trace(gm: fx.GraphModule) -> fx.GraphModule: Returns: the `gm` with breakpoint inserted. """ + def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] with gm.graph.on_generate_code( make_transformer=lambda cur_transform: ( # new code transformer to register - lambda body: ( - insert_pdb( - cur_transform(body) if cur_transform - else body - ) - ) + lambda body: (insert_pdb(cur_transform(body) if cur_transform else body)) ) ): gm.recompile() diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index af1e6ab057c7..0be22bc0d795 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -1,20 +1,21 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -from functools import reduce -import torch -import operator -from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise -from typing import Callable, Dict -from torch.fx.node import Target, Node -from torch.nn.modules.batchnorm import BatchNorm2d -from torch.nn.modules.conv import Conv2d -from torch.fx.experimental.refinement_types import Equality import itertools - -from torch.fx.experimental.unification import Var # type: ignore[attr-defined] +import operator +from functools import reduce +from typing import Callable, Dict import sympy +import torch +from torch.fx.experimental.refinement_types import Equality +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] +from torch.fx.node import Node, Target +from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d + + _INFERENCE_RULES: Dict[Target, Callable] = {} _REFINEMENT_RULES: Dict[Target, Callable] = {} _RULES: Dict[Target, Callable] = {} @@ -32,10 +33,12 @@ def expand_to_tensor_dim(t, n): return TensorType(tuple(dims)) elif isinstance(t, TensorType): if len(t.__args__) != n: - raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}') + raise TypeError( + f"Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}" + ) return t else: - raise TypeError(f'Cannot match the type {t}') + raise TypeError(f"Cannot match the type {t}") def broadcast_types(t1, t2): @@ -80,32 +83,39 @@ def broadcast_types(t1, t2): (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) return (t1, t2) else: - raise TypeError(f'Cannot broadcast types {t1} and {t2}') + raise TypeError(f"Cannot broadcast types {t1} and {t2}") + def register_inference_rule(call_target): def register(fn): if call_target in _INFERENCE_RULES: - raise RuntimeError(f'Inference rule already registered for {call_target}!') + raise RuntimeError(f"Inference rule already registered for {call_target}!") _INFERENCE_RULES[call_target] = fn return fn + return register + def register_refinement_rule(call_target): def register(fn): if call_target in _REFINEMENT_RULES: - raise RuntimeError(f'Refinement rule already registered for {call_target}!') + raise RuntimeError(f"Refinement rule already registered for {call_target}!") _REFINEMENT_RULES[call_target] = fn return fn + return register + def register_algebraic_expressions_inference_rule(call_target): def register(fn): if call_target in _RULES: - raise RuntimeError(f'Rule already registered for {call_target}!') + raise RuntimeError(f"Rule already registered for {call_target}!") _RULES[call_target] = fn return fn + return register + @register_inference_rule(torch.add) @register_inference_rule(operator.add) def add_inference_rule(n: Node): @@ -142,15 +152,15 @@ def add_inference_rule(n: Node): (new_t1, new_t2) = broadcast_types(t1, t2) if new_t1 != t1 or new_t2 != t2: - n.meta['broadcast'] = True + n.meta["broadcast"] = True n.meta[str(n.args[0])] = new_t1 n.meta[str(n.args[1])] = new_t2 else: - n.meta['broadcast'] = False + n.meta["broadcast"] = False - new_t1 = t1 if not n.meta['broadcast'] else new_t1 - new_t2 = t2 if not n.meta['broadcast'] else new_t2 + new_t1 = t1 if not n.meta["broadcast"] else new_t1 + new_t2 = t2 if not n.meta["broadcast"] else new_t2 # we check for consistency between the new types if is_consistent(new_t1, new_t2): @@ -164,8 +174,11 @@ def add_inference_rule(n: Node): n.type = new_t1 return n.type else: - raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' - f' Types should match ') + raise TypeError( + f"Cannot add arguments {n.args[0]} ({n.args[0].type}) and {n.args[1]} ({n.args[1].type}) in node {n}." + f" Types should match " + ) + @register_inference_rule(getattr) def get_attr_inference_rule(n: Node, traced): @@ -185,6 +198,7 @@ def get_attr_inference_rule(n: Node, traced): # TODO. We leave it like this till we add a type to represent tensor sizes return n.type + @register_inference_rule(torch.transpose) def transpose_inference_rule(n: Node): """ @@ -211,9 +225,13 @@ def transpose_inference_rule(n: Node): n.type = get_greatest_upper_bound(n.type, final) return n.type else: - raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) else: - raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) @register_inference_rule(torch.reshape) @@ -251,9 +269,10 @@ def reshape_inference_rule(n: Node): n.type = t2_type return t2_type else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") + @register_inference_rule(BatchNorm2d) def bn2d_inference_rule(n: Node, module_instance): @@ -274,10 +293,11 @@ def bn2d_inference_rule(n: Node, module_instance): # we check the conditions on the incoming argument # and any existing annotation # we also check for consistency between both annotations - if is_consistent(arg_type.__args__[1], module_instance.num_features) and \ - is_consistent(n.type.__args__[1], module_instance.num_features) and \ - is_consistent(arg_type, n.type): - + if ( + is_consistent(arg_type.__args__[1], module_instance.num_features) + and is_consistent(n.type.__args__[1], module_instance.num_features) + and is_consistent(arg_type, n.type) + ): # we choose the more precise type # to be the node type # so if an incoming argument has more type information @@ -285,21 +305,35 @@ def bn2d_inference_rule(n: Node, module_instance): n.type = get_greatest_upper_bound(arg_type, n.type) return n.type else: - raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}') + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) def calculate_out_dimension(d_in, module_instance, index): """ For calculating h_in and w_out according to the conv2D documentation """ - padding = (module_instance.padding, module_instance.padding) \ - if isinstance(module_instance.padding, int) else module_instance.padding - kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \ - if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size - stride = (module_instance.stride, module_instance.stride) \ - if isinstance(module_instance.stride, int) else module_instance.stride - dilation = (module_instance.dilation, module_instance.dilation) \ - if isinstance(module_instance.dilation, int) else module_instance.dilation + padding = ( + (module_instance.padding, module_instance.padding) + if isinstance(module_instance.padding, int) + else module_instance.padding + ) + kernel_size = ( + (module_instance.kernel_size, module_instance.kernel_size) + if isinstance(module_instance.kernel_size, int) + else module_instance.kernel_size + ) + stride = ( + (module_instance.stride, module_instance.stride) + if isinstance(module_instance.stride, int) + else module_instance.stride + ) + dilation = ( + (module_instance.dilation, module_instance.dilation) + if isinstance(module_instance.dilation, int) + else module_instance.dilation + ) DIMENSION_TYPES = (int, sympy.Symbol) @@ -307,14 +341,14 @@ def calculate_out_dimension(d_in, module_instance, index): return Dyn elif isinstance(d_in, DIMENSION_TYPES): - n = d_in + 2 * padding[index] - \ - dilation[index] * \ - (kernel_size[index] - 1) - 1 + n = d_in + 2 * padding[index] - dilation[index] * (kernel_size[index] - 1) - 1 return (n // stride[0]) + 1 else: - raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}') + raise TypeError( + f"{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}" + ) def get_greatest_upper_bound(type1, type2): @@ -327,8 +361,11 @@ def get_greatest_upper_bound(type1, type2): return type1 elif isinstance(type1, TensorType) and isinstance(type2, TensorType): if not is_consistent(type1, type2): - raise TypeError(f'Inconsistent types {type1}, {type2}') - gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)] + raise TypeError(f"Inconsistent types {type1}, {type2}") + gub = [ + t1 if is_more_precise(t1, t2) else t2 + for (t1, t2) in zip(type1.__args__, type2.__args__) + ] return TensorType(tuple(gub)) @@ -352,12 +389,16 @@ def conv2d_inference_rule(n: Node, module_instance): h_in = arg_type.__args__[2] h_out = calculate_out_dimension(h_in, module_instance, 0) w_out = calculate_out_dimension(w_in, module_instance, 1) - new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out)) + new_type = TensorType( + (arg_type.__args__[0], module_instance.out_channels, h_out, w_out) + ) gub = get_greatest_upper_bound(new_type, curr_node_type) n.type = gub return n.type else: - raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}') + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) @register_inference_rule(torch.nn.ReLU) @@ -393,7 +434,7 @@ def maxpool2d_check(typ, module_instance): return TensorType(tuple(new_type_list)) else: - raise TypeError(f'Wrong size {typ} for {module_instance}') + raise TypeError(f"Wrong size {typ} for {module_instance}") @register_inference_rule(torch.nn.MaxPool2d) @@ -417,7 +458,6 @@ def maxpool2d_inference_rule(n: Node, module_instance): return n.type - def linear_check(tensor_type, module_instance): """ Checks that an input tensor type satisfies the conditions for linear operation @@ -429,9 +469,11 @@ def linear_check(tensor_type, module_instance): new_type_args[-1] = module_instance.out_features return TensorType(tuple(new_type_args)) else: - raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}') + raise TypeError( + f"Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}" + ) else: - raise TypeError(f'Type {tensor_type} must have rank 2 or more.') + raise TypeError(f"Type {tensor_type} must have rank 2 or more.") @register_inference_rule(torch.nn.Linear) @@ -469,7 +511,8 @@ def adaptiveavgpool2d_check(tensor_type, module_instance): return TensorType(tuple(new_type_list)) else: - raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}') + raise TypeError(f"Tensor ranks must be 3 or 4. Got {tensor_type}") + @register_inference_rule(torch.nn.AdaptiveAvgPool2d) def adaptiveavgpool2d_inference_rule(n: Node, module_instance): @@ -485,6 +528,7 @@ def adaptiveavgpool2d_inference_rule(n: Node, module_instance): n.type = get_greatest_upper_bound(n.type, output_type) return n.type + def flatten_check(tensor_type, start_dim, end_dim): l = len(tensor_type.__args__) @@ -503,7 +547,10 @@ def flatten_check(tensor_type, start_dim, end_dim): new_type_list = lhs + mid + rhs return TensorType(tuple(new_type_list)) else: - raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}') + raise TypeError( + f"Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}" + ) + @register_inference_rule(torch.flatten) def flatten_inference_rule(n: Node): @@ -530,10 +577,11 @@ def flatten_inference_rule(n: Node): if isinstance(n.args[0].type, TensorType): output_type = flatten_check(n.args[0].type, start_dim, end_dim) - n.type = get_greatest_upper_bound(output_type , n.type) + n.type = get_greatest_upper_bound(output_type, n.type) return n.type + class GraphTypeChecker: def __init__(self, env, traced): self.env = env @@ -571,16 +619,16 @@ class GraphTypeChecker: if n.type is None: n.type = Dyn - if n.op == 'placeholder': + if n.op == "placeholder": return n.type - elif n.op == 'get_attr': + elif n.op == "get_attr": t = get_parameter(self.traced, n.target) # type: ignore[arg-type] if isinstance(t.data, torch.Tensor): n.type = TensorType(t.data.shape) return n.type - elif n.op == 'call_function': + elif n.op == "call_function": if n.target == getattr: assert getattr in _INFERENCE_RULES return _INFERENCE_RULES[n.target](n, self.traced) @@ -588,18 +636,24 @@ class GraphTypeChecker: elif n.target in _INFERENCE_RULES: return _INFERENCE_RULES[n.target](n) else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) - elif n.op == 'call_module': + elif n.op == "call_module": module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _INFERENCE_RULES: return _INFERENCE_RULES[type(module_instance)](n, module_instance) else: - raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) + + elif n.op == "output": - elif n.op == 'output': def get_node_type(a): return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) return n.type @@ -634,6 +688,7 @@ def linear_refinement_rule(n: Node): res = [Equality(arg_type.__args__[0], n.type.__args__[0])] return res + @register_refinement_rule(BatchNorm2d) @register_refinement_rule(torch.nn.ReLU) def all_eq(n: Node): @@ -688,7 +743,11 @@ def element_wise_eq(n: Node): if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): arg_type1 = n.args[0].type arg_type2 = n.args[1].type - if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): + if ( + isinstance(arg_type1, TensorType) + and isinstance(arg_type2, TensorType) + and isinstance(n.type, TensorType) + ): args1, args2 = broadcast_types(arg_type1, arg_type2) # by this point, we know that args1 and args2 are the same size. a1 = args1.__args__ @@ -757,12 +816,14 @@ def conv_rule(n: Node, module_instance): n.type = new_type return new_type + class Refine: """ Symbolic shape inference. Generates constraints over type variables. Currently all constraints are equality constraints. """ + def __init__(self, traced): self.constraints = [] self.traced = traced @@ -805,7 +866,6 @@ class Refine: else: return typ - def convert_to_sympy_symbols(self, typ): """ Replace all unknown types with fresh type variables. @@ -835,22 +895,24 @@ class Refine: n.type = self.replace_dyn_with_fresh_var(n.type) - if n.op == 'call_function': + if n.op == "call_function": if n.target in _REFINEMENT_RULES: self.constraints += _REFINEMENT_RULES[n.target](n) else: pass - if n.op == 'call_module': + if n.op == "call_module": module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _REFINEMENT_RULES: self.constraints += _REFINEMENT_RULES[type(module_instance)](n) else: pass - if n.op == 'output': + if n.op == "output": + def get_node_type(a): return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) return n.type @@ -859,28 +921,31 @@ class Refine: def infer_symbolic_relations(self, n: Node): n.type = self.convert_to_sympy_symbols(n.type) - if n.op == 'call_function': + if n.op == "call_function": if n.target in _RULES: return _RULES[n.target](n) else: pass - if n.op == 'call_module': + if n.op == "call_module": module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _RULES: return _RULES[type(module_instance)](n, module_instance) else: pass - if n.op == 'output': + if n.op == "output": + def get_node_type(a): return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) return n.type else: pass + def get_parameter(traced, target: str): """ Returns the parameter given by ``target`` if it exists, diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py index c1a634b2602a..b3e1efcbd19e 100644 --- a/torch/fx/experimental/merge_matmul.py +++ b/torch/fx/experimental/merge_matmul.py @@ -1,14 +1,13 @@ # mypy: allow-untyped-defs -import torch - -from torch.fx.node import Node -from torch.fx._symbolic_trace import symbolic_trace -from torch.fx.passes.tools_common import legalize_graph import itertools import operator - from typing import Dict, List, Tuple +import torch +from torch.fx._symbolic_trace import symbolic_trace +from torch.fx.node import Node +from torch.fx.passes.tools_common import legalize_graph + def split_result_tensors( result: torch.Tensor, inputs: List[torch.Tensor] @@ -146,7 +145,14 @@ def merge_matmul(in_mod: torch.nn.Module): # Multiply the concatenated LHS operands with the one RHS. This will produce # the same results as all the individual matmuls involving rhs in the original graph, # but they will all be concatenated together. - merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) + merge_mm = gm.graph.call_function( + torch.matmul, + ( + merge_mm_cat, + rhs, + ), + {}, + ) # Split the result of the merged matmul using the shapes of the LHS operands # to ascertain how large each chunk should be. diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py index 0ea47f7c3b14..1b74f33f40b5 100644 --- a/torch/fx/experimental/meta_tracer.py +++ b/torch/fx/experimental/meta_tracer.py @@ -1,14 +1,15 @@ # mypy: allow-untyped-defs -import torch -import torch.fx -import warnings -import functools import builtins - +import functools +import warnings from typing import Any, Callable, Dict, Optional, Union +import torch +import torch.fx + + def embedding_override(self, input): - return torch.empty(*input.shape, self.weight.shape[-1], device='meta') + return torch.empty(*input.shape, self.weight.shape[-1], device="meta") def nn_layernorm_override(self, input): @@ -24,21 +25,22 @@ def torch_nn_relu_override(self, x): def functional_relu_override(x, inplace=False): - assert not inplace, 'dont support inplace functional.relu for metatensor analysis' + assert not inplace, "dont support inplace functional.relu for metatensor analysis" return x def torch_where_override(condition, x, y): # torch.where returns the broadcasted tensor of condition, x, and y, # so hack it by using addition - return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta') + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") def torch_abs_override(input, *, out=None): - assert out is None, 'Dont support in-place abs for MetaTensor analysis' + assert out is None, "Dont support in-place abs for MetaTensor analysis" return input -manual_meta_overrides : Dict[Callable, Callable] = { + +manual_meta_overrides: Dict[Callable, Callable] = { torch.nn.Embedding: embedding_override, torch.nn.LayerNorm: nn_layernorm_override, torch.relu: torch_relu_override, @@ -48,6 +50,7 @@ manual_meta_overrides : Dict[Callable, Callable] = { torch.abs: torch_abs_override, } + def gen_constructor_wrapper(target): @functools.wraps(target) def wrapper(*args, **kwargs): @@ -57,57 +60,66 @@ def gen_constructor_wrapper(target): if isinstance(v, torch.fx.Proxy): nonlocal proxy proxy = v + torch.fx.node.map_aggregate(args, check_has_proxy) torch.fx.node.map_aggregate(kwargs, check_has_proxy) if proxy is not None: - return proxy.tracer.create_proxy('call_function', target, args, kwargs) + return proxy.tracer.create_proxy("call_function", target, args, kwargs) else: return target(*args, **kwargs) + return wrapper, target + class MetaProxy(torch.fx.Proxy): def install_tensor_meta(self, tensor_meta): self._tensor_meta = tensor_meta def size(self, dim=None): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.size(*[dim] if dim else []) - return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {}) + return self.tracer.create_proxy( + "call_method", "size", (self, dim) if dim else (self,), {} + ) def dim(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.dim() - return self.tracer.create_proxy('call_method', 'dim', (self,), {}) + return self.tracer.create_proxy("call_method", "dim", (self,), {}) @property def shape(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.shape - return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {}) + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "shape"), {} + ) @property def dtype(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.dtype - return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {}) + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "dtype"), {} + ) @property def device(self): # Hack so we can track when devices are used. During meta-tensor propagation, # replace these values with a constant 'meta' - return MetaDeviceAttribute(self, 'device') + return MetaDeviceAttribute(self, "device") def __getattr__(self, k): - if k == '_tensor_meta': + if k == "_tensor_meta": return self.__getattribute__(k) # note: not added to the graph yet, if this is a method call # we peephole optimize to the method invocation return MetaAttribute(self, k) + class MetaAttribute(MetaProxy): def __init__(self, root, attr: str): - self.root = root self.attr = attr self.tracer = root.tracer @@ -118,33 +130,51 @@ class MetaAttribute(MetaProxy): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy( + "call_function", getattr, (self.root, self.attr), {} + ).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy( + "call_method", self.attr, (self.root,) + args, kwargs + ) + class MetaDeviceAttribute(MetaAttribute): pass + def proxys_to_metas(v): if isinstance(v, MetaDeviceAttribute): - return 'meta' + return "meta" if isinstance(v, torch.fx.Proxy): - assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}' - assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta' + assert isinstance(v, MetaProxy), f"Expected MetaProxy but got {type(v)}" + assert hasattr(v, "_tensor_meta"), "MetaProxy does not have an associated meta" return v._tensor_meta return v + class MetaTracer(torch.fx.Tracer): - allow_insert_stateless_mods : bool = True + allow_insert_stateless_mods: bool = True - _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye'] + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] - def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): - rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + def create_proxy( + self, + kind, + target, + args, + kwargs, + name=None, + type_expr=None, + proxy_factory_fn=None, + ): + rv = super().create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) - if kind == 'placeholder' and target in self.meta_args: + if kind == "placeholder" and target in self.meta_args: rv.install_tensor_meta(self.meta_args[target]) return rv @@ -154,54 +184,57 @@ class MetaTracer(torch.fx.Tracer): # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, # this will break and you will likely see issues where we cannot infer # the size of the output. - if 'device' in kwargs: - kwargs['device'] = 'meta' + if "device" in kwargs: + kwargs["device"] = "meta" try: args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas) kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas) - if kind == 'call_function': + if kind == "call_function": meta_target = manual_meta_overrides.get(target, target) meta_out = meta_target(*args_metas, **kwargs_metas) - elif kind == 'call_method': - meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index] - elif kind == 'call_module': - assert hasattr(self, 'orig_forward') + elif kind == "call_method": + meta_target = getattr(args_metas[0], target) # type: ignore[index] + meta_out = meta_target(*args_metas[1:], **kwargs_metas) # type: ignore[index] + elif kind == "call_module": + assert hasattr(self, "orig_forward") self._disable_module_getattr = True try: mod = self.root.get_submodule(target) mod_type = type(mod) if mod_type in manual_meta_overrides: - meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type] + meta_out = manual_meta_overrides[mod_type]( + mod, *args_metas, **kwargs_metas + ) # type: ignore[misc, arg-type] else: meta_out = self.orig_forward(*args_metas, **kwargs_metas) finally: self._disable_module_getattr = False - elif kind == 'get_attr': + elif kind == "get_attr": self._disable_module_getattr = True try: attr_itr = self.root - atoms = target.split('.') + atoms = target.split(".") for atom in atoms: attr_itr = getattr(attr_itr, atom) assert isinstance(attr_itr, torch.Tensor) - meta_out = attr_itr.to(device='meta') + meta_out = attr_itr.to(device="meta") finally: self._disable_module_getattr = False else: return rv # TODO - assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet' + assert isinstance(rv, torch.fx.Proxy), "Dont support composite output yet" rv.install_tensor_meta(meta_out) except Exception as e: - warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}') + warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") return rv def getattr(self, attr, attr_val, parameter_proxy_cache): - if getattr(self, '_disable_module_getattr', False): + if getattr(self, "_disable_module_getattr", False): return attr_val else: return super().getattr(attr, attr_val, parameter_proxy_cache) @@ -228,7 +261,11 @@ class MetaTracer(torch.fx.Tracer): try: return super().path_of_module(mod) except NameError: - if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0: + if ( + self.allow_insert_stateless_mods + and len(list(mod.parameters())) == 0 + and len(list(mod.buffers())) == 0 + ): path = self._insert_module_as_submodule(mod) self.prev_module = path return path @@ -237,12 +274,13 @@ class MetaTracer(torch.fx.Tracer): def proxy(self, node): return MetaProxy(node, self) - def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] + def trace(self, root, meta_args: Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] assert isinstance(meta_args, dict) self.meta_args = meta_args self.patched_torch_methods = { - target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH + target: gen_constructor_wrapper(getattr(torch, target)) + for target in self._TORCH_METHODS_TO_PATCH } self.orig_fns = set() @@ -252,18 +290,22 @@ class MetaTracer(torch.fx.Tracer): try: graph = super().trace(root, concrete_args) - graph._tracer_extras = {'meta_args': meta_args} + graph._tracer_extras = {"meta_args": meta_args} return graph finally: for name, (_, orig) in self.patched_torch_methods.items(): setattr(torch, name, orig) -def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], - meta_args : Optional[Dict[str, torch.Tensor]] = None, - concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule: +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + meta_args: Optional[Dict[str, torch.Tensor]] = None, + concrete_args: Optional[Dict[str, Any]] = None, +) -> torch.fx.GraphModule: tracer = MetaTracer() graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type] - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) gm = torch.fx.GraphModule(tracer.root, graph, name) return gm diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py index 4693a62de240..8aca3e482c95 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -1,7 +1,16 @@ # mypy: allow-untyped-defs -from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ - op_mod, op_gt, op_lt, op_neq, op_eq -from torch.fx.tensor_type import TensorType, Dyn +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_div, + op_eq, + op_gt, + op_lt, + op_mod, + op_mul, + op_neq, + op_sub, +) +from torch.fx.tensor_type import Dyn, TensorType class Constraint: @@ -22,7 +31,7 @@ class Conj(Constraint): return False def __repr__(self): - return f'And({self.conjucts})' + return f"And({self.conjucts})" class Disj(Constraint): @@ -34,12 +43,14 @@ class Disj(Constraint): def __eq__(self, other): if isinstance(other, Disj): - return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + return ( + self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + ) else: return False def __repr__(self): - return f'Or({self.disjuncts})' + return f"Or({self.disjuncts})" class Prod(Constraint): @@ -56,13 +67,14 @@ class Prod(Constraint): return False def __repr__(self): - return f'Product({self.products})' + return f"Product({self.products})" class T(Constraint): """ True """ + def __init__(self) -> None: pass @@ -70,12 +82,14 @@ class T(Constraint): return isinstance(other, T) def __repr__(self): - return 'True' + return "True" + class F(Constraint): """ False """ + def __init__(self) -> None: pass @@ -83,13 +97,14 @@ class F(Constraint): return isinstance(other, F) def __repr__(self): - return 'False' + return "False" class BinaryConstraint(Constraint): """ Represents all binary operations """ + def __init__(self, lhs, rhs, op): """ :param lhs: lhs of the constraint @@ -102,21 +117,25 @@ class BinaryConstraint(Constraint): def __eq__(self, other): if isinstance(other, BinaryConstraint): - return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + return ( + self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + ) else: return False def __repr__(self): - return f'({self.lhs} {self.op} {self.rhs})' + return f"({self.lhs} {self.op} {self.rhs})" class BinConstraintT(BinaryConstraint): """ Binary constraints about tensors """ + def __init__(self, lhs, rhs, op): - assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \ - (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn) + assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and ( + isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn + ) super().__init__(lhs, rhs, op) def __eq__(self, other): @@ -127,6 +146,7 @@ class BinConstraintD(BinaryConstraint): """ Binary constraints about dimensions """ + def __init__(self, lhs, rhs, op): assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) @@ -137,11 +157,11 @@ class BinConstraintD(BinaryConstraint): return super().__eq__(other) - class TGreatestUpperBound(Constraint): """ Greatest Upper bound for tensors with dynamic type """ + def __init__(self, res, rhs1, rhs2): """ :param res: tensor variable that stores the result of the outout @@ -153,11 +173,15 @@ class TGreatestUpperBound(Constraint): self.rhs2 = rhs2 def __repr__(self): - return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}' + return f"{self.res} = {self.rhs1}\u2294*{self.rhs2}" def __eq__(self, other): if isinstance(other, TGreatestUpperBound): - return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) else: return False @@ -166,6 +190,7 @@ class DGreatestUpperBound(Constraint): """ Greatest Upper bound for dimensions """ + def __init__(self, res, rhs1, rhs2): """ :param res: Dimension variable to store the result @@ -181,11 +206,15 @@ class DGreatestUpperBound(Constraint): self.rhs2 = rhs2 def __repr__(self): - return f'{self.res} = {self.rhs1}\u2294{self.rhs2}' + return f"{self.res} = {self.rhs1}\u2294{self.rhs2}" def __eq__(self, other): if isinstance(other, DGreatestUpperBound): - return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) else: return False @@ -194,6 +223,7 @@ class CanReshape(Constraint): """ can_reshape constraint """ + def __init__(self, src, target): """ :param src: tensor variable @@ -203,7 +233,7 @@ class CanReshape(Constraint): self.target = target def __repr__(self): - return f'can-reshape({self.src}, {self.target})' + return f"can-reshape({self.src}, {self.target})" def __eq__(self, other): if isinstance(other, CanReshape): @@ -213,7 +243,6 @@ class CanReshape(Constraint): class IndexSelect(Constraint): - def __init__(self, tensor_size, input_var, dim_replace, index, output): """ Args: @@ -235,26 +264,28 @@ class IndexSelect(Constraint): self.output = output def __repr__(self): - - return f' {self.output} = ' \ - f'IndexSelect({self.input_var}, ' \ - f'tensor_size: {self.tensor_size}, ' \ - f'{self.dim_replace}, ' \ - f'{self.index})' + return ( + f" {self.output} = " + f"IndexSelect({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.dim_replace}, " + f"{self.index})" + ) def __eq__(self, other): if isinstance(other, IndexSelect): - return self.tensor_size == other.tensor_size and \ - self.dim_replace == other.dim_replace and \ - self.index == other.index and \ - self.output == other.output and \ - self.input_var == other.input_var + return ( + self.tensor_size == other.tensor_size + and self.dim_replace == other.dim_replace + and self.index == other.index + and self.output == other.output + and self.input_var == other.input_var + ) else: return False class Transpose(Constraint): - def __init__(self, tensor_size, input_var, index1, index2, output): """ Args: @@ -276,26 +307,28 @@ class Transpose(Constraint): self.output = output def __repr__(self): - - return f' {self.output} = ' \ - f'Transpose({self.input_var}, ' \ - f'tensor_size: {self.tensor_size}, ' \ - f'{self.index1}, ' \ - f'{self.index2})' + return ( + f" {self.output} = " + f"Transpose({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.index1}, " + f"{self.index2})" + ) def __eq__(self, other): if isinstance(other, Transpose): - return self.tensor_size == other.tensor_size and \ - self.index1 == other.index1 and \ - self.index2 == other.index2 and \ - self.output == other.output and \ - self.input_var == other.input_var + return ( + self.tensor_size == other.tensor_size + and self.index1 == other.index1 + and self.index2 == other.index2 + and self.output == other.output + and self.input_var == other.input_var + ) else: return False class GetItem(Constraint): - def __init__(self, tensor_size, index, res, input_var): """ Constraint for getting item given a tensor size @@ -312,19 +345,21 @@ class GetItem(Constraint): self.input_var = input_var def __repr__(self): - return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})' + return f" {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})" def __eq__(self, other): if isinstance(other, GetItem): - return self.res == other.res and \ - self.tensor_size == other.tensor_size and \ - self.index == other.index and \ - self.input_var == other.input_var + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index == other.index + and self.input_var == other.input_var + ) else: return False -class GetItemTensor(Constraint): +class GetItemTensor(Constraint): def __init__(self, tensor_size, index_tuple, res, input_var): """ Constraint for getting item given a tensor size @@ -343,20 +378,32 @@ class GetItemTensor(Constraint): self.input_var = input_var def __repr__(self): - return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})' + return f" {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})" def __eq__(self, other): if isinstance(other, GetItemTensor): - return self.res == other.res and \ - self.tensor_size == other.tensor_size and \ - self.index_tuple == other.index_tuple and \ - self.input_var == other.input_var + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index_tuple == other.index_tuple + and self.input_var == other.input_var + ) else: return False -class CalcConv(Constraint): - def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars): +class CalcConv(Constraint): + def __init__( + self, + conv_result, + input_var, + c_out, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): """ :param conv_result: the convolution result :param input_var: input to convolution @@ -373,25 +420,41 @@ class CalcConv(Constraint): self.matching_constraint = matching_constraint_vars def __repr__(self): - return f'{self.conv_result} =' \ - f' calc-conv({self.input_var},' \ - f' {self.c_out}, {self.kernel}, ' \ - f'{self.padding}, {self.stride},' \ - f' {self.dilation})' + return ( + f"{self.conv_result} =" + f" calc-conv({self.input_var}," + f" {self.c_out}, {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) def __eq__(self, other): if isinstance(other, CalcConv): - return self.conv_result == other.conv_result and self.input_var == other.input_var and \ - self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \ - and self.stride == other.stride and self.dilation == other.dilation \ + return ( + self.conv_result == other.conv_result + and self.input_var == other.input_var + and self.c_out == other.c_out + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation and self.matching_constraint == other.matching_constraint + ) else: return False class CalcMaxPool(Constraint): - - def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars): + def __init__( + self, + maxpool_result, + input_var, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): """ :param maxpool_result: the result of maxpool :param input_var: input to convolution @@ -406,18 +469,25 @@ class CalcMaxPool(Constraint): self.matching_constraint = matching_constraint_vars def __repr__(self): - return f'{self.maxpool_result} =' \ - f' calc-maxpool({self.input_var},' \ - f' {self.kernel}, ' \ - f'{self.padding}, {self.stride},' \ - f' {self.dilation})' + return ( + f"{self.maxpool_result} =" + f" calc-maxpool({self.input_var}," + f" {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) def __eq__(self, other): if isinstance(other, CalcMaxPool): - return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \ - and self.kernel == other.kernel and self.padding == other.padding \ - and self.stride == other.stride and self.dilation == other.dilation \ + return ( + self.maxpool_result == other.maxpool_result + and self.input_var == other.input_var + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation and self.matching_constraint == other.matching_constraint + ) else: return False @@ -437,21 +507,28 @@ class ApplyBroadcasting(Constraint): def __eq__(self, other): if isinstance(other, ApplyBroadcasting): - return self.res1 == other.res1 \ - and self.res2 == other.res2 \ - and self.input1 == other.input1 \ + return ( + self.res1 == other.res1 + and self.res2 == other.res2 + and self.input1 == other.input1 and self.input2 == other.input2 + ) else: return False def __repr__(self): - return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})' + return ( + f"{self.res1}, {self.res2} =" + f" apply-broadcasting({self.input1}," + f" {self.input2})" + ) class CalcProduct(Constraint): """ Given correct dimensions, calculate the product for flatten accounting for Dyn """ + def __init__(self, start, end, flattened, dims_to_flatten): """ :param start: start index @@ -471,20 +548,25 @@ class CalcProduct(Constraint): def __eq__(self, other): if isinstance(other, CalcProduct): - return self.start == other.start and self.end == other.end and \ - self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened + return ( + self.start == other.start + and self.end == other.end + and self.dims_to_flatten == other.dims_to_flatten + and self.flattened == other.flattened + ) else: return False def __repr__(self): - return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})' + return f"{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})" class TVar: """ Tensor variable with no tensor constructor """ + def __init__(self, tvar): """ :param tvar: tensor variable @@ -492,7 +574,7 @@ class TVar: self.tvar = tvar def __repr__(self): - return f'TV({self.tvar})' + return f"TV({self.tvar})" def __eq__(self, other): if isinstance(other, TVar): @@ -505,6 +587,7 @@ class DVar: """ Dimension variable """ + def __init__(self, c): """ :param c: character or number @@ -512,7 +595,7 @@ class DVar: self.c = c def __repr__(self): - return f'DV({self.c})' + return f"DV({self.c})" def __eq__(self, other): if isinstance(other, DVar): @@ -525,6 +608,7 @@ class BVar: """ Boolean variable """ + def __init__(self, c): """ :param c: character or number @@ -532,7 +616,7 @@ class BVar: self.c = c def __repr__(self): - return f'BV({self.c})' + return f"BV({self.c})" def __eq__(self, other): if isinstance(other, BVar): @@ -554,5 +638,6 @@ def is_bool_expr(constraint): else: return isinstance(constraint, (BVar, Conj, Disj)) + def is_dim(d): return isinstance(d, (DVar, int)) or d == Dyn diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index 952dde662f2a..de7fd6689451 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -1,34 +1,71 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -import torch import operator import warnings from typing import Callable, Dict, Iterable +import torch from torch.fx._symbolic_trace import _assert_is_none -from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \ - Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \ - TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound -from torch.fx.experimental.migrate_gradual_types.operation import \ - op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul -from torch.fx.node import Target, Node -from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \ - gen_bvar - +from torch.fx.experimental.migrate_gradual_types.constraint import ( + ApplyBroadcasting, + BinConstraintD, + BinConstraintT, + CalcConv, + CalcMaxPool, + CalcProduct, + CanReshape, + Conj, + DGreatestUpperBound, + Disj, + DVar, + F, + GetItem, + GetItemTensor, + IndexSelect, + T, + TGreatestUpperBound, + Transpose, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_consistency, + op_div, + op_eq, + op_gt, + op_leq, + op_lt, + op_matching, + op_mul, + op_neq, + op_precision, + op_sub, +) +from torch.fx.experimental.migrate_gradual_types.util import ( + gen_bvar, + gen_dvar, + gen_nat_constraints, + gen_tensor_dims, + gen_tvar, +) +from torch.fx.node import Node, Target from torch.fx.tensor_type import Dyn, TensorType -from torch.nn.modules.conv import Conv2d from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d + _INFERENCE_RULES: Dict[Target, Callable] = {} MAX_TENSOR_RANK = 4 + def register_inference_rule(call_target): def register(fn): if call_target in _INFERENCE_RULES: - raise RuntimeError(f'Inference rule already registered for {call_target}!') + raise RuntimeError(f"Inference rule already registered for {call_target}!") _INFERENCE_RULES[call_target] = fn return fn + return register @@ -55,10 +92,11 @@ def get_attr_inference_rule(n: Node, symbols, constraints, counter): input = symbols[n.args[0]] attr = n.args[1] - if attr == 'device': + if attr == "device": return [BinConstraintT(input, output, op_eq)], counter else: - raise NotImplementedError('Not yet implemented') + raise NotImplementedError("Not yet implemented") + @register_inference_rule(torch.bmm) def bmm_inference_rule(n: Node, symbols, constraints, counter): @@ -79,26 +117,53 @@ def bmm_inference_rule(n: Node, symbols, constraints, counter): dims_input1, counter = gen_tensor_dims(3, counter) dims_input2, counter = gen_tensor_dims(3, counter) - inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), - BinConstraintT(bmm_input2, Dyn, op_eq), - BinConstraintT(bmm_output, Dyn, op_eq)]) + inputs_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_output, Dyn, op_eq), + ] + ) - input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), - BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), - BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)]) + input1_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq + ), + ] + ) - input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq), - BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), - BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)]) + input2_dyn = Conj( + [ + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq + ), + ] + ) - consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)] + consistency_constraints = [ + BinConstraintD(dims_input1[0], dims_input2[0], op_consistency) + ] batch_size, counter = gen_dvar(counter) - inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), - BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), - BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), - *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])]) + inputs_are_tensors = Conj( + [ + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, + TensorType([batch_size, dims_input1[1], dims_input2[2]]), + op_eq, + ), + *consistency_constraints, + DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]), + ] + ) return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter @@ -115,8 +180,6 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[1], int) assert isinstance(n.args[2], Node) - - index_select, counter = gen_tvar(counter) symbols[n] = index_select @@ -126,10 +189,30 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter): is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) - c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) - for i in range(MAX_TENSOR_RANK)])]) - c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) - for i in range(MAX_TENSOR_RANK)])]) + c2 = Conj( + [ + is_size_1, + Disj( + [ + IndexSelect( + i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select + ) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) + c3 = Conj( + [ + is_dyn, + Disj( + [ + IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) return [Disj([c2, c3])], counter @@ -158,14 +241,27 @@ def expand_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(symbols[arg], DVar) e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) - e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq) + e2_constraint = BinConstraintT( + e2, + TensorType( + [arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]] + ), + op_eq, + ) - constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand) + constraints, counter = gen_broadcasting_constraints( + e1, e2, symbols, counter, expand + ) # constraint the output size dims, counter = gen_tensor_dims(len(n.args[1:]), counter) nat_constraints = gen_nat_constraints(dims) - c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints] + c = [ + BinConstraintT(expand, TensorType(dims), op_eq), + *nat_constraints, + e2_constraint, + *e2_nat_constraints, + ] constraints += c return constraints, counter @@ -206,7 +302,7 @@ def equality_inference_rule(n: Node, symbols, constraints, counter): my_size = [symbols[arg] for arg in n.args[0]] return [BinConstraintT(output, TensorType(my_size), op_eq)], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") @register_inference_rule("transpose") @@ -225,10 +321,17 @@ def transpose_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(from_arg, TVar) # input and output are dyn - is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]) + is_dyn = Conj( + [BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)] + ) # or input is a tensor and we actually do the replacement - c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)]) + c3 = Disj( + [ + Transpose(i + 1, from_arg, n.args[1], n.args[2], output) + for i in range(MAX_TENSOR_RANK) + ] + ) return [Disj([is_dyn, c3])], counter @@ -250,8 +353,11 @@ def type_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(from_arg, TVar) assert isinstance(to_arg, TVar) - return [BinConstraintT(from_arg, to_arg, op_consistency), - BinConstraintT(output, to_arg, op_eq)], counter + return [ + BinConstraintT(from_arg, to_arg, op_consistency), + BinConstraintT(output, to_arg, op_eq), + ], counter + @register_inference_rule("masked_fill_") def masked_fill_inference_rule(n: Node, symbols, constraints, counter): @@ -273,9 +379,11 @@ def masked_fill_inference_rule(n: Node, symbols, constraints, counter): if isinstance(e1, TVar) and isinstance(e2, TVar): masked_fill_tensor, counter = gen_tvar(counter) symbols[n] = masked_fill_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor) + return gen_broadcasting_constraints( + e1, e2, symbols, counter, masked_fill_tensor + ) else: - raise NotImplementedError('Not yet implemented') + raise NotImplementedError("Not yet implemented") @register_inference_rule(torch.nn.functional.embedding) @@ -286,7 +394,9 @@ def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): # will treat this as a static shape. So we will not use matching. weight_dims, counter = gen_tensor_dims(2, counter) - equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq) + equality_constraint = BinConstraintT( + embedding_dim_weights, TensorType(weight_dims), op_eq + ) embedding_dim = weight_dims[1] constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) return [equality_constraint] + constraints, counter @@ -302,7 +412,6 @@ def embedding_inference_rule(n: Node, module_instance, symbols, constraints, cou def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): - embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] @@ -318,9 +427,15 @@ def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases - c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), - BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + - nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(embedding_input, TensorType(new_dims), op_eq), + BinConstraintT( + embedding_output, TensorType(new_dims + [embedding_dim]), op_eq + ), + ] + + nat_constraints + ) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter @@ -348,9 +463,10 @@ def view_inference_rule(n: Node, symbols, constraints, counter): my_view, counter = gen_tvar(counter) symbols[n] = my_view - src_var = symbols[n.args[0]] - t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape + t2 = [ + symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:] + ] # target shape t2_type = [] num_constraints = [] @@ -382,7 +498,6 @@ def size_inference_rule(n: Node, symbols, constraints, counter): Ex: size = input_ids.size() """ - if len(n.args) == 1: # generate the new variable size, counter = gen_tvar(counter) @@ -398,7 +513,10 @@ def size_inference_rule(n: Node, symbols, constraints, counter): size_index, counter = gen_dvar(counter) symbols[n] = size_index input = symbols[n.args[0]] - c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)] + c2 = [ + GetItem(i + 1, n.args[1], size_index, input) + for i in range(MAX_TENSOR_RANK) + ] c3 = BinConstraintD(0, size_index, op_leq) input_dyn = BinConstraintT(input, Dyn, op_eq) @@ -452,9 +570,14 @@ def cumsum_inference_rule(n: Node, symbols, constraints, counter): nat_constraints = gen_nat_constraints(new_dims) - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), - BinConstraintT(output, TensorType(new_dims), op_eq)] + - [range_check(arg_1, i)] + nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims), op_eq), + BinConstraintT(output, TensorType(new_dims), op_eq), + ] + + [range_check(arg_1, i)] + + nat_constraints + ) c2.append(c_tensor_i) dyn_or_tensor = Disj([c1, Disj(c2)]) @@ -481,7 +604,6 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): get_item_arg = symbols[n.args[0]] assert isinstance(get_item_arg, TVar) - # if the input is dynamic, we accept any index and return # a dynamic dimension as output input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) @@ -492,8 +614,10 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): # generate a getItem constraint which will be expanded based on the # tensor dimension. - c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)] - + c2 = [ + GetItem(i + 1, n.args[1], get_item_output, get_item_arg) + for i in range(MAX_TENSOR_RANK) + ] # since the output is a dimension, we make sure it's a natural number # added as a conjunction to the disjunction of c2 @@ -515,8 +639,10 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] c1 = Conj([input_dyn, output_dyn]) - c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] - for i in range(MAX_TENSOR_RANK)] + c2 = [ + GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] + for i in range(MAX_TENSOR_RANK) + ] else: # TODO: we should figure out why there is a key-error here. return [], counter @@ -524,7 +650,7 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): return [Disj([c1, *c2])], counter else: - raise RuntimeError('Method not yet implemented') + raise RuntimeError("Method not yet implemented") @register_inference_rule(operator.gt) @@ -553,7 +679,7 @@ def gt_inference_rule(n: Node, symbols, constraints, counter): return [equality_constraint], counter else: - raise RuntimeError('Sort Mismatch') + raise RuntimeError("Sort Mismatch") elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): if isinstance(e1, DVar): @@ -567,7 +693,9 @@ def gt_inference_rule(n: Node, symbols, constraints, counter): elif isinstance(e1, TVar) and isinstance(e2, int): # then we made the wrong assumption about the argument being a tensor # so we should fix the assumption - warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.') + warnings.warn( + f"Made the wrong assumption for node {n}. Correctness not guaranteed." + ) new_e1, counter = gen_dvar(counter) symbols[n.args[0]] = new_e1 @@ -580,10 +708,10 @@ def gt_inference_rule(n: Node, symbols, constraints, counter): return [equality_constraint], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") @register_inference_rule(operator.eq) @@ -609,7 +737,7 @@ def eq_inference_rule(n: Node, symbols, constraints, counter): return [equality_constraint], counter else: - raise RuntimeError('Sort Mismatch') + raise RuntimeError("Sort Mismatch") elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): if isinstance(e1, DVar): @@ -620,9 +748,10 @@ def eq_inference_rule(n: Node, symbols, constraints, counter): equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) return [equality_constraint], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") + @register_inference_rule(operator.ne) def neq_inference_rule(n: Node, symbols, constraints, counter): @@ -641,7 +770,6 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): # implementing for size 3 and 4 if len(n.args[1]) == 3: - assert isinstance(n.args[1][0], (Node, int)) assert isinstance(n.args[1][1], (Node, int)) assert isinstance(n.args[1][2], (Node, int)) @@ -662,11 +790,19 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): neq_3 = BinConstraintD(d3, b[2], op_neq) # dimensions inconsistent - dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]) - dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]) - dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]) + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3] + ) - dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]) + dims_inconsistent = Disj( + [dims_inconsistent1, dims_inconsistent2, dims_inconsistent3] + ) # we are covering size 3 and 4 only for now ne_constraint = Conj([input_is_size3, dims_inconsistent]) @@ -675,7 +811,6 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) elif len(n.args[1]) == 4: - assert isinstance(n.args[1][0], (Node, int)) assert isinstance(n.args[1][1], (Node, int)) assert isinstance(n.args[1][2], (Node, int)) @@ -703,12 +838,27 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): neq_4 = BinConstraintD(d4, b4, op_neq) # dimensions to inconsistent - dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]) - dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]) - dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]) - dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]) + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3] + ) + dims_inconsistent4 = Conj( + [BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4] + ) - dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4]) + dims_inconsistent = Disj( + [ + dims_inconsistent1, + dims_inconsistent2, + dims_inconsistent3, + dims_inconsistent4, + ] + ) ne_constraint = Conj([input_is_size4, dims_inconsistent]) @@ -717,7 +867,7 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") return [equality_constraint], counter @@ -748,7 +898,7 @@ def lt_inference_rule(n: Node, symbols, constraints, counter): return [equality_constraint], counter else: - raise RuntimeError('Sort Mismatch') + raise RuntimeError("Sort Mismatch") elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): if isinstance(e1, DVar): @@ -759,10 +909,10 @@ def lt_inference_rule(n: Node, symbols, constraints, counter): equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) return [equality_constraint], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") @register_inference_rule(torch.full) @@ -788,28 +938,42 @@ def arange_inference_rule(n: Node, symbols, constraints, counter): if len(n.args) == 1: end = symbols[n.args[0]] else: - raise NotImplementedError('Not yet implemented') + raise NotImplementedError("Not yet implemented") # int((end - start) / step) d1, counter = gen_dvar(counter) - size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) + size_constraint = BinConstraintD( + d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq + ) arange, counter = gen_tvar(counter) symbols[n] = arange # either the a parameter is a number or it is Dyn - c1 = Disj([BinConstraintD(end, Dyn, op_eq), - BinConstraintD(start, Dyn, op_eq), - BinConstraintD(step, Dyn, op_eq)]) + c1 = Disj( + [ + BinConstraintD(end, Dyn, op_eq), + BinConstraintD(start, Dyn, op_eq), + BinConstraintD(step, Dyn, op_eq), + ] + ) c2 = BinConstraintD(d1, Dyn, op_eq) both_dyn = Conj([c1, c2]) - c11 = Conj([BinConstraintD(end, Dyn, op_neq), - BinConstraintD(start, Dyn, op_neq), - BinConstraintD(step, Dyn, op_neq)]) + c11 = Conj( + [ + BinConstraintD(end, Dyn, op_neq), + BinConstraintD(start, Dyn, op_neq), + BinConstraintD(step, Dyn, op_neq), + ] + ) c22 = BinConstraintD(d1, Dyn, op_neq) both_numbers = Conj([c11, c22, size_constraint]) - return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter + return [ + BinConstraintT(arange, TensorType([d1]), op_eq), + Disj([both_dyn, both_numbers]), + ], counter + def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): # additional vars that don't correspond to expressions @@ -829,7 +993,6 @@ def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): @register_inference_rule(torch.add) @register_inference_rule(operator.add) def broadcasting_inference_rule(n: Node, symbols, constraints, counter): - op_code = None if n.target == operator.add or n.target == torch.add: op_code = op_add @@ -837,7 +1000,9 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): op_code = op_mul if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar): + if isinstance(symbols[n.args[0]], TVar) and isinstance( + symbols[n.args[1]], TVar + ): my_output, counter = gen_tvar(counter) symbols[n] = my_output e1 = symbols[n.args[0]] @@ -845,7 +1010,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)): if isinstance(symbols[n.args[0]], TVar): @@ -859,8 +1024,14 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): e1 = symbols[n.args[0]] # we will propagate the runtime value here since this is regular addition - c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq), - BinConstraintD(0, my_output, op_leq)]) + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e1, n.args[1], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) return [c], counter elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)): @@ -875,16 +1046,22 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): e2 = symbols[n.args[1]] # we will propagate the runtime value here since this is regular addition - c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq), - BinConstraintD(0, my_output, op_leq)]) + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e2, n.args[0], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) return [c], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") else: # TODO generate add constraints for scalar addition - raise NotImplementedError('Addition not yet implemented') + raise NotImplementedError("Addition not yet implemented") @register_inference_rule(torch.flatten) @@ -915,7 +1092,9 @@ def flatten_inference_rule(n: Node, symbols, constraints, counter): const = [] for i in range(1, MAX_TENSOR_RANK + 1): - c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter) + c, counter = generate_flatten_constraints( + start_dim, end_dim, input, flattened, i, counter + ) const.append(c) return [Disj([both_dyn, *const])], counter @@ -937,7 +1116,9 @@ def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, co Input should be consistent with the normalized_shape """ assert isinstance(n.args[0], Node) - return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter) + return gen_layer_norm_constraints( + n, module_instance.normalized_shape, symbols, counter + ) def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): @@ -955,13 +1136,18 @@ def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): new_dims_rhs, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs) - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), - BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + - add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + - nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs), op_eq), + BinConstraintT(output, TensorType(new_dims_rhs), op_eq), + ] + + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + + nat_constraints + ) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter + @register_inference_rule(torch.nn.Dropout) @register_inference_rule(torch.nn.ReLU) def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): @@ -983,7 +1169,9 @@ def linear_inference_rule(n: Node, module_instance, symbols, constraints, counte If the input is Dyn, then so should the output """ assert isinstance(n.args[0], Node) - return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter) + return linear_constraints( + n, module_instance.in_features, module_instance.out_features, symbols, counter + ) @register_inference_rule("dim") # type: ignore[attr-defined] @@ -1001,8 +1189,12 @@ def torch_dim_inference_rule(n: Node, symbols, constraints, counter): for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), - BinConstraintD(my_dim, i, op_eq)]) + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintD(my_dim, i, op_eq), + ] + ) c1.append(c_tensor_i) return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter @@ -1012,8 +1204,12 @@ def torch_dim_inference_rule(n: Node, symbols, constraints, counter): def torch_linear_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) weight_dims, counter = gen_tensor_dims(2, counter) - equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq) - constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter) + equality_constraint = BinConstraintT( + symbols[n.args[1]], TensorType(weight_dims), op_eq + ) + constraints, counter = linear_constraints( + n, weight_dims[1], weight_dims[0], symbols, counter + ) return [equality_constraint] + constraints, counter @@ -1034,13 +1230,20 @@ def linear_constraints(n: Node, in_features, out_features, symbols, counter): nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) - c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), - BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + - add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) + - nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq), + ] + + add_linear_constraints( + new_dims_rhs_1, new_dims_rhs_2, in_features, out_features + ) + + nat_constraints + ) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter + def add_layer_norm_constraints(input_dim, normalized_dim): """ The constraints say that the type has te form: [*, 1024, 1024] @@ -1130,7 +1333,13 @@ def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, coun d4, counter = gen_dvar(counter) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq) + c2 = BinConstraintT( + avg_pool, + TensorType( + [d1, d2, module_instance.output_size[0], module_instance.output_size[1]] + ), + op_eq, + ) return [c1, c2, *nat_constraints], counter @@ -1152,12 +1361,16 @@ def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counte # c2 = DConsistency(module_instance.in_channels, d2) c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) - c3 = CalcConv(my_conv, input_var, - module_instance.out_channels, - module_instance.kernel_size, - module_instance.padding, - module_instance.stride, - module_instance.dilation, [d1, d2, d3, d4]) + c3 = CalcConv( + my_conv, + input_var, + module_instance.out_channels, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) @@ -1176,8 +1389,15 @@ def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, count c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding, - module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) + c2 = CalcMaxPool( + maxpool, + input_var, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) @@ -1190,8 +1410,7 @@ class ConstraintGenerator: self.traced_params = dict(self.traced.named_parameters()) self.constraints = [] self.symbol_dict = {} - self.graph = traced.graph if hasattr(traced, 'graph') else graph - + self.graph = traced.graph if hasattr(traced, "graph") else graph def generate_constraints(self, counter=0): """ @@ -1217,7 +1436,7 @@ class ConstraintGenerator: - conv2d """ - if n.op == 'placeholder': + if n.op == "placeholder": x, counter = gen_tvar(counter) self.symbol_dict[n] = x @@ -1226,8 +1445,8 @@ class ConstraintGenerator: if n.type != Dyn and (not isinstance(n.type, TensorType)): if n.type == torch.nn.parameter.Parameter: # since we have a parameter, the shape must be static - assert 'example_value' in n.meta - my_type = TensorType(n.meta['example_value'].size()) + assert "example_value" in n.meta + my_type = TensorType(n.meta["example_value"].size()) else: my_type = Dyn @@ -1235,30 +1454,38 @@ class ConstraintGenerator: c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) return [c1, c2], counter - elif n.op == 'call_function': + elif n.op == "call_function": if n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') - - elif n.op == 'call_module': + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + elif n.op == "call_module": module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _INFERENCE_RULES: - return _INFERENCE_RULES[type(module_instance)](n, - module_instance, - self.symbol_dict, - self.constraints, counter) + return _INFERENCE_RULES[type(module_instance)]( + n, module_instance, self.symbol_dict, self.constraints, counter + ) else: - raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) - elif n.op == 'call_method': + elif n.op == "call_method": if n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) - elif n.op == 'get_attr': + elif n.op == "get_attr": t = self.traced_params.get(n.target, None) if isinstance(t, torch.Tensor): @@ -1274,7 +1501,7 @@ class ConstraintGenerator: else: return [], counter - elif n.op == 'output': + elif n.op == "output": return [], counter else: diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py index a784495a1d81..7a854b1dabe8 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -1,30 +1,67 @@ # mypy: ignore-errors import copy import itertools -from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK -from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \ - Transpose -from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound -from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound -from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool -from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape -from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect -from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching -from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq -from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod -from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar -from torch.fx.tensor_type import TensorType, Dyn from typing import Callable, Dict, List +from torch.fx.experimental.migrate_gradual_types.constraint import ( + ApplyBroadcasting, + BinConstraintD, + CalcConv, + CalcMaxPool, + CalcProduct, + CanReshape, + Conj, + Constraint, + DGreatestUpperBound, + Disj, + DVar, + F, + GetItem, + GetItemTensor, + IndexSelect, + Prod, + T, + TGreatestUpperBound, + Transpose, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ( + BinConstraintT, + MAX_TENSOR_RANK, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_consistency, + op_div, + op_eq, + op_leq, + op_matching, + op_mod, + op_mul, + op_neq, + op_precision, + op_sub, +) +from torch.fx.experimental.migrate_gradual_types.util import ( + gen_dvar, + gen_nat_constraints, + gen_tensor_dims, +) +from torch.fx.tensor_type import Dyn, TensorType + + _TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} def register_transformation_rule(call_target): def register(fn): if call_target in _TRANSFORMATION_RULES: - raise RuntimeError(f'Transformation rule already registered for {call_target}!') + raise RuntimeError( + f"Transformation rule already registered for {call_target}!" + ) _TRANSFORMATION_RULES[call_target] = fn return fn + return register @@ -54,10 +91,15 @@ def transform_transpose(constraint, counter): new_dims[constraint.index1] = dims[constraint.index2] new_dims[constraint.index2] = dims[constraint.index1] - transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index1, is_valid_index2, - BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + transformed_constraint = Conj( + [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index1, + is_valid_index2, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq), + ] + ) return transformed_constraint, counter @@ -78,10 +120,14 @@ def transform_index_select(constraint, counter): new_dims = copy.deepcopy(dims) new_dims[constraint.index] = constraint.dim_replace - transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index, - BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + transformed_constraint = Conj( + [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq), + ] + ) # print(constraints) return transformed_constraint, counter @@ -106,20 +152,24 @@ def transform_get_item(constraint, counter): dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) - is_valid_index = valid_index(constraint.index, dims) - all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index] + all_constraints = [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + ] # if the index is valid, we generate a constraint for getting an item # otherwise this clause will have been UNSAT due to the wrong index if is_valid_index == T(): - all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq)) + all_constraints.append( + BinConstraintD(constraint.res, dims[constraint.index], op_eq) + ) return Conj(all_constraints), counter + def valid_index_tensor(index, dims): """ if the slice instances exceed the length of the dimensions @@ -134,6 +184,7 @@ def valid_index_tensor(index, dims): else: return T() + @register_transformation_rule(GetItemTensor) def transform_get_item_tensor(constraint, counter): """ @@ -151,7 +202,6 @@ def transform_get_item_tensor(constraint, counter): """ assert isinstance(constraint.index_tuple, tuple) - # generate a result tensor of the expected size dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) @@ -163,7 +213,6 @@ def transform_get_item_tensor(constraint, counter): dim_index = 0 for i in range(len(constraint.index_tuple)): - # append 1 to the right location of the resulting tensor if constraint.index_tuple[i] is None: resulting_tensor_dims[i] = 1 @@ -172,7 +221,7 @@ def transform_get_item_tensor(constraint, counter): pass else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") # append the remaining dimensions to the right location dim_index = 0 @@ -189,10 +238,12 @@ def transform_get_item_tensor(constraint, counter): return F(), counter else: - constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), - *nat_constraints, - is_valid_index] + constraints = [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), + *nat_constraints, + is_valid_index, + ] return Conj(constraints), counter @@ -217,11 +268,14 @@ def generate_binconstraint_t(constraint, counter): dim, counter = gen_dvar(counter) new_dims.append(dim) - new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for - new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \ - [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \ - [BinConstraintD(1, new_dim, op_leq) for - new_dim in new_dims] + new_dim_constraints = ( + [ + BinConstraintD(old_dim, new_dim, op_precision) + for new_dim, old_dim in zip(new_dims, constraint.lhs.__args__) + ] + + [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + + [BinConstraintD(1, new_dim, op_leq) for new_dim in new_dims] + ) return Conj(new_dim_constraints), counter # matching @@ -232,17 +286,39 @@ def generate_binconstraint_t(constraint, counter): d3 = constraint.rhs.__args__[2] d4 = constraint.rhs.__args__[3] - conj = [BinConstraintT(constraint.lhs, Dyn, op_eq), - BinConstraintD(d1, Dyn, op_eq), - BinConstraintD(d2, Dyn, op_eq), - BinConstraintD(d3, Dyn, op_eq), - BinConstraintD(d4, Dyn, op_eq)] - return Disj([Conj(conj), - BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter + conj = [ + BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintD(d1, Dyn, op_eq), + BinConstraintD(d2, Dyn, op_eq), + BinConstraintD(d3, Dyn, op_eq), + BinConstraintD(d4, Dyn, op_eq), + ] + return ( + Disj( + [ + Conj(conj), + BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq), + ] + ), + counter, + ) elif constraint.op == op_consistency: - c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)]) - [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter) + c_dyn = Disj( + [ + BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintT(constraint.rhs, Dyn, op_eq), + ] + ) + ( + ( + c_tensor_1, + c_tensor_2, + c_tensor_3, + c_tensor_4, + ), + counter, + ) = gen_consistency_constraints(constraint, counter) return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter @@ -272,8 +348,16 @@ def generate_binconstraint_d(constraint, counter): return T(), counter elif constraint.op == op_consistency: - return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq), - BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter + return ( + Disj( + [ + BinConstraintD(constraint.lhs, constraint.rhs, op_eq), + BinConstraintD(constraint.rhs, Dyn, op_eq), + BinConstraintD(constraint.lhs, Dyn, op_eq), + ] + ), + counter, + ) else: return constraint, counter @@ -309,8 +393,17 @@ def generate_gub(constraint, counter): Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound on dimensions """ - c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq), - BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)]) + c1 = Conj( + [ + Disj( + [ + BinConstraintT(constraint.rhs1, Dyn, op_eq), + BinConstraintT(constraint.rhs2, Dyn, op_eq), + ] + ), + BinConstraintT(constraint.res, Dyn, op_eq), + ] + ) [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) @@ -322,9 +415,24 @@ def generate_d_gub(constraint, counter): """ Transform greatest upper bound for dimensions into equality constraints """ - c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)]) - c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) - c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) + c1 = Conj( + [ + BinConstraintD(constraint.rhs1, Dyn, op_eq), + BinConstraintD(constraint.res, constraint.rhs2, op_eq), + ] + ) + c2 = Conj( + [ + BinConstraintD(constraint.rhs2, Dyn, op_eq), + BinConstraintD(constraint.res, constraint.rhs1, op_eq), + ] + ) + c3 = Conj( + [ + BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), + BinConstraintD(constraint.res, constraint.rhs1, op_eq), + ] + ) return Disj([c1, c2, c3]), counter @@ -337,17 +445,26 @@ def generate_calc_conv(constraint, counter): c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) # the second dimension of the output is equal to the output channels - c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)]) + c2 = Conj( + [ + BinConstraintD(d[1], constraint.c_out, op_eq), + BinConstraintD(d[1], Dyn, op_neq), + ] + ) # the input corresponds to the output in the first dimension of the convolution c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) c4, c5 = calc_last_two_dims(constraint, d) - leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), - BinConstraintD(0, d[1], op_leq), - BinConstraintD(0, d[2], op_leq), - BinConstraintD(0, d[3], op_leq)]) + leq_constraints = Conj( + [ + BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq), + ] + ) return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter @@ -368,10 +485,14 @@ def generate_calc_maxpool(constraint, counter): c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) c4, c5 = calc_last_two_dims(constraint, d) - leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), - BinConstraintD(0, d[1], op_leq), - BinConstraintD(0, d[2], op_leq), - BinConstraintD(0, d[3], op_leq)]) + leq_constraints = Conj( + [ + BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq), + ] + ) return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter @@ -388,7 +509,7 @@ def generate_calc_product(constraint, counter): n = len(constraint.dims_to_flatten) # this will be evaluated right here - boundary_check = (0 <= start and start < end and end <= n) + boundary_check = 0 <= start and start < end and end <= n c_boundary = T() if boundary_check else F() @@ -410,16 +531,40 @@ def generate_calc_product(constraint, counter): if len(total_constraints) > 4: all_constraints.append(F()) else: - all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p)) + all_constraints.append( + Conj( + [ + BinConstraintT( + flattened, TensorType(lhs + mid_var + rhs), op_eq + ) + ] + + p + ) + ) else: new_var, counter = gen_dvar(counter) - mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)]) + mid_eq_prod = Conj( + [ + BinConstraintD(new_var, Prod(mid), op_eq), + BinConstraintD(new_var, Dyn, op_neq), + ] + ) mid_var = [new_var] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: - all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p)) + all_constraints.append( + Conj( + [ + BinConstraintT( + flattened, TensorType(lhs + mid_var + rhs), op_eq + ), + mid_eq_prod, + ] + + p + ) + ) return Conj([Disj(all_constraints), c_boundary]), counter @@ -466,22 +611,40 @@ def generate_reshape(constraint, counter): if is_fully_static: # size 1 tensor - c3_tensor1 = Disj([d1_eq_dyn, - (Conj([d1_neq_dyn, - BinConstraintD(d1, Prod(target), op_eq)]))]) + c3_tensor1 = Disj( + [d1_eq_dyn, (Conj([d1_neq_dyn, BinConstraintD(d1, Prod(target), op_eq)]))] + ) all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) # size 2 tensor - all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]) + all_tensor_2 = Conj( + [c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)] + ) # size 3 tensor - all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]) + all_tensor_3 = Conj( + [c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)] + ) # size 4 tensor - all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]) + all_tensor_4 = Conj( + [c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)] + ) - return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), - nat_d1, nat_d2, nat_d3, nat_d4]), counter + return ( + Conj( + [ + Disj( + [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4] + ), + nat_d1, + nat_d2, + nat_d3, + nat_d4, + ] + ), + counter, + ) # then there must be exactly one occurrence of dyn else: @@ -492,28 +655,57 @@ def generate_reshape(constraint, counter): new_target.append(n) # tensor 1 - c3_tensor1 = Disj([d1_eq_dyn, - (Conj([d1_neq_dyn, - is_dim_div_by_target(new_target, d1)]))]) + c3_tensor1 = Disj( + [d1_eq_dyn, (Conj([d1_neq_dyn, is_dim_div_by_target(new_target, d1)]))] + ) all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) # tensor 2 c21 = Disj([d1_eq_dyn, d2_eq_dyn]) - c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]) + c22 = Conj( + [d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))] + ) all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) # tensor 3 c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) - c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))]) + c32 = Conj( + [ + d1_neq_dyn, + d2_neq_dyn, + d3_neq_dyn, + is_dim_div_by_target(new_target, Prod([d1, d2, d3])), + ] + ) all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) # tensor 4 c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) - c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))]) + c42 = Conj( + [ + d1_neq_dyn, + d2_neq_dyn, + d3_neq_dyn, + d4_neq_dyn, + is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4])), + ] + ) all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) - return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), - nat_d1, nat_d2, nat_d3, nat_d4]), counter + return ( + Conj( + [ + Disj( + [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4] + ), + nat_d1, + nat_d2, + nat_d3, + nat_d4, + ] + ), + counter, + ) @register_transformation_rule(ApplyBroadcasting) @@ -537,40 +729,58 @@ def generate_broadcasting(constraint, counter): # tensor possibility # generate dimensions to create tensors of size 1 - final_tensor_1_constraint, _, _, nat_dims_1, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter) + final_tensor_1_constraint, _, _, nat_dims_1, counter = gen_broadcasting_constraints( + e1, e2, e11, e12, 1, counter + ) # generate dimensions to create tensors of size 2 - final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \ - final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) - - # generate dimensions to create tensors of size 3 - final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \ - final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) - - # generate dimensions to create tensors of size 4 - final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \ - final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) - - final_result = Disj([ - e1_dyn_constraint, - e2_dyn_constraint, - final_tensor_1_constraint, + ( final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, final_tensor_2_constraint_padding_arg2, + nat_dims_2, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) + + # generate dimensions to create tensors of size 3 + ( final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, final_tensor_3_constraint_padding_arg2, + nat_dims_3, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) + + # generate dimensions to create tensors of size 4 + ( final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, - final_tensor_4_constraint_padding_arg2 - ]) + final_tensor_4_constraint_padding_arg2, + nat_dims_4, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) - return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter + final_result = Disj( + [ + e1_dyn_constraint, + e2_dyn_constraint, + final_tensor_1_constraint, + final_tensor_2_constraint_no_padding, + final_tensor_2_constraint_padding_arg1, + final_tensor_2_constraint_padding_arg2, + final_tensor_3_constraint_no_padding, + final_tensor_3_constraint_padding_arg1, + final_tensor_3_constraint_padding_arg2, + final_tensor_4_constraint_no_padding, + final_tensor_4_constraint_padding_arg1, + final_tensor_4_constraint_padding_arg2, + ] + ) + + return ( + Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), + counter, + ) def transform_constraint(constraint: Constraint, counter: int): @@ -591,8 +801,6 @@ def transform_constraint(constraint: Constraint, counter: int): return constraint, counter - - def calc_last_two_dims(constraint, d: List[DVar]): """ Generates constraints for the last two dimensions of a convolution or a maxpool output @@ -612,29 +820,49 @@ def calc_last_two_dims(constraint, d: List[DVar]): b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) - d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]) - d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]) + d3_not_dyn = Conj( + [BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)] + ) + d4_not_dyn = Conj( + [BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)] + ) # transform parameters into tuples incase they are not already - padding = (constraint.padding, constraint.padding) \ - if isinstance(constraint.padding, int) else constraint.padding - kernel = (constraint.kernel, constraint.kernel) \ - if isinstance(constraint.kernel, int) else constraint.kernel - stride = (constraint.stride, constraint.stride) \ - if isinstance(constraint.stride, int) else constraint.stride - dilation = (constraint.dilation, constraint.dilation) \ - if isinstance(constraint.dilation, int) else constraint.dilation + padding = ( + (constraint.padding, constraint.padding) + if isinstance(constraint.padding, int) + else constraint.padding + ) + kernel = ( + (constraint.kernel, constraint.kernel) + if isinstance(constraint.kernel, int) + else constraint.kernel + ) + stride = ( + (constraint.stride, constraint.stride) + if isinstance(constraint.stride, int) + else constraint.stride + ) + dilation = ( + (constraint.dilation, constraint.dilation) + if isinstance(constraint.dilation, int) + else constraint.dilation + ) f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) - f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div) + f3 = BinConstraintD( + BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div + ) f4 = BinConstraintD(f3, 1, op_add) c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) - f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div) + f33 = BinConstraintD( + BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div + ) f44 = BinConstraintD(f33, 1, op_add) c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) @@ -652,8 +880,12 @@ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): one possibility about the values of the dimension variables """ # generate all possibilities of being equal or not equal to dyn for my_list - eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))] - neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))] + eq_possibilities = [ + BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list)) + ] + neq_possibilities = [ + BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list)) + ] d_possibilities = [] for i in zip(eq_possibilities, neq_possibilities): @@ -721,10 +953,13 @@ def gen_all_reshape_possibilities(list_of_dims, target): all_constraints.append(Conj(p)) elif len(to_multiply) < len(list_of_dims): - all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])) + all_constraints.append( + Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]) + ) else: - all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims), - Prod(target), op_eq)])) + all_constraints.append( + Conj(p + [BinConstraintD(Prod(list_of_dims), Prod(target), op_eq)]) + ) return Disj(all_constraints) @@ -746,27 +981,36 @@ def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False if tensor_input1[index] is None: assert padding - if not padding: # then the inputs are the same length so they all have dimensions at "index" - return Conj([BinConstraintD(tensor_input1[index], 1, op_eq), - BinConstraintD(res1[index], res2[index], op_eq), - BinConstraintD(res2[index], tensor_input2[index], op_eq)]) + return Conj( + [ + BinConstraintD(tensor_input1[index], 1, op_eq), + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq), + ] + ) else: # we don't set the input dimension to 1, since it doesn't exist. - return Conj([BinConstraintD(res1[index], res2[index], op_eq), - BinConstraintD(res2[index], tensor_input2[index], op_eq)]) + return Conj( + [ + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq), + ] + ) -def apply_padding(e1_var: TVar, - e11: BinConstraintT, - e2: BinConstraintT, - e12: BinConstraintT, - d2: List[DVar], - d11: List[DVar], - d12: List[DVar], - counter: int): +def apply_padding( + e1_var: TVar, + e11: BinConstraintT, + e2: BinConstraintT, + e12: BinConstraintT, + d2: List[DVar], + d11: List[DVar], + d12: List[DVar], + counter: int, +): """ We are considering the possibility where one input has less dimensions than another input, so we apply padding to the broadcasted results @@ -789,7 +1033,6 @@ def apply_padding(e1_var: TVar, # pad the shorter input with None so we can pass it to the broadcasting helper function for i in range(1, len(d2)): - d1, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) @@ -804,30 +1047,37 @@ def apply_padding(e1_var: TVar, # for every padding size, we also consider broadcasting for j in range(len(d2) - i): - broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True)) + broadcast_padding.append( + broadcast_dim(simulate_padding, d2, d11, d12, j, True) + ) # we consider the possibilities for broadcasting for every dimension. Since we already # padded d1, we do not consider it while broadcasting - all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1, - d2[(len(d2) - i):], - d11[(len(d2) - i):], - d12[(len(d2) - i):]) + all_broadcasting_possibilities = ( + generate_all_broadcasting_possibilities_no_padding( + d1, d2[(len(d2) - i) :], d11[(len(d2) - i) :], d12[(len(d2) - i) :] + ) + ) # combine all constraints into a conjunction - c = Conj([e1, e11, e2, e12, - *broadcast_padding, - all_broadcasting_possibilities, - *nat_constraints - ]) + c = Conj( + [ + e1, + e11, + e2, + e12, + *broadcast_padding, + all_broadcasting_possibilities, + *nat_constraints, + ] + ) res.append(c) return Disj(res), counter -def no_broadcast_dim_with_index(d1: List[DVar], - d2: List[DVar], - d3: List[DVar], - d4: List[DVar], - i: int): +def no_broadcast_dim_with_index( + d1: List[DVar], d2: List[DVar], d3: List[DVar], d4: List[DVar], i: int +): """ Args: d1: input 1 @@ -838,17 +1088,28 @@ def no_broadcast_dim_with_index(d1: List[DVar], Returns: Constraints for when no broadcasting occurs """ - return Conj([ - Disj([ - Conj([BinConstraintD(d1[i], 1, op_eq), - BinConstraintD(d2[i], 1, op_eq)]), - - Conj([BinConstraintD(d1[i], 1, op_neq), - BinConstraintD(d2[i], 1, op_neq)])]), - - BinConstraintD(d1[i], d3[i], op_eq), - BinConstraintD(d2[i], d4[i], op_eq)]) - + return Conj( + [ + Disj( + [ + Conj( + [ + BinConstraintD(d1[i], 1, op_eq), + BinConstraintD(d2[i], 1, op_eq), + ] + ), + Conj( + [ + BinConstraintD(d1[i], 1, op_neq), + BinConstraintD(d2[i], 1, op_neq), + ] + ), + ] + ), + BinConstraintD(d1[i], d3[i], op_eq), + BinConstraintD(d2[i], d4[i], op_eq), + ] + ) def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): @@ -871,14 +1132,16 @@ def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): return res, counter -def create_equality_constraints_for_broadcasting(e1: TVar, - e2: TVar, - e11: TVar, - e12: TVar, - d1: List[DVar], - d2: List[DVar], - d11: List[DVar], - d12: List[DVar]): +def create_equality_constraints_for_broadcasting( + e1: TVar, + e2: TVar, + e11: TVar, + e12: TVar, + d1: List[DVar], + d2: List[DVar], + d11: List[DVar], + d12: List[DVar], +): """ Create equality constraints for when no broadcasting occurs Args: @@ -920,10 +1183,17 @@ def gen_consistency_constraints(constraint: Constraint, counter: int): nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) - c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), - BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] + - [BinConstraintD(d1, d2, op_consistency) for - d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq), + ] + + [ + BinConstraintD(d1, d2, op_consistency) + for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2) + ] + + nat_constraints + ) all_constraints.append(c_tensor_i) @@ -953,22 +1223,29 @@ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): dims3, counter = gen_tensor_dims(i, counter) c3tensor = TensorType(dims3) - c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq), - BinConstraintT(constraint.rhs2, c2tensor, op_eq), - BinConstraintT(constraint.res, c3tensor, op_eq)] + \ - gen_nat_constraints(dims1 + dims2 + dims3) + c += [ + BinConstraintT(constraint.rhs1, c1tensor, op_eq), + BinConstraintT(constraint.rhs2, c2tensor, op_eq), + BinConstraintT(constraint.res, c3tensor, op_eq), + ] + gen_nat_constraints(dims1 + dims2 + dims3) - assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) + assert ( + len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) + ) for i in range(len(c3tensor.__args__)): - c.append(DGreatestUpperBound(c3tensor.__args__[i], - c1tensor.__args__[i], - c2tensor.__args__[i])) + c.append( + DGreatestUpperBound( + c3tensor.__args__[i], c1tensor.__args__[i], c2tensor.__args__[i] + ) + ) all_constraints.append(Conj(c)) return all_constraints, counter -def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): +def generate_all_broadcasting_possibilities_no_padding( + d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar] +): """ Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. We look at all combinations for all dimensions in d1 and d2 @@ -996,7 +1273,9 @@ def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[ return Conj(res2) -def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int): +def gen_broadcasting_constraints( + e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int +): """ Simulates broadcasting on e1 and e2 and returns the results respectively in e11 and e12. Because of gradual types, @@ -1019,22 +1298,33 @@ def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: in [d1, d2, d3, d4] = dims nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims))) - initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12, - d1, d2, d3, d4) + initialize_tensors_constraints = create_equality_constraints_for_broadcasting( + e1, e2, e11, e12, d1, d2, d3, d4 + ) [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints # without padding, broadcast all possibilities for tensors of size i - final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints, - generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)]) + final_tensor_constraint_no_padding = Conj( + [ + *initialize_tensors_constraints, + generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4), + ] + ) # with padding, broadcast all possibilities for tensors of size i - final_tensor_constraint_padding_arg1, counter = \ - apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter) + final_tensor_constraint_padding_arg1, counter = apply_padding( + e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter + ) - final_tensor_constraint_padding_arg2, counter = \ - apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter) + final_tensor_constraint_padding_arg2, counter = apply_padding( + e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter + ) - return final_tensor_constraint_no_padding, \ - final_tensor_constraint_padding_arg1, \ - final_tensor_constraint_padding_arg2, nat_dims_i, counter + return ( + final_tensor_constraint_no_padding, + final_tensor_constraint_padding_arg1, + final_tensor_constraint_padding_arg2, + nat_dims_i, + counter, + ) diff --git a/torch/fx/experimental/migrate_gradual_types/operation.py b/torch/fx/experimental/migrate_gradual_types/operation.py index 432cd570bebb..267100c8545c 100644 --- a/torch/fx/experimental/migrate_gradual_types/operation.py +++ b/torch/fx/experimental/migrate_gradual_types/operation.py @@ -1,14 +1,14 @@ -op_add = '+' -op_sub = '-' -op_mul = '*' -op_div = '/' -op_eq = '=' -op_neq = '!=' -op_imp = '=>' -op_matching = '\u22b3' # (contains) -op_consistency = '~' -op_precision = '\u2291' # (square image of or equal to) -op_leq = '\u2264' # less-than or equal to -op_lt = '<' -op_gt = '>' -op_mod = '%' +op_add = "+" +op_sub = "-" +op_mul = "*" +op_div = "/" +op_eq = "=" +op_neq = "!=" +op_imp = "=>" +op_matching = "\u22b3" # (contains) +op_consistency = "~" +op_precision = "\u2291" # (square image of or equal to) +op_leq = "\u2264" # less-than or equal to +op_lt = "<" +op_gt = ">" +op_mod = "%" diff --git a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py index c8cf70006cd8..d1f9f33965e0 100644 --- a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +++ b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -1,16 +1,49 @@ # mypy: allow-untyped-defs -from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr -from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar -from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim -from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator -from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint -from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt -from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod -from torch.fx.tensor_type import TensorType, Dyn +from torch.fx.experimental.migrate_gradual_types.constraint import ( + BinConstraintD, + BinConstraintT, + BVar, + Conj, + Disj, + DVar, + F, + is_algebraic_expression, + is_bool_expr, + is_dim, + Prod, + T, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ( + ConstraintGenerator, +) +from torch.fx.experimental.migrate_gradual_types.constraint_transformation import ( + transform_constraint, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_div, + op_eq, + op_gt, + op_leq, + op_lt, + op_mod, + op_mul, + op_neq, + op_sub, +) +from torch.fx.tensor_type import Dyn, TensorType + try: import z3 # type: ignore[import] - from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D + + from torch.fx.experimental.migrate_gradual_types.z3_types import ( + D, + tensor_type, + z3_dyn, + ) + HAS_Z3 = True def transform_to_z3(constraint, counter, dimension_dict): @@ -41,35 +74,48 @@ try: return (lhs == rhs), counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") elif isinstance(constraint, BinConstraintD): if constraint.op == op_eq: - if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): - transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict) + transformed_rhs, counter = transform_to_z3( + constraint.rhs, counter, dimension_dict + ) transformed_lhs = z3.Bool(constraint.lhs.c) return transformed_lhs == transformed_rhs, counter elif is_dim(constraint.lhs) and is_dim(constraint.rhs): # with dimension transformations we consider the encoding - lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_dimension( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_dimension( + constraint.rhs, counter, dimension_dict + ) return lhs == rhs, counter else: # then we have an algebraic expression which means that we disregard the # first element of the encoding - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) return lhs == rhs, counter # The assumption here is that the LHS and RHS must be dimensions elif constraint.op == op_neq: assert is_dim(constraint.lhs) assert is_dim(constraint.rhs) - lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_dimension( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_dimension( + constraint.rhs, counter, dimension_dict + ) if constraint.rhs == Dyn or constraint.lhs == Dyn: if constraint.rhs == Dyn: return lhs.arg(0) == 1, counter @@ -79,44 +125,83 @@ try: # if one of the instances is a number elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int): if isinstance(constraint.lhs, int): - return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + return ( + z3.Or( + [ + rhs.arg(0) == 0, + z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]), + ] + ), + counter, + ) elif isinstance(constraint.rhs, int): - return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + return ( + z3.Or( + [ + lhs.arg(0) == 0, + z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]), + ] + ), + counter, + ) else: - return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), - z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), - z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter - + return ( + z3.Or( + [ + z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), + z3.And( + [ + lhs.arg(0) != 0, + rhs.arg(0) != 0, + lhs.arg(1) != rhs.arg(1), + ] + ), + ] + ), + counter, + ) elif constraint.op == op_leq: # if the dimensions are not dyn, this will come into effect # there would have been another constraint specifying if a given dimension # is dyn or not assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) return lhs <= rhs, counter elif constraint.op == op_gt: assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) return lhs > rhs, counter elif constraint.op == op_lt: assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) return lhs < rhs, counter else: - raise NotImplementedError('operation not yet implemented') + raise NotImplementedError("operation not yet implemented") else: - raise NotImplementedError('Operation not yet implemented') - + raise NotImplementedError("Operation not yet implemented") def transform_var(tensor, counter, dimension_dict): """ @@ -166,13 +251,15 @@ try: return D(1, dimension), counter elif isinstance(dimension, DVar): if dimension.c in dimension_dict: - return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter + return ( + D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), + counter, + ) else: counter += 1 dimension_dict[dimension.c] = counter return D(z3.Int(counter), z3.Int(dimension.c)), counter - def transform_algebraic_expression(expr, counter, dimension_dict): """ Transforms an algebraic expression to z3 format @@ -190,7 +277,6 @@ try: return transformed.arg(1), counter elif isinstance(expr, Prod): - dims = [] for dim in expr.products: assert is_dim(dim) @@ -199,9 +285,12 @@ try: return z3.Product(dims), counter elif is_algebraic_expression(expr): - - lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + expr.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + expr.rhs, counter, dimension_dict + ) if expr.op == op_sub: c = lhs - rhs @@ -219,14 +308,13 @@ try: c = lhs % rhs else: - raise NotImplementedError('operation not yet implemented') + raise NotImplementedError("operation not yet implemented") return c, counter else: raise RuntimeError - def transform_all_constraints(traced, counter=0): """ Given a trace, generates constraints and transforms them to z3 format @@ -291,7 +379,6 @@ try: # transform precision, matching, consistency till obtaining a fixed point new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) - # since the function returns a list of one element, we get the first element # we are only interested in the RHS in this case because the LHS just stores # the result @@ -304,19 +391,27 @@ try: condition_constraint_rhs = condition_constraint.rhs # transform the condition constraint - condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter) + condition_constraint_rhs, counter = iterate_till_fixed_point( + condition_constraint_rhs, counter + ) transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) - transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict) + transformed_condition_constraint, counter = transform_to_z3( + condition_constraint_rhs, counter, dimension_dict + ) - negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint) + negation_transformed_condition_constraint = z3.Not( + transformed_condition_constraint + ) - return z3.And([transformed, transformed_condition_constraint]), \ - z3.And([transformed, negation_transformed_condition_constraint]) + return z3.And([transformed, transformed_condition_constraint]), z3.And( + [transformed, negation_transformed_condition_constraint] + ) - - def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None): + def evaluate_conditional_with_constraints( + tracer_root, graph, node, counter=0, user_constraints=None + ): """ Given an IR and a node representing a conditional, evaluate the conditional and its negation @@ -329,8 +424,10 @@ try: """ - transformed_positive, transformed_negative = \ - transform_all_constraints_trace_time(tracer_root, graph, node, counter) + ( + transformed_positive, + transformed_negative, + ) = transform_all_constraints_trace_time(tracer_root, graph, node, counter) s = z3.Solver() s.add(transformed_positive) diff --git a/torch/fx/experimental/migrate_gradual_types/util.py b/torch/fx/experimental/migrate_gradual_types/util.py index 99f94609f265..bd40d2a463f5 100644 --- a/torch/fx/experimental/migrate_gradual_types/util.py +++ b/torch/fx/experimental/migrate_gradual_types/util.py @@ -1,6 +1,10 @@ # mypy: allow-untyped-defs -from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ - BVar +from torch.fx.experimental.migrate_gradual_types.constraint import ( + BinConstraintD, + BVar, + DVar, + TVar, +) from torch.fx.experimental.migrate_gradual_types.operation import op_leq @@ -23,6 +27,7 @@ def gen_dvar(curr): curr += 1 return DVar(curr), curr + def gen_bvar(curr): """ Generate a boolean variable @@ -32,6 +37,7 @@ def gen_bvar(curr): curr += 1 return BVar(curr), curr + def gen_tensor_dims(n, curr): """ Generate a list of tensor dimensions diff --git a/torch/fx/experimental/migrate_gradual_types/z3_types.py b/torch/fx/experimental/migrate_gradual_types/z3_types.py index 897a79d56975..939f4865ab7d 100644 --- a/torch/fx/experimental/migrate_gradual_types/z3_types.py +++ b/torch/fx/experimental/migrate_gradual_types/z3_types.py @@ -1,22 +1,23 @@ try: import z3 # type: ignore[import] + HAS_Z3 = True # dynamic type - dyn = z3.DeclareSort('Dyn') - dyn_type = z3.Const('dyn', dyn) + dyn = z3.DeclareSort("Dyn") + dyn_type = z3.Const("dyn", dyn) # dimension - dim = z3.Datatype('dim') - dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort())) + dim = z3.Datatype("dim") + dim.declare("dim", ("0", z3.IntSort()), ("1", z3.IntSort())) dim = dim.create() # tensors - tensor_type = z3.Datatype('TensorType') - tensor_type.declare('Dyn', ('dyn', dyn)) - tensor_type.declare('tensor1', ('0', dim)) - tensor_type.declare('tensor2', ('0', dim), ('1', dim)) - tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim)) - tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim)) + tensor_type = z3.Datatype("TensorType") + tensor_type.declare("Dyn", ("dyn", dyn)) + tensor_type.declare("tensor1", ("0", dim)) + tensor_type.declare("tensor2", ("0", dim), ("1", dim)) + tensor_type.declare("tensor3", ("0", dim), ("1", dim), ("2", dim)) + tensor_type.declare("tensor4", ("0", dim), ("1", dim), ("2", dim), ("3", dim)) tensor_type = tensor_type.create() # create dimension diff --git a/torch/fx/experimental/normalize.py b/torch/fx/experimental/normalize.py index 30b076a72bee..cc6944d5a5af 100644 --- a/torch/fx/experimental/normalize.py +++ b/torch/fx/experimental/normalize.py @@ -1,16 +1,16 @@ # mypy: allow-untyped-defs import operator -from typing import Any, Callable, Dict, Tuple, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch import torch.fx import torch.fx as fx -from torch.fx import Transformer, Proxy -from torch.fx.node import Argument, Target, Node, map_aggregate +from torch.fx import Proxy, Transformer +from torch.fx.node import Argument, map_aggregate, Node, Target from torch.fx.operator_schemas import ( - normalize_module, - normalize_function, create_type_hint, + normalize_function, + normalize_module, ) from .schema_type_annotation import AnnotateTypesWithSchema diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index 44b031524a41..2fe600c247b8 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -1,37 +1,42 @@ # mypy: allow-untyped-defs -import torch.fx as fx -from torch.fx.node import Argument, Target -from torch.nn.utils.fusion import fuse_conv_bn_eval -from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.fx.passes.shape_prop import ShapeProp import copy -from collections import defaultdict -import torch.utils.mkldnn as th_mkldnn +import logging import operator import time -import logging +from collections import defaultdict from enum import Enum +from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type -def _parent_name(target : str) -> Tuple[str, str]: +import torch +import torch.fx as fx +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.mkldnn as th_mkldnn +from torch.fx.node import Argument, Target +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn.utils.fusion import fuse_conv_bn_eval + + +def _parent_name(target: str) -> Tuple[str, str]: """ Splits a qualname into parent path and last atom. For example, `foo.bar.baz` -> (`foo.bar`, `baz`) """ - *parent, name = target.rsplit('.', 1) - return parent[0] if parent else '', name + *parent, name = target.rsplit(".", 1) + return parent[0] if parent else "", name + # Works for length 2 patterns with 2 modules -def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): +def matches_module_pattern( + pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any] +): if len(node.args) == 0: return False nodes: Tuple[Any, fx.Node] = (node.args[0], node) for expected_type, current_node in zip(pattern, nodes): if not isinstance(current_node, fx.Node): return False - if current_node.op != 'call_module': + if current_node.op != "call_module": return False if not isinstance(current_node.target, str): return False @@ -42,20 +47,25 @@ def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict return True -def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): +def replace_node_module( + node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module +): assert isinstance(node.target, str) parent_name, name = _parent_name(node.target) modules[node.target] = new_module setattr(modules[parent_name], name, new_module) + def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: """ Fuses convolution/BN layers for inference purposes. Will deepcopy your model by default, but can modify the model inplace as well. """ - patterns = [(nn.Conv1d, nn.BatchNorm1d), - (nn.Conv2d, nn.BatchNorm2d), - (nn.Conv3d, nn.BatchNorm3d)] + patterns = [ + (nn.Conv1d, nn.BatchNorm1d), + (nn.Conv2d, nn.BatchNorm2d), + (nn.Conv3d, nn.BatchNorm3d), + ] if not inplace: model = copy.deepcopy(model) if not no_trace or not isinstance(model, torch.fx.GraphModule): @@ -80,6 +90,7 @@ def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Modu new_graph.erase_node(node) return fx.GraphModule(fx_model, new_graph) + def remove_dropout(model: nn.Module) -> nn.Module: """ Removes all dropout layers from the module. @@ -87,15 +98,24 @@ def remove_dropout(model: nn.Module) -> nn.Module: fx_model = fx.symbolic_trace(model) class DropoutRemover(torch.fx.Transformer): - def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_module( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: if isinstance(self.submodules[target], nn.Dropout): assert len(args) == 1 return args[0] else: return super().call_module(target, args, kwargs) + return DropoutRemover(fx_model).transform() -def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): + +def extract_subgraph( + orig_module: nn.Module, + nodes: List[fx.Node], + inputs: List[fx.Node], + outputs: List[fx.Node], +): """ Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. """ @@ -111,10 +131,21 @@ def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[ new_graph.lint() return fx.GraphModule(orig_module, new_graph) + mkldnn_supported = [ - nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, - torch.relu, torch.transpose, torch.sigmoid, - F.relu, F.avg_pool2d, F.adaptive_avg_pool2d + nn.Conv2d, + nn.Linear, + nn.BatchNorm2d, + nn.ReLU, + nn.MaxPool2d, + nn.AvgPool2d, + nn.AdaptiveAvgPool2d, + torch.relu, + torch.transpose, + torch.sigmoid, + F.relu, + F.avg_pool2d, + F.adaptive_avg_pool2d, ] # These are operators that may not be convertible into MKLDNN ops (e.g. the # args are scalar values). Thus, we only include them in the subgraph if their @@ -124,7 +155,7 @@ mkldnn_supported_unknown = [operator.add, operator.mul] mkldnn_map = { nn.Conv2d: th_mkldnn.MkldnnConv2d, nn.Linear: th_mkldnn.MkldnnLinear, - nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a) + nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a), } @@ -136,7 +167,7 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): """ old_modules: Dict[nn.Module, nn.Module] = {} for node in nodes: - if node.op == 'call_module': + if node.op == "call_module": assert isinstance(node.target, str) cur_module = modules[node.target] if type(cur_module) in mkldnn_map: @@ -146,18 +177,24 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): replace_node_module(node, modules, new_module) return old_modules -def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]): + +def reset_modules( + nodes: List[fx.Node], + modules: Dict[str, nn.Module], + old_modules: Dict[nn.Module, nn.Module], +): """ Maps each module that's been changed with `modules_to_mkldnn` back to its original. """ for node in nodes: - if node.op == 'call_module': - assert (isinstance(node.target, str)) + if node.op == "call_module": + assert isinstance(node.target, str) cur_module = modules[node.target] if cur_module in old_modules: replace_node_module(node, modules, old_modules[cur_module]) + class MklSubgraph: def __init__(self, fx_graph: fx.Graph): self.fx_graph = fx_graph @@ -165,6 +202,7 @@ class MklSubgraph: self.start_nodes: List[fx.Node] = [] self.end_nodes: List[fx.Node] = [] + def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): """ This generates a heuristic that can be passed into `optimize_for_inference` that @@ -196,13 +234,21 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): f() return time.time() - begin - mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])]) + mkl_time = benchmark( + lambda: [ + i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs]) + ] + ) - reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules) + reset_modules( + submodule.graph.nodes, dict(submodule.named_modules()), old_modules + ) no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) return mkl_time < no_mkl_time + return use_mkl_heuristic + def use_mkl_length(graph: MklSubgraph) -> bool: """ This is a heuristic that can be passed into `optimize_for_inference` that @@ -211,6 +257,7 @@ def use_mkl_length(graph: MklSubgraph) -> bool: """ return len(graph.nodes) > 2 + class UnionFind: def __init__(self, n): self.parent: List[Optional[int]] = [None] * n @@ -237,10 +284,11 @@ class UnionFind: self.parent[b] = a self.size[a] += self.size[b] + def optimize_for_inference( model: torch.nn.Module, pass_config: Optional[Dict[str, Any]] = None, - tracer: Type[fx.Tracer] = fx.Tracer + tracer: Type[fx.Tracer] = fx.Tracer, ) -> torch.nn.Module: """ Performs a set of optimization passes to optimize a model for the @@ -258,7 +306,7 @@ def optimize_for_inference( default_pass_config = { "conv_bn_fuse": True, "remove_dropout": True, - "mkldnn_layout_optimize": {'heuristic': use_mkl_length}, + "mkldnn_layout_optimize": {"heuristic": use_mkl_length}, } if pass_config is None: pass_config = {} @@ -292,15 +340,19 @@ def optimize_for_inference( # a MKLDNN node if its inputs are MKLDNN nodes. for node in list(fx_graph.nodes): supports_mkldnn = MklSupport.NO - if node.op == 'call_module': + if node.op == "call_module": cur_module = modules[node.target] if type(cur_module) in mkldnn_supported: supports_mkldnn = MklSupport.YES sample_parameter = next(cur_module.parameters(), None) if sample_parameter is not None: - assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules" - assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules" - elif node.op == 'call_function': + assert ( + sample_parameter.dtype == torch.float + ), "this pass is only for torch.float modules" + assert sample_parameter.device == torch.device( + "cpu" + ), "this pass is only for CPU modules" + elif node.op == "call_function": if node.target in mkldnn_supported: supports_mkldnn = MklSupport.YES elif node.target in mkldnn_supported_unknown: @@ -308,15 +360,17 @@ def optimize_for_inference( if supports_mkldnn != MklSupport.NO: if supports_mkldnn == MklSupport.UNKNOWN: - if not any(arg.target == 'to_dense' for arg in node.args): + if not any(arg.target == "to_dense" for arg in node.args): continue with fx_graph.inserting_before(node): - mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) + mkldnn_args = fx.map_arg( + node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,)) + ) node.args = cast(Tuple[fx.node.Argument], mkldnn_args) with fx_graph.inserting_after(node): - dense_x = fx_graph.create_node('call_method', 'to_dense', (node,)) + dense_x = fx_graph.create_node("call_method", "to_dense", (node,)) node.replace_all_uses_with(dense_x) dense_x.args = (node,) @@ -326,28 +380,26 @@ def optimize_for_inference( # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b for node in fx_graph.nodes: - if node.op == 'call_method' and node.target == 'to_dense': + if node.op == "call_method" and node.target == "to_dense": prv_node = node.args[0] users = list(node.users) for user in users: - if user.op == 'call_method' and user.target == 'to_mkldnn': + if user.op == "call_method" and user.target == "to_mkldnn": user.replace_all_uses_with(prv_node) fx_graph.erase_node(user) if len(node.users) == 0: fx_graph.erase_node(node) - num_nodes = len(fx_graph.nodes) uf = UnionFind(num_nodes) def get_color(n): - if hasattr(n, 'color'): # Current node is part of a MKL subgraph + if hasattr(n, "color"): # Current node is part of a MKL subgraph return uf.find(n.color) - if hasattr(n, 'start_color'): # Current node is input to MKL subgraph + if hasattr(n, "start_color"): # Current node is input to MKL subgraph return uf.find(n.start_color) return None - # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists # of input nodes (which are only `to_mkldnn` calls), output nodes # (`to_dense` calls), and intermediate nodes, which are run entirely on @@ -360,14 +412,19 @@ def optimize_for_inference( # nodes (i.e. colors), we need to join these 2 colors into 1. That's done # using a Disjoint Set Union. for cur_idx, node in enumerate(fx_graph.nodes): - if node.op == 'call_method' and node.target == 'to_mkldnn': + if node.op == "call_method" and node.target == "to_mkldnn": node.start_color = cur_idx uf.make_set(cur_idx) - elif node.op == 'call_method' and node.target == 'to_dense': + elif node.op == "call_method" and node.target == "to_dense": assert get_color(node.args[0]) is not None node.end_color = get_color(node.args[0]) else: - cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] + cur_colors = [ + get_color(i) + for i in node.all_input_nodes + if isinstance(i, fx.Node) + if get_color(i) is not None + ] if len(cur_colors) == 0: continue @@ -377,17 +434,15 @@ def optimize_for_inference( for other_color in cur_colors[1:]: uf.join(cur_colors[0], other_color) - mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) for node in fx_graph.nodes: - if hasattr(node, 'color'): + if hasattr(node, "color"): mkldnn_graphs[uf.find(node.color)].nodes.append(node) - if hasattr(node, 'start_color'): + if hasattr(node, "start_color"): mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) - if hasattr(node, 'end_color'): + if hasattr(node, "end_color"): mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) - # Now that we have all the subgraphs, we need to decide which MKLDNN # subgraphs we actually want to keep in MKLDNN. for graph in mkldnn_graphs.values(): @@ -400,7 +455,7 @@ def optimize_for_inference( mkldnn_conversions = 0 for node in fx_graph.nodes: - if node.target == 'to_mkldnn' or node.target == 'to_dense': + if node.target == "to_mkldnn" or node.target == "to_dense": mkldnn_conversions += 1 logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions) diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py index a88c481f92f2..e59921c58fa1 100644 --- a/torch/fx/experimental/partitioner_utils.py +++ b/torch/fx/experimental/partitioner_utils.py @@ -1,8 +1,8 @@ # mypy: allow-untyped-defs from enum import Enum -from typing import NamedTuple, Dict, List, Set +from typing import Dict, List, NamedTuple, Set -from torch.fx.node import Node, map_arg +from torch.fx.node import map_arg, Node class Partition: @@ -146,7 +146,7 @@ def get_latency_of_one_partition( # this node is on the top bfs level in this partition if not any( n in partition.nodes and n.op not in {"placeholder", "get_attr"} - for n in input_nodes + for n in input_nodes ): top_nodes.append(node) return top_nodes @@ -279,7 +279,9 @@ def get_latency_of_partitioned_graph( def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: """This function helps to recursively get the latency of a path of partitions""" # Update latency by adding current partition's latency - latency_so_far_sec += partition_to_latency_mapping[partition].overall_latency_sec + latency_so_far_sec += partition_to_latency_mapping[ + partition + ].overall_latency_sec if partition.children: max_latency_sec = 0.0 diff --git a/torch/fx/experimental/refinement_types.py b/torch/fx/experimental/refinement_types.py index a33ddf3710a4..4a262af8fad9 100644 --- a/torch/fx/experimental/refinement_types.py +++ b/torch/fx/experimental/refinement_types.py @@ -5,10 +5,10 @@ class Equality: self.rhs = rhs def __str__(self): - return f'{self.lhs} = {self.rhs}' + return f"{self.lhs} = {self.rhs}" def __repr__(self): - return f'{self.lhs} = {self.rhs}' + return f"{self.lhs} = {self.rhs}" def __eq__(self, other): if isinstance(other, Equality): diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index 3647ca59153b..76ec03f86289 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -1,16 +1,18 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import ast -import inspect -import textwrap import copy import functools +import inspect +import textwrap from types import FunctionType -from typing import cast, Union, Callable, Dict, Optional, Any +from typing import Any, Callable, cast, Dict, Optional, Union + +import torch +from torch._sources import normalize_source_lines from torch.fx._symbolic_trace import Tracer from torch.fx.graph import Graph -from torch._sources import normalize_source_lines -import torch + class AST_Rewriter(ast.NodeTransformer): """ @@ -29,11 +31,10 @@ class AST_Rewriter(ast.NodeTransformer): # suitable for dynamo tracing anyways. @torch._dynamo.disable def rewrite(self, fn: FunctionType): - # Normalize the source lines sourcelines, _ = inspect.getsourcelines(fn) sourcelines = normalize_source_lines(sourcelines) - source = ''.join(sourcelines) + source = "".join(sourcelines) normalized_str = textwrap.dedent(source) # Rewrite the original AST @@ -64,6 +65,7 @@ class AST_Rewriter(ast.NodeTransformer): g = functools.update_wrapper(g, f) g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] return g + # Return the correct FunctionType object return change_func_globals(fn_compiled, globals=fn.__globals__) @@ -73,7 +75,7 @@ class AST_Rewriter(ast.NodeTransformer): symbolically-traceable torch._assert function """ # Create the Call node - n = ast.parse('torch._assert()', mode='eval') + n = ast.parse("torch._assert()", mode="eval") assert isinstance(n, ast.Expression) call_node = n.body assert isinstance(call_node, ast.Call) @@ -96,13 +98,22 @@ class AST_Rewriter(ast.NodeTransformer): Output: y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) """ - return ast.Assign(targets=[node.target], value=ast.Call( - func=ast.Name(id='annotate', ctx=ast.Load()), - args=[node.value, node.annotation], keywords=[])) + return ast.Assign( + targets=[node.target], + value=ast.Call( + func=ast.Name(id="annotate", ctx=ast.Load()), + args=[node.value, node.annotation], + keywords=[], + ), + ) class RewritingTracer(Tracer): - def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + def trace( + self, + root: Union[torch.nn.Module, Callable], + concrete_args: Optional[Dict[str, Any]] = None, + ) -> Graph: return super().trace(_rewrite(root), concrete_args) @@ -111,7 +122,7 @@ def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Cal # Rewrite this module's `forward` as well as the `forward`s of # all of this module's recursive descendents. Return the new, # rewritten module hierarchy. - def rewrite_module(m : torch.nn.Module): + def rewrite_module(m: torch.nn.Module): class RewrittenModule(torch.nn.Module): def __init__(self, orig): super().__init__() @@ -120,8 +131,12 @@ def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Cal self.__dict__[k] = copy.copy(rewrite_module(v)) else: self.__dict__[k] = copy.copy(v) - RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) + + RewrittenModule.forward = AST_Rewriter().rewrite( + cast(FunctionType, m.forward) + ) return RewrittenModule(m) + return rewrite_module(fn) else: # Rewrite this single free function diff --git a/torch/fx/experimental/schema_type_annotation.py b/torch/fx/experimental/schema_type_annotation.py index 5c7ab78706cb..519fec16cfc8 100644 --- a/torch/fx/experimental/schema_type_annotation.py +++ b/torch/fx/experimental/schema_type_annotation.py @@ -1,13 +1,14 @@ # mypy: allow-untyped-defs -import torch -import torch.fx import inspect from typing import Any, Dict, Optional, Tuple -from torch.fx.node import Argument, Target + +import torch +import torch.fx from torch._jit_internal import boolean_dispatched +from torch.fx import Transformer +from torch.fx.node import Argument, Target from torch.fx.operator_schemas import _torchscript_type_to_python_type -from torch.fx import Transformer class AnnotateTypesWithSchema(Transformer): """ @@ -27,16 +28,24 @@ class AnnotateTypesWithSchema(Transformer): traced = AnnotateTypesWithSchema(traced).transform() """ - def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True, - annotate_modules : bool = True, annotate_get_attrs : bool = True): + + def __init__( + self, + module: torch.nn.Module, + annotate_functionals: bool = True, + annotate_modules: bool = True, + annotate_get_attrs: bool = True, + ): super().__init__(module) self.annotate_functionals = annotate_functionals self.annotate_modules = annotate_modules self.annotate_get_attrs = annotate_get_attrs - def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + def call_function( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ): python_ret_type = None - if self.annotate_functionals and target.__module__ == 'torch.nn.functional': + if self.annotate_functionals and target.__module__ == "torch.nn.functional": target_for_analysis = target if target in boolean_dispatched: # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have @@ -45,51 +54,71 @@ class AnnotateTypesWithSchema(Transformer): # branch signature for analysis. Otherwise, leave this un-normalized assert not isinstance(target, str) dispatched = boolean_dispatched[target] - if_true, if_false = dispatched['if_true'], dispatched['if_false'] + if_true, if_false = dispatched["if_true"], dispatched["if_false"] # TODO: can we emit the union of these? What are the implications on TorchScript # compilation? - if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation: + if ( + inspect.signature(if_true).return_annotation + != inspect.signature(if_false).return_annotation + ): return super().call_function(target, args, kwargs) target_for_analysis = if_true python_ret_type = self._extract_python_return_type(target_for_analysis) return_proxy = super().call_function(target, args, kwargs) - return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) return return_proxy - def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + def call_module( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ): python_ret_type = None assert isinstance(target, str) submod = self.fetch_attr(target) - if self.annotate_modules and hasattr(submod.__class__, '__name__'): + if self.annotate_modules and hasattr(submod.__class__, "__name__"): classname = submod.__class__.__name__ if getattr(torch.nn, classname, None) == submod.__class__: python_ret_type = self._extract_python_return_type(submod.forward) return_proxy = super().call_module(target, args, kwargs) - return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) return return_proxy - def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + def get_attr( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + ): attr_proxy = super().get_attr(target, args, kwargs) if self.annotate_get_attrs: module_itr = self.module assert isinstance(target, str) - atoms = target.split('.') + atoms = target.split(".") for i, atom in enumerate(atoms): if not hasattr(module_itr, atom): - raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!') + raise RuntimeError( + f'Node referenced nonextent target {".".join(atoms[:i])}!' + ) module_itr = getattr(module_itr, atom) maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) if maybe_inferred_ts_type.success(): - python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type()) - attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type + python_type = _torchscript_type_to_python_type( + maybe_inferred_ts_type.type() + ) + attr_proxy.node.type = ( + python_type if not attr_proxy.node.type else attr_proxy.node.type + ) return attr_proxy - def _extract_python_return_type(self, target : Target) -> Optional[Any]: + def _extract_python_return_type(self, target: Target) -> Optional[Any]: """ Given a Python call target, try to extract the Python return annotation if it is available, otherwise return None @@ -109,4 +138,8 @@ class AnnotateTypesWithSchema(Transformer): except (ValueError, TypeError): return None - return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None + return ( + sig.return_annotation + if sig.return_annotation is not inspect.Signature.empty + else None + ) diff --git a/torch/fx/experimental/unification/__init__.py b/torch/fx/experimental/unification/__init__.py index 31446d0e6125..7db0e29d1d4f 100644 --- a/torch/fx/experimental/unification/__init__.py +++ b/torch/fx/experimental/unification/__init__.py @@ -1,4 +1,4 @@ # mypy: disable-error-code=attr-defined -from .core import unify, reify # noqa: F403 +from .core import reify, unify # noqa: F403 from .more import unifiable # noqa: F403 -from .variable import var, isvar, vars, variables, Var # noqa: F403 +from .variable import isvar, Var, var, variables, vars # noqa: F403 diff --git a/torch/fx/experimental/unification/core.py b/torch/fx/experimental/unification/core.py index 0893c385bbc9..e32f42c8968e 100644 --- a/torch/fx/experimental/unification/core.py +++ b/torch/fx/experimental/unification/core.py @@ -2,10 +2,11 @@ from collections.abc import Iterator # type: ignore[import] from functools import partial +from .dispatch import dispatch from .unification_tools import assoc # type: ignore[import] from .utils import transitive_get as walk from .variable import isvar -from .dispatch import dispatch + __all__ = ["reify", "unify"] @@ -13,33 +14,47 @@ __all__ = ["reify", "unify"] # Reification # ############### + @dispatch(Iterator, dict) def _reify(t, s): return map(partial(reify, s=s), t) # return (reify(arg, s) for arg in t) + + _reify + @dispatch(tuple, dict) # type: ignore[no-redef] def _reify(t, s): return tuple(reify(iter(t), s)) + + _reify + @dispatch(list, dict) # type: ignore[no-redef] def _reify(t, s): return list(reify(iter(t), s)) + + _reify + @dispatch(dict, dict) # type: ignore[no-redef] def _reify(d, s): return {k: reify(v, s) for k, v in d.items()} + + _reify + @dispatch(object, dict) # type: ignore[no-redef] def _reify(o, s): return o # catch all, just return the object + def reify(e, s): - """ Replace variables of expression with substitution + """Replace variables of expression with substitution >>> # xdoctest: +SKIP >>> x, y = var(), var() >>> e = (1, x, (3, y)) @@ -54,12 +69,14 @@ def reify(e, s): return reify(s[e], s) if e in s else e return _reify(e, s) + ############### # Unification # ############### seq = tuple, list, Iterator + @dispatch(seq, seq, dict) def _unify(u, v, s): if len(u) != len(v): @@ -69,6 +86,8 @@ def _unify(u, v, s): if s is False: return False return s + + # # @dispatch((set, frozenset), (set, frozenset), dict) # def _unify(u, v, s): @@ -98,8 +117,8 @@ def _unify(u, v, s): @dispatch(object, object, dict) def unify(u, v, s): # no check at the moment - """ Find substitution so that u == v while satisfying s - >>> x = var('x') + """Find substitution so that u == v while satisfying s + >>> x = var("x") >>> unify((1, x), (1, 2), {}) {~x: 2} """ @@ -112,8 +131,11 @@ def unify(u, v, s): # no check at the moment if isvar(v): return assoc(s, v, u) return _unify(u, v, s) + + unify + @dispatch(object, object) # type: ignore[no-redef] def unify(u, v): return unify(u, v, {}) diff --git a/torch/fx/experimental/unification/dispatch.py b/torch/fx/experimental/unification/dispatch.py index 93039ce75070..82d62e1f1619 100644 --- a/torch/fx/experimental/unification/dispatch.py +++ b/torch/fx/experimental/unification/dispatch.py @@ -1,6 +1,8 @@ from functools import partial + from .multipledispatch import dispatch # type: ignore[import] + namespace = {} # type: ignore[var-annotated] dispatch = partial(dispatch, namespace=namespace) diff --git a/torch/fx/experimental/unification/match.py b/torch/fx/experimental/unification/match.py index 1e7b3f2d22bb..01861a086f64 100644 --- a/torch/fx/experimental/unification/match.py +++ b/torch/fx/experimental/unification/match.py @@ -1,8 +1,8 @@ # mypy: allow-untyped-defs -from .core import unify, reify # type: ignore[attr-defined] -from .variable import isvar +from .core import reify, unify # type: ignore[attr-defined] +from .unification_tools import first, groupby # type: ignore[import] from .utils import _toposort, freeze -from .unification_tools import groupby, first # type: ignore[import] +from .variable import isvar class Dispatcher: @@ -28,32 +28,38 @@ class Dispatcher: if s is not False: result = self.funcs[signature] return result, s - raise NotImplementedError("No match found. \nKnown matches: " - + str(self.ordering) + "\nInput: " + str(args)) + raise NotImplementedError( + "No match found. \nKnown matches: " + + str(self.ordering) + + "\nInput: " + + str(args) + ) def register(self, *signature): def _(func): self.add(signature, func) return self + return _ class VarDispatcher(Dispatcher): - """ A dispatcher that calls functions with variable names + """A dispatcher that calls functions with variable names >>> # xdoctest: +SKIP - >>> d = VarDispatcher('d') - >>> x = var('x') - >>> @d.register('inc', x) + >>> d = VarDispatcher("d") + >>> x = var("x") + >>> @d.register("inc", x) ... def f(x): ... return x + 1 - >>> @d.register('double', x) + >>> @d.register("double", x) ... def f(x): ... return x * 2 - >>> d('inc', 10) + >>> d("inc", 10) 11 - >>> d('double', 10) + >>> d("double", 10) 20 """ + def __call__(self, *args, **kwargs): func, s = self.resolve(args) d = {k.token: v for k, v in s.items()} @@ -64,8 +70,8 @@ global_namespace = {} # type: ignore[var-annotated] def match(*signature, **kwargs): - namespace = kwargs.get('namespace', global_namespace) - dispatcher = kwargs.get('Dispatcher', Dispatcher) + namespace = kwargs.get("namespace", global_namespace) + dispatcher = kwargs.get("Dispatcher", Dispatcher) def _(func): name = func.__name__ @@ -77,11 +83,12 @@ def match(*signature, **kwargs): d.add(signature, func) return d + return _ def supercedes(a, b): - """ ``a`` is a more specific match than ``b`` """ + """``a`` is a more specific match than ``b``""" if isvar(b) and not isvar(a): return True s = unify(a, b) @@ -96,7 +103,7 @@ def supercedes(a, b): # Taken from multipledispatch def edge(a, b, tie_breaker=hash): - """ A should be checked before B + """A should be checked before B Tie broken by tie_breaker, defaults to ``hash`` """ if supercedes(a, b): @@ -109,7 +116,7 @@ def edge(a, b, tie_breaker=hash): # Taken from multipledispatch def ordering(signatures): - """ A sane ordering of signatures to check, first to last + """A sane ordering of signatures to check, first to last Topological sort of edges as given by ``edge`` and ``supercedes`` """ signatures = list(map(tuple, signatures)) diff --git a/torch/fx/experimental/unification/more.py b/torch/fx/experimental/unification/more.py index 2228448a71a1..da2b1773f95b 100644 --- a/torch/fx/experimental/unification/more.py +++ b/torch/fx/experimental/unification/more.py @@ -1,10 +1,10 @@ # mypy: allow-untyped-defs -from .core import unify, reify # type: ignore[attr-defined] +from .core import reify, unify # type: ignore[attr-defined] from .dispatch import dispatch def unifiable(cls): - """ Register standard unify and reify operations on class + """Register standard unify and reify operations on class This uses the type and __dict__ or __slots__ attributes to define the nature of the term See Also: @@ -15,7 +15,7 @@ def unifiable(cls): ... self.b = b >>> unifiable(A) - >>> x = var('x') + >>> x = var("x") >>> a = A(1, 2) >>> b = A(1, x) >>> unify(a, b, {}) @@ -33,22 +33,23 @@ def unifiable(cls): def reify_object(o, s): - """ Reify a Python object with a substitution + """Reify a Python object with a substitution >>> # xdoctest: +SKIP >>> class Foo(object): ... def __init__(self, a, b): ... self.a = a ... self.b = b + ... ... def __str__(self): - ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) - >>> x = var('x') + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") >>> f = Foo(1, x) >>> print(f) Foo(1, ~x) >>> print(reify_object(f, {x: 2})) Foo(1, 2) """ - if hasattr(o, '__slots__'): + if hasattr(o, "__slots__"): return _reify_object_slots(o, s) else: return _reify_object_dict(o, s) @@ -77,7 +78,7 @@ def _reify_object_slots(o, s): @dispatch(slice, dict) def _reify(o, s): - """ Reify a Python ``slice`` object """ + """Reify a Python ``slice`` object""" return slice(*reify((o.start, o.stop, o.step), s)) @@ -87,16 +88,17 @@ def _reify(o, s): def unify_object(u, v, s): - """ Unify two Python objects + """Unify two Python objects Unifies their type and ``__dict__`` attributes >>> # xdoctest: +SKIP >>> class Foo(object): ... def __init__(self, a, b): ... self.a = a ... self.b = b + ... ... def __str__(self): - ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) - >>> x = var('x') + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") >>> f = Foo(1, x) >>> g = Foo(1, 2) >>> unify_object(f, g, {}) @@ -104,15 +106,17 @@ def unify_object(u, v, s): """ if type(u) != type(v): return False - if hasattr(u, '__slots__'): - return unify([getattr(u, slot) for slot in u.__slots__], - [getattr(v, slot) for slot in v.__slots__], - s) + if hasattr(u, "__slots__"): + return unify( + [getattr(u, slot) for slot in u.__slots__], + [getattr(v, slot) for slot in v.__slots__], + s, + ) else: return unify(u.__dict__, v.__dict__, s) @dispatch(slice, slice, dict) def _unify(u, v, s): - """ Unify a Python ``slice`` object """ + """Unify a Python ``slice`` object""" return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/torch/fx/experimental/unification/multipledispatch/__init__.py b/torch/fx/experimental/unification/multipledispatch/__init__.py index a0295af0ea6b..bb7304069243 100644 --- a/torch/fx/experimental/unification/multipledispatch/__init__.py +++ b/torch/fx/experimental/unification/multipledispatch/__init__.py @@ -1,3 +1,7 @@ from .core import dispatch -from .dispatcher import (Dispatcher, halt_ordering, restart_ordering, - MDNotImplementedError) +from .dispatcher import ( + Dispatcher, + halt_ordering, + MDNotImplementedError, + restart_ordering, +) diff --git a/torch/fx/experimental/unification/multipledispatch/conflict.py b/torch/fx/experimental/unification/multipledispatch/conflict.py index 7187330ead25..44a893ad56a4 100644 --- a/torch/fx/experimental/unification/multipledispatch/conflict.py +++ b/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -1,17 +1,28 @@ # mypy: allow-untyped-defs -from .utils import _toposort, groupby -from .variadic import isvariadic import operator -__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature", - "edge", "ordering"] +from .utils import _toposort, groupby +from .variadic import isvariadic + + +__all__ = [ + "AmbiguityWarning", + "supercedes", + "consistent", + "ambiguous", + "ambiguities", + "super_signature", + "edge", + "ordering", +] + class AmbiguityWarning(Warning): pass def supercedes(a, b): - """ A is consistent and strictly more specific than B """ + """A is consistent and strictly more specific than B""" if len(a) < len(b): # only case is if a is empty and b is variadic return not a and len(b) == 1 and isvariadic(b[-1]) @@ -41,7 +52,7 @@ def supercedes(a, b): def consistent(a, b): - """ It is possible for an argument list to satisfy both A and B """ + """It is possible for an argument list to satisfy both A and B""" # Need to check for empty args if not a: @@ -51,8 +62,7 @@ def consistent(a, b): # Non-empty args check for mutual subclasses if len(a) == len(b): - return all(issubclass(aa, bb) or issubclass(bb, aa) - for aa, bb in zip(a, b)) + return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) else: p1 = 0 p2 = 0 @@ -70,45 +80,53 @@ def consistent(a, b): p1 += 1 # We only need to check for variadic ends # Variadic types are guaranteed to be the last element - return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined] - isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined] + return ( + isvariadic(cur_a) # type: ignore[possibly-undefined] + and p2 == len(b) + or isvariadic(cur_b) # type: ignore[possibly-undefined] + and p1 == len(a) + ) def ambiguous(a, b): - """ A is consistent with B but neither is strictly more specific """ + """A is consistent with B but neither is strictly more specific""" return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) def ambiguities(signatures): - """ All signature pairs such that A is ambiguous with B """ + """All signature pairs such that A is ambiguous with B""" signatures = list(map(tuple, signatures)) - return {(a, b) for a in signatures for b in signatures - if hash(a) < hash(b) - and ambiguous(a, b) - and not any(supercedes(c, a) and supercedes(c, b) - for c in signatures)} + return { + (a, b) + for a in signatures + for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) + } def super_signature(signatures): - """ A signature that would break ambiguities """ + """A signature that would break ambiguities""" n = len(signatures[0]) assert all(len(s) == n for s in signatures) - return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] - for i in range(n)] + return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)] def edge(a, b, tie_breaker=hash): - """ A should be checked before B + """A should be checked before B Tie broken by tie_breaker, defaults to ``hash`` """ # A either supercedes B and B does not supercede A or if B does then call # tie_breaker - return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)) + return supercedes(a, b) and ( + not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) + ) def ordering(signatures): - """ A sane ordering of signatures to check, first to last + """A sane ordering of signatures to check, first to last Topological sort of edges as given by ``edge`` and ``supercedes`` """ signatures = list(map(tuple, signatures)) diff --git a/torch/fx/experimental/unification/multipledispatch/core.py b/torch/fx/experimental/unification/multipledispatch/core.py index 5b5bdbc96301..57a0eadaae15 100644 --- a/torch/fx/experimental/unification/multipledispatch/core.py +++ b/torch/fx/experimental/unification/multipledispatch/core.py @@ -4,12 +4,14 @@ import sys from .dispatcher import Dispatcher, MethodDispatcher + global_namespace = {} # type: ignore[var-annotated] __all__ = ["dispatch", "ismethod"] + def dispatch(*types, **kwargs): - """ Dispatch function on the types of the inputs + """Dispatch function on the types of the inputs Supports dispatch on all non-keyword arguments. Collects implementations based on the function name. Ignores namespaces. If ambiguous type signatures occur a warning is raised when the function is @@ -38,6 +40,7 @@ def dispatch(*types, **kwargs): ... @dispatch(list) ... def __init__(self, data): ... self.data = data + ... ... @dispatch(int) ... def __init__(self, datum): ... self.data = [datum] @@ -46,7 +49,7 @@ def dispatch(*types, **kwargs): >>> MyClass(3).data [3] """ - namespace = kwargs.get('namespace', global_namespace) + namespace = kwargs.get("namespace", global_namespace) types = tuple(types) @@ -65,20 +68,21 @@ def dispatch(*types, **kwargs): dispatcher.add(types, func) return dispatcher + return _df def ismethod(func): - """ Is func a method? + """Is func a method? Note that this has to work as the method is defined but before the class is defined. At this stage methods look like functions. """ if hasattr(inspect, "signature"): signature = inspect.signature(func) - return signature.parameters.get('self', None) is not None + return signature.parameters.get("self", None) is not None else: if sys.version_info.major < 3: spec = inspect.getargspec(func) # type: ignore[attr-defined] else: spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] - return spec and spec.args and spec.args[0] == 'self' + return spec and spec.args and spec.args[0] == "self" diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index a1d28201d041..4f160995cce0 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -1,21 +1,35 @@ # mypy: allow-untyped-defs -from warnings import warn import inspect -from typing_extensions import deprecated -from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning -from .utils import expand_tuples -from .variadic import Variadic, isvariadic import itertools as itl +from typing_extensions import deprecated +from warnings import warn + +from .conflict import ambiguities, AmbiguityWarning, ordering, super_signature +from .utils import expand_tuples +from .variadic import isvariadic, Variadic + + +__all__ = [ + "MDNotImplementedError", + "ambiguity_warn", + "halt_ordering", + "restart_ordering", + "variadic_signature_matches_iter", + "variadic_signature_matches", + "Dispatcher", + "source", + "MethodDispatcher", + "str_signature", + "warning_text", +] -__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter", - "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"] class MDNotImplementedError(NotImplementedError): - """ A NotImplementedError for multiple dispatch """ + """A NotImplementedError for multiple dispatch""" def ambiguity_warn(dispatcher, ambiguities): - """ Raise warning when ambiguity is detected + """Raise warning when ambiguity is detected Parameters ---------- dispatcher : Dispatcher @@ -92,7 +106,7 @@ def variadic_signature_matches(types, full_signature): class Dispatcher: - """ Dispatch methods based on type signature + """Dispatch methods based on type signature Use ``dispatch`` to add implementations Examples -------- @@ -109,7 +123,8 @@ class Dispatcher: >>> f(3.0) 2.0 """ - __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc' + + __slots__ = "__name__", "name", "funcs", "_ordering", "_cache", "doc" def __init__(self, name, doc=None): self.name = self.__name__ = name @@ -119,9 +134,9 @@ class Dispatcher: self._cache = {} def register(self, *types, **kwargs): - """ register dispatcher with new implementation + """register dispatcher with new implementation >>> # xdoctest: +SKIP - >>> f = Dispatcher('f') + >>> f = Dispatcher("f") >>> @f.register(int) ... def inc(x): ... return x + 1 @@ -139,9 +154,11 @@ class Dispatcher: >>> f([1, 2, 3]) [3, 2, 1] """ + def _df(func): - self.add(types, func, **kwargs) # type: ignore[call-arg] + self.add(types, func, **kwargs) # type: ignore[call-arg] return func + return _df @classmethod @@ -152,28 +169,27 @@ class Dispatcher: @classmethod def get_func_annotations(cls, func): - """ get annotations of function positional parameters - """ + """get annotations of function positional parameters""" params = cls.get_func_params(func) if params: Parameter = inspect.Parameter - params = (param for param in params - if param.kind in - (Parameter.POSITIONAL_ONLY, - Parameter.POSITIONAL_OR_KEYWORD)) + params = ( + param + for param in params + if param.kind + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + ) - annotations = tuple( - param.annotation - for param in params) + annotations = tuple(param.annotation for param in params) if all(ann is not Parameter.empty for ann in annotations): return annotations def add(self, signature, func): - """ Add new types/method pair to dispatcher + """Add new types/method pair to dispatcher >>> # xdoctest: +SKIP - >>> D = Dispatcher('add') + >>> D = Dispatcher("add") >>> D.add((int, int), lambda x, y: x + y) >>> D.add((float, float), lambda x, y: x + y) >>> D(1, 2) @@ -202,24 +218,25 @@ class Dispatcher: for index, typ in enumerate(signature, start=1): if not isinstance(typ, (type, list)): - str_sig = ', '.join(c.__name__ if isinstance(c, type) - else str(c) for c in signature) - raise TypeError(f"Tried to dispatch on non-type: {typ}\n" - f"In signature: <{str_sig}>\n" - f"In function: {self.name}") + str_sig = ", ".join( + c.__name__ if isinstance(c, type) else str(c) for c in signature + ) + raise TypeError( + f"Tried to dispatch on non-type: {typ}\n" + f"In signature: <{str_sig}>\n" + f"In function: {self.name}" + ) # handle variadic signatures if isinstance(typ, list): if index != len(signature): - raise TypeError( - 'Variadic signature must be the last element' - ) + raise TypeError("Variadic signature must be the last element") if len(typ) != 1: raise TypeError( - 'Variadic signature must contain exactly one element. ' - 'To use a variadic union type place the desired types ' - 'inside of a tuple, e.g., [(int, str)]' + "Variadic signature must contain exactly one element. " + "To use a variadic union type place the desired types " + "inside of a tuple, e.g., [(int, str)]" ) new_signature.append(Variadic[typ[0]]) else: @@ -255,7 +272,8 @@ class Dispatcher: func = self.dispatch(*types) if not func: raise NotImplementedError( - f'Could not find signature for {self.name}: <{str_signature(types)}>') from e + f"Could not find signature for {self.name}: <{str_signature(types)}>" + ) from e self._cache[types] = func try: return func(*args, **kwargs) @@ -271,10 +289,12 @@ class Dispatcher: raise NotImplementedError( "Matching functions for " - f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e + f"{self.name}: <{str_signature(types)}> found, but none completed successfully", + ) from e def __str__(self): return f"" + __repr__ = __str__ def dispatch(self, *types): @@ -304,7 +324,6 @@ class Dispatcher: return None def dispatch_iter(self, *types): - n = len(types) for signature in self.ordering: if len(signature) == n and all(map(issubclass, types, signature)): @@ -315,21 +334,22 @@ class Dispatcher: result = self.funcs[signature] yield result - @deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning) + @deprecated( + "`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning + ) def resolve(self, types): - """ Determine appropriate implementation for this type signature + """Determine appropriate implementation for this type signature .. deprecated:: 0.4.4 Use ``dispatch(*types)`` instead """ return self.dispatch(*types) def __getstate__(self): - return {'name': self.name, - 'funcs': self.funcs} + return {"name": self.name, "funcs": self.funcs} def __setstate__(self, d): - self.name = d['name'] - self.funcs = d['funcs'] + self.name = d["name"] + self.funcs = d["funcs"] self._ordering = ordering(self.funcs) self._cache = {} @@ -344,23 +364,23 @@ class Dispatcher: for sig in self.ordering[::-1]: func = self.funcs[sig] if func.__doc__: - s = f'Inputs: <{str_signature(sig)}>\n' - s += '-' * len(s) + '\n' + s = f"Inputs: <{str_signature(sig)}>\n" + s += "-" * len(s) + "\n" s += func.__doc__.strip() docs.append(s) else: other.append(str_signature(sig)) if other: - docs.append('Other signatures:\n ' + '\n '.join(other)) + docs.append("Other signatures:\n " + "\n ".join(other)) - return '\n\n'.join(docs) + return "\n\n".join(docs) def _help(self, *args): return self.dispatch(*map(type, args)).__doc__ def help(self, *args, **kwargs): - """ Print docstring for the function corresponding to inputs """ + """Print docstring for the function corresponding to inputs""" print(self._help(*args)) def _source(self, *args): @@ -370,22 +390,23 @@ class Dispatcher: return source(func) def source(self, *args, **kwargs): - """ Print source code for the function corresponding to inputs """ + """Print source code for the function corresponding to inputs""" print(self._source(*args)) def source(func): - s = f'File: {inspect.getsourcefile(func)}\n\n' + s = f"File: {inspect.getsourcefile(func)}\n\n" s = s + inspect.getsource(func) return s class MethodDispatcher(Dispatcher): - """ Dispatch methods based on type signature + """Dispatch methods based on type signature See Also: Dispatcher """ - __slots__ = ('obj', 'cls') + + __slots__ = ("obj", "cls") @classmethod def get_func_params(cls, func): @@ -402,26 +423,31 @@ class MethodDispatcher(Dispatcher): types = tuple([type(arg) for arg in args]) func = self.dispatch(*types) if not func: - raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>') + raise NotImplementedError( + f"Could not find signature for {self.name}: <{str_signature(types)}>" + ) return func(self.obj, *args, **kwargs) def str_signature(sig): - """ String representation of type signature + """String representation of type signature >>> str_signature((int, float)) 'int, float' """ - return ', '.join(cls.__name__ for cls in sig) + return ", ".join(cls.__name__ for cls in sig) def warning_text(name, amb): - """ The text for ambiguity warnings """ + """The text for ambiguity warnings""" text = f"\nAmbiguities exist in dispatched function {name}\n\n" text += "The following signatures may result in ambiguous behavior:\n" for pair in amb: - text += "\t" + \ - ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" + text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" text += "\n\nConsider making the following additions:\n\n" - text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) - + f')\ndef {name}(...)' for s in amb]) + text += "\n\n".join( + [ + "@dispatch(" + str_signature(super_signature(s)) + f")\ndef {name}(...)" + for s in amb + ] + ) return text diff --git a/torch/fx/experimental/unification/multipledispatch/utils.py b/torch/fx/experimental/unification/multipledispatch/utils.py index 77702e8ccb7f..9c91cca2067a 100644 --- a/torch/fx/experimental/unification/multipledispatch/utils.py +++ b/torch/fx/experimental/unification/multipledispatch/utils.py @@ -1,8 +1,10 @@ # mypy: allow-untyped-defs from collections import OrderedDict + __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] + def raises(err, lamda): try: lamda() @@ -31,12 +33,12 @@ def expand_tuples(L): # Taken from theano/theano/gof/sched.py # Avoids licensing issues because this was written by Matthew Rocklin def _toposort(edges): - """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + """Topological sort algorithm by Kahn [1] - O(nodes + vertices) inputs: edges - a dict of the form {a: {b, c}} where b and c depend on a outputs: L - an ordered list of nodes that satisfy the dependencies of edges - >>> _toposort({1: (2, 3), 2: (3, )}) + >>> _toposort({1: (2, 3), 2: (3,)}) [1, 2, 3] >>> # Closely follows the wikipedia page [2] >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", @@ -44,8 +46,7 @@ def _toposort(edges): >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms """ incoming_edges = reverse_dict(edges) - incoming_edges = OrderedDict((k, set(val)) - for k, val in incoming_edges.items()) + incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) L = [] @@ -64,7 +65,7 @@ def _toposort(edges): def reverse_dict(d): """Reverses direction of dependence dict - >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} >>> reverse_dict(d) # doctest: +SKIP {1: ('a',), 2: ('a', 'b'), 3: ('b',)} :note: dict order are not deterministic. As we iterate on the @@ -82,8 +83,8 @@ def reverse_dict(d): # Taken from toolz # Avoids licensing issues because this version was authored by Matthew Rocklin def groupby(func, seq): - """ Group a collection by a key function - >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + """Group a collection by a key function + >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] >>> groupby(len, names) # doctest: +SKIP {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} >>> iseven = lambda x: x % 2 == 0 diff --git a/torch/fx/experimental/unification/multipledispatch/variadic.py b/torch/fx/experimental/unification/multipledispatch/variadic.py index 49e546e1ea26..1b5604a15248 100644 --- a/torch/fx/experimental/unification/multipledispatch/variadic.py +++ b/torch/fx/experimental/unification/multipledispatch/variadic.py @@ -1,15 +1,17 @@ # mypy: allow-untyped-defs from .utils import typename + __all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"] + class VariadicSignatureType(type): # checking if subclass is a subclass of self def __subclasscheck__(cls, subclass): - other_type = (subclass.variadic_type if isvariadic(subclass) - else (subclass,)) + other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) return subclass is cls or all( - issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined] + issubclass(other, cls.variadic_type) # type: ignore[attr-defined] + for other in other_type ) def __eq__(cls, other): @@ -24,8 +26,7 @@ class VariadicSignatureType(type): bool Whether or not `other` is equal to `self` """ - return (isvariadic(other) and - set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined] + return isvariadic(other) and set(cls.variadic_type) == set(other.variadic_type) # type: ignore[attr-defined] def __hash__(cls): return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined] @@ -57,17 +58,20 @@ class VariadicSignatureMeta(type): generate a new type for Variadic signatures. See the Variadic class for examples of how this behaves. """ + def __getitem__(cls, variadic_type): if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): - raise ValueError("Variadic types must be type or tuple of types" - " (Variadic[int] or Variadic[(int, float)]") + raise ValueError( + "Variadic types must be type or tuple of types" + " (Variadic[int] or Variadic[(int, float)]" + ) if not isinstance(variadic_type, tuple): - variadic_type = variadic_type, + variadic_type = (variadic_type,) return VariadicSignatureType( - f'Variadic[{typename(variadic_type)}]', + f"Variadic[{typename(variadic_type)}]", (), - dict(variadic_type=variadic_type, __slots__=()) + dict(variadic_type=variadic_type, __slots__=()), ) diff --git a/torch/fx/experimental/unification/unification_tools.py b/torch/fx/experimental/unification/unification_tools.py index d06d9bef771c..a47d900273f5 100644 --- a/torch/fx/experimental/unification/unification_tools.py +++ b/torch/fx/experimental/unification/unification_tools.py @@ -1,25 +1,40 @@ # mypy: allow-untyped-defs import collections import operator -from functools import reduce from collections.abc import Mapping +from functools import reduce -__all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap', - 'valfilter', 'keyfilter', 'itemfilter', - 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in'] + +__all__ = [ + "merge", + "merge_with", + "valmap", + "keymap", + "itemmap", + "valfilter", + "keyfilter", + "itemfilter", + "assoc", + "dissoc", + "assoc_in", + "update_in", + "get_in", +] def _get_factory(f, kwargs): - factory = kwargs.pop('factory', dict) + factory = kwargs.pop("factory", dict) if kwargs: - raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'") + raise TypeError( + f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'" + ) return factory def merge(*dicts, **kwargs): - """ Merge a collection of dictionaries + """Merge a collection of dictionaries - >>> merge({1: 'one'}, {2: 'two'}) + >>> merge({1: "one"}, {2: "two"}) {1: 'one', 2: 'two'} Later dictionaries have precedence @@ -41,7 +56,7 @@ def merge(*dicts, **kwargs): def merge_with(func, *dicts, **kwargs): - """ Merge dictionaries and apply function to combined values + """Merge dictionaries and apply function to combined values A key may occur in more than one dict, and all values mapped from the key will be passed to the function as a list, such as func([val1, val2, ...]). @@ -70,7 +85,7 @@ def merge_with(func, *dicts, **kwargs): def valmap(func, d, factory=dict): - """ Apply function to values of dictionary + """Apply function to values of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} >>> valmap(sum, bills) # doctest: +SKIP @@ -86,7 +101,7 @@ def valmap(func, d, factory=dict): def keymap(func, d, factory=dict): - """ Apply function to keys of dictionary + """Apply function to keys of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} >>> keymap(str.lower, bills) # doctest: +SKIP @@ -102,7 +117,7 @@ def keymap(func, d, factory=dict): def itemmap(func, d, factory=dict): - """ Apply function to items of dictionary + """Apply function to items of dictionary >>> accountids = {"Alice": 10, "Bob": 20} >>> itemmap(reversed, accountids) # doctest: +SKIP @@ -118,7 +133,7 @@ def itemmap(func, d, factory=dict): def valfilter(predicate, d, factory=dict): - """ Filter items in dictionary by value + """Filter items in dictionary by value >>> iseven = lambda x: x % 2 == 0 >>> d = {1: 2, 2: 3, 3: 4, 4: 5} @@ -138,7 +153,7 @@ def valfilter(predicate, d, factory=dict): def keyfilter(predicate, d, factory=dict): - """ Filter items in dictionary by key + """Filter items in dictionary by key >>> iseven = lambda x: x % 2 == 0 >>> d = {1: 2, 2: 3, 3: 4, 4: 5} @@ -158,7 +173,7 @@ def keyfilter(predicate, d, factory=dict): def itemfilter(predicate, d, factory=dict): - """ Filter items in dictionary by item + """Filter items in dictionary by item >>> def isvalid(item): ... k, v = item @@ -182,13 +197,13 @@ def itemfilter(predicate, d, factory=dict): def assoc(d, key, value, factory=dict): - """ Return a new dict with new key value pair + """Return a new dict with new key value pair New dict has d[key] set to value. Does not modify the initial dictionary. - >>> assoc({'x': 1}, 'x', 2) + >>> assoc({"x": 1}, "x", 2) {'x': 2} - >>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP + >>> assoc({"x": 1}, "y", 3) # doctest: +SKIP {'x': 1, 'y': 3} """ d2 = factory() @@ -198,22 +213,22 @@ def assoc(d, key, value, factory=dict): def dissoc(d, *keys, **kwargs): - """ Return a new dict with the given key(s) removed. + """Return a new dict with the given key(s) removed. New dict has d[key] deleted for each supplied key. Does not modify the initial dictionary. - >>> dissoc({'x': 1, 'y': 2}, 'y') + >>> dissoc({"x": 1, "y": 2}, "y") {'x': 1} - >>> dissoc({'x': 1, 'y': 2}, 'y', 'x') + >>> dissoc({"x": 1, "y": 2}, "y", "x") {} - >>> dissoc({'x': 1}, 'y') # Ignores missing keys + >>> dissoc({"x": 1}, "y") # Ignores missing keys {'x': 1} """ factory = _get_factory(dissoc, kwargs) d2 = factory() - if len(keys) < len(d) * .6: + if len(keys) < len(d) * 0.6: d2.update(d) for key in keys: if key in d2: @@ -227,13 +242,14 @@ def dissoc(d, *keys, **kwargs): def assoc_in(d, keys, value, factory=dict): - """ Return a new dict with new, potentially nested, key value pair + """Return a new dict with new, potentially nested, key value pair - >>> purchase = {'name': 'Alice', - ... 'order': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP + >>> purchase = { + ... "name": "Alice", + ... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP {'credit card': '5555-1234-1234-1234', 'name': 'Alice', 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} @@ -242,7 +258,7 @@ def assoc_in(d, keys, value, factory=dict): def update_in(d, keys, func, default=None, factory=dict): - """ Update value in a (potentially) nested dictionary + """Update value in a (potentially) nested dictionary inputs: d - dictionary on which to operate @@ -257,14 +273,15 @@ def update_in(d, keys, func, default=None, factory=dict): specified by the keys, with the innermost value set to func(default). >>> inc = lambda x: x + 1 - >>> update_in({'a': 0}, ['a'], inc) + >>> update_in({"a": 0}, ["a"], inc) {'a': 1} - >>> transaction = {'name': 'Alice', - ... 'purchase': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP {'credit card': '5555-1234-1234-1234', 'name': 'Alice', 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} @@ -272,7 +289,7 @@ def update_in(d, keys, func, default=None, factory=dict): >>> # updating a value when k0 is not in d >>> update_in({}, [1, 2, 3], str, default="bar") {1: {2: {3: 'bar'}}} - >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0) + >>> update_in({1: "foo"}, [2, 3, 4], inc, 0) {1: 'foo', 2: {3: {4: 1}}} """ ks = iter(keys) @@ -300,7 +317,7 @@ def update_in(d, keys, func, default=None, factory=dict): def get_in(keys, coll, default=None, no_default=False): - """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. + """Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless ``no_default`` is specified, then it raises KeyError or IndexError. @@ -308,20 +325,21 @@ def get_in(keys, coll, default=None, no_default=False): ``get_in`` is a generalization of ``operator.getitem`` for nested data structures such as dictionaries and lists. - >>> transaction = {'name': 'Alice', - ... 'purchase': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> get_in(['purchase', 'items', 0], transaction) + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> get_in(["purchase", "items", 0], transaction) 'Apple' - >>> get_in(['name'], transaction) + >>> get_in(["name"], transaction) 'Alice' - >>> get_in(['purchase', 'total'], transaction) - >>> get_in(['purchase', 'items', 'apple'], transaction) - >>> get_in(['purchase', 'items', 10], transaction) - >>> get_in(['purchase', 'total'], transaction, 0) + >>> get_in(["purchase", "total"], transaction) + >>> get_in(["purchase", "items", "apple"], transaction) + >>> get_in(["purchase", "items", 10], transaction) + >>> get_in(["purchase", "total"], transaction, 0) 0 - >>> get_in(['y'], {}, no_default=True) + >>> get_in(["y"], {}, no_default=True) Traceback (most recent call last): ... KeyError: 'y' @@ -352,9 +370,9 @@ def getter(index): def groupby(key, seq): - """ Group a collection by a key function + """Group a collection by a key function - >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] >>> groupby(len, names) # doctest: +SKIP {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} @@ -364,9 +382,14 @@ def groupby(key, seq): Non-callable keys imply grouping on a member. - >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'}, - ... {'name': 'Bob', 'gender': 'M'}, - ... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP + >>> groupby( + ... "gender", + ... [ + ... {"name": "Alice", "gender": "F"}, + ... {"name": "Bob", "gender": "M"}, + ... {"name": "Charlie", "gender": "M"}, + ... ], + ... ) # doctest:+SKIP {'F': [{'gender': 'F', 'name': 'Alice'}], 'M': [{'gender': 'M', 'name': 'Bob'}, {'gender': 'M', 'name': 'Charlie'}]} @@ -388,9 +411,9 @@ def groupby(key, seq): def first(seq): - """ The first element in a sequence + """The first element in a sequence - >>> first('ABC') + >>> first("ABC") 'A' """ return next(iter(seq)) diff --git a/torch/fx/experimental/unification/utils.py b/torch/fx/experimental/unification/utils.py index 609fe59d43f4..7634c9b2ec90 100644 --- a/torch/fx/experimental/unification/utils.py +++ b/torch/fx/experimental/unification/utils.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-defs __all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] + + def hashable(x): try: hash(x) @@ -9,7 +11,7 @@ def hashable(x): def transitive_get(key, d): - """ Transitive dict.get + """Transitive dict.get >>> d = {1: 2, 2: 3, 3: 4} >>> d.get(1) 2 @@ -32,13 +34,13 @@ def raises(err, lamda): # Taken from theano/theano/gof/sched.py # Avoids licensing issues because this was written by Matthew Rocklin def _toposort(edges): - """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + """Topological sort algorithm by Kahn [1] - O(nodes + vertices) inputs: edges - a dict of the form {a: {b, c}} where b and c depend on a outputs: L - an ordered list of nodes that satisfy the dependencies of edges >>> # xdoctest: +SKIP - >>> _toposort({1: (2, 3), 2: (3, )}) + >>> _toposort({1: (2, 3), 2: (3,)}) [1, 2, 3] Closely follows the wikipedia page [2] [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", @@ -47,7 +49,7 @@ def _toposort(edges): """ incoming_edges = reverse_dict(edges) incoming_edges = {k: set(val) for k, val in incoming_edges.items()} - S = ({v for v in edges if v not in incoming_edges}) + S = {v for v in edges if v not in incoming_edges} L = [] while S: @@ -65,7 +67,7 @@ def _toposort(edges): def reverse_dict(d): """Reverses direction of dependence dict - >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} >>> reverse_dict(d) # doctest: +SKIP {1: ('a',), 2: ('a', 'b'), 3: ('b',)} :note: dict order are not deterministic. As we iterate on the @@ -89,12 +91,12 @@ def xfail(func): def freeze(d): - """ Freeze container to hashable form + """Freeze container to hashable form >>> freeze(1) 1 >>> freeze([1, 2]) (1, 2) - >>> freeze({1: 2}) # doctest: +SKIP + >>> freeze({1: 2}) # doctest: +SKIP frozenset([(1, 2)]) """ if isinstance(d, dict): diff --git a/torch/fx/experimental/unification/variable.py b/torch/fx/experimental/unification/variable.py index 66e97a3a7663..46e59851fdfa 100644 --- a/torch/fx/experimental/unification/variable.py +++ b/torch/fx/experimental/unification/variable.py @@ -1,14 +1,16 @@ # mypy: allow-untyped-defs from contextlib import contextmanager -from .utils import hashable + from .dispatch import dispatch +from .utils import hashable + _global_logic_variables = set() # type: ignore[var-annotated] _glv = _global_logic_variables class Var: - """ Logic Variable """ + """Logic Variable""" _id = 1 @@ -25,6 +27,7 @@ class Var: def __str__(self): return "~" + str(self.token) # type: ignore[attr-defined] + __repr__ = __str__ def __eq__(self, other): @@ -46,6 +49,7 @@ def vars(): def isvar(v): return True + isvar @@ -69,12 +73,12 @@ def variables(*variables): False >>> # Normal approach >>> from unification import unify - >>> x = var('x') + >>> x = var("x") >>> unify(x, 1) {~x: 1} >>> # Context Manager approach - >>> with variables('x'): - ... print(unify('x', 1)) + >>> with variables("x"): + ... print(unify("x", 1)) {'x': 1} """ old_global_logic_variables = _global_logic_variables.copy() diff --git a/torch/fx/experimental/unify_refinements.py b/torch/fx/experimental/unify_refinements.py index cad0a33425bf..bab662e0655a 100644 --- a/torch/fx/experimental/unify_refinements.py +++ b/torch/fx/experimental/unify_refinements.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from torch.fx.experimental.graph_gradual_typechecker import Refine +from torch.fx.experimental.unification import unify, Var # type: ignore[attr-defined] from torch.fx.tensor_type import TensorType -from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] def infer_symbolic_types_single_pass(traced): @@ -13,6 +13,7 @@ def infer_symbolic_types_single_pass(traced): mgu = unify_eq(r.constraints) substitute_all_types(traced.graph, mgu) + def infer_symbolic_types(traced): """ Calls our symbolic inferencer twice. @@ -32,6 +33,7 @@ def infer_symbolic_types(traced): r.symbolic_relations() + def convert_eq(list_of_eq): """ Convert equality constraints in the right format @@ -109,6 +111,7 @@ def substitute_all_types(graph, mapping): for n in graph.nodes: n.type = substitute_solution_one_type(mapping, n.type) + def check_for_type_equality(g1, g2): """ A check equality to be used in fixed points. diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 805d8e994ccd..67335ef92a76 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1,32 +1,47 @@ # mypy: allow-untyped-defs -from collections import defaultdict -from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name -import torch.utils._pytree as pytree -from . import _pytree as fx_pytree -from ._compatibility import compatibility -from torch._C import _NodeIter - -import os +import builtins import contextlib -from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type, Iterable -from dataclasses import dataclass -from contextlib import contextmanager import copy import enum -import torch -import keyword -import re -import builtins -import math -import warnings -import inspect import functools +import inspect +import keyword +import math +import os +import re +import warnings +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + Iterable, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, +) + +import torch +import torch.utils._pytree as pytree +from torch._C import _NodeIter + +from . import _pytree as fx_pytree +from ._compatibility import compatibility +from .node import _get_qualified_name, _type_repr, Argument, map_arg, Node, Target + __all__ = ["PythonCode", "CodeGen", "Graph"] if TYPE_CHECKING: + from ._symbolic_trace import Tracer # noqa: F401 from .graph_module import GraphModule # noqa: F401 - from ._symbolic_trace import Tracer # noqa: F401 # Mapping of builtins to their `typing` equivalent. @@ -38,7 +53,9 @@ _origin_type_map = { tuple: Tuple, } -_legal_ops = dict.fromkeys(['call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output']) +_legal_ops = dict.fromkeys( + ["call_function", "call_method", "get_attr", "call_module", "placeholder", "output"] +) # Signature for functions thattransforms the body (`list[str]`) of the @@ -53,11 +70,13 @@ class _CustomBuiltin(NamedTuple): an import. For common objects of this sort, we bundle them in the globals of every FX graph. """ + # How to import this object from the standard library. import_str: str # The actual object, produced from that import string. obj: Any + _custom_builtins: Dict[str, _CustomBuiltin] = {} @@ -65,17 +84,17 @@ def _register_custom_builtin(name: str, import_str: str, obj: Any): _custom_builtins[name] = _CustomBuiltin(import_str, obj) -_register_custom_builtin('inf', 'from math import inf', math.inf) -_register_custom_builtin('nan', 'from math import nan', math.nan) -_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None)) -_register_custom_builtin('torch', 'import torch', torch) -_register_custom_builtin('device', 'from torch import device', torch.device) -_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree) -_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree) +_register_custom_builtin("inf", "from math import inf", math.inf) +_register_custom_builtin("nan", "from math import nan", math.nan) +_register_custom_builtin("NoneType", "NoneType = type(None)", type(None)) +_register_custom_builtin("torch", "import torch", torch) +_register_custom_builtin("device", "from torch import device", torch.device) +_register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree) +_register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree) def _is_magic(x: str) -> bool: - return x.startswith('__') and x.endswith('__') + return x.startswith("__") and x.endswith("__") def _snake_case(s: str) -> str: @@ -91,22 +110,22 @@ def _snake_case(s: str) -> str: # Replace occurrences where a lowercase letter is followed by an uppercase letter -_snake_case_sub = functools.partial(re.compile(r'(?<=[a-z])([A-Z])').sub, r'_\1') +_snake_case_sub = functools.partial(re.compile(r"(?<=[a-z])([A-Z])").sub, r"_\1") def _is_from_torch(obj: Any) -> bool: - module_name = getattr(obj, '__module__', None) + module_name = getattr(obj, "__module__", None) if module_name is not None: - base_module = module_name.partition('.')[0] + base_module = module_name.partition(".")[0] return ( - base_module == 'torch' and - not module_name.startswith("torch._dynamo.") and - not module_name.startswith("torch._inductor.") + base_module == "torch" + and not module_name.startswith("torch._dynamo.") + and not module_name.startswith("torch._inductor.") ) - name = getattr(obj, '__name__', None) + name = getattr(obj, "__name__", None) # exclude torch because torch.torch.torch.torch works. idk mang - if name is not None and name != 'torch': + if name is not None and name != "torch": for guess in [torch, torch.nn.functional]: if getattr(guess, name, None) is obj: return True @@ -122,13 +141,14 @@ class _Namespace: - Each name is unique within a given namespace. - Names generated do not shadow builtins, unless the object is indeed that builtin. """ + def __init__(self): self._obj_to_name: Dict[Any, str] = {} self._unassociated_names = set() self._used_names: Set[str] = set() self._base_count: Dict[str, int] = defaultdict(int) - self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+') + self._illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") def create_name(self, candidate: str, obj: Optional[Any]) -> str: @@ -142,13 +162,13 @@ class _Namespace: return self._obj_to_name[obj] # delete all characters that are illegal in a Python identifier - candidate = self._illegal_char_regex.sub('_', candidate) + candidate = self._illegal_char_regex.sub("_", candidate) if not candidate: - candidate = '_unnamed' + candidate = "_unnamed" if candidate[0].isdigit(): - candidate = f'_{candidate}' + candidate = f"_{candidate}" match = self._name_suffix_regex.match(candidate) if match is None: @@ -158,13 +178,13 @@ class _Namespace: base, num_str = match.group(1, 2) num = int(num_str) - candidate = base if num is None else f'{base}_{num}' + candidate = base if num is None else f"{base}_{num}" if not num: num = self._base_count[base] while candidate in self._used_names or self._is_illegal_name(candidate, obj): num += 1 - candidate = f'{base}_{num}' + candidate = f"{base}_{num}" self._used_names.add(candidate) self._base_count[base] = num @@ -204,36 +224,39 @@ class _Namespace: self._obj_to_name[obj] = name self._used_names.add(name) + dtype_abbrs = { - torch.bfloat16: 'bf16', - torch.float64: 'f64', - torch.float32: 'f32', - torch.float16: 'f16', - torch.float8_e4m3fn: 'f8e4m3fn', - torch.float8_e5m2: 'f8e5m2', - torch.float8_e4m3fnuz: 'f8e4m3fnuz', - torch.float8_e5m2fnuz: 'f8e5m2fnuz', - torch.complex32: 'c32', - torch.complex64: 'c64', - torch.complex128: 'c128', - torch.int8: 'i8', - torch.int16: 'i16', - torch.int32: 'i32', - torch.int64: 'i64', - torch.bool: 'b8', - torch.uint8: 'u8', - torch.uint16: 'u16', - torch.uint32: 'u32', - torch.uint64: 'u64', - torch.bits16: 'b16', + torch.bfloat16: "bf16", + torch.float64: "f64", + torch.float32: "f32", + torch.float16: "f16", + torch.float8_e4m3fn: "f8e4m3fn", + torch.float8_e5m2: "f8e5m2", + torch.float8_e4m3fnuz: "f8e4m3fnuz", + torch.float8_e5m2fnuz: "f8e5m2fnuz", + torch.complex32: "c32", + torch.complex64: "c64", + torch.complex128: "c128", + torch.int8: "i8", + torch.int16: "i16", + torch.int32: "i32", + torch.int64: "i64", + torch.bool: "b8", + torch.uint8: "u8", + torch.uint16: "u16", + torch.uint32: "u32", + torch.uint64: "u64", + torch.bits16: "b16", } + @compatibility(is_backward_compatible=True) @dataclass class PythonCode: """ Represents all the information necessary to exec or save a graph as Python code. """ + # Python source code for the forward function definition. src: str # Values in global scope during execution of `src_def`. @@ -244,15 +267,16 @@ class PythonCode: def _format_target(base: str, target: str) -> str: - elems = target.split('.') + elems = target.split(".") r = base for e in elems: if not e.isidentifier(): r = f'getattr({r}, "{e}")' else: - r = f'{r}.{e}' + r = f"{r}.{e}" return r + class _InsertPoint: def __init__(self, graph, new_insert): self.graph = graph @@ -264,9 +288,10 @@ class _InsertPoint: def __exit__(self, type, value, tb): self.graph._insert = self.orig_insert + class _node_list: - def __init__(self, graph: 'Graph', direction: str = '_next'): - assert direction in ['_next', '_prev'] + def __init__(self, graph: "Graph", direction: str = "_next"): + assert direction in ["_next", "_prev"] self.graph = graph self.direction = direction @@ -278,35 +303,40 @@ class _node_list: yield from _NodeIter(self.graph._root, self.direction == "_prev") def __reversed__(self): - return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') + return _node_list(self.graph, "_next" if self.direction == "_prev" else "_prev") + class _PyTreeInfo(NamedTuple): """ Contains extra info stored when we're using Pytrees """ + orig_args: List[str] in_spec: pytree.TreeSpec out_spec: Optional[pytree.TreeSpec] + @dataclass(frozen=True) class _ParsedStackTrace: """ Represents the top-most frame of a parsed stack trace """ + file: str lineno: str name: str code: str def get_summary_str(self): - return f'File: {self.file}:{self.lineno} in {self.name}, code: {self.code}' + return f"File: {self.file}:{self.lineno} in {self.name}, code: {self.code}" + # get File:lineno code from stack_trace def _parse_stack_trace(stack_trace: str): if stack_trace is None: return None pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") - lines = stack_trace.strip().split('\n') + lines = stack_trace.strip().split("\n") # stacktrace should have innermost frame last, so we # iterate backwards to find the first line that starts # with 'File ' @@ -322,6 +352,7 @@ def _parse_stack_trace(stack_trace: str): return _ParsedStackTrace(file, lineno, name, code) return None + @compatibility(is_backward_compatible=False) class CodeGen: def __init__(self): @@ -335,16 +366,18 @@ class CodeGen: """ # If the original function didn't have self as its first argument, we # would have added it. - if len(free_vars) == 0 or free_vars[0] != 'self': - free_vars.insert(0, 'self') - return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + if len(free_vars) == 0 or free_vars[0] != "self": + free_vars.insert(0, "self") + return ( + f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + ) def generate_output(self, output_args: Argument) -> str: """ Given the output arguments, generates the return statement of the FX function. Note: The returned statement should not be indented. """ - return f'return {repr(output_args)}' + return f"return {repr(output_args)}" def process_inputs(self, *args: Any) -> Any: """ @@ -373,8 +406,15 @@ class CodeGen: return [] def _gen_python_code( - self, nodes, root_module: str, namespace: _Namespace, *, - verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False + self, + nodes, + root_module: str, + namespace: _Namespace, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, ) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] @@ -382,9 +422,13 @@ class CodeGen: wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation : List[str] = [''] - include_stride = include_stride or (os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1") - include_device = include_device or (os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1") + maybe_return_annotation: List[str] = [""] + include_stride = include_stride or ( + os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1" + ) + include_device = include_device or ( + os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1" + ) def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -394,7 +438,9 @@ class CodeGen: Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -413,19 +459,19 @@ class CodeGen: for name, (_, obj) in _custom_builtins.items(): add_global(name, obj) - def type_repr(o : Any): + def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): # This is a generic type, e.g. typing.List[torch.Tensor] origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, '__args__'): + if hasattr(o, "__args__"): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] @@ -460,6 +506,7 @@ class CodeGen: if colored: return f"{codes[name]}{s}{codes['reset']}" return s + return f yellow = make_wrapper_func("yellow") # noqa: F841 @@ -473,11 +520,13 @@ class CodeGen: def _get_repr(arg: Any) -> str: # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): + if isinstance(arg, tuple) and hasattr(arg, "_fields"): qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" - elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + elif isinstance( + arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ): qualified_name = _get_qualified_name(arg) global_name = add_global(qualified_name, arg) return f"{global_name}" @@ -494,22 +543,23 @@ class CodeGen: else: return blue(repr(arg)) - - def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + def _format_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> str: + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' + return f"{args_s}, {kwargs_s}" return args_s or kwargs_s # Run through reverse nodes and record the first instance of a use # of a given node. This represents the *last* use of the node in the # execution order of the program, which we will use to free unused # values - node_to_last_use : Dict[Node, Node] = {} - user_to_last_uses : Dict[Node, List[Node]] = {} + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} - def register_last_uses(n : Node, user : Node): + def register_last_uses(n: Node, user: Node): if n not in node_to_last_use: node_to_last_use[n] = user user_to_last_uses.setdefault(user, []).append(n) @@ -518,16 +568,16 @@ class CodeGen: map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - def delete_unused_values(user : Node): + def delete_unused_values(user: Node): """ Delete values after their last use. This ensures that values that are not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) @@ -538,21 +588,23 @@ class CodeGen: nodes_to_delete.append(user) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {dim(to_delete_str)}\n') + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {dim(to_delete_str)}\n") else: - body.append('\n') + body.append("\n") prev_stacktrace = None - def append_stacktrace_summary(node : Node): + def append_stacktrace_summary(node: Node): """ Append a summary of the stacktrace to the generated code. This is useful for debugging. """ nonlocal prev_stacktrace - if node.op not in {'placeholder', 'output'}: + if node.op not in {"placeholder", "output"}: if node.stack_trace: if node.stack_trace != prev_stacktrace: prev_stacktrace = node.stack_trace @@ -565,93 +617,128 @@ class CodeGen: elif prev_stacktrace != "": prev_stacktrace = "" no_stacktrace_msg = "# No stacktrace found for following nodes" - body.append(f'\n{dim(no_stacktrace_msg)}\n') + body.append(f"\n{dim(no_stacktrace_msg)}\n") - def stringify_shape(shape : Iterable) -> str: + def stringify_shape(shape: Iterable) -> str: return f"[{', '.join(str(x) for x in shape)}]" - def emit_node(node : Node): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + def emit_node(node: Node): + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" + ) if verbose: # override annotation with more detailed information from torch.fx.experimental.proxy_tensor import py_sym_types from torch.fx.passes.shape_prop import TensorMetadata - meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None))) + meta_val = node.meta.get( + "val", + node.meta.get("tensor_meta", node.meta.get("example_value", None)), + ) # use string as annotation, to make it valid python code if isinstance(meta_val, torch.Tensor): - stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else "" + stride_annotation = ( + f"{stringify_shape(meta_val.stride())}" + if include_stride + else "" + ) device_annotation = f"{meta_val.device}" if include_device else "" - maybe_type_annotation = \ - f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' \ + maybe_type_annotation = ( + f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"' + ) elif isinstance(meta_val, py_sym_types): maybe_type_annotation = f': "Sym({meta_val})"' elif isinstance(meta_val, TensorMetadata): maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' - if node.op == 'placeholder': + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = ( + "" if not node.args else f" = {_get_repr(node.args[0])}" + ) + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" + ) + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods: + if ( + getattr(node.target, "__module__", "") == "_operator" + and node.target.__name__ in magic_methods + ): assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}') + if ( + getattr(node.target, "__module__", "") == "_operator" + and node.target.__name__ in inplace_methods + ): + body.append( + f"{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}') + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}" + ) return - body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + body.append( + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" body.append(self.generate_output(node.args[0])) return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends @@ -669,15 +756,13 @@ class CodeGen: # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') - - + body.append("pass\n") if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" if self._body_transformer: body = self._body_transformer(body) @@ -689,10 +774,10 @@ class CodeGen: # remove counter and generate lineno to node index mapping lineno_map: Dict[int, Optional[int]] = {} - prologue_len = prologue.count('\n') + 1 + prologue_len = prologue.count("\n") + 1 new_lines: List[str] = [] cur_idx = None - for line in ''.join(body).split('\n'): + for line in "".join(body).split("\n"): counter = re.search(r"# COUNTER: (\d+)", line) if counter and counter.group(1) is not None: cur_idx = int(counter.group(1)) @@ -700,8 +785,8 @@ class CodeGen: lineno_map[len(new_lines) + prologue_len] = cur_idx new_lines.append(line) - code = "\n".join(new_lines).lstrip('\n') - code = '\n'.join(' ' + line for line in code.split('\n')) + code = "\n".join(new_lines).lstrip("\n") + code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" {wrap_stmts} @@ -754,25 +839,35 @@ class _PyTreeCodeGen(CodeGen): return super().gen_fn_def(free_vars, maybe_return_annotation) fn_args = self.pytree_info.orig_args - has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False + has_orig_self = (fn_args[0] == "self") if len(fn_args) > 0 else False if has_orig_self: - free_vars.insert(0, 'self') + free_vars.insert(0, "self") fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) if len(free_vars) > 0: # pytree has placeholders in it # when kwargs is present, in_spec is tuple(args, kwargs) - has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \ - self.pytree_info.in_spec.num_children == 2 and \ - self.pytree_info.in_spec.children_specs[0].type == tuple and \ - self.pytree_info.in_spec.children_specs[1].type == dict - fn_kwargs = '{}' + has_args_kwargs_tuple = ( + self.pytree_info.in_spec.type == tuple + and self.pytree_info.in_spec.num_children == 2 + and self.pytree_info.in_spec.children_specs[0].type == tuple + and self.pytree_info.in_spec.children_specs[1].type == dict + ) + fn_kwargs = "{}" fn_signature = f"[{', '.join(fn_args)}], self._in_spec" if has_args_kwargs_tuple: count_args = self.pytree_info.in_spec.children_specs[0].num_children fn_args = self.pytree_info.orig_args[:count_args] - fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip( - self.pytree_info.in_spec.children_specs[1].context, - self.pytree_info.orig_args[count_args:])) + '}' + fn_kwargs = ( + "{" + + ", ".join( + f"'{k}':{v}" + for k, v in zip( + self.pytree_info.in_spec.children_specs[1].context, + self.pytree_info.orig_args[count_args:], + ) + ) + + "}" + ) fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. @@ -789,16 +884,20 @@ class _PyTreeCodeGen(CodeGen): def generate_output(self, output_args): if self.pytree_info and self.pytree_info.out_spec: - return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)' + return f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)" else: return super().generate_output(output_args) + class _FindNodesLookupTable: """ Side table for the graph for the purpose of doing fast queries """ + def __init__(self): - self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(dict) + self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict( + dict + ) def _key(self, node) -> Tuple[str, Optional[Target]]: return (node.op, node.target if node.op == "call_function" else None) @@ -812,7 +911,7 @@ class _FindNodesLookupTable: def remove(self, node: Node) -> None: self.table[self._key(node)].pop(node) - def find_nodes(self, *, op: str, target: Optional['Target'] = None): + def find_nodes(self, *, op: str, target: Optional["Target"] = None): if op == "call_function": assert target is not None return [*self.table[(op, target)].keys()] @@ -823,6 +922,7 @@ class _FindNodesLookupTable: # op is call_method, get_attr, call_module return [node for node in self.table[(op, None)].keys() if node.target == target] + @compatibility(is_backward_compatible=True) class Graph: """ @@ -838,6 +938,7 @@ class Graph: import torch import torch.fx + class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -845,7 +946,10 @@ class Graph: self.linear = torch.nn.Linear(4, 5) def forward(self, x): - return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + return torch.topk( + torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 + ) + m = MyModule() gm = torch.fx.symbolic_trace(m) @@ -869,13 +973,17 @@ class Graph: """ @compatibility(is_backward_compatible=True) - def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, - tracer_extras: Optional[Dict[str, Any]] = None): + def __init__( + self, + owning_module: Optional["GraphModule"] = None, + tracer_cls: Optional[Type["Tracer"]] = None, + tracer_extras: Optional[Dict[str, Any]] = None, + ): """ Construct an empty Graph. """ - self._root : Node = Node(self, '', 'root', '', (), {}) - self._used_names : Dict[str, int] = {} # base name -> number + self._root: Node = Node(self, "", "root", "", (), {}) + self._used_names: Dict[str, int] = {} # base name -> number self._insert = self._root.prepend self._len = 0 self._graph_namespace = _Namespace() @@ -883,7 +991,7 @@ class Graph: self._tracer_cls = tracer_cls self._tracer_extras = tracer_extras self._codegen = CodeGen() - self._co_fields : Dict[str, Any] = {} + self._co_fields: Dict[str, Any] = {} self._find_nodes_lookup_table = _FindNodesLookupTable() @property @@ -910,7 +1018,9 @@ class Graph: return _node_list(self) @compatibility(is_backward_compatible=False) - def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = True): + def find_nodes( + self, *, op: str, target: Optional["Target"] = None, sort: bool = True + ): """ Allows for fast query of nodes @@ -934,7 +1044,9 @@ class Graph: return node_list @compatibility(is_backward_compatible=True) - def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': + def graph_copy( + self, g: "Graph", val_map: Dict[Node, Node], return_output_node=False + ) -> "Optional[Argument]": """ Copy all nodes from a given graph into ``self``. @@ -954,13 +1066,13 @@ class Graph: for node in g.nodes: if node in val_map: continue - if node.op == 'output': + if node.op == "output": rv = map_arg(node.args[0], lambda n: val_map[n]) return rv if not return_output_node else (rv, node) - val_map[node] = self.node_copy(node, lambda n : val_map[n]) + val_map[node] = self.node_copy(node, lambda n: val_map[n]) return None - def __deepcopy__(self, memo=None) -> 'Graph': + def __deepcopy__(self, memo=None) -> "Graph": """ Explicitly implement __deepcopy__ to prevent excessive recursion depth from the default implementation. This uses graph_copy to copy the nodes @@ -974,16 +1086,22 @@ class Graph: g._codegen = copy.deepcopy(self._codegen) assert isinstance(output_vals, tuple) output_val, old_output_node = output_vals - new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None)) + new_output_node = g.output( + output_val, type_expr=getattr(old_output_node, "type", None) + ) new_output_node.meta = copy.copy(old_output_node.meta) return g @compatibility(is_backward_compatible=True) - def create_node(self, op: str, target: 'Target', - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - name: Optional[str] = None, - type_expr: Optional[Any] = None) -> Node: + def create_node( + self, + op: str, + target: "Target", + args: Optional[Tuple["Argument", ...]] = None, + kwargs: Optional[Dict[str, "Argument"]] = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Create a ``Node`` and add it to the ``Graph`` at the current insert-point. Note that the current insert-point can be set via :meth:`Graph.inserting_before` @@ -1019,7 +1137,10 @@ class Graph: name = self._graph_namespace.create_name(candidate, None) n = Node(self, name, op, target, args, kwargs, type_expr) - if self.owning_module is not None and getattr(self.owning_module, "_create_node_hooks", None) is not None: + if ( + self.owning_module is not None + and getattr(self.owning_module, "_create_node_hooks", None) is not None + ): for f in self.owning_module._create_node_hooks: f(n) @@ -1041,9 +1162,8 @@ class Graph: def process_outputs(self, out): return self._codegen.process_outputs(out) - @compatibility(is_backward_compatible=True) - def erase_node(self, to_erase : Node) -> None: + def erase_node(self, to_erase: Node) -> None: """ Erases a ``Node`` from the ``Graph``. Throws an exception if there are still users of that node in the ``Graph``. @@ -1053,15 +1173,20 @@ class Graph: to_erase (Node): The ``Node`` to erase from the ``Graph``. """ if len(to_erase.users) > 0: - raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' - f'users in the graph: {to_erase.users}!') + raise RuntimeError( + f"Tried to erase Node {to_erase} but it still had {len(to_erase.users)} " + f"users in the graph: {to_erase.users}!" + ) if to_erase.graph != self: raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") if to_erase._erased: warnings.warn(f"erase_node({to_erase}) on an already erased node") return - if self.owning_module is not None and getattr(self.owning_module, "_erase_node_hooks", None) is not None: + if ( + self.owning_module is not None + and getattr(self.owning_module, "_erase_node_hooks", None) is not None + ): for f in self.owning_module._erase_node_hooks: f(to_erase) @@ -1086,9 +1211,9 @@ class Graph: then restore it when the with statement exits:: with g.inserting_before(n): - ... # inserting before node n - ... # insert point restored to what it was previously - g.inserting_before(n) # set the insert point permanently + ... # inserting before node n + ... # insert point restored to what it was previously + g.inserting_before(n) # set the insert point permanently Args: @@ -1110,9 +1235,9 @@ class Graph: then restore it when the with statement exits:: with g.inserting_after(n): - ... # inserting after node n - ... # insert point restored to what it was previously - g.inserting_after(n) # set the insert point permanently + ... # inserting after node n + ... # insert point restored to what it was previously + g.inserting_after(n) # set the insert point permanently Args: @@ -1128,8 +1253,12 @@ class Graph: return _InsertPoint(self, n.append) @compatibility(is_backward_compatible=True) - def placeholder(self, name: str, type_expr: Optional[Any] = None, - default_value : Any = inspect.Signature.empty) -> Node: + def placeholder( + self, + name: str, + type_expr: Optional[Any] = None, + default_value: Any = inspect.Signature.empty, + ) -> Node: """ Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents a function input. @@ -1154,7 +1283,7 @@ class Graph: as ``Graph.create_node``. """ args = () if default_value is inspect.Signature.empty else (default_value,) - return self.create_node('placeholder', name, args=args, type_expr=type_expr) + return self.create_node("placeholder", name, args=args, type_expr=type_expr) @compatibility(is_backward_compatible=True) def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: @@ -1181,7 +1310,10 @@ class Graph: The same insertion point and type expression rules apply for this method as ``Graph.create_node``. """ - def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool: + + def _get_attr_reference_exists( + mod: torch.nn.Module, qualified_name: str + ) -> bool: module_path, _, name = qualified_name.rpartition(".") try: @@ -1195,32 +1327,40 @@ class Graph: res = getattr(submod, name) - if (not isinstance(res, torch.nn.Module) - and not isinstance(res, torch.nn.Parameter) - and name not in submod._buffers): + if ( + not isinstance(res, torch.nn.Module) + and not isinstance(res, torch.nn.Parameter) + and name not in submod._buffers + ): return False return True - if (self.owning_module and - not _get_attr_reference_exists(self.owning_module, qualified_name)): - warnings.warn("Attempted to insert a get_attr Node with no " - "underlying reference in the owning " - "GraphModule! Call " - "GraphModule.add_submodule to add the " - "necessary submodule, " - "GraphModule.add_parameter to add the " - "necessary Parameter, or " - "nn.Module.register_buffer to add the " - "necessary buffer", stacklevel=2) - return self.create_node('get_attr', qualified_name, type_expr=type_expr) + if self.owning_module and not _get_attr_reference_exists( + self.owning_module, qualified_name + ): + warnings.warn( + "Attempted to insert a get_attr Node with no " + "underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule, " + "GraphModule.add_parameter to add the " + "necessary Parameter, or " + "nn.Module.register_buffer to add the " + "necessary buffer", + stacklevel=2, + ) + return self.create_node("get_attr", qualified_name, type_expr=type_expr) @compatibility(is_backward_compatible=True) - def call_module(self, - module_name: str, - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: + def call_module( + self, + module_name: str, + args: Optional[Tuple["Argument", ...]] = None, + kwargs: Optional[Dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node represents a call to the forward() function of a ``Module`` in the ``Module`` @@ -1251,21 +1391,26 @@ class Graph: The same insertion point and type expression rules apply for this method as :meth:`Graph.create_node`. """ - if (self.owning_module and - self.owning_module.get_submodule(module_name) is None): - warnings.warn("Attempted to insert a call_module Node with " - "no underlying reference in the owning " - "GraphModule! Call " - "GraphModule.add_submodule to add the " - "necessary submodule") - return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) + if self.owning_module and self.owning_module.get_submodule(module_name) is None: + warnings.warn( + "Attempted to insert a call_module Node with " + "no underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule" + ) + return self.create_node( + "call_module", module_name, args, kwargs, type_expr=type_expr + ) @compatibility(is_backward_compatible=True) - def call_method(self, - method_name: str, - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: + def call_method( + self, + method_name: str, + args: Optional[Tuple["Argument", ...]] = None, + kwargs: Optional[Dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node represents a call to a given method on the 0th element of ``args``. @@ -1293,14 +1438,18 @@ class Graph: The same insertion point and type expression rules apply for this method as :meth:`Graph.create_node`. """ - return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) + return self.create_node( + "call_method", method_name, args, kwargs, type_expr=type_expr + ) @compatibility(is_backward_compatible=True) - def call_function(self, - the_function: Callable[..., Any], - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: + def call_function( + self, + the_function: Callable[..., Any], + args: Optional[Tuple["Argument", ...]] = None, + kwargs: Optional[Dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node represents a call to a Python callable, specified by ``the_function``. @@ -1328,20 +1477,24 @@ class Graph: The same insertion point and type expression rules apply for this method as :meth:`Graph.create_node`. """ - return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) + return self.create_node( + "call_function", the_function, args, kwargs, type_expr=type_expr + ) @compatibility(is_backward_compatible=True) - def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: + def node_copy( + self, node: Node, arg_transform: Callable[[Node], "Argument"] = lambda x: x + ) -> Node: """ Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from the graph of node to the graph of self. Example:: # Copying all the nodes in `g` into `new_graph` - g : torch.fx.Graph = ... + g: torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: - value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) Args: @@ -1357,12 +1510,14 @@ class Graph: kwargs = map_arg(node.kwargs, arg_transform) assert isinstance(args, tuple) assert isinstance(kwargs, dict) - result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type) + result_node = self.create_node( + node.op, node.target, args, kwargs, node.name, node.type + ) result_node.meta = copy.copy(node.meta) return result_node @compatibility(is_backward_compatible=True) - def output(self, result: 'Argument', type_expr: Optional[Any] = None): + def output(self, result: "Argument", type_expr: Optional[Any] = None): """ Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents a ``return`` statement in Python code. ``result`` is the value that should @@ -1380,9 +1535,11 @@ class Graph: The same insertion point and type expression rules apply for this method as ``Graph.create_node``. """ - return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) + return self.create_node( + op="output", target="output", args=(result,), type_expr=type_expr + ) - def _target_to_str(self, target : Target) -> str: + def _target_to_str(self, target: Target) -> str: if callable(target): op = target.__name__ else: @@ -1395,8 +1552,13 @@ class Graph: @compatibility(is_backward_compatible=True) def python_code( - self, root_module: str, *, - verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False + self, + root_module: str, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1457,36 +1619,50 @@ class Graph: with override_node_repr(self): return self._python_code( - root_module, namespace, - verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored + root_module, + namespace, + verbose=verbose, + include_stride=include_stride, + include_device=include_device, + colored=colored, ) def _python_code( - self, root_module: str, namespace: _Namespace, *, - verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, + self, + root_module: str, + namespace: _Namespace, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( - self.nodes, root_module, namespace, - verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored + self.nodes, + root_module, + namespace, + verbose=verbose, + include_stride=include_stride, + include_device=include_device, + colored=colored, ) - def __str__(self) -> str: """ Return a human-readable (not machine-readable) string representation of this Graph """ - placeholder_names : List[str] = [] + placeholder_names: List[str] = [] # This is a one-element array just so ``format_node`` can modify the closed # over value - maybe_return_typename : List[str] = [''] + maybe_return_typename: List[str] = [""] node_strs = [node.format_node(placeholder_names) for node in self.nodes] - param_str = ', '.join(placeholder_names) - s = f'graph({param_str}){maybe_return_typename[0]}:' + param_str = ", ".join(placeholder_names) + s = f"graph({param_str}){maybe_return_typename[0]}:" for node_str in node_strs: if node_str: - s += '\n ' + node_str + s += "\n " + node_str return s @compatibility(is_backward_compatible=True) @@ -1499,15 +1675,17 @@ class Graph: try: from tabulate import tabulate except ImportError: - print("`print_tabular` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) raise - node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] - for n in self.nodes] - print(tabulate(node_specs, - headers=['opcode', 'name', 'target', 'args', 'kwargs'])) + node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes] + print( + tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"]) + ) @compatibility(is_backward_compatible=True) def lint(self): @@ -1521,23 +1699,34 @@ class Graph: """ # Check topo order - def check_arg(arg : Node, n : Optional[Node] = None) -> None: - context_str = f' of Node \'{n}\' ' if n else ' ' + def check_arg(arg: Node, n: Optional[Node] = None) -> None: + context_str = f" of Node '{n}' " if n else " " if arg.graph is not self: - raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' - f'but was used as an argument! If you are copying nodes from another graph, make ' - f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') + raise RuntimeError( + f"Argument '{arg}'{context_str}does not belong to this Graph, " + f"but was used as an argument! If you are copying nodes from another graph, make " + f"sure to use ``arg_transform`` on node_copy() to remap values\n{self}" + ) if arg not in seen_values: - raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' - f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') + raise RuntimeError( + f"Argument '{arg}'{context_str}was used before it has been " + f"defined! Please check that Nodes in the graph are topologically ordered\n{self}" + ) - seen_names : Set[str] = set() - seen_values : Set[Node] = set() + seen_names: Set[str] = set() + seen_values: Set[Node] = set() for node in self.nodes: - if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: - raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') + if node.op not in [ + "placeholder", + "call_method", + "call_module", + "call_function", + "get_attr", + "output", + ]: + raise RuntimeError(f"Node {node} had unknown opcode {node.op}!") if node.graph is not self: - raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') + raise RuntimeError(f"Node '{node}' does not belong to this Graph!") if node not in self._find_nodes_lookup_table: raise RuntimeError(f"Node '{node}' is not added to the side table") map_arg(node.args, lambda arg: check_arg(arg, node)) @@ -1545,7 +1734,7 @@ class Graph: seen_values.add(node) if node.name in seen_names: - raise RuntimeError(f'Node redefined name {node.name}!') + raise RuntimeError(f"Node redefined name {node.name}!") seen_names.add(node.name) # Check targets are legit @@ -1553,49 +1742,64 @@ class Graph: num_warnings = 0 MAX_WARNINGS = 5 for node in self.nodes: - if node.op == 'call_function': + if node.op == "call_function": if not callable(node.target): - raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' - 'a Callable is expected') + raise ValueError( + f"Node {node} target {node.target} has type {torch.typename(node.target)} but " + "a Callable is expected" + ) else: if not isinstance(node.target, str): - raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' - 'a str is expected') - if node.op in ['get_attr', 'call_module']: - target_atoms = node.target.split('.') + raise ValueError( + f"Node {node} target {node.target} has type {torch.typename(node.target)} but " + "a str is expected" + ) + if node.op in ["get_attr", "call_module"]: + target_atoms = node.target.split(".") m_itr = self.owning_module for i, atom in enumerate(target_atoms): new_m_itr = getattr(m_itr, atom, None) - seen_qualname = '.'.join(target_atoms[:i]) + seen_qualname = ".".join(target_atoms[:i]) if new_m_itr is None: - raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' - f'{atom} of {seen_qualname}') - if (node.op == "call_module" - and not isinstance(new_m_itr, torch.nn.Module)): - raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module') - elif (node.op == "get_attr" - and not isinstance(new_m_itr, torch.nn.Module) - and not isinstance(new_m_itr, torch.nn.Parameter) - and atom not in m_itr._buffers): + raise RuntimeError( + f"Node {node} target {node.target} references nonexistent attribute " + f"{atom} of {seen_qualname}" + ) + if node.op == "call_module" and not isinstance( + new_m_itr, torch.nn.Module + ): + raise RuntimeError( + f"Node {node} target {node.target} {atom} of {seen_qualname} does " + "not reference an nn.Module" + ) + elif ( + node.op == "get_attr" + and not isinstance(new_m_itr, torch.nn.Module) + and not isinstance(new_m_itr, torch.nn.Parameter) + and atom not in m_itr._buffers + ): if num_warnings < MAX_WARNINGS: # Don't emit this warning too frequently, # for very large graphs this can become very expensive # from a performance perspective. - warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module, nn.Parameter, or buffer, which is ' - 'what \'get_attr\' Nodes typically target') + warnings.warn( + f"Node {node} target {node.target} {atom} of {seen_qualname} does " + "not reference an nn.Module, nn.Parameter, or buffer, which is " + "what 'get_attr' Nodes typically target" + ) num_warnings += 1 else: m_itr = new_m_itr if num_warnings > MAX_WARNINGS: warnings.warn( - f'Additional {num_warnings - MAX_WARNINGS} warnings ' - 'suppressed about get_attr references' + f"Additional {num_warnings - MAX_WARNINGS} warnings " + "suppressed about get_attr references" ) @compatibility(is_backward_compatible=True) - def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None): + def eliminate_dead_code( + self, is_impure_node: Optional[Callable[[Node], bool]] = None + ): """ Remove all dead code from the graph, based on each node's number of users, and whether the nodes have any side effects. The graph must be @@ -1664,7 +1868,7 @@ class Graph: @compatibility(is_backward_compatible=False) def on_generate_code( self, - make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc] + make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc], ): """Register a transformer function when python code is generated @@ -1690,6 +1894,7 @@ class Graph: gm: fx.GraphModule = ... + # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual @@ -1697,21 +1902,17 @@ class Graph: def insert_pdb(body): return ["import pdb; pdb.set_trace()\\n", *body] + # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): - gm.graph.on_generate_code( - lambda _: insert_pdb - ) + gm.graph.on_generate_code(lambda _: insert_pdb) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( - lambda body: insert_pdb( - current_trans(body) if current_trans - else body - ) + lambda body: insert_pdb(current_trans(body) if current_trans else body) ) ) @@ -1749,47 +1950,51 @@ class Graph: reflectable_magic_methods = { - 'add': '{} + {}', - 'sub': '{} - {}', - 'mul': '{} * {}', - 'floordiv': '{} // {}', - 'truediv': '{} / {}', - 'div': '{} / {}', - 'mod': '{} % {}', - 'pow': '{} ** {}', - 'lshift': '{} << {}', - 'rshift': '{} >> {}', - 'and_': '{} & {}', - 'or_': '{} | {}', - 'xor': '{} ^ {}', - 'getitem': '{}[{}]', - 'matmul': '{} @ {}', + "add": "{} + {}", + "sub": "{} - {}", + "mul": "{} * {}", + "floordiv": "{} // {}", + "truediv": "{} / {}", + "div": "{} / {}", + "mod": "{} % {}", + "pow": "{} ** {}", + "lshift": "{} << {}", + "rshift": "{} >> {}", + "and_": "{} & {}", + "or_": "{} | {}", + "xor": "{} ^ {}", + "getitem": "{}[{}]", + "matmul": "{} @ {}", } -magic_methods = dict({ - 'eq': '{} == {}', - 'ne': '{} != {}', - 'lt': '{} < {}', - 'gt': '{} > {}', - 'le': '{} <= {}', - 'ge': '{} >= {}', - 'pos': '+{}', - 'neg': '-{}', - 'invert': '~{}'}, **reflectable_magic_methods) +magic_methods = dict( + { + "eq": "{} == {}", + "ne": "{} != {}", + "lt": "{} < {}", + "gt": "{} > {}", + "le": "{} <= {}", + "ge": "{} >= {}", + "pos": "+{}", + "neg": "-{}", + "invert": "~{}", + }, + **reflectable_magic_methods, +) inplace_methods = { - 'iadd': '{} += {}', - 'iand': '{} &= {}', - 'ifloordiv': '{} //= {}', - 'ilshift': '{} <<= {}', - 'imod': '{} %= {}', - 'imul': '{} *= {}', - 'imatmul': '{} @= {}', - 'ior': '{} |= {}', - 'ipow': '{} **= {}', - 'irshift': '{} >>= {}', - 'isub': '{} -= {}', - 'itruediv': '{} /= {}', - 'ixor': '{} ^= {}', - 'setitem': '{}[{}] = {}', + "iadd": "{} += {}", + "iand": "{} &= {}", + "ifloordiv": "{} //= {}", + "ilshift": "{} <<= {}", + "imod": "{} %= {}", + "imul": "{} *= {}", + "imatmul": "{} @= {}", + "ior": "{} |= {}", + "ipow": "{} **= {}", + "irshift": "{} >>= {}", + "isub": "{} -= {}", + "itruediv": "{} /= {}", + "ixor": "{} ^= {}", + "setitem": "{}[{}] = {}", } diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 2328541511fd..e2da57696177 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -19,6 +19,7 @@ from torch.package import Importer, PackageExporter, PackageImporter, sys_import from ._compatibility import compatibility from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode + __all__ = [ "reduce_graph_module", "reduce_package_graph_module", @@ -386,11 +387,9 @@ class _WrappedCall: return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] except Exception as e: assert e.__traceback__ - topmost_framesummary: ( - traceback.FrameSummary - ) = traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[ - -1 - ] # type: ignore[arg-type] + topmost_framesummary: traceback.FrameSummary = ( + traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] + ) if "eval_with_key" in topmost_framesummary.filename: print( _WrappedCall._generate_error_message(topmost_framesummary), @@ -612,20 +611,20 @@ class {module_name}(torch.nn.Module): module_str = ( f"torch.load(r'{module_file}', weights_only=False) # {module_repr}" ) - model_str += f"{tab*2}self.{module_name} = {module_str}\n" + model_str += f"{tab * 2}self.{module_name} = {module_str}\n" for buffer_name, buffer in self._buffers.items(): if buffer is None: continue - model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950 for param_name, param in self._parameters.items(): if param is None: continue - model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + model_str += f"{tab * 2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" # noqa: B950 model_str += ( - f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + f"{tab * 2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" ) model_str += f"{_addindent(self.code, 4)}\n" @@ -667,7 +666,6 @@ class {module_name}(torch.nn.Module): mod: torch.nn.Module = self for item in prefix: - submod = getattr(mod, item, None) if submod is None: @@ -707,7 +705,6 @@ class {module_name}(torch.nn.Module): # Get the parent module for item in path: - if not hasattr(mod, item): return False @@ -743,9 +740,7 @@ class {module_name}(torch.nn.Module): used: List[str] = [] for node in self.graph.nodes: - if node.op == "call_module" or node.op == "get_attr": - # A list of strings representing the different parts # of the path. For example, `foo.bar.baz` gives us # ["foo", "bar", "baz"] diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index c75407583137..12a2070b586f 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -1,20 +1,24 @@ # mypy: allow-untyped-defs -from .graph_module import GraphModule -from ._lazy_graph_module import _make_graph_module -from .graph import Graph -from .node import Argument, Node, Target, map_arg, map_aggregate -from .proxy import Proxy -from ._symbolic_trace import Tracer -from ._compatibility import compatibility -from . import config -import torch.fx.traceback as fx_traceback -import torch -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import inspect from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import torch +import torch.fx.traceback as fx_traceback from torch.hub import tqdm -__all__ = ['Interpreter', 'Transformer'] +from . import config +from ._compatibility import compatibility +from ._lazy_graph_module import _make_graph_module +from ._symbolic_trace import Tracer +from .graph import Graph +from .graph_module import GraphModule +from .node import Argument, map_aggregate, map_arg, Node, Target +from .proxy import Proxy + + +__all__ = ["Interpreter", "Transformer"] + @compatibility(is_backward_compatible=True) class Interpreter: @@ -43,22 +47,22 @@ class Interpreter: method equivalents). We could subclass Interpreter like so:: class NegSigmSwapInterpreter(Interpreter): - def call_function(self, target : Target, - args : Tuple, kwargs : Dict) -> Any: + def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) - def call_method(self, target : Target, - args : Tuple, kwargs : Dict) -> Any: - if target == 'neg': + def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: + if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) + def fn(x): return torch.sigmoid(x).neg() + gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) @@ -74,15 +78,21 @@ class Interpreter: graph instead of `module.graph`, using the provided `module` argument to satisfy any requests for state. """ + @compatibility(is_backward_compatible=True) - def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None): + def __init__( + self, + module: torch.nn.Module, + garbage_collect_values: bool = True, + graph: Optional[Graph] = None, + ): self.module = module self.submodules = dict(self.module.named_modules()) if graph is not None: self.graph = graph else: self.graph = self.module.graph - self.env : Dict[Node, Any] = {} + self.env: Dict[Node, Any] = {} self.name = "Interpreter" self.garbage_collect_values = garbage_collect_values self.extra_traceback = True @@ -92,10 +102,10 @@ class Interpreter: # of a given node. This represents the *last* use of the node in the # execution order of the program, which we will use to free unused # values - node_to_last_use : Dict[Node, Node] = {} - self.user_to_last_uses : Dict[Node, List[Node]] = {} + node_to_last_use: Dict[Node, Node] = {} + self.user_to_last_uses: Dict[Node, List[Node]] = {} - def register_last_uses(n : Node, user : Node): + def register_last_uses(n: Node, user: Node): if n not in node_to_last_use: node_to_last_use[n] = user self.user_to_last_uses.setdefault(user, []).append(n) @@ -105,7 +115,12 @@ class Interpreter: map_arg(node.kwargs, lambda n: register_last_uses(n, node)) @compatibility(is_backward_compatible=True) - def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: + def run( + self, + *args, + initial_env: Optional[Dict[Node, Any]] = None, + enable_io_processing: bool = True, + ) -> Any: """ Run `module` via interpretation and return the result. @@ -128,10 +143,16 @@ class Interpreter: # position and extract those values. if enable_io_processing: args = self.graph.process_inputs(*args) - self.args_iter : Iterator[Any] = iter(args) - pbar = tqdm(total=len(self.graph.nodes), - desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", - initial=0, position=0, leave=True, disable=config.disable_progress, delay=0) + self.args_iter: Iterator[Any] = iter(args) + pbar = tqdm( + total=len(self.graph.nodes), + desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", + initial=0, + position=0, + leave=True, + disable=config.disable_progress, + delay=0, + ) for node in self.graph.nodes: pbar.update(1) @@ -147,7 +168,7 @@ class Interpreter: except Exception as e: if self.extra_traceback: msg = f"While executing {node.format_node()}" - msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg) + msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg) msg += f"\nOriginal traceback:\n{node.stack_trace}" e.args = (msg,) + e.args[1:] if isinstance(e, KeyError): @@ -158,9 +179,13 @@ class Interpreter: for to_delete in self.user_to_last_uses.get(node, []): del self.env[to_delete] - if node.op == 'output': + if node.op == "output": output_val = self.env[node] - return self.graph.process_outputs(output_val) if enable_io_processing else output_val + return ( + self.graph.process_outputs(output_val) + if enable_io_processing + else output_val + ) @compatibility(is_backward_compatible=True) def boxed_run(self, args_list): @@ -183,7 +208,7 @@ class Interpreter: yield @compatibility(is_backward_compatible=True) - def run_node(self, n : Node) -> Any: + def run_node(self, n: Node) -> Any: """ Run a specific node ``n`` and return the result. Calls into placeholder, get_attr, call_function, @@ -204,7 +229,9 @@ class Interpreter: # Main Node running APIs @compatibility(is_backward_compatible=True) - def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def placeholder( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: ``Interpreter`` maintains an internal iterator over @@ -222,7 +249,7 @@ class Interpreter: Any: The argument value that was retrieved. """ assert isinstance(target, str) - if target.startswith('*'): + if target.startswith("*"): # For a starred parameter e.g. `*args`, retrieve all # remaining values from the args list. return list(self.args_iter) @@ -233,10 +260,14 @@ class Interpreter: if len(args) > 0: return args[0] else: - raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si + raise RuntimeError( + f"Expected positional argument for parameter {target}, but one was not passed in!" + ) from si @compatibility(is_backward_compatible=True) - def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def get_attr( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute value from the ``Module`` hierarchy of ``self.module``. @@ -255,7 +286,9 @@ class Interpreter: return self.fetch_attr(target) @compatibility(is_backward_compatible=True) - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_function( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``call_function`` node and return the result. @@ -275,7 +308,9 @@ class Interpreter: return target(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_method( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``call_method`` node and return the result. @@ -297,7 +332,9 @@ class Interpreter: return getattr(self_obj, target)(*args_tail, **kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_module( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``call_module`` node and return the result. @@ -320,7 +357,9 @@ class Interpreter: return submod(*args, **kwargs) @compatibility(is_backward_compatible=True) - def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def output( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute an ``output`` node. This really just retrieves the value referenced by the ``output`` node and returns it. @@ -339,7 +378,7 @@ class Interpreter: # Helper methods @compatibility(is_backward_compatible=True) - def fetch_attr(self, target : str): + def fetch_attr(self, target: str): """ Fetch an attribute from the ``Module`` hierarchy of ``self.module``. @@ -349,16 +388,18 @@ class Interpreter: Return: Any: The value of the attribute. """ - target_atoms = target.split('.') + target_atoms = target.split(".") attr_itr = self.module for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): - raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}") + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i + 1])}" + ) attr_itr = getattr(attr_itr, atom) return attr_itr @compatibility(is_backward_compatible=True) - def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: + def fetch_args_kwargs_from_env(self, n: Node) -> Tuple[Tuple, Dict]: """ Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` from the current execution environment. @@ -376,7 +417,7 @@ class Interpreter: return args, kwargs @compatibility(is_backward_compatible=True) - def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: + def map_nodes_to_values(self, args: Argument, n: Node) -> Argument: """ Recursively descend through ``args`` and look up the concrete value for each ``Node`` in the current execution environment. @@ -386,13 +427,18 @@ class Interpreter: n (Node): Node to which ``args`` belongs. This is only used for error reporting. """ - def load_arg(n_arg : Node) -> Any: + + def load_arg(n_arg: Node) -> Any: if n_arg not in self.env: - raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' - f'to diagnose such issues') + raise RuntimeError( + f"Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() " + f"to diagnose such issues" + ) return self.env[n_arg] + return map_arg(args, load_arg) + @compatibility(is_backward_compatible=True) class Transformer(Interpreter): """ @@ -409,23 +455,29 @@ class Transformer(Interpreter): method equivalents). We could subclass ``Transformer`` like so:: class NegSigmSwapXformer(Transformer): - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_function( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) - def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - if target == 'neg': + def call_method( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: + if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) + def fn(x): return torch.sigmoid(x).neg() + gm = torch.fx.symbolic_trace(fn) - transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() + transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) @@ -452,7 +504,9 @@ class Transformer(Interpreter): self.tracer.root = module @compatibility(is_backward_compatible=True) - def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + def placeholder( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Proxy: """ Execute a ``placeholder`` node. In ``Transformer``, this is overridden to insert a new ``placeholder`` into the output @@ -467,10 +521,14 @@ class Transformer(Interpreter): """ assert isinstance(target, str) default_value = next(iter(args)) if args else inspect.Signature.empty - return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) + return Proxy( + self.new_graph.placeholder(target, default_value=default_value), self.tracer + ) @compatibility(is_backward_compatible=True) - def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + def get_attr( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Proxy: """ Execute a ``get_attr`` node. In ``Transformer``, this is overridden to insert a new ``get_attr`` node into the output @@ -487,16 +545,20 @@ class Transformer(Interpreter): return self.tracer.create_proxy("get_attr", target, args, kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_module( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: # Override so that the leaf module policy from `self.tracer` is respected. assert isinstance(target, str) submod = self.fetch_attr(target) return self.tracer.call_module(submod, submod.forward, args, kwargs) @compatibility(is_backward_compatible=True) - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_function( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: # Override so that functions that were wrapped are still wrapped. - return self.tracer.create_proxy('call_function', target, args, kwargs) + return self.tracer.create_proxy("call_function", target, args, kwargs) @compatibility(is_backward_compatible=True) def transform(self) -> GraphModule: @@ -507,8 +569,10 @@ class Transformer(Interpreter): with fx_traceback.preserve_node_meta(): result = super().run(enable_io_processing=False) if result is not None: - def strip_proxy(a : Union[Argument, Proxy]) -> Any: + + def strip_proxy(a: Union[Argument, Proxy]) -> Any: return a.node if isinstance(a, Proxy) else a + new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy)) # also preserve the metadata from the old output node, if it exists old_output_node = list(self.graph.nodes)[-1] @@ -516,5 +580,4 @@ class Transformer(Interpreter): for k, v in old_output_node.meta.items(): new_output_node.meta[k] = v - return _make_graph_module(self.module, self.new_graph) diff --git a/torch/fx/node.py b/torch/fx/node.py index 8c3461cbe23c..469b63403848 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -1,39 +1,71 @@ # Nodes represent a definition of a value in our graph of operators. -from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set +import builtins +import inspect +import types +import warnings +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union + +import torch +from torch._C import _NodeBase +from torch.fx.operator_schemas import ( + ArgsKwargsPair, + normalize_function, + normalize_module, +) + +from .._ops import ops as _ops from ._compatibility import compatibility from .immutable_collections import immutable_dict, immutable_list -import torch -import builtins -import types -import inspect -import warnings -from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair -from .._ops import ops as _ops -from torch._C import _NodeBase + if TYPE_CHECKING: from .graph import Graph -__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"] +__all__ = ["Node", "map_arg", "map_aggregate", "has_side_effect"] -BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, - torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, - torch.SymInt, torch.SymBool, torch.SymFloat] +BaseArgumentTypes = Union[ + str, + int, + float, + bool, + complex, + torch.dtype, + torch.Tensor, + torch.device, + torch.memory_format, + torch.layout, + torch._ops.OpOverload, + torch.SymInt, + torch.SymBool, + torch.SymFloat, +] base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] Target = Union[Callable[..., Any], str] -Argument = Optional[Union[ - Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types - List[Any], # actually Argument - Dict[str, Any], # actually Argument - slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing - range, - 'Node', - BaseArgumentTypes -]] +Argument = Optional[ + Union[ + Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + range, + "Node", + BaseArgumentTypes, + ] +] -_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']) +_legal_ops = dict.fromkeys( + [ + "placeholder", + "call_method", + "call_module", + "call_function", + "get_attr", + "output", + "root", + ] +) _side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { torch._C._set_grad_enabled, @@ -74,7 +106,8 @@ def _find_module_of_method(orig_method: Callable[..., Any]) -> str: for guess in [torch, torch.nn.functional]: if getattr(guess, name, None) is orig_method: return guess.__name__ - raise RuntimeError(f'cannot find module for {orig_method}') + raise RuntimeError(f"cannot find module for {orig_method}") + # Borrowed from CPython typing module # https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 @@ -86,22 +119,24 @@ def _type_repr(obj: object) -> str: else, we fall back on repr(obj). """ if isinstance(obj, type): - if obj.__module__ == 'builtins': + if obj.__module__ == "builtins": return obj.__qualname__ - return f'{obj.__module__}.{obj.__qualname__}' + return f"{obj.__module__}.{obj.__qualname__}" if obj is ...: - return '...' + return "..." if isinstance(obj, types.FunctionType): return obj.__name__ return repr(obj) + def _get_qualified_name(func: Callable[..., Any]) -> str: # things like getattr just appear in builtins if getattr(builtins, func.__name__, None) is func: return func.__name__ # torch.Tensor.{fn} - if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType)) - and func is getattr(torch.Tensor, func.__name__, None)): + if isinstance( + func, (types.MethodDescriptorType, types.WrapperDescriptorType) + ) and func is getattr(torch.Tensor, func.__name__, None): return f"torch.Tensor.{func.__name__}" name = func.__name__ if name == "": @@ -111,33 +146,45 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: except Exception as e: raise RuntimeError("Unable to represent lambda") from e module = _find_module_of_method(func) - module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + module = module.replace( + "torch._ops", "torch.ops" + ) # WAR for bug in how torch.ops assigns module # Fixup segment_reduce mismatch if module == "torch" and name == "segment_reduce": name = "_" + name - return f'{module}.{name}' + return f"{module}.{name}" -def _format_arg(arg: object, max_list_len: float = float('inf')) -> str: - if hasattr(arg, '_custom_fx_repr_fn'): + +def _format_arg(arg: object, max_list_len: float = float("inf")) -> str: + if hasattr(arg, "_custom_fx_repr_fn"): return arg._custom_fx_repr_fn() elif isinstance(arg, list): - items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) - maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' - return f'[{items}{maybe_len}]' + items = ", ".join( + _format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len + ) + maybe_len = ( + "" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]" + ) + return f"[{items}{maybe_len}]" elif isinstance(arg, tuple): - items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) - maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' - maybe_comma = ',' if len(arg) == 1 else '' - return f'({items}{maybe_comma}{maybe_len})' + items = ", ".join( + _format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len + ) + maybe_len = ( + "" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]" + ) + maybe_comma = "," if len(arg) == 1 else "" + return f"({items}{maybe_comma}{maybe_len})" elif isinstance(arg, dict): - items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items()) - return f'{{{items_str}}}' + items_str = ", ".join(f"{k}: {_format_arg(v)}" for k, v in arg.items()) + return f"{{{items_str}}}" if isinstance(arg, Node): - return '%' + str(arg) + return "%" + str(arg) else: return str(arg) + @compatibility(is_backward_compatible=True) class Node(_NodeBase): """ @@ -166,23 +213,31 @@ class Node(_NodeBase): - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement in the Graph printout. """ - _args: Tuple['Argument', ...] - _kwargs: Dict[str, 'Argument'] - graph: 'Graph' + + _args: Tuple["Argument", ...] + _kwargs: Dict[str, "Argument"] + graph: "Graph" name: str op: str - target: 'Target' - _input_nodes: Dict['Node', None] - users: Dict['Node', None] + target: "Target" + _input_nodes: Dict["Node", None] + users: Dict["Node", None] type: Optional[Any] _sort_key: Any - _repr_fn: Optional[Callable[['Node'], str]] + _repr_fn: Optional[Callable[["Node"], str]] meta: Dict[str, Any] @compatibility(is_backward_compatible=True) - def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', - args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], - return_type : Optional[Any] = None) -> None: + def __init__( + self, + graph: "Graph", + name: str, + op: str, + target: "Target", + args: Tuple["Argument", ...], + kwargs: Dict[str, "Argument"], + return_type: Optional[Any] = None, + ) -> None: """ Instantiate an instance of ``Node``. Note: most often, you want to use the Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather @@ -210,14 +265,18 @@ class Node(_NodeBase): of analyses. """ assert op in _legal_ops - if op == 'call_function': + if op == "call_function": if not callable(target): - raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' - 'but a Callable is expected') + raise ValueError( + f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} " + "but a Callable is expected" + ) else: if not isinstance(target, str): - raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' - 'but a str is expected') + raise ValueError( + f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} " + "but a str is expected" + ) super().__init__() # bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects @@ -225,9 +284,13 @@ class Node(_NodeBase): assign(self, "graph", graph) assign(self, "name", name) # unique name of value being created - assign(self, "op", op) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + assign( + self, "op", op + ) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr - assign(self, "target", target) # for method/module/function, the name of the method/module/function/attr + assign( + self, "target", target + ) # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add # All `Node`-valued inputs. Key is the Node, value is don't-care. @@ -280,7 +343,7 @@ class Node(_NodeBase): self._next = _next @property - def next(self) -> 'Node': + def next(self) -> "Node": """ Returns the next ``Node`` in the linked list of Nodes. @@ -291,7 +354,7 @@ class Node(_NodeBase): return self._next @property - def prev(self) -> 'Node': + def prev(self) -> "Node": """ Returns the previous ``Node`` in the linked list of Nodes. @@ -302,7 +365,7 @@ class Node(_NodeBase): return self._prev @compatibility(is_backward_compatible=True) - def prepend(self, x: 'Node') -> None: + def prepend(self, x: "Node") -> None: """ Insert x before this node in the list of nodes in the graph. Example:: @@ -316,7 +379,9 @@ class Node(_NodeBase): """ assert self.graph == x.graph, "Attempting to move a Node into a different Graph" if self == x: - warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.") + warnings.warn( + "Trying to prepend a node to itself. This behavior has no effect on the graph." + ) return x._remove_from_list() p = self._prev @@ -328,28 +393,28 @@ class Node(_NodeBase): nsk = x._next._sort_key if len(psk) > len(nsk): idx: int - *prefix, idx = psk[:len(nsk) + 1] + *prefix, idx = psk[: len(nsk) + 1] x._sort_key = (*prefix, idx + 1) elif len(psk) < len(nsk): - *prefix, idx = nsk[:len(psk) + 1] + *prefix, idx = nsk[: len(psk) + 1] x._sort_key = (*prefix, idx - 1) else: # same length, increase length by 1 x._sort_key = (*psk, 0) - def __gt__(self, other: 'Node') -> bool: + def __gt__(self, other: "Node") -> bool: return self._sort_key > other._sort_key - def __lt__(self, other: 'Node') -> bool: + def __lt__(self, other: "Node") -> bool: return self._sort_key < other._sort_key - def __ge__(self, other: 'Node') -> bool: + def __ge__(self, other: "Node") -> bool: return self > other or self == other - def __le__(self, other: 'Node') -> bool: + def __le__(self, other: "Node") -> bool: return self < other or self == other @compatibility(is_backward_compatible=True) - def append(self, x: 'Node') -> None: + def append(self, x: "Node") -> None: """ Insert ``x`` after this node in the list of nodes in the graph. Equivalent to ``self.next.prepend(x)`` @@ -376,7 +441,7 @@ class Node(_NodeBase): return self._args @args.setter - def args(self, a : Tuple[Argument, ...]) -> None: + def args(self, a: Tuple[Argument, ...]) -> None: """ Set the tuple of arguments to this Node. The interpretation of arguments depends on the node's opcode. See the ``fx.Graph`` docstring for more @@ -399,7 +464,7 @@ class Node(_NodeBase): return self._kwargs @kwargs.setter - def kwargs(self, k : Dict[str, Argument]) -> None: + def kwargs(self, k: Dict[str, Argument]) -> None: """ Set the dict of kwargs to this Node. The interpretation of arguments depends on the node's opcode. See the ``fx.Graph`` docstring for more @@ -410,7 +475,7 @@ class Node(_NodeBase): self.__update_args_kwargs(self._args, k) @property - def all_input_nodes(self) -> List['Node']: + def all_input_nodes(self) -> List["Node"]: """ Return all Nodes that are inputs to this Node. This is equivalent to iterating over ``args`` and ``kwargs`` and only collecting the values that @@ -424,7 +489,7 @@ class Node(_NodeBase): return list(self._input_nodes.keys()) @compatibility(is_backward_compatible=True) - def update_arg(self, idx : int, arg : Argument) -> None: + def update_arg(self, idx: int, arg: Argument) -> None: """ Update an existing positional argument to contain the new value ``arg``. After calling, ``self.args[idx] == arg``. @@ -439,7 +504,7 @@ class Node(_NodeBase): self.args = tuple(args) @compatibility(is_backward_compatible=True) - def insert_arg(self, idx : int, arg : Argument) -> None: + def insert_arg(self, idx: int, arg: Argument) -> None: """ Insert an positional argument to the argument list with given index. @@ -448,7 +513,9 @@ class Node(_NodeBase): idx (int): The index of the element in ``self.args`` to be inserted before. arg (Argument): The new argument value to insert into ``args`` """ - assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)" + assert ( + 0 <= idx <= len(self.args) + ), "insert_args index must be between 0 and len(self.args)" args_left = self.args[:idx] args_right = self.args[idx:] @@ -463,7 +530,7 @@ class Node(_NodeBase): new_use.users.setdefault(self) @compatibility(is_backward_compatible=True) - def update_kwarg(self, key : str, arg : Argument) -> None: + def update_kwarg(self, key: str, arg: Argument) -> None: """ Update an existing keyword argument to contain the new value ``arg``. After calling, ``self.kwargs[key] == arg``. @@ -490,13 +557,16 @@ class Node(_NodeBase): return self.meta.get("stack_trace", None) @stack_trace.setter - def stack_trace(self, trace : Optional[str]) -> None: + def stack_trace(self, trace: Optional[str]) -> None: self.meta["stack_trace"] = trace - def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None: + def __update_args_kwargs( + self, new_args: Tuple["Argument", ...], new_kwargs: Dict[str, "Argument"] + ) -> None: """ This API is internal. Do *not* call it directly. """ + def update_users_and_input_nodes(n: Any) -> Any: if isinstance(n, Node): self._input_nodes.setdefault(n) @@ -512,8 +582,12 @@ class Node(_NodeBase): # - Normalize list->immutable_list, dict->immutable_dict, etc # - Populate self._input_nodes # - Populate arg.users[self] for each arg - object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes)) - object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)) + object.__setattr__( + self, "_args", map_aggregate(new_args, update_users_and_input_nodes) + ) + object.__setattr__( + self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes) + ) def __repr__(self) -> str: if self._repr_fn: @@ -529,8 +603,8 @@ class Node(_NodeBase): """ if isinstance(target, str): return target - if hasattr(target, '__module__'): - name = getattr(target, '__name__', None) + if hasattr(target, "__module__"): + name = getattr(target, "__name__", None) if name is None: # Just to be defensive, if we don't have `__name__`, get the # qualname. Not sure if this happens for any members of `operator` @@ -538,16 +612,18 @@ class Node(_NodeBase): # things in `operator` have `_operator` as their __module__. # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__` return _get_qualified_name(target) # type: ignore[arg-type] - if target.__module__ == 'builtins': - return f'builtins.{name}' - elif target.__module__ == '_operator': - return f'operator.{name}' + if target.__module__ == "builtins": + return f"builtins.{name}" + elif target.__module__ == "_operator": + return f"operator.{name}" return _get_qualified_name(target) # type: ignore[arg-type] @compatibility(is_backward_compatible=True) - def format_node(self, - placeholder_names: Optional[List[str]] = None, - maybe_return_typename: Optional[List[str]] = None) -> Optional[str]: + def format_node( + self, + placeholder_names: Optional[List[str]] = None, + maybe_return_typename: Optional[List[str]] = None, + ) -> Optional[str]: """ Return a descriptive string representation of ``self``. @@ -576,37 +652,46 @@ class Node(_NodeBase): return a descriptive string representation of the current Node. """ - if self.op == 'placeholder': + if self.op == "placeholder": assert isinstance(self.target, str) arg_str = self.target - arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' + arg_str += arg_str + f": {_type_repr(self.type)}" if self.type else "" if placeholder_names: placeholder_names.append(arg_str) return None - maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' - default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' - return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' - elif self.op == 'get_attr': - maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' - return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ - f'{self.op}[target={self._pretty_print_target(self.target)}]' - elif self.op == 'output': + maybe_typename = f"{_type_repr(self.type)} " if self.type else "" + default_val = "(default=" + str(self.args[0]) + ")" if self.args else "" + return f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}" + elif self.op == "get_attr": + maybe_typename = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) + return ( + f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = " + f"{self.op}[target={self._pretty_print_target(self.target)}]" + ) + elif self.op == "output": if self.type and maybe_return_typename: - maybe_return_typename[0] = f' -> {_type_repr(self.type)}' - return f'return {self.args[0]}' + maybe_return_typename[0] = f" -> {_type_repr(self.type)}" + return f"return {self.args[0]}" else: - maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' - return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ - f'{self.op}[target={self._pretty_print_target(self.target)}](' \ - f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' + maybe_typename = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) + return ( + f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = " + f"{self.op}[target={self._pretty_print_target(self.target)}](" + f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})" + ) @compatibility(is_backward_compatible=True) - def replace_all_uses_with(self, - replace_with: 'Node', - delete_user_cb: Callable[['Node'], bool] = lambda user: True, - *, - propagate_meta: bool = False - ) -> List['Node']: + def replace_all_uses_with( + self, + replace_with: "Node", + delete_user_cb: Callable[["Node"], bool] = lambda user: True, + *, + propagate_meta: bool = False, + ) -> List["Node"]: """ Replace all uses of ``self`` in the Graph with the Node ``replace_with``. @@ -625,9 +710,10 @@ class Node(_NodeBase): The list of Nodes on which this change was made. """ if propagate_meta: - assert len(replace_with.meta) == 0, \ - 'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \ - 'but replace_with already has .meta keys' + assert len(replace_with.meta) == 0, ( + "Called node.replace_all_uses_with(replace_with, propagate_meta=True), " + "but replace_with already has .meta keys" + ) for k, v in self.meta.items(): replace_with.meta[k] = v to_process = list(self.users) @@ -638,7 +724,7 @@ class Node(_NodeBase): skipped.append(use_node) continue - def maybe_replace_node(n : Node) -> Node: + def maybe_replace_node(n: Node) -> Node: if n == self: return replace_with else: @@ -690,9 +776,12 @@ class Node(_NodeBase): @compatibility(is_backward_compatible=False) def normalized_arguments( - self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, - kwarg_types : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + self, + root: torch.nn.Module, + arg_types: Optional[Tuple[Any]] = None, + kwarg_types: Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, + ) -> Optional[ArgsKwargsPair]: """ Returns normalized arguments to Python targets. This means that `args/kwargs` will be matched up to the module/functional's @@ -715,17 +804,23 @@ class Node(_NodeBase): Returns NamedTuple ArgsKwargsPair, or `None` if not successful. """ - if self.op == 'call_function': + if self.op == "call_function": assert callable(self.target) - return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] - elif self.op == 'call_module': + return normalize_function( + self.target, + self.args, # type: ignore[arg-type] + self.kwargs, + arg_types, + kwarg_types, + ) + elif self.op == "call_module": assert isinstance(self.target, str) return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] return None @compatibility(is_backward_compatible=True) - def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None: + def replace_input_with(self, old_input: "Node", new_input: "Node") -> None: """ Loop through input nodes of ``self``, and replace all instances of ``old_input`` with ``new_input``. @@ -735,7 +830,8 @@ class Node(_NodeBase): old_input (Node): The old input node to be replaced. new_input (Node): The new input node to replace ``old_input``. """ - def maybe_replace_node(n : Node) -> Node: + + def maybe_replace_node(n: Node) -> Node: return new_input if n == old_input else n m = self.graph.owning_module @@ -756,7 +852,7 @@ class Node(_NodeBase): self.graph._graph_namespace._rename_object(self, name) def __setattr__(self, name: str, value: Any) -> None: - if name == 'name' and hasattr(self, "name"): + if name == "name" and hasattr(self, "name"): m = self.graph.owning_module if getattr(m, "_replace_hook", None): assert isinstance(value, str) @@ -764,9 +860,9 @@ class Node(_NodeBase): m._replace_hook(old=self, new=value, user=user) update = False if ( - hasattr(self, name) and - hasattr(self.graph, "_find_nodes_lookup_table") and - self in self.graph._find_nodes_lookup_table + hasattr(self, name) + and hasattr(self.graph, "_find_nodes_lookup_table") + and self in self.graph._find_nodes_lookup_table ): update = True self.graph._find_nodes_lookup_table.remove(self) @@ -774,6 +870,7 @@ class Node(_NodeBase): if update: self.graph._find_nodes_lookup_table.insert(self) + @compatibility(is_backward_compatible=True) def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: """ @@ -782,6 +879,7 @@ def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + @compatibility(is_backward_compatible=True) def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: """ @@ -790,7 +888,7 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: if isinstance(a, tuple): t = tuple([map_aggregate(elem, fn) for elem in a]) # Support NamedTuple (if it has `_fields`) by repacking into original type. - return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type] + return t if not hasattr(a, "_fields") else type(a)(*t) # type: ignore[arg-type] elif isinstance(a, list): return immutable_list([map_aggregate(elem, fn) for elem in a]) elif isinstance(a, dict): @@ -799,6 +897,10 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: dict.__setitem__(rv, k, map_aggregate(v, fn)) return rv elif isinstance(a, slice): - return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) + return slice( + map_aggregate(a.start, fn), + map_aggregate(a.stop, fn), + map_aggregate(a.step, fn), + ) else: return fn(a) diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 53f7099d7e68..f654b6c060e8 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -1,63 +1,100 @@ # mypy: allow-untyped-defs -import torch +import enum import inspect import numbers import types import typing -import enum import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING +from typing import ( + Any, + Callable, + cast, + Dict, + List, + NamedTuple, + Optional, + Tuple, + TYPE_CHECKING, +) + +import torch from torch._jit_internal import boolean_dispatched +from torch._ops import OpOverload, OpOverloadPacket + from ._compatibility import compatibility -from torch._ops import OpOverloadPacket, OpOverload + if TYPE_CHECKING: from .node import Argument -__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint", - "type_matches", "normalize_function", "normalize_module"] +__all__ = [ + "ArgsKwargsPair", + "check_for_mutable_operation", + "get_signature_for_torch_op", + "create_type_hint", + "type_matches", + "normalize_function", + "normalize_module", +] + @compatibility(is_backward_compatible=False) class ArgsKwargsPair(NamedTuple): """ Simple named tuple for wrapping args/kwargs pairs. """ + args: Tuple[Any, ...] kwargs: Dict[str, Any] -_manual_overrides : Dict[Callable, List[inspect.Signature]] = {} + +_manual_overrides: Dict[Callable, List[inspect.Signature]] = {} + def _nonzero_schemas(): signatures = [] def nonzero(self): pass + signatures.append(inspect.signature(nonzero)) - def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef] + def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef] pass + signatures.append(inspect.signature(nonzero)) return signatures + _manual_overrides[torch.nonzero] = _nonzero_schemas() + class _FakeGlobalNamespace: def __getattr__(self, name): - if name == 'torch': + if name == "torch": return torch - raise RuntimeError('Expected a torch namespace lookup') + raise RuntimeError("Expected a torch namespace lookup") -_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, - 'number' : numbers.Number, 'Future' : torch.jit.Future, - 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, - '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), - 'Storage': torch.UntypedStorage, - 't': typing.TypeVar('t')} + +_type_eval_globals = { + "Tensor": torch.Tensor, + "Device": torch.device, + "Layout": torch.layout, + "number": numbers.Number, + "Future": torch.jit.Future, + "AnyEnumType": enum.Enum, + "QScheme": torch.qscheme, + "__torch__": _FakeGlobalNamespace(), + "NoneType": type(None), + "Storage": torch.UntypedStorage, + "t": typing.TypeVar("t"), +} for k in dir(typing): _type_eval_globals[k] = getattr(typing, k) -def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: + +def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any: """ Convert a TorchScript type to a Python type (including subtypes) via eval'ing the annotation_str. _type_eval_globals sets up expressions @@ -65,9 +102,13 @@ def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: """ return eval(ts_type.annotation_str, _type_eval_globals) -def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: + +def _torchscript_schema_to_signature_impl( + ts_schema: torch._C.FunctionSchema, +) -> inspect.Signature: from inspect import Parameter - parameters : List[Parameter] = [] + + parameters: List[Parameter] = [] for arg in ts_schema.arguments: arg_type = _torchscript_type_to_python_type(arg.type) default = arg.default_value if arg.has_default_value() else Parameter.empty @@ -76,8 +117,12 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) - # argument name. Downstream, if someone converts that positional argument to a keyword # argument, the name mismatch will break things, so here we're going to normalize the # name to "input" - name = arg.name if arg.name != 'self' else 'input' - kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD + name = arg.name if arg.name != "self" else "input" + kind = ( + Parameter.KEYWORD_ONLY + if arg.kwarg_only + else Parameter.POSITIONAL_OR_KEYWORD + ) # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument if name == "from": assert kind == Parameter.POSITIONAL_OR_KEYWORD @@ -87,9 +132,18 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) - # This renders all previous arguments to positional only for idx, p in enumerate(parameters): assert p.kind == Parameter.POSITIONAL_OR_KEYWORD - parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation) - parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type)) - return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] + parameters[idx] = Parameter( + name=p.name, + kind=Parameter.POSITIONAL_ONLY, + default=p.default, + annotation=p.annotation, + ) + parameters.append( + Parameter(name=name, kind=kind, default=default, annotation=arg_type) + ) + return_types = [ + _torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns + ] if len(return_types) == 0: return_type = None elif len(return_types) == 1: @@ -99,9 +153,13 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) - return inspect.Signature(parameters, return_annotation=return_type) -_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {} -def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: +_SCHEMA_TO_SIGNATURE_CACHE: Dict[Tuple[str, str], inspect.Signature] = {} + + +def _torchscript_schema_to_signature( + ts_schema: torch._C.FunctionSchema, +) -> inspect.Signature: # Cached as it's called in the hot path of FakeTensor dispatch cache_key = ts_schema.name, ts_schema.overload_name cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key) @@ -112,8 +170,11 @@ def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> ins _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res return res + @compatibility(is_backward_compatible=False) -def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): +def check_for_mutable_operation( + target: Callable, args: Tuple["Argument", ...], kwargs: Dict[str, "Argument"] +): signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) if signatures and schemas: @@ -131,9 +192,11 @@ def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...] def throw_if_mutable(schema): if schema.is_mutable: - raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' - f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' - f'are not supported') + raise RuntimeError( + f"Tried to trace mutable operation {schema}. FX only supports functional " + f"code, so operations that mutate operands in-place (e.g. via `out` arguments) " + f"are not supported" + ) if len(matched_schemas) == 0: # Did not match any schema. Cannot check for mutation @@ -147,8 +210,9 @@ def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...] # do nothing. pass + @compatibility(is_backward_compatible=False) -def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): +def get_signature_for_torch_op(op: Callable, return_schemas: bool = False): """ Given an operator on the `torch` namespace, return a list of `inspect.Signature` objects corresponding to the overloads of that op.. May return `None` if a signature @@ -181,6 +245,7 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] return (signatures, schemas) if return_schemas else signatures + @compatibility(is_backward_compatible=False) def create_type_hint(x): """ @@ -198,11 +263,15 @@ def create_type_hint(x): if isinstance(x, (list, tuple)): # todo(chilli): Figure out the right way for mypy to handle this if isinstance(x, list): + def ret_type(x): return List[x] # type: ignore[valid-type] + else: + def ret_type(x): return Tuple[x, ...] + if len(x) == 0: return ret_type(Any) base_type = x[0] @@ -216,12 +285,15 @@ def create_type_hint(x): return ret_type(base_type) except Exception: # We tried to create a type hint for list but failed. - warnings.warn(f"We were not able to successfully create type hint from the type {x}") + warnings.warn( + f"We were not able to successfully create type hint from the type {x}" + ) return x + @compatibility(is_backward_compatible=False) -def type_matches(signature_type : Any, argument_type : Any): - sig_origin_type = getattr(signature_type, '__origin__', signature_type) +def type_matches(signature_type: Any, argument_type: Any): + sig_origin_type = getattr(signature_type, "__origin__", signature_type) if signature_type is argument_type: return True @@ -236,13 +308,14 @@ def type_matches(signature_type : Any, argument_type : Any): # int can be promoted to List[int] return True - if getattr(signature_type, '__origin__', None) in {list, List}: + if getattr(signature_type, "__origin__", None) in {list, List}: sig_el_type = signature_type.__args__[0] if not inspect.isclass(sig_el_type): warnings.warn( - f"Does not support nested parametric types, got {signature_type}. Please file a bug.") + f"Does not support nested parametric types, got {signature_type}. Please file a bug." + ) return False - if getattr(argument_type, '__origin__', None) in {list, List}: + if getattr(argument_type, "__origin__", None) in {list, List}: return issubclass(argument_type.__args__[0], sig_el_type) def is_homogeneous_tuple(t): @@ -267,11 +340,16 @@ def type_matches(signature_type : Any, argument_type : Any): return False + @compatibility(is_backward_compatible=False) def normalize_function( - target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, - kwarg_types : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + target: Callable, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + arg_types: Optional[Tuple[Any]] = None, + kwarg_types: Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, +) -> Optional[ArgsKwargsPair]: """ Returns normalized arguments to PyTorch functions. This means that `args/kwargs` will be matched up to the functional's @@ -308,14 +386,19 @@ def normalize_function( # branch signature for analysis. Otherwise, leave this un-normalized assert not isinstance(target, str) dispatched = boolean_dispatched[target] - if_true, if_false = dispatched['if_true'], dispatched['if_false'] - if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: + if_true, if_false = dispatched["if_true"], dispatched["if_false"] + if ( + inspect.signature(if_true).parameters + != inspect.signature(if_false).parameters + ): return None target_for_analysis = if_true assert callable(target_for_analysis) sig = inspect.signature(inspect.unwrap(target_for_analysis)) - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + sig, args, kwargs, normalize_to_only_use_kwargs + ) else: assert callable(target) torch_op_schemas = get_signature_for_torch_op(target) @@ -336,8 +419,9 @@ def normalize_function( pass elif len(matched_schemas) == 1: # Matched exactly one schema, unambiguous - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, - normalize_to_only_use_kwargs) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs + ) else: if arg_types is not None or kwarg_types is not None: arg_types = arg_types if arg_types else cast(Tuple[Any], ()) @@ -345,30 +429,49 @@ def normalize_function( for candidate_signature in torch_op_schemas: sig_matches = True try: - bound_types = candidate_signature.bind(*arg_types, **kwarg_types) + bound_types = candidate_signature.bind( + *arg_types, **kwarg_types + ) for arg_name, arg_type in bound_types.arguments.items(): param = candidate_signature.parameters[arg_name] - sig_matches = sig_matches and type_matches(param.annotation, arg_type) + sig_matches = sig_matches and type_matches( + param.annotation, arg_type + ) except TypeError: sig_matches = False if sig_matches: - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, - normalize_to_only_use_kwargs) + new_args_and_kwargs = ( + _args_kwargs_to_normalized_args_kwargs( + candidate_signature, + args, + kwargs, + normalize_to_only_use_kwargs, + ) + ) break else: # Matched more than one schema. In this situation, the caller must provide the types of # the arguments of the overload they expect. - schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) - raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' - f'the schema match was ambiguous! Please provide argument types to ' - f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') + schema_printouts = "\n".join( + str(schema) for schema in matched_schemas + ) + raise RuntimeError( + f"Tried to normalize arguments to {torch.typename(target)} but " + f"the schema match was ambiguous! Please provide argument types to " + f"the normalize_arguments() call. Available schemas:\n{schema_printouts}" + ) return new_args_and_kwargs + @compatibility(is_backward_compatible=False) def normalize_module( - root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + root: torch.nn.Module, + target: str, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, +) -> Optional[ArgsKwargsPair]: """ Returns normalized arguments to PyTorch modules. This means that `args/kwargs` will be matched up to the functional's @@ -391,22 +494,29 @@ def normalize_module( try: submod = root.get_submodule(target) except AttributeError as e: - raise RuntimeError(f"Tried to normalize node with target {target} but root did not " - f"have that target!") from e - if hasattr(submod.__class__, '__name__'): + raise RuntimeError( + f"Tried to normalize node with target {target} but root did not " + f"have that target!" + ) from e + if hasattr(submod.__class__, "__name__"): classname = submod.__class__.__name__ if getattr(torch.nn, classname, None) == submod.__class__: sig = inspect.signature(inspect.unwrap(submod.forward)) if kwargs is None: kwargs = {} - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, - normalize_to_only_use_kwargs) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + sig, args, kwargs, normalize_to_only_use_kwargs + ) return new_args_and_kwargs return None -def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], - kwargs : Dict[str, Any], - normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: + +def _args_kwargs_to_normalized_args_kwargs( + sig: inspect.Signature, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + normalize_to_only_use_kwargs: bool, +) -> Optional[ArgsKwargsPair]: """ Given a call target, args, and kwargs, return the arguments normalized into an ArgsKwargsPair, or None if the type signature is not supported by @@ -428,20 +538,22 @@ def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple # Don't currently support positional-only # or varargs (*args, **kwargs) signatures supported_parameter_types = { - inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + } if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): # Add an exception for one signature, which is common for random/uniform, i.e.: # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None # `from` is Python keyword and as such functions with that signature should have # positional-only args, but at the same time they could be dispatched as kwargs - if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']: + if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]: return None bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() - new_kwargs : Dict[str, Any] = {} - new_args : List[Any] = [] + new_kwargs: Dict[str, Any] = {} + new_args: List[Any] = [] for i, param in enumerate(sig.parameters): if not normalize_to_only_use_kwargs and i < len(args): new_args.append(bound_args.arguments[param]) diff --git a/torch/fx/passes/__init__.py b/torch/fx/passes/__init__.py index f83a2f248fcd..433d8818e259 100644 --- a/torch/fx/passes/__init__.py +++ b/torch/fx/passes/__init__.py @@ -1,12 +1,14 @@ -from . import graph_drawer -from . import graph_manipulation -from . import net_min_base -from . import operator_support -from . import param_fetch -from . import reinplace -from . import runtime_assert -from . import shape_prop -from . import split_module -from . import split_utils -from . import splitter_base -from . import tools_common +from . import ( + graph_drawer, + graph_manipulation, + net_min_base, + operator_support, + param_fetch, + reinplace, + runtime_assert, + shape_prop, + split_module, + split_utils, + splitter_base, + tools_common, +) diff --git a/torch/fx/passes/backends/cudagraphs.py b/torch/fx/passes/backends/cudagraphs.py index 0f48165b7dab..b98178f0d533 100644 --- a/torch/fx/passes/backends/cudagraphs.py +++ b/torch/fx/passes/backends/cudagraphs.py @@ -1,12 +1,13 @@ # mypy: allow-untyped-defs +import operator + import torch +from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport from torch.fx.passes.tools_common import CALLABLE_NODE_OPS -from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.utils import _pytree as pytree -import operator class CudaGraphsSupport(OperatorSupport): # TODO: why is submodules passed here @@ -27,7 +28,7 @@ class CudaGraphsSupport(OperatorSupport): def find_not_cuda(t): nonlocal found_not_cuda - if isinstance(t, torch.Tensor) and t.device.type != 'cuda': + if isinstance(t, torch.Tensor) and t.device.type != "cuda": found_not_cuda = True for n in node.all_input_nodes: @@ -40,6 +41,7 @@ class CudaGraphsSupport(OperatorSupport): return not found_not_cuda + def partition_cudagraphs(gm, inputs): """ Partition an FX graph into sub-GraphModules that can be validly run under @@ -51,7 +53,9 @@ def partition_cudagraphs(gm, inputs): supported_ops = CudaGraphsSupport() # TODO: single node partition may be wrong due to the pessimization # from copying in and out the data. Check in benchmarks, perhaps - partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True) + partitioner = CapabilityBasedPartitioner( + gm, supported_ops, allows_single_node_partition=True + ) partitions = partitioner.propose_partitions() fused_graph = partitioner.fuse_partitions(partitions) return fused_graph diff --git a/torch/fx/passes/dialect/common/cse_pass.py b/torch/fx/passes/dialect/common/cse_pass.py index 577f445e7b31..6a501f041d19 100644 --- a/torch/fx/passes/dialect/common/cse_pass.py +++ b/torch/fx/passes/dialect/common/cse_pass.py @@ -1,20 +1,45 @@ # mypy: allow-untyped-defs -from typing import Dict, Tuple, Any +from typing import Any, Dict, Tuple import torch +from torch.fx import Graph, GraphModule, Node from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.utils._pytree import tree_flatten -from torch.fx import GraphModule, Graph -from torch.fx import Node aten = torch.ops.aten # stateful ops are banned from CSE -rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950 +rand_ops = { + aten.dropout, + aten._fused_dropout, + aten._standard_gamma, + aten.bernoulli, + aten.multinomial, + aten.native_dropout, + aten.normal, + aten.poisson, + aten.binomial, + aten.rrelu, + aten.rand_like, + aten.rand, + aten.randint, + aten.randn, + aten.randperm, +} # noqa: E501,B950 -inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501 +inplace_ops = { + aten.add_, + aten.sub_, + aten.mul_, + aten.div_, + aten.pow_, + aten.lerp_, + aten.relu_, + aten.sigmoid_, + aten.tanh_, +} # noqa: E501 @torch.fx._compatibility.compatibility(is_backward_compatible=False) @@ -24,7 +49,6 @@ def get_CSE_banned_ops(): @torch.fx._compatibility.compatibility(is_backward_compatible=False) class CSEPass(PassBase): - def __init__(self, banned_ops=None): """ This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. @@ -58,20 +82,32 @@ class CSEPass(PassBase): result = p(traced_graph) print(result.graph_module) """ + def get_aten_target(node): - if hasattr(node.target, 'overloadpacket'): + if hasattr(node.target, "overloadpacket"): return node.target.overloadpacket return node.target modified = False new_graph = Graph() - env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph - hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph - token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token + env: Dict[ + Node, Node + ] = {} # map from node in the old graph to node in the new graph + hash_env: Dict[ + Tuple[torch._ops.OpOverload, int], Node + ] = {} # map from hash to a node in the new graph + token_map: Dict[ + Tuple[torch._ops.OpOverload, int], Dict[str, Any] + ] = {} # map from hash to token for n in graph_module.graph.nodes: # The placeholder, output, and get_attr nodes are copied to the new graph without change # do not CSE away random operations - if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops: + if ( + n.op == "placeholder" + or n.op == "output" + or n.op == "get_attr" + or get_aten_target(n) in self.banned_ops + ): new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' @@ -84,13 +120,19 @@ class CSEPass(PassBase): if isinstance(v, Node) and v in env: arg_list[i] = env[v] return tuple(arg_list), spec + args, args_spec = substitute(n.args) kwargs, kwargs_spec = substitute(n.kwargs) # each token corresponds to a unique node # nodes with the same token can be substituted - token = {"target": n.target, "args": args, "args_spec": args_spec, - "kwargs": kwargs, "kwargs_spec": kwargs_spec} + token = { + "target": n.target, + "args": args, + "args_spec": args_spec, + "kwargs": kwargs, + "kwargs_spec": kwargs_spec, + } # hash substituted args to a number, do not hash specs because specs are not hashable hash_arg = hash((args, kwargs)) diff --git a/torch/fx/passes/fake_tensor_prop.py b/torch/fx/passes/fake_tensor_prop.py index 2b40207e0f80..8036f5d0fd55 100644 --- a/torch/fx/passes/fake_tensor_prop.py +++ b/torch/fx/passes/fake_tensor_prop.py @@ -2,13 +2,15 @@ from typing import Optional import torch.fx +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.fx import Node -from torch.fx.node import map_aggregate from torch.fx._compatibility import compatibility -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor -from torch.fx.experimental.proxy_tensor import snapshot_fake, py_sym_types +from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake +from torch.fx.node import map_aggregate + + +__all__ = ["FakeTensorProp"] -__all__ = ['FakeTensorProp'] @compatibility(is_backward_compatible=False) class FakeTensorProp(torch.fx.Interpreter): @@ -24,7 +26,10 @@ class FakeTensorProp(torch.fx.Interpreter): module (GraphModule): The module to be executed mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node. """ - def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None): + + def __init__( + self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None + ): super().__init__(module) if mode is None: mode = FakeTensorMode() @@ -33,7 +38,10 @@ class FakeTensorProp(torch.fx.Interpreter): mode.reset_nt_tensor_id_counter() def run_node(self, n: Node): - from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings + from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + ) result = super().run_node(n) rebind_unbacked(self._mode.shape_env, n, result) @@ -52,8 +60,10 @@ class FakeTensorProp(torch.fx.Interpreter): meta = map_aggregate(result, extract_val) if meta is not None: - n.meta['val'] = meta - if (shape_env := self._mode.shape_env) and (symbol_to_path := compute_unbacked_bindings(shape_env, result)): + n.meta["val"] = meta + if (shape_env := self._mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, result) + ): n.meta["unbacked_bindings"] = symbol_to_path return result diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index 975b2b617178..9a1710c9721a 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -58,6 +58,7 @@ _WEIGHT_TEMPLATE = { } if HAS_PYDOT: + @compatibility(is_backward_compatible=False) class FxGraphDrawer: """ @@ -87,7 +88,12 @@ if HAS_PYDOT: self._dot_graphs = { name: self._to_dot( - graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace + graph_module, + name, + ignore_getattr, + ignore_parameters_and_buffers, + skip_node_names_in_args, + parse_stack_trace, ) } @@ -127,8 +133,8 @@ if HAS_PYDOT: >>> symbolic_traced = torch.fx.symbolic_trace(module) >>> # setup output file >>> import ubelt as ub - >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir() - >>> fpath = dpath / 'linear.svg' + >>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir() + >>> fpath = dpath / "linear.svg" >>> # draw the graph >>> g = FxGraphDrawer(symbolic_traced, "linear") >>> g.get_dot_graph().write_svg(fpath) @@ -148,7 +154,6 @@ if HAS_PYDOT: return self._dot_graphs def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: - template = { "shape": self.dot_graph_shape, "fillcolor": "#CAFFE3", @@ -161,7 +166,9 @@ if HAS_PYDOT: # Use a random color for each node; based on its name so it's stable. target_name = node._pretty_print_target(node.target) target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) - template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] + template["fillcolor"] = _HASH_COLOR_MAP[ + target_hash % len(_HASH_COLOR_MAP) + ] return template def _get_leaf_node( @@ -199,12 +206,11 @@ if HAS_PYDOT: full_file_name: str, truncate_to_last_n: int = 2, ): - splits = full_file_name.split('/') + splits = full_file_name.split("/") if len(splits) >= truncate_to_last_n: - return '/'.join(splits[-truncate_to_last_n:]) + return "/".join(splits[-truncate_to_last_n:]) return full_file_name - def _get_node_label( self, module: torch.fx.GraphModule, @@ -219,8 +225,7 @@ if HAS_PYDOT: elif isinstance(arg, dict): prefix, suffix = r"|kwargs={\l", r",\n}\l" arg_strs_list = [ - f"{k}: {_format_arg(v, max_list_len=8)}" - for k, v in arg.items() + f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items() ] else: # Fall back to nothing in unexpected case. return "" @@ -235,7 +240,6 @@ if HAS_PYDOT: arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") return arg_strs.replace("{", r"\{").replace("}", r"\}") - label = "{" + f"name=%{node.name}|op_code={node.op}\n" if node.op == "call_module": @@ -244,7 +248,10 @@ if HAS_PYDOT: extra = "" if hasattr(leaf_module, "__constants__"): extra = r"\n".join( - [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] + [ + f"{c}: {getattr(leaf_module, c)}" + for c in leaf_module.__constants__ + ] # type: ignore[union-attr] ) label += extra + r"\n" else: @@ -252,7 +259,10 @@ if HAS_PYDOT: if self.normalize_args: try: args, kwargs = normalize_function( # type: ignore[misc] - node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] + node.target, # type: ignore[arg-type] + node.args, # type: ignore[arg-type] + node.kwargs, + normalize_to_only_use_kwargs=True, ) except Exception: # Fallback to not normalizing if there's an exception. @@ -266,12 +276,12 @@ if HAS_PYDOT: label += _get_str_for_args_kwargs(kwargs) label += f"|num_users={len(node.users)}" + r"\n" - tensor_meta = node.meta.get('tensor_meta') + tensor_meta = node.meta.get("tensor_meta") label += self._tensor_meta_to_label(tensor_meta) # for original fx graph # print buf=buf0, n_origin=6 - buf_meta = node.meta.get('buf_meta', None) + buf_meta = node.meta.get("buf_meta", None) if buf_meta is not None: label += f"|buf={buf_meta.name}" + r"\n" label += f"|n_origin={buf_meta.n_origin}" + r"\n" @@ -281,8 +291,10 @@ if HAS_PYDOT: if parse_stack_trace and node.stack_trace is not None: parsed_stack_trace = _parse_stack_trace(node.stack_trace) fname = self._shorten_file_name(parsed_stack_trace.file) - label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n" - + label += ( + f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + + r"\n" + ) return label + "}" @@ -322,19 +334,43 @@ if HAS_PYDOT: assert "qscheme" in tm.qparams qscheme = tm.qparams["qscheme"] if qscheme in { - torch.per_tensor_affine, - torch.per_tensor_symmetric, + torch.per_tensor_affine, + torch.per_tensor_symmetric, }: result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" - result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" + result += ( + "|" + + "q_zero_point" + + "=" + + str(tm.qparams["zero_point"]) + + r"\n" + ) elif qscheme in { - torch.per_channel_affine, - torch.per_channel_symmetric, - torch.per_channel_affine_float_qparams, + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, }: - result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" - result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" - result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" + result += ( + "|" + + "q_per_channel_scale" + + "=" + + str(tm.qparams["scale"]) + + r"\n" + ) + result += ( + "|" + + "q_per_channel_zero_point" + + "=" + + str(tm.qparams["zero_point"]) + + r"\n" + ) + result += ( + "|" + + "q_per_channel_axis" + + "=" + + str(tm.qparams["axis"]) + + r"\n" + ) else: raise RuntimeError(f"Unsupported qscheme: {qscheme}") result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" @@ -363,7 +399,6 @@ if HAS_PYDOT: # "TB" means top-to-bottom rank direction in layout dot_graph = pydot.Dot(name, rankdir="TB") - buf_name_to_subgraph = {} for node in graph_module.graph.nodes: @@ -372,16 +407,22 @@ if HAS_PYDOT: style = self._get_node_style(node) dot_node = pydot.Node( - node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style + node.name, + label=self._get_node_label( + graph_module, node, skip_node_names_in_args, parse_stack_trace + ), + **style, ) current_graph = dot_graph - buf_meta = node.meta.get('buf_meta', None) + buf_meta = node.meta.get("buf_meta", None) if buf_meta is not None and buf_meta.n_origin > 1: buf_name = buf_meta.name if buf_name not in buf_name_to_subgraph: - buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name) + buf_name_to_subgraph[buf_name] = pydot.Cluster( + buf_name, label=buf_name + ) current_graph = buf_name_to_subgraph.get(buf_name) current_graph.add_node(dot_node) @@ -407,12 +448,14 @@ if HAS_PYDOT: if node.op == "call_module": leaf_module = self._get_leaf_node(graph_module, node) - if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule): + if not ignore_parameters_and_buffers and not isinstance( + leaf_module, torch.fx.GraphModule + ): get_module_params_or_buffers() for subgraph in buf_name_to_subgraph.values(): - subgraph.set('color', 'royalblue') - subgraph.set('penwidth', '2') + subgraph.set("color", "royalblue") + subgraph.set("penwidth", "2") dot_graph.add_subgraph(subgraph) for node in graph_module.graph.nodes: @@ -426,6 +469,7 @@ if HAS_PYDOT: else: if not TYPE_CHECKING: + @compatibility(is_backward_compatible=False) class FxGraphDrawer: def __init__( @@ -439,5 +483,7 @@ else: dot_graph_shape: Optional[str] = None, normalize_args: bool = False, ): - raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' - 'pydot through your favorite Python package manager.') + raise RuntimeError( + "FXGraphDrawer requires the pydot package to be installed. Please install " + "pydot through your favorite Python package manager." + ) diff --git a/torch/fx/passes/graph_manipulation.py b/torch/fx/passes/graph_manipulation.py index a573fea14362..ce9904fc500e 100644 --- a/torch/fx/passes/graph_manipulation.py +++ b/torch/fx/passes/graph_manipulation.py @@ -5,15 +5,18 @@ import torch from torch.fx._compatibility import compatibility from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule -from torch.fx.node import ( - map_arg, - Node, - Target, -) +from torch.fx.node import map_arg, Node, Target from torch.fx.passes.shape_prop import ShapeProp -__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta', - 'get_size_of_node'] + +__all__ = [ + "replace_target_nodes_with", + "size_bytes", + "get_size_of_all_nodes", + "get_tensor_meta", + "get_size_of_node", +] + @compatibility(is_backward_compatible=False) def replace_target_nodes_with( diff --git a/torch/fx/passes/infra/__init__.py b/torch/fx/passes/infra/__init__.py index 657b6a93014f..939157f1302e 100644 --- a/torch/fx/passes/infra/__init__.py +++ b/torch/fx/passes/infra/__init__.py @@ -1,2 +1 @@ - from . import pass_manager diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 4ffb5e3c3641..122545b8dccf 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -1,22 +1,24 @@ # mypy: allow-untyped-defs -from torch.fx.passes.utils.fuser_utils import fuse_by_partitions import collections import itertools import logging - from copy import copy from typing import Dict, Iterable, List, Optional, Sequence, Set from torch.fx.graph_module import GraphModule -from torch.fx.node import Node, _get_qualified_name +from torch.fx.node import _get_qualified_name, Node from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) + class Partition: - def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None): + def __init__( + self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None + ): self.id = id self.nodes = dict.fromkeys(nodes) if nodes is not None else {} @@ -32,6 +34,7 @@ class Partition: def size(self): return len(self.nodes) + class _DependencyViewer: def __init__(self, graph_module: GraphModule): self.upstreams = collections.defaultdict(set) @@ -55,15 +58,16 @@ class _DependencyViewer: def upstreams_of(self, node: Node) -> Set[Node]: return self.upstreams[node] -class CapabilityBasedPartitioner: - def __init__(self, - graph_module: GraphModule, - operator_support: OperatorSupportBase, - allows_single_node_partition: bool = False, - non_compute_ops: Optional[Sequence[str]] = None, - allowed_single_node_partition_ops: Optional[Sequence[str]] = None, - ) -> None: +class CapabilityBasedPartitioner: + def __init__( + self, + graph_module: GraphModule, + operator_support: OperatorSupportBase, + allows_single_node_partition: bool = False, + non_compute_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + ) -> None: self.graph_module = graph_module self.operator_support = operator_support self.allows_single_node_partition = allows_single_node_partition @@ -76,19 +80,21 @@ class CapabilityBasedPartitioner: self.dependency_viewer = _DependencyViewer(graph_module) def __is_node_supported(self, node: Node) -> bool: - return ( - self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node) + return self.operator_support.is_node_supported( + dict(self.graph_module.named_modules()), node ) def propose_partitions(self) -> List[Partition]: # partition_map is a mapping from partition id to a set of partition id's. # The value set contains all the partition ids that can be reached by doing a # DFS starting from the partition id in the key. - partition_map : Dict[int, Set] = collections.defaultdict(set) + partition_map: Dict[int, Set] = collections.defaultdict(set) # assumptions: nodes in candidate list is sorted in topological order - assignment: Dict[Node, int] = {} # mapping from node to partition_id - partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition + assignment: Dict[Node, int] = {} # mapping from node to partition_id + partitions_by_id: Dict[ + int, Partition + ] = {} # mapping from partition_id to partition new_partition_id = itertools.count() # try to merge partition other_id into partition self_id @@ -149,7 +155,9 @@ class CapabilityBasedPartitioner: # delete other partition del partitions_by_id[other_id] - partition_map[self_id] = partition_map[self_id].union(partition_map[other_id]) + partition_map[self_id] = partition_map[self_id].union( + partition_map[other_id] + ) del partition_map[other_id] return True @@ -223,16 +231,18 @@ class CapabilityBasedPartitioner: for node in self.graph_module.graph.nodes: is_tuple_output = True for user in node.users: - if user.op != "call_function" or \ - _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] + if ( + user.op != "call_function" + or _get_qualified_name(user.target) != "_operator.getitem" + ): # type: ignore[arg-type] is_tuple_output = False break # node has tuple outputs, re-assign all following getitem node into node's partition if is_tuple_output: - id = assignment.get(node, None) # type: ignore[arg-type] + id = assignment.get(node, None) # type: ignore[arg-type] for user in node.users: - if assignment.get(user, None) != id: # type: ignore[arg-type] + if assignment.get(user, None) != id: # type: ignore[arg-type] nodes_reassignment[user] = id # type: ignore[assignment] for node, id in nodes_reassignment.items(): merge_single_node(node, id) @@ -250,7 +260,10 @@ class CapabilityBasedPartitioner: assert callable(node.target) if _get_qualified_name(node.target) not in non_compute_ops: compute_node_count += 1 - if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops: + if ( + _get_qualified_name(node.target) + in self.allowed_single_node_partition_ops + ): compute_node_count += 1 if compute_node_count <= 1: partitions_to_remove.append(id) @@ -259,11 +272,17 @@ class CapabilityBasedPartitioner: logger.debug("Partitions proposed:") for id, partition in partitions_by_id.items(): - logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes]) + logger.debug( + "partition #%s: %s", id, [node.name for node in partition.nodes] + ) - return [partition for partition in partitions_by_id.values() if partition.size() > 0] + return [ + partition for partition in partitions_by_id.values() if partition.size() > 0 + ] - def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule: + def fuse_partitions( + self, partitions: List[Partition], prefix: str = "fused_" + ) -> GraphModule: logger.debug("Fusing partitions...") # fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ] return fuse_by_partitions( @@ -277,15 +296,23 @@ class CapabilityBasedPartitioner: non_compute_ops = set(self.non_compute_ops) def is_non_compute_node(node: Node): - return node.op == "call_function" and \ - _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] + return ( + node.op == "call_function" + and _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] + ) # cache transparent nodes transparent_input_nodes: Dict[Node, bool] = {} transparent_output_nodes: Dict[Node, bool] = {} - def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): - if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): + def is_transparent_input_node( + node: Node, partition: Set[Node], removed_nodes: Set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): return True if node in transparent_input_nodes: return transparent_input_nodes[node] @@ -299,14 +326,22 @@ class CapabilityBasedPartitioner: transparent_input_nodes[node] = False return False - def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): - if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): + def is_transparent_output_node( + node: Node, partition: Set[Node], removed_nodes: Set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): return True if node in transparent_output_nodes: return transparent_output_nodes[node] if is_non_compute_node(node): for output_n in node.users: - if not is_transparent_output_node(output_n, partition, removed_nodes): + if not is_transparent_output_node( + output_n, partition, removed_nodes + ): transparent_output_nodes[node] = False return False transparent_output_nodes[node] = True @@ -320,9 +355,12 @@ class CapabilityBasedPartitioner: # the set. remove_node: Set[Node] = set() for node in partition.nodes: - if is_non_compute_node(node) and \ - (is_transparent_input_node(node, set(partition.nodes), remove_node) or - is_transparent_output_node(node, set(partition.nodes), remove_node)): + if is_non_compute_node(node) and ( + is_transparent_input_node(node, set(partition.nodes), remove_node) + or is_transparent_output_node( + node, set(partition.nodes), remove_node + ) + ): remove_node.add(node) if len(remove_node) != 0: diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py index 3f5b64eafbb6..acf78d2581b5 100644 --- a/torch/fx/passes/infra/pass_base.py +++ b/torch/fx/passes/infra/pass_base.py @@ -3,11 +3,12 @@ import abc from collections import namedtuple from typing import Optional -from torch.fx.graph_module import GraphModule from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule -__all__ = ['PassResult', 'PassBase'] +__all__ = ["PassResult", "PassBase"] + @compatibility(is_backward_compatible=False) class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): @@ -16,9 +17,11 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): graph_module: The modified graph module modified: A flag for if the pass has modified the graph module """ + def __new__(cls, graph_module, modified): return super().__new__(cls, graph_module, modified) + @compatibility(is_backward_compatible=False) class PassBase(abc.ABC): """ diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 29540fa447eb..cea5f4f25c77 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -1,19 +1,21 @@ # mypy: allow-untyped-defs import inspect import logging -from queue import Queue from functools import wraps +from queue import Queue from typing import Callable, Dict, List import torch.nn as nn -from torch.fx.graph_module import GraphModule from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule from torch.fx.passes.infra.pass_base import PassResult + logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) -__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager'] +__all__ = ["pass_result_wrapper", "this_before_that_pass_constraint", "PassManager"] + @compatibility(is_backward_compatible=False) def pass_result_wrapper(fn: Callable) -> Callable: @@ -46,6 +48,7 @@ def pass_result_wrapper(fn: Callable) -> Callable: return wrapped_fn + def _validate_pass_schedule_constraint( constraint: Callable[[Callable, Callable], bool], passes: List[Callable] ) -> None: @@ -59,6 +62,7 @@ def _validate_pass_schedule_constraint( f" list." ) + def _topological_sort_passes( passes: List[Callable], constraints: List[Callable] ) -> List[Callable]: @@ -75,7 +79,7 @@ def _topological_sort_passes( return passes # Contruct a graph mapping nodes to a list of their users - graph: Dict[Callable, List[Callable]] = {p : [] for p in passes} + graph: Dict[Callable, List[Callable]] = {p: [] for p in passes} indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0) candidates: Queue = Queue() for a in passes: @@ -108,11 +112,14 @@ def _topological_sort_passes( # Check if there are unvisited nodes (aka cycles in the graph) cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) if len(cycle_passes) != 0: - error = f"Circular dependency detected within the following passes: {cycle_passes}" + error = ( + f"Circular dependency detected within the following passes: {cycle_passes}" + ) raise RuntimeError(error) return sorted_passes + @compatibility(is_backward_compatible=False) def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: """ @@ -123,9 +130,7 @@ def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable ``` passes = [pass_b, pass_a] - constraints = [ - this_before_that_pass_constraint(pass_a, pass_b) - ] + constraints = [this_before_that_pass_constraint(pass_a, pass_b)] ``` Args: @@ -231,7 +236,9 @@ class PassManager: sig = inspect.signature(check) if len(list(sig.parameters.values())) != 1: - raise TypeError("PassManager check function should only take in one variable, a module") + raise TypeError( + "PassManager check function should only take in one variable, a module" + ) setattr(self, "check", check) # noqa: B010 diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 6182972e670e..81f8a845e83f 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.fx - from torch.fx._compatibility import compatibility from torch.fx.node import map_arg @@ -21,6 +20,7 @@ from .tools_common import ( Tensors, ) + __all__ = [ "FxNetMinimizerBadModuleError", "FxNetMinimizerRunFuncError", @@ -37,7 +37,6 @@ class FxNetMinimizerBadModuleError(Exception): """ - @compatibility(is_backward_compatible=False) class FxNetMinimizerRunFuncError(Exception): """ @@ -45,7 +44,6 @@ class FxNetMinimizerRunFuncError(Exception): """ - @compatibility(is_backward_compatible=False) class FxNetMinimizerResultMismatchError(Exception): """ @@ -53,7 +51,6 @@ class FxNetMinimizerResultMismatchError(Exception): """ - @dataclass class _MinimizerSettingBase: """ @@ -109,14 +106,9 @@ class _MinimizerBase: ], settings: _MinimizerSettingBase, module_exporter: Optional[ - Callable[ - [Tensors, torch.fx.GraphModule, str], - None - ] - ] = None, - exclusion_fn: Optional[ - Callable[[NodeList, int, int], None] + Callable[[Tensors, torch.fx.GraphModule, str], None] ] = None, + exclusion_fn: Optional[Callable[[NodeList, int, int], None]] = None, ): assert isinstance(module, torch.fx.GraphModule) @@ -159,14 +151,18 @@ class _MinimizerBase: self.a_outputs[name] = sample_input[i] self.b_outputs[name] = sample_input[i] - def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors: + def run_a( + self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1 + ) -> TensorOrTensors: """ Run `mod` with `inputs` and generate output. The output will be compared with output of run_b(). """ raise RuntimeError("run_a() is not implemented.") - def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors: + def run_b( + self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1 + ) -> TensorOrTensors: """ Run `mod` with `inputs` and generate output. The output will be compared with output of run_a(). @@ -323,7 +319,7 @@ class _MinimizerBase: split_module: torch.fx.GraphModule, submod_name: str, output_names: Names, - report_idx: int = -1 + report_idx: int = -1, ): """ Run the submodule in `split_module` that has name `submod_name` @@ -388,10 +384,14 @@ class _MinimizerBase: report.append(f"Result mismatch for {result_key}") if self.module_exporter: self.module_exporter( - a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index] + a_input, + submodule, + str(result_key[0]) + "_cpu", # type: ignore[index] ) self.module_exporter( - b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index] + b_input, + submodule, + str(result_key[0]) + "_acc", # type: ignore[index] ) raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") @@ -418,7 +418,7 @@ class _MinimizerBase: self.reports.append(report) report.append(f"Binary search iteration {self.iteration}") report.append( - f"From node index {start_idx}:{first_node_name} to {end_idx-1}:{output_node_name}. " + f"From node index {start_idx}:{first_node_name} to {end_idx - 1}:{output_node_name}. " f"Size of the interested node list is {len(nodes)}" ) cur_nodes: NodeSet = set(nodes) @@ -428,7 +428,6 @@ class _MinimizerBase: self._run_and_compare(split_module, submod_name, [output_node_name]) except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError): - if len(nodes) == 1: report.append( f"This is the last node in the sub-module. " @@ -504,13 +503,13 @@ class _MinimizerBase: split_module, submod_name = self._build_submodule(cur_nodes) self._run_and_compare(split_module, submod_name, [node.name]) self.print_report(report) - except (FxNetMinimizerResultMismatchError): + except FxNetMinimizerResultMismatchError: culprits.add(node) report.append(f"Found culprit from numeric error: {node}") self.print_report(report) if not self.settings.find_all: return culprits - except (FxNetMinimizerRunFuncError): + except FxNetMinimizerRunFuncError: culprits.update(cur_nodes) report.append(f"Found culprit from run error: {node}") self.print_report(report) @@ -519,8 +518,9 @@ class _MinimizerBase: return culprits - - def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool) -> int: + def _block_traverse_impl( + self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool + ) -> int: """ Recursive block search implementation. find_last_node: If True, search for the last node which result in numerics difference @@ -529,7 +529,7 @@ class _MinimizerBase: report: List[str] = [] mid = (start_idx + end_idx) // 2 - cur_nodes_list: NodeList = nodes[:mid + 1] if find_last_node else nodes[mid:] + cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:] if self.exclusion_fn: self.exclusion_fn(cur_nodes_list, -1, -1) @@ -561,16 +561,20 @@ class _MinimizerBase: try: split_module, submod_name = self._build_submodule(cur_nodes) - self._run_and_compare(split_module, submod_name, [last_node_name], report_idx) + self._run_and_compare( + split_module, submod_name, [last_node_name], report_idx + ) except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): - report.append(f"Culprits found from node {first_node_name} to {last_node_name}.") + report.append( + f"Culprits found from node {first_node_name} to {last_node_name}." + ) if start_idx == mid: report.extend( [ "This is the last node in the sub-module. ", "Search in the current branch is successful with node :", - f"{start_idx}, node name: {nodes[start_idx].name}." + f"{start_idx}, node name: {nodes[start_idx].name}.", ] ) self.print_report(report) @@ -585,9 +589,13 @@ class _MinimizerBase: if find_last_node: return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) else: - return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node) + return self._block_traverse_impl( + nodes, mid + 1, end_idx, find_last_node + ) else: - report.append(f"Culprits not found from node start to {mid}:{nodes[mid].name}.") + report.append( + f"Culprits not found from node start to {mid}:{nodes[mid].name}." + ) if start_idx == mid: report.extend( @@ -607,12 +615,15 @@ class _MinimizerBase: self.print_report(report) if find_last_node: - return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node) + return self._block_traverse_impl( + nodes, mid + 1, end_idx, find_last_node + ) else: return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) - - def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> NodeSet: + def _block_traverse( + self, nodes: NodeList, find_last_node: Optional[bool] + ) -> NodeSet: """ Traverse topologically sorted node list Find minimium block (start_idx, end_idx) which contains the culprit @@ -639,10 +650,7 @@ class _MinimizerBase: self.print_report(last_node_report) end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True) last_node_report.extend( - [ - "Finish Pass 1", - f"Find end_idx = {end_idx}:{nodes[end_idx].name}" - ] + ["Finish Pass 1", f"Find end_idx = {end_idx}:{nodes[end_idx].name}"] ) self.print_report(last_node_report) @@ -650,25 +658,28 @@ class _MinimizerBase: if run_both or not find_last_node: first_node_report = ["Start searching for first node in culprit"] self.print_report(first_node_report) - start_idx = self._block_traverse_impl(nodes[0:end_idx + 1], start_idx, end_idx, False) + start_idx = self._block_traverse_impl( + nodes[0 : end_idx + 1], start_idx, end_idx, False + ) first_node_report.append("*" * 50) self.reports.append(first_node_report) first_node_report.extend( [ "Finish Pass 2", - f"Find start_idx = {start_idx}:{nodes[start_idx].name}" + f"Find start_idx = {start_idx}:{nodes[start_idx].name}", ] ) self.print_report(first_node_report) # step 3: form module with minimum culprits - culprits.update(nodes[start_idx:end_idx + 1]) - result_report = [f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"] + culprits.update(nodes[start_idx : end_idx + 1]) + result_report = [ + f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})" + ] self.reports.append(result_report) self.print_report(result_report) return culprits - def _defined_traverse(self, nodes: NodeList) -> NodeSet: """ run user defined `nodes` and determine if it is a culprit. @@ -735,7 +746,9 @@ class _MinimizerBase: return culprits - def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet: + def _skip_traverse_impl( + self, all_nodes: NodeList, start_idx: int, end_idx: int + ) -> NodeSet: """ Skip certain nodes in graph based on settings """ @@ -754,19 +767,19 @@ class _MinimizerBase: self.iteration += 1 report.append(f" Nodes block {self.iteration}.") report.append( - f"From node index {start_idx} to {end_idx-1}. " + f"From node index {start_idx} to {end_idx - 1}. " f"Size of the interested node list is {len(nodes)}" ) try: split_module, submod_name = self._build_submodule(cur_nodes) self._run_and_compare(split_module, submod_name, []) - except (FxNetMinimizerResultMismatchError): + except FxNetMinimizerResultMismatchError: culprits.update(cur_nodes) report.append(f"Found culprit from numeric error: {cur_nodes}") self.print_report(report) return culprits - except (FxNetMinimizerRunFuncError): + except FxNetMinimizerRunFuncError: culprits.update(cur_nodes) report.append(f"Found culprit from run error: {cur_nodes}") self.print_report(report) @@ -776,7 +789,6 @@ class _MinimizerBase: self.print_report(report) return set() - def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet: """ Skip certain nodes in graph based on settings @@ -787,7 +799,7 @@ class _MinimizerBase: culprits = set() while idx < num_nodes: node = all_nodes[idx] - if (node.name in skip_nodes): # skip the node + if node.name in skip_nodes: # skip the node if idx > start_idx: culprits = self._skip_traverse_impl(all_nodes, start_idx, idx) start_idx = idx + 1 @@ -797,8 +809,6 @@ class _MinimizerBase: return culprits - - def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList: """ Collect nodes in the model that between nodes with name of `start` and `end`. @@ -911,8 +921,10 @@ class _MinimizerBase: return self._accumulate_traverse(nodes) if self.settings.traverse_method == "skip": - if (skip_nodes is None): - raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.") + if skip_nodes is None: + raise RuntimeError( + "'skip_nodes' can't be None when 'traverse_method' is 'skip'." + ) return self._skip_traverse(nodes, skip_nodes) if self.settings.traverse_method == "defined": diff --git a/torch/fx/passes/operator_support.py b/torch/fx/passes/operator_support.py index 57edabc0a55a..53e8be37cecf 100644 --- a/torch/fx/passes/operator_support.py +++ b/torch/fx/passes/operator_support.py @@ -5,11 +5,19 @@ import typing as t import torch import torch.fx from torch.fx._compatibility import compatibility + from .shape_prop import TensorMetadata -from .tools_common import get_node_target, CALLABLE_NODE_OPS +from .tools_common import CALLABLE_NODE_OPS, get_node_target -__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain'] +__all__ = [ + "OperatorSupportBase", + "OperatorSupport", + "create_op_support", + "chain", + "OpSupports", + "any_chain", +] # fx.Node.target typename, as returned by `get_node_target()` TargetTypeName = str @@ -28,6 +36,7 @@ SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes] @compatibility(is_backward_compatible=False) class OperatorSupportBase(abc.ABC): """Interface for determining if a fx.Node is supported by a backend""" + @abc.abstractmethod def is_node_supported( self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node @@ -57,10 +66,7 @@ class OperatorSupport(OperatorSupportBase): _support_dict: SupportDict - def __init__( - self, - support_dict: t.Optional[SupportDict] = None - ): + def __init__(self, support_dict: t.Optional[SupportDict] = None): self._support_dict = support_dict or {} def is_node_supported( @@ -139,11 +145,13 @@ def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase `IsNodeSupported` has the same call signature as `OperatorSupportBase.is_node_supported` """ + class FunctionalOperatorSupport(OperatorSupportBase): def is_node_supported( - self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: return is_node_supported(submodules, node) + return FunctionalOperatorSupport() @@ -153,11 +161,10 @@ def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: instance by evaluating each input `OperatorSupportBase` instance, and returns False if any of it reports False. """ + def _chain(submods, node) -> bool: - return all( - x.is_node_supported(submods, node) - for x in op_support - ) + return all(x.is_node_supported(submods, node) for x in op_support) + return create_op_support(_chain) @@ -167,11 +174,10 @@ def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: instance by evaluating each input `OperatorSupportBase` instance, and returns True if any of it reports True. """ + def _any_chain(submods, node) -> bool: - return any( - x.is_node_supported(submods, node) - for x in op_support - ) + return any(x.is_node_supported(submods, node) for x in op_support) + return create_op_support(_any_chain) @@ -180,6 +186,7 @@ class OpSupports: """A set of atomic `OperatorSupportBase` instances that can be combined together to form more complex operator support logic. """ + @classmethod def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: """Report a node as non-supported, if any of its arguments is of dtype""" @@ -193,6 +200,7 @@ class OpSupports: if arg_dtype == dtype: return False return True + return create_op_support(_decline_if_input_dtype) @classmethod @@ -200,16 +208,22 @@ class OpSupports: """ If a node has a name that is in the disallow set, reported it as non-supported. """ + def _decline_if_node_in_names( submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node, ) -> bool: return node.name not in disallow_set + return create_op_support(_decline_if_node_in_names) def _get_arg_dtype(arg: torch.fx.Node) -> t.Any: assert isinstance(arg, torch.fx.Node) tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] - dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"] + dtype = ( + tensor_meta.dtype + if isinstance(tensor_meta, TensorMetadata) + else arg.meta["type"] + ) return dtype diff --git a/torch/fx/passes/param_fetch.py b/torch/fx/passes/param_fetch.py index 5979e29fcc6b..3eba16b06b03 100644 --- a/torch/fx/passes/param_fetch.py +++ b/torch/fx/passes/param_fetch.py @@ -1,35 +1,59 @@ -from torch.fx.graph_module import GraphModule from typing import Any, Callable, Dict, List, Tuple, Type + import torch import torch.nn as nn - from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + + +__all__ = [ + "default_matching", + "extract_attrs_for_lowering", + "lift_lowering_attrs_to_nodes", +] -__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes'] # Matching method matches the attribute name of current version to the attribute name of `target_version` @compatibility(is_backward_compatible=False) def default_matching(name: str, target_version: int) -> str: - """Default matching method - """ + """Default matching method""" return name + # This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. # The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. # If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), torch.nn.modules.conv.Conv2d: ( - 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching + 1, + [ + "weight", + "bias", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + ], + default_matching, + ), + torch.nn.modules.batchnorm.BatchNorm2d: ( + 2, + ["weight", "bias", "running_mean", "running_var", "eps"], + default_matching, ), - torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching), torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), torch.nn.modules.pooling.MaxPool2d: ( - 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching + 1, + ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], + default_matching, ), torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), } + @compatibility(is_backward_compatible=False) def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` @@ -41,21 +65,25 @@ def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: if type(mod) in module_fetch_book: version, param_to_fetch, matching_method = module_fetch_book[type(mod)] if version < mod._version: - raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " - "please upgrade the module_fetch_book, open an issue and @842974287 " - "or report a bug to AIACC team directly.") + raise RuntimeError( + f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " + "please upgrade the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly." + ) for attr in param_to_fetch: attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) else: - raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, " - "please add it to the module_fetch_book, open an issue and @842974287 " - "or report a bug to AIACC team directly.") + raise RuntimeError( + f"{torch.typename(mod)} is not in the module_fetch_book yet, " + "please add it to the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly." + ) return attrs_for_lowering + @compatibility(is_backward_compatible=False) def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: - """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module. - """ + """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.""" submodules = dict(fx_module.named_modules()) for node in fx_module.graph.nodes: @@ -63,4 +91,6 @@ def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: if isinstance(submodules[node.target], GraphModule): lift_lowering_attrs_to_nodes(submodules[node.target]) else: - node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target]) + node.attrs_for_lowering = extract_attrs_for_lowering( + submodules[node.target] + ) diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index d8d9e79a95e6..eb793aa6f11e 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs +import logging from functools import wraps from inspect import unwrap from typing import Callable, List, Optional -import logging + logger = logging.getLogger(__name__) @@ -15,6 +16,7 @@ __all__ = [ "these_before_those_pass_constraint", ] + # for callables which modify object inplace and return something other than # the object on which they act def inplace_wrapper(fn: Callable) -> Callable: @@ -36,6 +38,7 @@ def inplace_wrapper(fn: Callable) -> Callable: return wrapped_fn + def log_hook(fn: Callable, level=logging.INFO) -> Callable: """ Logs callable output. @@ -48,16 +51,13 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable: ``` def my_pass(d: Dict) -> bool: changed = False - if 'foo' in d: - d['foo'] = 'bar' + if "foo" in d: + d["foo"] = "bar" changed = True return changed - pm = PassManager( - passes=[ - inplace_wrapper(log_hook(my_pass)) - ] - ) + + pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))]) ``` Args: @@ -67,6 +67,7 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable: Returns: wrapped_fn (Callable[Type1, Type2]) """ + @wraps(fn) def wrapped_fn(gm): val = fn(gm) @@ -76,8 +77,11 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable: return wrapped_fn - -def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): +def loop_pass( + base_pass: Callable, + n_iter: Optional[int] = None, + predicate: Optional[Callable] = None, +): """ Convenience wrapper for passes which need to be applied multiple times. @@ -154,9 +158,7 @@ def these_before_those_pass_constraint(these: Callable, those: Callable): loop_pass(pass_a, 5), ] - constraints = [ - these_before_those_pass_constraint(pass_a, pass_b) - ] + constraints = [these_before_those_pass_constraint(pass_a, pass_b)] ``` Args: diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index c18a5ce4f570..3b61446a92f7 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -1,32 +1,38 @@ # mypy: allow-untyped-defs +import _operator +import itertools +from collections import defaultdict +from enum import Enum +from typing import Dict, Set + import torch +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.fx import Node from torch.fx._compatibility import compatibility -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor -from torch.utils._pytree import tree_map_only -from torch.utils import _pytree as pytree from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map_only -import _operator -from enum import Enum -import itertools -from typing import Set, Dict -from collections import defaultdict -__all__ = ['reinplace'] +__all__ = ["reinplace"] + class _ViewType(Enum): NonView = 0 SingleOutputView = 1 MultiOutputView = 2 + def _is_view_op(tgt): if tgt is not None and isinstance(tgt, torch._ops.OpOverload): schema = tgt._schema if len(schema.arguments) > 0: first_arg = schema.arguments[0] # check if op is a view - return first_arg.alias_info is not None and not first_arg.alias_info.is_write + return ( + first_arg.alias_info is not None and not first_arg.alias_info.is_write + ) + def _get_view_type(tgt) -> _ViewType: if tgt is not None and isinstance(tgt, torch._ops.OpOverload): @@ -36,7 +42,7 @@ def _get_view_type(tgt) -> _ViewType: # check if op is a view if first_arg.alias_info is not None and not first_arg.alias_info.is_write: # check if op is a multi-output view - if '*' in first_arg.alias_info.after_set: + if "*" in first_arg.alias_info.after_set: return _ViewType.MultiOutputView else: return _ViewType.SingleOutputView @@ -54,12 +60,11 @@ def _get_view_type(tgt) -> _ViewType: # to sanity check that our aliasing information is correct. @compatibility(is_backward_compatible=False) class _FunctionalizationMetadataProp(torch.fx.Interpreter): - def run_node(self, node: Node): self.node_counter += 1 result = super().run_node(node) - node.meta['fake_result'] = result - node.meta['node_idx'] = self.node_counter + node.meta["fake_result"] = result + node.meta["node_idx"] = self.node_counter # (1) Update metadata with the list of nodes that are used by this node # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. @@ -69,11 +74,11 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter): node_args = node_args[1:] # (2) Update metadata to track aliasing information about view tensor nodes. - if node.op == 'call_function': + if node.op == "call_function": view_type = _get_view_type(node.target) if view_type == _ViewType.SingleOutputView: assert isinstance(node.args[0], Node) - node.meta['view_of'] = node.args[0] + node.meta["view_of"] = node.args[0] elif view_type == _ViewType.MultiOutputView: self.multi_output_view_nodes[node] = node.args[0] @@ -95,38 +100,52 @@ class _FunctionalizationMetadataProp(torch.fx.Interpreter): # Note: we could also track indexing info here for multi-output views. # I don't think this metadata is strictly needed for de-functionalization. assert isinstance(maybe_base_of_view, Node) - node.meta['view_of'] = maybe_base_of_view + node.meta["view_of"] = maybe_base_of_view - if 'view_of' in node.meta: + if "view_of" in node.meta: # We're linking the current node with its first argument as views. # Assert here that this is actually the case, and their storages are the same. - assert isinstance(node.meta['fake_result'], FakeTensor) - assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) - view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) - base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage()) + assert isinstance(node.meta["fake_result"], FakeTensor) + assert isinstance(node.meta["view_of"].meta["fake_result"], FakeTensor) + view_storage = StorageWeakRef(node.meta["fake_result"]._typed_storage()) + base_storage = StorageWeakRef( + node.meta["view_of"].meta["fake_result"]._typed_storage() + ) assert view_storage == base_storage return result - - def propagate(self, *args): self.multi_output_view_nodes = {} self.node_counter = -1 with FakeTensorMode() as mode: - fake_args = [mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args] + fake_args = [ + mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args + ] return super().run(*fake_args) + def _schemas_match(functional_schema, inplace_schema): - names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name - arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all( - a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)) + names_match = ( + inplace_schema.name.endswith("_") + and inplace_schema.name[:-1] == functional_schema.name + ) + arg_types_match = len(functional_schema.arguments) == len( + inplace_schema.arguments + ) and all( + a1.type == a2.type + for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments) + ) # for the inplace op, its first argument should be mutable - assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write + assert ( + inplace_schema.arguments[0].alias_info is not None + and inplace_schema.arguments[0].alias_info.is_write + ) # and its remaining arguments shouldn't be. assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) return names_match and arg_types_match + # TODO: this should be beefed up to be able to properly re-inplace with: # - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) # - out= ops (e.g. angle -> angle.out) @@ -143,17 +162,20 @@ def _maybe_get_inplace_op(op): op_namespace = op.__module__.split(".")[-1] op_base_name = op.overloadpacket.__name__ maybe_namespace_module = getattr(torch.ops, op_namespace) - maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None) + maybe_inplace_op = ( + None + if maybe_namespace_module is None + else getattr(maybe_namespace_module, f"{op_base_name}_", None) + ) if maybe_inplace_op is None: return None inplace_overloads = [ - getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads() + getattr(maybe_inplace_op, overload_name) + for overload_name in maybe_inplace_op.overloads() ] inplace_overloads_with_matching_schemas = [ - f - for f in inplace_overloads - if _schemas_match(op._schema, f._schema) + f for f in inplace_overloads if _schemas_match(op._schema, f._schema) ] # Just because foo() and foo_() are both existing operators, # They aren't guaranteed to have compatible schemas. @@ -165,6 +187,7 @@ def _maybe_get_inplace_op(op): inplace_op = inplace_overloads_with_matching_schemas[0] return inplace_op + _VIEW_INVERSE_MAP = { torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, @@ -172,6 +195,7 @@ _VIEW_INVERSE_MAP = { torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, } + # This function, given a set of set of (aliased) tensor nodes, # Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index # in the node ordering. @@ -186,17 +210,21 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): usage_nodes = t.users for n in usage_nodes: # We only care about usages after the current node - if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index: + if "node_idx" not in n.meta or n.meta["node_idx"] <= op_index: continue # We also don't care about intermediate view ops. # They only matter if their output is then used elsewhere # (either in an out-of-place op, or as an output to the function). if n in tensor_aliases: - if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem: + if ( + isinstance(n.target, torch._ops.OpOverload) + or n.target == _operator.getitem + ): continue nodes_used_after.add(n) return nodes_used_after + # Given an op that we're trying to re-inplace, "b = foo(a)", # And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" # Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: @@ -204,23 +232,27 @@ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): # (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" # (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata # as "alias" -def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]: +def _get_view_inverse_node_usages( + later_node_usages: Set[Node], self_aliases: Set[Node] +) -> Set[Node]: def matching_view_metadata(a, b): - return a.size() == b.size() and \ - a.stride() == b.stride() and \ - a.storage_offset() == b.storage_offset() + return ( + a.size() == b.size() + and a.stride() == b.stride() + and a.storage_offset() == b.storage_offset() + ) view_inverse_nodes = set() # Go through them in node order, so we can see chains of view_scatter ops. - for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']): + for n in sorted(later_node_usages, key=lambda x: x.meta["node_idx"]): if n.target not in _VIEW_INVERSE_MAP: continue base = n.args[0] mutated_view = n.args[1] assert isinstance(base, Node) - assert isinstance(base.meta['fake_result'], FakeTensor) + assert isinstance(base.meta["fake_result"], FakeTensor) assert isinstance(mutated_view, Node) - assert isinstance(mutated_view.meta['fake_result'], FakeTensor) + assert isinstance(mutated_view.meta["fake_result"], FakeTensor) # Check that this view_inverse op actually corresponds to taking doing the inverse # of one of our existing self_alias nodes. original_view = _VIEW_INVERSE_MAP[n.target] @@ -229,18 +261,21 @@ def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Se # that was created from some op `alias = foo(base, args...)` # such that the current _scatter op "inverts" that foo call. # We can check that by running the original op again, and checking that the strides match. - if 'view_of' not in self_alias.meta: + if "view_of" not in self_alias.meta: continue - self_alias_base = self_alias.meta['view_of'] + self_alias_base = self_alias.meta["view_of"] try: # The we're trying to re-use the args from the view_scatter call inside of the corresponding # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse # of the current alias we're looking at. - view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs) - expected_metadata = self_alias.meta['fake_result'] + view_replay_metadata = original_view( + self_alias_base.meta["fake_result"], *n.args[2:], **n.kwargs + ) + expected_metadata = self_alias.meta["fake_result"] # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. - if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \ - matching_view_metadata(view_replay_metadata, expected_metadata): + if matching_view_metadata( + self_alias_base.meta["fake_result"], base.meta["fake_result"] + ) and matching_view_metadata(view_replay_metadata, expected_metadata): view_inverse_nodes.add(n) except Exception: continue @@ -471,25 +506,29 @@ def reinplace(gm, *sample_args): # NOTE: later, we'll need to add an optimization for fully recovering performance # on programs that mutate inputs. input_storages = { - StorageWeakRef( - node.meta['fake_result']._typed_storage() - ) for node in gm.graph.nodes if (node.op == 'placeholder' and isinstance(node.meta['fake_result'], torch.Tensor))} + StorageWeakRef(node.meta["fake_result"]._typed_storage()) + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.meta["fake_result"], torch.Tensor) + ) + } # We also need to know for a given node, what are all of its aliasing nodes. storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) for n in gm.graph.nodes: - if 'fake_result' in n.meta: + if "fake_result" in n.meta: # Tree-mapping because some ops can return lists of tensors. def _add_to_map(x): if isinstance(x, FakeTensor): storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) - pytree.tree_map_(_add_to_map, n.meta['fake_result']) + + pytree.tree_map_(_add_to_map, n.meta["fake_result"]) # inplace-ify functional ops, subject to the constraints written below. all_later_view_inverse_nodes_to_delete = set() for node in gm.graph.nodes: - if node.op == 'call_function': - + if node.op == "call_function": # Today, the re-inplace pass on directly acts on: # - functional ops with an inplace variant # - {view}_scatter ops that can be potentially removed from the graph. @@ -512,8 +551,8 @@ def reinplace(gm, *sample_args): # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), # this is probably an optimization to revisit later). self_arg = node.args[0] - self_flattened = pytree.tree_leaves(self_arg.meta['fake_result']) - node_flattened = pytree.tree_leaves(node.meta['fake_result']) + self_flattened = pytree.tree_leaves(self_arg.meta["fake_result"]) + node_flattened = pytree.tree_leaves(node.meta["fake_result"]) self_has_wrong_metadata = False if len(self_flattened) == len(node_flattened): for self_meta, node_meta in zip(self_flattened, node_flattened): @@ -532,7 +571,9 @@ def reinplace(gm, *sample_args): continue # Step 1b: ensure that the op we're trying to re-inplace isn't a program input - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) + self_arg_storage = StorageWeakRef( + self_arg.meta["fake_result"]._typed_storage() + ) if self_arg_storage in input_storages: # TODO: later, add the optimization for handling `copy_()` calls in the graph. continue @@ -542,14 +583,20 @@ def reinplace(gm, *sample_args): # so we prevent re-inplacing in this case. continue - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) + self_arg_storage = StorageWeakRef( + self_arg.meta["fake_result"]._typed_storage() + ) self_aliases = storage_to_nodes[self_arg_storage] # First, we find all later usages of any of the aliases of self_arg. - later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx']) + later_node_usages = _get_all_later_node_usages( + self_aliases, node.meta["node_idx"] + ) # Then, we check if any of those later usages are actually view_scatter ops # that are safe to fully remove. - later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases) + later_view_inverse_node_usages = _get_view_inverse_node_usages( + later_node_usages, self_aliases + ) # Step 2: Check to see if the input to the op is re-used later in the graph. # If not (same goes for its aliases), then this op is safe to re-in place. @@ -565,7 +612,10 @@ def reinplace(gm, *sample_args): # we would prefer to remove it from the graph entirely, # and instead copy_() the slice directly into the larger tensor. # See the description of the algorithm for a full example. - if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete: + if ( + node.target in _VIEW_INVERSE_MAP + and node not in all_later_view_inverse_nodes_to_delete + ): view_op = _VIEW_INVERSE_MAP[node.target] # Before: # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) @@ -576,13 +626,23 @@ def reinplace(gm, *sample_args): mutated_slice_node = node.args[1] remaining_slice_args = node.args[2:] slice_node = gm.graph.create_node( - 'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs) + "call_function", + view_op, + (self_arg,) + tuple(remaining_slice_args), + node.kwargs, + ) gm.graph.create_node( - 'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {}) + "call_function", + torch.ops.aten.copy_.default, + ( + slice_node, + mutated_slice_node, + ), + {}, + ) # Add the slice_scatter node to our "nodes to delete" list. all_later_view_inverse_nodes_to_delete.add(node) - else: # Step 3b: Check to see if this operator has an inplace variant. maybe_inplace_op = _maybe_get_inplace_op(node.target) @@ -597,19 +657,29 @@ def reinplace(gm, *sample_args): # Hmm... morally I think we also want to keep the `fake_result` metadata # up to date here, but I'm not sure how easy it is to do. # Maybe it's fine to wait until the end of the pass to update it. - curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) - storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) - storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) + curr_node_storage = StorageWeakRef( + node.meta["fake_result"]._typed_storage() + ) + storage_to_nodes[self_arg_storage].update( + storage_to_nodes[curr_node_storage] + ) + storage_to_nodes[curr_node_storage].update( + storage_to_nodes[self_arg_storage] + ) # Need to remember the view_scatter view nodes we found so we can remove them alter. - all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages) + all_later_view_inverse_nodes_to_delete.update( + later_view_inverse_node_usages + ) # Step 4: # Now that we've replaced b = a.foo() with a.foo_(), # We need to replace any later usages of "b" with "a" for old in itertools.chain([node], later_view_inverse_node_usages): new = old.args[0] - nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']] + nodes_to_update = [ + n for n in old.users if n.meta["node_idx"] > node.meta["node_idx"] + ] for node_to_update in nodes_to_update: def replace_arg(a): @@ -618,21 +688,29 @@ def reinplace(gm, *sample_args): return a # First, replace usages of "b" with "a" - node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args) - node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs) + node_to_update.args = tree_map_only( + Node, replace_arg, node_to_update.args + ) + node_to_update.kwargs = tree_map_only( + Node, replace_arg, node_to_update.kwargs + ) # Second, update our storage_to_nodes data structure. - old_flattened_res = pytree.tree_leaves(old.meta['fake_result']) - node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result']) + old_flattened_res = pytree.tree_leaves(old.meta["fake_result"]) + node_flattened_res = pytree.tree_leaves( + node_to_update.meta["fake_result"] + ) old_res_storage = { - StorageWeakRef( - x._typed_storage() - ) for x in old_flattened_res if isinstance(x, FakeTensor)} + StorageWeakRef(x._typed_storage()) + for x in old_flattened_res + if isinstance(x, FakeTensor) + } node_res_storage = { - StorageWeakRef( - x._typed_storage() - ) for x in node_flattened_res if isinstance(x, FakeTensor)} + StorageWeakRef(x._typed_storage()) + for x in node_flattened_res + if isinstance(x, FakeTensor) + } # This will happen if we're updating a view op, e.g. # e.g. replacing @@ -644,12 +722,17 @@ def reinplace(gm, *sample_args): # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, # or multiple tensors that all share the same storage. # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. - if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: - new_flattened_res = pytree.tree_leaves(new.meta['fake_result']) + if ( + len(old_res_storage) == 1 + and len(node_res_storage) == 1 + and old_res_storage == node_res_storage + ): + new_flattened_res = pytree.tree_leaves(new.meta["fake_result"]) new_res_storage = { - StorageWeakRef( - x._typed_storage() - ) for x in new_flattened_res if isinstance(x, FakeTensor)} + StorageWeakRef(x._typed_storage()) + for x in new_flattened_res + if isinstance(x, FakeTensor) + } assert len(new_res_storage) == 1 (new_ref,) = new_res_storage (node_ref,) = node_res_storage @@ -666,6 +749,5 @@ def reinplace(gm, *sample_args): for to_delete in all_later_view_inverse_nodes_to_delete: gm.graph.erase_node(to_delete) - gm.recompile() return gm diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index dcaee3f82113..4931e840707e 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -1,17 +1,19 @@ # mypy: ignore-errors +import traceback +from typing import Any, Dict, NamedTuple, Optional, Tuple + import torch import torch.fx -import traceback - from torch._dispatch.python import enable_python_dispatcher -from torch.fx.node import Node, map_aggregate -from typing import Any, Tuple, NamedTuple, Optional, Dict -from torch.fx._compatibility import compatibility from torch._guards import detect_fake_mode from torch._subclasses.meta_utils import is_sparse_any +from torch.fx._compatibility import compatibility +from torch.fx.node import map_aggregate, Node + + +__all__ = ["TensorMetadata", "ShapeProp"] -__all__ = ['TensorMetadata', 'ShapeProp'] @compatibility(is_backward_compatible=True) class TensorMetadata(NamedTuple): @@ -19,17 +21,20 @@ class TensorMetadata(NamedTuple): # about a tensor within a PyTorch program. # General Tensor metadata - shape : torch.Size - dtype : torch.dtype - requires_grad : bool - stride : Tuple[int, ...] - memory_format : Optional[torch.memory_format] + shape: torch.Size + dtype: torch.dtype + requires_grad: bool + stride: Tuple[int, ...] + memory_format: Optional[torch.memory_format] # Quantization metadata - is_quantized : bool + is_quantized: bool qparams: Dict[str, Any] -def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata: + +def _extract_tensor_metadata( + result: torch.Tensor, include_contiguity=True +) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. """ @@ -59,7 +64,11 @@ def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: qparams["scale"] = result.q_scale() # type: ignore[assignment] qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] - elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: + elif qscheme in { + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric, + }: # In this branch, scale and zero_point are expected to be tensors, # we store the values as immutable_list in TensorMetadata for # easier serialization downstream @@ -68,7 +77,9 @@ def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] return TensorMetadata( - shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) + shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams + ) + @compatibility(is_backward_compatible=True) class ShapeProp(torch.fx.Interpreter): @@ -117,12 +128,14 @@ class ShapeProp(torch.fx.Interpreter): fake_mode (FakeTensorMode): A fake mode for copying the gm """ + def __init__(self, gm, fake_mode=None): super().__init__(gm) if fake_mode is None: fake_mode = detect_fake_mode() if fake_mode is not None: from torch._dynamo.utils import deepcopy_to_fake_tensor + # Note: # We need fake execution cause the inputs are fake, however, we cannot fakify the module # - because we need to write to the tensor_meta of the real module. So we fakify to @@ -140,7 +153,7 @@ class ShapeProp(torch.fx.Interpreter): self.real_module = self.module - def run_node(self, n : Node) -> Any: + def run_node(self, n: Node) -> Any: try: if self.fake_module is not None: # Hacky swap. Alternatively, we could do this with overriding @@ -157,8 +170,7 @@ class ShapeProp(torch.fx.Interpreter): except Exception as e: traceback.print_exc() raise RuntimeError( - f"ShapeProp error for: node={n.format_node()} with " - f"meta={n.meta}" + f"ShapeProp error for: node={n.format_node()} with " f"meta={n.meta}" ) from e found_tensor = False @@ -173,9 +185,9 @@ class ShapeProp(torch.fx.Interpreter): meta = map_aggregate(result, extract_tensor_meta) if found_tensor: - n.meta['tensor_meta'] = meta + n.meta["tensor_meta"] = meta - n.meta['type'] = type(result) + n.meta["type"] = type(result) return result def propagate(self, *args): @@ -190,7 +202,10 @@ class ShapeProp(torch.fx.Interpreter): Any: The value returned from executing the Module """ if self.fake_mode is not None: - fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args] + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] else: fake_args = args return super().run(*fake_args) diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 8c37dcf37a3a..19709978248d 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -1,19 +1,20 @@ # mypy: allow-untyped-defs import inspect -from typing import Any, Callable, Dict, List, Optional, Set -from collections import OrderedDict import logging +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Optional, Set import torch from torch.fx._compatibility import compatibility +from torch.fx._utils import lazy_format_graph_code from torch.fx.graph_module import GraphModule from torch.fx.node import Node -from torch.fx._utils import lazy_format_graph_code __all__ = ["Partition", "split_module"] log = _LOGGER = logging.getLogger(__name__) + @compatibility(is_backward_compatible=True) class Partition: def __init__(self, name: str): @@ -146,9 +147,7 @@ def split_module( log.debug( "%s", - lazy_format_graph_code( - "pre split_module", m, colored=True - ), + lazy_format_graph_code("pre split_module", m, colored=True), ) def construct_graph( @@ -161,11 +160,20 @@ def split_module( node.args[0] if len(node.args) > 0 else inspect.Signature.empty ) if keep_original_node_name: - args = () if default_value is inspect.Signature.empty else (default_value,) - base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type] + args = ( + () if default_value is inspect.Signature.empty else (default_value,) + ) + base_mod_env[node.name] = base_mod_graph.create_node( + "placeholder", + node.name, + args=args, # type: ignore[arg-type] + type_expr=node.type, + ) else: base_mod_env[node.name] = base_mod_graph.placeholder( - node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type] + node.target, # type: ignore[arg-type] + type_expr=node.type, + default_value=default_value, ) base_mod_env[node.name].meta = node.meta.copy() elif node.op == "get_attr": @@ -185,9 +193,7 @@ def split_module( orig_nodes: Dict[str, Node] = {} symbol_to_node: Dict[sympy.Symbol, Node] = {} - def record_cross_partition_use( - def_node: Node, use_node: Optional[Node] - ): # noqa: B950 + def record_cross_partition_use(def_node: Node, use_node: Optional[Node]): from torch.fx.experimental.symbolic_shapes import free_symbols defined = getattr(def_node, "_fx_partition", None) @@ -195,7 +201,10 @@ def split_module( log.debug( "record_cross_partition_use %s (%s) %s (%s)", - def_node.name, defined, use_node.name if use_node is not None else "-", used + def_node.name, + defined, + use_node.name if use_node is not None else "-", + used, ) if defined != used: @@ -234,7 +243,9 @@ def split_module( def instantiate_node_partition_mapping(node): partition_name = str(split_callback(node)) - log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name) + log.debug( + "instantiate_node_partition_mapping %s (%s)", node.name, partition_name + ) # add node to partitions partition = partitions.get(partition_name) @@ -249,7 +260,7 @@ def split_module( GLOBAL_STATE_NODES = [ torch.amp._enter_autocast, torch.amp._exit_autocast, - torch._C._set_grad_enabled + torch._C._set_grad_enabled, ] # For grad regions: @@ -280,10 +291,10 @@ def split_module( # rely on later, but this needs some extra work. Quick fix first. # See https://github.com/pytorch/pytorch/issues/130534 if ( - (val := node.meta.get("example_value")) is not None and - isinstance(val, torch.SymInt) and - isinstance(s0 := val.node.expr, sympy.Symbol) and - s0 not in symbol_to_node + (val := node.meta.get("example_value")) is not None + and isinstance(val, torch.SymInt) + and isinstance(s0 := val.node.expr, sympy.Symbol) + and s0 not in symbol_to_node ): symbol_to_node[val.node.expr] = node @@ -344,9 +355,10 @@ def split_module( if assert_monotonically_increasing: pid = split_callback(node) - assert highest_partition <= pid, \ - ("autocast or set_grad_enabled require monotonically increasing partitions:" - f"highest: {highest_partition}, this node's: {pid}") + assert highest_partition <= pid, ( + "autocast or set_grad_enabled require monotonically increasing partitions:" + f"highest: {highest_partition}, this node's: {pid}" + ) highest_partition = pid # do not capture cross-partition dependencies for global state nodes as they will be @@ -392,7 +404,9 @@ def split_module( kwargs={}, type_expr=node.type, ) - new_node.meta = node.meta.copy() # is it really a good idea to copy this? + new_node.meta = ( + node.meta.copy() + ) # is it really a good idea to copy this? partition.environment[node] = new_node # add placeholders to partition inputs @@ -425,7 +439,9 @@ def split_module( target_attr = m for atom in target_atoms: if not hasattr(target_attr, atom): - raise AttributeError(f"Operator target {node.target} not found!") + raise AttributeError( + f"Operator target {node.target} not found!" + ) target_attr = getattr(target_attr, atom) # target = target_atoms[-1] target = "_".join(target_atoms) @@ -467,7 +483,9 @@ def split_module( kwargs={}, type_expr=exit_node.type, ) - new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this? + new_node.meta = ( + exit_node.meta.copy() + ) # is it really a good idea to copy this? # original module environment dict mapping node names to nodes orig_mod_env: Dict[str, Node] = {} @@ -520,7 +538,9 @@ def split_module( if keep_original_order: # first get the attr nodes required by this partition orig_mod_attr_nodes: List[Node] = [ - orig_mod_env[key] for key in partition.inputs if key not in original_order + orig_mod_env[key] + for key in partition.inputs + if key not in original_order ] for node in original_order: @@ -568,8 +588,6 @@ def split_module( ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) log.debug( "%s", - lazy_format_graph_code( - "post split_module", ret, colored=True - ), + lazy_format_graph_code("post split_module", ret, colored=True), ) return ret diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index 1c003966983f..e2bece6f72f2 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -10,6 +10,7 @@ from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module from .tools_common import NodeList + __all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"] diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index bccd4bda6dfd..31cb357df353 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -1,40 +1,44 @@ # mypy: allow-untyped-defs import argparse import copy +import logging from collections import defaultdict from dataclasses import dataclass -from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple -import logging +from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple import torch -from torch.fx.passes.graph_manipulation import get_size_of_node -from torch.fx.node import map_arg from torch.fx._compatibility import compatibility +from torch.fx.node import map_arg +from torch.fx.passes.graph_manipulation import get_size_of_node -from .operator_support import ( - get_node_target, - OperatorSupportBase, -) from .graph_drawer import FxGraphDrawer +from .operator_support import get_node_target, OperatorSupportBase from .shape_prop import ShapeProp from .split_utils import split_by_tags from .tools_common import ( - FxNetAccFusionsFinder, CALLABLE_NODE_OPS, - Tensors, + FxNetAccFusionsFinder, + is_node_output_tensor, NodeList, NodeSet, - is_node_output_tensor, + Tensors, ) -__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules'] +__all__ = [ + "FxNetAccNodesFinder", + "FxNetSplitterInternalError", + "Subgraph", + "SplitResult", + "generate_inputs_for_submodules", +] _LOGGER = logging.getLogger(__name__) DEFAULT_MIN_ACC_MODULE_SIZE = 1 DEFAULT_SKIP_FUSION = False DEFAULT_ALLOW_NON_TENSOR = False + class _SplitterSettingBase: def __init__( self, @@ -82,9 +86,15 @@ class _SplitterSettingBase: ) args, _unknown = parser.parse_known_args() - self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size + self.min_acc_module_size: int = ( + args.min_acc_module_size + if args.min_acc_module_size + else min_acc_module_size + ) self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion - self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + self.allow_non_tensor: bool = ( + args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + ) self.max_acc_splits: int = max_acc_splits @@ -114,9 +124,7 @@ class FxNetAccNodesFinder: self.allow_non_tensor = allow_non_tensor self.acc_nodes: NodeSet = set() - def reduce_acc_nodes_non_tensor_input_helper( - self, cpu_worklist: NodeList - ): + def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): """ Transitively excludes nodes from ACC supported set. For every node in the worklist: @@ -190,10 +198,12 @@ class FxNetAccNodesFinder: return self.acc_nodes + @compatibility(is_backward_compatible=False) class FxNetSplitterInternalError(Exception): pass + @compatibility(is_backward_compatible=False) @dataclass class Subgraph: @@ -201,6 +211,7 @@ class Subgraph: nodes: NodeList device_ordinal: Optional[int] = None + @compatibility(is_backward_compatible=False) class SplitResult(NamedTuple): """ @@ -243,7 +254,9 @@ def generate_inputs_for_submodules( submodule_to_names = {mod: name for name, mod in model.named_modules()} def pre_forward(module, module_inputs): - results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs + results[submodule_to_names[module]] = ( + copy.deepcopy(module_inputs) if deepcopy else module_inputs + ) for name, mod in model.named_modules(): if name in target_submodules: @@ -308,7 +321,7 @@ class _SplitterBase: """ # PCIe bandwidth for the backend, default to 100 GB/s - PCIe_BW = 100 * 2 ** 30 + PCIe_BW = 100 * 2**30 def __init__( self, @@ -335,7 +348,9 @@ class _SplitterBase: self.settings = settings self.operator_support = operator_support self.sample_input = sample_input - self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)() + self.acc_nodes = FxNetAccNodesFinder( + self.module, self.operator_support, self.settings.allow_non_tensor + )() if self.settings.skip_fusion: self.fusions = {} @@ -357,11 +372,11 @@ class _SplitterBase: # =============================================================== def get_node_submodule_map(self) -> Dict[str, str]: - """ Returns a map from node name to submodule name, e.g. - node: main_module_impl_impl_over_arch_unary_multiple_embedding - _pooling_embedding_pooling_sparse_entity_equivalence_key - _proxy_embedding_bag - maps to submodule name of: _run_on_acc_1 + """Returns a map from node name to submodule name, e.g. + node: main_module_impl_impl_over_arch_unary_multiple_embedding + _pooling_embedding_pooling_sparse_entity_equivalence_key + _proxy_embedding_bag + maps to submodule name of: _run_on_acc_1 """ return self._node_submodule_map @@ -411,9 +426,7 @@ class _SplitterBase: return mod - def _find_culprit( - self, mod: torch.fx.GraphModule, inputs: Tensors - ) -> str: + def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str: """ When an error occurs during lowering or running the lowered mod, we use this function to find culprits in the `mod` that causes the error. @@ -492,7 +505,9 @@ class _SplitterBase: supported_nodes.append(node) supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) else: - unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) + unsupported_node_types[target].add( + (arg_dtypes_tuple, kwarg_dtypes_tuple) + ) if dump_graph: self._draw_graph_based_on_node_support(self.module, supported_nodes) @@ -527,7 +542,11 @@ class _SplitterBase: reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" for i, subgraph in enumerate(subgraphs): - reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: " + reports += ( + f"_run_on_acc_{i}: " + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{i}: " + ) reports += f"{len(subgraph.nodes)} node(s)\n" self.tag(subgraphs) @@ -535,9 +554,7 @@ class _SplitterBase: split_mod.eval() if dump_graph: - drawer = FxGraphDrawer( - split_mod, "preview", ignore_getattr=True - ) + drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True) dot_graphs = drawer.get_all_dot_graphs() for name, dot_graph in dot_graphs.items(): # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. @@ -564,9 +581,7 @@ class _SplitterBase: handle.remove() return sub_inputs - submod_inputs = get_submod_inputs( - split_mod, submod, self.sample_input - ) + submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input) ShapeProp(submod).propagate(*submod_inputs) total_input_bytes = 0 @@ -649,9 +664,7 @@ class _SplitterBase: return result - def update_reverse_deps_for_fusions( - self, deps: Dict[torch.fx.Node, NodeSet] - ): + def update_reverse_deps_for_fusions(self, deps: Dict[torch.fx.Node, NodeSet]): processed_node = set() for node, fusion in self.fusions.items(): @@ -853,7 +866,11 @@ class _SplitterBase: def tag(self, subgraphs: List[Subgraph]): self.tags = [] for subgraph in subgraphs: - tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}" + tag = ( + f"_run_on_acc_{len(self.tags)}" + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{len(self.tags)}" + ) self.tags.append(tag) for node in subgraph.nodes: if hasattr(node, "tag"): @@ -863,7 +880,9 @@ class _SplitterBase: self._node_submodule_map[node.name] = tag def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: - split_module = split_by_tags(self.module, self.tags, return_tuple=self._return_tuple) + split_module = split_by_tags( + self.module, self.tags, return_tuple=self._return_tuple + ) if remove_tag: for node in self.module.graph.nodes: if hasattr(node, "tag"): @@ -875,7 +894,9 @@ class _SplitterBase: subgraphs = self.remove_small_acc_subgraphs(subgraphs) acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count - print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs") + print( + f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs" + ) self.tag(subgraphs) return self.split() @@ -894,5 +915,7 @@ class _SplitterBase: "result in performance issues." ) - submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) + submodule_inputs = generate_inputs_for_submodules( + split_module, self.sample_input, submodule_names + ) return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) diff --git a/torch/fx/passes/tests/test_pass_manager.py b/torch/fx/passes/tests/test_pass_manager.py index 60ed6671179b..157dc4017eda 100644 --- a/torch/fx/passes/tests/test_pass_manager.py +++ b/torch/fx/passes/tests/test_pass_manager.py @@ -26,9 +26,7 @@ class TestPassManager(unittest.TestCase): def test_these_before_those_pass_constraint(self) -> None: passes = [lambda x: 2 * x for _ in range(10)] constraint = these_before_those_pass_constraint(passes[-1], passes[0]) - pm = PassManager( - [inplace_wrapper(p) for p in passes] - ) + pm = PassManager([inplace_wrapper(p) for p in passes]) # add unfulfillable constraint pm.add_constraint(constraint) @@ -46,7 +44,7 @@ class TestPassManager(unittest.TestCase): pm1.add_pass(p) pm1.add_constraint(constraint) output1 = pm1(1) - self.assertEqual(output1, 2 ** 3) + self.assertEqual(output1, 2**3) passes = [lambda x: 3 * x for _ in range(3)] constraint = these_before_those_pass_constraint(passes[0], passes[1]) @@ -55,4 +53,4 @@ class TestPassManager(unittest.TestCase): pm2.add_pass(p) pm2.add_constraint(constraint) output2 = pm2(1) - self.assertEqual(output2, 3 ** 3) + self.assertEqual(output2, 3**3) diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index aac071ace8c2..4ed56be63b09 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -1,15 +1,22 @@ # mypy: allow-untyped-defs -from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional import collections -from dataclasses import dataclass import operator +from dataclasses import dataclass +from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union import torch import torch.fx -from torch.fx.node import _get_qualified_name from torch.fx._compatibility import compatibility +from torch.fx.node import _get_qualified_name -__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph'] + +__all__ = [ + "get_acc_ops_name", + "get_node_target", + "is_node_output_tensor", + "FxNetAccFusionsFinder", + "legalize_graph", +] Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] TensorOrTensors = Union[torch.Tensor, Tensors] @@ -26,12 +33,16 @@ def get_acc_ops_name(k): elif k.__module__ and "acc_ops" in k.__module__: return f"acc_ops.{k.__name__}" else: - module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + module = k.__module__.replace( + "torch._ops", "torch.ops" + ) # WAR for bug in how torch.ops assigns module return f"{module if module else ''}.{k.__name__}" @compatibility(is_backward_compatible=False) -def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str: +def get_node_target( + submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node +) -> str: """ Given a `node` returns its target typename. @@ -66,6 +77,7 @@ def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.No assert isinstance(node.target, str) return node.target + @compatibility(is_backward_compatible=False) def is_node_output_tensor(node: torch.fx.Node) -> bool: """Checks if the node output produces a Tensor or not. @@ -77,6 +89,7 @@ def is_node_output_tensor(node: torch.fx.Node) -> bool: type_ = node.meta.get("type", None) return type_ is not None and issubclass(type_, torch.Tensor) + @compatibility(is_backward_compatible=False) class FxNetAccFusionsFinder: """ @@ -297,7 +310,9 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: # If the new graph's size is not as large as the old one, then there must be # a cycle (i.e. some node's dependencies were not satisfied.) if len(new_graph.nodes) < len(gm.graph.nodes): - raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}") + raise RuntimeError( + f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}" + ) new_graph._codegen = gm.graph._codegen gm.graph = new_graph return gm diff --git a/torch/fx/passes/utils/__init__.py b/torch/fx/passes/utils/__init__.py index 2a7970ba4c28..ee5e7e66868a 100644 --- a/torch/fx/passes/utils/__init__.py +++ b/torch/fx/passes/utils/__init__.py @@ -1 +1 @@ -from .common import lift_subgraph_as_module, HolderModule, compare_graphs +from .common import compare_graphs, HolderModule, lift_subgraph_as_module diff --git a/torch/fx/passes/utils/common.py b/torch/fx/passes/utils/common.py index ba2ae45aabf5..bb628372337b 100644 --- a/torch/fx/passes/utils/common.py +++ b/torch/fx/passes/utils/common.py @@ -3,7 +3,6 @@ from typing import Dict, Tuple from torch.fx._compatibility import compatibility from torch.fx.graph import Graph - from torch.fx.graph_module import GraphModule from torch.fx.passes.utils.matcher_utils import SubgraphMatcher from torch.nn import Module diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 3268cc4a493c..8bcb9dee71c2 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -1,15 +1,16 @@ # mypy: allow-untyped-defs import copy from queue import SimpleQueue -from typing import List, Dict, Optional as _Optional, Tuple +from typing import Dict, List, Optional as _Optional, Tuple import torch.fx -from torch.fx.graph_module import GraphModule -from torch.fx.graph import Graph -from torch.fx.node import Node -from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph -from torch.fx.passes.utils import lift_subgraph_as_module from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node +from torch.fx.passes.tools_common import legalize_graph, NodeList, NodeSet +from torch.fx.passes.utils import lift_subgraph_as_module + @compatibility(is_backward_compatible=False) def topo_sort(nodes: NodeList) -> NodeList: @@ -35,7 +36,9 @@ def topo_sort(nodes: NodeList) -> NodeList: if indegree_map[n] == 0: candidates.put(n) - assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes" + assert len(nodes) == len( + sorted_nodes + ), "topological sorted nodes doesn't have same length as input nodes" return sorted_nodes @@ -96,7 +99,6 @@ def fuse_as_graphmodule( module_name: str, partition_lookup_table: _Optional[Dict[Node, None]] = None, ) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: - """ Fuse nodes in graph_module into a GraphModule. @@ -121,9 +123,13 @@ def fuse_as_graphmodule( # assumption: nodes are already sorted in topo order for node in nodes: - assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}" + assert ( + node.graph.owning_module is gm + ), f"{node} doesn't belong to passed in graph module {gm._get_name()}" assert not node._erased, f"{node} has been removed from owning graph" - assert node in gm.graph._find_nodes_lookup_table, f"{node} is not found in graph module {gm._get_name()}" + assert ( + node in gm.graph._find_nodes_lookup_table + ), f"{node} is not found in graph module {gm._get_name()}" # validates partition doesn't introduce dependency circles in the graph assert validate_partition(nodes), "Invalid partition, found dependency cycles" @@ -134,8 +140,10 @@ def fuse_as_graphmodule( subgraph = Graph() - node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph - node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph + node_to_placeholder: Dict[ + Node, Node + ] = {} # mapping of nodes from old graph to placeholder in new graph + node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph # handles inputs through graph.node_copy's arg_transform functions def remap_inputs(x): @@ -184,7 +192,9 @@ def fuse_as_graphmodule( # lint to ensure correctness subgraph.lint() fused_gm: GraphModule - fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name) + fused_gm, _ = lift_subgraph_as_module( + gm, subgraph, comp_name="", class_name=module_name + ) # sub_gm's input nodes in the original module original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys()) @@ -196,16 +206,18 @@ def fuse_as_graphmodule( @compatibility(is_backward_compatible=False) -def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]): +def insert_subgm( + gm: GraphModule, + sub_gm: GraphModule, + orig_inputs: Tuple[Node, ...], + orig_outputs: Tuple[Node, ...], +): # add sub_gm into gm submodule_name = sub_gm.__class__.__name__ gm.add_submodule(submodule_name, sub_gm) # Create a call_module node in main graph. - module_node = gm.graph.call_module( - submodule_name, - args=orig_inputs, - kwargs=None) + module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None) if len(orig_outputs) == 1: # main_remapping[comp.orig_outputs[0]] = module_node @@ -216,24 +228,30 @@ def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) - module_node.meta["val"] = tuple(orig_output.meta.get("val", None) for orig_output in orig_outputs) + module_node.meta["val"] = tuple( + orig_output.meta.get("val", None) for orig_output in orig_outputs + ) return gm + @compatibility(is_backward_compatible=False) def erase_nodes(gm: GraphModule, nodes: NodeList): - # erase original nodes in inversed topological order for node in reversed(nodes): gm.graph.erase_node(node) @compatibility(is_backward_compatible=False) -def fuse_by_partitions(gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_") -> GraphModule: +def fuse_by_partitions( + gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_" +) -> GraphModule: for partition_id, partition in enumerate(partitions): sorted_nodes = topo_sort(list(partition)) submodule_name = prefix + str(partition_id) - sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name, partition) + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( + gm, sorted_nodes, submodule_name, partition + ) insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index ba09ad177d29..cc05b8f512b1 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -10,6 +10,7 @@ import torch from torch.fx import Graph, Node from torch.fx._compatibility import compatibility + __all__ = ["SubgraphMatcher", "InternalMatch"] diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index 0a4f072644cd..f77db98880b7 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -1,19 +1,21 @@ -from dataclasses import dataclass, field -from torch.fx.graph import Graph -from torch.fx.node import Node -from torch.fx._compatibility import compatibility -from typing import Dict, List, Any, Type, Optional, Callable import logging import os +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Type + +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.node import Node -__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition'] +__all__ = ["get_source_partitions", "check_subgraphs_connected", "SourcePartition"] + # Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs def _init_logger() -> logging.Logger: logger = logging.getLogger(__name__) - level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper() + level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper() logger.setLevel(level) console = logging.StreamHandler() formatter = logging.Formatter("%(filename)s > %(message)s") @@ -24,6 +26,7 @@ def _init_logger() -> logging.Logger: logger.propagate = False return logger + logger = _init_logger() @@ -77,8 +80,9 @@ def get_source_partitions( # be different from "source_fn_stack", for example for the add_ node # decomposed from batch norm. We should remove the check on "source_fn_stack" # after we fix "torch_fn". T199561090 - if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and - (torch_fn := node.meta.get("torch_fn", None)) is not None): + if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and ( + torch_fn := node.meta.get("torch_fn", None) + ) is not None: node_fqn, source_fn = torch_fn source_fn_name = source_fn.split(".")[1] if source_fn_name in wanted_sources: @@ -86,7 +90,6 @@ def get_source_partitions( partition = diff_modules.setdefault(node_fqn, []) partition.append(node) - if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None: source_fn = source_fn_st[-1] if source_fn[1] in wanted_sources: @@ -140,7 +143,9 @@ def get_source_partitions( @compatibility(is_backward_compatible=False) # type: ignore[misc] -def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool: +def check_subgraphs_connected( + subgraph1: SourcePartition, subgraph2: SourcePartition +) -> bool: """ Given two subgraphs A and B (in the form of a list of nodes), checks if A has nodes connecting to at least one node in B -- aka there exists a node diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 84a1b12f5bc9..ccbe06575474 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -1,29 +1,37 @@ # mypy: ignore-errors -import enum -import dis -import copy -import sys -import torch -import inspect -import operator import collections +import copy +import dis +import enum +import inspect import logging +import operator +import sys +from dataclasses import fields, is_dataclass +from typing import Any, Callable, Dict, Iterator, Optional, OrderedDict, Tuple -from dataclasses import is_dataclass, fields - - -from .graph import magic_methods, reflectable_magic_methods, Graph -from torch.utils._traceback import CapturedTraceback -from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable -from .node import Target, Node, Argument, base_types, map_aggregate -from ._compatibility import compatibility -from .operator_schemas import check_for_mutable_operation +import torch import torch.fx.traceback as fx_traceback +from torch.utils._traceback import CapturedTraceback -__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', - 'Proxy', 'MetaProxy', 'Attribute', 'ParameterProxy', 'Scope', - 'ScopeContextManager'] +from ._compatibility import compatibility +from .graph import Graph, magic_methods, reflectable_magic_methods +from .node import Argument, base_types, map_aggregate, Node, Target +from .operator_schemas import check_for_mutable_operation + + +__all__ = [ + "TracerBase", + "GraphAppendingTracer", + "TraceError", + "Proxy", + "MetaProxy", + "Attribute", + "ParameterProxy", + "Scope", + "ScopeContextManager", +] log = logging.getLogger(__name__) @@ -31,7 +39,7 @@ log = logging.getLogger(__name__) @compatibility(is_backward_compatible=False) class Scope: - """ Scope object that records the module path and the module type + """Scope object that records the module path and the module type of a module. Scope is used to track the information of the module that contains a Node in a Graph of GraphModule. For example:: @@ -41,6 +49,7 @@ class Scope: # scope for this would be (module_path="sub", module_type=Sub) return x.transpose(1, 2) + class M(torch.nn.Module): def __init__(self) -> None: self.sub = Sub() @@ -62,7 +71,7 @@ class Scope: @compatibility(is_backward_compatible=False) class ScopeContextManager: - """ A context manager to track the Scope of Node during symbolic tracing. + """A context manager to track the Scope of Node during symbolic tracing. When entering a forward function of a Module, we'll update the scope information of the current module, and when we exit, we'll restore the previous scope information. """ @@ -102,28 +111,28 @@ _COPY_META_FIELDS = [ "quantization_tag", # TODO deprecated "_numeric_debug_handle", # TODO deprecated "custom", - "partitioner_tag" + "partitioner_tag", ] @compatibility(is_backward_compatible=True) class TracerBase: graph: Graph - record_stack_traces : bool = False + record_stack_traces: bool = False # Feature flag for mutable schema checking # Enableby default in 1.12 - check_mutable_operations : bool = False + check_mutable_operations: bool = False # Feature flag for assert tracing - trace_asserts : bool = False + trace_asserts: bool = False # Feature flag for proxying accesses to buffer values - proxy_buffer_attributes : bool = False + proxy_buffer_attributes: bool = False # Name of the function to be traced. It will only be used when # ``root`` is an instance of ``nn.Module`` traced_func_name: str = "forward" # Maps the containing module's name to the operator name - scope : Scope + scope: Scope # Records the module call stack module_stack: OrderedDict[str, Tuple[str, Any]] @@ -132,9 +141,15 @@ class TracerBase: node_name_to_scope: Dict[str, Tuple[str, type]] @compatibility(is_backward_compatible=True) - def create_node(self, kind : str, target : Target, - args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: + def create_node( + self, + kind: str, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Inserts a graph node given target, args, kwargs, and name. @@ -143,7 +158,7 @@ class TracerBase: want to disallow in-place operations from being recorded. """ - if kind == 'call_function' and self.check_mutable_operations: + if kind == "call_function" and self.check_mutable_operations: check_for_mutable_operation(target, args, kwargs) node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) @@ -182,20 +197,27 @@ class TracerBase: node.meta["seq_nr"] = new_seq_nr elif self.module_stack: - node.meta['nn_module_stack'] = copy.copy(self.module_stack) + node.meta["nn_module_stack"] = copy.copy(self.module_stack) log.debug("create_node %s", node) return node @compatibility(is_backward_compatible=True) - def proxy(self, node: Node) -> 'Proxy': + def proxy(self, node: Node) -> "Proxy": return Proxy(node, self) @compatibility(is_backward_compatible=True) - def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], - name: Optional[str] = None, type_expr : Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - ''' + def create_proxy( + self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], "Proxy"] = None, + ): + """ Create a Node from the given arguments, then return the Node wrapped in a Proxy object. @@ -203,7 +225,7 @@ class TracerBase: represents the parameter of a function. If we need to encode a default parameter, we use the ``args`` tuple. ``args`` is otherwise empty for ``placeholder`` Nodes. - ''' + """ args_ = self.create_arg(args) kwargs_ = self.create_arg(kwargs) @@ -218,8 +240,7 @@ class TracerBase: proxy = proxy_factory_fn(node) if self.record_stack_traces and not proxy.node.stack_trace: - proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format()) - + proxy.node.stack_trace = "".join(CapturedTraceback.extract().format()) return proxy @@ -233,20 +254,23 @@ class TracerBase: # the user code during tracing. frame = inspect.currentframe() - pt_files = ['torch/fx/proxy.py', - 'torch/fx/_symbolic_trace.py', - 'torch/fx/experimental/proxy_tensor.py', - 'torch/_ops.py', - 'torch/_tensor.py', - 'torch/utils/_python_dispatch.py', - 'torch/_prims_common/wrappers.py', - 'torch/_refs/__init__.py', - 'torch/_refs/nn/functional/__init__.py', - 'torch/utils/_stats.py', - ] + pt_files = [ + "torch/fx/proxy.py", + "torch/fx/_symbolic_trace.py", + "torch/fx/experimental/proxy_tensor.py", + "torch/_ops.py", + "torch/_tensor.py", + "torch/utils/_python_dispatch.py", + "torch/_prims_common/wrappers.py", + "torch/_refs/__init__.py", + "torch/_refs/nn/functional/__init__.py", + "torch/utils/_stats.py", + ] while frame: frame = frame.f_back - if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): + if frame and all( + not frame.f_code.co_filename.endswith(file) for file in pt_files + ): break if not frame: @@ -264,11 +288,11 @@ class TracerBase: """ if isinstance(a, Proxy): return a.node # most common arg type goes first - elif hasattr(a, '__fx_create_arg__'): + elif hasattr(a, "__fx_create_arg__"): return a.__fx_create_arg__(self) # aggregates elif isinstance(a, tuple): - if hasattr(a, '_fields'): + if hasattr(a, "_fields"): # NamedTuple constructors don't seem to like getting a generator # expression as an argument to their constructor, so build this # intermediate tuple and unpack it into the NamedTuple constructor @@ -278,10 +302,13 @@ class TracerBase: elif isinstance(a, list): return [self.create_arg(elem) for elem in a] elif isinstance(a, dict): + def no_node(arg): if isinstance(arg, Node): - raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " - f"Node. Got key: {k}") + raise RuntimeError( + "Keys for dictionaries used as an argument cannot contain a " + f"Node. Got key: {k}" + ) r = {} for k, v in a.items(): @@ -294,16 +321,27 @@ class TracerBase: r[k] = self.create_arg(v) return r elif isinstance(a, slice): - return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + return slice( + self.create_arg(a.start), + self.create_arg(a.stop), + self.create_arg(a.step), + ) elif isinstance(a, range): - return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + return range( + self.create_arg(a.start), + self.create_arg(a.stop), + self.create_arg(a.step), + ) elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): return a elif is_dataclass(a): - kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} + kwargs = { + field.name: self.create_arg(getattr(a, field.name)) + for field in fields(a) + } return self.create_node("call_function", a.__class__, (), kwargs) elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: @@ -312,37 +350,41 @@ class TracerBase: raise NotImplementedError(f"argument of type: {type(a)}") @compatibility(is_backward_compatible=True) - def to_bool(self, obj: 'Proxy') -> bool: + def to_bool(self, obj: "Proxy") -> bool: """Called when a proxy object is being converted to a boolean, such as when used in control flow. Normally we don't know what to do because we don't know the value of the proxy, but a custom tracer can attach more information to the graph node using create_node and can choose to return a value. """ - raise TraceError('symbolically traced variables cannot be used as inputs to control flow') + raise TraceError( + "symbolically traced variables cannot be used as inputs to control flow" + ) @compatibility(is_backward_compatible=True) - def iter(self, obj: 'Proxy') -> Iterator: + def iter(self, obj: "Proxy") -> Iterator: """Called when a proxy object is being iterated over, such as when used in control flow. Normally we don't know what to do because we don't know the value of the proxy, but a custom tracer can attach more information to the graph node using create_node and can choose to return an iterator. """ - raise TraceError('Proxy object cannot be iterated. This can be ' - 'attempted when the Proxy is used in a loop or' - ' as a *args or **kwargs function argument. ' - 'See the torch.fx docs on pytorch.org for a ' - 'more detailed explanation of what types of ' - 'control flow can be traced, and check out the' - ' Proxy docstring for help troubleshooting ' - 'Proxy iteration errors') + raise TraceError( + "Proxy object cannot be iterated. This can be " + "attempted when the Proxy is used in a loop or" + " as a *args or **kwargs function argument. " + "See the torch.fx docs on pytorch.org for a " + "more detailed explanation of what types of " + "control flow can be traced, and check out the" + " Proxy docstring for help troubleshooting " + "Proxy iteration errors" + ) @compatibility(is_backward_compatible=True) - def keys(self, obj: 'Proxy') -> Any: + def keys(self, obj: "Proxy") -> Any: """Called when a proxy object is has the keys() method called. This is what happens when ** is called on a proxy. This should return an iterator it ** is suppose to work in your custom tracer. """ - return Attribute(obj, 'keys')() + return Attribute(obj, "keys")() # used in Proxy object when just appending to the graph while not tracing. @@ -355,14 +397,17 @@ class GraphAppendingTracer(TracerBase): self.module_stack = collections.OrderedDict() self.node_name_to_scope = {} + @compatibility(is_backward_compatible=False) def assert_fn(x): assert x + @compatibility(is_backward_compatible=True) class TraceError(ValueError): pass + @compatibility(is_backward_compatible=True) class Proxy: """ @@ -394,7 +439,7 @@ class Proxy: """ @compatibility(is_backward_compatible=True) - def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): + def __init__(self, node: Node, tracer: "Optional[TracerBase]" = None): if tracer is None: # This allows you to create a Proxy object around a raw Node tracer = GraphAppendingTracer(node.graph) @@ -402,9 +447,9 @@ class Proxy: self.node = node def __repr__(self) -> str: - return f'Proxy({self.node.name})' + return f"Proxy({self.node.name})" - def __getattr__(self, k) -> 'Attribute': + def __getattr__(self, k) -> "Attribute": # note: not added to the graph yet, if this is a method call # we peephole optimize to the method invocation return Attribute(self, k) @@ -417,6 +462,7 @@ class Proxy: # will go to __getattr__(self, "__deepcopy__") and return a # Attribute(__deepcopy__), and may go into an infinite loop in some cases. import copy + new_dict = {} for k, v in self.__dict__.items(): try: @@ -424,7 +470,10 @@ class Proxy: except Exception: log.warning( "Shallow copy %s of Proxy because it cannot be deepcopied. " - "Proxy is created for node %s", k, self.node.name) + "Proxy is created for node %s", + k, + self.node.name, + ) new_obj = copy.copy(v) new_dict[k] = new_obj assert "node" in new_dict @@ -438,10 +487,12 @@ class Proxy: # This is called when being unpickled/loaded. self.__dict__ = d - def __call__(self, *args, **kwargs) -> 'Proxy': - return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) + def __call__(self, *args, **kwargs) -> "Proxy": + return self.tracer.create_proxy( + "call_method", "__call__", (self,) + args, kwargs + ) - def __iter__(self) -> Iterator['Proxy']: + def __iter__(self) -> Iterator["Proxy"]: frame = inspect.currentframe() assert frame is not None calling_frame = frame.f_back @@ -449,17 +500,20 @@ class Proxy: inst_list = list(dis.get_instructions(calling_frame.f_code)) if sys.version_info >= (3, 11): from bisect import bisect_left - inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset) + + inst_idx = bisect_left( + inst_list, calling_frame.f_lasti, key=lambda x: x.offset + ) else: inst_idx = calling_frame.f_lasti // 2 inst = inst_list[inst_idx] - if inst.opname == 'UNPACK_SEQUENCE': + if inst.opname == "UNPACK_SEQUENCE": return (self[i] for i in range(inst.argval)) # type: ignore[index] return self.tracer.iter(self) def __abs__(self): - return self.tracer.create_proxy('call_function', operator.abs, (self,), {}) + return self.tracer.create_proxy("call_function", operator.abs, (self,), {}) def __bool__(self) -> bool: if self.tracer.trace_asserts: @@ -472,19 +526,23 @@ class Proxy: insts = list(dis.get_instructions(calling_frame.f_code)) if sys.version_info >= (3, 11): from bisect import bisect_left + cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) else: cur = calling_frame.f_lasti // 2 inst = insts[cur] - if inst.opname == 'POP_JUMP_IF_TRUE': + if inst.opname == "POP_JUMP_IF_TRUE": first = insts[cur + 1] assert inst.arg is not None last = insts[inst.arg // 2 - 1] - starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' - or first.opname == 'LOAD_ASSERTION_ERROR') - if starts_with_assert and last.opname == 'RAISE_VARARGS': - self.tracer.create_proxy('call_function', assert_fn, (self,), {}) + starts_with_assert = ( + first.opname == "LOAD_GLOBAL" + and first.argval == "AssertionError" + or first.opname == "LOAD_ASSERTION_ERROR" + ) + if starts_with_assert and last.opname == "RAISE_VARARGS": + self.tracer.create_proxy("call_function", assert_fn, (self,), {}) return True return self.tracer.to_bool(self) @@ -494,39 +552,51 @@ class Proxy: return self.tracer.keys(self) def __len__(self): - raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " - "this call to be recorded, please call torch.fx.wrap('len') at " - "module scope") + raise RuntimeError( + "'len' is not supported in symbolic tracing by default. If you want " + "this call to be recorded, please call torch.fx.wrap('len') at " + "module scope" + ) @classmethod def __torch_function__(cls, orig_method, types, args=None, kwargs=None): args = args if args else () kwargs = kwargs if kwargs else {} - tracers : Dict[Any, None] = {} + tracers: Dict[Any, None] = {} def find_tracer(a): if isinstance(a, cls): tracers[a.tracer] = None + torch.fx.node.map_aggregate(args, find_tracer) torch.fx.node.map_aggregate(kwargs, find_tracer) if len(tracers) > 1: - raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' - f'trying to trace operations {orig_method}') + raise RuntimeError( + f"Found multiple different tracers {list(tracers.keys())} while " + f"trying to trace operations {orig_method}" + ) tracer = next(iter(tracers.keys())) if isinstance(orig_method, torch._C.ScriptMethod): args = (orig_method.owner,) + args - return tracer.create_proxy('call_method', orig_method.name, args, kwargs) + return tracer.create_proxy("call_method", orig_method.name, args, kwargs) if torch.overrides.is_tensor_method_or_property(orig_method): - return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) + return tracer.create_proxy( + "call_method", orig_method.__name__, args, kwargs + ) else: if isinstance(orig_method, torch._ops.HigherOrderOperator): # TODO: Define how to symbolically trace HigherOrderOperators raise RuntimeError("Unable to symbolically trace HigherOrderOperators") - return tracer.create_proxy('call_function', orig_method, args, kwargs, - name=tracer.graph._target_to_str(orig_method.__name__)) + return tracer.create_proxy( + "call_function", + orig_method, + args, + kwargs, + name=tracer.graph._target_to_str(orig_method.__name__), + ) @compatibility(is_backward_compatible=False) @@ -535,12 +605,14 @@ class MetaProxy(Proxy): A Proxy subclass that propagates metadata (meta['val']) during graph tracing. """ - def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None, fake_mode=None): + def __init__( + self, node: Node, tracer: "Optional[TracerBase]" = None, fake_mode=None + ): super().__init__(node, tracer) self.fake_mode = fake_mode def __repr__(self) -> str: - return f'MetaProxy({self.node.name})' + return f"MetaProxy({self.node.name})" @classmethod def __torch_function__(cls, orig_method, types, args=None, kwargs=None): @@ -553,16 +625,19 @@ class MetaProxy(Proxy): meta_proxy = arg break - assert meta_proxy is not None, "No MetaProxy found in arguments, but one is expected." + assert ( + meta_proxy is not None + ), "No MetaProxy found in arguments, but one is expected." proxy = super().__torch_function__(orig_method, types, args, kwargs) with meta_proxy.fake_mode: proxy.node.meta["val"] = orig_method( *[a.node.meta["val"] if isinstance(a, Proxy) else a for a in args], - **kwargs + **kwargs, ) return MetaProxy(proxy.node, proxy.tracer, meta_proxy.fake_mode) + @compatibility(is_backward_compatible=True) class Attribute(Proxy): @compatibility(is_backward_compatible=True) @@ -577,11 +652,15 @@ class Attribute(Proxy): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy( + "call_function", getattr, (self.root, self.attr), {} + ).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy( + "call_method", self.attr, (self.root,) + args, kwargs + ) @compatibility(is_backward_compatible=False) @@ -591,6 +670,7 @@ class ParameterProxy(Proxy): attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing """ + def __init__(self, tracer: TracerBase, node: Node, name, param): super().__init__(node, tracer) assert isinstance(param, torch.nn.Parameter) @@ -598,7 +678,7 @@ class ParameterProxy(Proxy): self.name = name def __repr__(self) -> str: - return f'ParameterProxy({self.name})' + return f"ParameterProxy({self.name})" @property def shape(self): @@ -622,25 +702,31 @@ class ParameterProxy(Proxy): for method in magic_methods: + def _scope(method): def impl(*args, **kwargs): tracer = args[0].tracer target = getattr(operator, method) - return tracer.create_proxy('call_function', target, args, kwargs) + return tracer.create_proxy("call_function", target, args, kwargs) + impl.__name__ = method as_magic = f'__{method.strip("_")}__' setattr(Proxy, as_magic, impl) + _scope(method) + def _define_reflectable(orig_method_name): method_name = f'__r{orig_method_name.strip("_")}__' def impl(self, rhs): target = getattr(operator, orig_method_name) - return self.tracer.create_proxy('call_function', target, (rhs, self), {}) + return self.tracer.create_proxy("call_function", target, (rhs, self), {}) + impl.__name__ = method_name impl.__qualname__ = method_name setattr(Proxy, method_name, impl) + for orig_method_name in reflectable_magic_methods: _define_reflectable(orig_method_name) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 48da77f32ece..b823fda3123f 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -1,18 +1,36 @@ -from .graph_module import GraphModule -from .graph import Graph -from .node import Node -from ._symbolic_trace import symbolic_trace -from ._compatibility import compatibility - import copy from dataclasses import dataclass -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + TYPE_CHECKING, + Union, +) + import torch +from ._compatibility import compatibility +from ._symbolic_trace import symbolic_trace +from .graph import Graph +from .graph_module import GraphModule +from .node import Node + + if TYPE_CHECKING: from .passes.utils.matcher_with_name_node_map_utils import InternalMatch -__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"] +__all__ = [ + "Match", + "replace_pattern", + "replace_pattern_with_filters", + "ReplacedPatterns", +] + @compatibility(is_backward_compatible=True) class Match(NamedTuple): @@ -21,6 +39,7 @@ class Match(NamedTuple): # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node] + @compatibility(is_backward_compatible=False) @dataclass class ReplacedPatterns: @@ -31,6 +50,7 @@ class ReplacedPatterns: # List of nodes that were added into the graph replacements: List[Node] + def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: gm.delete_all_unused_submodules() @@ -48,7 +68,6 @@ def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: for node in gm.graph.nodes: if node.op == "call_module" or node.op == "get_attr": - gm_attr = try_get_attr(gm, node.target) replacement_attr = try_get_attr(replacement, node.target) @@ -70,11 +89,14 @@ def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: # CASE 3: The target doesn't exist as an attribute in `gm` # or `replacement` else: - raise RuntimeError('Attempted to create a "', node.op, - '" node during subgraph rewriting ' - f"with target {node.target}, but " - "the referenced attribute does not " - "exist in the replacement GraphModule") + raise RuntimeError( + 'Attempted to create a "', + node.op, + '" node during subgraph rewriting ' + f"with target {node.target}, but " + "the referenced attribute does not " + "exist in the replacement GraphModule", + ) gm.graph.lint() @@ -83,7 +105,7 @@ def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: def replace_pattern( gm: GraphModule, pattern: Union[Callable, GraphModule], - replacement: Union[Callable, GraphModule] + replacement: Union[Callable, GraphModule], ) -> List[Match]: """ Matches all possible non-overlapping sets of operators and their @@ -116,6 +138,7 @@ def replace_pattern( import torch from torch.fx import symbolic_trace, subgraph_rewriter + class M(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -125,12 +148,15 @@ def replace_pattern( m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) + def pattern(w1, w2): return torch.cat([w1, w2]).sum() + def replacement(w1, w2): return torch.stack([w1, w2]) + traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) @@ -199,7 +225,9 @@ def replace_pattern( return add_2 """ match_and_replacements = _replace_pattern(gm, pattern, replacement) - return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements] + return [ + Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements + ] # Experimental API, not backward compatible @@ -208,10 +236,14 @@ def replace_pattern_with_filters( gm: GraphModule, pattern: Union[Callable, Graph, GraphModule], replacement: Union[Callable, Graph, GraphModule, None] = None, - match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + match_filters: Optional[ + List[Callable[["InternalMatch", Graph, Graph], bool]] + ] = None, ignore_literals: bool = False, # Placed at the end to avoid breaking backward compatibility - replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None, + replacement_callback: Optional[ + Callable[["InternalMatch", Graph, Graph], Graph] + ] = None, ) -> List[ReplacedPatterns]: """ See replace_pattern for documentation. This function is an overload with an additional match_filter argument. @@ -226,20 +258,25 @@ def replace_pattern_with_filters( replacement graph based on the match. """ - return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals, replacement_callback) + return _replace_pattern( + gm, pattern, replacement, match_filters, ignore_literals, replacement_callback + ) def _replace_pattern( gm: GraphModule, pattern: Union[Callable, Graph, GraphModule], replacement: Union[Callable, Graph, GraphModule, None] = None, - match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + match_filters: Optional[ + List[Callable[["InternalMatch", Graph, Graph], bool]] + ] = None, ignore_literals: bool = False, # Placed at the end to avoid breaking backward compatibility - replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None, + replacement_callback: Optional[ + Callable[["InternalMatch", Graph, Graph], Graph] + ] = None, ) -> List[ReplacedPatterns]: - - from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch + from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher if match_filters is None: match_filters = [] @@ -254,15 +291,23 @@ def _replace_pattern( else: pattern_graph = symbolic_trace(pattern).graph - matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, - remove_overlapping_matches=True, ignore_literals=ignore_literals) + matcher = SubgraphMatcher( + pattern_graph, + match_output=False, + match_placeholder=False, + remove_overlapping_matches=True, + ignore_literals=ignore_literals, + ) _matches: List[InternalMatch] = matcher.match(original_graph) # Filter out matches that don't match the filter _matches = [ - m for m in _matches - if all(match_filter(m, original_graph, pattern_graph) - for match_filter in match_filters) + m + for m in _matches + if all( + match_filter(m, original_graph, pattern_graph) + for match_filter in match_filters + ) ] if isinstance(replacement, GraphModule): @@ -272,7 +317,9 @@ def _replace_pattern( elif callable(replacement): common_replacement_graph = symbolic_trace(replacement).graph else: - assert replacement_callback is not None, "Must provide either a replacement GraphModule or a replacement callback" + assert ( + replacement_callback is not None + ), "Must provide either a replacement GraphModule or a replacement callback" common_replacement_graph = None # As we progressively replace nodes, we'll need to keep track of how the match results should change @@ -281,11 +328,17 @@ def _replace_pattern( match_and_replacements = [] for match in _matches: if replacement_callback is not None: - replacement_graph = replacement_callback(match, original_graph, pattern_graph) + replacement_graph = replacement_callback( + match, original_graph, pattern_graph + ) else: - assert common_replacement_graph is not None, "Must provide either a replacement GraphModule or a replacement callback" + assert ( + common_replacement_graph is not None + ), "Must provide either a replacement GraphModule or a replacement callback" replacement_graph = common_replacement_graph - replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] + replacement_placeholders = [ + n for n in replacement_graph.nodes if n.op == "placeholder" + ] # Build connecting between replacement graph's input and original graph input producer node @@ -300,7 +353,9 @@ def _replace_pattern( # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn gn_ind = match.placeholder_nodes.index(gn) match.placeholder_nodes[gn_ind] = match_changed_node[gn] - map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)] + map_key = list(match.nodes_map.keys())[ + list(match.nodes_map.values()).index(gn) + ] match.nodes_map[map_key] = match_changed_node[gn] else: val_map[rn] = gn @@ -322,13 +377,17 @@ def _replace_pattern( break with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined] - copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) + copied_returning_nodes = original_graph.graph_copy( + replacement_graph, val_map + ) if isinstance(copied_returning_nodes, Node): - copied_returning_nodes = (copied_returning_nodes, ) + copied_returning_nodes = (copied_returning_nodes,) # Get a list of nodes that have been replaced into the graph - replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes] + replacement_nodes: List[Node] = [ + v for v in val_map.values() if v not in match.placeholder_nodes + ] # Hook the output Node of the replacement subgraph in to the # original Graph at the correct location @@ -346,7 +405,7 @@ def _replace_pattern( ReplacedPatterns( anchor=match.anchors[0], nodes_map=match.nodes_map, - replacements=replacement_nodes + replacements=replacement_nodes, ) ) diff --git a/torch/fx/tensor_type.py b/torch/fx/tensor_type.py index 83b5a9f8faf6..4f375e461ef2 100644 --- a/torch/fx/tensor_type.py +++ b/torch/fx/tensor_type.py @@ -19,7 +19,7 @@ class TensorType: self.__args__ = dim def __repr__(self): - return f'TensorType[{self.__args__}]' + return f"TensorType[{self.__args__}]" def __eq__(self, other): if isinstance(other, self.__class__): @@ -38,8 +38,9 @@ class _DynType: """ _DynType defines a type which stands for the absence of type information. """ + def __init__(self) -> None: - self.__name__ = '_DynType' + self.__name__ = "_DynType" def __eq__(self, other): return isinstance(other, self.__class__) @@ -53,6 +54,7 @@ class _DynType: Dyn = _DynType() + @compatibility(is_backward_compatible=False) def is_consistent(t1, t2): """ @@ -73,8 +75,10 @@ def is_consistent(t1, t2): return True if isinstance(t1, TensorType) and isinstance(t2, TensorType): - return len(t1.__args__) == len(t2.__args__) and \ - all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + return len(t1.__args__) == len(t2.__args__) and all( + is_consistent(elem1, elem2) + for elem1, elem2 in zip(t1.__args__, t2.__args__) + ) else: return False @@ -98,8 +102,10 @@ def is_more_precise(t1, t2): return True if isinstance(t1, TensorType) and isinstance(t2, TensorType): - return len(t1.__args__) == len(t2.__args__) and \ - all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + return len(t1.__args__) == len(t2.__args__) and all( + is_more_precise(elem1, elem2) + for elem1, elem2 in zip(t1.__args__, t2.__args__) + ) else: return False diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 4e72a8011f63..84c94c75cf66 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -1,12 +1,21 @@ # mypy: allow-untyped-defs import traceback from contextlib import contextmanager -from typing import List, Any, Dict +from typing import Any, Dict, List + from ._compatibility import compatibility -__all__ = ['preserve_node_meta', 'has_preserved_node_meta', - 'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr', - 'format_stack', 'set_current_meta', 'get_current_meta'] + +__all__ = [ + "preserve_node_meta", + "has_preserved_node_meta", + "set_stack_trace", + "set_grad_fn_seq_nr", + "reset_grad_fn_seq_nr", + "format_stack", + "set_current_meta", + "get_current_meta", +] current_meta: Dict[str, Any] = {} should_preserve_node_meta = False @@ -30,7 +39,7 @@ def preserve_node_meta(): @compatibility(is_backward_compatible=False) -def set_stack_trace(stack : List[str]): +def set_stack_trace(stack: List[str]): global current_meta if should_preserve_node_meta and stack: @@ -43,7 +52,9 @@ def set_grad_fn_seq_nr(seq_nr): if should_preserve_node_meta: # The seq_nr is captured by eager mode in the grad_fn during forward - current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr] + current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [ + seq_nr + ] current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 @@ -90,7 +101,9 @@ def set_current_meta(node): if "from_node" not in current_meta: current_meta["from_node"] = [(node.name, node.target)] elif current_meta["from_node"][-1][0] != node.name: - current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)] + current_meta["from_node"] = current_meta["from_node"] + [ + (node.name, node.target) + ] yield finally: