mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This reverts commit 45afaf08a14ab760d86ea80dea6d50cec8626513. Reverted https://github.com/pytorch/pytorch/pull/165498 on behalf of https://github.com/seemethere due to First part of the stack was reverted so will need to revert this too ([comment](https://github.com/pytorch/pytorch/pull/165498#issuecomment-3416618198))
233 lines
7.0 KiB
Python
233 lines
7.0 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
|
from torch.utils._dtype_abbrs import dtype_abbrs
|
|
from torch.utils._python_dispatch import (
|
|
_get_current_dispatch_mode,
|
|
_get_current_dispatch_mode_stack,
|
|
TorchDispatchMode,
|
|
)
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
__all__ = ["DebugMode", "get_active_debug_mode"]
|
|
|
|
REDISTRIBUTE_FUNC = "redistribute_input"
|
|
|
|
|
|
def _stringify_shape(shape) -> str:
|
|
return f"[{', '.join([str(x) for x in shape])}]"
|
|
|
|
|
|
def _stringify_device_mesh(mesh) -> str:
|
|
return f"DM({', '.join([str(s) for s in mesh.shape])})"
|
|
|
|
|
|
def _stringify_placement(placement) -> str:
|
|
return f"[{', '.join([str(p) for p in placement])}]"
|
|
|
|
|
|
def _stringify_attributes(tensor, attributes) -> str:
|
|
pairs = {}
|
|
for attr in attributes:
|
|
if hasattr(tensor, attr):
|
|
pairs[attr] = getattr(tensor, attr)
|
|
if len(pairs) == 0:
|
|
return ""
|
|
return f"{{{', '.join([f'{k}={v}' for k, v in pairs.items()])}}}"
|
|
|
|
|
|
def _stringify_dtensor_spec(spec) -> str:
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
|
|
|
return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order)
|
|
|
|
|
|
def _tensor_debug_string(tensor, attributes) -> str:
|
|
"""Convert tensor to debug string representation."""
|
|
|
|
if isinstance(tensor, torch.Tensor):
|
|
tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}"
|
|
|
|
if isinstance(tensor, torch.distributed.tensor.DTensor):
|
|
# omitted device mesh
|
|
return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}"
|
|
elif isinstance(tensor, FakeTensor):
|
|
return f"ft: {tensor_debug_str}"
|
|
else:
|
|
return f"t: {tensor_debug_str}"
|
|
else:
|
|
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
|
|
|
|
|
|
def _arg_to_str(arg, attributes) -> str:
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
|
|
|
def to_str(x):
|
|
if isinstance(x, torch.Tensor):
|
|
return _tensor_debug_string(x, attributes)
|
|
elif isinstance(x, DTensorSpec):
|
|
return _stringify_dtensor_spec(x)
|
|
return x
|
|
|
|
arg = tree_map(to_str, arg)
|
|
return str(arg)
|
|
|
|
|
|
def _op_to_str(op, attributes, *args, **kwargs) -> str:
|
|
if op == REDISTRIBUTE_FUNC:
|
|
if len(args) == 2:
|
|
args_str = f"{_arg_to_str(args[0], attributes)}, trace: {args[1]}"
|
|
elif len(args) == 3:
|
|
_args = [_arg_to_str(arg, attributes) for arg in args]
|
|
args_str = f"{_args[0]}, {_args[1]} -> {_args[2]}"
|
|
else:
|
|
raise RuntimeError(f"Unsupported args for {REDISTRIBUTE_FUNC}: {args}")
|
|
else:
|
|
args_str = ", ".join(_arg_to_str(arg, attributes) for arg in args)
|
|
|
|
if kwargs:
|
|
kwargs_str = ", " + ", ".join(
|
|
f"{k}={_arg_to_str(v, attributes)}" for k, v in kwargs.items()
|
|
)
|
|
else:
|
|
kwargs_str = ""
|
|
|
|
if isinstance(op, torch._ops.OpOverload):
|
|
op_name = op.__qualname__
|
|
elif hasattr(op, "__module__") and hasattr(op, "__name__"):
|
|
op_name = f"{op.__module__}.{op.__name__}"
|
|
else:
|
|
op_name = str(op)
|
|
|
|
return f"{op_name}({args_str}{kwargs_str})"
|
|
|
|
|
|
class DebugMode(TorchDispatchMode):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
record_torchfunction=False,
|
|
record_faketensor=False,
|
|
record_realtensor=True,
|
|
record_tensor_attributes=None,
|
|
):
|
|
super().__init__()
|
|
import torch.distributed.tensor # noqa: F401
|
|
|
|
self.supports_higher_order_operators = True
|
|
self.record_torchfunction = record_torchfunction
|
|
self.record_faketensor = record_faketensor
|
|
self.record_realtensor = record_realtensor
|
|
self.record_tensor_attributes = record_tensor_attributes or []
|
|
|
|
self.operators = []
|
|
self.call_depth = 0
|
|
|
|
# Without this override, running torch.compile under DebugMode
|
|
# will force torch.compile to always use the “eager” backend
|
|
# With this, DebugMode will not take effect on torch.compile
|
|
@classmethod
|
|
def ignore_compile_internals(cls):
|
|
return True
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
self.operators.append((func, args, kwargs, self.call_depth))
|
|
|
|
try:
|
|
self.call_depth += 1
|
|
return func(*args, **kwargs)
|
|
finally:
|
|
self.call_depth -= 1
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
# Record the operation with its call depth
|
|
if torch.distributed.tensor.DTensor in types:
|
|
self.operators.append((func, args, kwargs, self.call_depth))
|
|
return NotImplemented
|
|
elif FakeTensor in types or isinstance(
|
|
_get_current_dispatch_mode(), FakeTensorMode
|
|
):
|
|
if self.record_faketensor:
|
|
if func != torch.ops.prim.device.default:
|
|
self.operators.append((func, args, kwargs, self.call_depth + 1))
|
|
elif len(types) == 0:
|
|
if self.record_realtensor:
|
|
self.operators.append((func, args, kwargs, self.call_depth + 1))
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
return result
|
|
|
|
def __enter__(self):
|
|
self.operators = []
|
|
self.call_depth = 0
|
|
|
|
if self.record_torchfunction:
|
|
torch._C._push_on_torch_function_stack(self)
|
|
|
|
super().__enter__()
|
|
return self
|
|
|
|
# pyrefly: ignore # bad-override
|
|
def __exit__(self, *args):
|
|
super().__exit__(*args)
|
|
if self.record_torchfunction:
|
|
torch._C._pop_torch_function_stack()
|
|
|
|
@contextlib.contextmanager
|
|
def record_redistribute_calls(
|
|
self,
|
|
arg_idx,
|
|
src_placement,
|
|
dst_placement,
|
|
transform_info_str: Optional[str] = None,
|
|
):
|
|
try:
|
|
arg_list = (
|
|
[arg_idx, transform_info_str]
|
|
if transform_info_str
|
|
else [arg_idx, src_placement, dst_placement]
|
|
)
|
|
self.operators.append(
|
|
(
|
|
REDISTRIBUTE_FUNC,
|
|
arg_list,
|
|
{},
|
|
self.call_depth + 1,
|
|
)
|
|
)
|
|
self.call_depth += 1
|
|
yield
|
|
finally:
|
|
self.call_depth -= 1
|
|
|
|
def debug_string(self) -> str:
|
|
with torch._C.DisableTorchFunction():
|
|
result = ""
|
|
result += "\n".join(
|
|
" "
|
|
+ " " * depth
|
|
+ _op_to_str(op, self.record_tensor_attributes, *args, **kwargs)
|
|
for op, args, kwargs, depth in self.operators
|
|
)
|
|
return result
|
|
|
|
|
|
def get_active_debug_mode() -> Optional[DebugMode]:
|
|
debug_mode = None
|
|
for mode in _get_current_dispatch_mode_stack():
|
|
if isinstance(mode, DebugMode):
|
|
debug_mode = mode
|
|
break
|
|
return debug_mode
|