mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Reland "Add heirachical module names to torchFX graph.node" (#90205)"
This reverts commit 6b7efac3c9ea5c9fbfb18069abd254ad7d9a103e. Reverted https://github.com/pytorch/pytorch/pull/90205 on behalf of https://github.com/seemethere due to Reverting since this caused failures in internal systems, see https://fb.workplace.com/groups/802176577445480/posts/894284641568006 for discussion
This commit is contained in:
@ -11,6 +11,8 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
'_check_is_graph_module',
|
||||
'_swap_ff_with_fxff',
|
||||
'_fuse_fx',
|
||||
'Scope',
|
||||
'ScopeContextManager',
|
||||
'QuantizationTracer',
|
||||
'_prepare_fx',
|
||||
'_prepare_standalone_module_fx',
|
||||
|
@ -1679,36 +1679,6 @@ class TestFX(JitTestCase):
|
||||
if node.op in {'placeholder'}:
|
||||
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
|
||||
|
||||
def test_nn_module_stack(self):
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_mod(x)
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sub_mod = SubModule()
|
||||
|
||||
def forward(self, x):
|
||||
return self.sub_mod(x)
|
||||
|
||||
m = MyModule()
|
||||
gm = torch.fx.symbolic_trace(m)
|
||||
|
||||
mod_stack = {}
|
||||
expected_stack = [('sub_mod', str(type(m.sub_mod))),
|
||||
('sub_mod.conv_mod', str(type(m.sub_mod.conv_mod)))]
|
||||
for node in gm.graph.nodes:
|
||||
mod_stack = node.meta.get('nn_module_stack', {})
|
||||
if mod_stack:
|
||||
break
|
||||
stack_list = list(mod_stack.items())
|
||||
self.assertEqual(stack_list, expected_stack)
|
||||
|
||||
def test_interpreter(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1,13 +1,67 @@
|
||||
import torch
|
||||
from torch.fx._symbolic_trace import Tracer
|
||||
from torch.fx.proxy import Scope
|
||||
from torch.fx.node import Target, Node, Argument
|
||||
from torch.nn.intrinsic import _FusedModule
|
||||
from typing import List, Callable
|
||||
from typing import List, Callable, Tuple, Any, Dict, Optional
|
||||
|
||||
__all__ = [
|
||||
"QuantizationTracer",
|
||||
]
|
||||
|
||||
class Scope(object):
|
||||
""" Scope object that records the module path and the module type
|
||||
of a module. Scope is used to track the information of the module
|
||||
that contains a Node in a Graph of GraphModule. For example::
|
||||
|
||||
class Sub(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
# This will be a call_method Node in GraphModule,
|
||||
# scope for this would be (module_path="sub", module_type=Sub)
|
||||
return x.transpose(1, 2)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
self.sub = Sub()
|
||||
|
||||
def forward(self, x):
|
||||
# This will be a call_method Node as well,
|
||||
# scope for this would be (module_path="", None)
|
||||
x = x.transpose(1, 2)
|
||||
x = self.sub(x)
|
||||
return x
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, module_path: str, module_type: Any):
|
||||
super().__init__()
|
||||
self.module_path = module_path
|
||||
self.module_type = module_type
|
||||
|
||||
|
||||
class ScopeContextManager(object):
|
||||
""" A context manager to track the Scope of Node during symbolic tracing.
|
||||
When entering a forward function of a Module, we'll update the scope information of
|
||||
the current module, and when we exit, we'll restore the previous scope information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
|
||||
):
|
||||
super().__init__()
|
||||
self.prev_module_type = scope.module_type
|
||||
self.prev_module_path = scope.module_path
|
||||
self.scope = scope
|
||||
self.scope.module_path = current_module_path
|
||||
self.scope.module_type = type(current_module)
|
||||
|
||||
def __enter__(self):
|
||||
return
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.scope.module_path = self.prev_module_path
|
||||
self.scope.module_type = self.prev_module_type
|
||||
return
|
||||
|
||||
class QuantizationTracer(Tracer):
|
||||
def __init__(
|
||||
self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
|
||||
@ -21,6 +75,7 @@ class QuantizationTracer(Tracer):
|
||||
# We can change this if there is a use case that configures
|
||||
# qconfig using top level module type
|
||||
self.scope = Scope("", None)
|
||||
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
|
||||
self.record_stack_traces = True
|
||||
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
||||
@ -33,3 +88,32 @@ class QuantizationTracer(Tracer):
|
||||
or type(m) in self.skipped_module_classes
|
||||
or isinstance(m, _FusedModule)
|
||||
)
|
||||
|
||||
def call_module(
|
||||
self,
|
||||
m: torch.nn.Module,
|
||||
forward: Callable[..., Any],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Any:
|
||||
module_qualified_name = self.path_of_module(m)
|
||||
# Creating scope with information of current module
|
||||
# scope will be restored automatically upon exit
|
||||
with ScopeContextManager(self.scope, m, module_qualified_name):
|
||||
return super().call_module(m, forward, args, kwargs)
|
||||
|
||||
def create_node(
|
||||
self,
|
||||
kind: str,
|
||||
target: Target,
|
||||
args: Tuple[Argument, ...],
|
||||
kwargs: Dict[str, Argument],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
) -> Node:
|
||||
node = super().create_node(kind, target, args, kwargs, name, type_expr)
|
||||
self.node_name_to_scope[node.name] = (
|
||||
self.scope.module_path,
|
||||
self.scope.module_type,
|
||||
)
|
||||
return node
|
||||
|
@ -63,6 +63,61 @@ def _fuse_fx(
|
||||
graph_module, is_qat, fuse_custom_config, backend_config) # type: ignore[operator]
|
||||
|
||||
|
||||
class Scope(object):
|
||||
""" Scope object that records the module path and the module type
|
||||
of a module. Scope is used to track the information of the module
|
||||
that contains a Node in a Graph of GraphModule. For example::
|
||||
|
||||
class Sub(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
# This will be a call_method Node in GraphModule,
|
||||
# scope for this would be (module_path="sub", module_type=Sub)
|
||||
return x.transpose(1, 2)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
self.sub = Sub()
|
||||
|
||||
def forward(self, x):
|
||||
# This will be a call_method Node as well,
|
||||
# scope for this would be (module_path="", None)
|
||||
x = x.transpose(1, 2)
|
||||
x = self.sub(x)
|
||||
return x
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, module_path: str, module_type: Any):
|
||||
super().__init__()
|
||||
self.module_path = module_path
|
||||
self.module_type = module_type
|
||||
|
||||
|
||||
class ScopeContextManager(object):
|
||||
""" A context manager to track the Scope of Node during symbolic tracing.
|
||||
When entering a forward function of a Module, we'll update the scope information of
|
||||
the current module, and when we exit, we'll restore the previous scope information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
|
||||
):
|
||||
super().__init__()
|
||||
self.prev_module_type = scope.module_type
|
||||
self.prev_module_path = scope.module_path
|
||||
self.scope = scope
|
||||
self.scope.module_path = current_module_path
|
||||
self.scope.module_type = type(current_module)
|
||||
|
||||
def __enter__(self):
|
||||
return
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.scope.module_path = self.prev_module_path
|
||||
self.scope.module_type = self.prev_module_type
|
||||
return
|
||||
|
||||
|
||||
def _prepare_fx(
|
||||
model: torch.nn.Module,
|
||||
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
|
||||
|
@ -5,7 +5,6 @@ import inspect
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
import collections
|
||||
from itertools import chain
|
||||
from types import CodeType, FunctionType, ModuleType
|
||||
from typing import (
|
||||
@ -29,7 +28,7 @@ from ._compatibility import compatibility
|
||||
from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
|
||||
from .graph_module import GraphModule
|
||||
from .node import Argument, base_types, map_aggregate
|
||||
from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
|
||||
from .proxy import ParameterProxy, Proxy, TracerBase
|
||||
|
||||
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
|
||||
|
||||
@ -45,6 +44,7 @@ _is_fx_tracing_flag = False
|
||||
def is_fx_tracing():
|
||||
return _is_fx_tracing_flag
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ProxyableClassMeta(type):
|
||||
"""
|
||||
@ -250,13 +250,6 @@ class Tracer(TracerBase):
|
||||
self.param_shapes_constant = param_shapes_constant
|
||||
|
||||
self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
|
||||
self.root_module_name: str = ""
|
||||
# Maps the containing module's name to the operator name
|
||||
self.scope = Scope("", None)
|
||||
# Records the module call stack
|
||||
self.module_stack = collections.OrderedDict()
|
||||
# Mapping of node name to module scope
|
||||
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def create_arg(self, a: Any) -> "Argument":
|
||||
@ -437,18 +430,9 @@ class Tracer(TracerBase):
|
||||
value was returned from the ``Module`` invocation.
|
||||
"""
|
||||
module_qualified_name = self.path_of_module(m)
|
||||
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
|
||||
# module_stack is an ordered dict so writing then deleting the
|
||||
# entry is equivalent to push/pop on a list
|
||||
self.module_stack[_scope.module_path] = str(_scope.module_type)
|
||||
if not self.is_leaf_module(m, module_qualified_name):
|
||||
ret_val = forward(*args, **kwargs)
|
||||
else:
|
||||
ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
|
||||
key, _ = self.module_stack.popitem(last=True)
|
||||
assert key == _scope.module_path, f" Unexpected key {key}"
|
||||
|
||||
return ret_val
|
||||
return forward(*args, **kwargs)
|
||||
return self.create_proxy("call_module", module_qualified_name, args, kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
|
||||
@ -596,7 +580,7 @@ class Tracer(TracerBase):
|
||||
name,
|
||||
default,
|
||||
{},
|
||||
type_expr=fn_for_analysis.__annotations__.get(name, None)
|
||||
type_expr=fn_for_analysis.__annotations__.get(name, None),
|
||||
)
|
||||
|
||||
arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
|
||||
@ -679,7 +663,6 @@ class Tracer(TracerBase):
|
||||
), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
|
||||
|
||||
fn = getattr(type(root), self.traced_func_name)
|
||||
self.root_module_name = root._get_name()
|
||||
self.submodule_paths = {mod: name for name, mod in root.named_modules()}
|
||||
else:
|
||||
self.root = torch.nn.Module()
|
||||
|
@ -1,83 +1,17 @@
|
||||
import dis
|
||||
import copy
|
||||
import torch
|
||||
import inspect
|
||||
import operator
|
||||
import traceback
|
||||
import collections
|
||||
|
||||
from .graph import magic_methods, reflectable_magic_methods, Graph
|
||||
from typing import Tuple, Dict, OrderedDict, Optional, Iterable, Any, Iterator, Callable
|
||||
from typing import Tuple, Dict, Optional, Iterable, Any, Iterator, Callable
|
||||
from .node import Target, Node, Argument, base_types, map_aggregate
|
||||
from ._compatibility import compatibility
|
||||
from .operator_schemas import check_for_mutable_operation
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
|
||||
'Proxy', 'Attribute', 'ParameterProxy', 'Scope',
|
||||
'ScopeContextManager']
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class Scope(object):
|
||||
""" Scope object that records the module path and the module type
|
||||
of a module. Scope is used to track the information of the module
|
||||
that contains a Node in a Graph of GraphModule. For example::
|
||||
|
||||
class Sub(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
# This will be a call_method Node in GraphModule,
|
||||
# scope for this would be (module_path="sub", module_type=Sub)
|
||||
return x.transpose(1, 2)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
self.sub = Sub()
|
||||
|
||||
def forward(self, x):
|
||||
# This will be a call_method Node as well,
|
||||
# scope for this would be (module_path="", None)
|
||||
x = x.transpose(1, 2)
|
||||
x = self.sub(x)
|
||||
return x
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, module_path: str, module_type: Any):
|
||||
super().__init__()
|
||||
self.module_path = module_path
|
||||
self.module_type = module_type
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
class ScopeContextManager(object):
|
||||
""" A context manager to track the Scope of Node during symbolic tracing.
|
||||
When entering a forward function of a Module, we'll update the scope information of
|
||||
the current module, and when we exit, we'll restore the previous scope information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scope: Scope,
|
||||
current_scope: Scope,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep a copy of prev scope to restore on exit
|
||||
self._prev_scope = copy.copy(scope)
|
||||
# Update scope to current scope
|
||||
scope.module_path = current_scope.module_path
|
||||
scope.module_type = current_scope.module_type
|
||||
# Save a reference so we can restore it
|
||||
self._scope = scope
|
||||
|
||||
def __enter__(self):
|
||||
return self._scope
|
||||
|
||||
def __exit__(self, *args):
|
||||
self._scope.module_path = self._prev_scope.module_path
|
||||
self._scope.module_type = self._prev_scope.module_type
|
||||
return
|
||||
|
||||
__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', 'Proxy', 'Attribute', 'ParameterProxy']
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class TracerBase:
|
||||
@ -95,15 +29,6 @@ class TracerBase:
|
||||
# ``root`` is an instance of ``nn.Module``
|
||||
traced_func_name: str = "forward"
|
||||
|
||||
# Maps the containing module's name to the operator name
|
||||
scope : Scope
|
||||
|
||||
# Records the module call stack
|
||||
module_stack: OrderedDict[str, str]
|
||||
|
||||
# Mapping of node name to module scope
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]]
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def create_node(self, kind : str, target : Target,
|
||||
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
|
||||
@ -118,16 +43,7 @@ class TracerBase:
|
||||
if kind == 'call_function' and self.check_mutable_operations:
|
||||
check_for_mutable_operation(target, args, kwargs)
|
||||
|
||||
node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
|
||||
# TODO node_name_to_scope will be depricated in favor of
|
||||
# node.meta['nn_module_stack']
|
||||
self.node_name_to_scope[node.name] = (
|
||||
self.scope.module_path,
|
||||
self.scope.module_type,
|
||||
)
|
||||
if self.module_stack:
|
||||
node.meta['nn_module_stack'] = copy.copy(self.module_stack)
|
||||
return node
|
||||
return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def proxy(self, node: Node) -> 'Proxy':
|
||||
@ -291,9 +207,6 @@ class GraphAppendingTracer(TracerBase):
|
||||
def __init__(self, graph: Graph):
|
||||
super().__init__()
|
||||
self.graph = graph
|
||||
self.scope = Scope("", None)
|
||||
self.module_stack = collections.OrderedDict()
|
||||
self.node_name_to_scope = {}
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def assert_fn(x):
|
||||
|
@ -11,6 +11,8 @@ from torch.ao.quantization.quantize_fx import (
|
||||
_check_is_graph_module,
|
||||
_swap_ff_with_fxff,
|
||||
_fuse_fx,
|
||||
Scope,
|
||||
ScopeContextManager,
|
||||
QuantizationTracer,
|
||||
_prepare_fx,
|
||||
_prepare_standalone_module_fx,
|
||||
|
Reference in New Issue
Block a user