mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Changes including: - introduced `linter/`, `testing/`, `stats/` folders in `tools/` - move appropriate scripts into these folders - change grepped references in the pytorch/pytorch repo Next step - introduce `build/` folder for build scripts Pull Request resolved: https://github.com/pytorch/pytorch/pull/60473 Test Plan: - CI (this is important b/c pytorch/test-infra also rely on some script reference. - tools/tests/ Reviewed By: albanD Differential Revision: D29352716 Pulled By: walterddr fbshipit-source-id: bad40b5ce130b35dfd9e59b8af34f9025f3285fd
1256 lines
49 KiB
Python
1256 lines
49 KiB
Python
import collections
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import os
|
|
import re
|
|
import unittest
|
|
from typing import Any, Callable, Iterator, List, Tuple
|
|
|
|
import torch
|
|
|
|
from torch.testing._internal.common_utils import \
|
|
(IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
|
|
from torch.testing._internal.common_device_type import \
|
|
(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
|
|
get_device_type_test_bases, instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA,
|
|
deviceCountAtLeast)
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
import torch.testing._internal.opinfo_helper as opinfo_helper
|
|
from torch.testing._asserts import UsageError
|
|
|
|
# For testing TestCase methods and torch.testing functions
|
|
class TestTesting(TestCase):
|
|
# Ensure that assertEqual handles numpy arrays properly
|
|
@dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False,
|
|
include_bool=True, include_complex=True)))
|
|
def test_assertEqual_numpy(self, device, dtype):
|
|
S = 10
|
|
test_sizes = [
|
|
(),
|
|
(0,),
|
|
(S,),
|
|
(S, S),
|
|
(0, S),
|
|
(S, 0)]
|
|
for test_size in test_sizes:
|
|
a = make_tensor(test_size, device, dtype, low=-5, high=5)
|
|
a_n = a.cpu().numpy()
|
|
msg = f'size: {test_size}'
|
|
self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg)
|
|
self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg)
|
|
self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg)
|
|
|
|
# Tests that when rtol or atol (including self.precision) is set, then
|
|
# the other is zeroed.
|
|
# TODO: this is legacy behavior and should be updated after test
|
|
# precisions are reviewed to be consistent with torch.isclose.
|
|
@onlyOnCPUAndCUDA
|
|
def test__comparetensors_legacy(self, device):
|
|
a = torch.tensor((10000000.,))
|
|
b = torch.tensor((10000002.,))
|
|
|
|
x = torch.tensor((1.,))
|
|
y = torch.tensor((1. + 1e-5,))
|
|
|
|
# Helper for reusing the tensor values as scalars
|
|
def _scalar_helper(a, b, rtol=None, atol=None):
|
|
return self._compareScalars(a.item(), b.item(), rtol=rtol, atol=atol)
|
|
|
|
for op in (self._compareTensors, _scalar_helper):
|
|
# Tests default
|
|
result, debug_msg = op(a, b)
|
|
self.assertTrue(result)
|
|
|
|
# Tests setting atol
|
|
result, debug_msg = op(a, b, atol=2, rtol=0)
|
|
self.assertTrue(result)
|
|
|
|
# Tests setting atol too small
|
|
result, debug_msg = op(a, b, atol=1, rtol=0)
|
|
self.assertFalse(result)
|
|
|
|
# Tests setting rtol too small
|
|
result, debug_msg = op(x, y, atol=0, rtol=1.05e-5)
|
|
self.assertTrue(result)
|
|
|
|
# Tests setting rtol too small
|
|
result, debug_msg = op(x, y, atol=0, rtol=1e-5)
|
|
self.assertFalse(result)
|
|
|
|
@onlyOnCPUAndCUDA
|
|
def test__comparescalars_debug_msg(self, device):
|
|
# float x float
|
|
result, debug_msg = self._compareScalars(4., 7.)
|
|
expected_msg = ("Comparing 4.0 and 7.0 gives a difference of 3.0, "
|
|
"but the allowed difference with rtol=1.3e-06 and "
|
|
"atol=1e-05 is only 1.9100000000000003e-05!")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# complex x complex, real difference
|
|
result, debug_msg = self._compareScalars(complex(1, 3), complex(3, 1))
|
|
expected_msg = ("Comparing the real part 1.0 and 3.0 gives a difference "
|
|
"of 2.0, but the allowed difference with rtol=1.3e-06 "
|
|
"and atol=1e-05 is only 1.39e-05!")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# complex x complex, imaginary difference
|
|
result, debug_msg = self._compareScalars(complex(1, 3), complex(1, 5.5))
|
|
expected_msg = ("Comparing the imaginary part 3.0 and 5.5 gives a "
|
|
"difference of 2.5, but the allowed difference with "
|
|
"rtol=1.3e-06 and atol=1e-05 is only 1.715e-05!")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# complex x int
|
|
result, debug_msg = self._compareScalars(complex(1, -2), 1)
|
|
expected_msg = ("Comparing the imaginary part -2.0 and 0.0 gives a "
|
|
"difference of 2.0, but the allowed difference with "
|
|
"rtol=1.3e-06 and atol=1e-05 is only 1e-05!")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# NaN x NaN, equal_nan=False
|
|
result, debug_msg = self._compareScalars(float('nan'), float('nan'), equal_nan=False)
|
|
expected_msg = ("Found nan and nan while comparing and either one is "
|
|
"nan and the other isn't, or both are nan and equal_nan "
|
|
"is False")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks that compareTensors provides the correct debug info
|
|
@onlyOnCPUAndCUDA
|
|
def test__comparetensors_debug_msg(self, device):
|
|
# Acquires atol that will be used
|
|
atol = max(1e-05, self.precision)
|
|
|
|
# Checks float tensor comparisons (2D tensor)
|
|
a = torch.tensor(((0, 6), (7, 9)), device=device, dtype=torch.float32)
|
|
b = torch.tensor(((0, 7), (7, 22)), device=device, dtype=torch.float32)
|
|
result, debug_msg = self._compareTensors(a, b)
|
|
expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 4) "
|
|
"whose difference(s) exceeded the margin of error (including 0 nan comparisons). "
|
|
"The greatest difference was 13.0 (9.0 vs. 22.0), "
|
|
"which occurred at index (1, 1).").format(atol)
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks float tensor comparisons (with extremal values)
|
|
a = torch.tensor((float('inf'), 5, float('inf')), device=device, dtype=torch.float32)
|
|
b = torch.tensor((float('inf'), float('nan'), float('-inf')), device=device, dtype=torch.float32)
|
|
result, debug_msg = self._compareTensors(a, b)
|
|
expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 3) "
|
|
"whose difference(s) exceeded the margin of error (including 1 nan comparisons). "
|
|
"The greatest difference was nan (5.0 vs. nan), "
|
|
"which occurred at index 1.").format(atol)
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks float tensor comparisons (with finite vs nan differences)
|
|
a = torch.tensor((20, -6), device=device, dtype=torch.float32)
|
|
b = torch.tensor((-1, float('nan')), device=device, dtype=torch.float32)
|
|
result, debug_msg = self._compareTensors(a, b)
|
|
expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 2) "
|
|
"whose difference(s) exceeded the margin of error (including 1 nan comparisons). "
|
|
"The greatest difference was nan (-6.0 vs. nan), "
|
|
"which occurred at index 1.").format(atol)
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks int tensor comparisons (1D tensor)
|
|
a = torch.tensor((1, 2, 3, 4), device=device)
|
|
b = torch.tensor((2, 5, 3, 4), device=device)
|
|
result, debug_msg = self._compareTensors(a, b)
|
|
expected_msg = ("Found 2 different element(s) (out of 4), "
|
|
"with the greatest difference of 3 (2 vs. 5) "
|
|
"occuring at index 1.")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks bool tensor comparisons (0D tensor)
|
|
a = torch.tensor((True), device=device)
|
|
b = torch.tensor((False), device=device)
|
|
result, debug_msg = self._compareTensors(a, b)
|
|
expected_msg = ("Found 1 different element(s) (out of 1), "
|
|
"with the greatest difference of 1 (1 vs. 0) "
|
|
"occuring at index 0.")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks complex tensor comparisons (real part)
|
|
a = torch.tensor((1 - 1j, 4 + 3j), device=device)
|
|
b = torch.tensor((1 - 1j, 1 + 3j), device=device)
|
|
result, debug_msg = self._compareTensors(a, b)
|
|
expected_msg = ("Real parts failed to compare as equal! "
|
|
"With rtol=1.3e-06 and atol={0}, "
|
|
"found 1 element(s) (out of 2) whose difference(s) exceeded the "
|
|
"margin of error (including 0 nan comparisons). The greatest difference was "
|
|
"3.0 (4.0 vs. 1.0), which occurred at index 1.").format(atol)
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks complex tensor comparisons (imaginary part)
|
|
a = torch.tensor((1 - 1j, 4 + 3j), device=device)
|
|
b = torch.tensor((1 - 1j, 4 - 21j), device=device)
|
|
result, debug_msg = self._compareTensors(a, b)
|
|
expected_msg = ("Imaginary parts failed to compare as equal! "
|
|
"With rtol=1.3e-06 and atol={0}, "
|
|
"found 1 element(s) (out of 2) whose difference(s) exceeded the "
|
|
"margin of error (including 0 nan comparisons). The greatest difference was "
|
|
"24.0 (3.0 vs. -21.0), which occurred at index 1.").format(atol)
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks size mismatch
|
|
a = torch.tensor((1, 2), device=device)
|
|
b = torch.tensor((3), device=device)
|
|
result, debug_msg = self._compareTensors(a, b)
|
|
expected_msg = ("Attempted to compare equality of tensors "
|
|
"with different sizes. Got sizes torch.Size([2]) and torch.Size([]).")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks dtype mismatch
|
|
a = torch.tensor((1, 2), device=device, dtype=torch.long)
|
|
b = torch.tensor((1, 2), device=device, dtype=torch.float32)
|
|
result, debug_msg = self._compareTensors(a, b, exact_dtype=True)
|
|
expected_msg = ("Attempted to compare equality of tensors "
|
|
"with different dtypes. Got dtypes torch.int64 and torch.float32.")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Checks device mismatch
|
|
if self.device_type == 'cuda':
|
|
a = torch.tensor((5), device='cpu')
|
|
b = torch.tensor((5), device=device)
|
|
result, debug_msg = self._compareTensors(a, b, exact_device=True)
|
|
expected_msg = ("Attempted to compare equality of tensors "
|
|
"on different devices! Got devices cpu and cuda:0.")
|
|
self.assertEqual(debug_msg, expected_msg)
|
|
|
|
# Helper for testing _compareTensors and _compareScalars
|
|
# Works on single element tensors
|
|
def _comparetensors_helper(self, tests, device, dtype, equal_nan, exact_dtype=True, atol=1e-08, rtol=1e-05):
|
|
for test in tests:
|
|
a = torch.tensor((test[0],), device=device, dtype=dtype)
|
|
b = torch.tensor((test[1],), device=device, dtype=dtype)
|
|
|
|
# Tensor x Tensor comparison
|
|
compare_result, debug_msg = self._compareTensors(a, b, rtol=rtol, atol=atol,
|
|
equal_nan=equal_nan,
|
|
exact_dtype=exact_dtype)
|
|
self.assertEqual(compare_result, test[2])
|
|
|
|
# Scalar x Scalar comparison
|
|
compare_result, debug_msg = self._compareScalars(a.item(), b.item(),
|
|
rtol=rtol, atol=atol,
|
|
equal_nan=equal_nan)
|
|
self.assertEqual(compare_result, test[2])
|
|
|
|
def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
|
|
for test in tests:
|
|
a = torch.tensor((test[0],), device=device, dtype=dtype)
|
|
b = torch.tensor((test[1],), device=device, dtype=dtype)
|
|
|
|
actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
|
|
expected = test[2]
|
|
self.assertEqual(actual.item(), expected)
|
|
|
|
# torch.close is not implemented for bool tensors
|
|
# see https://github.com/pytorch/pytorch/issues/33048
|
|
def test_isclose_comparetensors_bool(self, device):
|
|
tests = (
|
|
(True, True, True),
|
|
(False, False, True),
|
|
(True, False, False),
|
|
(False, True, False),
|
|
)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
self._isclose_helper(tests, device, torch.bool, False)
|
|
|
|
self._comparetensors_helper(tests, device, torch.bool, False)
|
|
|
|
@dtypes(torch.uint8,
|
|
torch.int8, torch.int16, torch.int32, torch.int64)
|
|
def test_isclose_comparetensors_integer(self, device, dtype):
|
|
tests = (
|
|
(0, 0, True),
|
|
(0, 1, False),
|
|
(1, 0, False),
|
|
)
|
|
|
|
self._isclose_helper(tests, device, dtype, False)
|
|
|
|
# atol and rtol tests
|
|
tests = [
|
|
(0, 1, True),
|
|
(1, 0, False),
|
|
(1, 3, True),
|
|
]
|
|
|
|
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
|
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
|
|
|
if dtype is torch.uint8:
|
|
tests = [
|
|
(-1, 1, False),
|
|
(1, -1, False)
|
|
]
|
|
else:
|
|
tests = [
|
|
(-1, 1, True),
|
|
(1, -1, True)
|
|
]
|
|
|
|
self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
|
|
self._comparetensors_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
|
|
|
|
@onlyOnCPUAndCUDA
|
|
@dtypes(torch.float16, torch.float32, torch.float64)
|
|
def test_isclose_comparetensors_float(self, device, dtype):
|
|
tests = (
|
|
(0, 0, True),
|
|
(0, -1, False),
|
|
(float('inf'), float('inf'), True),
|
|
(-float('inf'), float('inf'), False),
|
|
(float('inf'), float('nan'), False),
|
|
(float('nan'), float('nan'), False),
|
|
(0, float('nan'), False),
|
|
(1, 1, True),
|
|
)
|
|
|
|
self._isclose_helper(tests, device, dtype, False)
|
|
self._comparetensors_helper(tests, device, dtype, False)
|
|
|
|
# atol and rtol tests
|
|
eps = 1e-2 if dtype is torch.half else 1e-6
|
|
tests = (
|
|
(0, 1, True),
|
|
(0, 1 + eps, False),
|
|
(1, 0, False),
|
|
(1, 3, True),
|
|
(1 - eps, 3, False),
|
|
(-.25, .5, True),
|
|
(-.25 - eps, .5, False),
|
|
(.25, -.5, True),
|
|
(.25 + eps, -.5, False),
|
|
)
|
|
|
|
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
|
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
|
|
|
# equal_nan = True tests
|
|
tests = (
|
|
(0, float('nan'), False),
|
|
(float('inf'), float('nan'), False),
|
|
(float('nan'), float('nan'), True),
|
|
)
|
|
|
|
self._isclose_helper(tests, device, dtype, True)
|
|
|
|
self._comparetensors_helper(tests, device, dtype, True)
|
|
|
|
# torch.close with equal_nan=True is not implemented for complex inputs
|
|
# see https://github.com/numpy/numpy/issues/15959
|
|
# Note: compareTensor will compare the real and imaginary parts of a
|
|
# complex tensors separately, unlike isclose.
|
|
@dtypes(torch.complex64, torch.complex128)
|
|
def test_isclose_comparetensors_complex(self, device, dtype):
|
|
tests = (
|
|
(complex(1, 1), complex(1, 1 + 1e-8), True),
|
|
(complex(0, 1), complex(1, 1), False),
|
|
(complex(1, 1), complex(1, 0), False),
|
|
(complex(1, 1), complex(1, float('nan')), False),
|
|
(complex(1, float('nan')), complex(1, float('nan')), False),
|
|
(complex(1, 1), complex(1, float('inf')), False),
|
|
(complex(float('inf'), 1), complex(1, float('inf')), False),
|
|
(complex(-float('inf'), 1), complex(1, float('inf')), False),
|
|
(complex(-float('inf'), 1), complex(float('inf'), 1), False),
|
|
(complex(float('inf'), 1), complex(float('inf'), 1), True),
|
|
(complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
|
|
)
|
|
|
|
self._isclose_helper(tests, device, dtype, False)
|
|
self._comparetensors_helper(tests, device, dtype, False)
|
|
|
|
# atol and rtol tests
|
|
|
|
# atol and rtol tests
|
|
eps = 1e-6
|
|
tests = (
|
|
# Complex versions of float tests (real part)
|
|
(complex(0, 0), complex(1, 0), True),
|
|
(complex(0, 0), complex(1 + eps, 0), False),
|
|
(complex(1, 0), complex(0, 0), False),
|
|
(complex(1, 0), complex(3, 0), True),
|
|
(complex(1 - eps, 0), complex(3, 0), False),
|
|
(complex(-.25, 0), complex(.5, 0), True),
|
|
(complex(-.25 - eps, 0), complex(.5, 0), False),
|
|
(complex(.25, 0), complex(-.5, 0), True),
|
|
(complex(.25 + eps, 0), complex(-.5, 0), False),
|
|
# Complex versions of float tests (imaginary part)
|
|
(complex(0, 0), complex(0, 1), True),
|
|
(complex(0, 0), complex(0, 1 + eps), False),
|
|
(complex(0, 1), complex(0, 0), False),
|
|
(complex(0, 1), complex(0, 3), True),
|
|
(complex(0, 1 - eps), complex(0, 3), False),
|
|
(complex(0, -.25), complex(0, .5), True),
|
|
(complex(0, -.25 - eps), complex(0, .5), False),
|
|
(complex(0, .25), complex(0, -.5), True),
|
|
(complex(0, .25 + eps), complex(0, -.5), False),
|
|
)
|
|
|
|
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
|
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
|
|
|
# atol and rtol tests for isclose
|
|
tests = (
|
|
# Complex-specific tests
|
|
(complex(1, -1), complex(-1, 1), False),
|
|
(complex(1, -1), complex(2, -2), True),
|
|
(complex(-math.sqrt(2), math.sqrt(2)),
|
|
complex(-math.sqrt(.5), math.sqrt(.5)), True),
|
|
(complex(-math.sqrt(2), math.sqrt(2)),
|
|
complex(-math.sqrt(.501), math.sqrt(.499)), False),
|
|
(complex(2, 4), complex(1., 8.8523607), True),
|
|
(complex(2, 4), complex(1., 8.8523607 + eps), False),
|
|
(complex(1, 99), complex(4, 100), True),
|
|
)
|
|
|
|
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
|
|
|
# atol and rtol tests for compareTensors
|
|
tests = (
|
|
(complex(1, -1), complex(-1, 1), False),
|
|
(complex(1, -1), complex(2, -2), True),
|
|
(complex(1, 99), complex(4, 100), False),
|
|
)
|
|
|
|
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
|
|
|
# equal_nan = True tests
|
|
tests = (
|
|
(complex(1, 1), complex(1, float('nan')), False),
|
|
(complex(float('nan'), 1), complex(1, float('nan')), False),
|
|
(complex(float('nan'), 1), complex(float('nan'), 1), True),
|
|
)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
self._isclose_helper(tests, device, dtype, True)
|
|
|
|
self._comparetensors_helper(tests, device, dtype, True)
|
|
|
|
# Tests that isclose with rtol or atol values less than zero throws a
|
|
# RuntimeError
|
|
@dtypes(torch.bool, torch.uint8,
|
|
torch.int8, torch.int16, torch.int32, torch.int64,
|
|
torch.float16, torch.float32, torch.float64)
|
|
def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
|
|
t = torch.tensor((1,), device=device, dtype=dtype)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch.isclose(t, t, atol=-1, rtol=1)
|
|
with self.assertRaises(RuntimeError):
|
|
torch.isclose(t, t, atol=1, rtol=-1)
|
|
with self.assertRaises(RuntimeError):
|
|
torch.isclose(t, t, atol=-1, rtol=-1)
|
|
|
|
@dtypes(torch.bool, torch.long, torch.float, torch.cfloat)
|
|
def test_make_tensor(self, device, dtype):
|
|
def check(size, low, high, requires_grad, noncontiguous):
|
|
t = make_tensor(size, device, dtype, low=low, high=high,
|
|
requires_grad=requires_grad, noncontiguous=noncontiguous)
|
|
|
|
self.assertEqual(t.shape, size)
|
|
self.assertEqual(t.device, torch.device(device))
|
|
self.assertEqual(t.dtype, dtype)
|
|
|
|
low = -9 if low is None else low
|
|
high = 9 if high is None else high
|
|
|
|
if t.numel() > 0 and dtype in [torch.long, torch.float]:
|
|
self.assertTrue(t.le(high).logical_and(t.ge(low)).all().item())
|
|
|
|
if dtype in [torch.float, torch.cfloat]:
|
|
self.assertEqual(t.requires_grad, requires_grad)
|
|
else:
|
|
self.assertFalse(t.requires_grad)
|
|
|
|
if t.numel() > 1:
|
|
self.assertEqual(t.is_contiguous(), not noncontiguous)
|
|
else:
|
|
self.assertTrue(t.is_contiguous())
|
|
|
|
for size in (tuple(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)):
|
|
check(size, None, None, False, False)
|
|
check(size, 2, 4, True, True)
|
|
|
|
def test_assert_messages(self, device):
|
|
self.assertIsNone(self._get_assert_msg(msg=None))
|
|
self.assertEqual("\nno_debug_msg", self._get_assert_msg("no_debug_msg"))
|
|
self.assertEqual("no_user_msg", self._get_assert_msg(msg=None, debug_msg="no_user_msg"))
|
|
self.assertEqual("debug_msg\nuser_msg", self._get_assert_msg(msg="user_msg", debug_msg="debug_msg"))
|
|
|
|
# The following tests (test_cuda_assert_*) are added to ensure test suite terminates early
|
|
# when CUDA assert was thrown. Because all subsequent test will fail if that happens.
|
|
# These tests are slow because it spawn another process to run test suite.
|
|
# See: https://github.com/pytorch/pytorch/issues/49019
|
|
@onlyCUDA
|
|
@slowTest
|
|
def test_cuda_assert_should_stop_common_utils_test_suite(self, device):
|
|
# test to ensure common_utils.py override has early termination for CUDA.
|
|
stderr = TestCase.runWithPytorchAPIUsageStderr("""\
|
|
#!/usr/bin/env python3
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
|
|
|
|
class TestThatContainsCUDAAssertFailure(TestCase):
|
|
|
|
@slowTest
|
|
def test_throw_unrecoverable_cuda_exception(self):
|
|
x = torch.rand(10, device='cuda')
|
|
# cause unrecoverable CUDA exception, recoverable on CPU
|
|
y = x[torch.tensor([25])].cpu()
|
|
|
|
@slowTest
|
|
def test_trivial_passing_test_case_on_cpu_cuda(self):
|
|
x1 = torch.tensor([0., 1.], device='cuda')
|
|
x2 = torch.tensor([0., 1.], device='cpu')
|
|
self.assertEqual(x1, x2)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|
|
""")
|
|
# should capture CUDA error
|
|
self.assertIn('CUDA error: device-side assert triggered', stderr)
|
|
# should run only 1 test because it throws unrecoverable error.
|
|
self.assertIn('Ran 1 test', stderr)
|
|
|
|
|
|
@onlyCUDA
|
|
@slowTest
|
|
def test_cuda_assert_should_stop_common_device_type_test_suite(self, device):
|
|
# test to ensure common_device_type.py override has early termination for CUDA.
|
|
stderr = TestCase.runWithPytorchAPIUsageStderr("""\
|
|
#!/usr/bin/env python3
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
|
|
class TestThatContainsCUDAAssertFailure(TestCase):
|
|
|
|
@slowTest
|
|
def test_throw_unrecoverable_cuda_exception(self, device):
|
|
x = torch.rand(10, device=device)
|
|
# cause unrecoverable CUDA exception, recoverable on CPU
|
|
y = x[torch.tensor([25])].cpu()
|
|
|
|
@slowTest
|
|
def test_trivial_passing_test_case_on_cpu_cuda(self, device):
|
|
x1 = torch.tensor([0., 1.], device=device)
|
|
x2 = torch.tensor([0., 1.], device='cpu')
|
|
self.assertEqual(x1, x2)
|
|
|
|
instantiate_device_type_tests(
|
|
TestThatContainsCUDAAssertFailure,
|
|
globals(),
|
|
only_for='cuda'
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|
|
""")
|
|
# should capture CUDA error
|
|
self.assertIn('CUDA error: device-side assert triggered', stderr)
|
|
# should run only 1 test because it throws unrecoverable error.
|
|
self.assertIn('Ran 1 test', stderr)
|
|
|
|
|
|
@onlyCUDA
|
|
@slowTest
|
|
def test_cuda_assert_should_not_stop_common_distributed_test_suite(self, device):
|
|
# test to ensure common_distributed.py override should not early terminate CUDA.
|
|
stderr = TestCase.runWithPytorchAPIUsageStderr("""\
|
|
#!/usr/bin/env python3
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import (run_tests, slowTest)
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
|
|
|
class TestThatContainsCUDAAssertFailure(MultiProcessTestCase):
|
|
|
|
@slowTest
|
|
def test_throw_unrecoverable_cuda_exception(self, device):
|
|
x = torch.rand(10, device=device)
|
|
# cause unrecoverable CUDA exception, recoverable on CPU
|
|
y = x[torch.tensor([25])].cpu()
|
|
|
|
@slowTest
|
|
def test_trivial_passing_test_case_on_cpu_cuda(self, device):
|
|
x1 = torch.tensor([0., 1.], device=device)
|
|
x2 = torch.tensor([0., 1.], device='cpu')
|
|
self.assertEqual(x1, x2)
|
|
|
|
instantiate_device_type_tests(
|
|
TestThatContainsCUDAAssertFailure,
|
|
globals(),
|
|
only_for='cuda'
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|
|
""")
|
|
# we are currently disabling CUDA early termination for distributed tests.
|
|
self.assertIn('Ran 2 test', stderr)
|
|
|
|
@onlyOnCPUAndCUDA
|
|
def test_get_supported_dtypes(self, device):
|
|
# Test the `get_supported_dtypes` helper function.
|
|
# We acquire the dtypes for few Ops dynamically and verify them against
|
|
# the correct statically described values.
|
|
ops_to_test = list(filter(lambda op: op.name in ['atan2', 'topk', 'xlogy'], op_db))
|
|
|
|
for op in ops_to_test:
|
|
dynamic_dtypes = opinfo_helper.get_supported_dtypes(op.op, op.sample_inputs_func, self.device_type)
|
|
dynamic_dispatch = opinfo_helper.dtypes_dispatch_hint(dynamic_dtypes)
|
|
if self.device_type == 'cpu':
|
|
dtypes = op.dtypesIfCPU
|
|
else: # device_type ='cuda'
|
|
dtypes = op.dtypesIfCUDA
|
|
|
|
self.assertTrue(set(dtypes) == set(dynamic_dtypes))
|
|
self.assertTrue(set(dtypes) == set(dynamic_dispatch.dispatch_fn()))
|
|
|
|
instantiate_device_type_tests(TestTesting, globals())
|
|
|
|
|
|
class TestFrameworkUtils(TestCase):
|
|
|
|
@skipIfRocm
|
|
@unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
|
|
@unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
|
|
def test_filtering_env_var(self):
|
|
# Test environment variable selected device type test generator.
|
|
test_filter_file_template = """\
|
|
#!/usr/bin/env python3
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import (TestCase, run_tests)
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
|
|
class TestEnvironmentVariable(TestCase):
|
|
|
|
def test_trivial_passing_test(self, device):
|
|
x1 = torch.tensor([0., 1.], device=device)
|
|
x2 = torch.tensor([0., 1.], device='cpu')
|
|
self.assertEqual(x1, x2)
|
|
|
|
instantiate_device_type_tests(
|
|
TestEnvironmentVariable,
|
|
globals(),
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|
|
"""
|
|
test_bases_count = len(get_device_type_test_bases())
|
|
# Test without setting env var should run everything.
|
|
env = dict(os.environ)
|
|
for k in ['IN_CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]:
|
|
if k in env.keys():
|
|
del env[k]
|
|
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
|
|
self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii'))
|
|
|
|
# Test with setting only_for should only run 1 test.
|
|
env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
|
|
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
|
|
self.assertIn('Ran 1 test', stderr.decode('ascii'))
|
|
|
|
# Test with setting except_for should run 1 less device type from default.
|
|
del env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY]
|
|
env[PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY] = 'cpu'
|
|
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
|
|
self.assertIn(f'Ran {test_bases_count-1} test', stderr.decode('ascii'))
|
|
|
|
# Test with setting both should throw exception
|
|
env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
|
|
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
|
|
self.assertNotIn('OK', stderr.decode('ascii'))
|
|
|
|
|
|
def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]:
|
|
"""Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
|
|
|
|
Args:
|
|
actual (Any): Actual input.
|
|
expected (Any): Expected input.
|
|
|
|
Returns:
|
|
List[Tuple[Any, Any]]: Pair of example inputs, as well as the example inputs wrapped in sequences
|
|
(:class:`tuple`, :class:`list`), and mappings (:class:`dict`, :class:`~collections.OrderedDict`).
|
|
"""
|
|
return [
|
|
(actual, expected),
|
|
# tuple vs. tuple
|
|
((actual,), (expected,)),
|
|
# list vs. list
|
|
([actual], [expected]),
|
|
# tuple vs. list
|
|
((actual,), [expected]),
|
|
# dict vs. dict
|
|
({"t": actual}, {"t": expected}),
|
|
# OrderedDict vs. OrderedDict
|
|
(collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
|
|
# dict vs. OrderedDict
|
|
({"t": actual}, collections.OrderedDict([("t", expected)])),
|
|
# list of tuples vs. tuple of lists
|
|
([(actual,)], ([expected],)),
|
|
# list of dicts vs. tuple of OrderedDicts
|
|
([{"t": actual}], (collections.OrderedDict([("t", expected)]),)),
|
|
# dict of lists vs. OrderedDict of tuples
|
|
({"t": [actual]}, collections.OrderedDict([("t", (expected,))])),
|
|
]
|
|
|
|
|
|
def assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]:
|
|
"""Yields :func:`torch.testing.assert_close` with predefined positional inputs based on two examples.
|
|
|
|
.. note::
|
|
|
|
Every test that does not test for a specific input should iterate over this to maximize the coverage.
|
|
|
|
Args:
|
|
actual (Any): Actual input.
|
|
expected (Any): Expected input.
|
|
|
|
Yields:
|
|
Callable: :func:`torch.testing.assert_close` with predefined positional inputs.
|
|
"""
|
|
for inputs in make_assert_close_inputs(actual, expected):
|
|
yield functools.partial(torch.testing.assert_close, *inputs)
|
|
|
|
|
|
class TestAssertClose(TestCase):
|
|
def test_quantized_support(self):
|
|
val = 1
|
|
actual = torch.tensor([val], dtype=torch.int32)
|
|
expected = torch._empty_affine_quantized(actual.shape, scale=1, zero_point=0, dtype=torch.qint32)
|
|
expected.fill_(val)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaises(UsageError):
|
|
fn()
|
|
|
|
def test_type_inequality(self):
|
|
actual = torch.empty(2)
|
|
expected = actual.tolist()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, str(type(expected))):
|
|
fn()
|
|
|
|
def test_unknown_type(self):
|
|
actual = "0"
|
|
expected = "0"
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(UsageError, str(type(actual))):
|
|
fn()
|
|
|
|
def test_mismatching_shape(self):
|
|
actual = torch.empty(())
|
|
expected = actual.clone().reshape((1,))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, "shape"):
|
|
fn()
|
|
|
|
@unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
|
|
def test_unknown_layout(self):
|
|
actual = torch.empty((2, 2))
|
|
expected = actual.to_mkldnn()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaises(UsageError):
|
|
fn()
|
|
|
|
def test_mismatching_layout(self):
|
|
strided = torch.empty((2, 2))
|
|
sparse_coo = strided.to_sparse()
|
|
sparse_csr = strided.to_sparse_csr()
|
|
|
|
for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, "layout"):
|
|
fn()
|
|
|
|
def test_mismatching_dtype(self):
|
|
actual = torch.empty((), dtype=torch.float)
|
|
expected = actual.clone().to(torch.int)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, "dtype"):
|
|
fn()
|
|
|
|
def test_mismatching_dtype_no_check(self):
|
|
actual = torch.ones((), dtype=torch.float)
|
|
expected = actual.clone().to(torch.int)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn(check_dtype=False)
|
|
|
|
def test_mismatching_stride(self):
|
|
actual = torch.empty((2, 2))
|
|
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, "stride"):
|
|
fn()
|
|
|
|
def test_mismatching_stride_no_check(self):
|
|
actual = torch.rand((2, 2))
|
|
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn(check_stride=False)
|
|
|
|
def test_only_rtol(self):
|
|
actual = torch.empty(())
|
|
expected = actual.clone()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaises(UsageError):
|
|
fn(rtol=0.0)
|
|
|
|
def test_only_atol(self):
|
|
actual = torch.empty(())
|
|
expected = actual.clone()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaises(UsageError):
|
|
fn(atol=0.0)
|
|
|
|
def test_mismatching_values(self):
|
|
actual = torch.tensor(1)
|
|
expected = torch.tensor(2)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaises(AssertionError):
|
|
fn()
|
|
|
|
def test_mismatching_values_rtol(self):
|
|
eps = 1e-3
|
|
actual = torch.tensor(1.0)
|
|
expected = torch.tensor(1.0 + eps)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaises(AssertionError):
|
|
fn(rtol=eps / 2, atol=0.0)
|
|
|
|
def test_mismatching_values_atol(self):
|
|
eps = 1e-3
|
|
actual = torch.tensor(0.0)
|
|
expected = torch.tensor(eps)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaises(AssertionError):
|
|
fn(rtol=0.0, atol=eps / 2)
|
|
|
|
def test_matching(self):
|
|
actual = torch.tensor(1.0)
|
|
expected = actual.clone()
|
|
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_matching_rtol(self):
|
|
eps = 1e-3
|
|
actual = torch.tensor(1.0)
|
|
expected = torch.tensor(1.0 + eps)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn(rtol=eps * 2, atol=0.0)
|
|
|
|
def test_matching_atol(self):
|
|
eps = 1e-3
|
|
actual = torch.tensor(0.0)
|
|
expected = torch.tensor(eps)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn(rtol=0.0, atol=eps * 2)
|
|
|
|
def test_matching_nan(self):
|
|
actual = torch.tensor(float("NaN"))
|
|
expected = actual.clone()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaises(AssertionError):
|
|
fn()
|
|
|
|
def test_matching_nan_with_equal_nan(self):
|
|
actual = torch.tensor(float("NaN"))
|
|
expected = actual.clone()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn(equal_nan=True)
|
|
|
|
def test_numpy(self):
|
|
tensor = torch.rand(2, 2, dtype=torch.float32)
|
|
actual = tensor.numpy()
|
|
expected = actual.copy()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn()
|
|
|
|
def test_scalar(self):
|
|
number = torch.randint(10, size=()).item()
|
|
for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
|
|
check_dtype = type(actual) is type(expected)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn(check_dtype=check_dtype)
|
|
|
|
|
|
class TestAssertCloseMultiDevice(TestCase):
|
|
@deviceCountAtLeast(1)
|
|
def test_mismatching_device(self, devices):
|
|
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
|
|
actual = torch.empty((), device=actual_device)
|
|
expected = actual.clone().to(expected_device)
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, "device"):
|
|
fn()
|
|
|
|
@deviceCountAtLeast(1)
|
|
def test_mismatching_device_no_check(self, devices):
|
|
for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
|
|
actual = torch.rand((), device=actual_device)
|
|
expected = actual.clone().to(expected_device)
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn(check_device=False)
|
|
|
|
|
|
instantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda")
|
|
|
|
|
|
class TestAssertCloseErrorMessage(TestCase):
|
|
def test_mismatched_elements(self):
|
|
actual = torch.tensor([1, 2, 3, 4])
|
|
expected = torch.tensor([1, 2, 5, 6])
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
|
|
fn()
|
|
|
|
def test_abs_diff(self):
|
|
actual = torch.tensor([[1, 2], [3, 4]])
|
|
expected = torch.tensor([[1, 2], [5, 4]])
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at (1, 0)")):
|
|
fn()
|
|
|
|
def test_rel_diff(self):
|
|
actual = torch.tensor([[1, 2], [3, 4]])
|
|
expected = torch.tensor([[1, 4], [3, 4]])
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at (0, 1)")):
|
|
fn()
|
|
|
|
def test_zero_div_zero(self):
|
|
actual = torch.tensor([1.0, 0.0])
|
|
expected = torch.tensor([2.0, 0.0])
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
# Although it looks complicated, this regex just makes sure that the word 'nan' is not part of the error
|
|
# message. That would happen if the 0 / 0 is used for the mismatch computation although it matches.
|
|
with self.assertRaisesRegex(AssertionError, "((?!nan).)*"):
|
|
fn()
|
|
|
|
def test_rtol(self):
|
|
rtol = 1e-3
|
|
|
|
actual = torch.tensor(1)
|
|
expected = torch.tensor(2)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(
|
|
AssertionError, re.escape(f"Greatest relative difference: 0.5 at 0 (up to {rtol} allowed)")
|
|
):
|
|
fn(rtol=rtol, atol=0.0)
|
|
|
|
def test_atol(self):
|
|
atol = 1e-3
|
|
|
|
actual = torch.tensor(1)
|
|
expected = torch.tensor(2)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(
|
|
AssertionError, re.escape(f"Greatest absolute difference: 1 at 0 (up to {atol} allowed)")
|
|
):
|
|
fn(rtol=0.0, atol=atol)
|
|
|
|
def test_msg_str(self):
|
|
msg = "Custom error message!"
|
|
|
|
actual = torch.tensor(1)
|
|
expected = torch.tensor(2)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, msg):
|
|
fn(msg=msg)
|
|
|
|
def test_msg_callable(self):
|
|
msg = "Custom error message!"
|
|
|
|
def make_msg(actual, expected, trace):
|
|
return msg
|
|
|
|
actual = torch.tensor(1)
|
|
expected = torch.tensor(2)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, msg):
|
|
fn(msg=make_msg)
|
|
|
|
|
|
class TestAssertCloseContainer(TestCase):
|
|
def test_sequence_mismatching_len(self):
|
|
actual = (torch.empty(()),)
|
|
expected = ()
|
|
|
|
with self.assertRaises(AssertionError):
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_sequence_mismatching_values_msg(self):
|
|
t1 = torch.tensor(1)
|
|
t2 = torch.tensor(2)
|
|
|
|
actual = (t1, t1)
|
|
expected = (t1, t2)
|
|
|
|
with self.assertRaisesRegex(AssertionError, r"index\s+1"):
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_mapping_mismatching_keys(self):
|
|
actual = {"a": torch.empty(())}
|
|
expected = {}
|
|
|
|
with self.assertRaises(AssertionError):
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
def test_mapping_mismatching_values_msg(self):
|
|
t1 = torch.tensor(1)
|
|
t2 = torch.tensor(2)
|
|
|
|
actual = {"a": t1, "b": t1}
|
|
expected = {"a": t1, "b": t2}
|
|
|
|
with self.assertRaisesRegex(AssertionError, r"key\s+'b'"):
|
|
torch.testing.assert_close(actual, expected)
|
|
|
|
|
|
class TestAssertCloseComplex(TestCase):
|
|
def test_mismatching_nan_with_equal_nan(self):
|
|
actual = torch.tensor(complex(1, float("NaN")))
|
|
expected = torch.tensor(complex(float("NaN"), 1))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaises(AssertionError):
|
|
fn(equal_nan=True)
|
|
|
|
def test_mismatching_nan_with_equal_nan_relaxed(self):
|
|
actual = torch.tensor(complex(1, float("NaN")))
|
|
expected = torch.tensor(complex(float("NaN"), 1))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn(equal_nan="relaxed")
|
|
|
|
def test_mismatching_values_msg_real(self):
|
|
actual = torch.tensor(complex(0, 1))
|
|
expected = torch.tensor(complex(1, 1))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the real part")):
|
|
fn()
|
|
|
|
def test_mismatching_values_msg_imag(self):
|
|
actual = torch.tensor(complex(1, 0))
|
|
expected = torch.tensor(complex(1, 1))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the imaginary part")):
|
|
fn()
|
|
|
|
|
|
class TestAssertCloseSparseCOO(TestCase):
|
|
def test_matching_coalesced(self):
|
|
indices = (
|
|
(0, 1),
|
|
(1, 0),
|
|
)
|
|
values = (1, 2)
|
|
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
|
|
expected = actual.clone()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn()
|
|
|
|
def test_matching_uncoalesced(self):
|
|
indices = (
|
|
(0, 1),
|
|
(1, 0),
|
|
)
|
|
values = (1, 2)
|
|
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
|
|
expected = actual.clone()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn()
|
|
|
|
def test_mismatching_is_coalesced(self):
|
|
indices = (
|
|
(0, 1),
|
|
(1, 0),
|
|
)
|
|
values = (1, 2)
|
|
actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
|
|
expected = actual.clone().coalesce()
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, "is_coalesced"):
|
|
fn()
|
|
|
|
def test_mismatching_is_coalesced_no_check(self):
|
|
actual_indices = (
|
|
(0, 1),
|
|
(1, 0),
|
|
)
|
|
actual_values = (1, 2)
|
|
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2)).coalesce()
|
|
|
|
expected_indices = (
|
|
(0, 1, 1,),
|
|
(1, 0, 0,),
|
|
)
|
|
expected_values = (1, 1, 1)
|
|
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn(check_is_coalesced=False)
|
|
|
|
def test_mismatching_nnz(self):
|
|
actual_indices = (
|
|
(0, 1),
|
|
(1, 0),
|
|
)
|
|
actual_values = (1, 2)
|
|
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
|
|
|
|
expected_indices = (
|
|
(0, 1, 1,),
|
|
(1, 0, 0,),
|
|
)
|
|
expected_values = (1, 1, 1)
|
|
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("number of specified values")):
|
|
fn()
|
|
|
|
def test_mismatching_indices_msg(self):
|
|
actual_indices = (
|
|
(0, 1),
|
|
(1, 0),
|
|
)
|
|
actual_values = (1, 2)
|
|
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
|
|
|
|
expected_indices = (
|
|
(0, 1),
|
|
(1, 1),
|
|
)
|
|
expected_values = (1, 2)
|
|
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the indices")):
|
|
fn()
|
|
|
|
def test_mismatching_values_msg(self):
|
|
actual_indices = (
|
|
(0, 1),
|
|
(1, 0),
|
|
)
|
|
actual_values = (1, 2)
|
|
actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
|
|
|
|
expected_indices = (
|
|
(0, 1),
|
|
(1, 0),
|
|
)
|
|
expected_values = (1, 3)
|
|
expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
|
|
fn()
|
|
|
|
|
|
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
|
|
class TestAssertCloseSparseCSR(TestCase):
|
|
def test_matching(self):
|
|
crow_indices = (0, 1, 2)
|
|
col_indices = (1, 0)
|
|
values = (1, 2)
|
|
actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
|
|
# TODO: replace this by actual.clone() after https://github.com/pytorch/pytorch/issues/59285 is fixed
|
|
expected = torch.sparse_csr_tensor(
|
|
actual.crow_indices(), actual.col_indices(), actual.values(), size=actual.size(), device=actual.device
|
|
)
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
fn()
|
|
|
|
def test_mismatching_crow_indices_msg(self):
|
|
actual_crow_indices = (0, 1, 2)
|
|
actual_col_indices = (1, 0)
|
|
actual_values = (1, 2)
|
|
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
|
|
|
|
expected_crow_indices = (0, 2, 2)
|
|
expected_col_indices = actual_col_indices
|
|
expected_values = actual_values
|
|
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the crow_indices")):
|
|
fn()
|
|
|
|
def test_mismatching_col_indices_msg(self):
|
|
actual_crow_indices = (0, 1, 2)
|
|
actual_col_indices = (1, 0)
|
|
actual_values = (1, 2)
|
|
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
|
|
|
|
expected_crow_indices = actual_crow_indices
|
|
expected_col_indices = (1, 1)
|
|
expected_values = actual_values
|
|
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the col_indices")):
|
|
fn()
|
|
|
|
def test_mismatching_values_msg(self):
|
|
actual_crow_indices = (0, 1, 2)
|
|
actual_col_indices = (1, 0)
|
|
actual_values = (1, 2)
|
|
actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
|
|
|
|
expected_crow_indices = actual_crow_indices
|
|
expected_col_indices = actual_col_indices
|
|
expected_values = (1, 3)
|
|
expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
|
|
|
|
for fn in assert_close_with_inputs(actual, expected):
|
|
with self.assertRaisesRegex(AssertionError, re.escape("The failure occurred for the values")):
|
|
fn()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|