mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			2536 lines
		
	
	
		
			100 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			2536 lines
		
	
	
		
			100 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: tests"]
 | |
| 
 | |
| import collections
 | |
| import doctest
 | |
| import functools
 | |
| import importlib
 | |
| import inspect
 | |
| import itertools
 | |
| import math
 | |
| import os
 | |
| import re
 | |
| import subprocess
 | |
| import sys
 | |
| import unittest.mock
 | |
| from typing import Any
 | |
| from collections.abc import Callable
 | |
| from collections.abc import Iterator
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from torch.testing import make_tensor
 | |
| from torch.testing._internal.common_utils import (
 | |
|     IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest,
 | |
|     parametrize, reparametrize, subtest, instantiate_parametrized_tests, dtype_name,
 | |
|     TEST_WITH_ROCM, decorateIf, skipIfRocm
 | |
| )
 | |
| from torch.testing._internal.common_device_type import \
 | |
|     (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
 | |
|      get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes,
 | |
|      deviceCountAtLeast, ops, expectedFailureMeta, OpDTypes)
 | |
| from torch.testing._internal.common_methods_invocations import op_db
 | |
| from torch.testing._internal import opinfo
 | |
| from torch.testing._internal.common_dtype import all_types_and_complex_and, floating_types
 | |
| from torch.testing._internal.common_modules import modules, module_db, ModuleInfo
 | |
| from torch.testing._internal.opinfo.core import SampleInput, DecorateInfo, OpInfo
 | |
| import operator
 | |
| import string
 | |
| 
 | |
| # For testing TestCase methods and torch.testing functions
 | |
| class TestTesting(TestCase):
 | |
|     # Ensure that assertEqual handles numpy arrays properly
 | |
|     @dtypes(*all_types_and_complex_and(torch.bool, torch.half))
 | |
|     def test_assertEqual_numpy(self, device, dtype):
 | |
|         S = 10
 | |
|         test_sizes = [
 | |
|             (),
 | |
|             (0,),
 | |
|             (S,),
 | |
|             (S, S),
 | |
|             (0, S),
 | |
|             (S, 0)]
 | |
|         for test_size in test_sizes:
 | |
|             a = make_tensor(test_size, dtype=dtype, device=device, low=-5, high=5)
 | |
|             a_n = a.cpu().numpy()
 | |
|             msg = f'size: {test_size}'
 | |
|             self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg)
 | |
|             self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg)
 | |
|             self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg)
 | |
| 
 | |
|     def test_assertEqual_longMessage(self):
 | |
|         actual = "actual"
 | |
|         expected = "expected"
 | |
| 
 | |
|         long_message = self.longMessage
 | |
|         try:
 | |
|             # Capture the default error message by forcing TestCase.longMessage = False
 | |
|             self.longMessage = False
 | |
|             try:
 | |
|                 self.assertEqual(actual, expected)
 | |
|             except AssertionError as error:
 | |
|                 default_msg = str(error)
 | |
|             else:
 | |
|                 raise AssertionError("AssertionError not raised")
 | |
| 
 | |
|             self.longMessage = True
 | |
|             extra_msg = "sentinel"
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape(f"{default_msg}\n{extra_msg}")):
 | |
|                 self.assertEqual(actual, expected, msg=extra_msg)
 | |
|         finally:
 | |
|             self.longMessage = long_message
 | |
| 
 | |
|     def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
 | |
|         for test in tests:
 | |
|             a = torch.tensor((test[0],), device=device, dtype=dtype)
 | |
|             b = torch.tensor((test[1],), device=device, dtype=dtype)
 | |
| 
 | |
|             actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
 | |
|             expected = test[2]
 | |
|             self.assertEqual(actual.item(), expected)
 | |
| 
 | |
|     def test_isclose_bool(self, device):
 | |
|         tests = (
 | |
|             (True, True, True),
 | |
|             (False, False, True),
 | |
|             (True, False, False),
 | |
|             (False, True, False),
 | |
|         )
 | |
| 
 | |
|         self._isclose_helper(tests, device, torch.bool, False)
 | |
| 
 | |
|     @dtypes(torch.uint8,
 | |
|             torch.int8, torch.int16, torch.int32, torch.int64)
 | |
|     def test_isclose_integer(self, device, dtype):
 | |
|         tests = (
 | |
|             (0, 0, True),
 | |
|             (0, 1, False),
 | |
|             (1, 0, False),
 | |
|         )
 | |
| 
 | |
|         self._isclose_helper(tests, device, dtype, False)
 | |
| 
 | |
|         # atol and rtol tests
 | |
|         tests = [
 | |
|             (0, 1, True),
 | |
|             (1, 0, False),
 | |
|             (1, 3, True),
 | |
|         ]
 | |
| 
 | |
|         self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
 | |
| 
 | |
|         if dtype is torch.uint8:
 | |
|             tests = [
 | |
|                 (-1, 1, False),
 | |
|                 (1, -1, False)
 | |
|             ]
 | |
|         else:
 | |
|             tests = [
 | |
|                 (-1, 1, True),
 | |
|                 (1, -1, True)
 | |
|             ]
 | |
| 
 | |
|         self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
 | |
| 
 | |
|     @onlyNativeDeviceTypes
 | |
|     @dtypes(torch.float16, torch.float32, torch.float64)
 | |
|     def test_isclose_float(self, device, dtype):
 | |
|         tests = (
 | |
|             (0, 0, True),
 | |
|             (0, -1, False),
 | |
|             (float('inf'), float('inf'), True),
 | |
|             (-float('inf'), float('inf'), False),
 | |
|             (float('inf'), float('nan'), False),
 | |
|             (float('nan'), float('nan'), False),
 | |
|             (0, float('nan'), False),
 | |
|             (1, 1, True),
 | |
|         )
 | |
| 
 | |
|         self._isclose_helper(tests, device, dtype, False)
 | |
| 
 | |
|         # atol and rtol tests
 | |
|         eps = 1e-2 if dtype is torch.half else 1e-6
 | |
|         tests = (
 | |
|             (0, 1, True),
 | |
|             (0, 1 + eps, False),
 | |
|             (1, 0, False),
 | |
|             (1, 3, True),
 | |
|             (1 - eps, 3, False),
 | |
|             (-.25, .5, True),
 | |
|             (-.25 - eps, .5, False),
 | |
|             (.25, -.5, True),
 | |
|             (.25 + eps, -.5, False),
 | |
|         )
 | |
| 
 | |
|         self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
 | |
| 
 | |
|         # equal_nan = True tests
 | |
|         tests = (
 | |
|             (0, float('nan'), False),
 | |
|             (float('inf'), float('nan'), False),
 | |
|             (float('nan'), float('nan'), True),
 | |
|         )
 | |
| 
 | |
|         self._isclose_helper(tests, device, dtype, True)
 | |
| 
 | |
|     @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
 | |
|     @dtypes(torch.complex64, torch.complex128)
 | |
|     def test_isclose_complex(self, device, dtype):
 | |
|         tests = (
 | |
|             (complex(1, 1), complex(1, 1 + 1e-8), True),
 | |
|             (complex(0, 1), complex(1, 1), False),
 | |
|             (complex(1, 1), complex(1, 0), False),
 | |
|             (complex(1, 1), complex(1, float('nan')), False),
 | |
|             (complex(1, float('nan')), complex(1, float('nan')), False),
 | |
|             (complex(1, 1), complex(1, float('inf')), False),
 | |
|             (complex(float('inf'), 1), complex(1, float('inf')), False),
 | |
|             (complex(-float('inf'), 1), complex(1, float('inf')), False),
 | |
|             (complex(-float('inf'), 1), complex(float('inf'), 1), False),
 | |
|             (complex(float('inf'), 1), complex(float('inf'), 1), True),
 | |
|             (complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
 | |
|         )
 | |
| 
 | |
|         self._isclose_helper(tests, device, dtype, False)
 | |
| 
 | |
|         # atol and rtol tests
 | |
| 
 | |
|         # atol and rtol tests
 | |
|         eps = 1e-6
 | |
|         tests = (
 | |
|             # Complex versions of float tests (real part)
 | |
|             (complex(0, 0), complex(1, 0), True),
 | |
|             (complex(0, 0), complex(1 + eps, 0), False),
 | |
|             (complex(1, 0), complex(0, 0), False),
 | |
|             (complex(1, 0), complex(3, 0), True),
 | |
|             (complex(1 - eps, 0), complex(3, 0), False),
 | |
|             (complex(-.25, 0), complex(.5, 0), True),
 | |
|             (complex(-.25 - eps, 0), complex(.5, 0), False),
 | |
|             (complex(.25, 0), complex(-.5, 0), True),
 | |
|             (complex(.25 + eps, 0), complex(-.5, 0), False),
 | |
|             # Complex versions of float tests (imaginary part)
 | |
|             (complex(0, 0), complex(0, 1), True),
 | |
|             (complex(0, 0), complex(0, 1 + eps), False),
 | |
|             (complex(0, 1), complex(0, 0), False),
 | |
|             (complex(0, 1), complex(0, 3), True),
 | |
|             (complex(0, 1 - eps), complex(0, 3), False),
 | |
|             (complex(0, -.25), complex(0, .5), True),
 | |
|             (complex(0, -.25 - eps), complex(0, .5), False),
 | |
|             (complex(0, .25), complex(0, -.5), True),
 | |
|             (complex(0, .25 + eps), complex(0, -.5), False),
 | |
|         )
 | |
| 
 | |
|         self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
 | |
| 
 | |
|         # atol and rtol tests for isclose
 | |
|         tests = (
 | |
|             # Complex-specific tests
 | |
|             (complex(1, -1), complex(-1, 1), False),
 | |
|             (complex(1, -1), complex(2, -2), True),
 | |
|             (complex(-math.sqrt(2), math.sqrt(2)),
 | |
|              complex(-math.sqrt(.5), math.sqrt(.5)), True),
 | |
|             (complex(-math.sqrt(2), math.sqrt(2)),
 | |
|              complex(-math.sqrt(.501), math.sqrt(.499)), False),
 | |
|             (complex(2, 4), complex(1., 8.8523607), True),
 | |
|             (complex(2, 4), complex(1., 8.8523607 + eps), False),
 | |
|             (complex(1, 99), complex(4, 100), True),
 | |
|         )
 | |
|         self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
 | |
| 
 | |
|         # equal_nan = True tests
 | |
|         tests = (
 | |
|             (complex(1, 1), complex(1, float('nan')), False),
 | |
|             (complex(1, 1), complex(float('nan'), 1), False),
 | |
|             (complex(float('nan'), 1), complex(float('nan'), 1), True),
 | |
|             (complex(float('nan'), 1), complex(1, float('nan')), True),
 | |
|             (complex(float('nan'), float('nan')), complex(float('nan'), float('nan')), True),
 | |
|         )
 | |
|         self._isclose_helper(tests, device, dtype, True)
 | |
| 
 | |
|     # Tests that isclose with rtol or atol values less than zero throws a
 | |
|     #   RuntimeError
 | |
|     @dtypes(torch.bool, torch.uint8,
 | |
|             torch.int8, torch.int16, torch.int32, torch.int64,
 | |
|             torch.float16, torch.float32, torch.float64)
 | |
|     def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
 | |
|         t = torch.tensor((1,), device=device, dtype=dtype)
 | |
| 
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             torch.isclose(t, t, atol=-1, rtol=1)
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             torch.isclose(t, t, atol=1, rtol=-1)
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             torch.isclose(t, t, atol=-1, rtol=-1)
 | |
| 
 | |
|     def test_isclose_equality_shortcut(self):
 | |
|         # For values >= 2**53, integers differing by 1 can no longer differentiated by torch.float64 or lower precision
 | |
|         # floating point dtypes. Thus, even with rtol == 0 and atol == 0, these tensors would be considered close if
 | |
|         # they were not compared as integers.
 | |
|         a = torch.tensor(2 ** 53, dtype=torch.int64)
 | |
|         b = a + 1
 | |
| 
 | |
|         self.assertFalse(torch.isclose(a, b, rtol=0, atol=0))
 | |
| 
 | |
|     @dtypes(torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128)
 | |
|     def test_isclose_nan_equality_shortcut(self, device, dtype):
 | |
|         if dtype.is_floating_point:
 | |
|             a = b = torch.nan
 | |
|         else:
 | |
|             a = complex(torch.nan, 0)
 | |
|             b = complex(0, torch.nan)
 | |
| 
 | |
|         expected = True
 | |
|         tests = [(a, b, expected)]
 | |
| 
 | |
|         self._isclose_helper(tests, device, dtype, equal_nan=True, rtol=0, atol=0)
 | |
| 
 | |
|     # The following tests (test_cuda_assert_*) are added to ensure test suite terminates early
 | |
|     # when CUDA assert was thrown. Because all subsequent test will fail if that happens.
 | |
|     # These tests are slow because it spawn another process to run test suite.
 | |
|     # See: https://github.com/pytorch/pytorch/issues/49019
 | |
|     @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
 | |
|     @onlyCUDA
 | |
|     @slowTest
 | |
|     def test_cuda_assert_should_stop_common_utils_test_suite(self, device):
 | |
|         # test to ensure common_utils.py override has early termination for CUDA.
 | |
|         stderr = TestCase.runWithPytorchAPIUsageStderr("""\
 | |
| #!/usr/bin/env python3
 | |
| 
 | |
| import torch
 | |
| from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
 | |
| 
 | |
| class TestThatContainsCUDAAssertFailure(TestCase):
 | |
| 
 | |
|     @slowTest
 | |
|     def test_throw_unrecoverable_cuda_exception(self):
 | |
|         x = torch.rand(10, device='cuda')
 | |
|         # cause unrecoverable CUDA exception, recoverable on CPU
 | |
|         y = x[torch.tensor([25])].cpu()
 | |
| 
 | |
|     @slowTest
 | |
|     def test_trivial_passing_test_case_on_cpu_cuda(self):
 | |
|         x1 = torch.tensor([0., 1.], device='cuda')
 | |
|         x2 = torch.tensor([0., 1.], device='cpu')
 | |
|         self.assertEqual(x1, x2)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 | |
| """)
 | |
|         # should capture CUDA error
 | |
|         self.assertIn('CUDA error: device-side assert triggered', stderr)
 | |
|         # should run only 1 test because it throws unrecoverable error.
 | |
|         self.assertIn('errors=1', stderr)
 | |
| 
 | |
| 
 | |
|     @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
 | |
|     @onlyCUDA
 | |
|     @slowTest
 | |
|     def test_cuda_assert_should_stop_common_device_type_test_suite(self, device):
 | |
|         # test to ensure common_device_type.py override has early termination for CUDA.
 | |
|         stderr = TestCase.runWithPytorchAPIUsageStderr("""\
 | |
| #!/usr/bin/env python3
 | |
| 
 | |
| import torch
 | |
| from torch.testing._internal.common_utils import (TestCase, run_tests, slowTest)
 | |
| from torch.testing._internal.common_device_type import instantiate_device_type_tests
 | |
| 
 | |
| class TestThatContainsCUDAAssertFailure(TestCase):
 | |
| 
 | |
|     @slowTest
 | |
|     def test_throw_unrecoverable_cuda_exception(self, device):
 | |
|         x = torch.rand(10, device=device)
 | |
|         # cause unrecoverable CUDA exception, recoverable on CPU
 | |
|         y = x[torch.tensor([25])].cpu()
 | |
| 
 | |
|     @slowTest
 | |
|     def test_trivial_passing_test_case_on_cpu_cuda(self, device):
 | |
|         x1 = torch.tensor([0., 1.], device=device)
 | |
|         x2 = torch.tensor([0., 1.], device='cpu')
 | |
|         self.assertEqual(x1, x2)
 | |
| 
 | |
| instantiate_device_type_tests(
 | |
|     TestThatContainsCUDAAssertFailure,
 | |
|     globals(),
 | |
|     only_for='cuda'
 | |
| )
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 | |
| """)
 | |
|         # should capture CUDA error
 | |
|         self.assertIn('CUDA error: device-side assert triggered', stderr)
 | |
|         # should run only 1 test because it throws unrecoverable error.
 | |
|         self.assertIn('errors=1', stderr)
 | |
| 
 | |
| 
 | |
|     @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support device side asserts")
 | |
|     @onlyCUDA
 | |
|     @slowTest
 | |
|     def test_cuda_assert_should_not_stop_common_distributed_test_suite(self, device):
 | |
|         # test to ensure common_distributed.py override should not early terminate CUDA.
 | |
|         stderr = TestCase.runWithPytorchAPIUsageStderr("""\
 | |
| #!/usr/bin/env python3
 | |
| 
 | |
| import torch
 | |
| from torch.testing._internal.common_utils import (run_tests, slowTest)
 | |
| from torch.testing._internal.common_device_type import instantiate_device_type_tests
 | |
| from torch.testing._internal.common_distributed import MultiProcessTestCase
 | |
| 
 | |
| class TestThatContainsCUDAAssertFailure(MultiProcessTestCase):
 | |
| 
 | |
|     @slowTest
 | |
|     def test_throw_unrecoverable_cuda_exception(self, device):
 | |
|         x = torch.rand(10, device=device)
 | |
|         # cause unrecoverable CUDA exception, recoverable on CPU
 | |
|         y = x[torch.tensor([25])].cpu()
 | |
| 
 | |
|     @slowTest
 | |
|     def test_trivial_passing_test_case_on_cpu_cuda(self, device):
 | |
|         x1 = torch.tensor([0., 1.], device=device)
 | |
|         x2 = torch.tensor([0., 1.], device='cpu')
 | |
|         self.assertEqual(x1, x2)
 | |
| 
 | |
| instantiate_device_type_tests(
 | |
|     TestThatContainsCUDAAssertFailure,
 | |
|     globals(),
 | |
|     only_for='cuda'
 | |
| )
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 | |
| """)
 | |
|         # we are currently disabling CUDA early termination for distributed tests.
 | |
|         self.assertIn('errors=2', stderr)
 | |
| 
 | |
|     @expectedFailureMeta  # This is only supported for CPU and CUDA
 | |
|     @onlyNativeDeviceTypes
 | |
|     def test_get_supported_dtypes(self, device):
 | |
|         # Test the `get_supported_dtypes` helper function.
 | |
|         # We acquire the dtypes for few Ops dynamically and verify them against
 | |
|         # the correct statically described values.
 | |
|         ops_to_test = list(filter(lambda op: op.name in ['atan2', 'topk', 'xlogy'], op_db))
 | |
| 
 | |
|         for op in ops_to_test:
 | |
|             dynamic_dtypes = opinfo.utils.get_supported_dtypes(op, op.sample_inputs_func, self.device_type)
 | |
|             dynamic_dispatch = opinfo.utils.dtypes_dispatch_hint(dynamic_dtypes)
 | |
|             if self.device_type == 'cpu':
 | |
|                 dtypes = op.dtypes
 | |
|             else:  # device_type ='cuda'
 | |
|                 dtypes = op.dtypesIfCUDA
 | |
| 
 | |
|             self.assertTrue(set(dtypes) == set(dynamic_dtypes))
 | |
|             self.assertTrue(set(dtypes) == set(dynamic_dispatch.dispatch_fn()))
 | |
| 
 | |
|     @onlyCPU
 | |
|     @ops(
 | |
|         [
 | |
|             op
 | |
|             for op in op_db
 | |
|             if len(
 | |
|                 op.supported_dtypes("cpu").symmetric_difference(
 | |
|                     op.supported_dtypes("cuda")
 | |
|                 )
 | |
|             )
 | |
|             > 0
 | |
|         ][:1],
 | |
|         dtypes=OpDTypes.none,
 | |
|     )
 | |
|     def test_supported_dtypes(self, device, op):
 | |
|         self.assertNotEqual(op.supported_dtypes("cpu"), op.supported_dtypes("cuda"))
 | |
|         self.assertEqual(op.supported_dtypes("cuda"), op.supported_dtypes("cuda:0"))
 | |
|         self.assertEqual(
 | |
|             op.supported_dtypes(torch.device("cuda")),
 | |
|             op.supported_dtypes(torch.device("cuda", index=1)),
 | |
|         )
 | |
| 
 | |
|     def test_setup_and_teardown_run_for_device_specific_tests(self, device):
 | |
|         # TODO: Move this (and other similar text blocks) to some fixtures/ subdir
 | |
|         stderr = TestCase.runWithPytorchAPIUsageStderr(f"""\
 | |
| #!/usr/bin/env python3
 | |
| 
 | |
| import torch
 | |
| from torch.testing._internal.common_device_type import instantiate_device_type_tests
 | |
| from torch.testing._internal.common_utils import TestCase, run_tests
 | |
| 
 | |
| class TestFoo(TestCase):
 | |
|     @classmethod
 | |
|     def setUpClass(cls):
 | |
|         # store something on the test class to query during teardown
 | |
|         cls.stored_thing = "called with " + cls.__name__
 | |
| 
 | |
|     @classmethod
 | |
|     def tearDownClass(cls):
 | |
|         # throw here so we know teardown was run
 | |
|         raise RuntimeError(cls.stored_thing)
 | |
| 
 | |
|     def test_bar(self, device):
 | |
|         # make sure the test can access the stored thing
 | |
|         print(self.stored_thing)
 | |
| 
 | |
| instantiate_device_type_tests(TestFoo, globals(), only_for='{self.device_type}')
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 | |
| """)
 | |
|         expected_device_class_name = f"TestFoo{self.device_type.upper()}"
 | |
|         expected_error_text = f"RuntimeError: called with {expected_device_class_name}"
 | |
|         self.assertIn(expected_error_text, stderr)
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestTesting, globals())
 | |
| 
 | |
| 
 | |
| class TestFrameworkUtils(TestCase):
 | |
| 
 | |
|     @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
 | |
|     @unittest.skipIf(IS_SANDCASTLE, "Skipping because doesn't work on sandcastle")
 | |
|     def test_filtering_env_var(self):
 | |
|         # Test environment variable selected device type test generator.
 | |
|         test_filter_file_template = """\
 | |
| #!/usr/bin/env python3
 | |
| 
 | |
| import torch
 | |
| from torch.testing._internal.common_utils import (TestCase, run_tests)
 | |
| from torch.testing._internal.common_device_type import instantiate_device_type_tests
 | |
| 
 | |
| class TestEnvironmentVariable(TestCase):
 | |
| 
 | |
|     def test_trivial_passing_test(self, device):
 | |
|         x1 = torch.tensor([0., 1.], device=device)
 | |
|         x2 = torch.tensor([0., 1.], device='cpu')
 | |
|         self.assertEqual(x1, x2)
 | |
| 
 | |
| instantiate_device_type_tests(
 | |
|     TestEnvironmentVariable,
 | |
|     globals(),
 | |
| )
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 | |
| """
 | |
|         test_bases_count = len(get_device_type_test_bases())
 | |
|         # Test without setting env var should run everything.
 | |
|         env = dict(os.environ)
 | |
|         for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]:
 | |
|             if k in env.keys():
 | |
|                 del env[k]
 | |
|         _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
 | |
|         self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii'))
 | |
| 
 | |
|         # Test with setting only_for should only run 1 test.
 | |
|         env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
 | |
|         _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
 | |
|         self.assertIn('Ran 1 test', stderr.decode('ascii'))
 | |
| 
 | |
|         # Test with setting except_for should run 1 less device type from default.
 | |
|         del env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY]
 | |
|         env[PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY] = 'cpu'
 | |
|         _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
 | |
|         self.assertIn(f'Ran {test_bases_count - 1} test', stderr.decode('ascii'))
 | |
| 
 | |
|         # Test with setting both should throw exception
 | |
|         env[PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY] = 'cpu'
 | |
|         _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env)
 | |
|         self.assertNotIn('OK', stderr.decode('ascii'))
 | |
| 
 | |
| 
 | |
| def make_assert_close_inputs(actual: Any, expected: Any) -> list[tuple[Any, Any]]:
 | |
|     """Makes inputs for :func:`torch.testing.assert_close` functions based on two examples.
 | |
| 
 | |
|     Args:
 | |
|         actual (Any): Actual input.
 | |
|         expected (Any): Expected input.
 | |
| 
 | |
|     Returns:
 | |
|         List[Tuple[Any, Any]]: Pair of example inputs, as well as the example inputs wrapped in sequences
 | |
|         (:class:`tuple`, :class:`list`), and mappings (:class:`dict`, :class:`~collections.OrderedDict`).
 | |
|     """
 | |
|     return [
 | |
|         (actual, expected),
 | |
|         # tuple vs. tuple
 | |
|         ((actual,), (expected,)),
 | |
|         # list vs. list
 | |
|         ([actual], [expected]),
 | |
|         # tuple vs. list
 | |
|         ((actual,), [expected]),
 | |
|         # dict vs. dict
 | |
|         ({"t": actual}, {"t": expected}),
 | |
|         # OrderedDict vs. OrderedDict
 | |
|         (collections.OrderedDict([("t", actual)]), collections.OrderedDict([("t", expected)])),
 | |
|         # dict vs. OrderedDict
 | |
|         ({"t": actual}, collections.OrderedDict([("t", expected)])),
 | |
|         # list of tuples vs. tuple of lists
 | |
|         ([(actual,)], ([expected],)),
 | |
|         # list of dicts vs. tuple of OrderedDicts
 | |
|         ([{"t": actual}], (collections.OrderedDict([("t", expected)]),)),
 | |
|         # dict of lists vs. OrderedDict of tuples
 | |
|         ({"t": [actual]}, collections.OrderedDict([("t", (expected,))])),
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def assert_close_with_inputs(actual: Any, expected: Any) -> Iterator[Callable]:
 | |
|     """Yields :func:`torch.testing.assert_close` with predefined positional inputs based on two examples.
 | |
| 
 | |
|     .. note::
 | |
| 
 | |
|         Every test that does not test for a specific input should iterate over this to maximize the coverage.
 | |
| 
 | |
|     Args:
 | |
|         actual (Any): Actual input.
 | |
|         expected (Any): Expected input.
 | |
| 
 | |
|     Yields:
 | |
|         Callable: :func:`torch.testing.assert_close` with predefined positional inputs.
 | |
|     """
 | |
|     for inputs in make_assert_close_inputs(actual, expected):
 | |
|         yield functools.partial(torch.testing.assert_close, *inputs)
 | |
| 
 | |
| 
 | |
| class TestAssertClose(TestCase):
 | |
|     def test_mismatching_types_subclasses(self):
 | |
|         actual = torch.rand(())
 | |
|         expected = torch.nn.Parameter(actual)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_mismatching_types_type_equality(self):
 | |
|         actual = torch.empty(())
 | |
|         expected = torch.nn.Parameter(actual)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(TypeError, str(type(expected))):
 | |
|                 fn(allow_subclasses=False)
 | |
| 
 | |
|     def test_mismatching_types(self):
 | |
|         actual = torch.empty(2)
 | |
|         expected = actual.numpy()
 | |
| 
 | |
|         for fn, allow_subclasses in itertools.product(assert_close_with_inputs(actual, expected), (True, False)):
 | |
|             with self.assertRaisesRegex(TypeError, str(type(expected))):
 | |
|                 fn(allow_subclasses=allow_subclasses)
 | |
| 
 | |
|     def test_unknown_type(self):
 | |
|         actual = "0"
 | |
|         expected = "0"
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(TypeError, str(type(actual))):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_shape(self):
 | |
|         actual = torch.empty(())
 | |
|         expected = actual.clone().reshape((1,))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, "shape"):
 | |
|                 fn()
 | |
| 
 | |
|     @unittest.skipIf(not torch.backends.mkldnn.is_available(), reason="MKLDNN is not available.")
 | |
|     def test_unknown_layout(self):
 | |
|         actual = torch.empty((2, 2))
 | |
|         expected = actual.to_mkldnn()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(ValueError, "layout"):
 | |
|                 fn()
 | |
| 
 | |
|     def test_meta(self):
 | |
|         actual = torch.empty((2, 2), device="meta")
 | |
|         expected = torch.empty((2, 2), device="meta")
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_mismatching_layout(self):
 | |
|         strided = torch.empty((2, 2))
 | |
|         sparse_coo = strided.to_sparse()
 | |
|         sparse_csr = strided.to_sparse_csr()
 | |
| 
 | |
|         for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
 | |
|             for fn in assert_close_with_inputs(actual, expected):
 | |
|                 with self.assertRaisesRegex(AssertionError, "layout"):
 | |
|                     fn()
 | |
| 
 | |
|     def test_mismatching_layout_no_check(self):
 | |
|         strided = torch.randn((2, 2))
 | |
|         sparse_coo = strided.to_sparse()
 | |
|         sparse_csr = strided.to_sparse_csr()
 | |
| 
 | |
|         for actual, expected in itertools.combinations((strided, sparse_coo, sparse_csr), 2):
 | |
|             for fn in assert_close_with_inputs(actual, expected):
 | |
|                 fn(check_layout=False)
 | |
| 
 | |
|     def test_mismatching_dtype(self):
 | |
|         actual = torch.empty((), dtype=torch.float)
 | |
|         expected = actual.clone().to(torch.int)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, "dtype"):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_dtype_no_check(self):
 | |
|         actual = torch.ones((), dtype=torch.float)
 | |
|         expected = actual.clone().to(torch.int)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn(check_dtype=False)
 | |
| 
 | |
|     def test_mismatching_stride(self):
 | |
|         actual = torch.empty((2, 2))
 | |
|         expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, "stride"):
 | |
|                 fn(check_stride=True)
 | |
| 
 | |
|     def test_mismatching_stride_no_check(self):
 | |
|         actual = torch.rand((2, 2))
 | |
|         expected = torch.as_strided(actual.clone().t().contiguous(), actual.shape, actual.stride()[::-1])
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_only_rtol(self):
 | |
|         actual = torch.empty(())
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaises(ValueError):
 | |
|                 fn(rtol=0.0)
 | |
| 
 | |
|     def test_only_atol(self):
 | |
|         actual = torch.empty(())
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaises(ValueError):
 | |
|                 fn(atol=0.0)
 | |
| 
 | |
|     def test_mismatching_values(self):
 | |
|         actual = torch.tensor(1)
 | |
|         expected = torch.tensor(2)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaises(AssertionError):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_values_rtol(self):
 | |
|         eps = 1e-3
 | |
|         actual = torch.tensor(1.0)
 | |
|         expected = torch.tensor(1.0 + eps)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaises(AssertionError):
 | |
|                 fn(rtol=eps / 2, atol=0.0)
 | |
| 
 | |
|     def test_mismatching_values_atol(self):
 | |
|         eps = 1e-3
 | |
|         actual = torch.tensor(0.0)
 | |
|         expected = torch.tensor(eps)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaises(AssertionError):
 | |
|                 fn(rtol=0.0, atol=eps / 2)
 | |
| 
 | |
|     def test_matching(self):
 | |
|         actual = torch.tensor(1.0)
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         torch.testing.assert_close(actual, expected)
 | |
| 
 | |
|     def test_matching_rtol(self):
 | |
|         eps = 1e-3
 | |
|         actual = torch.tensor(1.0)
 | |
|         expected = torch.tensor(1.0 + eps)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn(rtol=eps * 2, atol=0.0)
 | |
| 
 | |
|     def test_matching_atol(self):
 | |
|         eps = 1e-3
 | |
|         actual = torch.tensor(0.0)
 | |
|         expected = torch.tensor(eps)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn(rtol=0.0, atol=eps * 2)
 | |
| 
 | |
|     # TODO: the code that this test was designed for was removed in https://github.com/pytorch/pytorch/pull/56058
 | |
|     #  We need to check if this test is still needed or if this behavior is now enabled by default.
 | |
|     def test_matching_conjugate_bit(self):
 | |
|         actual = torch.tensor(complex(1, 1)).conj()
 | |
|         expected = torch.tensor(complex(1, -1))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_matching_nan(self):
 | |
|         nan = float("NaN")
 | |
| 
 | |
|         tests = (
 | |
|             (nan, nan),
 | |
|             (complex(nan, 0), complex(0, nan)),
 | |
|             (complex(nan, nan), complex(nan, 0)),
 | |
|             (complex(nan, nan), complex(nan, nan)),
 | |
|         )
 | |
| 
 | |
|         for actual, expected in tests:
 | |
|             for fn in assert_close_with_inputs(actual, expected):
 | |
|                 with self.assertRaises(AssertionError):
 | |
|                     fn()
 | |
| 
 | |
|     def test_matching_nan_with_equal_nan(self):
 | |
|         nan = float("NaN")
 | |
| 
 | |
|         tests = (
 | |
|             (nan, nan),
 | |
|             (complex(nan, 0), complex(0, nan)),
 | |
|             (complex(nan, nan), complex(nan, 0)),
 | |
|             (complex(nan, nan), complex(nan, nan)),
 | |
|         )
 | |
| 
 | |
|         for actual, expected in tests:
 | |
|             for fn in assert_close_with_inputs(actual, expected):
 | |
|                 fn(equal_nan=True)
 | |
| 
 | |
|     def test_numpy(self):
 | |
|         tensor = torch.rand(2, 2, dtype=torch.float32)
 | |
|         actual = tensor.numpy()
 | |
|         expected = actual.copy()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_scalar(self):
 | |
|         number = torch.randint(10, size=()).item()
 | |
|         for actual, expected in itertools.product((int(number), float(number), complex(number)), repeat=2):
 | |
|             check_dtype = type(actual) is type(expected)
 | |
| 
 | |
|             for fn in assert_close_with_inputs(actual, expected):
 | |
|                 fn(check_dtype=check_dtype)
 | |
| 
 | |
|     def test_bool(self):
 | |
|         actual = torch.tensor([True, False])
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_none(self):
 | |
|         actual = expected = None
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_none_mismatch(self):
 | |
|         expected = None
 | |
| 
 | |
|         for actual in (False, 0, torch.nan, torch.tensor(torch.nan)):
 | |
|             for fn in assert_close_with_inputs(actual, expected):
 | |
|                 with self.assertRaises(AssertionError):
 | |
|                     fn()
 | |
| 
 | |
| 
 | |
|     def test_docstring_examples(self):
 | |
|         finder = doctest.DocTestFinder(verbose=False)
 | |
|         runner = doctest.DocTestRunner(verbose=False, optionflags=doctest.NORMALIZE_WHITESPACE)
 | |
|         globs = dict(torch=torch)
 | |
|         doctests = finder.find(torch.testing.assert_close, globs=globs)[0]
 | |
|         failures = []
 | |
|         runner.run(doctests, out=lambda report: failures.append(report))
 | |
|         if failures:
 | |
|             raise AssertionError(f"Doctest found {len(failures)} failures:\n\n" + "\n".join(failures))
 | |
| 
 | |
|     def test_default_tolerance_selection_mismatching_dtypes(self):
 | |
|         # If the default tolerances where selected based on the promoted dtype, i.e. float64,
 | |
|         # these tensors wouldn't be considered close.
 | |
|         actual = torch.tensor(0.99, dtype=torch.bfloat16)
 | |
|         expected = torch.tensor(1.0, dtype=torch.float64)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn(check_dtype=False)
 | |
| 
 | |
|     class UnexpectedException(Exception):
 | |
|         """The only purpose of this exception is to test ``assert_close``'s handling of unexpected exceptions. Thus,
 | |
|         the test should mock a component to raise this instead of the regular behavior. We avoid using a builtin
 | |
|         exception here to avoid triggering possible handling of them.
 | |
|         """
 | |
| 
 | |
|     @unittest.mock.patch("torch.testing._comparison.TensorLikePair.__init__", side_effect=UnexpectedException)
 | |
|     def test_unexpected_error_originate(self, _):
 | |
|         actual = torch.tensor(1.0)
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
 | |
|             torch.testing.assert_close(actual, expected)
 | |
| 
 | |
|     @unittest.mock.patch("torch.testing._comparison.TensorLikePair.compare", side_effect=UnexpectedException)
 | |
|     def test_unexpected_error_compare(self, _):
 | |
|         actual = torch.tensor(1.0)
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "unexpected exception"):
 | |
|             torch.testing.assert_close(actual, expected)
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| class TestAssertCloseMultiDevice(TestCase):
 | |
|     @deviceCountAtLeast(1)
 | |
|     def test_mismatching_device(self, devices):
 | |
|         for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
 | |
|             actual = torch.empty((), device=actual_device)
 | |
|             expected = actual.clone().to(expected_device)
 | |
|             for fn in assert_close_with_inputs(actual, expected):
 | |
|                 with self.assertRaisesRegex(AssertionError, "device"):
 | |
|                     fn()
 | |
| 
 | |
|     @deviceCountAtLeast(1)
 | |
|     def test_mismatching_device_no_check(self, devices):
 | |
|         for actual_device, expected_device in itertools.permutations(("cpu", *devices), 2):
 | |
|             actual = torch.rand((), device=actual_device)
 | |
|             expected = actual.clone().to(expected_device)
 | |
|             for fn in assert_close_with_inputs(actual, expected):
 | |
|                 fn(check_device=False)
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestAssertCloseMultiDevice, globals(), only_for="cuda")
 | |
| 
 | |
| 
 | |
| class TestAssertCloseErrorMessage(TestCase):
 | |
|     def test_identifier_tensor_likes(self):
 | |
|         actual = torch.tensor([1, 2, 3, 4])
 | |
|         expected = torch.tensor([1, 2, 5, 6])
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Tensor-likes")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_identifier_scalars(self):
 | |
|         actual = 3
 | |
|         expected = 5
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Scalars")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_not_equal(self):
 | |
|         actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
 | |
|         expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("not equal")):
 | |
|                 fn(rtol=0.0, atol=0.0)
 | |
| 
 | |
|     def test_not_close(self):
 | |
|         actual = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
 | |
|         expected = torch.tensor([1, 2, 5, 6], dtype=torch.float32)
 | |
| 
 | |
|         for fn, (rtol, atol) in itertools.product(
 | |
|             assert_close_with_inputs(actual, expected), ((1.3e-6, 0.0), (0.0, 1e-5), (1.3e-6, 1e-5))
 | |
|         ):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("not close")):
 | |
|                 fn(rtol=rtol, atol=atol)
 | |
| 
 | |
|     def test_mismatched_elements(self):
 | |
|         actual = torch.tensor([1, 2, 3, 4])
 | |
|         expected = torch.tensor([1, 2, 5, 6])
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Mismatched elements: 2 / 4 (50.0%)")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_abs_diff(self):
 | |
|         actual = torch.tensor([[1, 2], [3, 4]])
 | |
|         expected = torch.tensor([[1, 2], [5, 4]])
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Greatest absolute difference: 2 at index (1, 0)")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_small_float_dtype(self):
 | |
|         for dtype in [
 | |
|             torch.float8_e4m3fn,
 | |
|             torch.float8_e4m3fnuz,
 | |
|             torch.float8_e5m2,
 | |
|             torch.float8_e5m2fnuz,
 | |
|             torch.float8_e8m0fnu,
 | |
|         ]:
 | |
|             w_vector = torch.tensor([3.14, 1.0], dtype=dtype)
 | |
|             x_vector = torch.tensor([1.0, 3.14], dtype=dtype)
 | |
|             y_vector = torch.tensor([3.14, 3.14], dtype=dtype)
 | |
|             z_vector = torch.tensor([1.0, 3.14], dtype=dtype)
 | |
| 
 | |
|             for additional_dims in range(4):
 | |
|                 new_shape = list(w_vector.shape) + ([1] * additional_dims)
 | |
|                 w_tensor = w_vector.reshape(new_shape)
 | |
|                 x_tensor = x_vector.reshape(new_shape)
 | |
|                 y_tensor = y_vector.reshape(new_shape)
 | |
|                 z_tensor = z_vector.reshape(new_shape)
 | |
| 
 | |
|                 for fn in assert_close_with_inputs(x_tensor, y_tensor):
 | |
|                     expected_shape = (0,) + (0,) * (additional_dims)
 | |
|                     with self.assertRaisesRegex(
 | |
|                         AssertionError, re.escape(f"The first mismatched element is at index {expected_shape}")
 | |
|                     ):
 | |
|                         fn()
 | |
| 
 | |
|                 for fn in assert_close_with_inputs(w_tensor, y_tensor):
 | |
|                     expected_shape = (1,) + (0,) * (additional_dims)
 | |
|                     with self.assertRaisesRegex(
 | |
|                         AssertionError, re.escape(f"The first mismatched element is at index {expected_shape}")
 | |
|                     ):
 | |
|                         fn()
 | |
|                 for fn in assert_close_with_inputs(x_tensor, z_tensor):
 | |
|                     fn()
 | |
| 
 | |
|     def test_abs_diff_scalar(self):
 | |
|         actual = 3
 | |
|         expected = 5
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Absolute difference: 2")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_rel_diff(self):
 | |
|         actual = torch.tensor([[1, 2], [3, 4]])
 | |
|         expected = torch.tensor([[1, 4], [3, 4]])
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Greatest relative difference: 0.5 at index (0, 1)")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_rel_diff_scalar(self):
 | |
|         actual = 2
 | |
|         expected = 4
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Relative difference: 0.5")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_zero_div_zero(self):
 | |
|         actual = torch.tensor([1.0, 0.0])
 | |
|         expected = torch.tensor([2.0, 0.0])
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             # Although it looks complicated, this regex just makes sure that the word 'nan' is not part of the error
 | |
|             # message. That would happen if the 0 / 0 is used for the mismatch computation although it matches.
 | |
|             with self.assertRaisesRegex(AssertionError, "((?!nan).)*"):
 | |
|                 fn()
 | |
| 
 | |
|     def test_rtol(self):
 | |
|         rtol = 1e-3
 | |
| 
 | |
|         actual = torch.tensor((1, 2))
 | |
|         expected = torch.tensor((2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {rtol} allowed)")):
 | |
|                 fn(rtol=rtol, atol=0.0)
 | |
| 
 | |
|     def test_atol(self):
 | |
|         atol = 1e-3
 | |
| 
 | |
|         actual = torch.tensor((1, 2))
 | |
|         expected = torch.tensor((2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape(f"(up to {atol} allowed)")):
 | |
|                 fn(rtol=0.0, atol=atol)
 | |
| 
 | |
|     def test_msg_str(self):
 | |
|         msg = "Custom error message!"
 | |
| 
 | |
|         actual = torch.tensor(1)
 | |
|         expected = torch.tensor(2)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, msg):
 | |
|                 fn(msg=msg)
 | |
| 
 | |
|     def test_msg_callable(self):
 | |
|         msg = "Custom error message"
 | |
| 
 | |
|         actual = torch.tensor(1)
 | |
|         expected = torch.tensor(2)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, msg):
 | |
|                 fn(msg=lambda _: msg)
 | |
| 
 | |
| 
 | |
| class TestAssertCloseContainer(TestCase):
 | |
|     def test_sequence_mismatching_len(self):
 | |
|         actual = (torch.empty(()),)
 | |
|         expected = ()
 | |
| 
 | |
|         with self.assertRaises(AssertionError):
 | |
|             torch.testing.assert_close(actual, expected)
 | |
| 
 | |
|     def test_sequence_mismatching_values_msg(self):
 | |
|         t1 = torch.tensor(1)
 | |
|         t2 = torch.tensor(2)
 | |
| 
 | |
|         actual = (t1, t1)
 | |
|         expected = (t1, t2)
 | |
| 
 | |
|         with self.assertRaisesRegex(AssertionError, re.escape("item [1]")):
 | |
|             torch.testing.assert_close(actual, expected)
 | |
| 
 | |
|     def test_mapping_mismatching_keys(self):
 | |
|         actual = {"a": torch.empty(())}
 | |
|         expected = {}
 | |
| 
 | |
|         with self.assertRaises(AssertionError):
 | |
|             torch.testing.assert_close(actual, expected)
 | |
| 
 | |
|     def test_mapping_mismatching_values_msg(self):
 | |
|         t1 = torch.tensor(1)
 | |
|         t2 = torch.tensor(2)
 | |
| 
 | |
|         actual = {"a": t1, "b": t1}
 | |
|         expected = {"a": t1, "b": t2}
 | |
| 
 | |
|         with self.assertRaisesRegex(AssertionError, re.escape("item ['b']")):
 | |
|             torch.testing.assert_close(actual, expected)
 | |
| 
 | |
| 
 | |
| class TestAssertCloseSparseCOO(TestCase):
 | |
|     def test_matching_coalesced(self):
 | |
|         indices = (
 | |
|             (0, 1),
 | |
|             (1, 0),
 | |
|         )
 | |
|         values = (1, 2)
 | |
|         actual = torch.sparse_coo_tensor(indices, values, size=(2, 2)).coalesce()
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_matching_uncoalesced(self):
 | |
|         indices = (
 | |
|             (0, 1),
 | |
|             (1, 0),
 | |
|         )
 | |
|         values = (1, 2)
 | |
|         actual = torch.sparse_coo_tensor(indices, values, size=(2, 2))
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_mismatching_sparse_dims(self):
 | |
|         t = torch.randn(2, 3, 4)
 | |
|         actual = t.to_sparse()
 | |
|         expected = t.to_sparse(2)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("number of sparse dimensions in sparse COO tensors")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_nnz(self):
 | |
|         actual_indices = (
 | |
|             (0, 1),
 | |
|             (1, 0),
 | |
|         )
 | |
|         actual_values = (1, 2)
 | |
|         actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_indices = (
 | |
|             (0, 1, 1,),
 | |
|             (1, 0, 0,),
 | |
|         )
 | |
|         expected_values = (1, 1, 1)
 | |
|         expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("number of specified values in sparse COO tensors")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_indices_msg(self):
 | |
|         actual_indices = (
 | |
|             (0, 1),
 | |
|             (1, 0),
 | |
|         )
 | |
|         actual_values = (1, 2)
 | |
|         actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_indices = (
 | |
|             (0, 1),
 | |
|             (1, 1),
 | |
|         )
 | |
|         expected_values = (1, 2)
 | |
|         expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO indices")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_values_msg(self):
 | |
|         actual_indices = (
 | |
|             (0, 1),
 | |
|             (1, 0),
 | |
|         )
 | |
|         actual_values = (1, 2)
 | |
|         actual = torch.sparse_coo_tensor(actual_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_indices = (
 | |
|             (0, 1),
 | |
|             (1, 0),
 | |
|         )
 | |
|         expected_values = (1, 3)
 | |
|         expected = torch.sparse_coo_tensor(expected_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse COO values")):
 | |
|                 fn()
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSR testing")
 | |
| class TestAssertCloseSparseCSR(TestCase):
 | |
|     def test_matching(self):
 | |
|         crow_indices = (0, 1, 2)
 | |
|         col_indices = (1, 0)
 | |
|         values = (1, 2)
 | |
|         actual = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_mismatching_crow_indices_msg(self):
 | |
|         actual_crow_indices = (0, 1, 2)
 | |
|         actual_col_indices = (0, 1)
 | |
|         actual_values = (1, 2)
 | |
|         actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_crow_indices = (0, 2, 2)
 | |
|         expected_col_indices = actual_col_indices
 | |
|         expected_values = actual_values
 | |
|         expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR crow_indices")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_col_indices_msg(self):
 | |
|         actual_crow_indices = (0, 1, 2)
 | |
|         actual_col_indices = (1, 0)
 | |
|         actual_values = (1, 2)
 | |
|         actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_crow_indices = actual_crow_indices
 | |
|         expected_col_indices = (1, 1)
 | |
|         expected_values = actual_values
 | |
|         expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR col_indices")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_values_msg(self):
 | |
|         actual_crow_indices = (0, 1, 2)
 | |
|         actual_col_indices = (1, 0)
 | |
|         actual_values = (1, 2)
 | |
|         actual = torch.sparse_csr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_crow_indices = actual_crow_indices
 | |
|         expected_col_indices = actual_col_indices
 | |
|         expected_values = (1, 3)
 | |
|         expected = torch.sparse_csr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR values")):
 | |
|                 fn()
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSC testing")
 | |
| class TestAssertCloseSparseCSC(TestCase):
 | |
|     def test_matching(self):
 | |
|         ccol_indices = (0, 1, 2)
 | |
|         row_indices = (1, 0)
 | |
|         values = (1, 2)
 | |
|         actual = torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=(2, 2))
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_mismatching_ccol_indices_msg(self):
 | |
|         actual_ccol_indices = (0, 1, 2)
 | |
|         actual_row_indices = (0, 1)
 | |
|         actual_values = (1, 2)
 | |
|         actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_ccol_indices = (0, 2, 2)
 | |
|         expected_row_indices = actual_row_indices
 | |
|         expected_values = actual_values
 | |
|         expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC ccol_indices")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_row_indices_msg(self):
 | |
|         actual_ccol_indices = (0, 1, 2)
 | |
|         actual_row_indices = (1, 0)
 | |
|         actual_values = (1, 2)
 | |
|         actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_ccol_indices = actual_ccol_indices
 | |
|         expected_row_indices = (1, 1)
 | |
|         expected_values = actual_values
 | |
|         expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC row_indices")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_values_msg(self):
 | |
|         actual_ccol_indices = (0, 1, 2)
 | |
|         actual_row_indices = (1, 0)
 | |
|         actual_values = (1, 2)
 | |
|         actual = torch.sparse_csc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_ccol_indices = actual_ccol_indices
 | |
|         expected_row_indices = actual_row_indices
 | |
|         expected_values = (1, 3)
 | |
|         expected = torch.sparse_csc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSC values")):
 | |
|                 fn()
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSR testing")
 | |
| class TestAssertCloseSparseBSR(TestCase):
 | |
|     def test_matching(self):
 | |
|         crow_indices = (0, 1, 2)
 | |
|         col_indices = (1, 0)
 | |
|         values = ([[1]], [[2]])
 | |
|         actual = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(2, 2))
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_mismatching_crow_indices_msg(self):
 | |
|         actual_crow_indices = (0, 1, 2)
 | |
|         actual_col_indices = (0, 1)
 | |
|         actual_values = ([[1]], [[2]])
 | |
|         actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_crow_indices = (0, 2, 2)
 | |
|         expected_col_indices = actual_col_indices
 | |
|         expected_values = actual_values
 | |
|         expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR crow_indices")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_col_indices_msg(self):
 | |
|         actual_crow_indices = (0, 1, 2)
 | |
|         actual_col_indices = (1, 0)
 | |
|         actual_values = ([[1]], [[2]])
 | |
|         actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_crow_indices = actual_crow_indices
 | |
|         expected_col_indices = (1, 1)
 | |
|         expected_values = actual_values
 | |
|         expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR col_indices")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_values_msg(self):
 | |
|         actual_crow_indices = (0, 1, 2)
 | |
|         actual_col_indices = (1, 0)
 | |
|         actual_values = ([[1]], [[2]])
 | |
|         actual = torch.sparse_bsr_tensor(actual_crow_indices, actual_col_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_crow_indices = actual_crow_indices
 | |
|         expected_col_indices = actual_col_indices
 | |
|         expected_values = ([[1]], [[3]])
 | |
|         expected = torch.sparse_bsr_tensor(expected_crow_indices, expected_col_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSR values")):
 | |
|                 fn()
 | |
| 
 | |
| 
 | |
| @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support BSC testing")
 | |
| class TestAssertCloseSparseBSC(TestCase):
 | |
|     def test_matching(self):
 | |
|         ccol_indices = (0, 1, 2)
 | |
|         row_indices = (1, 0)
 | |
|         values = ([[1]], [[2]])
 | |
|         actual = torch.sparse_bsc_tensor(ccol_indices, row_indices, values, size=(2, 2))
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_mismatching_ccol_indices_msg(self):
 | |
|         actual_ccol_indices = (0, 1, 2)
 | |
|         actual_row_indices = (0, 1)
 | |
|         actual_values = ([[1]], [[2]])
 | |
|         actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_ccol_indices = (0, 2, 2)
 | |
|         expected_row_indices = actual_row_indices
 | |
|         expected_values = actual_values
 | |
|         expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC ccol_indices")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_row_indices_msg(self):
 | |
|         actual_ccol_indices = (0, 1, 2)
 | |
|         actual_row_indices = (1, 0)
 | |
|         actual_values = ([[1]], [[2]])
 | |
|         actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_ccol_indices = actual_ccol_indices
 | |
|         expected_row_indices = (1, 1)
 | |
|         expected_values = actual_values
 | |
|         expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC row_indices")):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_values_msg(self):
 | |
|         actual_ccol_indices = (0, 1, 2)
 | |
|         actual_row_indices = (1, 0)
 | |
|         actual_values = ([[1]], [[2]])
 | |
|         actual = torch.sparse_bsc_tensor(actual_ccol_indices, actual_row_indices, actual_values, size=(2, 2))
 | |
| 
 | |
|         expected_ccol_indices = actual_ccol_indices
 | |
|         expected_row_indices = actual_row_indices
 | |
|         expected_values = ([[1]], [[3]])
 | |
|         expected = torch.sparse_bsc_tensor(expected_ccol_indices, expected_row_indices, expected_values, size=(2, 2))
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, re.escape("Sparse BSC values")):
 | |
|                 fn()
 | |
| 
 | |
| 
 | |
| class TestAssertCloseQuantized(TestCase):
 | |
|     def test_mismatching_is_quantized(self):
 | |
|         actual = torch.tensor(1.0)
 | |
|         expected = torch.quantize_per_tensor(actual, scale=1.0, zero_point=0, dtype=torch.qint32)
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, "is_quantized"):
 | |
|                 fn()
 | |
| 
 | |
|     def test_mismatching_qscheme(self):
 | |
|         t = torch.tensor((1.0,))
 | |
|         actual = torch.quantize_per_tensor(t, scale=1.0, zero_point=0, dtype=torch.qint32)
 | |
|         expected = torch.quantize_per_channel(
 | |
|             t,
 | |
|             scales=torch.tensor((1.0,)),
 | |
|             zero_points=torch.tensor((0,)),
 | |
|             axis=0,
 | |
|             dtype=torch.qint32,
 | |
|         )
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             with self.assertRaisesRegex(AssertionError, "qscheme"):
 | |
|                 fn()
 | |
| 
 | |
|     def test_matching_per_tensor(self):
 | |
|         actual = torch.quantize_per_tensor(torch.tensor(1.0), scale=1.0, zero_point=0, dtype=torch.qint32)
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
|     def test_matching_per_channel(self):
 | |
|         actual = torch.quantize_per_channel(
 | |
|             torch.tensor((1.0,)),
 | |
|             scales=torch.tensor((1.0,)),
 | |
|             zero_points=torch.tensor((0,)),
 | |
|             axis=0,
 | |
|             dtype=torch.qint32,
 | |
|         )
 | |
|         expected = actual.clone()
 | |
| 
 | |
|         for fn in assert_close_with_inputs(actual, expected):
 | |
|             fn()
 | |
| 
 | |
| 
 | |
| class TestMakeTensor(TestCase):
 | |
|     supported_dtypes = dtypes(
 | |
|         torch.bool,
 | |
|         torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
 | |
|         torch.float16, torch.bfloat16, torch.float32, torch.float64,
 | |
|         torch.complex32, torch.complex64, torch.complex128,
 | |
|     )
 | |
| 
 | |
|     @supported_dtypes
 | |
|     @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
 | |
|     @parametrize("splat_shape", [False, True])
 | |
|     def test_smoke(self, dtype, device, shape, splat_shape):
 | |
|         t = torch.testing.make_tensor(*shape if splat_shape else shape, dtype=dtype, device=device)
 | |
| 
 | |
|         self.assertIsInstance(t, torch.Tensor)
 | |
|         self.assertEqual(t.shape, shape)
 | |
|         self.assertEqual(t.dtype, dtype)
 | |
|         self.assertEqual(t.device, torch.device(device))
 | |
| 
 | |
|     @supported_dtypes
 | |
|     @parametrize("requires_grad", [False, True])
 | |
|     def test_requires_grad(self, dtype, device, requires_grad):
 | |
|         make_tensor = functools.partial(
 | |
|             torch.testing.make_tensor,
 | |
|             dtype=dtype,
 | |
|             device=device,
 | |
|             requires_grad=requires_grad,
 | |
|         )
 | |
| 
 | |
|         if not requires_grad or dtype.is_floating_point or dtype.is_complex:
 | |
|             t = make_tensor()
 | |
|             self.assertEqual(t.requires_grad, requires_grad)
 | |
|         else:
 | |
|             with self.assertRaisesRegex(
 | |
|                     ValueError, "`requires_grad=True` is not supported for boolean and integral dtypes"
 | |
|             ):
 | |
|                 make_tensor()
 | |
| 
 | |
|     @supported_dtypes
 | |
|     @parametrize("noncontiguous", [False, True])
 | |
|     @parametrize("shape", [(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
 | |
|     def test_noncontiguous(self, dtype, device, noncontiguous, shape):
 | |
|         numel = functools.reduce(operator.mul, shape, 1)
 | |
| 
 | |
|         t = torch.testing.make_tensor(shape, dtype=dtype, device=device, noncontiguous=noncontiguous)
 | |
|         self.assertEqual(t.is_contiguous(), not noncontiguous or numel < 2)
 | |
| 
 | |
|     @supported_dtypes
 | |
|     @parametrize(
 | |
|         "memory_format_and_shape",
 | |
|         [
 | |
|             (None, (2, 3, 4)),
 | |
|             (torch.contiguous_format, (2, 3, 4)),
 | |
|             (torch.channels_last, (2, 3, 4, 5)),
 | |
|             (torch.channels_last_3d, (2, 3, 4, 5, 6)),
 | |
|             (torch.preserve_format, (2, 3, 4)),
 | |
|         ],
 | |
|     )
 | |
|     def test_memory_format(self, dtype, device, memory_format_and_shape):
 | |
|         memory_format, shape = memory_format_and_shape
 | |
| 
 | |
|         t = torch.testing.make_tensor(shape, dtype=dtype, device=device, memory_format=memory_format)
 | |
| 
 | |
|         self.assertTrue(
 | |
|             t.is_contiguous(memory_format=torch.contiguous_format if memory_format is None else memory_format)
 | |
|         )
 | |
| 
 | |
|     @supported_dtypes
 | |
|     def test_noncontiguous_memory_format(self, dtype, device):
 | |
|         with self.assertRaisesRegex(ValueError, "`noncontiguous` and `memory_format` are mutually exclusive"):
 | |
|             torch.testing.make_tensor(
 | |
|                 (2, 3, 4, 5),
 | |
|                 dtype=dtype,
 | |
|                 device=device,
 | |
|                 noncontiguous=True,
 | |
|                 memory_format=torch.channels_last,
 | |
|             )
 | |
| 
 | |
|     @supported_dtypes
 | |
|     def test_exclude_zero(self, dtype, device):
 | |
|         t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, exclude_zero=True, low=-1, high=2)
 | |
| 
 | |
|         self.assertTrue((t != 0).all())
 | |
| 
 | |
|     @supported_dtypes
 | |
|     def test_low_high_smoke(self, dtype, device):
 | |
|         low_inclusive, high_exclusive = 0, 2
 | |
| 
 | |
|         t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
 | |
|         if dtype.is_complex:
 | |
|             t = torch.view_as_real(t)
 | |
| 
 | |
|         self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
 | |
| 
 | |
|     @supported_dtypes
 | |
|     def test_low_high_default_smoke(self, dtype, device):
 | |
|         low_inclusive, high_exclusive = {
 | |
|             torch.bool: (0, 2),
 | |
|             torch.uint8: (0, 10),
 | |
|             **dict.fromkeys([torch.int8, torch.int16, torch.int32, torch.int64], (-9, 10)),
 | |
|         }.get(dtype, (-9, 9))
 | |
| 
 | |
|         t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low_inclusive, high=high_exclusive)
 | |
|         if dtype.is_complex:
 | |
|             t = torch.view_as_real(t)
 | |
| 
 | |
|         self.assertTrue(((t >= low_inclusive) & (t < high_exclusive)).all())
 | |
| 
 | |
|     @parametrize("low_high", [(0, 0), (1, 0), (0, -1)])
 | |
|     @parametrize("value_types", list(itertools.product([int, float], repeat=2)))
 | |
|     @supported_dtypes
 | |
|     def test_low_ge_high(self, dtype, device, low_high, value_types):
 | |
|         low, high = (value_type(value) for value, value_type in zip(low_high, value_types))
 | |
| 
 | |
|         if low == high and (dtype.is_floating_point or dtype.is_complex):
 | |
|             with self.assertWarnsRegex(
 | |
|                     FutureWarning,
 | |
|                     "Passing `low==high` to `torch.testing.make_tensor` for floating or complex types is deprecated",
 | |
|             ):
 | |
|                 t = torch.testing.make_tensor(10_000, dtype=dtype, device=device, low=low, high=high)
 | |
|             self.assertEqual(t, torch.full_like(t, complex(low, low) if dtype.is_complex else low))
 | |
|         else:
 | |
|             with self.assertRaisesRegex(ValueError, "`low` must be less than `high`"):
 | |
|                 torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
 | |
| 
 | |
|     @supported_dtypes
 | |
|     @parametrize("low_high", [(None, torch.nan), (torch.nan, None), (torch.nan, torch.nan)])
 | |
|     def test_low_high_nan(self, dtype, device, low_high):
 | |
|         low, high = low_high
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, "`low` and `high` cannot be NaN"):
 | |
|             torch.testing.make_tensor(dtype=dtype, device=device, low=low, high=high)
 | |
| 
 | |
|     @supported_dtypes
 | |
|     def test_low_high_outside_valid_range(self, dtype, device):
 | |
|         make_tensor = functools.partial(torch.testing.make_tensor, dtype=dtype, device=device)
 | |
| 
 | |
|         def get_dtype_limits(dtype):
 | |
|             if dtype is torch.bool:
 | |
|                 return 0, 1
 | |
| 
 | |
|             info = (torch.finfo if dtype.is_floating_point or dtype.is_complex else torch.iinfo)(dtype)
 | |
|             # We are using integer bounds here, because otherwise it would be impossible to pass `low` and `high`
 | |
|             # outside their valid range. Python uses 64bit floating point numbers and thus trying to do something like
 | |
|             # `torch.ffinfo(torch.float64)max * 2` will always result in `inf`. On the flipside, Pythons `int` is
 | |
|             # unbounded.
 | |
|             return int(info.min), int(info.max)
 | |
| 
 | |
|         lowest_inclusive, highest_inclusive = get_dtype_limits(dtype)
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, ""):
 | |
|             low, high = (-2, -1) if lowest_inclusive == 0 else (lowest_inclusive * 4, lowest_inclusive * 2)
 | |
|             make_tensor(low=low, high=high)
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, ""):
 | |
|             make_tensor(low=highest_inclusive * 2, high=highest_inclusive * 4)
 | |
| 
 | |
|     @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
 | |
|     def test_low_high_boolean_integral1(self, dtype, device):
 | |
|         shape = (10_000,)
 | |
|         eps = 1e-4
 | |
| 
 | |
|         actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=-(1 - eps), high=1 - eps)
 | |
|         expected = torch.zeros(shape, dtype=dtype, device=device)
 | |
| 
 | |
|         torch.testing.assert_close(actual, expected)
 | |
| 
 | |
|     @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
 | |
|     def test_low_high_boolean_integral2(self, dtype, device):
 | |
|         shape = (10_000,)
 | |
|         if dtype is torch.bool:
 | |
|             low = 1
 | |
|         elif dtype is torch.int64:
 | |
|             # Due to its internals, `make_tensor` is not able to sample `torch.iinfo(torch.int64).max`
 | |
|             low = torch.iinfo(dtype).max - 1
 | |
|         else:
 | |
|             low = torch.iinfo(dtype).max
 | |
|         high = low + 1
 | |
| 
 | |
|         actual = torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
 | |
|         expected = torch.full(shape, low, dtype=dtype, device=device)
 | |
| 
 | |
|         torch.testing.assert_close(actual, expected)
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestMakeTensor, globals())
 | |
| 
 | |
| 
 | |
| def _get_test_names_for_test_class(test_cls):
 | |
|     """ Convenience function to get all test names for a given test class. """
 | |
|     test_names = [f'{test_cls.__name__}.{key}' for key in test_cls.__dict__
 | |
|                   if key.startswith('test_')]
 | |
|     return sorted(test_names)
 | |
| 
 | |
| 
 | |
| def _get_test_funcs_for_test_class(test_cls):
 | |
|     """ Convenience function to get all (test function, parametrized_name) pairs for a given test class. """
 | |
|     test_funcs = [(getattr(test_cls, key), key) for key in test_cls.__dict__ if key.startswith('test_')]
 | |
|     return test_funcs
 | |
| 
 | |
| 
 | |
| class TestTestParametrization(TestCase):
 | |
|     def test_default_names(self):
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("x", range(5))
 | |
|             def test_default_names(self, x):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
 | |
|             def test_two_things_default_names(self, x, y):
 | |
|                 pass
 | |
| 
 | |
|         instantiate_parametrized_tests(TestParametrized)
 | |
| 
 | |
|         expected_test_names = [
 | |
|             'TestParametrized.test_default_names_x_0',
 | |
|             'TestParametrized.test_default_names_x_1',
 | |
|             'TestParametrized.test_default_names_x_2',
 | |
|             'TestParametrized.test_default_names_x_3',
 | |
|             'TestParametrized.test_default_names_x_4',
 | |
|             'TestParametrized.test_two_things_default_names_x_1_y_2',
 | |
|             'TestParametrized.test_two_things_default_names_x_2_y_3',
 | |
|             'TestParametrized.test_two_things_default_names_x_3_y_4',
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(TestParametrized)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_name_fn(self):
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
 | |
|             def test_custom_names(self, bias):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("x", [1, 2], name_fn=str)
 | |
|             @parametrize("y", [3, 4], name_fn=str)
 | |
|             @parametrize("z", [5, 6], name_fn=str)
 | |
|             def test_three_things_composition_custom_names(self, x, y, z):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
 | |
|             def test_two_things_custom_names_alternate(self, x, y):
 | |
|                 pass
 | |
| 
 | |
|         instantiate_parametrized_tests(TestParametrized)
 | |
| 
 | |
|         expected_test_names = [
 | |
|             'TestParametrized.test_custom_names_bias',
 | |
|             'TestParametrized.test_custom_names_no_bias',
 | |
|             'TestParametrized.test_three_things_composition_custom_names_1_3_5',
 | |
|             'TestParametrized.test_three_things_composition_custom_names_1_3_6',
 | |
|             'TestParametrized.test_three_things_composition_custom_names_1_4_5',
 | |
|             'TestParametrized.test_three_things_composition_custom_names_1_4_6',
 | |
|             'TestParametrized.test_three_things_composition_custom_names_2_3_5',
 | |
|             'TestParametrized.test_three_things_composition_custom_names_2_3_6',
 | |
|             'TestParametrized.test_three_things_composition_custom_names_2_4_5',
 | |
|             'TestParametrized.test_three_things_composition_custom_names_2_4_6',
 | |
|             'TestParametrized.test_two_things_custom_names_alternate_1__2',
 | |
|             'TestParametrized.test_two_things_custom_names_alternate_1__3',
 | |
|             'TestParametrized.test_two_things_custom_names_alternate_1__4',
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(TestParametrized)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_reparametrize(self):
 | |
| 
 | |
|         def include_is_even_arg(test_name, param_kwargs):
 | |
|             x = param_kwargs["x"]
 | |
|             is_even = x % 2 == 0
 | |
|             new_param_kwargs = dict(param_kwargs)
 | |
|             new_param_kwargs["is_even"] = is_even
 | |
|             is_even_suffix = "_even" if is_even else "_odd"
 | |
|             new_test_name = f"{test_name}{is_even_suffix}"
 | |
|             yield (new_test_name, new_param_kwargs)
 | |
| 
 | |
|         def exclude_odds(test_name, param_kwargs):
 | |
|             x = param_kwargs["x"]
 | |
|             is_even = x % 2 == 0
 | |
|             yield None if not is_even else (test_name, param_kwargs)
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @reparametrize(parametrize("x", range(5)), include_is_even_arg)
 | |
|             def test_foo(self, x, is_even):
 | |
|                 pass
 | |
| 
 | |
|             @reparametrize(parametrize("x", range(5)), exclude_odds)
 | |
|             def test_bar(self, x):
 | |
|                 pass
 | |
| 
 | |
|         instantiate_parametrized_tests(TestParametrized)
 | |
| 
 | |
|         expected_test_names = [
 | |
|             'TestParametrized.test_bar_x_0',
 | |
|             'TestParametrized.test_bar_x_2',
 | |
|             'TestParametrized.test_bar_x_4',
 | |
|             'TestParametrized.test_foo_x_0_even',
 | |
|             'TestParametrized.test_foo_x_1_odd',
 | |
|             'TestParametrized.test_foo_x_2_even',
 | |
|             'TestParametrized.test_foo_x_3_odd',
 | |
|             'TestParametrized.test_foo_x_4_even',
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(TestParametrized)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_subtest_names(self):
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("bias", [subtest(True, name='bias'),
 | |
|                                   subtest(False, name='no_bias')])
 | |
|             def test_custom_names(self, bias):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("x,y", [subtest((1, 2), name='double'),
 | |
|                                  subtest((1, 3), name='triple'),
 | |
|                                  subtest((1, 4), name='quadruple')])
 | |
|             def test_two_things_custom_names(self, x, y):
 | |
|                 pass
 | |
| 
 | |
|         instantiate_parametrized_tests(TestParametrized)
 | |
| 
 | |
|         expected_test_names = [
 | |
|             'TestParametrized.test_custom_names_bias',
 | |
|             'TestParametrized.test_custom_names_no_bias',
 | |
|             'TestParametrized.test_two_things_custom_names_double',
 | |
|             'TestParametrized.test_two_things_custom_names_quadruple',
 | |
|             'TestParametrized.test_two_things_custom_names_triple',
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(TestParametrized)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_apply_param_specific_decorators(self):
 | |
|         # Test that decorators can be applied on a per-param basis.
 | |
| 
 | |
|         def test_dec(func):
 | |
|             func._decorator_applied = True
 | |
|             return func
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("x", [subtest(1, name='one'),
 | |
|                                subtest(2, name='two', decorators=[test_dec]),
 | |
|                                subtest(3, name='three')])
 | |
|             def test_param(self, x):
 | |
|                 pass
 | |
| 
 | |
|         instantiate_parametrized_tests(TestParametrized)
 | |
| 
 | |
|         for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
 | |
|             self.assertEqual(hasattr(test_func, '_decorator_applied'), name == 'test_param_two')
 | |
| 
 | |
|     def test_compose_param_specific_decorators(self):
 | |
|         # Test that multiple per-param decorators compose correctly.
 | |
| 
 | |
|         def test_dec(func):
 | |
|             func._decorator_applied = True
 | |
|             return func
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("x", [subtest(1),
 | |
|                                subtest(2, decorators=[test_dec]),
 | |
|                                subtest(3)])
 | |
|             @parametrize("y", [subtest(False, decorators=[test_dec]),
 | |
|                                subtest(True)])
 | |
|             def test_param(self, x, y):
 | |
|                 pass
 | |
| 
 | |
|         instantiate_parametrized_tests(TestParametrized)
 | |
| 
 | |
|         for test_func, name in _get_test_funcs_for_test_class(TestParametrized):
 | |
|             # Decorator should be applied whenever either x == 2 or y == False.
 | |
|             should_apply = ('x_2' in name) or ('y_False' in name)
 | |
|             self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
 | |
| 
 | |
|     def test_modules_decorator_misuse_error(self):
 | |
|         # Test that @modules errors out when used with instantiate_parametrized_tests().
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @modules(module_db)
 | |
|             def test_modules(self, module_info):
 | |
|                 pass
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
 | |
|             instantiate_parametrized_tests(TestParametrized)
 | |
| 
 | |
|     def test_ops_decorator_misuse_error(self):
 | |
|         # Test that @ops errors out when used with instantiate_parametrized_tests().
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @ops(op_db)
 | |
|             def test_ops(self, module_info):
 | |
|                 pass
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, 'intended to be used in a device-specific context'):
 | |
|             instantiate_parametrized_tests(TestParametrized)
 | |
| 
 | |
|     def test_multiple_handling_of_same_param_error(self):
 | |
|         # Test that multiple decorators handling the same param errors out.
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("x", range(3))
 | |
|             @parametrize("x", range(5))
 | |
|             def test_param(self, x):
 | |
|                 pass
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, 'multiple parametrization decorators'):
 | |
|             instantiate_parametrized_tests(TestParametrized)
 | |
| 
 | |
|     @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
 | |
|     def test_subtest_expected_failure(self, x):
 | |
|         if x == 2:
 | |
|             raise RuntimeError('Boom')
 | |
| 
 | |
|     @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
 | |
|     @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
 | |
|     def test_two_things_subtest_expected_failure(self, x, y):
 | |
|         if x == 1 or y == 6:
 | |
|             raise RuntimeError('Boom')
 | |
| 
 | |
| 
 | |
| class TestTestParametrizationDeviceType(TestCase):
 | |
|     def test_unparametrized_names(self, device):
 | |
|         # This test exists to protect against regressions in device / dtype test naming
 | |
|         # due to parametrization logic.
 | |
| 
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             def test_device_specific(self, device):
 | |
|                 pass
 | |
| 
 | |
|             @dtypes(torch.float32, torch.float64)
 | |
|             def test_device_dtype_specific(self, device, dtype):
 | |
|                 pass
 | |
| 
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
|         expected_test_names = [name.format(device_cls.__name__, device) for name in (
 | |
|             '{}.test_device_dtype_specific_{}_float32',
 | |
|             '{}.test_device_dtype_specific_{}_float64',
 | |
|             '{}.test_device_specific_{}')
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(device_cls)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_empty_param_names(self, device):
 | |
|         # If no param names are passed, ensure things still work without parametrization.
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("", [])
 | |
|             def test_foo(self, device):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("", range(5))
 | |
|             def test_bar(self, device):
 | |
|                 pass
 | |
| 
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
|         expected_test_names = [name.format(device_cls.__name__, device) for name in (
 | |
|             '{}.test_bar_{}',
 | |
|             '{}.test_foo_{}')
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(device_cls)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_empty_param_list(self, device):
 | |
|         # If no param values are passed, ensure a helpful error message is thrown.
 | |
|         # In the wild, this could indicate reuse of an exhausted generator.
 | |
|         device = self.device_type
 | |
| 
 | |
|         generator = (a for a in range(5))
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("x", generator)
 | |
|             def test_foo(self, device, x):
 | |
|                 pass
 | |
| 
 | |
|             # Reuse generator from first test function.
 | |
|             @parametrize("y", generator)
 | |
|             def test_bar(self, device, y):
 | |
|                 pass
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, 'An empty arg_values was passed'):
 | |
|             locals_dict = dict(locals())
 | |
|             instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|     def test_default_names(self, device):
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("x", range(5))
 | |
|             def test_default_names(self, device, x):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("x,y", [(1, 2), (2, 3), (3, 4)])
 | |
|             def test_two_things_default_names(self, device, x, y):
 | |
|                 pass
 | |
| 
 | |
| 
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
|         expected_test_names = [name.format(device_cls.__name__, device) for name in (
 | |
|             '{}.test_default_names_x_0_{}',
 | |
|             '{}.test_default_names_x_1_{}',
 | |
|             '{}.test_default_names_x_2_{}',
 | |
|             '{}.test_default_names_x_3_{}',
 | |
|             '{}.test_default_names_x_4_{}',
 | |
|             '{}.test_two_things_default_names_x_1_y_2_{}',
 | |
|             '{}.test_two_things_default_names_x_2_y_3_{}',
 | |
|             '{}.test_two_things_default_names_x_3_y_4_{}')
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(device_cls)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_default_name_non_primitive(self, device):
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("x", [1, .5, "foo", object()])
 | |
|             def test_default_names(self, device, x):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("x,y", [(1, object()), (object(), .5), (object(), object())])
 | |
|             def test_two_things_default_names(self, device, x, y):
 | |
|                 pass
 | |
| 
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
|         expected_test_names = sorted(name.format(device_cls.__name__, device) for name in (
 | |
|             '{}.test_default_names_x_1_{}',
 | |
|             '{}.test_default_names_x_0_5_{}',
 | |
|             '{}.test_default_names_x_foo_{}',
 | |
|             '{}.test_default_names_x3_{}',
 | |
|             '{}.test_two_things_default_names_x_1_y0_{}',
 | |
|             '{}.test_two_things_default_names_x1_y_0_5_{}',
 | |
|             '{}.test_two_things_default_names_x2_y2_{}')
 | |
|         )
 | |
|         test_names = _get_test_names_for_test_class(device_cls)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_name_fn(self, device):
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
 | |
|             def test_custom_names(self, device, bias):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("x", [1, 2], name_fn=str)
 | |
|             @parametrize("y", [3, 4], name_fn=str)
 | |
|             @parametrize("z", [5, 6], name_fn=str)
 | |
|             def test_three_things_composition_custom_names(self, device, x, y, z):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: f'{x}__{y}')
 | |
|             def test_two_things_custom_names_alternate(self, device, x, y):
 | |
|                 pass
 | |
| 
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
|         expected_test_names = [name.format(device_cls.__name__, device) for name in (
 | |
|             '{}.test_custom_names_bias_{}',
 | |
|             '{}.test_custom_names_no_bias_{}',
 | |
|             '{}.test_three_things_composition_custom_names_1_3_5_{}',
 | |
|             '{}.test_three_things_composition_custom_names_1_3_6_{}',
 | |
|             '{}.test_three_things_composition_custom_names_1_4_5_{}',
 | |
|             '{}.test_three_things_composition_custom_names_1_4_6_{}',
 | |
|             '{}.test_three_things_composition_custom_names_2_3_5_{}',
 | |
|             '{}.test_three_things_composition_custom_names_2_3_6_{}',
 | |
|             '{}.test_three_things_composition_custom_names_2_4_5_{}',
 | |
|             '{}.test_three_things_composition_custom_names_2_4_6_{}',
 | |
|             '{}.test_two_things_custom_names_alternate_1__2_{}',
 | |
|             '{}.test_two_things_custom_names_alternate_1__3_{}',
 | |
|             '{}.test_two_things_custom_names_alternate_1__4_{}')
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(device_cls)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_subtest_names(self, device):
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @parametrize("bias", [subtest(True, name='bias'),
 | |
|                                   subtest(False, name='no_bias')])
 | |
|             def test_custom_names(self, device, bias):
 | |
|                 pass
 | |
| 
 | |
|             @parametrize("x,y", [subtest((1, 2), name='double'),
 | |
|                                  subtest((1, 3), name='triple'),
 | |
|                                  subtest((1, 4), name='quadruple')])
 | |
|             def test_two_things_custom_names(self, device, x, y):
 | |
|                 pass
 | |
| 
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
|         expected_test_names = [name.format(device_cls.__name__, device) for name in (
 | |
|             '{}.test_custom_names_bias_{}',
 | |
|             '{}.test_custom_names_no_bias_{}',
 | |
|             '{}.test_two_things_custom_names_double_{}',
 | |
|             '{}.test_two_things_custom_names_quadruple_{}',
 | |
|             '{}.test_two_things_custom_names_triple_{}')
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(device_cls)
 | |
|         self.assertEqual(expected_test_names, test_names)
 | |
| 
 | |
|     def test_ops_composition_names(self, device):
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @ops(op_db)
 | |
|             @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
 | |
|             def test_op_parametrized(self, device, dtype, op, flag):
 | |
|                 pass
 | |
| 
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
|         expected_test_names = []
 | |
|         for op in op_db:
 | |
|             for dtype in op.supported_dtypes(torch.device(device).type):
 | |
|                 for flag_part in ('flag_disabled', 'flag_enabled'):
 | |
|                     expected_name = f'{device_cls.__name__}.test_op_parametrized_{op.formatted_name}_{flag_part}_{device}_{dtype_name(dtype)}'  # noqa: B950
 | |
|                     expected_test_names.append(expected_name)
 | |
| 
 | |
|         test_names = _get_test_names_for_test_class(device_cls)
 | |
|         self.assertEqual(sorted(expected_test_names), sorted(test_names))
 | |
| 
 | |
|     def test_modules_composition_names(self, device):
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @modules(module_db)
 | |
|             @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
 | |
|             def test_module_parametrized(self, device, dtype, module_info, training, flag):
 | |
|                 pass
 | |
| 
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
|         expected_test_names = []
 | |
|         for module_info in module_db:
 | |
|             for dtype in module_info.dtypes:
 | |
|                 for flag_part in ('flag_disabled', 'flag_enabled'):
 | |
|                     expected_train_modes = (
 | |
|                         ['train_mode', 'eval_mode'] if module_info.train_and_eval_differ else [''])
 | |
|                     for training_part in expected_train_modes:
 | |
|                         expected_name = '{}.test_module_parametrized_{}{}_{}_{}_{}'.format(
 | |
|                             device_cls.__name__, module_info.formatted_name,
 | |
|                             '_' + training_part if len(training_part) > 0 else '',
 | |
|                             flag_part, device, dtype_name(dtype))
 | |
|                         expected_test_names.append(expected_name)
 | |
| 
 | |
|         test_names = _get_test_names_for_test_class(device_cls)
 | |
|         self.assertEqual(sorted(expected_test_names), sorted(test_names))
 | |
| 
 | |
|     def test_ops_decorator_applies_op_and_param_specific_decorators(self, device):
 | |
|         # Test that decorators can be applied on a per-op / per-param basis.
 | |
| 
 | |
|         # Create a test op, OpInfo entry, and decorator to apply.
 | |
|         def test_op(x):
 | |
|             return -x
 | |
| 
 | |
|         def test_dec(func):
 | |
|             func._decorator_applied = True
 | |
|             return func
 | |
| 
 | |
|         test_op_info = OpInfo(
 | |
|             'test_op',
 | |
|             op=test_op,
 | |
|             dtypes=floating_types(),
 | |
|             sample_inputs_func=lambda _: [],
 | |
|             decorators=[
 | |
|                 DecorateInfo(test_dec, 'TestParametrized', 'test_op_param',
 | |
|                              device_type='cpu', dtypes=[torch.float64],
 | |
|                              active_if=lambda p: p['x'] == 2)
 | |
|             ])
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @ops(op_db + [test_op_info])
 | |
|             @parametrize("x", [2, 3])
 | |
|             def test_op_param(self, device, dtype, op, x):
 | |
|                 pass
 | |
| 
 | |
|             @ops(op_db + [test_op_info])
 | |
|             @parametrize("y", [
 | |
|                 subtest(4),
 | |
|                 subtest(5, decorators=[test_dec])])
 | |
|             def test_other(self, device, dtype, op, y):
 | |
|                 pass
 | |
| 
 | |
|             @decorateIf(test_dec, lambda p: p['dtype'] == torch.int16)
 | |
|             @ops(op_db)
 | |
|             def test_three(self, device, dtype, op):
 | |
|                 pass
 | |
| 
 | |
|         device = self.device_type
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
| 
 | |
|         for test_func, name in _get_test_funcs_for_test_class(device_cls):
 | |
|             should_apply = (name == 'test_op_param_test_op_x_2_cpu_float64' or
 | |
|                             ('test_other' in name and 'y_5' in name) or
 | |
|                             ('test_three' in name and name.endswith('_int16')))
 | |
|             self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
 | |
| 
 | |
|     def test_modules_decorator_applies_module_and_param_specific_decorators(self, device):
 | |
|         # Test that decorators can be applied on a per-module / per-param basis.
 | |
| 
 | |
|         # Create a test module, ModuleInfo entry, and decorator to apply.
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.x = torch.nn.Parameter(torch.randn(3))
 | |
| 
 | |
|             def forward(self, y):
 | |
|                 return self.x + y
 | |
| 
 | |
|         def test_dec(func):
 | |
|             func._decorator_applied = True
 | |
|             return func
 | |
| 
 | |
|         test_module_info = ModuleInfo(
 | |
|             TestModule,
 | |
|             module_inputs_func=lambda _: [],
 | |
|             decorators=[
 | |
|                 DecorateInfo(test_dec, 'TestParametrized', 'test_module_param',
 | |
|                              device_type='cpu', dtypes=[torch.float64],
 | |
|                              active_if=lambda p: p['x'] == 2)
 | |
|             ])
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @modules(module_db + [test_module_info])
 | |
|             @parametrize("x", [2, 3])
 | |
|             def test_module_param(self, device, dtype, module_info, training, x):
 | |
|                 pass
 | |
| 
 | |
|             @modules(module_db + [test_module_info])
 | |
|             @parametrize("y", [
 | |
|                 subtest(4),
 | |
|                 subtest(5, decorators=[test_dec])])
 | |
|             def test_other(self, device, dtype, module_info, training, y):
 | |
|                 pass
 | |
| 
 | |
|             @decorateIf(test_dec, lambda p: p['dtype'] == torch.float64)
 | |
|             @modules(module_db)
 | |
|             def test_three(self, device, dtype, module_info):
 | |
|                 pass
 | |
| 
 | |
|         device = self.device_type
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
| 
 | |
|         for test_func, name in _get_test_funcs_for_test_class(device_cls):
 | |
|             should_apply = (name == 'test_module_param_TestModule_x_2_cpu_float64' or
 | |
|                             ('test_other' in name and 'y_5' in name) or
 | |
|                             ('test_three' in name and name.endswith('float64')))
 | |
|             self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
 | |
| 
 | |
|     def test_param_specific_decoration(self, device):
 | |
| 
 | |
|         def test_dec(func):
 | |
|             func._decorator_applied = True
 | |
|             return func
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @decorateIf(test_dec, lambda params: params["x"] == 1 and params["y"])
 | |
|             @parametrize("x", range(5))
 | |
|             @parametrize("y", [False, True])
 | |
|             def test_param(self, x, y):
 | |
|                 pass
 | |
| 
 | |
|         device = self.device_type
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
| 
 | |
|         for test_func, name in _get_test_funcs_for_test_class(device_cls):
 | |
|             should_apply = ('test_param_x_1_y_True' in name)
 | |
|             self.assertEqual(hasattr(test_func, '_decorator_applied'), should_apply)
 | |
| 
 | |
|     def test_dtypes_composition_valid(self, device):
 | |
|         # Test checks that @parametrize and @dtypes compose as expected when @parametrize
 | |
|         # doesn't set dtype.
 | |
| 
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @dtypes(torch.float32, torch.float64)
 | |
|             @parametrize("x", range(3))
 | |
|             def test_parametrized(self, x, dtype):
 | |
|                 pass
 | |
| 
 | |
|         locals_dict = dict(locals())
 | |
|         instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         device_cls = locals_dict[f'TestParametrized{device.upper()}']
 | |
|         expected_test_names = [name.format(device_cls.__name__, device) for name in (
 | |
|             '{}.test_parametrized_x_0_{}_float32',
 | |
|             '{}.test_parametrized_x_0_{}_float64',
 | |
|             '{}.test_parametrized_x_1_{}_float32',
 | |
|             '{}.test_parametrized_x_1_{}_float64',
 | |
|             '{}.test_parametrized_x_2_{}_float32',
 | |
|             '{}.test_parametrized_x_2_{}_float64')
 | |
|         ]
 | |
|         test_names = _get_test_names_for_test_class(device_cls)
 | |
|         self.assertEqual(sorted(expected_test_names), sorted(test_names))
 | |
| 
 | |
|     def test_dtypes_composition_invalid(self, device):
 | |
|         # Test checks that @dtypes cannot be composed with parametrization decorators when they
 | |
|         # also try to set dtype.
 | |
| 
 | |
|         device = self.device_type
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @dtypes(torch.float32, torch.float64)
 | |
|             @parametrize("dtype", [torch.int32, torch.int64])
 | |
|             def test_parametrized(self, dtype):
 | |
|                 pass
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
 | |
|             locals_dict = dict(locals())
 | |
|             instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|         # Verify proper error behavior with @ops + @dtypes, as both try to set dtype.
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @dtypes(torch.float32, torch.float64)
 | |
|             @ops(op_db)
 | |
|             def test_parametrized(self, op, dtype):
 | |
|                 pass
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
 | |
|             locals_dict = dict(locals())
 | |
|             instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|     def test_multiple_handling_of_same_param_error(self, device):
 | |
|         # Test that multiple decorators handling the same param errors out.
 | |
|         # Both @modules and @ops handle the dtype param.
 | |
| 
 | |
|         class TestParametrized(TestCase):
 | |
|             @ops(op_db)
 | |
|             @modules(module_db)
 | |
|             def test_param(self, device, dtype, op, module_info, training):
 | |
|                 pass
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "handled multiple times"):
 | |
|             locals_dict = dict(locals())
 | |
|             instantiate_device_type_tests(TestParametrized, locals_dict, only_for=device)
 | |
| 
 | |
|     @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3])
 | |
|     def test_subtest_expected_failure(self, device, x):
 | |
|         if x == 2:
 | |
|             raise RuntimeError('Boom')
 | |
| 
 | |
|     @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3])
 | |
|     @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
 | |
|     def test_two_things_subtest_expected_failure(self, device, x, y):
 | |
|         if x == 1 or y == 6:
 | |
|             raise RuntimeError('Boom')
 | |
| 
 | |
| 
 | |
| instantiate_parametrized_tests(TestTestParametrization)
 | |
| instantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
 | |
| 
 | |
| 
 | |
| class TestImports(TestCase):
 | |
|     @classmethod
 | |
|     def _check_python_output(cls, program) -> str:
 | |
|         return subprocess.check_output(
 | |
|             [sys.executable, "-W", "always", "-c", program],
 | |
|             stderr=subprocess.STDOUT,
 | |
|             # On Windows, opening the subprocess with the default CWD makes `import torch`
 | |
|             # fail, so just set CWD to this script's directory
 | |
|             cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
 | |
| 
 | |
|     # The test is flaky on ROCm/XPU and has been open and close multiple times
 | |
|     # https://github.com/pytorch/pytorch/issues/110040
 | |
|     @skipIfRocm
 | |
|     def test_circular_dependencies(self) -> None:
 | |
|         """ Checks that all modules inside torch can be imported
 | |
|         Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """
 | |
|         ignored_modules = ["torch.utils.tensorboard",  # deps on tensorboard
 | |
|                            "torch.distributed.elastic.rendezvous",  # depps on etcd
 | |
|                            "torch.backends._coreml",  # depends on pycoreml
 | |
|                            "torch.contrib.",  # something weird
 | |
|                            "torch.testing._internal.distributed.",  # just fails
 | |
|                            "torch.ao.pruning._experimental.",  # depends on pytorch_lightning, not user-facing
 | |
|                            "torch.onnx._internal",  # depends on onnx-script
 | |
|                            "torch._inductor.runtime.triton_helpers",  # depends on triton
 | |
|                            "torch._inductor.codegen.cuda",  # depends on cutlass
 | |
|                            "torch.distributed.benchmarks",  # depends on RPC and DDP Optim
 | |
|                            "torch.distributed.examples",  # requires CUDA and torchvision
 | |
|                            "torch.distributed.tensor.examples",  # example scripts
 | |
|                            "torch.distributed._tools.sac_ilp",  # depends on pulp
 | |
|                            "torch.csrc",  # files here are devtools, not part of torch
 | |
|                            "torch.include",  # torch include files after install
 | |
|                            ]
 | |
|         if IS_WINDOWS or IS_MACOS or IS_JETSON:
 | |
|             # Distributed should be importable on Windows(except nn.api.), but not on Mac
 | |
|             if IS_MACOS or IS_JETSON:
 | |
|                 ignored_modules.append("torch.distributed.")
 | |
|             else:
 | |
|                 ignored_modules.append("torch.distributed.nn.api.")
 | |
|                 ignored_modules.append("torch.distributed.optim.")
 | |
|                 ignored_modules.append("torch.distributed.rpc.")
 | |
|             ignored_modules.append("torch.testing._internal.dist_utils")
 | |
|             # And these both end up with transitive dependencies on distributed
 | |
|             ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop")
 | |
|             ignored_modules.append("torch.testing._internal.common_fsdp")
 | |
|             ignored_modules.append("torch.testing._internal.common_distributed")
 | |
| 
 | |
|         torch_dir = os.path.dirname(torch.__file__)
 | |
|         for base, _, files in os.walk(torch_dir):
 | |
|             prefix = os.path.relpath(base, os.path.dirname(torch_dir)).replace(os.path.sep, ".")
 | |
|             for f in files:
 | |
|                 if not f.endswith(".py"):
 | |
|                     continue
 | |
|                 mod_name = f"{prefix}.{f[:-3]}" if f != "__init__.py" else prefix
 | |
|                 # Do not attempt to import executable modules
 | |
|                 if f == "__main__.py":
 | |
|                     continue
 | |
|                 if any(mod_name.startswith(x) for x in ignored_modules):
 | |
|                     continue
 | |
|                 try:
 | |
|                     mod = importlib.import_module(mod_name)
 | |
|                 except Exception as e:
 | |
|                     raise RuntimeError(f"Failed to import {mod_name}: {e}") from e
 | |
|                 self.assertTrue(inspect.ismodule(mod))
 | |
| 
 | |
|     def test_lazy_imports_are_lazy(self) -> None:
 | |
|         out = self._check_python_output("import sys;import torch;print(all(x not in sys.modules for x in torch._lazy_modules))")
 | |
|         self.assertEqual(out.strip(), "True")
 | |
| 
 | |
|     def test_no_warning_on_import(self) -> None:
 | |
|         out = self._check_python_output("import torch")
 | |
|         self.assertEqual(out, "")
 | |
| 
 | |
|     def test_not_import_sympy(self) -> None:
 | |
|         out = self._check_python_output("import torch;import sys;print('sympy' not in sys.modules)")
 | |
|         self.assertEqual(out.strip(), "True",
 | |
|                          "PyTorch should not depend on SymPy at import time as importing SymPy is *very* slow.\n"
 | |
|                          "See the beginning of the following blog post for how to profile and find which file is importing sympy:\n"
 | |
|                          "https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589\n\n"
 | |
|                          "If you hit this error, you may want to:\n"
 | |
|                          "  - Refactor your code to avoid depending on sympy files you may not need to depend\n"
 | |
|                          "  - Use TYPE_CHECKING if you are using sympy + strings if you are using sympy on type annotations\n"
 | |
|                          "  - Import things that depend on SymPy locally")
 | |
| 
 | |
|     @parametrize('path', ['torch', 'functorch'])
 | |
|     def test_no_mutate_global_logging_on_import(self, path) -> None:
 | |
|         # Calling logging.basicConfig, among other things, modifies the global
 | |
|         # logging state. It is not OK to modify the global logging state on
 | |
|         # `import torch` (or other submodules we own) because users do not expect it.
 | |
|         expected = string.ascii_lowercase
 | |
|         commands = [
 | |
|             'import logging',
 | |
|             f'import {path}',
 | |
|             '_logger = logging.getLogger("torch_test_testing")',
 | |
|             'logging.root.addHandler(logging.StreamHandler())',
 | |
|             'logging.root.setLevel(logging.INFO)',
 | |
|             f'_logger.info("{expected}")'
 | |
|         ]
 | |
|         out = self._check_python_output("; ".join(commands))
 | |
|         self.assertEqual(out.strip(), expected)
 | |
| 
 | |
| class TestOpInfos(TestCase):
 | |
|     def test_sample_input(self) -> None:
 | |
|         a, b, c, d, e = (object() for _ in range(5))
 | |
| 
 | |
|         # Construction with natural syntax
 | |
|         s = SampleInput(a, b, c, d=d, e=e)
 | |
|         assert s.input is a
 | |
|         assert s.args == (b, c)
 | |
|         assert s.kwargs == dict(d=d, e=e)
 | |
| 
 | |
|         # Construction with explicit args and kwargs
 | |
|         s = SampleInput(a, args=(b,), kwargs=dict(c=c, d=d, e=e))
 | |
|         assert s.input is a
 | |
|         assert s.args == (b,)
 | |
|         assert s.kwargs == dict(c=c, d=d, e=e)
 | |
| 
 | |
|         # Construction with a mixed form will error
 | |
|         with self.assertRaises(AssertionError):
 | |
|             s = SampleInput(a, b, c, args=(d, e))
 | |
| 
 | |
|         with self.assertRaises(AssertionError):
 | |
|             s = SampleInput(a, b, c, kwargs=dict(d=d, e=e))
 | |
| 
 | |
|         with self.assertRaises(AssertionError):
 | |
|             s = SampleInput(a, args=(b, c), d=d, e=e)
 | |
| 
 | |
|         with self.assertRaises(AssertionError):
 | |
|             s = SampleInput(a, b, c=c, kwargs=dict(d=d, e=e))
 | |
| 
 | |
|         # Mixing metadata into "natural" construction will error
 | |
|         with self.assertRaises(AssertionError):
 | |
|             s = SampleInput(a, b, name="foo")
 | |
| 
 | |
|         with self.assertRaises(AssertionError):
 | |
|             s = SampleInput(a, b, output_process_fn_grad=lambda x: x)
 | |
| 
 | |
|         with self.assertRaises(AssertionError):
 | |
|             s = SampleInput(a, b, broadcasts_input=True)
 | |
| 
 | |
|         # But when only input is given, metadata is allowed for backward
 | |
|         # compatibility
 | |
|         s = SampleInput(a, broadcasts_input=True)
 | |
|         assert s.input is a
 | |
|         assert s.broadcasts_input
 | |
| 
 | |
|     def test_sample_input_metadata(self) -> None:
 | |
|         a, b = (object() for _ in range(2))
 | |
|         s1 = SampleInput(a, b=b)
 | |
|         self.assertIs(s1.output_process_fn_grad(None), None)
 | |
|         self.assertFalse(s1.broadcasts_input)
 | |
|         self.assertEqual(s1.name, "")
 | |
| 
 | |
|         s2 = s1.with_metadata(
 | |
|             output_process_fn_grad=lambda x: a,
 | |
|             broadcasts_input=True,
 | |
|             name="foo",
 | |
|         )
 | |
|         self.assertIs(s1, s2)
 | |
|         self.assertIs(s2.output_process_fn_grad(None), a)
 | |
|         self.assertTrue(s2.broadcasts_input)
 | |
|         self.assertEqual(s2.name, "foo")
 | |
| 
 | |
| 
 | |
| # Tests that validate the various sample generating functions on each OpInfo.
 | |
| class TestOpInfoSampleFunctions(TestCase):
 | |
| 
 | |
|     @ops(op_db, dtypes=OpDTypes.any_one)
 | |
|     def test_opinfo_sample_generators(self, device, dtype, op):
 | |
|         # Test op.sample_inputs doesn't generate multiple samples when called
 | |
|         samples = op.sample_inputs(device, dtype)
 | |
|         self.assertIsInstance(samples, Iterator)
 | |
| 
 | |
|     @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one)
 | |
|     def test_opinfo_reference_generators(self, device, dtype, op):
 | |
|         # Test op.reference_inputs doesn't generate multiple samples when called
 | |
|         samples = op.reference_inputs(device, dtype)
 | |
|         self.assertIsInstance(samples, Iterator)
 | |
| 
 | |
|     @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
 | |
|     def test_opinfo_error_generators(self, device, op):
 | |
|         # Test op.error_inputs doesn't generate multiple inputs when called
 | |
|         samples = op.error_inputs(device)
 | |
|         self.assertIsInstance(samples, Iterator)
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestOpInfoSampleFunctions, globals())
 | |
| instantiate_parametrized_tests(TestImports)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 |