Add support for checking tensor containers in torch.testing (#55385)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55385

This renames `assert_tensors_(equal|close)` to `_check_tensors_(equal|close)` and exposes two new functions: `assert_(equal|close)`. In addition to tensor pairs, the newly added functions also support the comparison of tensors in sequences or mappings. Otherwise their signature stays the same.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D27903805

Pulled By: mruberry

fbshipit-source-id: 719d19a1d26de8d14cb25846e3d22a6ac828c80a
This commit is contained in:
Philip Meier
2021-04-24 23:35:25 -07:00
committed by Facebook GitHub Bot
parent bcef7ebd60
commit dbf3451c6e
2 changed files with 558 additions and 260 deletions

View File

@ -1,11 +1,17 @@
import torch
import collections
import functools
import itertools
import math
import os
import random
import re
import unittest
from typing import Any, Callable, Iterator, List, Mapping, Sequence, Tuple, TypeVar
import numpy as np
import torch
from torch.testing._internal.common_utils import \
(IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
from torch.testing._internal.framework_utils import calculate_shards
@ -740,17 +746,62 @@ if __name__ == '__main__':
self.assertNotIn('OK', stderr.decode('ascii'))
T = TypeVar("T", torch.Tensor, Sequence[torch.Tensor], Mapping[Any, torch.Tensor])
class TestAsserts(TestCase):
def assert_fns(self):
return [torch.testing.assert_tensors_equal, torch.testing.assert_tensors_close]
def get_assert_fns(self) -> List[Callable]:
"""Gets assert functions to be tested.
Returns:
List(Callable): Top-level assert functions from :mod:`torch.testing`.
"""
return [torch.testing.assert_equal, torch.testing.assert_close]
def make_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> List[Tuple[T, T]]:
"""Makes inputs for assert functions based on two example tensors.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
Returns:
List[Tuple[T, T]]: Pairs of tensors, tensor sequences (:class:`tuple`, :class:`list`), and tensor mappings
(:class:`dict`, :class:`~collections.OrderedDict`)
"""
return [
(actual, expected),
((actual,), (expected,)),
([actual], [expected]),
({"t": actual}, {"t": expected}),
(collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
]
def assert_fns_with_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> Iterator[Callable]:
"""Yields assert functions with with included positional inputs based on two example tensors.
.. note::
This is a valid product of combinations from :meth:`get_assert_fns` and :meth:`make_inputs`. Every test
that does not test for anything specific should iterate over this to maximize the coverage.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
Yields:
List[Callable]: Assert functions with predefined positional inputs.
"""
for assert_fn, inputs in itertools.product(self.get_assert_fns(), self.make_inputs(actual, expected)):
yield functools.partial(assert_fn, *inputs)
@onlyCPU
def test_not_tensors(self, device):
actual = torch.empty((), device=device)
expected = np.empty(())
for fn in self.assert_fns():
with self.assertRaises(AssertionError):
for fn in self.get_assert_fns():
with self.assertRaises(UsageError):
fn(actual, expected)
@onlyCPU
@ -758,18 +809,18 @@ class TestAsserts(TestCase):
actual = torch.ones(1, dtype=torch.float32, device=device)
expected = torch.ones(1, dtype=torch.complex64, device=device)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn(actual, expected, check_dtype=False)
fn(check_dtype=False)
@onlyCPU
def test_sparse_support(self, device):
actual = torch.empty((), device=device)
expected = torch.sparse_coo_tensor(size=(), device=device)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn(actual, expected)
fn()
@onlyCPU
def test_quantized_support(self, device):
@ -778,221 +829,246 @@ class TestAsserts(TestCase):
expected = torch._empty_affine_quantized(actual.shape, scale=1, zero_point=0, dtype=torch.qint32, device=device)
expected.fill_(val)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn(actual, expected)
fn()
@onlyCPU
def test_mismatching_shape(self, device):
actual = torch.empty((), device=device)
expected = actual.clone().reshape((1,))
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "shape"):
fn(actual, expected)
fn()
@onlyCUDA
def test_mismatching_device(self, device):
actual = torch.empty((), device=device)
expected = actual.clone().cpu()
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "device"):
fn(actual, expected)
fn()
@onlyCUDA
def test_mismatching_device_no_check(self, device):
actual = torch.rand((), device=device)
expected = actual.clone().cpu()
for fn in self.assert_fns():
fn(actual, expected, check_device=False)
for fn in self.assert_fns_with_inputs(actual, expected):
fn(check_device=False)
@onlyCPU
def test_mismatching_dtype(self, device):
actual = torch.empty((), dtype=torch.float, device=device)
expected = actual.clone().to(torch.int)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "dtype"):
fn(actual, expected)
fn()
@onlyCPU
def test_mismatching_dtype_no_check(self, device):
actual = torch.ones((), dtype=torch.float, device=device)
expected = actual.clone().to(torch.int)
for fn in self.assert_fns():
fn(actual, expected, check_dtype=False)
for fn in self.assert_fns_with_inputs(actual, expected):
fn(check_dtype=False)
@onlyCPU
def test_mismatching_stride(self, device):
actual = torch.empty((2, 2), device=device)
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "stride"):
fn(actual, expected)
fn()
@onlyCPU
def test_mismatching_stride_no_check(self, device):
actual = torch.rand((2, 2), device=device)
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
for fn in self.assert_fns():
fn(actual, expected, check_stride=False)
for fn in self.assert_fns_with_inputs(actual, expected):
fn(check_stride=False)
@onlyCPU
def test_mismatching_values(self, device):
actual = torch.tensor(1, device=device)
expected = torch.tensor(2, device=device)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(AssertionError):
fn()
@onlyCPU
def test_assert_equal(self, device):
actual = torch.tensor(1, device=device)
expected = actual.clone()
torch.testing.assert_equal(actual, expected)
@onlyCPU
def test_assert_close(self, device):
actual = torch.tensor(1.0, device=device)
expected = actual.clone()
torch.testing.assert_close(actual, expected)
@onlyCPU
def test_assert_close_only_rtol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_close(actual, expected, rtol=0.0)
@onlyCPU
def test_assert_close_only_atol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_close(actual, expected, atol=0.0)
@onlyCPU
def test_assert_close_mismatching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_close(actual, expected, rtol=eps / 2, atol=0.0)
@onlyCPU
def test_assert_close_matching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
torch.testing.assert_close(actual, expected, rtol=eps * 2, atol=0.0)
@onlyCPU
def test_assert_close_mismatching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps / 2)
@onlyCPU
def test_assert_close_matching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps * 2)
@onlyCPU
def test_mismatching_values_msg_mismatches(self, device):
actual = torch.tensor([1, 2, 3, 4], device=device)
expected = torch.tensor([1, 2, 5, 6], device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
fn()
@onlyCPU
def test_mismatching_values_msg_abs_diff(self, device):
actual = torch.tensor([[1, 2], [3, 4]], device=device)
expected = torch.tensor([[1, 2], [5, 4]], device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at (1, 0)")):
fn()
@onlyCPU
def test_mismatching_values_msg_rel_diff(self, device):
actual = torch.tensor([[1, 2], [3, 4]], device=device)
expected = torch.tensor([[1, 4], [3, 4]], device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at (0, 1)")):
fn()
@onlyCPU
def test_assert_close_mismatching_values_msg_rtol(self, device):
rtol = 1e-3
actual = torch.tensor(1, device=device)
expected = torch.tensor(2, device=device)
for inputs in self.make_inputs(actual, expected):
with self.assertRaisesRegex(
AssertionError, re.escape(f"Greatest relative difference: 0.5 at 0 (up to {rtol} allowed)")
):
torch.testing.assert_close(*inputs, rtol=rtol, atol=0.0)
@onlyCPU
def test_assert_close_mismatching_values_msg_atol(self, device):
atol = 1e-3
actual = torch.tensor(1, device=device)
expected = torch.tensor(2, device=device)
for inputs in self.make_inputs(actual, expected):
with self.assertRaisesRegex(
AssertionError, re.escape(f"Greatest absolute difference: 1 at 0 (up to {atol} allowed)")
):
torch.testing.assert_close(*inputs, rtol=0.0, atol=atol)
@onlyCPU
def test_unknown_type(self, device):
actual = torch.empty((), device=device)
expected = {actual.clone()}
for fn in self.get_assert_fns():
with self.assertRaisesRegex(UsageError, str(type(expected))):
fn(actual, expected)
@onlyCPU
def test_sequence_mismatching_len(self, device):
actual = (torch.empty((), device=device),)
expected = ()
for fn in self.get_assert_fns():
with self.assertRaises(AssertionError):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_msg_abs_mismatches(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
def test_sequence_mismatching_values_msg(self, device):
t1 = torch.tensor(1, device=device)
t2 = torch.tensor(2, device=device)
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
actual = (t1, t1)
expected = (t1, t2)
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"\s+2\s+"):
for fn in self.get_assert_fns():
with self.assertRaisesRegex(AssertionError, r"index\s+1"):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_msg_rel_mismatches(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
def test_mapping_mismatching_keys(self, device):
actual = {"a": torch.empty((), device=device)}
expected = {}
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"22([.]2+)?\s*[%]"):
for fn in self.get_assert_fns():
with self.assertRaises(AssertionError):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_msg_max_abs_diff(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
def test_mapping_mismatching_values_msg(self, device):
t1 = torch.tensor(1, device=device)
t2 = torch.tensor(2, device=device)
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
actual = {"a": t1, "b": t1}
expected = {"a": t1, "b": t2}
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"\s+4[.]0\s+"):
for fn in self.get_assert_fns():
with self.assertRaisesRegex(AssertionError, r"key\s+'b'"):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_max_abs_diff_idx(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"1,\s*2"):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_msg_max_rel_diff(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"\s+0[.]5\s+"):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_max_rel_diff_idx(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"0,\s*1"):
fn(actual, expected)
@onlyCPU
def test_assert_tensors_equal(self, device):
actual = torch.tensor(1, device=device)
expected = actual.clone()
torch.testing.assert_tensors_equal(actual, expected)
@onlyCPU
def test_assert_tensors_close(self, device):
actual = torch.tensor(1.0, device=device)
expected = actual.clone()
torch.testing.assert_tensors_close(actual, expected)
@onlyCPU
def test_assert_tensors_close_only_rtol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_tensors_close(actual, expected, rtol=0.0)
@onlyCPU
def test_assert_tensors_close_only_atol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_tensors_close(actual, expected, atol=0.0)
@onlyCPU
def test_assert_tensors_close_mismatching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_tensors_close(actual, expected, rtol=eps / 2, atol=0.0)
@onlyCPU
def test_assert_tensors_close_matching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
torch.testing.assert_tensors_close(actual, expected, rtol=eps * 2, atol=0.0)
@onlyCPU
def test_assert_tensors_close_mismatching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps / 2)
@onlyCPU
def test_assert_tensors_close_matching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps * 2)
instantiate_device_type_tests(TestAsserts, globals())

View File

@ -1,12 +1,15 @@
import collections.abc
import functools
import sys
from collections import namedtuple
from typing import Any, Optional, Tuple, Type
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union, cast
import torch
from torch import Tensor
from ._core import _unravel_index
__all__ = ["assert_tensors_equal", "assert_tensors_close"]
__all__ = ["assert_equal", "assert_close"]
# The UsageError should be raised in case the test function is not used correctly. With this the user is able to
@ -42,7 +45,7 @@ _DTYPE_PRECISIONS = {
}
def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
def _get_default_rtol_and_atol(actual: Tensor, expected: Tensor) -> Tuple[float, float]:
dtype = actual.dtype if actual.dtype == expected.dtype else torch.promote_types(actual.dtype, expected.dtype)
return _DTYPE_PRECISIONS.get(dtype, (0.0, 0.0))
@ -57,23 +60,23 @@ def _check_are_tensors(actual: Any, expected: Any) -> Optional[AssertionError]:
Returns:
(Optional[AssertionError]): If check did not pass.
"""
if not (isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor)):
if not (isinstance(actual, Tensor) and isinstance(expected, Tensor)):
return AssertionError(f"Both inputs have to be tensors, but got {type(actual)} and {type(expected)} instead.")
return None
def _check_supported_tensors(
actual: torch.Tensor,
expected: torch.Tensor,
actual: Tensor,
expected: Tensor,
) -> Optional[UsageError]: # type: ignore[valid-type]
"""Checks if the tensors are supported by the current infrastructure.
All checks are temporary and will be relaxed in the future.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
Returns:
(Optional[UsageError]): If check did not pass.
@ -89,8 +92,8 @@ def _check_supported_tensors(
def _check_attributes_equal(
actual: torch.Tensor,
expected: torch.Tensor,
actual: Tensor,
expected: Tensor,
*,
check_device: bool = True,
check_dtype: bool = True,
@ -102,8 +105,8 @@ def _check_attributes_equal(
:attr:`~torch.Tensor.dtype`, and :meth:`~torch.Tensor.stride` are optional and can be disabled.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
same :attr:`~torch.Tensor.device` memory.
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
@ -112,7 +115,7 @@ def _check_attributes_equal(
:meth:`~torch.Tensor.stride`.
Returns:
(Optional[AssertionError]): If check did not pass.
(Optional[AssertionError]): If checks did not pass.
"""
msg_fmtstr = "The values for attribute '{}' do not match: {} != {}."
@ -131,7 +134,7 @@ def _check_attributes_equal(
return None
def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _equalize_attributes(actual: Tensor, expected: Tensor) -> Tuple[Tensor, Tensor]:
"""Equalizes some attributes of two tensors for value comparison.
If :attr:`actual` and :attr:`expected`
@ -140,11 +143,11 @@ def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[
:func:`torch.promote_types`.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
Returns:
Tuple(torch.Tensor, torch.Tensor): Equalized tensors.
Tuple(Tensor, Tensor): Equalized tensors.
"""
if actual.device != expected.device:
actual = actual.cpu()
@ -172,13 +175,13 @@ _Trace = namedtuple(
)
def _trace_mismatches(actual: torch.Tensor, expected: torch.Tensor, mismatches: torch.Tensor) -> _Trace:
def _trace_mismatches(actual: Tensor, expected: Tensor, mismatches: Tensor) -> _Trace:
"""Traces mismatches.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
mismatches (torch.Tensor): Boolean mask of the same shape as :attr:`actual` and :attr:`expected` that indicates
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
mismatches (Tensor): Boolean mask of the same shape as :attr:`actual` and :attr:`expected` that indicates
the location of mismatches.
Returns:
@ -220,12 +223,12 @@ def _trace_mismatches(actual: torch.Tensor, expected: torch.Tensor, mismatches:
)
def _check_values_equal(actual: torch.Tensor, expected: torch.Tensor) -> Optional[AssertionError]:
def _check_values_equal(actual: Tensor, expected: Tensor) -> Optional[AssertionError]:
"""Checks if the values of two tensors are bitwise equal.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
Returns:
(Optional[AssertionError]): If check did not pass.
@ -244,8 +247,8 @@ def _check_values_equal(actual: torch.Tensor, expected: torch.Tensor) -> Optiona
def _check_values_close(
actual: torch.Tensor,
expected: torch.Tensor,
actual: Tensor,
expected: Tensor,
*,
rtol,
atol,
@ -253,8 +256,8 @@ def _check_values_close(
"""Checks if the values of two tensors are close up to a desired tolerance.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
@ -274,55 +277,89 @@ def _check_values_close(
)
def assert_tensors_equal(
actual: torch.Tensor,
expected: torch.Tensor,
def _check_tensors_equal(
actual: Tensor,
expected: Tensor,
*,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> None:
"""Asserts that the values of two tensors are bitwise equal.
) -> Optional[Exception]:
"""Checks that the values of two tensors are bitwise equal.
Optionally, checks that some attributes of both tensors are equal.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
same :attr:`~torch.Tensor.device` memory. If this check is disabled **and** :attr:`actual` and
:attr:`expected` are not on the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory before
their values are compared.
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
:attr:`~torch.Tensor.dtype`. If this check is disabled **and** :attr:`actual` and :attr:`expected` do not
have the same :attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
:func:`torch.promote_types` before their values are compared.
check_stride (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
stride.
For a description of the parameters see :func:`assert_equal`.
Raises:
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
restriction and will be relaxed in the future.
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
AssertionError: If :attr:`check_device`, but :attr:`actual` and :attr:`expected` are not on the same
:attr:`~torch.Tensor.device` memory.
AssertionError: If :attr:`check_dtype`, but :attr:`actual` and :attr:`expected` do not have the same
:attr:`~torch.Tensor.dtype`.
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
AssertionError: If the values of :attr:`actual` and :attr:`expected` are not bitwise equal.
.. seealso::
To assert that the values in two tensors are are close but are not required to be bitwise equal, use
:func:`assert_tensors_close` instead.
Returns:
Optional[Exception]: If checks did not pass.
"""
exc: Optional[Exception] = _check_are_tensors(actual, expected)
if exc:
raise exc
return exc
exc = _check_supported_tensors(actual, expected)
if exc:
raise exc
return exc
exc = _check_attributes_equal(
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
)
if exc:
return exc
actual, expected = _equalize_attributes(actual, expected)
exc = _check_values_equal(actual, expected)
if exc:
return exc
return None
def _check_tensors_close(
actual: Tensor,
expected: Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> Optional[Exception]:
r"""Checks that the values of two tensors are close.
Closeness is defined by
.. math::
\lvert a - b \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert b \rvert
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are
bitwise equal.
Optionally, checks that some attributes of both tensors are equal.
For a description of the parameters see :func:`assert_equal`.
Returns:
Optional[Exception]: If checks did not pass.
"""
exc: Optional[Exception] = _check_are_tensors(actual, expected)
if exc:
return exc
exc = _check_supported_tensors(actual, expected)
if exc:
return exc
if (rtol is None) ^ (atol is None):
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
return UsageError(
f"Both 'rtol' and 'atol' must be omitted or specified, " f"but got rtol={rtol} and atol={atol} instead."
)
elif rtol is None:
rtol, atol = _get_default_rtol_and_atol(actual, expected)
exc = _check_attributes_equal(
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
@ -331,14 +368,205 @@ def assert_tensors_equal(
raise exc
actual, expected = _equalize_attributes(actual, expected)
exc = _check_values_equal(actual, expected)
if (rtol == 0.0) and (atol == 0.0):
exc = _check_values_equal(actual, expected)
else:
exc = _check_values_close(actual, expected, rtol=rtol, atol=atol)
if exc:
return exc
return None
def _check_by_type(
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
check_tensors: Callable[[Tensor, Tensor], Optional[Exception]],
) -> Optional[Exception]:
"""Delegates tensor checking based on the inputs types.
Currently supports pairs of
- :class:`Tensor`'s,
- :class:`~collections.abc.Sequence`'s of :class:`Tensor`'s, and
- :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
Args:
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if two tensors match.
In case they mismatch should return an :class:`Exception` with an expressive error message.
Returns:
(Optional[Exception]): :class:`UsageError` if the inputs types are unsupported. Additionally, any exception
returned by :attr:`check_tensors`.
"""
# _check_are_tensors() returns nothing in case both inputs are tensors and an exception otherwise. Thus, the logic
# is inverted here.
are_tensors = not _check_are_tensors(actual, expected)
if are_tensors:
return check_tensors(cast(Tensor, actual), cast(Tensor, expected))
if isinstance(actual, collections.abc.Sequence) and isinstance(expected, collections.abc.Sequence):
return _check_sequence(actual, expected, check_tensors)
elif isinstance(actual, collections.abc.Mapping) and isinstance(expected, collections.abc.Mapping):
return _check_mapping(actual, expected, check_tensors)
return UsageError(
f"Both inputs have to be tensors, or sequences or mappings of tensors, "
f"but got {type(actual)} and {type(expected)} instead."
)
E = TypeVar("E", bound=Exception)
def _amend_error_message(exc: E, msg_fmtstr: str) -> E:
"""Amends an exception message.
Args:
exc (E): Exception.
msg_fmtstr: Format string for the amended message.
Returns:
(E): New exception with amended error message.
"""
return type(exc)(msg_fmtstr.format(str(exc)))
_SEQUENCE_MSG_FMTSTR = "The failure occurred at index {} of the sequences."
def _check_sequence(
actual: Sequence[Tensor], expected: Sequence[Tensor], check_tensors: Callable[[Tensor, Tensor], Optional[Exception]]
) -> Optional[Exception]:
"""Checks if the values of two sequences of tensors match.
Args:
actual (Sequence[Tensor]): Actual sequence of tensors.
expected (Sequence[Tensor]): Expected sequence of tensors.
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if the items of
:attr:`actual` and :attr:`expected` match. In case they mismatch should return an :class:`Exception` with
an expressive error message.
Returns:
Optional[Exception]: :class:`AssertionError` if the sequences do not have the same length. Additionally, any
exception returned by :attr:`check_tensors`. In this case, the error message is amended to include the
first offending index.
"""
actual_len = len(actual)
expected_len = len(expected)
if actual_len != expected_len:
return AssertionError(f"The length of the sequences mismatch: {actual_len} != {expected_len}")
for idx, (actual_t, expected_t) in enumerate(zip(actual, expected)):
exc = check_tensors(actual_t, expected_t)
if exc:
return _amend_error_message(exc, f"{{}}\n\n{_SEQUENCE_MSG_FMTSTR.format(idx)}")
return None
_MAPPING_MSG_FMTSTR = "The failure occurred for key '{}' of the mappings."
def _check_mapping(
actual: Mapping[Any, Tensor],
expected: Mapping[Any, Tensor],
check_tensors: Callable[[Tensor, Tensor], Optional[Exception]],
) -> Optional[Exception]:
"""Checks if the values of two mappings of tensors match.
Args:
actual (Mapping[Any, Tensor]): First mapping of tensors.
expected (Mapping[Any, Tensor]): Second mapping of tensors.
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if the values of
:attr:`actual` and :attr:`expected` match. In case they mismatch should return an :class:`Exception` with
an expressive error message.
Returns:
Optional[Exception]: :class:`AssertionError` if the sequences do not have the same set of keys. Additionally,
any exception returned by :attr:`check_tensors`. In this case, the error message is amended to include the
first offending key.
"""
actual_keys = set(actual.keys())
expected_keys = set(expected.keys())
if actual_keys != expected_keys:
missing_keys = expected_keys - actual_keys
additional_keys = actual_keys - expected_keys
return AssertionError(
f"The keys of the mappings do not match:\n\n"
f"Missing keys in the actual mapping: {sorted(missing_keys)}\n"
f"Additional keys in the actual mapping: {sorted(additional_keys)}\n"
)
for key in sorted(actual_keys):
actual_t = actual[key]
expected_t = expected[key]
exc = check_tensors(actual_t, expected_t)
if exc:
return _amend_error_message(exc, f"{{}}\n\n{_MAPPING_MSG_FMTSTR.format(key)}")
return None
def assert_equal(
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
*,
check_device: bool = True,
check_dtype: bool = True,
check_stride: bool = True,
) -> None:
"""Asserts that the values of tensors are bitwise equal.
Optionally, checks that some attributes of tensors are equal.
Also supports :class:`~collections.abc.Sequence`'s and :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
Args:
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
check_device (bool): If ``True`` (default), asserts that tensors live in the same :attr:`~torch.Tensor.device`
memory. If this check is disabled **and** they do not live in the same memory :attr:`~torch.Tensor.device`,
they are moved CPU memory before their values are compared.
check_dtype (bool): If ``True`` (default), asserts that tensors have the same :attr:`~torch.Tensor.dtype`. If
this check is disabled they do not have the same :attr:`~torch.Tensor.dtype`, they are copied to the
:class:`~torch.dtype` returned by :func:`torch.promote_types` before their values are compared.
check_stride (bool): If ``True`` (default), asserts that the tensors have the same stride.
Raises:
UsageError: If the input pair has an unsupported type.
UsageError: If any tensor is complex, quantized, or sparse. This is a temporary restriction and
will be relaxed in the future.
AssertionError: If any corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
AssertionError: If :attr:`check_device`, but any corresponding tensors do not live in the same
:attr:`~torch.Tensor.device` memory.
AssertionError: If :attr:`check_dtype`, but any corresponding tensors do not have the same
:attr:`~torch.Tensor.dtype`.
AssertionError: If :attr:`check_stride`, but any corresponding tensors do not have the same stride.
AssertionError: If the values of any corresponding tensors are not bitwise equal.
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys mismatch.
.. seealso::
To assert that the values in tensors are close but are not required to be bitwise equal, use
:func:`assert_close` instead.
"""
check_tensors = functools.partial(
_check_tensors_equal,
check_device=check_device,
check_dtype=check_dtype,
check_stride=check_stride,
)
exc = _check_by_type(actual, expected, check_tensors)
if exc:
raise exc
def assert_tensors_close(
actual: torch.Tensor,
expected: torch.Tensor,
def assert_close(
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
@ -346,14 +574,24 @@ def assert_tensors_close(
check_dtype: bool = True,
check_stride: bool = True,
) -> None:
"""Asserts that the values of two tensors are close up to a desired tolerance.
r"""Asserts that the values of tensors are close.
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are bitwise
equal. Optionally, checks that some attributes of both tensors are equal.
Closeness is defined by
.. math::
\lvert a - b \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert b \rvert
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are
bitwise equal.
Optionally, checks that some attributes of tensors are equal.
Also supports :class:`~collections.abc.Sequence`'s and :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
rtol (Optional[float]): Relative tolerance. If specified :attr:`atol` must also be specified. If omitted,
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted,
@ -370,6 +608,7 @@ def assert_tensors_close(
stride.
Raises:
UsageError: If the input pair has an unsupported type.
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
restriction and will be relaxed in the future.
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
@ -379,11 +618,11 @@ def assert_tensors_close(
:attr:`~torch.Tensor.dtype`.
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
AssertionError: If the values of :attr:`actual` and :attr:`expected` are close up to a desired tolerance.
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys mismatch.
The following table displays the default ``rtol`` and ``atol`` for floating point :attr:`~torch.Tensor.dtype`'s.
For integer :attr:`~torch.Tensor.dtype`'s, ``rtol = atol = 0.0`` is used.
The following table displays the default ``rtol``'s and ``atol``'s. Note that the :class:`~torch.dtype` refers to
the promoted type in case :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.dtype`.
+===========================+============+==========+
| :class:`~torch.dtype` | ``rtol`` | ``atol`` |
@ -402,38 +641,21 @@ def assert_tensors_close(
+---------------------------+------------+----------+
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
+---------------------------+------------+----------+
| other | ``0.0`` | ``0.0`` |
+---------------------------+------------+----------+
.. seealso::
To assert that the values in two tensors are bitwise equal, use :func:`assert_tensors_equal` instead.
To assert that the values in tensors are bitwise equal, use :func:`assert_equal` instead.
"""
exc: Optional[Exception] = _check_are_tensors(actual, expected)
if exc:
raise exc
exc = _check_supported_tensors(actual, expected)
if exc:
raise exc
if (rtol is None) ^ (atol is None):
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
raise UsageError(
f"Both 'rtol' and 'atol' must be omitted or specified, " f"but got rtol={rtol} and atol={atol} instead."
)
elif rtol is None:
rtol, atol = _get_default_rtol_and_atol(actual, expected)
exc = _check_attributes_equal(
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
check_tensors = functools.partial(
_check_tensors_close,
rtol=rtol,
atol=atol,
check_device=check_device,
check_dtype=check_dtype,
check_stride=check_stride,
)
if exc:
raise exc
actual, expected = _equalize_attributes(actual, expected)
if (rtol == 0.0) and (atol == 0.0):
exc = _check_values_equal(actual, expected)
else:
exc = _check_values_close(actual, expected, rtol=rtol, atol=atol)
exc = _check_by_type(actual, expected, check_tensors)
if exc:
raise exc