Files
pytorch/torch/utils/_debug_mode.py
Pian Pawakapan 9997e853e9 [DebugMode] record triton kernels, run-to-run determinism checks (#167028)
Following up on https://github.com/pytorch/pytorch/pull/166348, extends DebugMode to capture inductor triton kernels at runtime, and adds an API for checking run-to-run determinism based on tensor hashes.

The workflow looks something like...
```python
# do 1st run with hashes, get logs
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
    compiled_model(*inputs)
logs1 = debug_mode.logs

# do 2nd run
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
    compiled_model(*inputs)
logs2 = debug_mode.logs

# returns list of calls w/ mismatched outputs
mismatches = DebugMode.check_hash_mismatches(logs1, logs2)
```

Example dump off a smaller version of @drisspg's FlexAttention fwd+bwd determinism tests [script](https://gist.github.com/pianpwk/f65cc63811d12853709dcc77d7eb69f1) (without forced reduction order):
```
cfg: TestConfig(name='Standard', B=2, Hq=32, Hkv=32, Q=2048, KV=2048, Dqk=128, Dv=128)
DETERMINISM: fwd: True, bwd_q: False, bwd_k: False, bwd_v: True

$$$ DEBUG MODE DUMP $$$  (this is what the logs look like)

    [triton] triton_tem_fused_0(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_MAX=t: f32[2, 32, 2048], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128])
    # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_MAX: 81775.3811062593, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, out_ptr0: 924917.7918248245}

    [triton] triton_per_fused_zeros_0(in_ptr0=t: bf16[2, 32, 2048, 128], in_ptr1=t: bf16[2, 32, 2048, 128], out_ptr1=t: f32[2, 32, 2048], xnumel=131072, r0_numel=128)
    # post-kernel hashes: {in_ptr0: 924917.7918248245, in_ptr1: 13389213.797377996, out_ptr1: 81775.38106592931}

    [triton] triton_tem_fused_zeros_1(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_DELTA=t: f32[2, 32, 2048], arg_DO=t: bf16[2, 32, 2048, 128], arg_DQ=t: bf16[2, 32, 2048, 128], arg_DV=t: bf16[2, 32, 2048, 128], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_Q_NUM_BLKS=t: i32[2, 32, 16], arg_Q_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_Q_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_Q_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128])
    # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_DELTA: 81775.38106592931, arg_DO: 13389213.797377996, arg_DQ: 874474.8084187683, arg_DV: 727742.3138379117, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_Q_NUM_BLKS: 1024.0, arg_Q_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, arg_FULL_Q_NUM_BLKS: 7680.0, arg_FULL_Q_IDX: 122880.0, out_ptr0: 700542.3431890717}

$$$ MISMATCHES $$$
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_0', 'arg_name': 'arg_MAX', 'pytree_path': None, 'hash1': 0.0, 'hash2': 81775.3811062593, 'rel_diff': 1.0, 'is_input_hash': False}  # I guess this one is misleading? not sure if I'm doing something wrong with waiting for kernel results
mismatch: {'call_type': 'triton kernel', 'call': 'triton_per_fused_zeros_0', 'arg_name': 'out_ptr1', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False}
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DELTA', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False}
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DQ', 'pytree_path': None, 'hash1': 874474.8097136207, 'hash2': 874474.8084187683, 'rel_diff': 1.480720012120795e-09, 'is_input_hash': False}
mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'out_ptr0', 'pytree_path': None, 'hash1': 700542.3488049245, 'hash2': 700542.3431890717, 'rel_diff': 8.016435812581196e-09, 'is_input_hash': False}
```

note: current hash implementation is basically tensor norm, so tensor closeness -> hash closeness. This is likely to change soon, e.g. maybe to `torch.hash_tensor` (https://github.com/pytorch/pytorch/pull/154149) by default

Sample paste diff between log dumps from 2 runs:
<img width="1665" height="445" alt="Screenshot 2025-11-05 at 11 27 24 PM" src="https://github.com/user-attachments/assets/41402e37-f50b-4a9e-a17c-bb98b5917076" />

Another case where running this for FSDP2 on Llama3-8B, helped narrow down divergence b/w aot_eager <-> inductor, to inductor's FWD RMSNorm kernels: P2027003180

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167028
Approved by: https://github.com/v0i0
2025-11-12 05:21:07 +00:00

1115 lines
40 KiB
Python

# mypy: allow-untyped-defs
"""
DebugMode is a debugging TorchDispatchMode that intercepts and logs runtime calls
to a hierarchical string dump. It logs real tensor, DTensor, and optionally FakeTensor
operations, with some additional handling for DTensor internals.
An example dump from an eager mode DTensor matmul:
torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0)
aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))
redistribute_input(1, S(0) -> R)
redistribute_input(t$2: f32[1, 32], trace: S(0)->R)
_c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32]
_c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32]
aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32]
This mode runs "under" compile, which means it hides itself during compilation, and is re-enabled
at runtime, and DebugMode-related operations won't show up in the compiled region.
DebugMode also provides some visibility into non-torch-dispatch calls (e.g. DTensor redistribute calls,
inductor-generated triton kernels), but requires special handling for these, since dispatch modes
can't intercept them by default.
The mode also provides some extensions for custom debugging (e.g. adding custom dispatch call hooks
via dispatch_hooks), or numerics debugging (e.g. tensor hashing for bitwise equivalence/closeness,
via log_tensor_hashes). These decorators allow annotating string dumps with additional per-call information,
for any region of runtime code.
Usage::
with DebugMode() as debug_mode:
result = some_pytorch_operation(tensor_input)
print(debug_mode.debug_string())
"""
import contextlib
import functools
import traceback
import weakref
from collections.abc import Callable
from typing import Any, TYPE_CHECKING
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_get_current_dispatch_mode_stack,
TorchDispatchMode,
)
from torch.utils._pytree import keystr, tree_all, tree_map, tree_map_with_path
from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import WeakIdRef
if TYPE_CHECKING:
from torch._dynamo.device_interface import DeviceInterface
from torch.distributed._tools.mod_tracker import ModTracker
__all__ = ["DebugMode", "get_active_debug_mode"]
REDISTRIBUTE_FUNC = "redistribute_input"
# registered dispatch call hooks
_DISPATCH_RECORD_HOOKS: list[Callable] = []
_DISPATCH_LOG_HOOKS: list[Callable] = []
# Tracks if we're in inductor benchmarking, and temporarily disables logging
# (for ignoring autotuning kernel launches which don't affect the user-facing result)
_IN_INDUCTOR_BENCHMARK = False
# For record_outputs, log_tensor_hashes hooks for triton kernels.
# Stores kernel outputs in call.record["output"]
_RECORD_TRITON_OUTPUTS = False
# Annotates kernel output hashes, and stores them in call.post_hashes
_TRITON_OUTPUT_HASH_FN = None
# Annotates kernel input hashes, and stores them in call.pre_hashes
_TRITON_INPUT_HASH_FN = None
def _stringify_shape(shape) -> str:
return f"[{', '.join([str(x) for x in shape])}]"
def _stringify_device_mesh(mesh) -> str:
return f"DM({', '.join([str(s) for s in mesh.shape])})"
def _stringify_placement(placement) -> str:
return f"[{', '.join([str(p) for p in placement])}]"
def _stringify_attributes(tensor, attributes) -> str:
pairs = {}
for attr in attributes:
if hasattr(tensor, attr):
pairs[attr] = getattr(tensor, attr)
if len(pairs) == 0:
return ""
return f"{{{', '.join([f'{k}={v}' for k, v in pairs.items()])}}}"
def _stringify_dtensor_spec(spec) -> str:
from torch.distributed.tensor._dtensor_spec import DTensorSpec
return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order)
class TensorIdTracker:
def __init__(self) -> None:
self.tensor_memo: dict[WeakIdRef, int] = {}
self.next_tensor_id = 0
def _id(self, tensor) -> int:
with torch._C._DisablePythonDispatcher():
o = WeakIdRef(tensor)
def del_memo() -> None:
self.tensor_memo.pop(o, None)
weakref.finalize(tensor, del_memo)
if o not in self.tensor_memo:
self.tensor_memo[o] = self.next_tensor_id
self.next_tensor_id += 1
return self.tensor_memo[o]
def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str:
"""Convert tensor to debug string representation."""
if isinstance(tensor, torch.Tensor):
tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}"
id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else ""
if isinstance(tensor, torch.distributed.tensor.DTensor):
# omitted device mesh
return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}"
elif isinstance(tensor, FakeTensor):
return f"ft{id_str}: {tensor_debug_str}"
else:
return f"t{id_str}: {tensor_debug_str}"
else:
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
def _arg_to_str(arg, attributes, tensor_memo=None) -> str:
from torch.distributed.tensor._dtensor_spec import DTensorSpec
def to_str(x):
if isinstance(x, torch.Tensor):
return _tensor_debug_string(x, attributes, tensor_memo)
elif isinstance(x, DTensorSpec):
return _stringify_dtensor_spec(x)
return x
arg = tree_map(to_str, arg)
return str(arg)
def default_hash_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor:
"""
from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous,
replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128.
This is used to generate a deterministic summary value for tensor comparison.
"""
with torch._C._DisablePythonDispatcher(), torch._C._DisableTorchDispatch():
if not (t.is_floating_point() or t.is_complex()):
t = t.float()
t = t.contiguous()
# Clean the tensor to handle NaN/inf values, then compute norm
t_clean = torch.nan_to_num(t, nan=0.0, posinf=1.0, neginf=-1.0)
dtype = torch.complex128 if t.is_complex() else torch.float64
out = t_clean.norm(p=1, dtype=dtype)
if use_scalar:
return out.item()
return out
def _compute_rel_diff(hash1, hash2):
# Relative difference: |hash1 - hash2| / max(|hash1|, |hash2|, eps)
numerator = abs(hash1 - hash2)
denominator = max(abs(hash1), abs(hash2), 1e-10)
return numerator / denominator
def _get_stack_trace() -> str:
from torch.fx.experimental.symbolic_shapes import uninteresting_files
summary = CapturedTraceback.extract().summary()
summary = summary[:-4] # filter out DebugMode frames
summary = [
frame for frame in summary if frame.filename not in uninteresting_files()
]
summary = traceback.StackSummary.from_list(summary)
return "".join(summary.format())
def _maybe_get_autograd_trace() -> str | None:
if torch._C._current_autograd_node() is not None:
tb = torch._C._current_autograd_node().metadata.get("traceback_") # type: ignore[attr-defined]
if tb:
return "".join(tb)
return None
def _get_op_name(op) -> str:
if isinstance(op, torch._ops.OpOverload):
op_name = op.__qualname__
elif hasattr(op, "__module__") and hasattr(op, "__name__"):
op_name = f"{op.__module__}.{op.__name__}"
else:
op_name = str(op)
return op_name
class _DebugCall:
"""Base class for tracking operator calls in DebugMode"""
def __init__(
self,
call_depth: int,
record: dict[str, Any] | None = None,
log: dict[str, Any] | None = None,
stack: bool = False,
) -> None:
self.call_depth = call_depth
if stack:
self.stack_trace = _get_stack_trace()
self.fwd_stack_trace = _maybe_get_autograd_trace()
# results from dispatch hooks
self.record = record
self.log = log
self.output_str: str | None = None
def stringify_args(
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
) -> None:
"""
To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs.
"""
raise NotImplementedError(
"Subclasses must implement stringify_args(), even if no-op"
)
def stringify_output(
self,
output: Any,
attributes: list[str],
tensor_memo: TensorIdTracker | None = None,
) -> None:
"""Store stringified version of call output in self.output_str"""
if tree_all(lambda x: x is None, output):
return
output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output)
self.output_str = f" -> {str(output_str)}"
def render(self, attributes: list[str]) -> str:
raise NotImplementedError("Subclasses must implement string render()")
def __repr__(self) -> str:
return self.render([])
class _OpCall(_DebugCall):
"""Normal operator call"""
def __init__(
self,
op,
args: tuple,
kwargs: dict,
call_depth: int,
stack: bool = False,
) -> None:
super().__init__(call_depth, stack=stack)
self.op = op
self.args = args
self.kwargs = kwargs
self.args_str: str | None = None
self.kwargs_str: str | None = None
def stringify_args(
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
) -> None:
self.args_str = ", ".join(
_arg_to_str(arg, attributes, tensor_memo) for arg in self.args
)
if self.kwargs:
self.kwargs_str = ", " + ", ".join(
f"{k}={_arg_to_str(v, attributes, tensor_memo)}"
for k, v in self.kwargs.items()
)
else:
self.kwargs_str = ""
del self.args
del self.kwargs
def render(self, attributes: list[str]) -> str:
if self.args_str is not None:
args_str = self.args_str
else:
args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args)
if self.kwargs_str is not None:
kwargs_str = self.kwargs_str
else:
if self.kwargs:
kwargs_str = ", " + ", ".join(
f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items()
)
else:
kwargs_str = ""
if isinstance(self.op, torch._ops.OpOverload):
op_name = self.op.__qualname__
elif hasattr(self.op, "__module__") and hasattr(self.op, "__name__"):
op_name = f"{self.op.__module__}.{self.op.__name__}"
else:
op_name = str(self.op)
base_str = f"{op_name}({args_str}{kwargs_str})"
if self.output_str:
base_str += self.output_str
if self.log:
base_str += f" # {self.log}"
return base_str
def __iter__(self):
# for BC; tuple(self) returns (op, args, kwargs, call_depth)
if self.args_str is not None:
yield from [self.op, self.args_str, self.kwargs_str, self.call_depth]
else:
yield from [self.op, self.args, self.kwargs, self.call_depth]
class _RedistributeCall(_DebugCall):
"""Redistribute call from DTensor dispatch"""
def __init__(
self,
arg,
src_placement,
dst_placement,
transform_info_str,
call_depth,
stack=False,
) -> None:
super().__init__(call_depth, stack=stack)
self.arg = arg
self.src_placement = src_placement
self.dst_placement = dst_placement
self.transform_info_str = transform_info_str
self.arg_str: str | None = None
def stringify_args(
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
) -> None:
self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}"
del self.arg
def render(self, attributes: list[str]) -> str:
if self.arg_str is not None:
arg_str = self.arg_str
else:
arg_str = f"{_arg_to_str(self.arg, attributes)}"
if self.transform_info_str is not None: # prioritize over src/dst placements
placement_str = f"trace: {self.transform_info_str}"
else:
src_placement_str = _arg_to_str(self.src_placement, attributes)
dst_placement_str = _arg_to_str(self.dst_placement, attributes)
placement_str = f"{src_placement_str} -> {dst_placement_str}"
base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})"
if self.output_str:
base_str += self.output_str
return base_str
def __iter__(self):
# for BC; tuple(self) returns (op, placement info, kwargs, call_depth)
if self.arg_str is not None:
arg = self.arg_str
else:
arg = self.arg
yield REDISTRIBUTE_FUNC
if self.transform_info_str:
yield [arg, self.transform_info_str]
else:
yield [arg, self.src_placement, self.dst_placement]
yield {}
yield self.call_depth
class _NNModuleCall(_DebugCall):
"""Designates entering an nn.Module's forward method"""
def __init__(self, module_name: str, call_depth: int, stack: bool = False) -> None:
super().__init__(call_depth, stack=stack)
self.module_name = module_name
def stringify_args(
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
) -> None:
pass # nothing to stringify
def render(self, attributes: list[str]) -> str:
return f"[nn.Mod] {self.module_name}"
def __iter__(self):
yield from [
f"[nn.Mod] {self.module_name}",
(),
{},
self.call_depth,
]
class _TritonKernelCall(_DebugCall):
"""Triton kernel call from Inductor"""
def __init__(
self,
kernel_name: str,
kwargs: dict[str, Any],
call_depth: int,
):
super().__init__(call_depth)
self.kernel_name = kernel_name
self.kwargs = kwargs
self.kwargs_str: str | None = None
self.pre_hashes: dict[str, Any] | None = None
self.post_hashes: dict[str, Any] | None = None
def stringify_args(
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
) -> None:
# Optionally hash kernel inputs before launch
global _TRITON_INPUT_HASH_FN
if hash_fn := _TRITON_INPUT_HASH_FN:
self.pre_hashes = {
k: hash_fn(v)
for k, v in self.kwargs.items()
if isinstance(v, torch.Tensor)
}
if self.kwargs:
self.kwargs_str = ", ".join(
f"{k}={_arg_to_str(v, attributes, tensor_memo)}"
for k, v in self.kwargs.items()
)
else:
self.kwargs_str = ""
def render(self, attributes: list[str]) -> str:
base_str = f"[triton] {self.kernel_name}({self.kwargs_str})"
if self.pre_hashes:
pre_hashes_str = ", ".join(f"{k}: {v}" for k, v in self.pre_hashes.items())
pre_hashes_str = (
"\n "
+ " " * self.call_depth
+ f"# pre-kernel hashes: {{{pre_hashes_str}}}"
)
else:
pre_hashes_str = ""
if self.post_hashes:
post_hashes_str = ", ".join(
f"{k}: {v}" for k, v in self.post_hashes.items()
)
post_hashes_str = (
"\n "
+ " " * self.call_depth
+ f"# post-kernel hashes: {{{post_hashes_str}}}"
)
else:
post_hashes_str = ""
return f"{base_str}{pre_hashes_str}{post_hashes_str}\n"
def finalize(self, device_interface: "DeviceInterface"):
# synchronize -> hash/store kernel results
global _RECORD_TRITON_OUTPUTS, _TRITON_OUTPUT_HASH_FN
device_interface.synchronize(device_interface.current_device())
if _RECORD_TRITON_OUTPUTS:
self.record = {
"output": {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in self.kwargs.items()
}
}
if hash_fn := _TRITON_OUTPUT_HASH_FN:
self.post_hashes = {
k: hash_fn(v)
for k, v in self.kwargs.items()
if isinstance(v, torch.Tensor)
}
# don't store tensors
del self.kwargs
def __iter__(self):
yield from [self.kernel_name, (), self.kwargs_str, self.call_depth]
def _run_hook(hook, *args):
out = hook(*args)
assert out is None or isinstance(out, dict)
return out
def _run_dispatch_hooks(call: _DebugCall, func, types, args, kwargs, result) -> None:
global _DISPATCH_RECORD_HOOKS, _DISPATCH_LOG_HOOKS
if _DISPATCH_RECORD_HOOKS:
record = {}
for hook in _DISPATCH_RECORD_HOOKS:
hook_out = _run_hook(hook, func, types, args, kwargs, result)
if hook_out is not None:
record.update(hook_out)
if record:
call.record = record
if _DISPATCH_LOG_HOOKS:
log = {}
for hook in _DISPATCH_LOG_HOOKS:
hook_out = _run_hook(hook, func, types, args, kwargs, result)
if hook_out is not None:
log.update(hook_out)
if log:
call.log = log
def _get_call_name(call: _DebugCall) -> str:
"""String identifying _DebugCall (e.g. func, kernel, module name)"""
if isinstance(call, _OpCall):
return _get_op_name(call.op)
elif isinstance(call, _TritonKernelCall):
return call.kernel_name
elif isinstance(call, _NNModuleCall):
return call.module_name
elif isinstance(call, _RedistributeCall):
return REDISTRIBUTE_FUNC
else:
return str(call)
class DebugMode(TorchDispatchMode):
def __init__(
self,
*,
record_torchfunction=False,
record_faketensor=False,
record_realtensor=True,
record_tensor_attributes=None,
record_nn_module=False,
store_original_args=False,
record_stack_trace=False,
record_output=False,
record_ids=False,
) -> None:
super().__init__()
import torch.distributed.tensor # noqa: F401
self.supports_higher_order_operators = True
# Pushes DebugMode onto the torchfunction stack, and records __torch_function__ calls as well.
# WARNING: currently incompatible with torch.compile due to dynamo guard failures.
self.record_torchfunction = record_torchfunction
# Records __torch_dispatch__ calls on FakeTensors.
self.record_faketensor = record_faketensor
# Records __torch_dispatch__ calls on real tensors.
self.record_realtensor = record_realtensor
# Optional list[str] of tensor attributes, to be annotated in the string dump.
self.record_tensor_attributes = record_tensor_attributes or []
# Uses ModTracker to record nn.Module entrances, as _NNModuleCall entries.
# This flag currently has no effect on torch.compiled-regions.
self.record_nn_module = record_nn_module
self.module_tracker: ModTracker | None = None
if self.record_nn_module:
self.module_tracker_setup()
# If True, stores call args/kwargs in logs, without immediately stringifying.
# Defaults to False for memory concerns.
self.store_original_args = store_original_args
# For stack trace recording, stores log call stack traces in .stack_trace.
# For backward graph nodes, will also store the corresponding forward stack traces in .fwd_stack_trace.
# NOTE: this is only available if autograd tracebacks are being set during the forward pass,
# e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly().
self.record_stack_trace = record_stack_trace
# Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input)
self.record_output: bool = record_output
# Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3.
self.record_ids: bool = record_ids
self.reset()
def reset(self) -> None:
self.operators = []
self.call_depth = 0
self._tensor_memo = TensorIdTracker()
self._output_info: dict[int, object] = {}
def _track_op_output(self, op_index, result) -> None:
"""Assign IDs to output tensors and store in output_info"""
# self._track_tensor_ids(result)
self._output_info[op_index] = result
# Without this override, running torch.compile under DebugMode
# will force torch.compile to always use the “eager” backend
# With this, DebugMode will not take effect on torch.compile
@classmethod
def ignore_compile_internals(cls) -> bool:
return True
def _record_call(self, call) -> None:
global _IN_INDUCTOR_BENCHMARK
if _IN_INDUCTOR_BENCHMARK:
return
if str(call).startswith("profiler::_record_function"):
return
if not self.store_original_args:
call.stringify_args(
self.record_tensor_attributes,
self._tensor_memo if self.record_ids else None,
)
self.operators.append(call)
def _record_call_output(self, call, output) -> None:
if not self.record_output:
return
call.stringify_output(
output,
self.record_tensor_attributes,
self._tensor_memo if self.record_ids else None,
)
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
call = _OpCall(
func, args, kwargs, self.call_depth, stack=self.record_stack_trace
)
self._record_call(call)
try:
self.call_depth += 1
result = func(*args, **kwargs)
self._record_call_output(call, result)
return result
finally:
self.call_depth -= 1
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# Record the operation with its call depth
call = None
if torch.distributed.tensor.DTensor in types:
call = _OpCall(
func, args, kwargs, self.call_depth, stack=self.record_stack_trace
)
self._record_call(call)
return NotImplemented
elif FakeTensor in types or isinstance(
_get_current_dispatch_mode(), FakeTensorMode
):
if self.record_faketensor:
if func != torch.ops.prim.device.default:
call = _OpCall(
func,
args,
kwargs,
self.call_depth + 1,
stack=self.record_stack_trace,
)
self._record_call(call)
elif len(types) == 0:
if self.record_realtensor:
call = _OpCall(
func,
args,
kwargs,
self.call_depth + 1,
stack=self.record_stack_trace,
)
self._record_call(call)
result = func(*args, **kwargs)
if call:
self._record_call_output(call, result)
_run_dispatch_hooks(call, func, types, args, kwargs, result)
return result
def __enter__(self):
self.reset()
if self.record_torchfunction:
torch._C._push_on_torch_function_stack(self)
super().__enter__()
if self.record_nn_module:
self.module_tracker.__enter__() # type: ignore[attribute, union-attr]
if self.record_stack_trace:
self.anomaly_for_traces = torch.autograd.set_detect_anomaly(
True, check_nan=False
)
self.anomaly_for_traces.__enter__()
return self
# pyrefly: ignore [bad-override]
def __exit__(self, *args):
super().__exit__(*args)
if self.record_nn_module:
self.module_tracker.__exit__() # type: ignore[attribute, union-attr]
if self.record_torchfunction:
torch._C._pop_torch_function_stack()
if self.record_stack_trace:
self.anomaly_for_traces.__exit__(*args)
def module_tracker_setup(self) -> None:
from torch.distributed._tools.mod_tracker import ModTracker
self.module_tracker = ModTracker()
# module pre-fw hook: record module call
def pre_fw_hook(module, input) -> None:
fqn = self.module_tracker._get_mod_name(module) # type: ignore[attribute, union-attr]
self.operators.append(_NNModuleCall(fqn, self.call_depth + 1))
self.call_depth += 1
# module post-fw hook: decrement call depth
def post_fw_hook(module, input, output) -> None:
self.call_depth -= 1
self.module_tracker.register_user_hooks(pre_fw_hook, post_fw_hook)
@contextlib.contextmanager
def record_redistribute_calls(
self,
arg,
src_placement,
dst_placement,
transform_info_str: str | None = None,
):
try:
self._record_call(
_RedistributeCall(
arg,
src_placement=src_placement,
dst_placement=dst_placement,
transform_info_str=transform_info_str,
call_depth=self.call_depth + 1,
stack=self.record_stack_trace,
)
)
self.call_depth += 1
yield
finally:
self.call_depth -= 1
def record_triton_kernel(
self, kernel_name: str, kwargs: dict[str, Any]
) -> _TritonKernelCall:
call = _TritonKernelCall(kernel_name, kwargs, self.call_depth + 1)
call.stringify_args(self.record_tensor_attributes)
self.operators.append(call)
return call
def debug_string(self) -> str:
with torch._C.DisableTorchFunction():
result = ""
result += "\n".join(
" " + " " * op.call_depth + op.render(self.record_tensor_attributes)
for op in self.operators
)
return result
@staticmethod
@contextlib.contextmanager
def dispatch_hooks(
record_hook: Callable | None = None,
log_hook: Callable | None = None,
):
"""
Allows installing post-hooks on arguments to intercepted __torch_dispatch__ calls;
hook signatures are expected as (func, types, args, kwargs, result),
i.e. __torch_dispatch__ args + return value.
Logging hook outputs are stored in call.log and annotate calls in debug_string(),
while recording hook outputs are just stored in call.record.
For now hooks are expected to return dictionaries.
"""
global _DISPATCH_RECORD_HOOKS, _DISPATCH_LOG_HOOKS
if record_hook:
_DISPATCH_RECORD_HOOKS.append(record_hook)
if log_hook:
_DISPATCH_LOG_HOOKS.append(log_hook)
try:
yield
finally:
if record_hook:
_DISPATCH_RECORD_HOOKS.pop()
if log_hook:
_DISPATCH_LOG_HOOKS.pop()
@staticmethod
@contextlib.contextmanager
def record_outputs():
"""
Hook for storing cloned output tensors in .record["output"].
"""
def dispatch_hook(func, types, args, kwargs, result):
out = tree_map(
lambda x: x.clone() if isinstance(x, torch.Tensor) else x, result
)
return {"output": out}
global _RECORD_TRITON_OUTPUTS
try:
_old_record_triton = _RECORD_TRITON_OUTPUTS
_RECORD_TRITON_OUTPUTS = True
with DebugMode.dispatch_hooks(record_hook=dispatch_hook):
yield
finally:
_RECORD_TRITON_OUTPUTS = _old_record_triton
@staticmethod
@contextlib.contextmanager
def log_tensor_hashes(hash_fn: Callable | None = None, hash_inputs: bool = False):
"""
Installs hook for tensor hash logging.
hash_fn: optional function for custom hashing
hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash".
NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes.
"""
if hash_fn is None:
hash_fn = functools.partial(default_hash_fn, use_scalar=True)
def _tree_hash(obj):
return tree_map(
lambda x: hash_fn(x) if isinstance(x, torch.Tensor) else None, obj
)
def _dispatch_hash_hook(func, types, args, kwargs, result):
if "empty" in str(func) or "profiler" in str(func):
return None
out = {}
out["hash"] = _tree_hash(result)
if hash_inputs:
out["input_hash"] = _tree_hash((args, kwargs))
if tree_all(lambda x: x is None, out.values()):
return None
return out
global _TRITON_INPUT_HASH_FN, _TRITON_OUTPUT_HASH_FN
try:
if hash_inputs:
_old_input_hfn = _TRITON_INPUT_HASH_FN
_TRITON_INPUT_HASH_FN = hash_fn
_old_output_hfn = _TRITON_OUTPUT_HASH_FN
_TRITON_OUTPUT_HASH_FN = hash_fn
with DebugMode.dispatch_hooks(log_hook=_dispatch_hash_hook):
yield
finally:
if hash_inputs:
_TRITON_INPUT_HASH_FN = _old_input_hfn # type: ignore[assignment]
_TRITON_OUTPUT_HASH_FN = _old_output_hfn
@staticmethod
@contextlib.contextmanager
def _benchmarking_inductor():
"""
Context manager for disabling logging during inductor benchmarking,
so logs don't contain all kernels launched from autotuning.
"""
global _IN_INDUCTOR_BENCHMARK
try:
_IN_INDUCTOR_BENCHMARK = True
yield
finally:
_IN_INDUCTOR_BENCHMARK = False
@property
def logs(self):
return list(self.operators)
@staticmethod
def check_hash_mismatches(
logs1: list, logs2: list, compare_inputs: bool = False
) -> list[dict]:
"""
Compares tensor hashes between two DebugMode runs, for checking run-to-run numerical divergence.
This first validates the two log sequences have identical structure (same operations, input shapes/dtypes, etc.),
then compares tensor hash values, and returns a list of call outputs where mismatches were found.
Expects input logs to have been run with log_tensor_hashes, and looks for hashes in .log["hash"] & .log["input_hash"]
(or .post_hashes & .pre_hashes for triton kernels).
note: skips checking log pairs where hashes aren't present, but will raise if present in one & not the other.
Args:
logs1: logs from the first DebugMode run (from debug_mode.logs)
logs2: logs from the second DebugMode run
compare_inputs: If True, also compare input tensor hashes (default: only output checking)
Returns:
List of dictionaries describing hash mismatches. Each dict contains:
- call_type: "torch op" or "triton kernel"
- call: Operator/kernel name
- arg_name: For triton kernels, the argument name; None for torch ops
- pytree_path: For torch ops, the pytree path to the differing tensor; None for kernels
- hash1: Hash value from the first run
- hash2: Hash value from the second run
- rel_diff: Relative difference between hash values
- is_input_hash: True if this is an input hash, False for output hash
Raises:
ValueError: If logs have different lengths, call types, operator names, or call depths
Usage::
# Run model first time
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
model(x)
logs1 = debug_mode.logs
# Run again, in exactly the same way
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
model(x)
logs2 = debug_mode.logs
mismatches = DebugMode.check_hash_mismatches(logs1, logs2)
for m in mismatches:
print(f"{m['call']}: hash diff {m['rel_diff']:.2e}")
"""
if len(logs1) != len(logs2):
raise ValueError(f"Log lengths don't match: {len(logs1)} vs {len(logs2)}")
difference_info = []
for i, (log1, log2) in enumerate(zip(logs1, logs2)):
# check call type
call1_type = type(log1).__name__
call2_type = type(log2).__name__
if call1_type != call2_type:
raise ValueError(
f"Call types don't match at index {i}: {call1_type} vs {call2_type}"
)
call_type = call1_type
# check call name
op1_name, op2_name = _get_call_name(log1), _get_call_name(log2)
if op1_name != op2_name:
raise ValueError(
f"Operators don't match at index {i}: {call_type}[{op1_name}] vs {call_type}[{op2_name}]"
)
op_name = op1_name
# check call depth
if log1.call_depth != log2.call_depth:
raise ValueError(
f"Call depths for {call_type}[{op_name}] don't match at index {i}: {log1.call_depth} vs {log2.call_depth}"
)
# Redistribute: call args should be the same
if isinstance(log1, _RedistributeCall):
if tuple(log1) != tuple(log2):
raise ValueError(
f"Redistribute calls don't match at index {i}: {log1} vs {log2}"
)
# Triton kernel: same arg names, arg types
elif isinstance(log1, _TritonKernelCall):
if log1.kwargs_str != log2.kwargs_str:
raise ValueError(
f"Triton kernel call args don't match for {log1.kernel_name} at index {i}:"
f"\n\nlog1: {log1.kwargs_str}\n\nlog2: {log2.kwargs_str}"
)
def compare_triton_hashes(hashes1, hashes2, is_input):
assert set(hashes1.keys()) == set(hashes2.keys()) # type: ignore[union-attr]
for key in hashes1.keys():
if hashes1[key] != hashes2[key]:
difference_info.append(
{
"call_type": "triton kernel",
"call": op_name,
"arg_name": key,
"pytree_path": None,
"hash1": hashes1[key],
"hash2": hashes2[key],
"rel_diff": _compute_rel_diff(
hashes1[key], hashes2[key]
),
"is_input_hash": is_input,
}
)
# check output hashes
has_post_1, has_post_2 = (
log1.post_hashes is not None,
log2.post_hashes is not None,
)
if has_post_1 != has_post_2:
raise ValueError(
f"Triton kernel post-hash presence inconsistent for {log1.kernel_name} "
f"at index {i}: log1 has post_hashes={has_post_1}, log2 has post_hashes={has_post_2}"
)
if has_post_1:
compare_triton_hashes(
log1.post_hashes, log2.post_hashes, is_input=False
)
# maybe check input hashes
if compare_inputs:
has_pre_1, has_pre_2 = (
log1.pre_hashes is not None,
log2.pre_hashes is not None,
)
if has_pre_1 != has_pre_2:
raise ValueError(
f"Triton kernel pre-hash presence inconsistent for {log1.kernel_name} "
f"at index {i}: log1 has pre_hashes={has_pre_1}, log2 has pre_hashes={has_pre_2}"
)
if has_pre_1:
compare_triton_hashes(
log1.pre_hashes, log2.pre_hashes, is_input=True
)
# regular log calls
elif isinstance(log1, _OpCall):
def compare_op_hashes(hashes1, hashes2, is_input):
def _helper(keypath, hash1, hash2):
if hash1 != hash2:
difference_info.append(
{
"call_type": "torch op",
"call": op_name,
"arg_name": None,
"pytree_path": keystr(keypath),
"hash1": hash1,
"hash2": hash2,
"rel_diff": _compute_rel_diff(hash1, hash2),
"is_input_hash": is_input,
}
)
tree_map_with_path(_helper, hashes1, hashes2)
# check output hashes
has_hash1 = log1.log is not None and "hash" in log1.log
has_hash2 = log2.log is not None and "hash" in log2.log
if has_hash1 != has_hash2:
raise ValueError(
f"Output hash presence inconsistent for triton kernel {call_type}[{op_name}] "
f"at index {i}: log1 has hash={has_hash1}, log2 has hash={has_hash2}"
)
if has_hash1:
compare_op_hashes(
log1.log["hash"], # type: ignore[union-attr]
log2.log["hash"],
is_input=False,
)
# maybe check input hashes
if compare_inputs:
has_hash1 = log1.log is not None and "input_hash" in log1.log
has_hash2 = log2.log is not None and "input_hash" in log2.log
if has_hash1 != has_hash2:
raise ValueError(
f"Input hash presence inconsistent for triton kernel {call_type}[{op_name}] "
f"at index {i}: log1 has input_hash={has_hash1}, log2 has input_hash={has_hash2}"
)
if has_hash1:
compare_op_hashes(
log1.log["input_hash"], # type: ignore[union-attr]
log2.log["input_hash"],
is_input=True,
)
return difference_info
def get_active_debug_mode() -> DebugMode | None:
debug_mode = None
for mode in _get_current_dispatch_mode_stack():
if isinstance(mode, DebugMode):
debug_mode = mode
break
return debug_mode