Compare commits

...

12 Commits

Author SHA1 Message Date
a46360061f Lazy import to avoid circular import issue 2025-09-19 15:47:03 -07:00
2cb3744b19 DisableTorchFunction in debug_string 2025-09-16 13:00:05 -07:00
28d63dab1c fix 2025-09-15 20:28:14 -07:00
cd65a1777e fix test 2025-09-15 20:27:27 -07:00
a715282154 make the test cuda only 2025-09-15 14:54:28 -07:00
0b042565b4 fix test 2025-09-15 14:43:09 -07:00
7333340b12 address comments 2025-09-15 14:43:09 -07:00
6e3d2bf02f add doc 2025-09-15 14:43:09 -07:00
57d563e5bd add test case, fix tests 2025-09-15 14:43:09 -07:00
0e6fd3dc05 fix lint 2025-09-15 14:43:09 -07:00
60ac912f05 Refactored as genearl purpose DebugMode 2025-09-15 14:43:09 -07:00
cc22fefef8 DTensorDebugMode 2025-09-15 14:43:09 -07:00
5 changed files with 440 additions and 3 deletions

View File

@ -3180,6 +3180,8 @@ coverage_ignore_classes = [
"WeakIdKeyDictionary",
"WeakIdRef",
"WeakTensorKeyDictionary",
# torch.utils.debug_mode
"DebugMode",
]
# The suffix(es) of source filenames.

View File

@ -78,6 +78,7 @@ for tracking purposes -->
.. py:module:: torch.utils.data.graph
.. py:module:: torch.utils.data.graph_settings
.. py:module:: torch.utils.data.sampler
.. py:module:: torch.utils.debug_mode
.. py:module:: torch.utils.dlpack
.. py:module:: torch.utils.file_baton
.. py:module:: torch.utils.flop_counter

View File

@ -0,0 +1,248 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import torch
import torch.distributed as dist
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
run_tests,
TestCase,
)
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.debug_mode import DebugMode
@requires_cuda
class TestDTensorDebugMode(TestCase):
def tearDown(self):
super().tearDown()
dist.destroy_process_group()
def setUp(self):
super().setUp()
self.world_size = 8
store = FakeStore()
dist.init_process_group(
backend="fake", rank=0, world_size=self.world_size, store=store
)
self.device_type = "cuda"
def test_debug_mode_mm(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
x = torch.randn(1, 8, requires_grad=False)
y = torch.randn(1, 32, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
with DebugMode() as debug_mode:
torch.mm(x_dtensor, y_dtensor).sum()
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
redistribute_input(1, [S(0)] -> [R])
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
_c10d_functional::wait_tensor(t: f32[8, 32])
aten::mm(t: f32[1, 8], t: f32[8, 32])
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 32][S(0)])
aten::sum(dt: f32[8, 32][S(0)])
aten::sum(t: f32[1, 32])""",
)
def test_debug_string_inside_context(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
x = torch.randn(1, 8, requires_grad=False)
y = torch.randn(1, 32, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
with DebugMode() as debug_mode:
torch.mm(x_dtensor, y_dtensor).sum()
s0 = debug_mode.debug_string()
s1 = debug_mode.debug_string()
self.assertEqual(s0, s1)
def test_debug_mode_backward(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
x = torch.randn(1, 8, requires_grad=True)
y = torch.randn(8, 1, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(1)], run_check=False)
with DebugMode() as debug_mode:
z = x_dtensor + y_dtensor
z.sum().backward()
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
<method 'add' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
aten::add.Tensor(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
redistribute_input(1, [S(1)] -> [S(0)])
_dtensor::shard_dim_alltoall(t: f32[8, 1], 1, 0, 0)
aten::add.Tensor(t: f32[1, 8], t: f32[1, 8])
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)])
aten::sum(dt: f32[8, 8][S(0)])
aten::sum(t: f32[1, 8])
torch._tensor.backward(dt: f32[][P], gradient=None, retain_graph=None, create_graph=False, inputs=None)
aten::ones_like(dt: f32[][P], pin_memory=False, memory_format=torch.preserve_format)
aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)
aten::expand(dt: f32[][R], [8, 8])
aten::expand(t: f32[], [8, 8])
aten::split.Tensor(t: f32[8, 8], 1, 1)
aten::clone(t: f32[8, 1])
aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu)
aten::detach(t: f32[8, 1])
aten::split.Tensor(t: f32[8, 8], 1)
aten::clone(t: f32[1, 8])
aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu)
aten::detach(t: f32[1, 8])""",
)
def test_debug_mode_einsum(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).view(4, 2))
# Create test tensors
a = torch.randn(16, 6, 8)
b = torch.randn(8, 4, 4)
a_dt = DTensor.from_local(a, mesh, [Partial(), Replicate()], run_check=False)
b_dt = DTensor.from_local(b, mesh, [Replicate(), Partial()], run_check=False)
# Capture the operator decomposition
with DebugMode() as debug_mode:
torch.einsum("bld,dnh->blnh", a_dt, b_dt)
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8][P, R], dt: f32[8, 4, 4][R, P])
aten::unsqueeze(dt: f32[16, 6, 8][P, R], 3)
aten::unsqueeze(t: f32[16, 6, 8], 3)
aten::unsqueeze(dt: f32[16, 6, 8, 1][P, R], 4)
aten::unsqueeze(t: f32[16, 6, 8, 1], 4)
aten::permute(dt: f32[16, 6, 8, 1, 1][P, R], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2])
aten::unsqueeze(dt: f32[8, 4, 4][R, P], 3)
aten::unsqueeze(t: f32[8, 4, 4], 3)
aten::unsqueeze(dt: f32[8, 4, 4, 1][R, P], 4)
aten::unsqueeze(t: f32[8, 4, 4, 1], 4)
aten::permute(dt: f32[8, 4, 4, 1, 1][R, P], [3, 4, 1, 2, 0])
aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0])
aten::permute(dt: f32[16, 6, 1, 1, 8][P, R], [0, 1, 4, 2, 3])
aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3])
aten::view(dt: f32[16, 6, 8, 1, 1][P, R], [1, 96, 8])
aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8])
aten::permute(dt: f32[1, 1, 4, 4, 8][R, P], [4, 2, 3, 0, 1])
aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1])
aten::view(dt: f32[8, 4, 4, 1, 1][R, P], [1, 8, 16])
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
redistribute_input(0, [P, R] -> [S(2), S(2)])
aten::chunk(t: f32[1, 96, 8], 4, 2)
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
aten::clone(t: f32[1, 96, 1])
redistribute_input(1, [R, P] -> [S(1), S(1)])
aten::chunk(t: f32[1, 8, 16], 4, 1)
aten::clone(t: f32[1, 2, 16])
aten::chunk(t: f32[1, 2, 16], 2, 1)
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
aten::permute(dt: f32[16, 6, 1, 4, 4][P, P], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2])
aten::view(dt: f32[16, 6, 4, 4, 1][P, P], [16, 6, 4, 4])
aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])""",
)
def test_real_tensor(self):
x = torch.randn(8, 8, 8)
linear = torch.nn.Linear(8, 8)
with DebugMode() as debug_mode:
linear(x).sum()
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch._C._nn.linear(t: f32[8, 8, 8], t: f32[8, 8], t: f32[8])
aten::view(t: f32[8, 8, 8], [64, 8])
aten::t(t: f32[8, 8])
aten::addmm(t: f32[8], t: f32[64, 8], t: f32[8, 8])
aten::view(t: f32[64, 8], [8, 8, 8])
<method 'sum' of 'torch._C.TensorBase' objects>(t: f32[8, 8, 8])
aten::sum(t: f32[8, 8, 8])""",
)
def test_fake_tensor(self):
with FakeTensorMode():
x = torch.randn(8, 8)
y = torch.randn(8, 8, 8)
with DebugMode(record_faketensor=True) as debug_mode:
torch.matmul(y, x)
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.matmul(ft: f32[8, 8, 8], ft: f32[8, 8])
aten::view(ft: f32[8, 8, 8], [64, 8])
aten::mm(ft: f32[64, 8], ft: f32[8, 8])
aten::_unsafe_view(ft: f32[64, 8], [8, 8, 8])""",
)
@parametrize("has_inner_mode", [True, False])
@parametrize("has_outer_mode", [True, False])
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):
class DummyTorchDispatchMode1(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
class DummyTorchDispatchMode2(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
x = torch.randn(1, 8, requires_grad=True)
y = torch.randn(1, 32, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
inner_mode = (
DummyTorchDispatchMode1() if has_inner_mode else contextlib.nullcontext()
)
outer_mode = (
DummyTorchDispatchMode2() if has_outer_mode else contextlib.nullcontext()
)
with outer_mode:
with DebugMode() as debug_mode:
with inner_mode:
torch.mm(x_dtensor, y_dtensor)
self.assertTrue(
"redistribute_input(1, [S(0)] -> [R])" in debug_mode.debug_string()
)
instantiate_parametrized_tests(TestDTensorDebugMode)
if __name__ == "__main__":
run_tests()

View File

@ -23,7 +23,11 @@ from torch.distributed.tensor._tp_conv import (
)
from torch.distributed.tensor._utils import try_find_mesh_from_args
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
from torch.utils._python_dispatch import return_and_correct_aliasing
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
return_and_correct_aliasing,
)
from torch.utils.debug_mode import DebugMode
try:
@ -334,6 +338,9 @@ class OpDispatcher:
suggested_input_schema: OpSchema,
use_val_from_redistribute_schema: bool,
) -> None:
debug_mode = _get_current_dispatch_mode()
in_debug_mode = isinstance(debug_mode, DebugMode)
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
if op_info.args_tree_spec is not None:
flatten_args_schema_to_reshard = tuple(
@ -348,9 +355,18 @@ class OpDispatcher:
if isinstance(arg_spec, DTensorSpec):
local_tensor = cast(torch.Tensor, op_info.local_args[i])
if arg_spec != reshard_arg_spec:
resharded_local_tensor = redistribute_local_tensor(
local_tensor, arg_spec, reshard_arg_spec
redistribute_context = (
debug_mode.record_redistribute_calls(
i, arg_spec, reshard_arg_spec
)
if in_debug_mode
else contextlib.nullcontext()
)
with redistribute_context:
resharded_local_tensor = redistribute_local_tensor(
local_tensor, arg_spec, reshard_arg_spec
)
new_local_args.append(resharded_local_tensor)
else:
new_local_args.append(local_tensor)

170
torch/utils/debug_mode.py Normal file
View File

@ -0,0 +1,170 @@
# mypy: allow-untyped-defs
import contextlib
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._python_dispatch import _get_current_dispatch_mode, TorchDispatchMode
from torch.utils._pytree import tree_map
__all__ = ["DebugMode"]
REDISTRIBUTE_FUNC = "redistribute_input"
def _stringify_shape(shape) -> str:
return f"[{', '.join([str(x) for x in shape])}]"
def _stringify_device_mesh(mesh) -> str:
return f"DM({', '.join([str(s) for s in mesh.shape])})"
def _stringify_placement(placement) -> str:
return f"[{', '.join([str(p) for p in placement])}]"
def _tensor_debug_string(tensor) -> str:
"""Convert tensor to debug string representation."""
if isinstance(tensor, torch.distributed.tensor.DTensor):
# omitted device mesh
return f"dt: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_placement(tensor.placements)}"
elif isinstance(tensor, FakeTensor):
return f"ft: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}"
elif isinstance(tensor, torch.Tensor):
return f"t: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}"
else:
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
def _arg_to_str(arg) -> str:
from torch.distributed.tensor._dtensor_spec import DTensorSpec
def to_str(x):
if isinstance(x, torch.Tensor):
return _tensor_debug_string(x)
elif isinstance(x, DTensorSpec):
return _stringify_placement(x.placements)
return x
arg = tree_map(to_str, arg)
return str(arg)
def _op_to_str(op, *args, **kwargs) -> str:
if op == REDISTRIBUTE_FUNC:
assert len(args) == 3
_args = [_arg_to_str(arg) for arg in args]
args_str = f"{_args[0]}, {_args[1]} -> {_args[2]}"
else:
args_str = ", ".join(_arg_to_str(arg) for arg in args)
if kwargs:
kwargs_str = ", " + ", ".join(
f"{k}={_arg_to_str(v)}" for k, v in kwargs.items()
)
else:
kwargs_str = ""
if isinstance(op, torch._ops.OpOverload):
op_name = op.__qualname__
elif hasattr(op, "__module__") and hasattr(op, "__name__"):
op_name = f"{op.__module__}.{op.__name__}"
else:
op_name = str(op)
return f"{op_name}({args_str}{kwargs_str})"
class DebugMode(TorchDispatchMode):
def __init__(
self,
*,
record_torchfunction=True,
record_faketensor=False,
record_realtensor=True,
):
super().__init__()
import torch.distributed.tensor # noqa: F401
self.record_torchfunction = record_torchfunction
self.record_faketensor = record_faketensor
self.record_realtensor = record_realtensor
self.operators = []
self.call_depth = 0
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
self.operators.append((func, args, kwargs, self.call_depth))
try:
self.call_depth += 1
return func(*args, **kwargs)
finally:
self.call_depth -= 1
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# Record the operation with its call depth
if torch.distributed.tensor.DTensor in types:
self.operators.append((func, args, kwargs, self.call_depth))
return NotImplemented
elif FakeTensor in types or isinstance(
_get_current_dispatch_mode(), FakeTensorMode
):
if self.record_faketensor:
if func not in {torch.ops.prim.device.default}:
self.operators.append((func, args, kwargs, self.call_depth + 1))
elif len(types) == 0:
if self.record_realtensor:
self.operators.append((func, args, kwargs, self.call_depth + 1))
result = func(*args, **kwargs)
return result
def __enter__(self):
self.operators = []
self.call_depth = 0
if self.record_torchfunction:
torch._C._push_on_torch_function_stack(self)
super().__enter__()
return self
def __exit__(self, *args):
super().__exit__(*args)
if self.record_torchfunction:
torch._C._pop_torch_function_stack()
@contextlib.contextmanager
def record_redistribute_calls(self, arg_idx, src_placement, dst_placement):
try:
self.operators.append(
(
REDISTRIBUTE_FUNC,
[arg_idx, src_placement, dst_placement],
{},
self.call_depth + 1,
)
)
self.call_depth += 1
yield
finally:
self.call_depth -= 1
def debug_string(self) -> str:
with torch._C.DisableTorchFunction():
result = ""
result += "\n".join(
" " + " " * depth + _op_to_str(op, *args, **kwargs)
for op, args, kwargs, depth in self.operators
)
return result