mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 14:54:56 +08:00
Following up on https://github.com/pytorch/pytorch/pull/166348, extends DebugMode to capture inductor triton kernels at runtime, and adds an API for checking run-to-run determinism based on tensor hashes. The workflow looks something like... ```python # do 1st run with hashes, get logs with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(): compiled_model(*inputs) logs1 = debug_mode.logs # do 2nd run with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(): compiled_model(*inputs) logs2 = debug_mode.logs # returns list of calls w/ mismatched outputs mismatches = DebugMode.check_hash_mismatches(logs1, logs2) ``` Example dump off a smaller version of @drisspg's FlexAttention fwd+bwd determinism tests [script](https://gist.github.com/pianpwk/f65cc63811d12853709dcc77d7eb69f1) (without forced reduction order): ``` cfg: TestConfig(name='Standard', B=2, Hq=32, Hkv=32, Q=2048, KV=2048, Dqk=128, Dv=128) DETERMINISM: fwd: True, bwd_q: False, bwd_k: False, bwd_v: True $$$ DEBUG MODE DUMP $$$ (this is what the logs look like) [triton] triton_tem_fused_0(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_MAX=t: f32[2, 32, 2048], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128]) # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_MAX: 81775.3811062593, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, out_ptr0: 924917.7918248245} [triton] triton_per_fused_zeros_0(in_ptr0=t: bf16[2, 32, 2048, 128], in_ptr1=t: bf16[2, 32, 2048, 128], out_ptr1=t: f32[2, 32, 2048], xnumel=131072, r0_numel=128) # post-kernel hashes: {in_ptr0: 924917.7918248245, in_ptr1: 13389213.797377996, out_ptr1: 81775.38106592931} [triton] triton_tem_fused_zeros_1(arg_Q=t: bf16[2, 32, 2048, 128], arg_K=t: bf16[2, 32, 2048, 128], arg_V=t: bf16[2, 32, 2048, 128], arg_LSE=t: f32[2, 32, 2048], arg_DELTA=t: f32[2, 32, 2048], arg_DO=t: bf16[2, 32, 2048, 128], arg_DQ=t: bf16[2, 32, 2048, 128], arg_DV=t: bf16[2, 32, 2048, 128], arg_KV_NUM_BLKS=t: i32[2, 32, 16], arg_KV_IDX=t: i32[2, 32, 16, 16], arg_Q_NUM_BLKS=t: i32[2, 32, 16], arg_Q_IDX=t: i32[2, 32, 16, 16], arg_FULL_KV_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_KV_IDX=t: i32[2, 32, 16, 16], arg_FULL_Q_NUM_BLKS=t: i32[2, 32, 16], arg_FULL_Q_IDX=t: i32[2, 32, 16, 16], out_ptr0=t: bf16[2, 32, 2048, 128]) # post-kernel hashes: {arg_Q: 13385916.068706088, arg_K: 13389356.409105342, arg_V: 13384993.48412523, arg_LSE: 1347168.9026973695, arg_DELTA: 81775.38106592931, arg_DO: 13389213.797377996, arg_DQ: 874474.8084187683, arg_DV: 727742.3138379117, arg_KV_NUM_BLKS: 1024.0, arg_KV_IDX: 122880.0, arg_Q_NUM_BLKS: 1024.0, arg_Q_IDX: 122880.0, arg_FULL_KV_NUM_BLKS: 7680.0, arg_FULL_KV_IDX: 122880.0, arg_FULL_Q_NUM_BLKS: 7680.0, arg_FULL_Q_IDX: 122880.0, out_ptr0: 700542.3431890717} $$$ MISMATCHES $$$ mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_0', 'arg_name': 'arg_MAX', 'pytree_path': None, 'hash1': 0.0, 'hash2': 81775.3811062593, 'rel_diff': 1.0, 'is_input_hash': False} # I guess this one is misleading? not sure if I'm doing something wrong with waiting for kernel results mismatch: {'call_type': 'triton kernel', 'call': 'triton_per_fused_zeros_0', 'arg_name': 'out_ptr1', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False} mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DELTA', 'pytree_path': None, 'hash1': 81775.3811062593, 'hash2': 81775.38106592931, 'rel_diff': 4.931801261646669e-10, 'is_input_hash': False} mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'arg_DQ', 'pytree_path': None, 'hash1': 874474.8097136207, 'hash2': 874474.8084187683, 'rel_diff': 1.480720012120795e-09, 'is_input_hash': False} mismatch: {'call_type': 'triton kernel', 'call': 'triton_tem_fused_zeros_1', 'arg_name': 'out_ptr0', 'pytree_path': None, 'hash1': 700542.3488049245, 'hash2': 700542.3431890717, 'rel_diff': 8.016435812581196e-09, 'is_input_hash': False} ``` note: current hash implementation is basically tensor norm, so tensor closeness -> hash closeness. This is likely to change soon, e.g. maybe to `torch.hash_tensor` (https://github.com/pytorch/pytorch/pull/154149) by default Sample paste diff between log dumps from 2 runs: <img width="1665" height="445" alt="Screenshot 2025-11-05 at 11 27 24 PM" src="https://github.com/user-attachments/assets/41402e37-f50b-4a9e-a17c-bb98b5917076" /> Another case where running this for FSDP2 on Llama3-8B, helped narrow down divergence b/w aot_eager <-> inductor, to inductor's FWD RMSNorm kernels: P2027003180 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167028 Approved by: https://github.com/v0i0
1115 lines
40 KiB
Python
1115 lines
40 KiB
Python
# mypy: allow-untyped-defs
|
|
"""
|
|
DebugMode is a debugging TorchDispatchMode that intercepts and logs runtime calls
|
|
to a hierarchical string dump. It logs real tensor, DTensor, and optionally FakeTensor
|
|
operations, with some additional handling for DTensor internals.
|
|
|
|
An example dump from an eager mode DTensor matmul:
|
|
|
|
torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0)
|
|
aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))
|
|
redistribute_input(1, S(0) -> R)
|
|
redistribute_input(t$2: f32[1, 32], trace: S(0)->R)
|
|
_c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32]
|
|
_c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32]
|
|
aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32]
|
|
|
|
This mode runs "under" compile, which means it hides itself during compilation, and is re-enabled
|
|
at runtime, and DebugMode-related operations won't show up in the compiled region.
|
|
DebugMode also provides some visibility into non-torch-dispatch calls (e.g. DTensor redistribute calls,
|
|
inductor-generated triton kernels), but requires special handling for these, since dispatch modes
|
|
can't intercept them by default.
|
|
|
|
The mode also provides some extensions for custom debugging (e.g. adding custom dispatch call hooks
|
|
via dispatch_hooks), or numerics debugging (e.g. tensor hashing for bitwise equivalence/closeness,
|
|
via log_tensor_hashes). These decorators allow annotating string dumps with additional per-call information,
|
|
for any region of runtime code.
|
|
|
|
Usage::
|
|
|
|
with DebugMode() as debug_mode:
|
|
result = some_pytorch_operation(tensor_input)
|
|
print(debug_mode.debug_string())
|
|
"""
|
|
|
|
import contextlib
|
|
import functools
|
|
import traceback
|
|
import weakref
|
|
from collections.abc import Callable
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
|
from torch.utils._dtype_abbrs import dtype_abbrs
|
|
from torch.utils._python_dispatch import (
|
|
_get_current_dispatch_mode,
|
|
_get_current_dispatch_mode_stack,
|
|
TorchDispatchMode,
|
|
)
|
|
from torch.utils._pytree import keystr, tree_all, tree_map, tree_map_with_path
|
|
from torch.utils._traceback import CapturedTraceback
|
|
from torch.utils.weak import WeakIdRef
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._dynamo.device_interface import DeviceInterface
|
|
from torch.distributed._tools.mod_tracker import ModTracker
|
|
|
|
|
|
__all__ = ["DebugMode", "get_active_debug_mode"]
|
|
|
|
|
|
REDISTRIBUTE_FUNC = "redistribute_input"
|
|
# registered dispatch call hooks
|
|
_DISPATCH_RECORD_HOOKS: list[Callable] = []
|
|
_DISPATCH_LOG_HOOKS: list[Callable] = []
|
|
# Tracks if we're in inductor benchmarking, and temporarily disables logging
|
|
# (for ignoring autotuning kernel launches which don't affect the user-facing result)
|
|
_IN_INDUCTOR_BENCHMARK = False
|
|
# For record_outputs, log_tensor_hashes hooks for triton kernels.
|
|
# Stores kernel outputs in call.record["output"]
|
|
_RECORD_TRITON_OUTPUTS = False
|
|
# Annotates kernel output hashes, and stores them in call.post_hashes
|
|
_TRITON_OUTPUT_HASH_FN = None
|
|
# Annotates kernel input hashes, and stores them in call.pre_hashes
|
|
_TRITON_INPUT_HASH_FN = None
|
|
|
|
|
|
def _stringify_shape(shape) -> str:
|
|
return f"[{', '.join([str(x) for x in shape])}]"
|
|
|
|
|
|
def _stringify_device_mesh(mesh) -> str:
|
|
return f"DM({', '.join([str(s) for s in mesh.shape])})"
|
|
|
|
|
|
def _stringify_placement(placement) -> str:
|
|
return f"[{', '.join([str(p) for p in placement])}]"
|
|
|
|
|
|
def _stringify_attributes(tensor, attributes) -> str:
|
|
pairs = {}
|
|
for attr in attributes:
|
|
if hasattr(tensor, attr):
|
|
pairs[attr] = getattr(tensor, attr)
|
|
if len(pairs) == 0:
|
|
return ""
|
|
return f"{{{', '.join([f'{k}={v}' for k, v in pairs.items()])}}}"
|
|
|
|
|
|
def _stringify_dtensor_spec(spec) -> str:
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
|
|
|
return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order)
|
|
|
|
|
|
class TensorIdTracker:
|
|
def __init__(self) -> None:
|
|
self.tensor_memo: dict[WeakIdRef, int] = {}
|
|
self.next_tensor_id = 0
|
|
|
|
def _id(self, tensor) -> int:
|
|
with torch._C._DisablePythonDispatcher():
|
|
o = WeakIdRef(tensor)
|
|
|
|
def del_memo() -> None:
|
|
self.tensor_memo.pop(o, None)
|
|
|
|
weakref.finalize(tensor, del_memo)
|
|
if o not in self.tensor_memo:
|
|
self.tensor_memo[o] = self.next_tensor_id
|
|
self.next_tensor_id += 1
|
|
return self.tensor_memo[o]
|
|
|
|
|
|
def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str:
|
|
"""Convert tensor to debug string representation."""
|
|
|
|
if isinstance(tensor, torch.Tensor):
|
|
tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}"
|
|
id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else ""
|
|
if isinstance(tensor, torch.distributed.tensor.DTensor):
|
|
# omitted device mesh
|
|
return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}"
|
|
elif isinstance(tensor, FakeTensor):
|
|
return f"ft{id_str}: {tensor_debug_str}"
|
|
else:
|
|
return f"t{id_str}: {tensor_debug_str}"
|
|
else:
|
|
raise RuntimeError(f"Unsupported tensor type: {type(tensor)}")
|
|
|
|
|
|
def _arg_to_str(arg, attributes, tensor_memo=None) -> str:
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
|
|
|
def to_str(x):
|
|
if isinstance(x, torch.Tensor):
|
|
return _tensor_debug_string(x, attributes, tensor_memo)
|
|
elif isinstance(x, DTensorSpec):
|
|
return _stringify_dtensor_spec(x)
|
|
return x
|
|
|
|
arg = tree_map(to_str, arg)
|
|
return str(arg)
|
|
|
|
|
|
def default_hash_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor:
|
|
"""
|
|
from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous,
|
|
replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128.
|
|
This is used to generate a deterministic summary value for tensor comparison.
|
|
"""
|
|
with torch._C._DisablePythonDispatcher(), torch._C._DisableTorchDispatch():
|
|
if not (t.is_floating_point() or t.is_complex()):
|
|
t = t.float()
|
|
t = t.contiguous()
|
|
# Clean the tensor to handle NaN/inf values, then compute norm
|
|
t_clean = torch.nan_to_num(t, nan=0.0, posinf=1.0, neginf=-1.0)
|
|
|
|
dtype = torch.complex128 if t.is_complex() else torch.float64
|
|
out = t_clean.norm(p=1, dtype=dtype)
|
|
if use_scalar:
|
|
return out.item()
|
|
return out
|
|
|
|
|
|
def _compute_rel_diff(hash1, hash2):
|
|
# Relative difference: |hash1 - hash2| / max(|hash1|, |hash2|, eps)
|
|
numerator = abs(hash1 - hash2)
|
|
denominator = max(abs(hash1), abs(hash2), 1e-10)
|
|
return numerator / denominator
|
|
|
|
|
|
def _get_stack_trace() -> str:
|
|
from torch.fx.experimental.symbolic_shapes import uninteresting_files
|
|
|
|
summary = CapturedTraceback.extract().summary()
|
|
summary = summary[:-4] # filter out DebugMode frames
|
|
summary = [
|
|
frame for frame in summary if frame.filename not in uninteresting_files()
|
|
]
|
|
summary = traceback.StackSummary.from_list(summary)
|
|
return "".join(summary.format())
|
|
|
|
|
|
def _maybe_get_autograd_trace() -> str | None:
|
|
if torch._C._current_autograd_node() is not None:
|
|
tb = torch._C._current_autograd_node().metadata.get("traceback_") # type: ignore[attr-defined]
|
|
if tb:
|
|
return "".join(tb)
|
|
return None
|
|
|
|
|
|
def _get_op_name(op) -> str:
|
|
if isinstance(op, torch._ops.OpOverload):
|
|
op_name = op.__qualname__
|
|
elif hasattr(op, "__module__") and hasattr(op, "__name__"):
|
|
op_name = f"{op.__module__}.{op.__name__}"
|
|
else:
|
|
op_name = str(op)
|
|
return op_name
|
|
|
|
|
|
class _DebugCall:
|
|
"""Base class for tracking operator calls in DebugMode"""
|
|
|
|
def __init__(
|
|
self,
|
|
call_depth: int,
|
|
record: dict[str, Any] | None = None,
|
|
log: dict[str, Any] | None = None,
|
|
stack: bool = False,
|
|
) -> None:
|
|
self.call_depth = call_depth
|
|
if stack:
|
|
self.stack_trace = _get_stack_trace()
|
|
self.fwd_stack_trace = _maybe_get_autograd_trace()
|
|
|
|
# results from dispatch hooks
|
|
self.record = record
|
|
self.log = log
|
|
self.output_str: str | None = None
|
|
|
|
def stringify_args(
|
|
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
|
|
) -> None:
|
|
"""
|
|
To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs.
|
|
"""
|
|
raise NotImplementedError(
|
|
"Subclasses must implement stringify_args(), even if no-op"
|
|
)
|
|
|
|
def stringify_output(
|
|
self,
|
|
output: Any,
|
|
attributes: list[str],
|
|
tensor_memo: TensorIdTracker | None = None,
|
|
) -> None:
|
|
"""Store stringified version of call output in self.output_str"""
|
|
if tree_all(lambda x: x is None, output):
|
|
return
|
|
output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output)
|
|
self.output_str = f" -> {str(output_str)}"
|
|
|
|
def render(self, attributes: list[str]) -> str:
|
|
raise NotImplementedError("Subclasses must implement string render()")
|
|
|
|
def __repr__(self) -> str:
|
|
return self.render([])
|
|
|
|
|
|
class _OpCall(_DebugCall):
|
|
"""Normal operator call"""
|
|
|
|
def __init__(
|
|
self,
|
|
op,
|
|
args: tuple,
|
|
kwargs: dict,
|
|
call_depth: int,
|
|
stack: bool = False,
|
|
) -> None:
|
|
super().__init__(call_depth, stack=stack)
|
|
self.op = op
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
self.args_str: str | None = None
|
|
self.kwargs_str: str | None = None
|
|
|
|
def stringify_args(
|
|
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
|
|
) -> None:
|
|
self.args_str = ", ".join(
|
|
_arg_to_str(arg, attributes, tensor_memo) for arg in self.args
|
|
)
|
|
if self.kwargs:
|
|
self.kwargs_str = ", " + ", ".join(
|
|
f"{k}={_arg_to_str(v, attributes, tensor_memo)}"
|
|
for k, v in self.kwargs.items()
|
|
)
|
|
else:
|
|
self.kwargs_str = ""
|
|
del self.args
|
|
del self.kwargs
|
|
|
|
def render(self, attributes: list[str]) -> str:
|
|
if self.args_str is not None:
|
|
args_str = self.args_str
|
|
else:
|
|
args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args)
|
|
|
|
if self.kwargs_str is not None:
|
|
kwargs_str = self.kwargs_str
|
|
else:
|
|
if self.kwargs:
|
|
kwargs_str = ", " + ", ".join(
|
|
f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items()
|
|
)
|
|
else:
|
|
kwargs_str = ""
|
|
|
|
if isinstance(self.op, torch._ops.OpOverload):
|
|
op_name = self.op.__qualname__
|
|
elif hasattr(self.op, "__module__") and hasattr(self.op, "__name__"):
|
|
op_name = f"{self.op.__module__}.{self.op.__name__}"
|
|
else:
|
|
op_name = str(self.op)
|
|
|
|
base_str = f"{op_name}({args_str}{kwargs_str})"
|
|
|
|
if self.output_str:
|
|
base_str += self.output_str
|
|
if self.log:
|
|
base_str += f" # {self.log}"
|
|
return base_str
|
|
|
|
def __iter__(self):
|
|
# for BC; tuple(self) returns (op, args, kwargs, call_depth)
|
|
if self.args_str is not None:
|
|
yield from [self.op, self.args_str, self.kwargs_str, self.call_depth]
|
|
else:
|
|
yield from [self.op, self.args, self.kwargs, self.call_depth]
|
|
|
|
|
|
class _RedistributeCall(_DebugCall):
|
|
"""Redistribute call from DTensor dispatch"""
|
|
|
|
def __init__(
|
|
self,
|
|
arg,
|
|
src_placement,
|
|
dst_placement,
|
|
transform_info_str,
|
|
call_depth,
|
|
stack=False,
|
|
) -> None:
|
|
super().__init__(call_depth, stack=stack)
|
|
self.arg = arg
|
|
self.src_placement = src_placement
|
|
self.dst_placement = dst_placement
|
|
self.transform_info_str = transform_info_str
|
|
|
|
self.arg_str: str | None = None
|
|
|
|
def stringify_args(
|
|
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
|
|
) -> None:
|
|
self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}"
|
|
del self.arg
|
|
|
|
def render(self, attributes: list[str]) -> str:
|
|
if self.arg_str is not None:
|
|
arg_str = self.arg_str
|
|
else:
|
|
arg_str = f"{_arg_to_str(self.arg, attributes)}"
|
|
|
|
if self.transform_info_str is not None: # prioritize over src/dst placements
|
|
placement_str = f"trace: {self.transform_info_str}"
|
|
else:
|
|
src_placement_str = _arg_to_str(self.src_placement, attributes)
|
|
dst_placement_str = _arg_to_str(self.dst_placement, attributes)
|
|
placement_str = f"{src_placement_str} -> {dst_placement_str}"
|
|
|
|
base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})"
|
|
if self.output_str:
|
|
base_str += self.output_str
|
|
return base_str
|
|
|
|
def __iter__(self):
|
|
# for BC; tuple(self) returns (op, placement info, kwargs, call_depth)
|
|
if self.arg_str is not None:
|
|
arg = self.arg_str
|
|
else:
|
|
arg = self.arg
|
|
|
|
yield REDISTRIBUTE_FUNC
|
|
if self.transform_info_str:
|
|
yield [arg, self.transform_info_str]
|
|
else:
|
|
yield [arg, self.src_placement, self.dst_placement]
|
|
yield {}
|
|
yield self.call_depth
|
|
|
|
|
|
class _NNModuleCall(_DebugCall):
|
|
"""Designates entering an nn.Module's forward method"""
|
|
|
|
def __init__(self, module_name: str, call_depth: int, stack: bool = False) -> None:
|
|
super().__init__(call_depth, stack=stack)
|
|
self.module_name = module_name
|
|
|
|
def stringify_args(
|
|
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
|
|
) -> None:
|
|
pass # nothing to stringify
|
|
|
|
def render(self, attributes: list[str]) -> str:
|
|
return f"[nn.Mod] {self.module_name}"
|
|
|
|
def __iter__(self):
|
|
yield from [
|
|
f"[nn.Mod] {self.module_name}",
|
|
(),
|
|
{},
|
|
self.call_depth,
|
|
]
|
|
|
|
|
|
class _TritonKernelCall(_DebugCall):
|
|
"""Triton kernel call from Inductor"""
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_name: str,
|
|
kwargs: dict[str, Any],
|
|
call_depth: int,
|
|
):
|
|
super().__init__(call_depth)
|
|
self.kernel_name = kernel_name
|
|
self.kwargs = kwargs
|
|
self.kwargs_str: str | None = None
|
|
|
|
self.pre_hashes: dict[str, Any] | None = None
|
|
self.post_hashes: dict[str, Any] | None = None
|
|
|
|
def stringify_args(
|
|
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
|
|
) -> None:
|
|
# Optionally hash kernel inputs before launch
|
|
global _TRITON_INPUT_HASH_FN
|
|
if hash_fn := _TRITON_INPUT_HASH_FN:
|
|
self.pre_hashes = {
|
|
k: hash_fn(v)
|
|
for k, v in self.kwargs.items()
|
|
if isinstance(v, torch.Tensor)
|
|
}
|
|
|
|
if self.kwargs:
|
|
self.kwargs_str = ", ".join(
|
|
f"{k}={_arg_to_str(v, attributes, tensor_memo)}"
|
|
for k, v in self.kwargs.items()
|
|
)
|
|
else:
|
|
self.kwargs_str = ""
|
|
|
|
def render(self, attributes: list[str]) -> str:
|
|
base_str = f"[triton] {self.kernel_name}({self.kwargs_str})"
|
|
if self.pre_hashes:
|
|
pre_hashes_str = ", ".join(f"{k}: {v}" for k, v in self.pre_hashes.items())
|
|
pre_hashes_str = (
|
|
"\n "
|
|
+ " " * self.call_depth
|
|
+ f"# pre-kernel hashes: {{{pre_hashes_str}}}"
|
|
)
|
|
else:
|
|
pre_hashes_str = ""
|
|
if self.post_hashes:
|
|
post_hashes_str = ", ".join(
|
|
f"{k}: {v}" for k, v in self.post_hashes.items()
|
|
)
|
|
post_hashes_str = (
|
|
"\n "
|
|
+ " " * self.call_depth
|
|
+ f"# post-kernel hashes: {{{post_hashes_str}}}"
|
|
)
|
|
else:
|
|
post_hashes_str = ""
|
|
return f"{base_str}{pre_hashes_str}{post_hashes_str}\n"
|
|
|
|
def finalize(self, device_interface: "DeviceInterface"):
|
|
# synchronize -> hash/store kernel results
|
|
global _RECORD_TRITON_OUTPUTS, _TRITON_OUTPUT_HASH_FN
|
|
device_interface.synchronize(device_interface.current_device())
|
|
if _RECORD_TRITON_OUTPUTS:
|
|
self.record = {
|
|
"output": {
|
|
k: v.clone() if isinstance(v, torch.Tensor) else v
|
|
for k, v in self.kwargs.items()
|
|
}
|
|
}
|
|
if hash_fn := _TRITON_OUTPUT_HASH_FN:
|
|
self.post_hashes = {
|
|
k: hash_fn(v)
|
|
for k, v in self.kwargs.items()
|
|
if isinstance(v, torch.Tensor)
|
|
}
|
|
|
|
# don't store tensors
|
|
del self.kwargs
|
|
|
|
def __iter__(self):
|
|
yield from [self.kernel_name, (), self.kwargs_str, self.call_depth]
|
|
|
|
|
|
def _run_hook(hook, *args):
|
|
out = hook(*args)
|
|
assert out is None or isinstance(out, dict)
|
|
return out
|
|
|
|
|
|
def _run_dispatch_hooks(call: _DebugCall, func, types, args, kwargs, result) -> None:
|
|
global _DISPATCH_RECORD_HOOKS, _DISPATCH_LOG_HOOKS
|
|
if _DISPATCH_RECORD_HOOKS:
|
|
record = {}
|
|
for hook in _DISPATCH_RECORD_HOOKS:
|
|
hook_out = _run_hook(hook, func, types, args, kwargs, result)
|
|
if hook_out is not None:
|
|
record.update(hook_out)
|
|
if record:
|
|
call.record = record
|
|
|
|
if _DISPATCH_LOG_HOOKS:
|
|
log = {}
|
|
for hook in _DISPATCH_LOG_HOOKS:
|
|
hook_out = _run_hook(hook, func, types, args, kwargs, result)
|
|
if hook_out is not None:
|
|
log.update(hook_out)
|
|
if log:
|
|
call.log = log
|
|
|
|
|
|
def _get_call_name(call: _DebugCall) -> str:
|
|
"""String identifying _DebugCall (e.g. func, kernel, module name)"""
|
|
if isinstance(call, _OpCall):
|
|
return _get_op_name(call.op)
|
|
elif isinstance(call, _TritonKernelCall):
|
|
return call.kernel_name
|
|
elif isinstance(call, _NNModuleCall):
|
|
return call.module_name
|
|
elif isinstance(call, _RedistributeCall):
|
|
return REDISTRIBUTE_FUNC
|
|
else:
|
|
return str(call)
|
|
|
|
|
|
class DebugMode(TorchDispatchMode):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
record_torchfunction=False,
|
|
record_faketensor=False,
|
|
record_realtensor=True,
|
|
record_tensor_attributes=None,
|
|
record_nn_module=False,
|
|
store_original_args=False,
|
|
record_stack_trace=False,
|
|
record_output=False,
|
|
record_ids=False,
|
|
) -> None:
|
|
super().__init__()
|
|
import torch.distributed.tensor # noqa: F401
|
|
|
|
self.supports_higher_order_operators = True
|
|
|
|
# Pushes DebugMode onto the torchfunction stack, and records __torch_function__ calls as well.
|
|
# WARNING: currently incompatible with torch.compile due to dynamo guard failures.
|
|
self.record_torchfunction = record_torchfunction
|
|
|
|
# Records __torch_dispatch__ calls on FakeTensors.
|
|
self.record_faketensor = record_faketensor
|
|
|
|
# Records __torch_dispatch__ calls on real tensors.
|
|
self.record_realtensor = record_realtensor
|
|
|
|
# Optional list[str] of tensor attributes, to be annotated in the string dump.
|
|
self.record_tensor_attributes = record_tensor_attributes or []
|
|
|
|
# Uses ModTracker to record nn.Module entrances, as _NNModuleCall entries.
|
|
# This flag currently has no effect on torch.compiled-regions.
|
|
self.record_nn_module = record_nn_module
|
|
|
|
self.module_tracker: ModTracker | None = None
|
|
if self.record_nn_module:
|
|
self.module_tracker_setup()
|
|
|
|
# If True, stores call args/kwargs in logs, without immediately stringifying.
|
|
# Defaults to False for memory concerns.
|
|
self.store_original_args = store_original_args
|
|
|
|
# For stack trace recording, stores log call stack traces in .stack_trace.
|
|
# For backward graph nodes, will also store the corresponding forward stack traces in .fwd_stack_trace.
|
|
# NOTE: this is only available if autograd tracebacks are being set during the forward pass,
|
|
# e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly().
|
|
self.record_stack_trace = record_stack_trace
|
|
|
|
# Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input)
|
|
self.record_output: bool = record_output
|
|
|
|
# Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3.
|
|
self.record_ids: bool = record_ids
|
|
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
self.operators = []
|
|
self.call_depth = 0
|
|
self._tensor_memo = TensorIdTracker()
|
|
self._output_info: dict[int, object] = {}
|
|
|
|
def _track_op_output(self, op_index, result) -> None:
|
|
"""Assign IDs to output tensors and store in output_info"""
|
|
# self._track_tensor_ids(result)
|
|
self._output_info[op_index] = result
|
|
|
|
# Without this override, running torch.compile under DebugMode
|
|
# will force torch.compile to always use the “eager” backend
|
|
# With this, DebugMode will not take effect on torch.compile
|
|
@classmethod
|
|
def ignore_compile_internals(cls) -> bool:
|
|
return True
|
|
|
|
def _record_call(self, call) -> None:
|
|
global _IN_INDUCTOR_BENCHMARK
|
|
if _IN_INDUCTOR_BENCHMARK:
|
|
return
|
|
|
|
if str(call).startswith("profiler::_record_function"):
|
|
return
|
|
|
|
if not self.store_original_args:
|
|
call.stringify_args(
|
|
self.record_tensor_attributes,
|
|
self._tensor_memo if self.record_ids else None,
|
|
)
|
|
self.operators.append(call)
|
|
|
|
def _record_call_output(self, call, output) -> None:
|
|
if not self.record_output:
|
|
return
|
|
call.stringify_output(
|
|
output,
|
|
self.record_tensor_attributes,
|
|
self._tensor_memo if self.record_ids else None,
|
|
)
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
call = _OpCall(
|
|
func, args, kwargs, self.call_depth, stack=self.record_stack_trace
|
|
)
|
|
self._record_call(call)
|
|
|
|
try:
|
|
self.call_depth += 1
|
|
result = func(*args, **kwargs)
|
|
self._record_call_output(call, result)
|
|
return result
|
|
finally:
|
|
self.call_depth -= 1
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
# Record the operation with its call depth
|
|
call = None
|
|
if torch.distributed.tensor.DTensor in types:
|
|
call = _OpCall(
|
|
func, args, kwargs, self.call_depth, stack=self.record_stack_trace
|
|
)
|
|
self._record_call(call)
|
|
return NotImplemented
|
|
elif FakeTensor in types or isinstance(
|
|
_get_current_dispatch_mode(), FakeTensorMode
|
|
):
|
|
if self.record_faketensor:
|
|
if func != torch.ops.prim.device.default:
|
|
call = _OpCall(
|
|
func,
|
|
args,
|
|
kwargs,
|
|
self.call_depth + 1,
|
|
stack=self.record_stack_trace,
|
|
)
|
|
self._record_call(call)
|
|
elif len(types) == 0:
|
|
if self.record_realtensor:
|
|
call = _OpCall(
|
|
func,
|
|
args,
|
|
kwargs,
|
|
self.call_depth + 1,
|
|
stack=self.record_stack_trace,
|
|
)
|
|
self._record_call(call)
|
|
|
|
result = func(*args, **kwargs)
|
|
if call:
|
|
self._record_call_output(call, result)
|
|
_run_dispatch_hooks(call, func, types, args, kwargs, result)
|
|
|
|
return result
|
|
|
|
def __enter__(self):
|
|
self.reset()
|
|
|
|
if self.record_torchfunction:
|
|
torch._C._push_on_torch_function_stack(self)
|
|
|
|
super().__enter__()
|
|
if self.record_nn_module:
|
|
self.module_tracker.__enter__() # type: ignore[attribute, union-attr]
|
|
|
|
if self.record_stack_trace:
|
|
self.anomaly_for_traces = torch.autograd.set_detect_anomaly(
|
|
True, check_nan=False
|
|
)
|
|
self.anomaly_for_traces.__enter__()
|
|
return self
|
|
|
|
# pyrefly: ignore [bad-override]
|
|
def __exit__(self, *args):
|
|
super().__exit__(*args)
|
|
if self.record_nn_module:
|
|
self.module_tracker.__exit__() # type: ignore[attribute, union-attr]
|
|
if self.record_torchfunction:
|
|
torch._C._pop_torch_function_stack()
|
|
if self.record_stack_trace:
|
|
self.anomaly_for_traces.__exit__(*args)
|
|
|
|
def module_tracker_setup(self) -> None:
|
|
from torch.distributed._tools.mod_tracker import ModTracker
|
|
|
|
self.module_tracker = ModTracker()
|
|
|
|
# module pre-fw hook: record module call
|
|
def pre_fw_hook(module, input) -> None:
|
|
fqn = self.module_tracker._get_mod_name(module) # type: ignore[attribute, union-attr]
|
|
self.operators.append(_NNModuleCall(fqn, self.call_depth + 1))
|
|
self.call_depth += 1
|
|
|
|
# module post-fw hook: decrement call depth
|
|
def post_fw_hook(module, input, output) -> None:
|
|
self.call_depth -= 1
|
|
|
|
self.module_tracker.register_user_hooks(pre_fw_hook, post_fw_hook)
|
|
|
|
@contextlib.contextmanager
|
|
def record_redistribute_calls(
|
|
self,
|
|
arg,
|
|
src_placement,
|
|
dst_placement,
|
|
transform_info_str: str | None = None,
|
|
):
|
|
try:
|
|
self._record_call(
|
|
_RedistributeCall(
|
|
arg,
|
|
src_placement=src_placement,
|
|
dst_placement=dst_placement,
|
|
transform_info_str=transform_info_str,
|
|
call_depth=self.call_depth + 1,
|
|
stack=self.record_stack_trace,
|
|
)
|
|
)
|
|
self.call_depth += 1
|
|
yield
|
|
finally:
|
|
self.call_depth -= 1
|
|
|
|
def record_triton_kernel(
|
|
self, kernel_name: str, kwargs: dict[str, Any]
|
|
) -> _TritonKernelCall:
|
|
call = _TritonKernelCall(kernel_name, kwargs, self.call_depth + 1)
|
|
call.stringify_args(self.record_tensor_attributes)
|
|
self.operators.append(call)
|
|
return call
|
|
|
|
def debug_string(self) -> str:
|
|
with torch._C.DisableTorchFunction():
|
|
result = ""
|
|
result += "\n".join(
|
|
" " + " " * op.call_depth + op.render(self.record_tensor_attributes)
|
|
for op in self.operators
|
|
)
|
|
return result
|
|
|
|
@staticmethod
|
|
@contextlib.contextmanager
|
|
def dispatch_hooks(
|
|
record_hook: Callable | None = None,
|
|
log_hook: Callable | None = None,
|
|
):
|
|
"""
|
|
Allows installing post-hooks on arguments to intercepted __torch_dispatch__ calls;
|
|
hook signatures are expected as (func, types, args, kwargs, result),
|
|
i.e. __torch_dispatch__ args + return value.
|
|
|
|
Logging hook outputs are stored in call.log and annotate calls in debug_string(),
|
|
while recording hook outputs are just stored in call.record.
|
|
For now hooks are expected to return dictionaries.
|
|
"""
|
|
global _DISPATCH_RECORD_HOOKS, _DISPATCH_LOG_HOOKS
|
|
|
|
if record_hook:
|
|
_DISPATCH_RECORD_HOOKS.append(record_hook)
|
|
if log_hook:
|
|
_DISPATCH_LOG_HOOKS.append(log_hook)
|
|
try:
|
|
yield
|
|
finally:
|
|
if record_hook:
|
|
_DISPATCH_RECORD_HOOKS.pop()
|
|
if log_hook:
|
|
_DISPATCH_LOG_HOOKS.pop()
|
|
|
|
@staticmethod
|
|
@contextlib.contextmanager
|
|
def record_outputs():
|
|
"""
|
|
Hook for storing cloned output tensors in .record["output"].
|
|
"""
|
|
|
|
def dispatch_hook(func, types, args, kwargs, result):
|
|
out = tree_map(
|
|
lambda x: x.clone() if isinstance(x, torch.Tensor) else x, result
|
|
)
|
|
return {"output": out}
|
|
|
|
global _RECORD_TRITON_OUTPUTS
|
|
try:
|
|
_old_record_triton = _RECORD_TRITON_OUTPUTS
|
|
_RECORD_TRITON_OUTPUTS = True
|
|
with DebugMode.dispatch_hooks(record_hook=dispatch_hook):
|
|
yield
|
|
finally:
|
|
_RECORD_TRITON_OUTPUTS = _old_record_triton
|
|
|
|
@staticmethod
|
|
@contextlib.contextmanager
|
|
def log_tensor_hashes(hash_fn: Callable | None = None, hash_inputs: bool = False):
|
|
"""
|
|
Installs hook for tensor hash logging.
|
|
|
|
hash_fn: optional function for custom hashing
|
|
hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash".
|
|
NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes.
|
|
"""
|
|
if hash_fn is None:
|
|
hash_fn = functools.partial(default_hash_fn, use_scalar=True)
|
|
|
|
def _tree_hash(obj):
|
|
return tree_map(
|
|
lambda x: hash_fn(x) if isinstance(x, torch.Tensor) else None, obj
|
|
)
|
|
|
|
def _dispatch_hash_hook(func, types, args, kwargs, result):
|
|
if "empty" in str(func) or "profiler" in str(func):
|
|
return None
|
|
|
|
out = {}
|
|
out["hash"] = _tree_hash(result)
|
|
if hash_inputs:
|
|
out["input_hash"] = _tree_hash((args, kwargs))
|
|
|
|
if tree_all(lambda x: x is None, out.values()):
|
|
return None
|
|
return out
|
|
|
|
global _TRITON_INPUT_HASH_FN, _TRITON_OUTPUT_HASH_FN
|
|
try:
|
|
if hash_inputs:
|
|
_old_input_hfn = _TRITON_INPUT_HASH_FN
|
|
_TRITON_INPUT_HASH_FN = hash_fn
|
|
_old_output_hfn = _TRITON_OUTPUT_HASH_FN
|
|
_TRITON_OUTPUT_HASH_FN = hash_fn
|
|
with DebugMode.dispatch_hooks(log_hook=_dispatch_hash_hook):
|
|
yield
|
|
finally:
|
|
if hash_inputs:
|
|
_TRITON_INPUT_HASH_FN = _old_input_hfn # type: ignore[assignment]
|
|
_TRITON_OUTPUT_HASH_FN = _old_output_hfn
|
|
|
|
@staticmethod
|
|
@contextlib.contextmanager
|
|
def _benchmarking_inductor():
|
|
"""
|
|
Context manager for disabling logging during inductor benchmarking,
|
|
so logs don't contain all kernels launched from autotuning.
|
|
"""
|
|
global _IN_INDUCTOR_BENCHMARK
|
|
try:
|
|
_IN_INDUCTOR_BENCHMARK = True
|
|
yield
|
|
finally:
|
|
_IN_INDUCTOR_BENCHMARK = False
|
|
|
|
@property
|
|
def logs(self):
|
|
return list(self.operators)
|
|
|
|
@staticmethod
|
|
def check_hash_mismatches(
|
|
logs1: list, logs2: list, compare_inputs: bool = False
|
|
) -> list[dict]:
|
|
"""
|
|
Compares tensor hashes between two DebugMode runs, for checking run-to-run numerical divergence.
|
|
|
|
This first validates the two log sequences have identical structure (same operations, input shapes/dtypes, etc.),
|
|
then compares tensor hash values, and returns a list of call outputs where mismatches were found.
|
|
Expects input logs to have been run with log_tensor_hashes, and looks for hashes in .log["hash"] & .log["input_hash"]
|
|
(or .post_hashes & .pre_hashes for triton kernels).
|
|
|
|
note: skips checking log pairs where hashes aren't present, but will raise if present in one & not the other.
|
|
|
|
Args:
|
|
logs1: logs from the first DebugMode run (from debug_mode.logs)
|
|
logs2: logs from the second DebugMode run
|
|
compare_inputs: If True, also compare input tensor hashes (default: only output checking)
|
|
|
|
Returns:
|
|
List of dictionaries describing hash mismatches. Each dict contains:
|
|
- call_type: "torch op" or "triton kernel"
|
|
- call: Operator/kernel name
|
|
- arg_name: For triton kernels, the argument name; None for torch ops
|
|
- pytree_path: For torch ops, the pytree path to the differing tensor; None for kernels
|
|
- hash1: Hash value from the first run
|
|
- hash2: Hash value from the second run
|
|
- rel_diff: Relative difference between hash values
|
|
- is_input_hash: True if this is an input hash, False for output hash
|
|
|
|
Raises:
|
|
ValueError: If logs have different lengths, call types, operator names, or call depths
|
|
|
|
Usage::
|
|
|
|
# Run model first time
|
|
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
|
|
model(x)
|
|
logs1 = debug_mode.logs
|
|
|
|
# Run again, in exactly the same way
|
|
with DebugMode() as debug_mode, DebugMode.log_tensor_hashes():
|
|
model(x)
|
|
logs2 = debug_mode.logs
|
|
|
|
mismatches = DebugMode.check_hash_mismatches(logs1, logs2)
|
|
for m in mismatches:
|
|
print(f"{m['call']}: hash diff {m['rel_diff']:.2e}")
|
|
"""
|
|
if len(logs1) != len(logs2):
|
|
raise ValueError(f"Log lengths don't match: {len(logs1)} vs {len(logs2)}")
|
|
|
|
difference_info = []
|
|
for i, (log1, log2) in enumerate(zip(logs1, logs2)):
|
|
# check call type
|
|
call1_type = type(log1).__name__
|
|
call2_type = type(log2).__name__
|
|
if call1_type != call2_type:
|
|
raise ValueError(
|
|
f"Call types don't match at index {i}: {call1_type} vs {call2_type}"
|
|
)
|
|
call_type = call1_type
|
|
|
|
# check call name
|
|
op1_name, op2_name = _get_call_name(log1), _get_call_name(log2)
|
|
if op1_name != op2_name:
|
|
raise ValueError(
|
|
f"Operators don't match at index {i}: {call_type}[{op1_name}] vs {call_type}[{op2_name}]"
|
|
)
|
|
op_name = op1_name
|
|
|
|
# check call depth
|
|
if log1.call_depth != log2.call_depth:
|
|
raise ValueError(
|
|
f"Call depths for {call_type}[{op_name}] don't match at index {i}: {log1.call_depth} vs {log2.call_depth}"
|
|
)
|
|
|
|
# Redistribute: call args should be the same
|
|
if isinstance(log1, _RedistributeCall):
|
|
if tuple(log1) != tuple(log2):
|
|
raise ValueError(
|
|
f"Redistribute calls don't match at index {i}: {log1} vs {log2}"
|
|
)
|
|
|
|
# Triton kernel: same arg names, arg types
|
|
elif isinstance(log1, _TritonKernelCall):
|
|
if log1.kwargs_str != log2.kwargs_str:
|
|
raise ValueError(
|
|
f"Triton kernel call args don't match for {log1.kernel_name} at index {i}:"
|
|
f"\n\nlog1: {log1.kwargs_str}\n\nlog2: {log2.kwargs_str}"
|
|
)
|
|
|
|
def compare_triton_hashes(hashes1, hashes2, is_input):
|
|
assert set(hashes1.keys()) == set(hashes2.keys()) # type: ignore[union-attr]
|
|
for key in hashes1.keys():
|
|
if hashes1[key] != hashes2[key]:
|
|
difference_info.append(
|
|
{
|
|
"call_type": "triton kernel",
|
|
"call": op_name,
|
|
"arg_name": key,
|
|
"pytree_path": None,
|
|
"hash1": hashes1[key],
|
|
"hash2": hashes2[key],
|
|
"rel_diff": _compute_rel_diff(
|
|
hashes1[key], hashes2[key]
|
|
),
|
|
"is_input_hash": is_input,
|
|
}
|
|
)
|
|
|
|
# check output hashes
|
|
has_post_1, has_post_2 = (
|
|
log1.post_hashes is not None,
|
|
log2.post_hashes is not None,
|
|
)
|
|
if has_post_1 != has_post_2:
|
|
raise ValueError(
|
|
f"Triton kernel post-hash presence inconsistent for {log1.kernel_name} "
|
|
f"at index {i}: log1 has post_hashes={has_post_1}, log2 has post_hashes={has_post_2}"
|
|
)
|
|
|
|
if has_post_1:
|
|
compare_triton_hashes(
|
|
log1.post_hashes, log2.post_hashes, is_input=False
|
|
)
|
|
|
|
# maybe check input hashes
|
|
if compare_inputs:
|
|
has_pre_1, has_pre_2 = (
|
|
log1.pre_hashes is not None,
|
|
log2.pre_hashes is not None,
|
|
)
|
|
if has_pre_1 != has_pre_2:
|
|
raise ValueError(
|
|
f"Triton kernel pre-hash presence inconsistent for {log1.kernel_name} "
|
|
f"at index {i}: log1 has pre_hashes={has_pre_1}, log2 has pre_hashes={has_pre_2}"
|
|
)
|
|
|
|
if has_pre_1:
|
|
compare_triton_hashes(
|
|
log1.pre_hashes, log2.pre_hashes, is_input=True
|
|
)
|
|
|
|
# regular log calls
|
|
elif isinstance(log1, _OpCall):
|
|
|
|
def compare_op_hashes(hashes1, hashes2, is_input):
|
|
def _helper(keypath, hash1, hash2):
|
|
if hash1 != hash2:
|
|
difference_info.append(
|
|
{
|
|
"call_type": "torch op",
|
|
"call": op_name,
|
|
"arg_name": None,
|
|
"pytree_path": keystr(keypath),
|
|
"hash1": hash1,
|
|
"hash2": hash2,
|
|
"rel_diff": _compute_rel_diff(hash1, hash2),
|
|
"is_input_hash": is_input,
|
|
}
|
|
)
|
|
|
|
tree_map_with_path(_helper, hashes1, hashes2)
|
|
|
|
# check output hashes
|
|
has_hash1 = log1.log is not None and "hash" in log1.log
|
|
has_hash2 = log2.log is not None and "hash" in log2.log
|
|
if has_hash1 != has_hash2:
|
|
raise ValueError(
|
|
f"Output hash presence inconsistent for triton kernel {call_type}[{op_name}] "
|
|
f"at index {i}: log1 has hash={has_hash1}, log2 has hash={has_hash2}"
|
|
)
|
|
|
|
if has_hash1:
|
|
compare_op_hashes(
|
|
log1.log["hash"], # type: ignore[union-attr]
|
|
log2.log["hash"],
|
|
is_input=False,
|
|
)
|
|
|
|
# maybe check input hashes
|
|
if compare_inputs:
|
|
has_hash1 = log1.log is not None and "input_hash" in log1.log
|
|
has_hash2 = log2.log is not None and "input_hash" in log2.log
|
|
if has_hash1 != has_hash2:
|
|
raise ValueError(
|
|
f"Input hash presence inconsistent for triton kernel {call_type}[{op_name}] "
|
|
f"at index {i}: log1 has input_hash={has_hash1}, log2 has input_hash={has_hash2}"
|
|
)
|
|
|
|
if has_hash1:
|
|
compare_op_hashes(
|
|
log1.log["input_hash"], # type: ignore[union-attr]
|
|
log2.log["input_hash"],
|
|
is_input=True,
|
|
)
|
|
|
|
return difference_info
|
|
|
|
|
|
def get_active_debug_mode() -> DebugMode | None:
|
|
debug_mode = None
|
|
for mode in _get_current_dispatch_mode_stack():
|
|
if isinstance(mode, DebugMode):
|
|
debug_mode = mode
|
|
break
|
|
return debug_mode
|