mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add support for checking tensor containers in torch.testing
(#55385)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55385 This renames `assert_tensors_(equal|close)` to `_check_tensors_(equal|close)` and exposes two new functions: `assert_(equal|close)`. In addition to tensor pairs, the newly added functions also support the comparison of tensors in sequences or mappings. Otherwise their signature stays the same. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D27903805 Pulled By: mruberry fbshipit-source-id: 719d19a1d26de8d14cb25846e3d22a6ac828c80a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bcef7ebd60
commit
dbf3451c6e
@ -1,11 +1,17 @@
|
||||
import torch
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import unittest
|
||||
from typing import Any, Callable, Iterator, List, Mapping, Sequence, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
|
||||
from torch.testing._internal.framework_utils import calculate_shards
|
||||
@ -740,17 +746,62 @@ if __name__ == '__main__':
|
||||
self.assertNotIn('OK', stderr.decode('ascii'))
|
||||
|
||||
|
||||
T = TypeVar("T", torch.Tensor, Sequence[torch.Tensor], Mapping[Any, torch.Tensor])
|
||||
|
||||
|
||||
class TestAsserts(TestCase):
|
||||
def assert_fns(self):
|
||||
return [torch.testing.assert_tensors_equal, torch.testing.assert_tensors_close]
|
||||
def get_assert_fns(self) -> List[Callable]:
|
||||
"""Gets assert functions to be tested.
|
||||
|
||||
Returns:
|
||||
List(Callable): Top-level assert functions from :mod:`torch.testing`.
|
||||
"""
|
||||
return [torch.testing.assert_equal, torch.testing.assert_close]
|
||||
|
||||
def make_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> List[Tuple[T, T]]:
|
||||
"""Makes inputs for assert functions based on two example tensors.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
|
||||
Returns:
|
||||
List[Tuple[T, T]]: Pairs of tensors, tensor sequences (:class:`tuple`, :class:`list`), and tensor mappings
|
||||
(:class:`dict`, :class:`~collections.OrderedDict`)
|
||||
"""
|
||||
return [
|
||||
(actual, expected),
|
||||
((actual,), (expected,)),
|
||||
([actual], [expected]),
|
||||
({"t": actual}, {"t": expected}),
|
||||
(collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
|
||||
]
|
||||
|
||||
def assert_fns_with_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> Iterator[Callable]:
|
||||
"""Yields assert functions with with included positional inputs based on two example tensors.
|
||||
|
||||
.. note::
|
||||
|
||||
This is a valid product of combinations from :meth:`get_assert_fns` and :meth:`make_inputs`. Every test
|
||||
that does not test for anything specific should iterate over this to maximize the coverage.
|
||||
|
||||
Args:
|
||||
actual (torch.Tensor): Actual tensor.
|
||||
expected (torch.Tensor): Expected tensor.
|
||||
|
||||
Yields:
|
||||
List[Callable]: Assert functions with predefined positional inputs.
|
||||
"""
|
||||
for assert_fn, inputs in itertools.product(self.get_assert_fns(), self.make_inputs(actual, expected)):
|
||||
yield functools.partial(assert_fn, *inputs)
|
||||
|
||||
@onlyCPU
|
||||
def test_not_tensors(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = np.empty(())
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaises(AssertionError):
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaises(UsageError):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
@ -758,18 +809,18 @@ class TestAsserts(TestCase):
|
||||
actual = torch.ones(1, dtype=torch.float32, device=device)
|
||||
expected = torch.ones(1, dtype=torch.complex64, device=device)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaises(UsageError):
|
||||
fn(actual, expected, check_dtype=False)
|
||||
fn(check_dtype=False)
|
||||
|
||||
@onlyCPU
|
||||
def test_sparse_support(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = torch.sparse_coo_tensor(size=(), device=device)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaises(UsageError):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_quantized_support(self, device):
|
||||
@ -778,221 +829,246 @@ class TestAsserts(TestCase):
|
||||
expected = torch._empty_affine_quantized(actual.shape, scale=1, zero_point=0, dtype=torch.qint32, device=device)
|
||||
expected.fill_(val)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaises(UsageError):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_shape(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone().reshape((1,))
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "shape"):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCUDA
|
||||
def test_mismatching_device(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone().cpu()
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "device"):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCUDA
|
||||
def test_mismatching_device_no_check(self, device):
|
||||
actual = torch.rand((), device=device)
|
||||
expected = actual.clone().cpu()
|
||||
|
||||
for fn in self.assert_fns():
|
||||
fn(actual, expected, check_device=False)
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
fn(check_device=False)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_dtype(self, device):
|
||||
actual = torch.empty((), dtype=torch.float, device=device)
|
||||
expected = actual.clone().to(torch.int)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "dtype"):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_dtype_no_check(self, device):
|
||||
actual = torch.ones((), dtype=torch.float, device=device)
|
||||
expected = actual.clone().to(torch.int)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
fn(actual, expected, check_dtype=False)
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
fn(check_dtype=False)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_stride(self, device):
|
||||
actual = torch.empty((2, 2), device=device)
|
||||
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, "stride"):
|
||||
fn(actual, expected)
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_stride_no_check(self, device):
|
||||
actual = torch.rand((2, 2), device=device)
|
||||
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
||||
|
||||
for fn in self.assert_fns():
|
||||
fn(actual, expected, check_stride=False)
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
fn(check_stride=False)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values(self, device):
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = torch.tensor(2, device=device)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaises(AssertionError):
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_equal(self, device):
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
torch.testing.assert_equal(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close(self, device):
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_only_rtol(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
with self.assertRaises(UsageError):
|
||||
torch.testing.assert_close(actual, expected, rtol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_only_atol(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
with self.assertRaises(UsageError):
|
||||
torch.testing.assert_close(actual, expected, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_mismatching_values_rtol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = torch.tensor(1.0 + eps, device=device)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(actual, expected, rtol=eps / 2, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_matching_values_rtol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = torch.tensor(1.0 + eps, device=device)
|
||||
|
||||
torch.testing.assert_close(actual, expected, rtol=eps * 2, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_mismatching_values_atol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(0.0, device=device)
|
||||
expected = torch.tensor(eps, device=device)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps / 2)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_matching_values_atol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(0.0, device=device)
|
||||
expected = torch.tensor(eps, device=device)
|
||||
|
||||
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps * 2)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_mismatches(self, device):
|
||||
actual = torch.tensor([1, 2, 3, 4], device=device)
|
||||
expected = torch.tensor([1, 2, 5, 6], device=device)
|
||||
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_abs_diff(self, device):
|
||||
actual = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||
expected = torch.tensor([[1, 2], [5, 4]], device=device)
|
||||
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at (1, 0)")):
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_rel_diff(self, device):
|
||||
actual = torch.tensor([[1, 2], [3, 4]], device=device)
|
||||
expected = torch.tensor([[1, 4], [3, 4]], device=device)
|
||||
|
||||
for fn in self.assert_fns_with_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at (0, 1)")):
|
||||
fn()
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_mismatching_values_msg_rtol(self, device):
|
||||
rtol = 1e-3
|
||||
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = torch.tensor(2, device=device)
|
||||
|
||||
for inputs in self.make_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, re.escape(f"Greatest relative difference: 0.5 at 0 (up to {rtol} allowed)")
|
||||
):
|
||||
torch.testing.assert_close(*inputs, rtol=rtol, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_close_mismatching_values_msg_atol(self, device):
|
||||
atol = 1e-3
|
||||
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = torch.tensor(2, device=device)
|
||||
|
||||
for inputs in self.make_inputs(actual, expected):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, re.escape(f"Greatest absolute difference: 1 at 0 (up to {atol} allowed)")
|
||||
):
|
||||
torch.testing.assert_close(*inputs, rtol=0.0, atol=atol)
|
||||
|
||||
@onlyCPU
|
||||
def test_unknown_type(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = {actual.clone()}
|
||||
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaisesRegex(UsageError, str(type(expected))):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_sequence_mismatching_len(self, device):
|
||||
actual = (torch.empty((), device=device),)
|
||||
expected = ()
|
||||
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaises(AssertionError):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_abs_mismatches(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
def test_sequence_mismatching_values_msg(self, device):
|
||||
t1 = torch.tensor(1, device=device)
|
||||
t2 = torch.tensor(2, device=device)
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
actual = (t1, t1)
|
||||
expected = (t1, t2)
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"\s+2\s+"):
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"index\s+1"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_rel_mismatches(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
def test_mapping_mismatching_keys(self, device):
|
||||
actual = {"a": torch.empty((), device=device)}
|
||||
expected = {}
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"22([.]2+)?\s*[%]"):
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaises(AssertionError):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_max_abs_diff(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
def test_mapping_mismatching_values_msg(self, device):
|
||||
t1 = torch.tensor(1, device=device)
|
||||
t2 = torch.tensor(2, device=device)
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
actual = {"a": t1, "b": t1}
|
||||
expected = {"a": t1, "b": t2}
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"\s+4[.]0\s+"):
|
||||
for fn in self.get_assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"key\s+'b'"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_max_abs_diff_idx(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"1,\s*2"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_msg_max_rel_diff(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"\s+0[.]5\s+"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_mismatching_values_max_rel_diff_idx(self, device):
|
||||
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
actual[0, 1] = 1
|
||||
expected[0, 1] = 2
|
||||
expected[1, 2] = 9
|
||||
|
||||
for fn in self.assert_fns():
|
||||
with self.assertRaisesRegex(AssertionError, r"0,\s*1"):
|
||||
fn(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_equal(self, device):
|
||||
actual = torch.tensor(1, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
torch.testing.assert_tensors_equal(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close(self, device):
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
torch.testing.assert_tensors_close(actual, expected)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_only_rtol(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
with self.assertRaises(UsageError):
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_only_atol(self, device):
|
||||
actual = torch.empty((), device=device)
|
||||
expected = actual.clone()
|
||||
|
||||
with self.assertRaises(UsageError):
|
||||
torch.testing.assert_tensors_close(actual, expected, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_mismatching_values_rtol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = torch.tensor(1.0 + eps, device=device)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=eps / 2, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_matching_values_rtol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(1.0, device=device)
|
||||
expected = torch.tensor(1.0 + eps, device=device)
|
||||
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=eps * 2, atol=0.0)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_mismatching_values_atol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(0.0, device=device)
|
||||
expected = torch.tensor(eps, device=device)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps / 2)
|
||||
|
||||
@onlyCPU
|
||||
def test_assert_tensors_close_matching_values_atol(self, device):
|
||||
eps = 1e-3
|
||||
actual = torch.tensor(0.0, device=device)
|
||||
expected = torch.tensor(eps, device=device)
|
||||
|
||||
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps * 2)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestAsserts, globals())
|
||||
|
||||
|
Reference in New Issue
Block a user