Add support for checking tensor containers in torch.testing (#55385)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55385

This renames `assert_tensors_(equal|close)` to `_check_tensors_(equal|close)` and exposes two new functions: `assert_(equal|close)`. In addition to tensor pairs, the newly added functions also support the comparison of tensors in sequences or mappings. Otherwise their signature stays the same.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D27903805

Pulled By: mruberry

fbshipit-source-id: 719d19a1d26de8d14cb25846e3d22a6ac828c80a
This commit is contained in:
Philip Meier
2021-04-24 23:35:25 -07:00
committed by Facebook GitHub Bot
parent bcef7ebd60
commit dbf3451c6e
2 changed files with 558 additions and 260 deletions

View File

@ -1,11 +1,17 @@
import torch
import collections
import functools
import itertools
import math
import os
import random
import re
import unittest
from typing import Any, Callable, Iterator, List, Mapping, Sequence, Tuple, TypeVar
import numpy as np
import torch
from torch.testing._internal.common_utils import \
(IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
from torch.testing._internal.framework_utils import calculate_shards
@ -740,17 +746,62 @@ if __name__ == '__main__':
self.assertNotIn('OK', stderr.decode('ascii'))
T = TypeVar("T", torch.Tensor, Sequence[torch.Tensor], Mapping[Any, torch.Tensor])
class TestAsserts(TestCase):
def assert_fns(self):
return [torch.testing.assert_tensors_equal, torch.testing.assert_tensors_close]
def get_assert_fns(self) -> List[Callable]:
"""Gets assert functions to be tested.
Returns:
List(Callable): Top-level assert functions from :mod:`torch.testing`.
"""
return [torch.testing.assert_equal, torch.testing.assert_close]
def make_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> List[Tuple[T, T]]:
"""Makes inputs for assert functions based on two example tensors.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
Returns:
List[Tuple[T, T]]: Pairs of tensors, tensor sequences (:class:`tuple`, :class:`list`), and tensor mappings
(:class:`dict`, :class:`~collections.OrderedDict`)
"""
return [
(actual, expected),
((actual,), (expected,)),
([actual], [expected]),
({"t": actual}, {"t": expected}),
(collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
]
def assert_fns_with_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> Iterator[Callable]:
"""Yields assert functions with with included positional inputs based on two example tensors.
.. note::
This is a valid product of combinations from :meth:`get_assert_fns` and :meth:`make_inputs`. Every test
that does not test for anything specific should iterate over this to maximize the coverage.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
Yields:
List[Callable]: Assert functions with predefined positional inputs.
"""
for assert_fn, inputs in itertools.product(self.get_assert_fns(), self.make_inputs(actual, expected)):
yield functools.partial(assert_fn, *inputs)
@onlyCPU
def test_not_tensors(self, device):
actual = torch.empty((), device=device)
expected = np.empty(())
for fn in self.assert_fns():
with self.assertRaises(AssertionError):
for fn in self.get_assert_fns():
with self.assertRaises(UsageError):
fn(actual, expected)
@onlyCPU
@ -758,18 +809,18 @@ class TestAsserts(TestCase):
actual = torch.ones(1, dtype=torch.float32, device=device)
expected = torch.ones(1, dtype=torch.complex64, device=device)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn(actual, expected, check_dtype=False)
fn(check_dtype=False)
@onlyCPU
def test_sparse_support(self, device):
actual = torch.empty((), device=device)
expected = torch.sparse_coo_tensor(size=(), device=device)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn(actual, expected)
fn()
@onlyCPU
def test_quantized_support(self, device):
@ -778,221 +829,246 @@ class TestAsserts(TestCase):
expected = torch._empty_affine_quantized(actual.shape, scale=1, zero_point=0, dtype=torch.qint32, device=device)
expected.fill_(val)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn(actual, expected)
fn()
@onlyCPU
def test_mismatching_shape(self, device):
actual = torch.empty((), device=device)
expected = actual.clone().reshape((1,))
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "shape"):
fn(actual, expected)
fn()
@onlyCUDA
def test_mismatching_device(self, device):
actual = torch.empty((), device=device)
expected = actual.clone().cpu()
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "device"):
fn(actual, expected)
fn()
@onlyCUDA
def test_mismatching_device_no_check(self, device):
actual = torch.rand((), device=device)
expected = actual.clone().cpu()
for fn in self.assert_fns():
fn(actual, expected, check_device=False)
for fn in self.assert_fns_with_inputs(actual, expected):
fn(check_device=False)
@onlyCPU
def test_mismatching_dtype(self, device):
actual = torch.empty((), dtype=torch.float, device=device)
expected = actual.clone().to(torch.int)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "dtype"):
fn(actual, expected)
fn()
@onlyCPU
def test_mismatching_dtype_no_check(self, device):
actual = torch.ones((), dtype=torch.float, device=device)
expected = actual.clone().to(torch.int)
for fn in self.assert_fns():
fn(actual, expected, check_dtype=False)
for fn in self.assert_fns_with_inputs(actual, expected):
fn(check_dtype=False)
@onlyCPU
def test_mismatching_stride(self, device):
actual = torch.empty((2, 2), device=device)
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "stride"):
fn(actual, expected)
fn()
@onlyCPU
def test_mismatching_stride_no_check(self, device):
actual = torch.rand((2, 2), device=device)
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
for fn in self.assert_fns():
fn(actual, expected, check_stride=False)
for fn in self.assert_fns_with_inputs(actual, expected):
fn(check_stride=False)
@onlyCPU
def test_mismatching_values(self, device):
actual = torch.tensor(1, device=device)
expected = torch.tensor(2, device=device)
for fn in self.assert_fns():
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(AssertionError):
fn()
@onlyCPU
def test_assert_equal(self, device):
actual = torch.tensor(1, device=device)
expected = actual.clone()
torch.testing.assert_equal(actual, expected)
@onlyCPU
def test_assert_close(self, device):
actual = torch.tensor(1.0, device=device)
expected = actual.clone()
torch.testing.assert_close(actual, expected)
@onlyCPU
def test_assert_close_only_rtol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_close(actual, expected, rtol=0.0)
@onlyCPU
def test_assert_close_only_atol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_close(actual, expected, atol=0.0)
@onlyCPU
def test_assert_close_mismatching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_close(actual, expected, rtol=eps / 2, atol=0.0)
@onlyCPU
def test_assert_close_matching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
torch.testing.assert_close(actual, expected, rtol=eps * 2, atol=0.0)
@onlyCPU
def test_assert_close_mismatching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps / 2)
@onlyCPU
def test_assert_close_matching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps * 2)
@onlyCPU
def test_mismatching_values_msg_mismatches(self, device):
actual = torch.tensor([1, 2, 3, 4], device=device)
expected = torch.tensor([1, 2, 5, 6], device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
fn()
@onlyCPU
def test_mismatching_values_msg_abs_diff(self, device):
actual = torch.tensor([[1, 2], [3, 4]], device=device)
expected = torch.tensor([[1, 2], [5, 4]], device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at (1, 0)")):
fn()
@onlyCPU
def test_mismatching_values_msg_rel_diff(self, device):
actual = torch.tensor([[1, 2], [3, 4]], device=device)
expected = torch.tensor([[1, 4], [3, 4]], device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at (0, 1)")):
fn()
@onlyCPU
def test_assert_close_mismatching_values_msg_rtol(self, device):
rtol = 1e-3
actual = torch.tensor(1, device=device)
expected = torch.tensor(2, device=device)
for inputs in self.make_inputs(actual, expected):
with self.assertRaisesRegex(
AssertionError, re.escape(f"Greatest relative difference: 0.5 at 0 (up to {rtol} allowed)")
):
torch.testing.assert_close(*inputs, rtol=rtol, atol=0.0)
@onlyCPU
def test_assert_close_mismatching_values_msg_atol(self, device):
atol = 1e-3
actual = torch.tensor(1, device=device)
expected = torch.tensor(2, device=device)
for inputs in self.make_inputs(actual, expected):
with self.assertRaisesRegex(
AssertionError, re.escape(f"Greatest absolute difference: 1 at 0 (up to {atol} allowed)")
):
torch.testing.assert_close(*inputs, rtol=0.0, atol=atol)
@onlyCPU
def test_unknown_type(self, device):
actual = torch.empty((), device=device)
expected = {actual.clone()}
for fn in self.get_assert_fns():
with self.assertRaisesRegex(UsageError, str(type(expected))):
fn(actual, expected)
@onlyCPU
def test_sequence_mismatching_len(self, device):
actual = (torch.empty((), device=device),)
expected = ()
for fn in self.get_assert_fns():
with self.assertRaises(AssertionError):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_msg_abs_mismatches(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
def test_sequence_mismatching_values_msg(self, device):
t1 = torch.tensor(1, device=device)
t2 = torch.tensor(2, device=device)
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
actual = (t1, t1)
expected = (t1, t2)
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"\s+2\s+"):
for fn in self.get_assert_fns():
with self.assertRaisesRegex(AssertionError, r"index\s+1"):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_msg_rel_mismatches(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
def test_mapping_mismatching_keys(self, device):
actual = {"a": torch.empty((), device=device)}
expected = {}
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"22([.]2+)?\s*[%]"):
for fn in self.get_assert_fns():
with self.assertRaises(AssertionError):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_msg_max_abs_diff(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
def test_mapping_mismatching_values_msg(self, device):
t1 = torch.tensor(1, device=device)
t2 = torch.tensor(2, device=device)
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
actual = {"a": t1, "b": t1}
expected = {"a": t1, "b": t2}
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"\s+4[.]0\s+"):
for fn in self.get_assert_fns():
with self.assertRaisesRegex(AssertionError, r"key\s+'b'"):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_max_abs_diff_idx(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"1,\s*2"):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_msg_max_rel_diff(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"\s+0[.]5\s+"):
fn(actual, expected)
@onlyCPU
def test_mismatching_values_max_rel_diff_idx(self, device):
actual = torch.full((3, 3), 5, dtype=torch.float32, device=device)
expected = actual.clone()
actual[0, 1] = 1
expected[0, 1] = 2
expected[1, 2] = 9
for fn in self.assert_fns():
with self.assertRaisesRegex(AssertionError, r"0,\s*1"):
fn(actual, expected)
@onlyCPU
def test_assert_tensors_equal(self, device):
actual = torch.tensor(1, device=device)
expected = actual.clone()
torch.testing.assert_tensors_equal(actual, expected)
@onlyCPU
def test_assert_tensors_close(self, device):
actual = torch.tensor(1.0, device=device)
expected = actual.clone()
torch.testing.assert_tensors_close(actual, expected)
@onlyCPU
def test_assert_tensors_close_only_rtol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_tensors_close(actual, expected, rtol=0.0)
@onlyCPU
def test_assert_tensors_close_only_atol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_tensors_close(actual, expected, atol=0.0)
@onlyCPU
def test_assert_tensors_close_mismatching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_tensors_close(actual, expected, rtol=eps / 2, atol=0.0)
@onlyCPU
def test_assert_tensors_close_matching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
torch.testing.assert_tensors_close(actual, expected, rtol=eps * 2, atol=0.0)
@onlyCPU
def test_assert_tensors_close_mismatching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps / 2)
@onlyCPU
def test_assert_tensors_close_matching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
torch.testing.assert_tensors_close(actual, expected, rtol=0.0, atol=eps * 2)
instantiate_device_type_tests(TestAsserts, globals())