mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add support for checking tensor containers in torch.testing
(#55385)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55385 This renames `assert_tensors_(equal|close)` to `_check_tensors_(equal|close)` and exposes two new functions: `assert_(equal|close)`. In addition to tensor pairs, the newly added functions also support the comparison of tensors in sequences or mappings. Otherwise their signature stays the same. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D27903805 Pulled By: mruberry fbshipit-source-id: 719d19a1d26de8d14cb25846e3d22a6ac828c80a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bcef7ebd60
commit
dbf3451c6e
@ -1,11 +1,17 @@
|
|||||||
import torch
|
import collections
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Any, Callable, Iterator, List, Mapping, Sequence, Tuple, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import \
|
from torch.testing._internal.common_utils import \
|
||||||
(IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
|
(IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
|
||||||
from torch.testing._internal.framework_utils import calculate_shards
|
from torch.testing._internal.framework_utils import calculate_shards
|
||||||
@ -740,17 +746,62 @@ if __name__ == '__main__':
|
|||||||
self.assertNotIn('OK', stderr.decode('ascii'))
|
self.assertNotIn('OK', stderr.decode('ascii'))
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", torch.Tensor, Sequence[torch.Tensor], Mapping[Any, torch.Tensor])
|
||||||
|
|
||||||
|
|
||||||
class TestAsserts(TestCase):
|
class TestAsserts(TestCase):
|
||||||
def assert_fns(self):
|
def get_assert_fns(self) -> List[Callable]:
|
||||||
return [torch.testing.assert_tensors_equal, torch.testing.assert_tensors_close]
|
"""Gets assert functions to be tested.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List(Callable): Top-level assert functions from :mod:`torch.testing`.
|
||||||
|
"""
|
||||||
|
return [torch.testing.assert_equal, torch.testing.assert_close]
|
||||||
|
|
||||||
|
def make_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> List[Tuple[T, T]]:
|
||||||
|
"""Makes inputs for assert functions based on two example tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual (torch.Tensor): Actual tensor.
|
||||||
|
expected (torch.Tensor): Expected tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[T, T]]: Pairs of tensors, tensor sequences (:class:`tuple`, :class:`list`), and tensor mappings
|
||||||
|
(:class:`dict`, :class:`~collections.OrderedDict`)
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
(actual, expected),
|
||||||
|
((actual,), (expected,)),
|
||||||
|
([actual], [expected]),
|
||||||
|
({"t": actual}, {"t": expected}),
|
||||||
|
(collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
|
||||||
|
]
|
||||||
|
|
||||||
|
def assert_fns_with_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> Iterator[Callable]:
|
||||||
|
"""Yields assert functions with with included positional inputs based on two example tensors.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
This is a valid product of combinations from :meth:`get_assert_fns` and :meth:`make_inputs`. Every test
|
||||||
|
that does not test for anything specific should iterate over this to maximize the coverage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual (torch.Tensor): Actual tensor.
|
||||||
|
expected (torch.Tensor): Expected tensor.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
List[Callable]: Assert functions with predefined positional inputs.
|
||||||
|
"""
|
||||||
|
for assert_fn, inputs in itertools.product(self.get_assert_fns(), self.make_inputs(actual, expected)):
|
||||||
|
yield functools.partial(assert_fn, *inputs)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_not_tensors(self, device):
|
def test_not_tensors(self, device):
|
||||||
actual = torch.empty((), device=device)
|
actual = torch.empty((), device=device)
|
||||||
expected = np.empty(())
|
expected = np.empty(())
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.get_assert_fns():
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(UsageError):
|
||||||
fn(actual, expected)
|
fn(actual, expected)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@ -758,18 +809,18 @@ class TestAsserts(TestCase):
|
|||||||
actual = torch.ones(1, dtype=torch.float32, device=device)
|
actual = torch.ones(1, dtype=torch.float32, device=device)
|
||||||
expected = torch.ones(1, dtype=torch.complex64, device=device)
|
expected = torch.ones(1, dtype=torch.complex64, device=device)
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
with self.assertRaises(UsageError):
|
with self.assertRaises(UsageError):
|
||||||
fn(actual, expected, check_dtype=False)
|
fn(check_dtype=False)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_sparse_support(self, device):
|
def test_sparse_support(self, device):
|
||||||
actual = torch.empty((), device=device)
|
actual = torch.empty((), device=device)
|
||||||
expected = torch.sparse_coo_tensor(size=(), device=device)
|
expected = torch.sparse_coo_tensor(size=(), device=device)
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
with self.assertRaises(UsageError):
|
with self.assertRaises(UsageError):
|
||||||
fn(actual, expected)
|
fn()
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_quantized_support(self, device):
|
def test_quantized_support(self, device):
|
||||||
@ -778,221 +829,246 @@ class TestAsserts(TestCase):
|
|||||||
expected = torch._empty_affine_quantized(actual.shape, scale=1, zero_point=0, dtype=torch.qint32, device=device)
|
expected = torch._empty_affine_quantized(actual.shape, scale=1, zero_point=0, dtype=torch.qint32, device=device)
|
||||||
expected.fill_(val)
|
expected.fill_(val)
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
with self.assertRaises(UsageError):
|
with self.assertRaises(UsageError):
|
||||||
fn(actual, expected)
|
fn()
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_mismatching_shape(self, device):
|
def test_mismatching_shape(self, device):
|
||||||
actual = torch.empty((), device=device)
|
actual = torch.empty((), device=device)
|
||||||
expected = actual.clone().reshape((1,))
|
expected = actual.clone().reshape((1,))
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
with self.assertRaisesRegex(AssertionError, "shape"):
|
with self.assertRaisesRegex(AssertionError, "shape"):
|
||||||
fn(actual, expected)
|
fn()
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
def test_mismatching_device(self, device):
|
def test_mismatching_device(self, device):
|
||||||
actual = torch.empty((), device=device)
|
actual = torch.empty((), device=device)
|
||||||
expected = actual.clone().cpu()
|
expected = actual.clone().cpu()
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
with self.assertRaisesRegex(AssertionError, "device"):
|
with self.assertRaisesRegex(AssertionError, "device"):
|
||||||
fn(actual, expected)
|
fn()
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
def test_mismatching_device_no_check(self, device):
|
def test_mismatching_device_no_check(self, device):
|
||||||
actual = torch.rand((), device=device)
|
actual = torch.rand((), device=device)
|
||||||
expected = actual.clone().cpu()
|
expected = actual.clone().cpu()
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
fn(actual, expected, check_device=False)
|
fn(check_device=False)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_mismatching_dtype(self, device):
|
def test_mismatching_dtype(self, device):
|
||||||
actual = torch.empty((), dtype=torch.float, device=device)
|
actual = torch.empty((), dtype=torch.float, device=device)
|
||||||
expected = actual.clone().to(torch.int)
|
expected = actual.clone().to(torch.int)
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
with self.assertRaisesRegex(AssertionError, "dtype"):
|
with self.assertRaisesRegex(AssertionError, "dtype"):
|
||||||
fn(actual, expected)
|
fn()
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_mismatching_dtype_no_check(self, device):
|
def test_mismatching_dtype_no_check(self, device):
|
||||||
actual = torch.ones((), dtype=torch.float, device=device)
|
actual = torch.ones((), dtype=torch.float, device=device)
|
||||||
expected = actual.clone().to(torch.int)
|
expected = actual.clone().to(torch.int)
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
fn(actual, expected, check_dtype=False)
|
fn(check_dtype=False)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_mismatching_stride(self, device):
|
def test_mismatching_stride(self, device):
|
||||||
actual = torch.empty((2, 2), device=device)
|
actual = torch.empty((2, 2), device=device)
|
||||||
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
with self.assertRaisesRegex(AssertionError, "stride"):
|
with self.assertRaisesRegex(AssertionError, "stride"):
|
||||||
fn(actual, expected)
|
fn()
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_mismatching_stride_no_check(self, device):
|
def test_mismatching_stride_no_check(self, device):
|
||||||
actual = torch.rand((2, 2), device=device)
|
actual = torch.rand((2, 2), device=device)
|
||||||
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
||||||
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
for fn in self.assert_fns():
|
fn(check_stride=False)
|
||||||
fn(actual, expected, check_stride=False)
|
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_mismatching_values(self, device):
|
def test_mismatching_values(self, device):
|
||||||
actual = torch.tensor(1, device=device)
|
actual = torch.tensor(1, device=device)
|
||||||
expected = torch.tensor(2, device=device)
|
expected = torch.tensor(2, device=device)
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
fn()
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_equal(self, device):
|
||||||
|
actual = torch.tensor(1, device=device)
|
||||||
|
expected = actual.clone()
|
||||||
|
|
||||||
|
torch.testing.assert_equal(actual, expected)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_close(self, device):
|
||||||
|
actual = torch.tensor(1.0, device=device)
|
||||||
|
expected = actual.clone()
|
||||||
|
|
||||||
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_close_only_rtol(self, device):
|
||||||
|
actual = torch.empty((), device=device)
|
||||||
|
expected = actual.clone()
|
||||||
|
|
||||||
|
with self.assertRaises(UsageError):
|
||||||
|
torch.testing.assert_close(actual, expected, rtol=0.0)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_close_only_atol(self, device):
|
||||||
|
actual = torch.empty((), device=device)
|
||||||
|
expected = actual.clone()
|
||||||
|
|
||||||
|
with self.assertRaises(UsageError):
|
||||||
|
torch.testing.assert_close(actual, expected, atol=0.0)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_close_mismatching_values_rtol(self, device):
|
||||||
|
eps = 1e-3
|
||||||
|
actual = torch.tensor(1.0, device=device)
|
||||||
|
expected = torch.tensor(1.0 + eps, device=device)
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(actual, expected, rtol=eps / 2, atol=0.0)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_close_matching_values_rtol(self, device):
|
||||||
|
eps = 1e-3
|
||||||
|
actual = torch.tensor(1.0, device=device)
|
||||||
|
expected = torch.tensor(1.0 + eps, device=device)
|
||||||
|
|
||||||
|
torch.testing.assert_close(actual, expected, rtol=eps * 2, atol=0.0)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_close_mismatching_values_atol(self, device):
|
||||||
|
eps = 1e-3
|
||||||
|
actual = torch.tensor(0.0, device=device)
|
||||||
|
expected = torch.tensor(eps, device=device)
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps / 2)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_close_matching_values_atol(self, device):
|
||||||
|
eps = 1e-3
|
||||||
|
actual = torch.tensor(0.0, device=device)
|
||||||
|
expected = torch.tensor(eps, device=device)
|
||||||
|
|
||||||
|
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps * 2)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_mismatching_values_msg_mismatches(self, device):
|
||||||
|
actual = torch.tensor([1, 2, 3, 4], device=device)
|
||||||
|
expected = torch.tensor([1, 2, 5, 6], device=device)
|
||||||
|
|
||||||
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
|
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
|
||||||
|
fn()
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_mismatching_values_msg_abs_diff(self, device):
|
||||||
|
actual = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||||
|
expected = torch.tensor([[1, 2], [5, 4]], device=device)
|
||||||
|
|
||||||
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
|
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at (1, 0)")):
|
||||||
|
fn()
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_mismatching_values_msg_rel_diff(self, device):
|
||||||
|
actual = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||||
|
expected = torch.tensor([[1, 4], [3, 4]], device=device)
|
||||||
|
|
||||||
|
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||||
|
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at (0, 1)")):
|
||||||
|
fn()
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_close_mismatching_values_msg_rtol(self, device):
|
||||||
|
rtol = 1e-3
|
||||||
|
|
||||||
|
actual = torch.tensor(1, device=device)
|
||||||
|
expected = torch.tensor(2, device=device)
|
||||||
|
|
||||||
|
for inputs in self.make_inputs(actual, expected):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
AssertionError, re.escape(f"Greatest relative difference: 0.5 at 0 (up to {rtol} allowed)")
|
||||||
|
):
|
||||||
|
torch.testing.assert_close(*inputs, rtol=rtol, atol=0.0)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_assert_close_mismatching_values_msg_atol(self, device):
|
||||||
|
atol = 1e-3
|
||||||
|
|
||||||
|
actual = torch.tensor(1, device=device)
|
||||||
|
expected = torch.tensor(2, device=device)
|
||||||
|
|
||||||
|
for inputs in self.make_inputs(actual, expected):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
AssertionError, re.escape(f"Greatest absolute difference: 1 at 0 (up to {atol} allowed)")
|
||||||
|
):
|
||||||
|
torch.testing.assert_close(*inputs, rtol=0.0, atol=atol)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_unknown_type(self, device):
|
||||||
|
actual = torch.empty((), device=device)
|
||||||
|
expected = {actual.clone()}
|
||||||
|
|
||||||
|
for fn in self.get_assert_fns():
|
||||||
|
with self.assertRaisesRegex(UsageError, str(type(expected))):
|
||||||
|
fn(actual, expected)
|
||||||
|
|
||||||
|
@onlyCPU
|
||||||
|
def test_sequence_mismatching_len(self, device):
|
||||||
|
actual = (torch.empty((), device=device),)
|
||||||
|
expected = ()
|
||||||
|
|
||||||
|
for fn in self.get_assert_fns():
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
fn(actual, expected)
|
fn(actual, expected)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_mismatching_values_msg_abs_mismatches(self, device):
|
def test_sequence_mismatching_values_msg(self, device):
|
||||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
t1 = torch.tensor(1, device=device)
|
||||||
expected = actual.clone()
|
t2 = torch.tensor(2, device=device)
|
||||||
|
|
||||||
actual[0, 1] = 1
|
actual = (t1, t1)
|
||||||
expected[0, 1] = 2
|
expected = (t1, t2)
|
||||||
expected[1, 2] = 9
|
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.get_assert_fns():
|
||||||
with self.assertRaisesRegex(AssertionError, r"\s+2\s+"):
|
with self.assertRaisesRegex(AssertionError, r"index\s+1"):
|
||||||
fn(actual, expected)
|
fn(actual, expected)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_mismatching_values_msg_rel_mismatches(self, device):
|
def test_mapping_mismatching_keys(self, device):
|
||||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
actual = {"a": torch.empty((), device=device)}
|
||||||
expected = actual.clone()
|
expected = {}
|
||||||
|
|
||||||
actual[0, 1] = 1
|
for fn in self.get_assert_fns():
|
||||||
expected[0, 1] = 2
|
with self.assertRaises(AssertionError):
|
||||||
expected[1, 2] = 9
|
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
|
||||||
with self.assertRaisesRegex(AssertionError, r"22([.]2+)?\s*[%]"):
|
|
||||||
fn(actual, expected)
|
fn(actual, expected)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_mismatching_values_msg_max_abs_diff(self, device):
|
def test_mapping_mismatching_values_msg(self, device):
|
||||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
t1 = torch.tensor(1, device=device)
|
||||||
expected = actual.clone()
|
t2 = torch.tensor(2, device=device)
|
||||||
|
|
||||||
actual[0, 1] = 1
|
actual = {"a": t1, "b": t1}
|
||||||
expected[0, 1] = 2
|
expected = {"a": t1, "b": t2}
|
||||||
expected[1, 2] = 9
|
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
for fn in self.get_assert_fns():
|
||||||
with self.assertRaisesRegex(AssertionError, r"\s+4[.]0\s+"):
|
with self.assertRaisesRegex(AssertionError, r"key\s+'b'"):
|
||||||
fn(actual, expected)
|
fn(actual, expected)
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_mismatching_values_max_abs_diff_idx(self, device):
|
|
||||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
|
||||||
expected = actual.clone()
|
|
||||||
|
|
||||||
actual[0, 1] = 1
|
|
||||||
expected[0, 1] = 2
|
|
||||||
expected[1, 2] = 9
|
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
|
||||||
with self.assertRaisesRegex(AssertionError, r"1,\s*2"):
|
|
||||||
fn(actual, expected)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_mismatching_values_msg_max_rel_diff(self, device):
|
|
||||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
|
||||||
expected = actual.clone()
|
|
||||||
|
|
||||||
actual[0, 1] = 1
|
|
||||||
expected[0, 1] = 2
|
|
||||||
expected[1, 2] = 9
|
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
|
||||||
with self.assertRaisesRegex(AssertionError, r"\s+0[.]5\s+"):
|
|
||||||
fn(actual, expected)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_mismatching_values_max_rel_diff_idx(self, device):
|
|
||||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
|
||||||
expected = actual.clone()
|
|
||||||
|
|
||||||
actual[0, 1] = 1
|
|
||||||
expected[0, 1] = 2
|
|
||||||
expected[1, 2] = 9
|
|
||||||
|
|
||||||
for fn in self.assert_fns():
|
|
||||||
with self.assertRaisesRegex(AssertionError, r"0,\s*1"):
|
|
||||||
fn(actual, expected)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_assert_tensors_equal(self, device):
|
|
||||||
actual = torch.tensor(1, device=device)
|
|
||||||
expected = actual.clone()
|
|
||||||
|
|
||||||
torch.testing.assert_tensors_equal(actual, expected)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_assert_tensors_close(self, device):
|
|
||||||
actual = torch.tensor(1.0, device=device)
|
|
||||||
expected = actual.clone()
|
|
||||||
|
|
||||||
torch.testing.assert_tensors_close(actual, expected)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_assert_tensors_close_only_rtol(self, device):
|
|
||||||
actual = torch.empty((), device=device)
|
|
||||||
expected = actual.clone()
|
|
||||||
|
|
||||||
with self.assertRaises(UsageError):
|
|
||||||
torch.testing.assert_tensors_close(actual, expected, rtol=0.0)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_assert_tensors_close_only_atol(self, device):
|
|
||||||
actual = torch.empty((), device=device)
|
|
||||||
expected = actual.clone()
|
|
||||||
|
|
||||||
with self.assertRaises(UsageError):
|
|
||||||
torch.testing.assert_tensors_close(actual, expected, atol=0.0)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_assert_tensors_close_mismatching_values_rtol(self, device):
|
|
||||||
eps = 1e-3
|
|
||||||
actual = torch.tensor(1.0, device=device)
|
|
||||||
expected = torch.tensor(1.0 + eps, device=device)
|
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_tensors_close(actual, expected, rtol=eps / 2, atol=0.0)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_assert_tensors_close_matching_values_rtol(self, device):
|
|
||||||
eps = 1e-3
|
|
||||||
actual = torch.tensor(1.0, device=device)
|
|
||||||
expected = torch.tensor(1.0 + eps, device=device)
|
|
||||||
|
|
||||||
torch.testing.assert_tensors_close(actual, expected, rtol=eps * 2, atol=0.0)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_assert_tensors_close_mismatching_values_atol(self, device):
|
|
||||||
eps = 1e-3
|
|
||||||
actual = torch.tensor(0.0, device=device)
|
|
||||||
expected = torch.tensor(eps, device=device)
|
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps / 2)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
def test_assert_tensors_close_matching_values_atol(self, device):
|
|
||||||
eps = 1e-3
|
|
||||||
actual = torch.tensor(0.0, device=device)
|
|
||||||
expected = torch.tensor(eps, device=device)
|
|
||||||
|
|
||||||
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps * 2)
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestAsserts, globals())
|
instantiate_device_type_tests(TestAsserts, globals())
|
||||||
|
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
|
import collections.abc
|
||||||
|
import functools
|
||||||
import sys
|
import sys
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Optional, Tuple, Type
|
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from ._core import _unravel_index
|
from ._core import _unravel_index
|
||||||
|
|
||||||
__all__ = ["assert_tensors_equal", "assert_tensors_close"]
|
__all__ = ["assert_equal", "assert_close"]
|
||||||
|
|
||||||
|
|
||||||
# The UsageError should be raised in case the test function is not used correctly. With this the user is able to
|
# The UsageError should be raised in case the test function is not used correctly. With this the user is able to
|
||||||
@ -42,7 +45,7 @@ _DTYPE_PRECISIONS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
|
def _get_default_rtol_and_atol(actual: Tensor, expected: Tensor) -> Tuple[float, float]:
|
||||||
dtype = actual.dtype if actual.dtype == expected.dtype else torch.promote_types(actual.dtype, expected.dtype)
|
dtype = actual.dtype if actual.dtype == expected.dtype else torch.promote_types(actual.dtype, expected.dtype)
|
||||||
return _DTYPE_PRECISIONS.get(dtype, (0.0, 0.0))
|
return _DTYPE_PRECISIONS.get(dtype, (0.0, 0.0))
|
||||||
|
|
||||||
@ -57,23 +60,23 @@ def _check_are_tensors(actual: Any, expected: Any) -> Optional[AssertionError]:
|
|||||||
Returns:
|
Returns:
|
||||||
(Optional[AssertionError]): If check did not pass.
|
(Optional[AssertionError]): If check did not pass.
|
||||||
"""
|
"""
|
||||||
if not (isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor)):
|
if not (isinstance(actual, Tensor) and isinstance(expected, Tensor)):
|
||||||
return AssertionError(f"Both inputs have to be tensors, but got {type(actual)} and {type(expected)} instead.")
|
return AssertionError(f"Both inputs have to be tensors, but got {type(actual)} and {type(expected)} instead.")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _check_supported_tensors(
|
def _check_supported_tensors(
|
||||||
actual: torch.Tensor,
|
actual: Tensor,
|
||||||
expected: torch.Tensor,
|
expected: Tensor,
|
||||||
) -> Optional[UsageError]: # type: ignore[valid-type]
|
) -> Optional[UsageError]: # type: ignore[valid-type]
|
||||||
"""Checks if the tensors are supported by the current infrastructure.
|
"""Checks if the tensors are supported by the current infrastructure.
|
||||||
|
|
||||||
All checks are temporary and will be relaxed in the future.
|
All checks are temporary and will be relaxed in the future.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
actual (torch.Tensor): Actual tensor.
|
actual (Tensor): Actual tensor.
|
||||||
expected (torch.Tensor): Expected tensor.
|
expected (Tensor): Expected tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Optional[UsageError]): If check did not pass.
|
(Optional[UsageError]): If check did not pass.
|
||||||
@ -89,8 +92,8 @@ def _check_supported_tensors(
|
|||||||
|
|
||||||
|
|
||||||
def _check_attributes_equal(
|
def _check_attributes_equal(
|
||||||
actual: torch.Tensor,
|
actual: Tensor,
|
||||||
expected: torch.Tensor,
|
expected: Tensor,
|
||||||
*,
|
*,
|
||||||
check_device: bool = True,
|
check_device: bool = True,
|
||||||
check_dtype: bool = True,
|
check_dtype: bool = True,
|
||||||
@ -102,8 +105,8 @@ def _check_attributes_equal(
|
|||||||
:attr:`~torch.Tensor.dtype`, and :meth:`~torch.Tensor.stride` are optional and can be disabled.
|
:attr:`~torch.Tensor.dtype`, and :meth:`~torch.Tensor.stride` are optional and can be disabled.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
actual (torch.Tensor): Actual tensor.
|
actual (Tensor): Actual tensor.
|
||||||
expected (torch.Tensor): Expected tensor.
|
expected (Tensor): Expected tensor.
|
||||||
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
|
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
|
||||||
same :attr:`~torch.Tensor.device` memory.
|
same :attr:`~torch.Tensor.device` memory.
|
||||||
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
|
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
|
||||||
@ -112,7 +115,7 @@ def _check_attributes_equal(
|
|||||||
:meth:`~torch.Tensor.stride`.
|
:meth:`~torch.Tensor.stride`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Optional[AssertionError]): If check did not pass.
|
(Optional[AssertionError]): If checks did not pass.
|
||||||
"""
|
"""
|
||||||
msg_fmtstr = "The values for attribute '{}' do not match: {} != {}."
|
msg_fmtstr = "The values for attribute '{}' do not match: {} != {}."
|
||||||
|
|
||||||
@ -131,7 +134,7 @@ def _check_attributes_equal(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _equalize_attributes(actual: Tensor, expected: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
"""Equalizes some attributes of two tensors for value comparison.
|
"""Equalizes some attributes of two tensors for value comparison.
|
||||||
|
|
||||||
If :attr:`actual` and :attr:`expected`
|
If :attr:`actual` and :attr:`expected`
|
||||||
@ -140,11 +143,11 @@ def _equalize_attributes(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[
|
|||||||
:func:`torch.promote_types`.
|
:func:`torch.promote_types`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
actual (torch.Tensor): Actual tensor.
|
actual (Tensor): Actual tensor.
|
||||||
expected (torch.Tensor): Expected tensor.
|
expected (Tensor): Expected tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple(torch.Tensor, torch.Tensor): Equalized tensors.
|
Tuple(Tensor, Tensor): Equalized tensors.
|
||||||
"""
|
"""
|
||||||
if actual.device != expected.device:
|
if actual.device != expected.device:
|
||||||
actual = actual.cpu()
|
actual = actual.cpu()
|
||||||
@ -172,13 +175,13 @@ _Trace = namedtuple(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _trace_mismatches(actual: torch.Tensor, expected: torch.Tensor, mismatches: torch.Tensor) -> _Trace:
|
def _trace_mismatches(actual: Tensor, expected: Tensor, mismatches: Tensor) -> _Trace:
|
||||||
"""Traces mismatches.
|
"""Traces mismatches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
actual (torch.Tensor): Actual tensor.
|
actual (Tensor): Actual tensor.
|
||||||
expected (torch.Tensor): Expected tensor.
|
expected (Tensor): Expected tensor.
|
||||||
mismatches (torch.Tensor): Boolean mask of the same shape as :attr:`actual` and :attr:`expected` that indicates
|
mismatches (Tensor): Boolean mask of the same shape as :attr:`actual` and :attr:`expected` that indicates
|
||||||
the location of mismatches.
|
the location of mismatches.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -220,12 +223,12 @@ def _trace_mismatches(actual: torch.Tensor, expected: torch.Tensor, mismatches:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_values_equal(actual: torch.Tensor, expected: torch.Tensor) -> Optional[AssertionError]:
|
def _check_values_equal(actual: Tensor, expected: Tensor) -> Optional[AssertionError]:
|
||||||
"""Checks if the values of two tensors are bitwise equal.
|
"""Checks if the values of two tensors are bitwise equal.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
actual (torch.Tensor): Actual tensor.
|
actual (Tensor): Actual tensor.
|
||||||
expected (torch.Tensor): Expected tensor.
|
expected (Tensor): Expected tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Optional[AssertionError]): If check did not pass.
|
(Optional[AssertionError]): If check did not pass.
|
||||||
@ -244,8 +247,8 @@ def _check_values_equal(actual: torch.Tensor, expected: torch.Tensor) -> Optiona
|
|||||||
|
|
||||||
|
|
||||||
def _check_values_close(
|
def _check_values_close(
|
||||||
actual: torch.Tensor,
|
actual: Tensor,
|
||||||
expected: torch.Tensor,
|
expected: Tensor,
|
||||||
*,
|
*,
|
||||||
rtol,
|
rtol,
|
||||||
atol,
|
atol,
|
||||||
@ -253,8 +256,8 @@ def _check_values_close(
|
|||||||
"""Checks if the values of two tensors are close up to a desired tolerance.
|
"""Checks if the values of two tensors are close up to a desired tolerance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
actual (torch.Tensor): Actual tensor.
|
actual (Tensor): Actual tensor.
|
||||||
expected (torch.Tensor): Expected tensor.
|
expected (Tensor): Expected tensor.
|
||||||
rtol (float): Relative tolerance.
|
rtol (float): Relative tolerance.
|
||||||
atol (float): Absolute tolerance.
|
atol (float): Absolute tolerance.
|
||||||
|
|
||||||
@ -274,55 +277,89 @@ def _check_values_close(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def assert_tensors_equal(
|
def _check_tensors_equal(
|
||||||
actual: torch.Tensor,
|
actual: Tensor,
|
||||||
expected: torch.Tensor,
|
expected: Tensor,
|
||||||
*,
|
*,
|
||||||
check_device: bool = True,
|
check_device: bool = True,
|
||||||
check_dtype: bool = True,
|
check_dtype: bool = True,
|
||||||
check_stride: bool = True,
|
check_stride: bool = True,
|
||||||
) -> None:
|
) -> Optional[Exception]:
|
||||||
"""Asserts that the values of two tensors are bitwise equal.
|
"""Checks that the values of two tensors are bitwise equal.
|
||||||
|
|
||||||
Optionally, checks that some attributes of both tensors are equal.
|
Optionally, checks that some attributes of both tensors are equal.
|
||||||
|
|
||||||
Args:
|
For a description of the parameters see :func:`assert_equal`.
|
||||||
actual (torch.Tensor): Actual tensor.
|
|
||||||
expected (torch.Tensor): Expected tensor.
|
|
||||||
check_device (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` are on the
|
|
||||||
same :attr:`~torch.Tensor.device` memory. If this check is disabled **and** :attr:`actual` and
|
|
||||||
:attr:`expected` are not on the same memory :attr:`~torch.Tensor.device`, they are moved CPU memory before
|
|
||||||
their values are compared.
|
|
||||||
check_dtype (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
|
|
||||||
:attr:`~torch.Tensor.dtype`. If this check is disabled **and** :attr:`actual` and :attr:`expected` do not
|
|
||||||
have the same :attr:`~torch.Tensor.dtype`, they are copied to the :class:`~torch.dtype` returned by
|
|
||||||
:func:`torch.promote_types` before their values are compared.
|
|
||||||
check_stride (bool): If ``True`` (default), asserts that both :attr:`actual` and :attr:`expected` have the same
|
|
||||||
stride.
|
|
||||||
|
|
||||||
Raises:
|
Returns:
|
||||||
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
|
Optional[Exception]: If checks did not pass.
|
||||||
restriction and will be relaxed in the future.
|
|
||||||
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
|
|
||||||
AssertionError: If :attr:`check_device`, but :attr:`actual` and :attr:`expected` are not on the same
|
|
||||||
:attr:`~torch.Tensor.device` memory.
|
|
||||||
AssertionError: If :attr:`check_dtype`, but :attr:`actual` and :attr:`expected` do not have the same
|
|
||||||
:attr:`~torch.Tensor.dtype`.
|
|
||||||
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
|
|
||||||
AssertionError: If the values of :attr:`actual` and :attr:`expected` are not bitwise equal.
|
|
||||||
|
|
||||||
.. seealso::
|
|
||||||
|
|
||||||
To assert that the values in two tensors are are close but are not required to be bitwise equal, use
|
|
||||||
:func:`assert_tensors_close` instead.
|
|
||||||
"""
|
"""
|
||||||
exc: Optional[Exception] = _check_are_tensors(actual, expected)
|
exc: Optional[Exception] = _check_are_tensors(actual, expected)
|
||||||
if exc:
|
if exc:
|
||||||
raise exc
|
return exc
|
||||||
|
|
||||||
exc = _check_supported_tensors(actual, expected)
|
exc = _check_supported_tensors(actual, expected)
|
||||||
if exc:
|
if exc:
|
||||||
raise exc
|
return exc
|
||||||
|
|
||||||
|
exc = _check_attributes_equal(
|
||||||
|
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
|
||||||
|
)
|
||||||
|
if exc:
|
||||||
|
return exc
|
||||||
|
actual, expected = _equalize_attributes(actual, expected)
|
||||||
|
|
||||||
|
exc = _check_values_equal(actual, expected)
|
||||||
|
if exc:
|
||||||
|
return exc
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _check_tensors_close(
|
||||||
|
actual: Tensor,
|
||||||
|
expected: Tensor,
|
||||||
|
*,
|
||||||
|
rtol: Optional[float] = None,
|
||||||
|
atol: Optional[float] = None,
|
||||||
|
check_device: bool = True,
|
||||||
|
check_dtype: bool = True,
|
||||||
|
check_stride: bool = True,
|
||||||
|
) -> Optional[Exception]:
|
||||||
|
r"""Checks that the values of two tensors are close.
|
||||||
|
|
||||||
|
Closeness is defined by
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\lvert a - b \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert b \rvert
|
||||||
|
|
||||||
|
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are
|
||||||
|
bitwise equal.
|
||||||
|
|
||||||
|
Optionally, checks that some attributes of both tensors are equal.
|
||||||
|
|
||||||
|
For a description of the parameters see :func:`assert_equal`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Exception]: If checks did not pass.
|
||||||
|
"""
|
||||||
|
exc: Optional[Exception] = _check_are_tensors(actual, expected)
|
||||||
|
if exc:
|
||||||
|
return exc
|
||||||
|
|
||||||
|
exc = _check_supported_tensors(actual, expected)
|
||||||
|
if exc:
|
||||||
|
return exc
|
||||||
|
|
||||||
|
if (rtol is None) ^ (atol is None):
|
||||||
|
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
|
||||||
|
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
|
||||||
|
return UsageError(
|
||||||
|
f"Both 'rtol' and 'atol' must be omitted or specified, " f"but got rtol={rtol} and atol={atol} instead."
|
||||||
|
)
|
||||||
|
elif rtol is None:
|
||||||
|
rtol, atol = _get_default_rtol_and_atol(actual, expected)
|
||||||
|
|
||||||
exc = _check_attributes_equal(
|
exc = _check_attributes_equal(
|
||||||
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
|
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
|
||||||
@ -331,14 +368,205 @@ def assert_tensors_equal(
|
|||||||
raise exc
|
raise exc
|
||||||
actual, expected = _equalize_attributes(actual, expected)
|
actual, expected = _equalize_attributes(actual, expected)
|
||||||
|
|
||||||
exc = _check_values_equal(actual, expected)
|
if (rtol == 0.0) and (atol == 0.0):
|
||||||
|
exc = _check_values_equal(actual, expected)
|
||||||
|
else:
|
||||||
|
exc = _check_values_close(actual, expected, rtol=rtol, atol=atol)
|
||||||
|
if exc:
|
||||||
|
return exc
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _check_by_type(
|
||||||
|
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||||
|
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||||
|
check_tensors: Callable[[Tensor, Tensor], Optional[Exception]],
|
||||||
|
) -> Optional[Exception]:
|
||||||
|
"""Delegates tensor checking based on the inputs types.
|
||||||
|
|
||||||
|
Currently supports pairs of
|
||||||
|
|
||||||
|
- :class:`Tensor`'s,
|
||||||
|
- :class:`~collections.abc.Sequence`'s of :class:`Tensor`'s, and
|
||||||
|
- :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
|
||||||
|
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
|
||||||
|
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if two tensors match.
|
||||||
|
In case they mismatch should return an :class:`Exception` with an expressive error message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Optional[Exception]): :class:`UsageError` if the inputs types are unsupported. Additionally, any exception
|
||||||
|
returned by :attr:`check_tensors`.
|
||||||
|
"""
|
||||||
|
# _check_are_tensors() returns nothing in case both inputs are tensors and an exception otherwise. Thus, the logic
|
||||||
|
# is inverted here.
|
||||||
|
are_tensors = not _check_are_tensors(actual, expected)
|
||||||
|
if are_tensors:
|
||||||
|
return check_tensors(cast(Tensor, actual), cast(Tensor, expected))
|
||||||
|
|
||||||
|
if isinstance(actual, collections.abc.Sequence) and isinstance(expected, collections.abc.Sequence):
|
||||||
|
return _check_sequence(actual, expected, check_tensors)
|
||||||
|
elif isinstance(actual, collections.abc.Mapping) and isinstance(expected, collections.abc.Mapping):
|
||||||
|
return _check_mapping(actual, expected, check_tensors)
|
||||||
|
|
||||||
|
return UsageError(
|
||||||
|
f"Both inputs have to be tensors, or sequences or mappings of tensors, "
|
||||||
|
f"but got {type(actual)} and {type(expected)} instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
E = TypeVar("E", bound=Exception)
|
||||||
|
|
||||||
|
|
||||||
|
def _amend_error_message(exc: E, msg_fmtstr: str) -> E:
|
||||||
|
"""Amends an exception message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc (E): Exception.
|
||||||
|
msg_fmtstr: Format string for the amended message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(E): New exception with amended error message.
|
||||||
|
"""
|
||||||
|
return type(exc)(msg_fmtstr.format(str(exc)))
|
||||||
|
|
||||||
|
|
||||||
|
_SEQUENCE_MSG_FMTSTR = "The failure occurred at index {} of the sequences."
|
||||||
|
|
||||||
|
|
||||||
|
def _check_sequence(
|
||||||
|
actual: Sequence[Tensor], expected: Sequence[Tensor], check_tensors: Callable[[Tensor, Tensor], Optional[Exception]]
|
||||||
|
) -> Optional[Exception]:
|
||||||
|
"""Checks if the values of two sequences of tensors match.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual (Sequence[Tensor]): Actual sequence of tensors.
|
||||||
|
expected (Sequence[Tensor]): Expected sequence of tensors.
|
||||||
|
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if the items of
|
||||||
|
:attr:`actual` and :attr:`expected` match. In case they mismatch should return an :class:`Exception` with
|
||||||
|
an expressive error message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Exception]: :class:`AssertionError` if the sequences do not have the same length. Additionally, any
|
||||||
|
exception returned by :attr:`check_tensors`. In this case, the error message is amended to include the
|
||||||
|
first offending index.
|
||||||
|
"""
|
||||||
|
actual_len = len(actual)
|
||||||
|
expected_len = len(expected)
|
||||||
|
if actual_len != expected_len:
|
||||||
|
return AssertionError(f"The length of the sequences mismatch: {actual_len} != {expected_len}")
|
||||||
|
for idx, (actual_t, expected_t) in enumerate(zip(actual, expected)):
|
||||||
|
exc = check_tensors(actual_t, expected_t)
|
||||||
|
if exc:
|
||||||
|
return _amend_error_message(exc, f"{{}}\n\n{_SEQUENCE_MSG_FMTSTR.format(idx)}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
_MAPPING_MSG_FMTSTR = "The failure occurred for key '{}' of the mappings."
|
||||||
|
|
||||||
|
|
||||||
|
def _check_mapping(
|
||||||
|
actual: Mapping[Any, Tensor],
|
||||||
|
expected: Mapping[Any, Tensor],
|
||||||
|
check_tensors: Callable[[Tensor, Tensor], Optional[Exception]],
|
||||||
|
) -> Optional[Exception]:
|
||||||
|
"""Checks if the values of two mappings of tensors match.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual (Mapping[Any, Tensor]): First mapping of tensors.
|
||||||
|
expected (Mapping[Any, Tensor]): Second mapping of tensors.
|
||||||
|
check_tensors (Callable[[Tensor, Tensor], Optional[Exception]]): Callable used to check if the values of
|
||||||
|
:attr:`actual` and :attr:`expected` match. In case they mismatch should return an :class:`Exception` with
|
||||||
|
an expressive error message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Exception]: :class:`AssertionError` if the sequences do not have the same set of keys. Additionally,
|
||||||
|
any exception returned by :attr:`check_tensors`. In this case, the error message is amended to include the
|
||||||
|
first offending key.
|
||||||
|
"""
|
||||||
|
actual_keys = set(actual.keys())
|
||||||
|
expected_keys = set(expected.keys())
|
||||||
|
if actual_keys != expected_keys:
|
||||||
|
missing_keys = expected_keys - actual_keys
|
||||||
|
additional_keys = actual_keys - expected_keys
|
||||||
|
return AssertionError(
|
||||||
|
f"The keys of the mappings do not match:\n\n"
|
||||||
|
f"Missing keys in the actual mapping: {sorted(missing_keys)}\n"
|
||||||
|
f"Additional keys in the actual mapping: {sorted(additional_keys)}\n"
|
||||||
|
)
|
||||||
|
for key in sorted(actual_keys):
|
||||||
|
actual_t = actual[key]
|
||||||
|
expected_t = expected[key]
|
||||||
|
|
||||||
|
exc = check_tensors(actual_t, expected_t)
|
||||||
|
if exc:
|
||||||
|
return _amend_error_message(exc, f"{{}}\n\n{_MAPPING_MSG_FMTSTR.format(key)}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def assert_equal(
|
||||||
|
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||||
|
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||||
|
*,
|
||||||
|
check_device: bool = True,
|
||||||
|
check_dtype: bool = True,
|
||||||
|
check_stride: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Asserts that the values of tensors are bitwise equal.
|
||||||
|
|
||||||
|
Optionally, checks that some attributes of tensors are equal.
|
||||||
|
|
||||||
|
Also supports :class:`~collections.abc.Sequence`'s and :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
|
||||||
|
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
|
||||||
|
check_device (bool): If ``True`` (default), asserts that tensors live in the same :attr:`~torch.Tensor.device`
|
||||||
|
memory. If this check is disabled **and** they do not live in the same memory :attr:`~torch.Tensor.device`,
|
||||||
|
they are moved CPU memory before their values are compared.
|
||||||
|
check_dtype (bool): If ``True`` (default), asserts that tensors have the same :attr:`~torch.Tensor.dtype`. If
|
||||||
|
this check is disabled they do not have the same :attr:`~torch.Tensor.dtype`, they are copied to the
|
||||||
|
:class:`~torch.dtype` returned by :func:`torch.promote_types` before their values are compared.
|
||||||
|
check_stride (bool): If ``True`` (default), asserts that the tensors have the same stride.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UsageError: If the input pair has an unsupported type.
|
||||||
|
UsageError: If any tensor is complex, quantized, or sparse. This is a temporary restriction and
|
||||||
|
will be relaxed in the future.
|
||||||
|
AssertionError: If any corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
|
||||||
|
AssertionError: If :attr:`check_device`, but any corresponding tensors do not live in the same
|
||||||
|
:attr:`~torch.Tensor.device` memory.
|
||||||
|
AssertionError: If :attr:`check_dtype`, but any corresponding tensors do not have the same
|
||||||
|
:attr:`~torch.Tensor.dtype`.
|
||||||
|
AssertionError: If :attr:`check_stride`, but any corresponding tensors do not have the same stride.
|
||||||
|
AssertionError: If the values of any corresponding tensors are not bitwise equal.
|
||||||
|
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
|
||||||
|
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys mismatch.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
To assert that the values in tensors are close but are not required to be bitwise equal, use
|
||||||
|
:func:`assert_close` instead.
|
||||||
|
"""
|
||||||
|
check_tensors = functools.partial(
|
||||||
|
_check_tensors_equal,
|
||||||
|
check_device=check_device,
|
||||||
|
check_dtype=check_dtype,
|
||||||
|
check_stride=check_stride,
|
||||||
|
)
|
||||||
|
exc = _check_by_type(actual, expected, check_tensors)
|
||||||
if exc:
|
if exc:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
def assert_tensors_close(
|
def assert_close(
|
||||||
actual: torch.Tensor,
|
actual: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||||
expected: torch.Tensor,
|
expected: Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]],
|
||||||
*,
|
*,
|
||||||
rtol: Optional[float] = None,
|
rtol: Optional[float] = None,
|
||||||
atol: Optional[float] = None,
|
atol: Optional[float] = None,
|
||||||
@ -346,14 +574,24 @@ def assert_tensors_close(
|
|||||||
check_dtype: bool = True,
|
check_dtype: bool = True,
|
||||||
check_stride: bool = True,
|
check_stride: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Asserts that the values of two tensors are close up to a desired tolerance.
|
r"""Asserts that the values of tensors are close.
|
||||||
|
|
||||||
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are bitwise
|
Closeness is defined by
|
||||||
equal. Optionally, checks that some attributes of both tensors are equal.
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\lvert a - b \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert b \rvert
|
||||||
|
|
||||||
|
If both tolerances, :attr:`rtol` and :attr:`rtol`, are ``0``, asserts that :attr:`actual` and :attr:`expected` are
|
||||||
|
bitwise equal.
|
||||||
|
|
||||||
|
Optionally, checks that some attributes of tensors are equal.
|
||||||
|
|
||||||
|
Also supports :class:`~collections.abc.Sequence`'s and :class:`~collections.abc.Mapping`'s of :class:`Tensor`'s.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
actual (torch.Tensor): Actual tensor.
|
actual (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Actual input.
|
||||||
expected (torch.Tensor): Expected tensor.
|
expected (Union[Tensor, Sequence[Tensor], Mapping[Any, Tensor]]): Expected input.
|
||||||
rtol (Optional[float]): Relative tolerance. If specified :attr:`atol` must also be specified. If omitted,
|
rtol (Optional[float]): Relative tolerance. If specified :attr:`atol` must also be specified. If omitted,
|
||||||
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
|
default values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
|
||||||
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted,
|
atol (Optional[float]): Absolute tolerance. If specified :attr:`rtol` must also be specified. If omitted,
|
||||||
@ -370,6 +608,7 @@ def assert_tensors_close(
|
|||||||
stride.
|
stride.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
UsageError: If the input pair has an unsupported type.
|
||||||
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
|
UsageError: If :attr:`actual` or :attr:`expected` is complex, quantized, or sparse. This is a temporary
|
||||||
restriction and will be relaxed in the future.
|
restriction and will be relaxed in the future.
|
||||||
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
|
AssertionError: If :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.shape`.
|
||||||
@ -379,11 +618,11 @@ def assert_tensors_close(
|
|||||||
:attr:`~torch.Tensor.dtype`.
|
:attr:`~torch.Tensor.dtype`.
|
||||||
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
|
AssertionError: If :attr:`check_stride`, but :attr:`actual` and :attr:`expected` do not have the same stride.
|
||||||
AssertionError: If the values of :attr:`actual` and :attr:`expected` are close up to a desired tolerance.
|
AssertionError: If the values of :attr:`actual` and :attr:`expected` are close up to a desired tolerance.
|
||||||
|
AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
|
||||||
|
AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys mismatch.
|
||||||
|
|
||||||
|
The following table displays the default ``rtol``'s and ``atol``'s. Note that the :class:`~torch.dtype` refers to
|
||||||
|
the promoted type in case :attr:`actual` and :attr:`expected` do not have the same :attr:`~torch.Tensor.dtype`.
|
||||||
The following table displays the default ``rtol`` and ``atol`` for floating point :attr:`~torch.Tensor.dtype`'s.
|
|
||||||
For integer :attr:`~torch.Tensor.dtype`'s, ``rtol = atol = 0.0`` is used.
|
|
||||||
|
|
||||||
+===========================+============+==========+
|
+===========================+============+==========+
|
||||||
| :class:`~torch.dtype` | ``rtol`` | ``atol`` |
|
| :class:`~torch.dtype` | ``rtol`` | ``atol`` |
|
||||||
@ -402,38 +641,21 @@ def assert_tensors_close(
|
|||||||
+---------------------------+------------+----------+
|
+---------------------------+------------+----------+
|
||||||
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
|
| :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
|
||||||
+---------------------------+------------+----------+
|
+---------------------------+------------+----------+
|
||||||
|
| other | ``0.0`` | ``0.0`` |
|
||||||
|
+---------------------------+------------+----------+
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
To assert that the values in two tensors are bitwise equal, use :func:`assert_tensors_equal` instead.
|
To assert that the values in tensors are bitwise equal, use :func:`assert_equal` instead.
|
||||||
"""
|
"""
|
||||||
exc: Optional[Exception] = _check_are_tensors(actual, expected)
|
check_tensors = functools.partial(
|
||||||
if exc:
|
_check_tensors_close,
|
||||||
raise exc
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
exc = _check_supported_tensors(actual, expected)
|
check_device=check_device,
|
||||||
if exc:
|
check_dtype=check_dtype,
|
||||||
raise exc
|
check_stride=check_stride,
|
||||||
|
|
||||||
if (rtol is None) ^ (atol is None):
|
|
||||||
# We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
|
|
||||||
# results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
|
|
||||||
raise UsageError(
|
|
||||||
f"Both 'rtol' and 'atol' must be omitted or specified, " f"but got rtol={rtol} and atol={atol} instead."
|
|
||||||
)
|
|
||||||
elif rtol is None:
|
|
||||||
rtol, atol = _get_default_rtol_and_atol(actual, expected)
|
|
||||||
|
|
||||||
exc = _check_attributes_equal(
|
|
||||||
actual, expected, check_device=check_device, check_dtype=check_dtype, check_stride=check_stride
|
|
||||||
)
|
)
|
||||||
if exc:
|
exc = _check_by_type(actual, expected, check_tensors)
|
||||||
raise exc
|
|
||||||
actual, expected = _equalize_attributes(actual, expected)
|
|
||||||
|
|
||||||
if (rtol == 0.0) and (atol == 0.0):
|
|
||||||
exc = _check_values_equal(actual, expected)
|
|
||||||
else:
|
|
||||||
exc = _check_values_close(actual, expected, rtol=rtol, atol=atol)
|
|
||||||
if exc:
|
if exc:
|
||||||
raise exc
|
raise exc
|
||||||
|
Reference in New Issue
Block a user