Show mismatching values when DeepCompile test fails (#7618)

This PR improves error message when DeepCompile test fails.

Tests of DeepCompile occasionally fail
([example](https://github.com/deepspeedai/DeepSpeed/actions/runs/18160078309/job/51688736712?pr=7604))
because of mismatching loss values.
To make sure this is not a synchronization bug that causes `nan` loss
values, the change in this PR shows the mismatching values. We can
consider increasing the tolerances once we confirm the mismatch is
reasonable.

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
This commit is contained in:
Masahiro Tanaka
2025-10-03 02:23:13 -07:00
committed by GitHub
parent 2a76988958
commit 82a9db7eba
3 changed files with 34 additions and 7 deletions

View File

@ -21,6 +21,8 @@ import deepspeed
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist
from .util import torch_assert_close
import pytest
from _pytest.outcomes import Skipped
from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
@ -562,6 +564,8 @@ def enable_determinism(seed: int):
def reduce_boolean_flags(flag: bool, op=all) -> bool:
if not dist.is_initialized():
return flag
device = get_accelerator().current_device()
tensor_flag = torch.tensor(1 if flag else 0, dtype=torch.int, device=device)
world_size = dist.get_world_size()
@ -569,3 +573,24 @@ def reduce_boolean_flags(flag: bool, op=all) -> bool:
dist.all_gather_into_tensor(tensor_flag_buf, tensor_flag)
list_flags = [bool(f) for f in tensor_flag_buf.tolist()]
return op(list_flags)
def allclose_on_all_ranks(actual, expected, assert_message=None, **kwargs) -> None:
"""
Compare two tensors across all ranks.
We want to make sure that all ranks succeed or fail together.
"""
allclose_local = False
allclose_global = False
mismatch_msg = ""
try:
torch_assert_close(actual, expected, **kwargs)
allclose_local = True
allclose_global = reduce_boolean_flags(allclose_local, all)
except AssertionError:
allclose_global = reduce_boolean_flags(allclose_local, all)
mismatch_msg = f"Tensors are not close: {actual=}, {expected=} {kwargs=}"
if not allclose_global:
message = "Tensors are not close on all ranks." if assert_message is None else assert_message
raise AssertionError(f"{message} {mismatch_msg}")

View File

@ -94,15 +94,15 @@ class no_child_process_in_deepspeed_io:
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method
def torch_assert_equal(actual, expected, **kwargs):
def torch_assert_equal(actual, expected, **kwargs) -> None:
"""
Compare two tensors or non-tensor numbers for their equality.
Add msg=blah to add an additional comment to when assert fails.
"""
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
def torch_assert_close(actual, expected, **kwargs):
def torch_assert_close(actual, expected, **kwargs) -> None:
"""
Compare two tensors or non-tensor numbers for their closeness.
@ -113,7 +113,7 @@ def torch_assert_close(actual, expected, **kwargs):
The check doesn't assert when `|a - b| <= (atol + rtol * |b|)`
"""
return torch.testing.assert_close(actual, expected, **kwargs)
torch.testing.assert_close(actual, expected, **kwargs)
def torch_assert_dicts_of_tensors_equal(actual, expected, **kwargs):

View File

@ -12,12 +12,14 @@ from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero import GatheredParameters
from unit.simple_model import SimpleModel
from unit.common import enable_determinism
from unit.common import enable_determinism, allclose_on_all_ranks
@enable_determinism(123)
def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
hidden_dim = hidden_dim_override if hidden_dim_override is not None else 10
# the default tolerances of torch.testing.assert_close are too small
RTOL = 5e-1
ATOL = 1e-2
@ -56,7 +58,7 @@ def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
baseline_loss = baseline_engine(x, y)
target_loss = target_engine(x, y)
assert torch.allclose(baseline_loss, target_loss, rtol=RTOL, atol=ATOL)
allclose_on_all_ranks(baseline_loss, target_loss, "Loss values are not close.", rtol=RTOL, atol=ATOL)
baseline_engine.backward(baseline_loss)
target_engine.backward(target_loss)
@ -66,7 +68,7 @@ def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
with GatheredParameters(target_engine.parameters()):
for p1, p2 in zip(baseline_engine.parameters(), target_engine.parameters()):
assert torch.allclose(p1.to(dtype), p2, rtol=RTOL, atol=ATOL)
allclose_on_all_ranks(p1, p2, "Parameters are not equal.", rtol=RTOL, atol=ATOL)
baseline_engine.destroy()
target_engine.destroy()