mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for checking tensor containers in torch.testing
(#55385)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55385 This renames `assert_tensors_(equal|close)` to `_check_tensors_(equal|close)` and exposes two new functions: `assert_(equal|close)`. In addition to tensor pairs, the newly added functions also support the comparison of tensors in sequences or mappings. Otherwise their signature stays the same. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D27903805 Pulled By: mruberry fbshipit-source-id: 719d19a1d26de8d14cb25846e3d22a6ac828c80a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bcef7ebd60
commit
dbf3451c6e
@ -1,11 +1,17 @@
|
||||
import torch
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import unittest
|
||||
from typing import Any, Callable, Iterator, List, Mapping, Sequence, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
|
||||
from torch.testing._internal.framework_utils import calculate_shards
|
||||
@ -740,17 +746,62 @@ if __name__ == '__main__':
|
||||
self.assertNotIn('OK', stderr.decode('ascii'))
|
||||
|
||||
|
||||
T = TypeVar("T", torch.Tensor, Sequence[torch.Tensor], Mapping[Any, torch.Tensor])
|
||||
|
||||
|
||||
class TestAsserts(TestCase):
|
||||
def assert_fns(self):
|
||||
return [torch.testing.assert_tensors_equal, torch.testing.assert_tensors_close]
|
||||
def get_assert_fns(self) -> List[Callable]:
|
||||
"""Gets assert functions to be tested.
|
||||
|
||||
Returns:
|
||||
List(Callable): Top-level assert functions from :mod:`torch.testing`.
|
||||
"""
|
||||
return [torch.testing.assert_equal, torch.testing.assert_close]
|
||||
|
||||
def make_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> List[Tuple[T, T]]:
|
||||
"""Makes inputs for assert functions based on two example tensors.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
|
||||
Returns:
|
||||
List[Tuple[T, T]]: Pairs of tensors, tensor sequences (:class:`tuple`, :class:`list`), and tensor mappings
|
||||
(:class:`dict`, :class:`~collections.OrderedDict`)
|
||||
"""
|
||||
return [
|
||||
(actual, expected),
|
||||
((actual,), (expected,)),
|
||||
([actual], [expected]),
|
||||
({"t": actual}, {"t": expected}),
|
||||
(collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
|
||||
]
|
||||
|
||||
def assert_fns_with_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> Iterator[Callable]:
|
||||
"""Yields assert functions with with included positional inputs based on two example tensors.
|
||||
|
||||
.. note::
|
||||
|
||||
This is a valid product of combinations from :meth:`get_assert_fns` and :meth:`make_inputs`. Every test
|
||||
that does not test for anything specific should iterate over this to maximize the coverage.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
|
||||
Yields:
|
||||
List[Callable]: Assert functions with predefined positional inputs.
|
||||
"""
|
||||
for assert_fn, inputs in itertools.product(self.get_assert_fns(), self.make_inputs(actual, expected)):
|
||||
yield functools.partial(assert_fn, *inputs)
|
||||
|
||||
@onlyCPU
|
||||
def test_not_tensors(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = np.empty(())
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaises(AssertionError):
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaises(UsageError):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
@ -758,18 +809,18 @@ class TestAsserts(TestCase):
|
||||
actual = torch.ones(1, dtype=torch.float32, device=device)
|
||||
expected = torch.ones(1, dtype=torch.complex64, device=device)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaises(UsageError):
|
||||
fn(actual, expected, check_dtype=False)
|
||||
fn(check_dtype=False)
|
||||
|
||||
@onlyCPU
|
||||
def test_sparse_support(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = torch.sparse_coo_tensor(size=(), device=device)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaises(UsageError):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_quantized_support(self, device):
|
||||
@ -778,221 +829,246 @@ class TestAsserts(TestCase):
|
||||
expected = torch._empty_affine_quantized(actual.shape, scale=1, zero_point=0, dtype=torch.qint32, device=device)
|
||||
expected.fill_(val)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaises(UsageError):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_shape(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone().reshape((1,))
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "shape"):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCUDA
|
||||
def test_mismatching_device(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone().cpu()
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "device"):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCUDA
|
||||
def test_mismatching_device_no_check(self, device):
|
||||
actual = torch.rand((), device=device)
|
||||
expected = actual.clone().cpu()
|
||||
|
||||
for fn in self.assert_fns():
|
||||
fn(actual, expected, check_device=False)
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
fn(check_device=False)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_dtype(self, device):
|
||||
actual = torch.empty((), dtype=torch.float, device=device)
|
||||
expected = actual.clone().to(torch.int)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "dtype"):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_dtype_no_check(self, device):
|
||||
actual = torch.ones((), dtype=torch.float, device=device)
|
||||
expected = actual.clone().to(torch.int)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
fn(actual, expected, check_dtype=False)
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
fn(check_dtype=False)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_stride(self, device):
|
||||
actual = torch.empty((2, 2), device=device)
|
||||
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "stride"):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_stride_no_check(self, device):
|
||||
actual = torch.rand((2, 2), device=device)
|
||||
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
||||
|
||||
for fn in self.assert_fns():
|
||||
fn(actual, expected, check_stride=False)
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
fn(check_stride=False)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values(self, device):
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = torch.tensor(2, device=device)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaises(AssertionError):
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_equal(self, device):
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
torch.testing.assert_equal(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close(self, device):
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_only_rtol(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
with self.assertRaises(UsageError):
|
||||
torch.testing.assert_close(actual, expected, rtol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_only_atol(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
with self.assertRaises(UsageError):
|
||||
torch.testing.assert_close(actual, expected, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_mismatching_values_rtol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = torch.tensor(1.0 + eps, device=device)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(actual, expected, rtol=eps / 2, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_matching_values_rtol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = torch.tensor(1.0 + eps, device=device)
|
||||
|
||||
torch.testing.assert_close(actual, expected, rtol=eps * 2, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_mismatching_values_atol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(0.0, device=device)
|
||||
expected = torch.tensor(eps, device=device)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps / 2)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_matching_values_atol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(0.0, device=device)
|
||||
expected = torch.tensor(eps, device=device)
|
||||
|
||||
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps * 2)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_mismatches(self, device):
|
||||
actual = torch.tensor([1, 2, 3, 4], device=device)
|
||||
expected = torch.tensor([1, 2, 5, 6], device=device)
|
||||
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_abs_diff(self, device):
|
||||
actual = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||
expected = torch.tensor([[1, 2], [5, 4]], device=device)
|
||||
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at (1, 0)")):
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_rel_diff(self, device):
|
||||
actual = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||
expected = torch.tensor([[1, 4], [3, 4]], device=device)
|
||||
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at (0, 1)")):
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_mismatching_values_msg_rtol(self, device):
|
||||
rtol = 1e-3
|
||||
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = torch.tensor(2, device=device)
|
||||
|
||||
for inputs in self.make_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, re.escape(f"Greatest relative difference: 0.5 at 0 (up to {rtol} allowed)")
|
||||
):
|
||||
torch.testing.assert_close(*inputs, rtol=rtol, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_mismatching_values_msg_atol(self, device):
|
||||
atol = 1e-3
|
||||
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = torch.tensor(2, device=device)
|
||||
|
||||
for inputs in self.make_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, re.escape(f"Greatest absolute difference: 1 at 0 (up to {atol} allowed)")
|
||||
):
|
||||
torch.testing.assert_close(*inputs, rtol=0.0, atol=atol)
|
||||
|
||||
@onlyCPU
|
||||
def test_unknown_type(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = {actual.clone()}
|
||||
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaisesRegex(UsageError, str(type(expected))):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_sequence_mismatching_len(self, device):
|
||||
actual = (torch.empty((), device=device),)
|
||||
expected = ()
|
||||
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaises(AssertionError):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_abs_mismatches(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
def test_sequence_mismatching_values_msg(self, device):
|
||||
t1 = torch.tensor(1, device=device)
|
||||
t2 = torch.tensor(2, device=device)
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
actual = (t1, t1)
|
||||
expected = (t1, t2)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"\s+2\s+"):
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"index\s+1"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_rel_mismatches(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
def test_mapping_mismatching_keys(self, device):
|
||||
actual = {"a": torch.empty((), device=device)}
|
||||
expected = {}
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"22([.]2+)?\s*[%]"):
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaises(AssertionError):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_max_abs_diff(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
def test_mapping_mismatching_values_msg(self, device):
|
||||
t1 = torch.tensor(1, device=device)
|
||||
t2 = torch.tensor(2, device=device)
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
actual = {"a": t1, "b": t1}
|
||||
expected = {"a": t1, "b": t2}
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"\s+4[.]0\s+"):
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"key\s+'b'"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_max_abs_diff_idx(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"1,\s*2"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_max_rel_diff(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"\s+0[.]5\s+"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_max_rel_diff_idx(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"0,\s*1"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_equal(self, device):
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
torch.testing.assert_tensors_equal(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close(self, device):
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
torch.testing.assert_tensors_close(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_only_rtol(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
with self.assertRaises(UsageError):
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_only_atol(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
with self.assertRaises(UsageError):
|
||||
torch.testing.assert_tensors_close(actual, expected, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_mismatching_values_rtol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = torch.tensor(1.0 + eps, device=device)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=eps / 2, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_matching_values_rtol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = torch.tensor(1.0 + eps, device=device)
|
||||
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=eps * 2, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_mismatching_values_atol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(0.0, device=device)
|
||||
expected = torch.tensor(eps, device=device)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps / 2)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_matching_values_atol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(0.0, device=device)
|
||||
expected = torch.tensor(eps, device=device)
|
||||
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps * 2)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestAsserts, globals())
|
||||
|
||||
|
@ -1,12 +1,15 @@
|
||||
import collections.abc
|
||||
import functools
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from typing import Any, Optional, Tuple, Type
|
||||
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ._core import _unravel_index
|
||||
|
||||
__all__ = ["assert_tensors_equal", "assert_tensors_close"]
|
||||
__all__ = ["assert_equal", "assert_close"]
|
||||
|
||||
|
||||
# The UsageError should be raised in case the test function is not used correctly. With this the user is able to
|
||||
@ -42,7 +45,7 @@ _DTYPE_PRECISIONS = {
|
||||
}
|
||||
|
||||
|
||||
def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
|
||||
def _get_default_rtol_and_atol(actual: Tensor, expected: Tensor) -> Tuple[float, float]:
|
||||
dtype = actual.dtype if actual.dtype == expected.dtype else torch.promote_types(actual.dtype, expected.dtype)
|
||||
return _DTYPE_PRECISIONS.get(dtype, (0.0, 0.0))
|
||||
|
||||
@ -57,23 +60,23 @@ def _check_are_tensors(actual: Any, expected: Any) -> Optional[AssertionError]:
|
||||
Returns:
|
||||
(Optional[AssertionError]): If check did not pass.
|
||||
"""
|
||||
if not (isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor)):
|
||||
if not (isinstance(actual, Tensor) and isinstance(expected, Tensor)):
|
||||
return AssertionError(f"Both inputs have to be tensors, but got {type(actual)} and {type(expected)} instead.")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _check_supported_tensors(
|
||||
actual: torch.Tensor,
|
||||
expected: torch.Tensor,
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
) -> Optional[UsageError]: # type: ignore[valid-type]
|
||||
"""Checks if the tensors are supported by the current infrastructure.
|
||||
|
||||
All checks are temporary and will be relaxed in the future.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
actual (Tensor): Actual tensor.
|
||||
expected (Tensor): Expected tensor.
|
||||
|
||||
Returns:
|
||||
(Optional[UsageError]): If check did not pass.
|
||||
@ -89,8 +92,8 @@ def _check_supported_tensors(
|
||||
|
||||
|
||||
def _check_attributes_equal(
|
||||
actual: torch.Tensor,
|
||||
expected: torch.Tensor,
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
*,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
@ -102,8 +105,8 @@ def _check_attributes_equal(
|
||||
:attr:`~torch.Tensor.dtype`, and :meth:`~torch.Tensor.stride` are optional and can be disabled.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
actual (Tensor): Actual tensor.
|
||||
expected (Tensor): Expected tensor.
|
||||
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
|
||||
same :attr:`~torch.Tensor.device` memory.
|
||||
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
|
||||
@ -112,7 +115,7 @@ def _check_attributes_equal(
|
||||
:meth:`~torch.Tensor.stride`.
|
||||
|
||||
Returns:
|
||||
(Optional[AssertionError]): If check did not pass.
|
||||
(Optional[AssertionError]): If checks did not pass.
|
||||
"""
|
||||
msg_fmtstr = "The values for attribute '{}' do not match: {} != {}."
|
||||
|
||||
@ -131,7 +134,7 @@ def _check_attributes_equal(
|
||||
return None
|
||||
|
||||
|
||||
def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _equalize_attributes(actual: Tensor, expected: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Equalizes some attributes of two tensors for value comparison.
|
||||
|
||||
If :attr:`actual` and :attr:`expected`
|
||||
@ -140,11 +143,11 @@ def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[
|
||||
:func:`torch.promote_types`.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
actual (Tensor): Actual tensor.
|
||||
expected (Tensor): Expected tensor.
|
||||
|
||||
Returns:
|
||||
Tuple(torch.Tensor, torch.Tensor): Equalized tensors.
|
||||
Tuple(Tensor, Tensor): Equalized tensors.
|
||||
"""
|
||||
if actual.device != expected.device:
|
||||
actual = actual.cpu()
|
||||
@ -172,13 +175,13 @@ _Trace = namedtuple(
|
||||
)
|
||||
|
||||
|
||||
def _trace_mismatches(actual: torch.Tensor, expected: torch.Tensor, mismatches: torch.Tensor) -> _Trace:
|
||||
def _trace_mismatches(actual: Tensor, expected: Tensor, mismatches: Tensor) -> _Trace:
|
||||
"""Traces mismatches.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
mismatches (torch.Tensor): Boolean mask of the same shape as :attr:`actual` and :attr:`expected` that indicates
|
||||
actual (Tensor): Actual tensor.
|
||||
expected (Tensor): Expected tensor.
|
||||
mismatches (Tensor): Boolean mask of the same shape as :attr:`actual` and :attr:`expected` that indicates
|
||||
the location of mismatches.
|
||||
|
||||
Returns:
|
||||
@ -220,12 +223,12 @@ def _trace_mismatches(actual: torch.Tensor, expected: torch.Tensor, mismatches:
|
||||
)
|
||||
|
||||
|
||||
def _check_values_equal(actual: torch.Tensor, expected: torch.Tensor) -> Optional[AssertionError]:
|
||||
def _check_values_equal(actual: Tensor, expected: Tensor) -> Optional[AssertionError]:
|
||||
"""Checks if the values of two tensors are bitwise equal.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
actual (Tensor): Actual tensor.
|
||||
expected (Tensor): Expected tensor.
|
||||
|
||||
Returns:
|
||||
(Optional[AssertionError]): If check did not pass.
|
||||
@ -244,8 +247,8 @@ def _check_values_equal(actual: torch.Tensor, expected: torch.Tensor) -> Optiona
|
||||
|
||||
|
||||
def _check_values_close(
|
||||
actual: torch.Tensor,
|
||||
expected: torch.Tensor,
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
*,
|
||||
rtol,
|
||||
atol,
|
||||
@ -253,8 +256,8 @@ def _check_values_close(
|
||||
"""Checks if the values of two tensors are close up to a desired tolerance.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
actual (Tensor): Actual tensor.
|
||||
expected (Tensor): Expected tensor.
|
||||
rtol (float): Relative tolerance.
|
||||
atol (float): Absolute tolerance.
|
||||
|
||||
@ -274,55 +277,89 @@ def _check_values_close(
|
||||
)
|
||||
|
||||
|
||||
def assert_tensors_equal(
|
||||
actual: torch.Tensor,
|
||||
expected: torch.Tensor,
|
||||
def _check_tensors_equal(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
*,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_stride: bool = True,
|
||||
) -> None:
|
||||
"""Asserts that the values of two tensors are bitwise equal.
|
||||
) -> Optional[Exception]:
|
||||
"""Checks that the values of two tensors are bitwise equal.
|
||||
|
||||
Optionally, checks that some attributes of both tensors are equal.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
|
||||
same :attr:`~torch.Tensor.device` memory. If this check is disabled **and** :attr:`actual` and
|
||||
:attr:`expected` are not on the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory before
|
||||
their values are compared.
|
||||
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
|
||||
:attr:`~torch.Tensor.dtype`. If this check is disabled **and** :attr:`actual` and :attr:`expected` do not
|
||||
have the same :attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
|
||||
:func:`torch.promote_types` before their values are compared.
|
||||
check_stride (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
|
||||
stride.
|
||||
For a description of the parameters see :func:`assert_equal`.
|
||||
|
||||
Raises:
|
||||
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
|
||||
restriction and will be relaxed in the future.
|
||||
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
|
||||
AssertionError: If :attr:`check_device`, but :attr:`actual` and :attr:`expected` are not on the same
|
||||
:attr:`~torch.Tensor.device` memory.
|
||||
AssertionError: If :attr:`check_dtype`, but :attr:`actual` and :attr:`expected` do not have the same
|
||||
:attr:`~torch.Tensor.dtype`.
|
||||
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
|
||||
AssertionError: If the values of :attr:`actual` and :attr:`expected` are not bitwise equal.
|
||||
|
||||
.. seealso::
|
||||
|
||||
To assert that the values in two tensors are are close but are not required to be bitwise equal, use
|
||||
:func:`assert_tensors_close` instead.
|
||||
Returns:
|
||||
Optional[Exception]: If checks did not pass.
|
||||
"""
|
||||
exc: Optional[Exception] = _check_are_tensors(actual, expected)
|
||||
if exc:
|
||||
raise exc
|
||||
return exc
|
||||
|
||||
exc = _check_supported_tensors(actual, expected)
|
||||
if exc:
|
||||
raise exc
|
||||
return exc
|
||||
|
||||
exc = _check_attributes_equal(
|
||||
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
|
||||
)
|
||||
if exc:
|
||||
return exc
|
||||
actual, expected = _equalize_attributes(actual, expected)
|
||||
|
||||
exc = _check_values_equal(actual, expected)
|
||||
if exc:
|
||||
return exc
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _check_tensors_close(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
*,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_stride: bool = True,
|
||||
) -> Optional[Exception]:
|
||||
r"""Checks that the values of two tensors are close.
|
||||
|
||||
Closeness is defined by
|
||||
|
||||
.. math::
|
||||
|
||||
\lvert a - b \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert b \rvert
|
||||
|
||||
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are
|
||||
bitwise equal.
|
||||
|
||||
Optionally, checks that some attributes of both tensors are equal.
|
||||
|
||||
For a description of the parameters see :func:`assert_equal`.
|
||||
|
||||
Returns:
|
||||
Optional[Exception]: If checks did not pass.
|
||||
"""
|
||||
exc: Optional[Exception] = _check_are_tensors(actual, expected)
|
||||
if exc:
|
||||
return exc
|
||||
|
||||
exc = _check_supported_tensors(actual, expected)
|
||||
if exc:
|
||||
return exc
|
||||
|
||||
if (rtol is None) ^ (atol is None):
|
||||
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
|
||||
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
|
||||
return UsageError(
|
||||
f"Both 'rtol' and 'atol' must be omitted or specified, " f"but got rtol={rtol} and atol={atol} instead."
|
||||
)
|
||||
elif rtol is None:
|
||||
rtol, atol = _get_default_rtol_and_atol(actual, expected)
|
||||
|
||||
exc = _check_attributes_equal(
|
||||
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
|
||||
@ -331,14 +368,205 @@ def assert_tensors_equal(
|
||||
raise exc
|
||||
actual, expected = _equalize_attributes(actual, expected)
|
||||
|
||||
exc = _check_values_equal(actual, expected)
|
||||
if (rtol == 0.0) and (atol == 0.0):
|
||||
exc = _check_values_equal(actual, expected)
|
||||
else:
|
||||
exc = _check_values_close(actual, expected, rtol=rtol, atol=atol)
|
||||
if exc:
|
||||
return exc
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _check_by_type(
|
||||
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||
check_tensors: Callable[[Tensor, Tensor], Optional[Exception]],
|
||||
) -> Optional[Exception]:
|
||||
"""Delegates tensor checking based on the inputs types.
|
||||
|
||||
Currently supports pairs of
|
||||
|
||||
- :class:`Tensor`'s,
|
||||
- :class:`~collections.abc.Sequence`'s of :class:`Tensor`'s, and
|
||||
- :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
|
||||
|
||||
Args:
|
||||
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
|
||||
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
|
||||
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if two tensors match.
|
||||
In case they mismatch should return an :class:`Exception` with an expressive error message.
|
||||
|
||||
Returns:
|
||||
(Optional[Exception]): :class:`UsageError` if the inputs types are unsupported. Additionally, any exception
|
||||
returned by :attr:`check_tensors`.
|
||||
"""
|
||||
# _check_are_tensors() returns nothing in case both inputs are tensors and an exception otherwise. Thus, the logic
|
||||
# is inverted here.
|
||||
are_tensors = not _check_are_tensors(actual, expected)
|
||||
if are_tensors:
|
||||
return check_tensors(cast(Tensor, actual), cast(Tensor, expected))
|
||||
|
||||
if isinstance(actual, collections.abc.Sequence) and isinstance(expected, collections.abc.Sequence):
|
||||
return _check_sequence(actual, expected, check_tensors)
|
||||
elif isinstance(actual, collections.abc.Mapping) and isinstance(expected, collections.abc.Mapping):
|
||||
return _check_mapping(actual, expected, check_tensors)
|
||||
|
||||
return UsageError(
|
||||
f"Both inputs have to be tensors, or sequences or mappings of tensors, "
|
||||
f"but got {type(actual)} and {type(expected)} instead."
|
||||
)
|
||||
|
||||
|
||||
E = TypeVar("E", bound=Exception)
|
||||
|
||||
|
||||
def _amend_error_message(exc: E, msg_fmtstr: str) -> E:
|
||||
"""Amends an exception message.
|
||||
|
||||
Args:
|
||||
exc (E): Exception.
|
||||
msg_fmtstr: Format string for the amended message.
|
||||
|
||||
Returns:
|
||||
(E): New exception with amended error message.
|
||||
"""
|
||||
return type(exc)(msg_fmtstr.format(str(exc)))
|
||||
|
||||
|
||||
_SEQUENCE_MSG_FMTSTR = "The failure occurred at index {} of the sequences."
|
||||
|
||||
|
||||
def _check_sequence(
|
||||
actual: Sequence[Tensor], expected: Sequence[Tensor], check_tensors: Callable[[Tensor, Tensor], Optional[Exception]]
|
||||
) -> Optional[Exception]:
|
||||
"""Checks if the values of two sequences of tensors match.
|
||||
|
||||
Args:
|
||||
actual (Sequence[Tensor]): Actual sequence of tensors.
|
||||
expected (Sequence[Tensor]): Expected sequence of tensors.
|
||||
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if the items of
|
||||
:attr:`actual` and :attr:`expected` match. In case they mismatch should return an :class:`Exception` with
|
||||
an expressive error message.
|
||||
|
||||
Returns:
|
||||
Optional[Exception]: :class:`AssertionError` if the sequences do not have the same length. Additionally, any
|
||||
exception returned by :attr:`check_tensors`. In this case, the error message is amended to include the
|
||||
first offending index.
|
||||
"""
|
||||
actual_len = len(actual)
|
||||
expected_len = len(expected)
|
||||
if actual_len != expected_len:
|
||||
return AssertionError(f"The length of the sequences mismatch: {actual_len} != {expected_len}")
|
||||
for idx, (actual_t, expected_t) in enumerate(zip(actual, expected)):
|
||||
exc = check_tensors(actual_t, expected_t)
|
||||
if exc:
|
||||
return _amend_error_message(exc, f"{{}}\n\n{_SEQUENCE_MSG_FMTSTR.format(idx)}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
_MAPPING_MSG_FMTSTR = "The failure occurred for key '{}' of the mappings."
|
||||
|
||||
|
||||
def _check_mapping(
|
||||
actual: Mapping[Any, Tensor],
|
||||
expected: Mapping[Any, Tensor],
|
||||
check_tensors: Callable[[Tensor, Tensor], Optional[Exception]],
|
||||
) -> Optional[Exception]:
|
||||
"""Checks if the values of two mappings of tensors match.
|
||||
|
||||
Args:
|
||||
actual (Mapping[Any, Tensor]): First mapping of tensors.
|
||||
expected (Mapping[Any, Tensor]): Second mapping of tensors.
|
||||
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if the values of
|
||||
:attr:`actual` and :attr:`expected` match. In case they mismatch should return an :class:`Exception` with
|
||||
an expressive error message.
|
||||
|
||||
Returns:
|
||||
Optional[Exception]: :class:`AssertionError` if the sequences do not have the same set of keys. Additionally,
|
||||
any exception returned by :attr:`check_tensors`. In this case, the error message is amended to include the
|
||||
first offending key.
|
||||
"""
|
||||
actual_keys = set(actual.keys())
|
||||
expected_keys = set(expected.keys())
|
||||
if actual_keys != expected_keys:
|
||||
missing_keys = expected_keys - actual_keys
|
||||
additional_keys = actual_keys - expected_keys
|
||||
return AssertionError(
|
||||
f"The keys of the mappings do not match:\n\n"
|
||||
f"Missing keys in the actual mapping: {sorted(missing_keys)}\n"
|
||||
f"Additional keys in the actual mapping: {sorted(additional_keys)}\n"
|
||||
)
|
||||
for key in sorted(actual_keys):
|
||||
actual_t = actual[key]
|
||||
expected_t = expected[key]
|
||||
|
||||
exc = check_tensors(actual_t, expected_t)
|
||||
if exc:
|
||||
return _amend_error_message(exc, f"{{}}\n\n{_MAPPING_MSG_FMTSTR.format(key)}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def assert_equal(
|
||||
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||
*,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_stride: bool = True,
|
||||
) -> None:
|
||||
"""Asserts that the values of tensors are bitwise equal.
|
||||
|
||||
Optionally, checks that some attributes of tensors are equal.
|
||||
|
||||
Also supports :class:`~collections.abc.Sequence`'s and :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
|
||||
|
||||
Args:
|
||||
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
|
||||
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
|
||||
check_device (bool): If ``True`` (default), asserts that tensors live in the same :attr:`~torch.Tensor.device`
|
||||
memory. If this check is disabled **and** they do not live in the same memory :attr:`~torch.Tensor.device`,
|
||||
they are moved CPU memory before their values are compared.
|
||||
check_dtype (bool): If ``True`` (default), asserts that tensors have the same :attr:`~torch.Tensor.dtype`. If
|
||||
this check is disabled they do not have the same :attr:`~torch.Tensor.dtype`, they are copied to the
|
||||
:class:`~torch.dtype` returned by :func:`torch.promote_types` before their values are compared.
|
||||
check_stride (bool): If ``True`` (default), asserts that the tensors have the same stride.
|
||||
|
||||
Raises:
|
||||
UsageError: If the input pair has an unsupported type.
|
||||
UsageError: If any tensor is complex, quantized, or sparse. This is a temporary restriction and
|
||||
will be relaxed in the future.
|
||||
AssertionError: If any corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
|
||||
AssertionError: If :attr:`check_device`, but any corresponding tensors do not live in the same
|
||||
:attr:`~torch.Tensor.device` memory.
|
||||
AssertionError: If :attr:`check_dtype`, but any corresponding tensors do not have the same
|
||||
:attr:`~torch.Tensor.dtype`.
|
||||
AssertionError: If :attr:`check_stride`, but any corresponding tensors do not have the same stride.
|
||||
AssertionError: If the values of any corresponding tensors are not bitwise equal.
|
||||
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
|
||||
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys mismatch.
|
||||
|
||||
.. seealso::
|
||||
|
||||
To assert that the values in tensors are close but are not required to be bitwise equal, use
|
||||
:func:`assert_close` instead.
|
||||
"""
|
||||
check_tensors = functools.partial(
|
||||
_check_tensors_equal,
|
||||
check_device=check_device,
|
||||
check_dtype=check_dtype,
|
||||
check_stride=check_stride,
|
||||
)
|
||||
exc = _check_by_type(actual, expected, check_tensors)
|
||||
if exc:
|
||||
raise exc
|
||||
|
||||
|
||||
def assert_tensors_close(
|
||||
actual: torch.Tensor,
|
||||
expected: torch.Tensor,
|
||||
def assert_close(
|
||||
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||
*,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
@ -346,14 +574,24 @@ def assert_tensors_close(
|
||||
check_dtype: bool = True,
|
||||
check_stride: bool = True,
|
||||
) -> None:
|
||||
"""Asserts that the values of two tensors are close up to a desired tolerance.
|
||||
r"""Asserts that the values of tensors are close.
|
||||
|
||||
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are bitwise
|
||||
equal. Optionally, checks that some attributes of both tensors are equal.
|
||||
Closeness is defined by
|
||||
|
||||
.. math::
|
||||
|
||||
\lvert a - b \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert b \rvert
|
||||
|
||||
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are
|
||||
bitwise equal.
|
||||
|
||||
Optionally, checks that some attributes of tensors are equal.
|
||||
|
||||
Also supports :class:`~collections.abc.Sequence`'s and :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
|
||||
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
|
||||
rtol (Optional[float]): Relative tolerance. If specified :attr:`atol` must also be specified. If omitted,
|
||||
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
|
||||
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted,
|
||||
@ -370,6 +608,7 @@ def assert_tensors_close(
|
||||
stride.
|
||||
|
||||
Raises:
|
||||
UsageError: If the input pair has an unsupported type.
|
||||
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
|
||||
restriction and will be relaxed in the future.
|
||||
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
|
||||
@ -379,11 +618,11 @@ def assert_tensors_close(
|
||||
:attr:`~torch.Tensor.dtype`.
|
||||
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
|
||||
AssertionError: If the values of :attr:`actual` and :attr:`expected` are close up to a desired tolerance.
|
||||
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
|
||||
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys mismatch.
|
||||
|
||||
|
||||
|
||||
The following table displays the default ``rtol`` and ``atol`` for floating point :attr:`~torch.Tensor.dtype`'s.
|
||||
For integer :attr:`~torch.Tensor.dtype`'s, ``rtol = atol = 0.0`` is used.
|
||||
The following table displays the default ``rtol``'s and ``atol``'s. Note that the :class:`~torch.dtype` refers to
|
||||
the promoted type in case :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.dtype`.
|
||||
|
||||
+===========================+============+==========+
|
||||
| :class:`~torch.dtype` | ``rtol`` | ``atol`` |
|
||||
@ -402,38 +641,21 @@ def assert_tensors_close(
|
||||
+---------------------------+------------+----------+
|
||||
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
|
||||
+---------------------------+------------+----------+
|
||||
| other | ``0.0`` | ``0.0`` |
|
||||
+---------------------------+------------+----------+
|
||||
|
||||
.. seealso::
|
||||
|
||||
To assert that the values in two tensors are bitwise equal, use :func:`assert_tensors_equal` instead.
|
||||
To assert that the values in tensors are bitwise equal, use :func:`assert_equal` instead.
|
||||
"""
|
||||
exc: Optional[Exception] = _check_are_tensors(actual, expected)
|
||||
if exc:
|
||||
raise exc
|
||||
|
||||
exc = _check_supported_tensors(actual, expected)
|
||||
if exc:
|
||||
raise exc
|
||||
|
||||
if (rtol is None) ^ (atol is None):
|
||||
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
|
||||
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
|
||||
raise UsageError(
|
||||
f"Both 'rtol' and 'atol' must be omitted or specified, " f"but got rtol={rtol} and atol={atol} instead."
|
||||
)
|
||||
elif rtol is None:
|
||||
rtol, atol = _get_default_rtol_and_atol(actual, expected)
|
||||
|
||||
exc = _check_attributes_equal(
|
||||
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
|
||||
check_tensors = functools.partial(
|
||||
_check_tensors_close,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
check_device=check_device,
|
||||
check_dtype=check_dtype,
|
||||
check_stride=check_stride,
|
||||
)
|
||||
if exc:
|
||||
raise exc
|
||||
actual, expected = _equalize_attributes(actual, expected)
|
||||
|
||||
if (rtol == 0.0) and (atol == 0.0):
|
||||
exc = _check_values_equal(actual, expected)
|
||||
else:
|
||||
exc = _check_values_close(actual, expected, rtol=rtol, atol=atol)
|
||||
exc = _check_by_type(actual, expected, check_tensors)
|
||||
if exc:
|
||||
raise exc
|
||||
|
Reference in New Issue
Block a user