Revert "Reland "Add heirachical module names to torchFX graph.node" (#90205)"

This reverts commit 6b7efac3c9ea5c9fbfb18069abd254ad7d9a103e.

Reverted https://github.com/pytorch/pytorch/pull/90205 on behalf of https://github.com/seemethere due to Reverting since this caused failures in internal systems, see https://fb.workplace.com/groups/802176577445480/posts/894284641568006 for discussion
This commit is contained in:
PyTorch MergeBot
2022-12-13 17:47:03 +00:00
parent 1439ebd899
commit 1119d2fa54
7 changed files with 154 additions and 145 deletions

View File

@ -11,6 +11,8 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
'_check_is_graph_module',
'_swap_ff_with_fxff',
'_fuse_fx',
'Scope',
'ScopeContextManager',
'QuantizationTracer',
'_prepare_fx',
'_prepare_standalone_module_fx',

View File

@ -1679,36 +1679,6 @@ class TestFX(JitTestCase):
if node.op in {'placeholder'}:
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
def test_nn_module_stack(self):
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False)
def forward(self, x):
return self.conv_mod(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.sub_mod = SubModule()
def forward(self, x):
return self.sub_mod(x)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
mod_stack = {}
expected_stack = [('sub_mod', str(type(m.sub_mod))),
('sub_mod.conv_mod', str(type(m.sub_mod.conv_mod)))]
for node in gm.graph.nodes:
mod_stack = node.meta.get('nn_module_stack', {})
if mod_stack:
break
stack_list = list(mod_stack.items())
self.assertEqual(stack_list, expected_stack)
def test_interpreter(self):
class MyModule(torch.nn.Module):
def __init__(self):

View File

@ -1,13 +1,67 @@
import torch
from torch.fx._symbolic_trace import Tracer
from torch.fx.proxy import Scope
from torch.fx.node import Target, Node, Argument
from torch.nn.intrinsic import _FusedModule
from typing import List, Callable
from typing import List, Callable, Tuple, Any, Dict, Optional
__all__ = [
"QuantizationTracer",
]
class Scope(object):
""" Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example::
class Sub(torch.nn.Module):
def forward(self, x):
# This will be a call_method Node in GraphModule,
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self):
self.sub = Sub()
def forward(self, x):
# This will be a call_method Node as well,
# scope for this would be (module_path="", None)
x = x.transpose(1, 2)
x = self.sub(x)
return x
"""
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type
class ScopeContextManager(object):
""" A context manager to track the Scope of Node during symbolic tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
def __init__(
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
):
super().__init__()
self.prev_module_type = scope.module_type
self.prev_module_path = scope.module_path
self.scope = scope
self.scope.module_path = current_module_path
self.scope.module_type = type(current_module)
def __enter__(self):
return
def __exit__(self, *args):
self.scope.module_path = self.prev_module_path
self.scope.module_type = self.prev_module_type
return
class QuantizationTracer(Tracer):
def __init__(
self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
@ -21,6 +75,7 @@ class QuantizationTracer(Tracer):
# We can change this if there is a use case that configures
# qconfig using top level module type
self.scope = Scope("", None)
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
self.record_stack_traces = True
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
@ -33,3 +88,32 @@ class QuantizationTracer(Tracer):
or type(m) in self.skipped_module_classes
or isinstance(m, _FusedModule)
)
def call_module(
self,
m: torch.nn.Module,
forward: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
module_qualified_name = self.path_of_module(m)
# Creating scope with information of current module
# scope will be restored automatically upon exit
with ScopeContextManager(self.scope, m, module_qualified_name):
return super().call_module(m, forward, args, kwargs)
def create_node(
self,
kind: str,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
) -> Node:
node = super().create_node(kind, target, args, kwargs, name, type_expr)
self.node_name_to_scope[node.name] = (
self.scope.module_path,
self.scope.module_type,
)
return node

View File

@ -63,6 +63,61 @@ def _fuse_fx(
graph_module, is_qat, fuse_custom_config, backend_config) # type: ignore[operator]
class Scope(object):
""" Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example::
class Sub(torch.nn.Module):
def forward(self, x):
# This will be a call_method Node in GraphModule,
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self):
self.sub = Sub()
def forward(self, x):
# This will be a call_method Node as well,
# scope for this would be (module_path="", None)
x = x.transpose(1, 2)
x = self.sub(x)
return x
"""
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type
class ScopeContextManager(object):
""" A context manager to track the Scope of Node during symbolic tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
def __init__(
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
):
super().__init__()
self.prev_module_type = scope.module_type
self.prev_module_path = scope.module_path
self.scope = scope
self.scope.module_path = current_module_path
self.scope.module_type = type(current_module)
def __enter__(self):
return
def __exit__(self, *args):
self.scope.module_path = self.prev_module_path
self.scope.module_type = self.prev_module_type
return
def _prepare_fx(
model: torch.nn.Module,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],

View File

@ -5,7 +5,6 @@ import inspect
import math
import os
import warnings
import collections
from itertools import chain
from types import CodeType, FunctionType, ModuleType
from typing import (
@ -29,7 +28,7 @@ from ._compatibility import compatibility
from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
from .graph_module import GraphModule
from .node import Argument, base_types, map_aggregate
from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
from .proxy import ParameterProxy, Proxy, TracerBase
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
@ -45,6 +44,7 @@ _is_fx_tracing_flag = False
def is_fx_tracing():
return _is_fx_tracing_flag
@compatibility(is_backward_compatible=True)
class ProxyableClassMeta(type):
"""
@ -250,13 +250,6 @@ class Tracer(TracerBase):
self.param_shapes_constant = param_shapes_constant
self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
self.root_module_name: str = ""
# Maps the containing module's name to the operator name
self.scope = Scope("", None)
# Records the module call stack
self.module_stack = collections.OrderedDict()
# Mapping of node name to module scope
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
@compatibility(is_backward_compatible=True)
def create_arg(self, a: Any) -> "Argument":
@ -437,18 +430,9 @@ class Tracer(TracerBase):
value was returned from the ``Module`` invocation.
"""
module_qualified_name = self.path_of_module(m)
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
# module_stack is an ordered dict so writing then deleting the
# entry is equivalent to push/pop on a list
self.module_stack[_scope.module_path] = str(_scope.module_type)
if not self.is_leaf_module(m, module_qualified_name):
ret_val = forward(*args, **kwargs)
else:
ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
key, _ = self.module_stack.popitem(last=True)
assert key == _scope.module_path, f" Unexpected key {key}"
return ret_val
return forward(*args, **kwargs)
return self.create_proxy("call_module", module_qualified_name, args, kwargs)
@compatibility(is_backward_compatible=False)
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
@ -596,7 +580,7 @@ class Tracer(TracerBase):
name,
default,
{},
type_expr=fn_for_analysis.__annotations__.get(name, None)
type_expr=fn_for_analysis.__annotations__.get(name, None),
)
arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
@ -679,7 +663,6 @@ class Tracer(TracerBase):
), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
fn = getattr(type(root), self.traced_func_name)
self.root_module_name = root._get_name()
self.submodule_paths = {mod: name for name, mod in root.named_modules()}
else:
self.root = torch.nn.Module()

View File

@ -1,83 +1,17 @@
import dis
import copy
import torch
import inspect
import operator
import traceback
import collections
from .graph import magic_methods, reflectable_magic_methods, Graph
from typing import Tuple, Dict, OrderedDict, Optional, Iterable, Any, Iterator, Callable
from typing import Tuple, Dict, Optional, Iterable, Any, Iterator, Callable
from .node import Target, Node, Argument, base_types, map_aggregate
from ._compatibility import compatibility
from .operator_schemas import check_for_mutable_operation
import torch.fx.traceback as fx_traceback
__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
'Proxy', 'Attribute', 'ParameterProxy', 'Scope',
'ScopeContextManager']
@compatibility(is_backward_compatible=False)
class Scope(object):
""" Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example::
class Sub(torch.nn.Module):
def forward(self, x):
# This will be a call_method Node in GraphModule,
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self):
self.sub = Sub()
def forward(self, x):
# This will be a call_method Node as well,
# scope for this would be (module_path="", None)
x = x.transpose(1, 2)
x = self.sub(x)
return x
"""
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type
@compatibility(is_backward_compatible=False)
class ScopeContextManager(object):
""" A context manager to track the Scope of Node during symbolic tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
def __init__(
self,
scope: Scope,
current_scope: Scope,
):
super().__init__()
# Keep a copy of prev scope to restore on exit
self._prev_scope = copy.copy(scope)
# Update scope to current scope
scope.module_path = current_scope.module_path
scope.module_type = current_scope.module_type
# Save a reference so we can restore it
self._scope = scope
def __enter__(self):
return self._scope
def __exit__(self, *args):
self._scope.module_path = self._prev_scope.module_path
self._scope.module_type = self._prev_scope.module_type
return
__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', 'Proxy', 'Attribute', 'ParameterProxy']
@compatibility(is_backward_compatible=True)
class TracerBase:
@ -95,15 +29,6 @@ class TracerBase:
# ``root`` is an instance of ``nn.Module``
traced_func_name: str = "forward"
# Maps the containing module's name to the operator name
scope : Scope
# Records the module call stack
module_stack: OrderedDict[str, str]
# Mapping of node name to module scope
node_name_to_scope: Dict[str, Tuple[str, type]]
@compatibility(is_backward_compatible=True)
def create_node(self, kind : str, target : Target,
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
@ -118,16 +43,7 @@ class TracerBase:
if kind == 'call_function' and self.check_mutable_operations:
check_for_mutable_operation(target, args, kwargs)
node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
# TODO node_name_to_scope will be depricated in favor of
# node.meta['nn_module_stack']
self.node_name_to_scope[node.name] = (
self.scope.module_path,
self.scope.module_type,
)
if self.module_stack:
node.meta['nn_module_stack'] = copy.copy(self.module_stack)
return node
return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
@compatibility(is_backward_compatible=True)
def proxy(self, node: Node) -> 'Proxy':
@ -291,9 +207,6 @@ class GraphAppendingTracer(TracerBase):
def __init__(self, graph: Graph):
super().__init__()
self.graph = graph
self.scope = Scope("", None)
self.module_stack = collections.OrderedDict()
self.node_name_to_scope = {}
@compatibility(is_backward_compatible=False)
def assert_fn(x):

View File

@ -11,6 +11,8 @@ from torch.ao.quantization.quantize_fx import (
_check_is_graph_module,
_swap_ff_with_fxff,
_fuse_fx,
Scope,
ScopeContextManager,
QuantizationTracer,
_prepare_fx,
_prepare_standalone_module_fx,