mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support fp8 dtypes in assert_close (#150002)
Fixes #135998 Adds support for fp8. These are compared bitwise, without atol and rtol. The implementation uses the same comparison functions, just with atol and rtol forced to zero. The error message is different from the default case; it only tells the user the first mismatch. This is to avoid triggering the error from #135998. Test Plan: New unit test covers new code paths. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150002 Approved by: https://github.com/cyyever, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
48761e9737
commit
a40e876b08
@ -943,6 +943,28 @@ class TestAssertCloseErrorMessage(TestCase):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at index (1, 0)")):
|
||||
fn()
|
||||
|
||||
def test_small_float_dtype(self):
|
||||
for dtype in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e8m0fnu,
|
||||
]:
|
||||
w = torch.tensor([3.14, 1.0], dtype=dtype)
|
||||
x = torch.tensor([1.0, 3.14], dtype=dtype)
|
||||
y = torch.tensor([3.14, 3.14], dtype=dtype)
|
||||
z = torch.tensor([1.0, 3.14], dtype=dtype)
|
||||
for fn in assert_close_with_inputs(x, y):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("The first mismatched element is at index 0")):
|
||||
fn()
|
||||
|
||||
for fn in assert_close_with_inputs(w, y):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("The first mismatched element is at index 1")):
|
||||
fn()
|
||||
for fn in assert_close_with_inputs(x, z):
|
||||
fn()
|
||||
|
||||
def test_abs_diff_scalar(self):
|
||||
actual = 3
|
||||
expected = 5
|
||||
|
@ -128,6 +128,37 @@ def get_tolerances(
|
||||
return default_tolerances(*inputs)
|
||||
|
||||
|
||||
def _make_bitwise_mismatch_msg(
|
||||
*,
|
||||
default_identifier: str,
|
||||
identifier: Optional[Union[str, Callable[[str], str]]] = None,
|
||||
extra: Optional[str] = None,
|
||||
first_mismatch_idx: Optional[int] = None,
|
||||
):
|
||||
"""Makes a mismatch error message for bitwise values.
|
||||
|
||||
Args:
|
||||
default_identifier (str): Default description of the compared values, e.g. "Tensor-likes".
|
||||
identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides
|
||||
``default_identifier``. Can be passed as callable in which case it will be called with
|
||||
``default_identifier`` to create the description at runtime.
|
||||
extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics.
|
||||
first_mismatch_idx (Optional[int]): the index of the first mismatch.
|
||||
"""
|
||||
if identifier is None:
|
||||
identifier = default_identifier
|
||||
elif callable(identifier):
|
||||
identifier = identifier(default_identifier)
|
||||
|
||||
msg = f"{identifier} are not 'equal'!\n\n"
|
||||
|
||||
if extra:
|
||||
msg += f"{extra.strip()}\n"
|
||||
if first_mismatch_idx is not None:
|
||||
msg += f"The first mismatched element is at index {first_mismatch_idx}.\n"
|
||||
return msg.strip()
|
||||
|
||||
|
||||
def _make_mismatch_msg(
|
||||
*,
|
||||
default_identifier: str,
|
||||
@ -263,6 +294,15 @@ def make_tensor_mismatch_msg(
|
||||
f"Mismatched elements: {total_mismatches} / {number_of_elements} "
|
||||
f"({total_mismatches / number_of_elements:.1%})"
|
||||
)
|
||||
if actual.dtype.is_floating_point and actual.dtype.itemsize == 1:
|
||||
# skip checking for max_abs_diff and max_rel_diff for float8-like values
|
||||
first_mismatch_idx = torch.nonzero(~matches, as_tuple=False)[0].item()
|
||||
return _make_bitwise_mismatch_msg(
|
||||
default_identifier="Tensor-likes",
|
||||
identifier=identifier,
|
||||
extra=extra,
|
||||
first_mismatch_idx=int(first_mismatch_idx),
|
||||
)
|
||||
|
||||
actual_flat = actual.flatten()
|
||||
expected_flat = expected.flatten()
|
||||
@ -824,6 +864,34 @@ class TensorLikePair(Pair):
|
||||
elif actual.layout == torch.jagged:
|
||||
actual, expected = actual.values(), expected.values()
|
||||
compare_fn = self._compare_regular_values_close
|
||||
elif actual.dtype.is_floating_point and actual.dtype.itemsize == 1:
|
||||
|
||||
def bitwise_comp(
|
||||
actual: torch.Tensor,
|
||||
expected: torch.Tensor,
|
||||
*,
|
||||
rtol: float,
|
||||
atol: float,
|
||||
equal_nan: bool,
|
||||
identifier: Optional[Union[str, Callable[[str], str]]] = None,
|
||||
) -> None:
|
||||
if rtol != 0.0 or atol != 0.0:
|
||||
raise ErrorMeta(
|
||||
AssertionError,
|
||||
f"Rtol={rtol} and atol={atol} are not supported for bitwise comparison of low \
|
||||
dimensional floats. Please use rtol=0.0 and atol=0.0",
|
||||
)
|
||||
|
||||
return self._compare_regular_values_close(
|
||||
actual,
|
||||
expected,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=equal_nan,
|
||||
identifier=identifier,
|
||||
)
|
||||
|
||||
compare_fn = bitwise_comp
|
||||
else:
|
||||
compare_fn = self._compare_regular_values_close
|
||||
|
||||
|
Reference in New Issue
Block a user