mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/54211 This was a little more annoying than expected, because the `exclude = ` key in `mypy.ini` is weird. I'll file an upstream issue about that. I ignored one file, `torch/distributed/elastic/agent/server/api.py` that had ~8 errors that were hard to figure out. This can be done in a follow-up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/55712 Reviewed By: walterddr Differential Revision: D27694976 Pulled By: malfet fbshipit-source-id: 228d8be6af040343ce46595dabaca212e69ccc68
411 lines
16 KiB
Python
411 lines
16 KiB
Python
import sys
|
|
from collections import namedtuple
|
|
from typing import Any, Optional, Tuple, Type
|
|
|
|
import torch
|
|
from ._core import _unravel_index
|
|
|
|
__all__ = ["assert_tensors_equal", "assert_tensors_close"]
|
|
|
|
|
|
# The UsageError should be raised in case the test function is not used correctly. With this the user is able to
|
|
# differentiate between a test failure (there is a bug in the tested code) and a test error (there is a bug in the
|
|
# test). If pytest is the test runner, we use the built-in UsageError instead our custom one.
|
|
|
|
try:
|
|
# The module 'pytest' will be imported if the 'pytest' runner is used. This will only give false-positives in case
|
|
# a previously imported module already directly or indirectly imported 'pytest', but the test is run by another
|
|
# runner such as 'unittest'.
|
|
# 'mypy' is not able to handle this within a type annotation
|
|
# (see https://mypy.readthedocs.io/en/latest/common_issues.html#variables-vs-type-aliases for details). In case
|
|
# 'UsageError' is used in an annotation, add a 'type: ignore[valid-type]' comment.
|
|
UsageError: Type[Exception] = sys.modules["pytest"].UsageError # type: ignore[attr-defined]
|
|
except (KeyError, AttributeError):
|
|
|
|
class UsageError(Exception): # type: ignore[no-redef]
|
|
pass
|
|
|
|
|
|
# This is copy-pasted from torch.testing._internal.common_utils.TestCase.dtype_precisions. With this we avoid a
|
|
# dependency on torch.testing._internal at import. See
|
|
# https://github.com/pytorch/pytorch/pull/54769#issuecomment-813174256 for details.
|
|
# {dtype: (rtol, atol)}
|
|
_DTYPE_PRECISIONS = {
|
|
torch.float16: (0.001, 1e-5),
|
|
torch.bfloat16: (0.016, 1e-5),
|
|
torch.float32: (1.3e-6, 1e-5),
|
|
torch.float64: (1e-7, 1e-7),
|
|
torch.complex32: (0.001, 1e-5),
|
|
torch.complex64: (1.3e-6, 1e-5),
|
|
torch.complex128: (1e-7, 1e-7),
|
|
}
|
|
|
|
|
|
def _get_default_rtol_and_atol(a: torch.Tensor, b: torch.Tensor) -> Tuple[float, float]:
|
|
dtype = a.dtype if a.dtype == b.dtype else torch.promote_types(a.dtype, b.dtype)
|
|
return _DTYPE_PRECISIONS.get(dtype, (0.0, 0.0))
|
|
|
|
|
|
def _check_are_tensors(a: Any, b: Any) -> Optional[AssertionError]:
|
|
"""Checks if both inputs are tensors.
|
|
|
|
Args:
|
|
a (Any): First input.
|
|
b (Any): Second input.
|
|
|
|
Returns:
|
|
(Optional[AssertionError]): If check did not pass.
|
|
"""
|
|
if not (isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor)):
|
|
return AssertionError(f"Both inputs have to be tensors, but got {type(a)} and {type(b)} instead.")
|
|
|
|
return None
|
|
|
|
|
|
def _check_supported_tensors(a: torch.Tensor, b: torch.Tensor) -> Optional[UsageError]: # type: ignore[valid-type]
|
|
"""Checks if the tensors are supported by the current infrastructure.
|
|
|
|
All checks are temporary and will be relaxed in the future.
|
|
|
|
Args:
|
|
a (torch.Tensor): First tensor.
|
|
b (torch.Tensor): Second tensor.
|
|
|
|
Returns:
|
|
(Optional[UsageError]): If check did not pass.
|
|
"""
|
|
if any(t.dtype in (torch.complex32, torch.complex64, torch.complex128) for t in (a, b)):
|
|
return UsageError("Comparison for complex tensors is not supported yet.")
|
|
if any(t.is_quantized for t in (a, b)):
|
|
return UsageError("Comparison for quantized tensors is not supported yet.")
|
|
if any(t.is_sparse for t in (a, b)):
|
|
return UsageError("Comparison for sparse tensors is not supported yet.")
|
|
|
|
return None
|
|
|
|
|
|
def _check_attributes_equal(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
*,
|
|
check_device: bool = True,
|
|
check_dtype: bool = True,
|
|
check_stride: bool = True,
|
|
) -> Optional[AssertionError]:
|
|
"""Checks if the attributes of two tensors match.
|
|
|
|
Always checks the :attr:`~torch.Tensor.shape`. Checks for :attr:`~torch.Tensor.device`,
|
|
:attr:`~torch.Tensor.dtype`, and :meth:`~torch.Tensor.stride` are optional and can be disabled.
|
|
|
|
Args:
|
|
a (torch.Tensor): First tensor.
|
|
b (torch.Tensor): Second tensor.
|
|
check_device (bool): If ``True`` (default), asserts that both :attr:`a` and :attr:`b` live in the same
|
|
:attr:`~torch.Tensor.device` memory.
|
|
check_dtype (bool): If ``True`` (default), asserts that both :attr:`a` and :attr:`b` have the same
|
|
:attr:`~torch.Tensor.dtype`.
|
|
check_stride (bool): If ``True`` (default), asserts that both :attr:`a` and :attr:`b` have the same
|
|
:meth:`~torch.Tensor.stride`.
|
|
|
|
Returns:
|
|
(Optional[AssertionError]): If check did not pass.
|
|
"""
|
|
msg_fmtstr = "The values for attribute '{}' do not match: {} != {}."
|
|
|
|
if a.shape != b.shape:
|
|
return AssertionError(msg_fmtstr.format("shape", a.shape, b.shape))
|
|
|
|
if check_device and a.device != b.device:
|
|
return AssertionError(msg_fmtstr.format("device", a.device, b.device))
|
|
|
|
if check_dtype and a.dtype != b.dtype:
|
|
return AssertionError(msg_fmtstr.format("dtype", a.dtype, b.dtype))
|
|
|
|
if check_stride and a.stride() != b.stride():
|
|
return AssertionError(msg_fmtstr.format("stride()", a.stride(), b.stride()))
|
|
|
|
return None
|
|
|
|
|
|
def _equalize_attributes(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Equalizes some attributes of two tensors for value comparison.
|
|
|
|
If :attr:`a` and :attr:`b`
|
|
- do not live in the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory, and
|
|
- do not have the same :attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
|
|
:func:`torch.promote_types`.
|
|
|
|
Args:
|
|
a (torch.Tensor): First tensor.
|
|
b (torch.Tensor): Second tensor.
|
|
|
|
Returns:
|
|
Tuple(torch.Tensor, torch.Tensor): Equalized tensors.
|
|
"""
|
|
if a.device != b.device:
|
|
a = a.cpu()
|
|
b = b.cpu()
|
|
|
|
if a.dtype != b.dtype:
|
|
dtype = torch.promote_types(a.dtype, b.dtype)
|
|
a = a.to(dtype)
|
|
b = b.to(dtype)
|
|
|
|
return a, b
|
|
|
|
|
|
_Trace = namedtuple("_Trace", ("total", "abs", "rel", "idx", "diff", "a", "b"))
|
|
|
|
|
|
def _trace_mismatches(a: torch.Tensor, b: torch.Tensor, mismatches: torch.Tensor) -> _Trace:
|
|
"""Traces mismatches and returns the found information.
|
|
|
|
The returned named tuple has the following fields:
|
|
- total (int): Total number of values in :attr:`a` and :attr:`b`.
|
|
- abs (int): Absolute number of mismatches.
|
|
- rel (float): Relative number of mismatches.
|
|
- idx (Union[int, Tuple[int, ...]]): Index of greatest absolute difference.
|
|
- diff (Union[int, float]): Greatest absolute difference.
|
|
- a (Union[int, float]): Value of :attr:`a` at the greatest absolute difference.
|
|
- b (Union[int, float]): Value of :attr:`a` at the greatest absolute difference.
|
|
|
|
For ``diff``, ``a``, and ``b`` the returned type depends on the :attr:`~torch.Tensor.dtype` of :attr:`a` and
|
|
:attr:`b`.
|
|
|
|
Args:
|
|
a (torch.Tensor): First tensor.
|
|
b (torch.Tensor): Second tensor.
|
|
mismatches (torch.Tensor): Boolean mask of the same shape as :attr:`a` and :attr:`b` that indicates the
|
|
location of mismatches.
|
|
"""
|
|
total = mismatches.numel()
|
|
abs = torch.sum(mismatches).item()
|
|
rel = abs / total
|
|
|
|
dtype = torch.float64 if a.dtype.is_floating_point else torch.int64
|
|
a_flat = a.flatten().to(dtype)
|
|
b_flat = b.flatten().to(dtype)
|
|
|
|
abs_diff_flat = torch.abs(a_flat - b_flat)
|
|
idx_flat = torch.argmax(abs_diff_flat)
|
|
|
|
return _Trace(
|
|
total=total,
|
|
abs=abs,
|
|
rel=rel,
|
|
idx=_unravel_index(idx_flat, a.shape),
|
|
diff=abs_diff_flat[idx_flat].item(),
|
|
a=a_flat[idx_flat].item(),
|
|
b=b_flat[idx_flat].item(),
|
|
)
|
|
|
|
|
|
def _check_values_equal(a: torch.Tensor, b: torch.Tensor) -> Optional[AssertionError]:
|
|
"""Checks if the values of two tensors are bitwise equal.
|
|
|
|
Args:
|
|
a (torch.Tensor): First tensor.
|
|
b (torch.Tensor): Second tensor.
|
|
|
|
Returns:
|
|
(Optional[AssertionError]): If check did not pass.
|
|
"""
|
|
mismatches = torch.ne(a, b)
|
|
if not torch.any(mismatches):
|
|
return None
|
|
|
|
trace = _trace_mismatches(a, b, mismatches)
|
|
return AssertionError(
|
|
f"Found {trace.abs} different element(s) out of {trace.total} ({trace.rel:.1%}). "
|
|
f"The greatest difference of {trace.diff} ({trace.a} vs. {trace.b}) occurred at index {trace.idx}"
|
|
)
|
|
|
|
|
|
def _check_values_close(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
*,
|
|
rtol,
|
|
atol,
|
|
) -> Optional[AssertionError]:
|
|
"""Checks if the values of two tensors are close up to a desired tolerance.
|
|
|
|
Args:
|
|
a (torch.Tensor): First tensor.
|
|
b (torch.Tensor): Second tensor.
|
|
rtol (float): Relative tolerance.
|
|
atol (float): Absolute tolerance.
|
|
|
|
Returns:
|
|
(Optional[AssertionError]): If check did not pass.
|
|
"""
|
|
mismatches = ~torch.isclose(a, b, rtol=rtol, atol=atol)
|
|
if not torch.any(mismatches):
|
|
return None
|
|
|
|
trace = _trace_mismatches(a, b, mismatches)
|
|
return AssertionError(
|
|
f"With rtol={rtol} and atol={atol}, "
|
|
f"found {trace.abs} different element(s) out of {trace.total} ({trace.rel:.1%}). "
|
|
f"The greatest difference of {trace.diff} ({trace.a} vs. {trace.b}) occurred at index {trace.idx}"
|
|
)
|
|
|
|
|
|
def assert_tensors_equal(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
*,
|
|
check_device: bool = True,
|
|
check_dtype: bool = True,
|
|
check_stride: bool = True,
|
|
) -> None:
|
|
"""Asserts that the values of two tensors are bitwise equal.
|
|
|
|
Optionally, checks that some attributes of both tensors are equal.
|
|
|
|
Args:
|
|
a (torch.Tensor): First tensor.
|
|
b (torch.Tensor): Second tensor.
|
|
check_device (bool): If ``True`` (default), asserts that both :attr:`a` and :attr:`b` live in the same
|
|
:attr:`~torch.Tensor.device` memory. If this check is disabled **and** :attr:`a` and :attr:`b` do not live
|
|
in the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory before their values are
|
|
compared.
|
|
check_dtype (bool): If ``True`` (default), asserts that both :attr:`a` and :attr:`b` have the same
|
|
:attr:`~torch.Tensor.dtype`. If this check is disabled **and** :attr:`a` and :attr:`b` do not have the same
|
|
:attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
|
|
:func:`torch.promote_types` before their values are compared.
|
|
check_stride (bool): If ``True`` (default), asserts that both :attr:`a` and :attr:`b` have the same stride.
|
|
|
|
Raises:
|
|
UsageError: If :attr:`a` or :attr:`b` is complex, quantized, or sparse. This is a temporary restriction and
|
|
will be relaxed in the future.
|
|
AssertionError: If :attr:`a` and :attr:`b` do not have the same :attr:`~torch.Tensor.shape`.
|
|
AssertionError: If :attr:`check_device`, but :attr:`a` and :attr:`b` do not live in the same
|
|
:attr:`~torch.Tensor.device` memory.
|
|
AssertionError: If :attr:`check_dtype`, but :attr:`a` and :attr:`b` do not have the same
|
|
:attr:`~torch.Tensor.dtype`.
|
|
AssertionError: If :attr:`check_stride`, but :attr:`a` and :attr:`b` do not have the same stride.
|
|
AssertionError: If the values of :attr:`a` and :attr:`b` are not bitwise equal.
|
|
|
|
.. seealso::
|
|
|
|
To assert that the values in two tensors are are close but are not required to be bitwise equal, use
|
|
:func:`assert_tensors_close` instead.
|
|
"""
|
|
exc: Optional[Exception] = _check_are_tensors(a, b)
|
|
if exc:
|
|
raise exc
|
|
|
|
exc = _check_supported_tensors(a, b)
|
|
if exc:
|
|
raise exc
|
|
|
|
exc = _check_attributes_equal(a, b, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride)
|
|
if exc:
|
|
raise exc
|
|
a, b = _equalize_attributes(a, b)
|
|
|
|
exc = _check_values_equal(a, b)
|
|
if exc:
|
|
raise exc
|
|
|
|
|
|
def assert_tensors_close(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
*,
|
|
rtol: Optional[float] = None,
|
|
atol: Optional[float] = None,
|
|
check_device: bool = True,
|
|
check_dtype: bool = True,
|
|
check_stride: bool = True,
|
|
) -> None:
|
|
"""Asserts that the values of two tensors are close up to a desired tolerance.
|
|
|
|
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`a` and :attr:`b` are bitwise
|
|
equal. Optionally, checks that some attributes of both tensors are equal.
|
|
|
|
Args:
|
|
a (torch.Tensor): First tensor.
|
|
b (torch.Tensor): Second tensor.
|
|
rtol (Optional[float]): Relative tolerance. If specified :attr:`atol` must also be specified. If omitted,
|
|
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
|
|
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted,
|
|
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
|
|
check_device (bool): If ``True`` (default), asserts that both :attr:`a` and :attr:`b` live in the same
|
|
:attr:`~torch.Tensor.device` memory. If this check is disabled **and** :attr:`a` and :attr:`b` do not live
|
|
in the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory before their values are
|
|
compared.
|
|
check_dtype (bool): If ``True`` (default), asserts that both :attr:`a` and :attr:`b` have the same
|
|
:attr:`~torch.Tensor.dtype`. If this check is disabled **and** :attr:`a` and :attr:`b` do not have the same
|
|
:attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
|
|
:func:`torch.promote_types` before their values are compared.
|
|
check_stride (bool): If ``True`` (default), asserts that both :attr:`a` and :attr:`b` have the same stride.
|
|
|
|
Raises:
|
|
UsageError: If :attr:`a` or :attr:`b` is complex, quantized, or sparse. This is a temporary restriction and
|
|
will be relaxed in the future.
|
|
AssertionError: If :attr:`a` and :attr:`b` do not have the same :attr:`~torch.Tensor.shape`.
|
|
AssertionError: If :attr:`check_device`, but :attr:`a` and :attr:`b` do not live in the same
|
|
:attr:`~torch.Tensor.device` memory.
|
|
AssertionError: If :attr:`check_dtype`, but :attr:`a` and :attr:`b` do not have the same
|
|
:attr:`~torch.Tensor.dtype`.
|
|
AssertionError: If :attr:`check_stride`, but :attr:`a` and :attr:`b` do not have the same stride.
|
|
AssertionError: If the values of :attr:`a` and :attr:`b` are close up to a desired tolerance.
|
|
|
|
|
|
|
|
The following table displays the default ``rtol`` and ``atol`` for floating point :attr:`~torch.Tensor.dtype`'s.
|
|
For integer :attr:`~torch.Tensor.dtype`'s, ``rtol = atol = 0.0`` is used.
|
|
|
|
+===========================+============+==========+
|
|
| :class:`~torch.dtype` | ``rtol`` | ``atol`` |
|
|
+===========================+============+==========+
|
|
| :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` |
|
|
+---------------------------+------------+----------+
|
|
| :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` |
|
|
+---------------------------+------------+----------+
|
|
| :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` |
|
|
+---------------------------+------------+----------+
|
|
| :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` |
|
|
+---------------------------+------------+----------+
|
|
| :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` |
|
|
+---------------------------+------------+----------+
|
|
| :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` |
|
|
+---------------------------+------------+----------+
|
|
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
|
|
+---------------------------+------------+----------+
|
|
|
|
.. seealso::
|
|
|
|
To assert that the values in two tensors are bitwise equal, use :func:`assert_tensors_equal` instead.
|
|
"""
|
|
exc: Optional[Exception] = _check_are_tensors(a, b)
|
|
if exc:
|
|
raise exc
|
|
|
|
exc = _check_supported_tensors(a, b)
|
|
if exc:
|
|
raise exc
|
|
|
|
if (rtol is None) ^ (atol is None):
|
|
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
|
|
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
|
|
raise UsageError(
|
|
f"Both 'rtol' and 'atol' must be omitted or specified, " f"but got rtol={rtol} and atol={atol} instead."
|
|
)
|
|
elif rtol is None:
|
|
rtol, atol = _get_default_rtol_and_atol(a, b)
|
|
|
|
exc = _check_attributes_equal(a, b, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride)
|
|
if exc:
|
|
raise exc
|
|
a, b = _equalize_attributes(a, b)
|
|
|
|
if (rtol == 0.0) and (atol == 0.0):
|
|
exc = _check_values_equal(a, b)
|
|
else:
|
|
exc = _check_values_close(a, b, rtol=rtol, atol=atol)
|
|
if exc:
|
|
raise exc
|