Files
pytorch/test/test_testing.py
Philip Meier dbf3451c6e Add support for checking tensor containers in torch.testing (#55385)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55385

This renames `assert_tensors_(equal|close)` to `_check_tensors_(equal|close)` and exposes two new functions: `assert_(equal|close)`. In addition to tensor pairs, the newly added functions also support the comparison of tensors in sequences or mappings. Otherwise their signature stays the same.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D27903805

Pulled By: mruberry

fbshipit-source-id: 719d19a1d26de8d14cb25846e3d22a6ac828c80a
2021-04-24 23:36:36 -07:00

1078 lines
43 KiB
Python

import collections
import functools
import itertools
import math
import os
import random
import re
import unittest
from typing import Any, Callable, Iterator, List, Mapping, Sequence, Tuple, TypeVar
import numpy as np
import torch
from torch.testing._internal.common_utils import \
(IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
from torch.testing._internal.framework_utils import calculate_shards
from torch.testing._internal.common_device_type import \
(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyOnCPUAndCUDA)
from torch.testing._asserts import UsageError
# For testing TestCase methods and torch.testing functions
class TestTesting(TestCase):
# Ensure that assertEqual handles numpy arrays properly
@dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False,
include_bool=True, include_complex=True)))
def test_assertEqual_numpy(self, device, dtype):
S = 10
test_sizes = [
(),
(0,),
(S,),
(S, S),
(0, S),
(S, 0)]
for test_size in test_sizes:
a = make_tensor(test_size, device, dtype, low=-5, high=5)
a_n = a.cpu().numpy()
msg = f'size: {test_size}'
self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg)
self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg)
self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg)
# Tests that when rtol or atol (including self.precision) is set, then
# the other is zeroed.
# TODO: this is legacy behavior and should be updated after test
# precisions are reviewed to be consistent with torch.isclose.
@onlyOnCPUAndCUDA
def test__comparetensors_legacy(self, device):
a = torch.tensor((10000000.,))
b = torch.tensor((10000002.,))
x = torch.tensor((1.,))
y = torch.tensor((1. + 1e-5,))
# Helper for reusing the tensor values as scalars
def _scalar_helper(a, b, rtol=None, atol=None):
return self._compareScalars(a.item(), b.item(), rtol=rtol, atol=atol)
for op in (self._compareTensors, _scalar_helper):
# Tests default
result, debug_msg = op(a, b)
self.assertTrue(result)
# Tests setting atol
result, debug_msg = op(a, b, atol=2, rtol=0)
self.assertTrue(result)
# Tests setting atol too small
result, debug_msg = op(a, b, atol=1, rtol=0)
self.assertFalse(result)
# Tests setting rtol too small
result, debug_msg = op(x, y, atol=0, rtol=1.05e-5)
self.assertTrue(result)
# Tests setting rtol too small
result, debug_msg = op(x, y, atol=0, rtol=1e-5)
self.assertFalse(result)
@onlyOnCPUAndCUDA
def test__comparescalars_debug_msg(self, device):
# float x float
result, debug_msg = self._compareScalars(4., 7.)
expected_msg = ("Comparing 4.0 and 7.0 gives a difference of 3.0, "
"but the allowed difference with rtol=1.3e-06 and "
"atol=1e-05 is only 1.9100000000000003e-05!")
self.assertEqual(debug_msg, expected_msg)
# complex x complex, real difference
result, debug_msg = self._compareScalars(complex(1, 3), complex(3, 1))
expected_msg = ("Comparing the real part 1.0 and 3.0 gives a difference "
"of 2.0, but the allowed difference with rtol=1.3e-06 "
"and atol=1e-05 is only 1.39e-05!")
self.assertEqual(debug_msg, expected_msg)
# complex x complex, imaginary difference
result, debug_msg = self._compareScalars(complex(1, 3), complex(1, 5.5))
expected_msg = ("Comparing the imaginary part 3.0 and 5.5 gives a "
"difference of 2.5, but the allowed difference with "
"rtol=1.3e-06 and atol=1e-05 is only 1.715e-05!")
self.assertEqual(debug_msg, expected_msg)
# complex x int
result, debug_msg = self._compareScalars(complex(1, -2), 1)
expected_msg = ("Comparing the imaginary part -2.0 and 0.0 gives a "
"difference of 2.0, but the allowed difference with "
"rtol=1.3e-06 and atol=1e-05 is only 1e-05!")
self.assertEqual(debug_msg, expected_msg)
# NaN x NaN, equal_nan=False
result, debug_msg = self._compareScalars(float('nan'), float('nan'), equal_nan=False)
expected_msg = ("Found nan and nan while comparing and either one is "
"nan and the other isn't, or both are nan and equal_nan "
"is False")
self.assertEqual(debug_msg, expected_msg)
# Checks that compareTensors provides the correct debug info
@onlyOnCPUAndCUDA
def test__comparetensors_debug_msg(self, device):
# Acquires atol that will be used
atol = max(1e-05, self.precision)
# Checks float tensor comparisons (2D tensor)
a = torch.tensor(((0, 6), (7, 9)), device=device, dtype=torch.float32)
b = torch.tensor(((0, 7), (7, 22)), device=device, dtype=torch.float32)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 4) "
"whose difference(s) exceeded the margin of error (including 0 nan comparisons). "
"The greatest difference was 13.0 (9.0 vs. 22.0), "
"which occurred at index (1, 1).").format(atol)
self.assertEqual(debug_msg, expected_msg)
# Checks float tensor comparisons (with extremal values)
a = torch.tensor((float('inf'), 5, float('inf')), device=device, dtype=torch.float32)
b = torch.tensor((float('inf'), float('nan'), float('-inf')), device=device, dtype=torch.float32)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 3) "
"whose difference(s) exceeded the margin of error (including 1 nan comparisons). "
"The greatest difference was nan (5.0 vs. nan), "
"which occurred at index 1.").format(atol)
self.assertEqual(debug_msg, expected_msg)
# Checks float tensor comparisons (with finite vs nan differences)
a = torch.tensor((20, -6), device=device, dtype=torch.float32)
b = torch.tensor((-1, float('nan')), device=device, dtype=torch.float32)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 2) "
"whose difference(s) exceeded the margin of error (including 1 nan comparisons). "
"The greatest difference was nan (-6.0 vs. nan), "
"which occurred at index 1.").format(atol)
self.assertEqual(debug_msg, expected_msg)
# Checks int tensor comparisons (1D tensor)
a = torch.tensor((1, 2, 3, 4), device=device)
b = torch.tensor((2, 5, 3, 4), device=device)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("Found 2 different element(s) (out of 4), "
"with the greatest difference of 3 (2 vs. 5) "
"occuring at index 1.")
self.assertEqual(debug_msg, expected_msg)
# Checks bool tensor comparisons (0D tensor)
a = torch.tensor((True), device=device)
b = torch.tensor((False), device=device)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("Found 1 different element(s) (out of 1), "
"with the greatest difference of 1 (1 vs. 0) "
"occuring at index 0.")
self.assertEqual(debug_msg, expected_msg)
# Checks complex tensor comparisons (real part)
a = torch.tensor((1 - 1j, 4 + 3j), device=device)
b = torch.tensor((1 - 1j, 1 + 3j), device=device)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("Real parts failed to compare as equal! "
"With rtol=1.3e-06 and atol={0}, "
"found 1 element(s) (out of 2) whose difference(s) exceeded the "
"margin of error (including 0 nan comparisons). The greatest difference was "
"3.0 (4.0 vs. 1.0), which occurred at index 1.").format(atol)
self.assertEqual(debug_msg, expected_msg)
# Checks complex tensor comparisons (imaginary part)
a = torch.tensor((1 - 1j, 4 + 3j), device=device)
b = torch.tensor((1 - 1j, 4 - 21j), device=device)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("Imaginary parts failed to compare as equal! "
"With rtol=1.3e-06 and atol={0}, "
"found 1 element(s) (out of 2) whose difference(s) exceeded the "
"margin of error (including 0 nan comparisons). The greatest difference was "
"24.0 (3.0 vs. -21.0), which occurred at index 1.").format(atol)
self.assertEqual(debug_msg, expected_msg)
# Checks size mismatch
a = torch.tensor((1, 2), device=device)
b = torch.tensor((3), device=device)
result, debug_msg = self._compareTensors(a, b)
expected_msg = ("Attempted to compare equality of tensors "
"with different sizes. Got sizes torch.Size([2]) and torch.Size([]).")
self.assertEqual(debug_msg, expected_msg)
# Checks dtype mismatch
a = torch.tensor((1, 2), device=device, dtype=torch.long)
b = torch.tensor((1, 2), device=device, dtype=torch.float32)
result, debug_msg = self._compareTensors(a, b, exact_dtype=True)
expected_msg = ("Attempted to compare equality of tensors "
"with different dtypes. Got dtypes torch.int64 and torch.float32.")
self.assertEqual(debug_msg, expected_msg)
# Checks device mismatch
if self.device_type == 'cuda':
a = torch.tensor((5), device='cpu')
b = torch.tensor((5), device=device)
result, debug_msg = self._compareTensors(a, b, exact_device=True)
expected_msg = ("Attempted to compare equality of tensors "
"on different devices! Got devices cpu and cuda:0.")
self.assertEqual(debug_msg, expected_msg)
# Helper for testing _compareTensors and _compareScalars
# Works on single element tensors
def _comparetensors_helper(self, tests, device, dtype, equal_nan, exact_dtype=True, atol=1e-08, rtol=1e-05):
for test in tests:
a = torch.tensor((test[0],), device=device, dtype=dtype)
b = torch.tensor((test[1],), device=device, dtype=dtype)
# Tensor x Tensor comparison
compare_result, debug_msg = self._compareTensors(a, b, rtol=rtol, atol=atol,
equal_nan=equal_nan,
exact_dtype=exact_dtype)
self.assertEqual(compare_result, test[2])
# Scalar x Scalar comparison
compare_result, debug_msg = self._compareScalars(a.item(), b.item(),
rtol=rtol, atol=atol,
equal_nan=equal_nan)
self.assertEqual(compare_result, test[2])
def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
for test in tests:
a = torch.tensor((test[0],), device=device, dtype=dtype)
b = torch.tensor((test[1],), device=device, dtype=dtype)
actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
expected = test[2]
self.assertEqual(actual.item(), expected)
# torch.close is not implemented for bool tensors
# see https://github.com/pytorch/pytorch/issues/33048
def test_isclose_comparetensors_bool(self, device):
tests = (
(True, True, True),
(False, False, True),
(True, False, False),
(False, True, False),
)
with self.assertRaises(RuntimeError):
self._isclose_helper(tests, device, torch.bool, False)
self._comparetensors_helper(tests, device, torch.bool, False)
@dtypes(torch.uint8,
torch.int8, torch.int16, torch.int32, torch.int64)
def test_isclose_comparetensors_integer(self, device, dtype):
tests = (
(0, 0, True),
(0, 1, False),
(1, 0, False),
)
self._isclose_helper(tests, device, dtype, False)
# atol and rtol tests
tests = [
(0, 1, True),
(1, 0, False),
(1, 3, True),
]
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
if dtype is torch.uint8:
tests = [
(-1, 1, False),
(1, -1, False)
]
else:
tests = [
(-1, 1, True),
(1, -1, True)
]
self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
self._comparetensors_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
@onlyOnCPUAndCUDA
@dtypes(torch.float16, torch.float32, torch.float64)
def test_isclose_comparetensors_float(self, device, dtype):
tests = (
(0, 0, True),
(0, -1, False),
(float('inf'), float('inf'), True),
(-float('inf'), float('inf'), False),
(float('inf'), float('nan'), False),
(float('nan'), float('nan'), False),
(0, float('nan'), False),
(1, 1, True),
)
self._isclose_helper(tests, device, dtype, False)
self._comparetensors_helper(tests, device, dtype, False)
# atol and rtol tests
eps = 1e-2 if dtype is torch.half else 1e-6
tests = (
(0, 1, True),
(0, 1 + eps, False),
(1, 0, False),
(1, 3, True),
(1 - eps, 3, False),
(-.25, .5, True),
(-.25 - eps, .5, False),
(.25, -.5, True),
(.25 + eps, -.5, False),
)
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
# equal_nan = True tests
tests = (
(0, float('nan'), False),
(float('inf'), float('nan'), False),
(float('nan'), float('nan'), True),
)
self._isclose_helper(tests, device, dtype, True)
self._comparetensors_helper(tests, device, dtype, True)
# torch.close with equal_nan=True is not implemented for complex inputs
# see https://github.com/numpy/numpy/issues/15959
# Note: compareTensor will compare the real and imaginary parts of a
# complex tensors separately, unlike isclose.
@dtypes(torch.complex64, torch.complex128)
def test_isclose_comparetensors_complex(self, device, dtype):
tests = (
(complex(1, 1), complex(1, 1 + 1e-8), True),
(complex(0, 1), complex(1, 1), False),
(complex(1, 1), complex(1, 0), False),
(complex(1, 1), complex(1, float('nan')), False),
(complex(1, float('nan')), complex(1, float('nan')), False),
(complex(1, 1), complex(1, float('inf')), False),
(complex(float('inf'), 1), complex(1, float('inf')), False),
(complex(-float('inf'), 1), complex(1, float('inf')), False),
(complex(-float('inf'), 1), complex(float('inf'), 1), False),
(complex(float('inf'), 1), complex(float('inf'), 1), True),
(complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
)
self._isclose_helper(tests, device, dtype, False)
self._comparetensors_helper(tests, device, dtype, False)
# atol and rtol tests
# atol and rtol tests
eps = 1e-6
tests = (
# Complex versions of float tests (real part)
(complex(0, 0), complex(1, 0), True),
(complex(0, 0), complex(1 + eps, 0), False),
(complex(1, 0), complex(0, 0), False),
(complex(1, 0), complex(3, 0), True),
(complex(1 - eps, 0), complex(3, 0), False),
(complex(-.25, 0), complex(.5, 0), True),
(complex(-.25 - eps, 0), complex(.5, 0), False),
(complex(.25, 0), complex(-.5, 0), True),
(complex(.25 + eps, 0), complex(-.5, 0), False),
# Complex versions of float tests (imaginary part)
(complex(0, 0), complex(0, 1), True),
(complex(0, 0), complex(0, 1 + eps), False),
(complex(0, 1), complex(0, 0), False),
(complex(0, 1), complex(0, 3), True),
(complex(0, 1 - eps), complex(0, 3), False),
(complex(0, -.25), complex(0, .5), True),
(complex(0, -.25 - eps), complex(0, .5), False),
(complex(0, .25), complex(0, -.5), True),
(complex(0, .25 + eps), complex(0, -.5), False),
)
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
# atol and rtol tests for isclose
tests = (
# Complex-specific tests
(complex(1, -1), complex(-1, 1), False),
(complex(1, -1), complex(2, -2), True),
(complex(-math.sqrt(2), math.sqrt(2)),
complex(-math.sqrt(.5), math.sqrt(.5)), True),
(complex(-math.sqrt(2), math.sqrt(2)),
complex(-math.sqrt(.501), math.sqrt(.499)), False),
(complex(2, 4), complex(1., 8.8523607), True),
(complex(2, 4), complex(1., 8.8523607 + eps), False),
(complex(1, 99), complex(4, 100), True),
)
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
# atol and rtol tests for compareTensors
tests = (
(complex(1, -1), complex(-1, 1), False),
(complex(1, -1), complex(2, -2), True),
(complex(1, 99), complex(4, 100), False),
)
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
# equal_nan = True tests
tests = (
(complex(1, 1), complex(1, float('nan')), False),
(complex(float('nan'), 1), complex(1, float('nan')), False),
(complex(float('nan'), 1), complex(float('nan'), 1), True),
)
with self.assertRaises(RuntimeError):
self._isclose_helper(tests, device, dtype, True)
self._comparetensors_helper(tests, device, dtype, True)
# Tests that isclose with rtol or atol values less than zero throws a
# RuntimeError
@dtypes(torch.bool, torch.uint8,
torch.int8, torch.int16, torch.int32, torch.int64,
torch.float16, torch.float32, torch.float64)
def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
t = torch.tensor((1,), device=device, dtype=dtype)
with self.assertRaises(RuntimeError):
torch.isclose(t, t, atol=-1, rtol=1)
with self.assertRaises(RuntimeError):
torch.isclose(t, t, atol=1, rtol=-1)
with self.assertRaises(RuntimeError):
torch.isclose(t, t, atol=-1, rtol=-1)
@dtypes(torch.bool, torch.long, torch.float, torch.cfloat)
def test_make_tensor(self, device, dtype):
def check(size, low, high, requires_grad, noncontiguous):
t = make_tensor(size, device, dtype, low=low, high=high,
requires_grad=requires_grad, noncontiguous=noncontiguous)
self.assertEqual(t.shape, size)
self.assertEqual(t.device, torch.device(device))
self.assertEqual(t.dtype, dtype)
low = -9 if low is None else low
high = 9 if high is None else high
if t.numel() > 0 and dtype in [torch.long, torch.float]:
self.assertTrue(t.le(high).logical_and(t.ge(low)).all().item())
if dtype in [torch.float, torch.cfloat]:
self.assertEqual(t.requires_grad, requires_grad)
else:
self.assertFalse(t.requires_grad)
if t.numel() > 1:
self.assertEqual(t.is_contiguous(), not noncontiguous)
else:
self.assertTrue(t.is_contiguous())
for size in (tuple(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)):
check(size, None, None, False, False)
check(size, 2, 4, True, True)
def test_assert_messages(self, device):
self.assertIsNone(self._get_assert_msg(msg=None))
self.assertEqual("\nno_debug_msg", self._get_assert_msg("no_debug_msg"))
self.assertEqual("no_user_msg", self._get_assert_msg(msg=None, debug_msg="no_user_msg"))
self.assertEqual("debug_msg\nuser_msg", self._get_assert_msg(msg="user_msg", debug_msg="debug_msg"))
# The following tests (test_cuda_assert_*) are added to ensure test suite terminates early
# when CUDA assert was thrown. Because all subsequent test will fail if that happens.
# These tests are slow because it spawn another process to run test suite.
# See: https://github.com/pytorch/pytorch/issues/49019
@onlyCUDA
@slowTest
def test_cuda_assert_should_stop_common_utils_test_suite(self, device):
# test to ensure common_utils.py override has early termination for CUDA.
stderr = TestCase.runWithPytorchAPIUsageStderr("""\
#!/usr/bin/env python
import torch
from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
class TestThatContainsCUDAAssertFailure(TestCase):
@slowTest
def test_throw_unrecoverable_cuda_exception(self):
x = torch.rand(10, device='cuda')
# cause unrecoverable CUDA exception, recoverable on CPU
y = x[torch.tensor([25])].cpu()
@slowTest
def test_trivial_passing_test_case_on_cpu_cuda(self):
x1 = torch.tensor([0., 1.], device='cuda')
x2 = torch.tensor([0., 1.], device='cpu')
self.assertEqual(x1, x2)
if __name__ == '__main__':
run_tests()
""")
# should capture CUDA error
self.assertIn('CUDA error: device-side assert triggered', stderr)
# should run only 1 test because it throws unrecoverable error.
self.assertIn('Ran 1 test', stderr)
@onlyCUDA
@slowTest
def test_cuda_assert_should_stop_common_device_type_test_suite(self, device):
# test to ensure common_device_type.py override has early termination for CUDA.
stderr = TestCase.runWithPytorchAPIUsageStderr("""\
#!/usr/bin/env python
import torch
from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
class TestThatContainsCUDAAssertFailure(TestCase):
@slowTest
def test_throw_unrecoverable_cuda_exception(self, device):
x = torch.rand(10, device=device)
# cause unrecoverable CUDA exception, recoverable on CPU
y = x[torch.tensor([25])].cpu()
@slowTest
def test_trivial_passing_test_case_on_cpu_cuda(self, device):
x1 = torch.tensor([0., 1.], device=device)
x2 = torch.tensor([0., 1.], device='cpu')
self.assertEqual(x1, x2)
instantiate_device_type_tests(
TestThatContainsCUDAAssertFailure,
globals(),
only_for='cuda'
)
if __name__ == '__main__':
run_tests()
""")
# should capture CUDA error
self.assertIn('CUDA error: device-side assert triggered', stderr)
# should run only 1 test because it throws unrecoverable error.
self.assertIn('Ran 1 test', stderr)
@onlyCUDA
@slowTest
def test_cuda_assert_should_not_stop_common_distributed_test_suite(self, device):
# test to ensure common_distributed.py override should not early terminate CUDA.
stderr = TestCase.runWithPytorchAPIUsageStderr("""\
#!/usr/bin/env python
import torch
from torch.testing._internal.common_utils import (run_tests, slowTest)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
class TestThatContainsCUDAAssertFailure(MultiProcessTestCase):
@slowTest
def test_throw_unrecoverable_cuda_exception(self, device):
x = torch.rand(10, device=device)
# cause unrecoverable CUDA exception, recoverable on CPU
y = x[torch.tensor([25])].cpu()
@slowTest
def test_trivial_passing_test_case_on_cpu_cuda(self, device):
x1 = torch.tensor([0., 1.], device=device)
x2 = torch.tensor([0., 1.], device='cpu')
self.assertEqual(x1, x2)
instantiate_device_type_tests(
TestThatContainsCUDAAssertFailure,
globals(),
only_for='cuda'
)
if __name__ == '__main__':
run_tests()
""")
# we are currently disabling CUDA early termination for distributed tests.
self.assertIn('Ran 2 test', stderr)
instantiate_device_type_tests(TestTesting, globals())
class TestFrameworkUtils(TestCase):
tests = [
'super_long_test',
'long_test1',
'long_test2',
'normal_test1',
'normal_test2',
'normal_test3',
'short_test1',
'short_test2',
'short_test3',
'short_test4',
'short_test5',
]
test_times = {
'super_long_test': 55,
'long_test1': 22,
'long_test2': 18,
'normal_test1': 9,
'normal_test2': 7,
'normal_test3': 5,
'short_test1': 1,
'short_test2': 0.6,
'short_test3': 0.4,
'short_test4': 0.3,
'short_test5': 0.01,
}
def test_calculate_2_shards_with_complete_test_times(self):
expected_shards = [
(60, ['super_long_test', 'normal_test3']),
(58.31, ['long_test1', 'long_test2', 'normal_test1', 'normal_test2', 'short_test1', 'short_test2',
'short_test3', 'short_test4', 'short_test5'])
]
self.assertEqual(expected_shards, calculate_shards(2, self.tests, self.test_times))
def test_calculate_5_shards_with_complete_test_times(self):
expected_shards = [
(55, ['super_long_test']),
(22, ['long_test1', ]),
(18, ['long_test2', ]),
(11.31, ['normal_test1', 'short_test1', 'short_test2', 'short_test3', 'short_test4', 'short_test5']),
(12, ['normal_test2', 'normal_test3']),
]
self.assertEqual(expected_shards, calculate_shards(5, self.tests, self.test_times))
def test_calculate_2_shards_with_incomplete_test_times(self):
incomplete_test_times = {k: v for k, v in self.test_times.items() if 'test1' in k}
expected_shards = [
(22, ['long_test1', 'long_test2', 'normal_test3', 'short_test3', 'short_test5']),
(10, ['normal_test1', 'short_test1', 'super_long_test', 'normal_test2', 'short_test2', 'short_test4']),
]
self.assertEqual(expected_shards, calculate_shards(2, self.tests, incomplete_test_times))
def test_calculate_5_shards_with_incomplete_test_times(self):
incomplete_test_times = {k: v for k, v in self.test_times.items() if 'test1' in k}
expected_shards = [
(22, ['long_test1', 'normal_test2', 'short_test5']),
(9, ['normal_test1', 'normal_test3']),
(1, ['short_test1', 'short_test2']),
(0, ['super_long_test', 'short_test3']),
(0, ['long_test2', 'short_test4']),
]
self.assertEqual(expected_shards, calculate_shards(5, self.tests, incomplete_test_times))
def test_calculate_2_shards_against_optimal_shards(self):
for _ in range(100):
random.seed(120)
random_times = {k: random.random() * 10 for k in self.tests}
# all test times except first two
rest_of_tests = [i for k, i in random_times.items() if k != 'super_long_test' and k != 'long_test1']
sum_of_rest = sum(rest_of_tests)
random_times['super_long_test'] = max(sum_of_rest / 2, max(rest_of_tests))
random_times['long_test1'] = sum_of_rest - random_times['super_long_test']
# An optimal sharding would look like the below, but we don't need to compute this for the test:
# optimal_shards = [
# (sum_of_rest, ['super_long_test', 'long_test1']),
# (sum_of_rest, [i for i in self.tests if i != 'super_long_test' and i != 'long_test1']),
# ]
calculated_shards = calculate_shards(2, self.tests, random_times)
max_shard_time = max(calculated_shards[0][0], calculated_shards[1][0])
if sum_of_rest != 0:
# The calculated shard should not have a ratio worse than 7/6 for num_shards = 2
self.assertGreaterEqual(7.0 / 6.0, max_shard_time / sum_of_rest)
sorted_tests = sorted(self.tests)
sorted_shard_tests = sorted(calculated_shards[0][1] + calculated_shards[1][1])
# All the tests should be represented by some shard
self.assertEqual(sorted_tests, sorted_shard_tests)
@skipIfRocm
@unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
@unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
def test_filtering_env_var(self):
# Test environment variable selected device type test generator.
test_filter_file_template = """\
#!/usr/bin/env python
import torch
from torch.testing._internal.common_utils import (TestCase, run_tests)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
class TestEnvironmentVariable(TestCase):
def test_trivial_passing_test(self, device):
x1 = torch.tensor([0., 1.], device=device)
x2 = torch.tensor([0., 1.], device='cpu')
self.assertEqual(x1, x2)
instantiate_device_type_tests(
TestEnvironmentVariable,
globals(),
)
if __name__ == '__main__':
run_tests()
"""
test_bases_count = len(get_device_type_test_bases())
# Test without setting env var should run everything.
env = dict(os.environ)
for k in ['IN_CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]:
if k in env.keys():
del env[k]
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii'))
# Test with setting only_for should only run 1 test.
env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
self.assertIn('Ran 1 test', stderr.decode('ascii'))
# Test with setting except_for should run 1 less device type from default.
del env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY]
env[PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY] = 'cpu'
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
self.assertIn(f'Ran {test_bases_count-1} test', stderr.decode('ascii'))
# Test with setting both should throw exception
env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
_, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
self.assertNotIn('OK', stderr.decode('ascii'))
T = TypeVar("T", torch.Tensor, Sequence[torch.Tensor], Mapping[Any, torch.Tensor])
class TestAsserts(TestCase):
def get_assert_fns(self) -> List[Callable]:
"""Gets assert functions to be tested.
Returns:
List(Callable): Top-level assert functions from :mod:`torch.testing`.
"""
return [torch.testing.assert_equal, torch.testing.assert_close]
def make_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> List[Tuple[T, T]]:
"""Makes inputs for assert functions based on two example tensors.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
Returns:
List[Tuple[T, T]]: Pairs of tensors, tensor sequences (:class:`tuple`, :class:`list`), and tensor mappings
(:class:`dict`, :class:`~collections.OrderedDict`)
"""
return [
(actual, expected),
((actual,), (expected,)),
([actual], [expected]),
({"t": actual}, {"t": expected}),
(collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
]
def assert_fns_with_inputs(self, actual: torch.Tensor, expected: torch.Tensor) -> Iterator[Callable]:
"""Yields assert functions with with included positional inputs based on two example tensors.
.. note::
This is a valid product of combinations from :meth:`get_assert_fns` and :meth:`make_inputs`. Every test
that does not test for anything specific should iterate over this to maximize the coverage.
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
Yields:
List[Callable]: Assert functions with predefined positional inputs.
"""
for assert_fn, inputs in itertools.product(self.get_assert_fns(), self.make_inputs(actual, expected)):
yield functools.partial(assert_fn, *inputs)
@onlyCPU
def test_not_tensors(self, device):
actual = torch.empty((), device=device)
expected = np.empty(())
for fn in self.get_assert_fns():
with self.assertRaises(UsageError):
fn(actual, expected)
@onlyCPU
def test_complex_support(self, device):
actual = torch.ones(1, dtype=torch.float32, device=device)
expected = torch.ones(1, dtype=torch.complex64, device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn(check_dtype=False)
@onlyCPU
def test_sparse_support(self, device):
actual = torch.empty((), device=device)
expected = torch.sparse_coo_tensor(size=(), device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn()
@onlyCPU
def test_quantized_support(self, device):
val = 1
actual = torch.tensor([val], dtype=torch.int32, device=device)
expected = torch._empty_affine_quantized(actual.shape, scale=1, zero_point=0, dtype=torch.qint32, device=device)
expected.fill_(val)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(UsageError):
fn()
@onlyCPU
def test_mismatching_shape(self, device):
actual = torch.empty((), device=device)
expected = actual.clone().reshape((1,))
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "shape"):
fn()
@onlyCUDA
def test_mismatching_device(self, device):
actual = torch.empty((), device=device)
expected = actual.clone().cpu()
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "device"):
fn()
@onlyCUDA
def test_mismatching_device_no_check(self, device):
actual = torch.rand((), device=device)
expected = actual.clone().cpu()
for fn in self.assert_fns_with_inputs(actual, expected):
fn(check_device=False)
@onlyCPU
def test_mismatching_dtype(self, device):
actual = torch.empty((), dtype=torch.float, device=device)
expected = actual.clone().to(torch.int)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "dtype"):
fn()
@onlyCPU
def test_mismatching_dtype_no_check(self, device):
actual = torch.ones((), dtype=torch.float, device=device)
expected = actual.clone().to(torch.int)
for fn in self.assert_fns_with_inputs(actual, expected):
fn(check_dtype=False)
@onlyCPU
def test_mismatching_stride(self, device):
actual = torch.empty((2, 2), device=device)
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, "stride"):
fn()
@onlyCPU
def test_mismatching_stride_no_check(self, device):
actual = torch.rand((2, 2), device=device)
expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
for fn in self.assert_fns_with_inputs(actual, expected):
fn(check_stride=False)
@onlyCPU
def test_mismatching_values(self, device):
actual = torch.tensor(1, device=device)
expected = torch.tensor(2, device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaises(AssertionError):
fn()
@onlyCPU
def test_assert_equal(self, device):
actual = torch.tensor(1, device=device)
expected = actual.clone()
torch.testing.assert_equal(actual, expected)
@onlyCPU
def test_assert_close(self, device):
actual = torch.tensor(1.0, device=device)
expected = actual.clone()
torch.testing.assert_close(actual, expected)
@onlyCPU
def test_assert_close_only_rtol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_close(actual, expected, rtol=0.0)
@onlyCPU
def test_assert_close_only_atol(self, device):
actual = torch.empty((), device=device)
expected = actual.clone()
with self.assertRaises(UsageError):
torch.testing.assert_close(actual, expected, atol=0.0)
@onlyCPU
def test_assert_close_mismatching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_close(actual, expected, rtol=eps / 2, atol=0.0)
@onlyCPU
def test_assert_close_matching_values_rtol(self, device):
eps = 1e-3
actual = torch.tensor(1.0, device=device)
expected = torch.tensor(1.0 + eps, device=device)
torch.testing.assert_close(actual, expected, rtol=eps * 2, atol=0.0)
@onlyCPU
def test_assert_close_mismatching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
with self.assertRaises(AssertionError):
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps / 2)
@onlyCPU
def test_assert_close_matching_values_atol(self, device):
eps = 1e-3
actual = torch.tensor(0.0, device=device)
expected = torch.tensor(eps, device=device)
torch.testing.assert_close(actual, expected, rtol=0.0, atol=eps * 2)
@onlyCPU
def test_mismatching_values_msg_mismatches(self, device):
actual = torch.tensor([1, 2, 3, 4], device=device)
expected = torch.tensor([1, 2, 5, 6], device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
fn()
@onlyCPU
def test_mismatching_values_msg_abs_diff(self, device):
actual = torch.tensor([[1, 2], [3, 4]], device=device)
expected = torch.tensor([[1, 2], [5, 4]], device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at (1, 0)")):
fn()
@onlyCPU
def test_mismatching_values_msg_rel_diff(self, device):
actual = torch.tensor([[1, 2], [3, 4]], device=device)
expected = torch.tensor([[1, 4], [3, 4]], device=device)
for fn in self.assert_fns_with_inputs(actual, expected):
with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at (0, 1)")):
fn()
@onlyCPU
def test_assert_close_mismatching_values_msg_rtol(self, device):
rtol = 1e-3
actual = torch.tensor(1, device=device)
expected = torch.tensor(2, device=device)
for inputs in self.make_inputs(actual, expected):
with self.assertRaisesRegex(
AssertionError, re.escape(f"Greatest relative difference: 0.5 at 0 (up to {rtol} allowed)")
):
torch.testing.assert_close(*inputs, rtol=rtol, atol=0.0)
@onlyCPU
def test_assert_close_mismatching_values_msg_atol(self, device):
atol = 1e-3
actual = torch.tensor(1, device=device)
expected = torch.tensor(2, device=device)
for inputs in self.make_inputs(actual, expected):
with self.assertRaisesRegex(
AssertionError, re.escape(f"Greatest absolute difference: 1 at 0 (up to {atol} allowed)")
):
torch.testing.assert_close(*inputs, rtol=0.0, atol=atol)
@onlyCPU
def test_unknown_type(self, device):
actual = torch.empty((), device=device)
expected = {actual.clone()}
for fn in self.get_assert_fns():
with self.assertRaisesRegex(UsageError, str(type(expected))):
fn(actual, expected)
@onlyCPU
def test_sequence_mismatching_len(self, device):
actual = (torch.empty((), device=device),)
expected = ()
for fn in self.get_assert_fns():
with self.assertRaises(AssertionError):
fn(actual, expected)
@onlyCPU
def test_sequence_mismatching_values_msg(self, device):
t1 = torch.tensor(1, device=device)
t2 = torch.tensor(2, device=device)
actual = (t1, t1)
expected = (t1, t2)
for fn in self.get_assert_fns():
with self.assertRaisesRegex(AssertionError, r"index\s+1"):
fn(actual, expected)
@onlyCPU
def test_mapping_mismatching_keys(self, device):
actual = {"a": torch.empty((), device=device)}
expected = {}
for fn in self.get_assert_fns():
with self.assertRaises(AssertionError):
fn(actual, expected)
@onlyCPU
def test_mapping_mismatching_values_msg(self, device):
t1 = torch.tensor(1, device=device)
t2 = torch.tensor(2, device=device)
actual = {"a": t1, "b": t1}
expected = {"a": t1, "b": t2}
for fn in self.get_assert_fns():
with self.assertRaisesRegex(AssertionError, r"key\s+'b'"):
fn(actual, expected)
instantiate_device_type_tests(TestAsserts, globals())
if __name__ == '__main__':
run_tests()