mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165376 Approved by: https://github.com/SherlockNoMad
262 lines
8.1 KiB
Python
262 lines
8.1 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
|
from torch.utils._dtype_abbrs import dtype_abbrs
|
|
from torch.utils._python_dispatch import (
|
|
_get_current_dispatch_mode,
|
|
_get_current_dispatch_mode_stack,
|
|
TorchDispatchMode,
|
|
)
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
__all__ = ["DebugMode", "get_active_debug_mode"]
|
|
|
|
REDISTRIBUTE_FUNC = "redistribute_input"
|
|
|
|
|
|
def _stringify_shape(shape) -> str:
|
|
return f"[{', '.join([str(x) for x in shape])}]"
|
|
|
|
|
|
def _stringify_device_mesh(mesh) -> str:
|
|
return f"DM({', '.join([str(s) for s in mesh.shape])})"
|
|
|
|
|
|
def _stringify_placement(placement) -> str:
|
|
return f"[{', '.join([str(p) for p in placement])}]"
|
|
|
|
|
|
def _stringify_attributes(tensor, attributes) -> str:
|
|
pairs = {}
|
|
for attr in attributes:
|
|
if hasattr(tensor, attr):
|
|
pairs[attr] = getattr(tensor, attr)
|
|
if len(pairs) == 0:
|
|
return ""
|
|
return f"{{{', '.join([f'{k}={v}' for k, v in pairs.items()])}}}"
|
|
|
|
|
|
def _stringify_dtensor_spec(spec) -> str:
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
|
|
|
return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order)
|
|
|
|
|
|
def _tensor_debug_string(tensor, attributes) -> str:
|
|
"""Convert tensor to debug string representation."""
|
|
|
|
if isinstance(tensor, torch.Tensor):
|
|
tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}"
|
|
|
|
if isinstance(tensor, torch.distributed.tensor.DTensor):
|
|
# omitted device mesh
|
|
return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}"
|
|
elif isinstance(tensor, FakeTensor):
|
|
return f"ft: {tensor_debug_str}"
|
|
else:
|
|
return f"t: {tensor_debug_str}"
|
|
else:
|
|
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
|
|
|
|
|
|
def _arg_to_str(arg, attributes) -> str:
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
|
|
|
def to_str(x):
|
|
if isinstance(x, torch.Tensor):
|
|
return _tensor_debug_string(x, attributes)
|
|
elif isinstance(x, DTensorSpec):
|
|
return _stringify_dtensor_spec(x)
|
|
return x
|
|
|
|
arg = tree_map(to_str, arg)
|
|
return str(arg)
|
|
|
|
|
|
class _DebugCall:
|
|
"""Base class for tracking operator calls in DebugMode"""
|
|
|
|
def __init__(self, call_depth: int):
|
|
self.call_depth = call_depth
|
|
|
|
def render(self, attributes: list[str]) -> str:
|
|
raise NotImplementedError("Subclasses must implement string render()")
|
|
|
|
|
|
class _OpCall(_DebugCall):
|
|
"""Normal operator call"""
|
|
|
|
def __init__(self, op, args: tuple, kwargs: dict, call_depth: int):
|
|
super().__init__(call_depth)
|
|
self.op = op
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
def render(self, attributes: list[str]) -> str:
|
|
args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args)
|
|
|
|
if self.kwargs:
|
|
kwargs_str = ", " + ", ".join(
|
|
f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items()
|
|
)
|
|
else:
|
|
kwargs_str = ""
|
|
|
|
if isinstance(self.op, torch._ops.OpOverload):
|
|
op_name = self.op.__qualname__
|
|
elif hasattr(self.op, "__module__") and hasattr(self.op, "__name__"):
|
|
op_name = f"{self.op.__module__}.{self.op.__name__}"
|
|
else:
|
|
op_name = str(self.op)
|
|
|
|
return f"{op_name}({args_str}{kwargs_str})"
|
|
|
|
|
|
class _RedistributeCall(_DebugCall):
|
|
"""Redistribute call from DTensor dispatch"""
|
|
|
|
def __init__(
|
|
self, arg, src_placement, dst_placement, transform_info_str, call_depth
|
|
):
|
|
super().__init__(call_depth)
|
|
self.arg = arg
|
|
self.src_placement = src_placement
|
|
self.dst_placement = dst_placement
|
|
self.transform_info_str = transform_info_str
|
|
|
|
def render(self, attributes: list[str]) -> str:
|
|
arg_str = f"{_arg_to_str(self.arg, attributes)}"
|
|
if self.transform_info_str is not None: # prioritize over src/dst placements
|
|
placement_str = f"trace: {self.transform_info_str}"
|
|
else:
|
|
src_placement_str = _arg_to_str(self.src_placement, attributes)
|
|
dst_placement_str = _arg_to_str(self.dst_placement, attributes)
|
|
placement_str = f"{src_placement_str} -> {dst_placement_str}"
|
|
return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})"
|
|
|
|
|
|
class DebugMode(TorchDispatchMode):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
record_torchfunction=False,
|
|
record_faketensor=False,
|
|
record_realtensor=True,
|
|
record_tensor_attributes=None,
|
|
):
|
|
super().__init__()
|
|
import torch.distributed.tensor # noqa: F401
|
|
|
|
self.supports_higher_order_operators = True
|
|
self.record_torchfunction = record_torchfunction
|
|
self.record_faketensor = record_faketensor
|
|
self.record_realtensor = record_realtensor
|
|
self.record_tensor_attributes = record_tensor_attributes or []
|
|
|
|
self.operators = []
|
|
self.call_depth = 0
|
|
|
|
# Without this override, running torch.compile under DebugMode
|
|
# will force torch.compile to always use the “eager” backend
|
|
# With this, DebugMode will not take effect on torch.compile
|
|
@classmethod
|
|
def ignore_compile_internals(cls):
|
|
return True
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
self.operators.append(_OpCall(func, args, kwargs, self.call_depth))
|
|
|
|
try:
|
|
self.call_depth += 1
|
|
return func(*args, **kwargs)
|
|
finally:
|
|
self.call_depth -= 1
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
# Record the operation with its call depth
|
|
if torch.distributed.tensor.DTensor in types:
|
|
self.operators.append(_OpCall(func, args, kwargs, self.call_depth))
|
|
return NotImplemented
|
|
elif FakeTensor in types or isinstance(
|
|
_get_current_dispatch_mode(), FakeTensorMode
|
|
):
|
|
if self.record_faketensor:
|
|
if func != torch.ops.prim.device.default:
|
|
self.operators.append(
|
|
_OpCall(func, args, kwargs, self.call_depth + 1)
|
|
)
|
|
elif len(types) == 0:
|
|
if self.record_realtensor:
|
|
self.operators.append(_OpCall(func, args, kwargs, self.call_depth + 1))
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
return result
|
|
|
|
def __enter__(self):
|
|
self.operators = []
|
|
self.call_depth = 0
|
|
|
|
if self.record_torchfunction:
|
|
torch._C._push_on_torch_function_stack(self)
|
|
|
|
super().__enter__()
|
|
return self
|
|
|
|
# pyrefly: ignore # bad-override
|
|
def __exit__(self, *args):
|
|
super().__exit__(*args)
|
|
if self.record_torchfunction:
|
|
torch._C._pop_torch_function_stack()
|
|
|
|
@contextlib.contextmanager
|
|
def record_redistribute_calls(
|
|
self,
|
|
arg,
|
|
src_placement,
|
|
dst_placement,
|
|
transform_info_str: Optional[str] = None,
|
|
):
|
|
try:
|
|
self.operators.append(
|
|
_RedistributeCall(
|
|
arg,
|
|
src_placement=src_placement,
|
|
dst_placement=dst_placement,
|
|
transform_info_str=transform_info_str,
|
|
call_depth=self.call_depth + 1,
|
|
)
|
|
)
|
|
self.call_depth += 1
|
|
yield
|
|
finally:
|
|
self.call_depth -= 1
|
|
|
|
def debug_string(self) -> str:
|
|
with torch._C.DisableTorchFunction():
|
|
result = ""
|
|
result += "\n".join(
|
|
" " + " " * op.call_depth + op.render(self.record_tensor_attributes)
|
|
for op in self.operators
|
|
)
|
|
return result
|
|
|
|
|
|
def get_active_debug_mode() -> Optional[DebugMode]:
|
|
debug_mode = None
|
|
for mode in _get_current_dispatch_mode_stack():
|
|
if isinstance(mode, DebugMode):
|
|
debug_mode = mode
|
|
break
|
|
return debug_mode
|