Revert "[DebugMode][2/N] add nn.Module tracking (#165498)"

This reverts commit 45afaf08a14ab760d86ea80dea6d50cec8626513.

Reverted https://github.com/pytorch/pytorch/pull/165498 on behalf of https://github.com/seemethere due to First part of the stack was reverted so will need to revert this too ([comment](https://github.com/pytorch/pytorch/pull/165498#issuecomment-3416618198))
This commit is contained in:
PyTorch MergeBot
2025-10-17 18:22:46 +00:00
parent ca5b7f8ded
commit b08d8c2e50
2 changed files with 1 additions and 84 deletions

View File

@ -330,46 +330,6 @@ class TestDTensorDebugMode(TestCase):
f(x)
self.assertEqual(len(debug_mode.debug_string()), 0)
def test_nn_module(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(4, 4)
self.l2 = torch.nn.Linear(4, 4)
def forward(self, x):
return self.l2(self.l1(x))
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.abc = Foo()
self.xyz = torch.nn.Linear(4, 4)
def forward(self, x):
return self.xyz(self.abc(x))
mod = Bar()
inp = torch.randn(4, 4)
with DebugMode(record_nn_module=True) as debug_mode:
_ = mod(inp)
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
[nn.Mod] Bar
[nn.Mod] Bar.abc
[nn.Mod] Bar.abc.l1
aten::t(t: f32[4, 4])
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
[nn.Mod] Bar.abc.l2
aten::t(t: f32[4, 4])
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
[nn.Mod] Bar.xyz
aten::t(t: f32[4, 4])
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Optional, TYPE_CHECKING
from typing import Optional
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -13,10 +13,6 @@ from torch.utils._python_dispatch import (
from torch.utils._pytree import tree_map
if TYPE_CHECKING:
from torch.distributed._tools.mod_tracker import ModTracker
__all__ = ["DebugMode", "get_active_debug_mode"]
REDISTRIBUTE_FUNC = "redistribute_input"
@ -110,17 +106,6 @@ def _op_to_str(op, attributes, *args, **kwargs) -> str:
return f"{op_name}({args_str}{kwargs_str})"
class _NNModuleCall(_DebugCall):
"""Designates entering an nn.Module's forward method"""
def __init__(self, module_name: str, call_depth: int):
super().__init__(call_depth)
self.module_name = module_name
def render(self, attributes: list[str]) -> str:
return f"[nn.Mod] {self.module_name}"
class DebugMode(TorchDispatchMode):
def __init__(
self,
@ -129,7 +114,6 @@ class DebugMode(TorchDispatchMode):
record_faketensor=False,
record_realtensor=True,
record_tensor_attributes=None,
record_nn_module=False,
):
super().__init__()
import torch.distributed.tensor # noqa: F401
@ -140,12 +124,6 @@ class DebugMode(TorchDispatchMode):
self.record_realtensor = record_realtensor
self.record_tensor_attributes = record_tensor_attributes or []
self.record_nn_module = record_nn_module
self.module_tracker: Optional[ModTracker] = None
if self.record_nn_module:
self.module_tracker_setup()
self.operators = []
self.call_depth = 0
@ -198,35 +176,14 @@ class DebugMode(TorchDispatchMode):
torch._C._push_on_torch_function_stack(self)
super().__enter__()
if self.record_nn_module:
self.module_tracker.__enter__() # type: ignore[attribute, union-attr]
return self
# pyrefly: ignore # bad-override
def __exit__(self, *args):
super().__exit__(*args)
if self.record_nn_module:
self.module_tracker.__exit__() # type: ignore[attribute, union-attr]
if self.record_torchfunction:
torch._C._pop_torch_function_stack()
def module_tracker_setup(self):
from torch.distributed._tools.mod_tracker import ModTracker
self.module_tracker = ModTracker()
# module pre-fw hook: record module call
def pre_fw_hook(module, input):
fqn = self.module_tracker._get_mod_name(module) # type: ignore[attribute, union-attr]
self.operators.append(_NNModuleCall(fqn, self.call_depth + 1))
self.call_depth += 1
# module post-fw hook: decrement call depth
def post_fw_hook(module, input, output):
self.call_depth -= 1
self.module_tracker.register_user_hooks(pre_fw_hook, post_fw_hook)
@contextlib.contextmanager
def record_redistribute_calls(
self,