Files
pytorch/torch/utils/_debug_mode.py
Sherlock Huang 95ac7d724e Rename to _debug_mode.py to make it private (#163534)
rename debug_mode.py to _debug_mode.py to make it private, per @alban's request.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163534
Approved by: https://github.com/albanD
2025-09-23 04:27:10 +00:00

171 lines
5.1 KiB
Python

# mypy: allow-untyped-defs
import contextlib
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._python_dispatch import _get_current_dispatch_mode, TorchDispatchMode
from torch.utils._pytree import tree_map
__all__ = ["DebugMode"]
REDISTRIBUTE_FUNC = "redistribute_input"
def _stringify_shape(shape) -> str:
return f"[{', '.join([str(x) for x in shape])}]"
def _stringify_device_mesh(mesh) -> str:
return f"DM({', '.join([str(s) for s in mesh.shape])})"
def _stringify_placement(placement) -> str:
return f"[{', '.join([str(p) for p in placement])}]"
def _tensor_debug_string(tensor) -> str:
"""Convert tensor to debug string representation."""
if isinstance(tensor, torch.distributed.tensor.DTensor):
# omitted device mesh
return f"dt: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_placement(tensor.placements)}"
elif isinstance(tensor, FakeTensor):
return f"ft: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}"
elif isinstance(tensor, torch.Tensor):
return f"t: {dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}"
else:
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
def _arg_to_str(arg) -> str:
from torch.distributed.tensor._dtensor_spec import DTensorSpec
def to_str(x):
if isinstance(x, torch.Tensor):
return _tensor_debug_string(x)
elif isinstance(x, DTensorSpec):
return _stringify_placement(x.placements)
return x
arg = tree_map(to_str, arg)
return str(arg)
def _op_to_str(op, *args, **kwargs) -> str:
if op == REDISTRIBUTE_FUNC:
assert len(args) == 3
_args = [_arg_to_str(arg) for arg in args]
args_str = f"{_args[0]}, {_args[1]} -> {_args[2]}"
else:
args_str = ", ".join(_arg_to_str(arg) for arg in args)
if kwargs:
kwargs_str = ", " + ", ".join(
f"{k}={_arg_to_str(v)}" for k, v in kwargs.items()
)
else:
kwargs_str = ""
if isinstance(op, torch._ops.OpOverload):
op_name = op.__qualname__
elif hasattr(op, "__module__") and hasattr(op, "__name__"):
op_name = f"{op.__module__}.{op.__name__}"
else:
op_name = str(op)
return f"{op_name}({args_str}{kwargs_str})"
class DebugMode(TorchDispatchMode):
def __init__(
self,
*,
record_torchfunction=False,
record_faketensor=False,
record_realtensor=True,
):
super().__init__()
import torch.distributed.tensor # noqa: F401
self.record_torchfunction = record_torchfunction
self.record_faketensor = record_faketensor
self.record_realtensor = record_realtensor
self.operators = []
self.call_depth = 0
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
self.operators.append((func, args, kwargs, self.call_depth))
try:
self.call_depth += 1
return func(*args, **kwargs)
finally:
self.call_depth -= 1
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# Record the operation with its call depth
if torch.distributed.tensor.DTensor in types:
self.operators.append((func, args, kwargs, self.call_depth))
return NotImplemented
elif FakeTensor in types or isinstance(
_get_current_dispatch_mode(), FakeTensorMode
):
if self.record_faketensor:
if func not in {torch.ops.prim.device.default}:
self.operators.append((func, args, kwargs, self.call_depth + 1))
elif len(types) == 0:
if self.record_realtensor:
self.operators.append((func, args, kwargs, self.call_depth + 1))
result = func(*args, **kwargs)
return result
def __enter__(self):
self.operators = []
self.call_depth = 0
if self.record_torchfunction:
torch._C._push_on_torch_function_stack(self)
super().__enter__()
return self
def __exit__(self, *args):
super().__exit__(*args)
if self.record_torchfunction:
torch._C._pop_torch_function_stack()
@contextlib.contextmanager
def record_redistribute_calls(self, arg_idx, src_placement, dst_placement):
try:
self.operators.append(
(
REDISTRIBUTE_FUNC,
[arg_idx, src_placement, dst_placement],
{},
self.call_depth + 1,
)
)
self.call_depth += 1
yield
finally:
self.call_depth -= 1
def debug_string(self) -> str:
with torch._C.DisableTorchFunction():
result = ""
result += "\n".join(
" " + " " * depth + _op_to_str(op, *args, **kwargs)
for op, args, kwargs, depth in self.operators
)
return result