diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 20da99f52eb0..aab91ddebe94 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -330,46 +330,6 @@ class TestDTensorDebugMode(TestCase): f(x) self.assertEqual(len(debug_mode.debug_string()), 0) - def test_nn_module(self): - class Foo(torch.nn.Module): - def __init__(self): - super().__init__() - self.l1 = torch.nn.Linear(4, 4) - self.l2 = torch.nn.Linear(4, 4) - - def forward(self, x): - return self.l2(self.l1(x)) - - class Bar(torch.nn.Module): - def __init__(self): - super().__init__() - self.abc = Foo() - self.xyz = torch.nn.Linear(4, 4) - - def forward(self, x): - return self.xyz(self.abc(x)) - - mod = Bar() - inp = torch.randn(4, 4) - with DebugMode(record_nn_module=True) as debug_mode: - _ = mod(inp) - - self.assertExpectedInline( - debug_mode.debug_string(), - """\ - [nn.Mod] Bar - [nn.Mod] Bar.abc - [nn.Mod] Bar.abc.l1 - aten::t(t: f32[4, 4]) - aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4]) - [nn.Mod] Bar.abc.l2 - aten::t(t: f32[4, 4]) - aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4]) - [nn.Mod] Bar.xyz - aten::t(t: f32[4, 4]) - aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""", - ) - instantiate_parametrized_tests(TestDTensorDebugMode) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 1986828c519b..7f7de2b7334f 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import contextlib -from typing import Optional, TYPE_CHECKING +from typing import Optional import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -13,10 +13,6 @@ from torch.utils._python_dispatch import ( from torch.utils._pytree import tree_map -if TYPE_CHECKING: - from torch.distributed._tools.mod_tracker import ModTracker - - __all__ = ["DebugMode", "get_active_debug_mode"] REDISTRIBUTE_FUNC = "redistribute_input" @@ -110,17 +106,6 @@ def _op_to_str(op, attributes, *args, **kwargs) -> str: return f"{op_name}({args_str}{kwargs_str})" -class _NNModuleCall(_DebugCall): - """Designates entering an nn.Module's forward method""" - - def __init__(self, module_name: str, call_depth: int): - super().__init__(call_depth) - self.module_name = module_name - - def render(self, attributes: list[str]) -> str: - return f"[nn.Mod] {self.module_name}" - - class DebugMode(TorchDispatchMode): def __init__( self, @@ -129,7 +114,6 @@ class DebugMode(TorchDispatchMode): record_faketensor=False, record_realtensor=True, record_tensor_attributes=None, - record_nn_module=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 @@ -140,12 +124,6 @@ class DebugMode(TorchDispatchMode): self.record_realtensor = record_realtensor self.record_tensor_attributes = record_tensor_attributes or [] - self.record_nn_module = record_nn_module - - self.module_tracker: Optional[ModTracker] = None - if self.record_nn_module: - self.module_tracker_setup() - self.operators = [] self.call_depth = 0 @@ -198,35 +176,14 @@ class DebugMode(TorchDispatchMode): torch._C._push_on_torch_function_stack(self) super().__enter__() - if self.record_nn_module: - self.module_tracker.__enter__() # type: ignore[attribute, union-attr] return self # pyrefly: ignore # bad-override def __exit__(self, *args): super().__exit__(*args) - if self.record_nn_module: - self.module_tracker.__exit__() # type: ignore[attribute, union-attr] if self.record_torchfunction: torch._C._pop_torch_function_stack() - def module_tracker_setup(self): - from torch.distributed._tools.mod_tracker import ModTracker - - self.module_tracker = ModTracker() - - # module pre-fw hook: record module call - def pre_fw_hook(module, input): - fqn = self.module_tracker._get_mod_name(module) # type: ignore[attribute, union-attr] - self.operators.append(_NNModuleCall(fqn, self.call_depth + 1)) - self.call_depth += 1 - - # module post-fw hook: decrement call depth - def post_fw_hook(module, input, output): - self.call_depth -= 1 - - self.module_tracker.register_user_hooks(pre_fw_hook, post_fw_hook) - @contextlib.contextmanager def record_redistribute_calls( self,