mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[DebugMode][2/N] add nn.Module tracking (#165498)"
This reverts commit 45afaf08a14ab760d86ea80dea6d50cec8626513. Reverted https://github.com/pytorch/pytorch/pull/165498 on behalf of https://github.com/seemethere due to First part of the stack was reverted so will need to revert this too ([comment](https://github.com/pytorch/pytorch/pull/165498#issuecomment-3416618198))
This commit is contained in:
@ -330,46 +330,6 @@ class TestDTensorDebugMode(TestCase):
|
||||
f(x)
|
||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||
|
||||
def test_nn_module(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(4, 4)
|
||||
self.l2 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.l2(self.l1(x))
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.abc = Foo()
|
||||
self.xyz = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.xyz(self.abc(x))
|
||||
|
||||
mod = Bar()
|
||||
inp = torch.randn(4, 4)
|
||||
with DebugMode(record_nn_module=True) as debug_mode:
|
||||
_ = mod(inp)
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
[nn.Mod] Bar
|
||||
[nn.Mod] Bar.abc
|
||||
[nn.Mod] Bar.abc.l1
|
||||
aten::t(t: f32[4, 4])
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
|
||||
[nn.Mod] Bar.abc.l2
|
||||
aten::t(t: f32[4, 4])
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
|
||||
[nn.Mod] Bar.xyz
|
||||
aten::t(t: f32[4, 4])
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
@ -13,10 +13,6 @@ from torch.utils._python_dispatch import (
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed._tools.mod_tracker import ModTracker
|
||||
|
||||
|
||||
__all__ = ["DebugMode", "get_active_debug_mode"]
|
||||
|
||||
REDISTRIBUTE_FUNC = "redistribute_input"
|
||||
@ -110,17 +106,6 @@ def _op_to_str(op, attributes, *args, **kwargs) -> str:
|
||||
return f"{op_name}({args_str}{kwargs_str})"
|
||||
|
||||
|
||||
class _NNModuleCall(_DebugCall):
|
||||
"""Designates entering an nn.Module's forward method"""
|
||||
|
||||
def __init__(self, module_name: str, call_depth: int):
|
||||
super().__init__(call_depth)
|
||||
self.module_name = module_name
|
||||
|
||||
def render(self, attributes: list[str]) -> str:
|
||||
return f"[nn.Mod] {self.module_name}"
|
||||
|
||||
|
||||
class DebugMode(TorchDispatchMode):
|
||||
def __init__(
|
||||
self,
|
||||
@ -129,7 +114,6 @@ class DebugMode(TorchDispatchMode):
|
||||
record_faketensor=False,
|
||||
record_realtensor=True,
|
||||
record_tensor_attributes=None,
|
||||
record_nn_module=False,
|
||||
):
|
||||
super().__init__()
|
||||
import torch.distributed.tensor # noqa: F401
|
||||
@ -140,12 +124,6 @@ class DebugMode(TorchDispatchMode):
|
||||
self.record_realtensor = record_realtensor
|
||||
self.record_tensor_attributes = record_tensor_attributes or []
|
||||
|
||||
self.record_nn_module = record_nn_module
|
||||
|
||||
self.module_tracker: Optional[ModTracker] = None
|
||||
if self.record_nn_module:
|
||||
self.module_tracker_setup()
|
||||
|
||||
self.operators = []
|
||||
self.call_depth = 0
|
||||
|
||||
@ -198,35 +176,14 @@ class DebugMode(TorchDispatchMode):
|
||||
torch._C._push_on_torch_function_stack(self)
|
||||
|
||||
super().__enter__()
|
||||
if self.record_nn_module:
|
||||
self.module_tracker.__enter__() # type: ignore[attribute, union-attr]
|
||||
return self
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
def __exit__(self, *args):
|
||||
super().__exit__(*args)
|
||||
if self.record_nn_module:
|
||||
self.module_tracker.__exit__() # type: ignore[attribute, union-attr]
|
||||
if self.record_torchfunction:
|
||||
torch._C._pop_torch_function_stack()
|
||||
|
||||
def module_tracker_setup(self):
|
||||
from torch.distributed._tools.mod_tracker import ModTracker
|
||||
|
||||
self.module_tracker = ModTracker()
|
||||
|
||||
# module pre-fw hook: record module call
|
||||
def pre_fw_hook(module, input):
|
||||
fqn = self.module_tracker._get_mod_name(module) # type: ignore[attribute, union-attr]
|
||||
self.operators.append(_NNModuleCall(fqn, self.call_depth + 1))
|
||||
self.call_depth += 1
|
||||
|
||||
# module post-fw hook: decrement call depth
|
||||
def post_fw_hook(module, input, output):
|
||||
self.call_depth -= 1
|
||||
|
||||
self.module_tracker.register_user_hooks(pre_fw_hook, post_fw_hook)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def record_redistribute_calls(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user