diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 0bd2a9e4d527..b15fb21944f0 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -20,7 +20,7 @@ from torch.testing._internal.common_utils import ( from torch.testing._internal.common_device_type import ( expectedFailureMeta, instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyNativeDeviceTypes, - skipCUDAIfRocm, skipIf, ops, OpDTypes) + skipCUDAIfRocm, skipIf, ops, OpDTypes, skipMeta) from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( all_types_and_complex_and, integral_types_and, get_all_dtypes, get_all_int_dtypes, get_all_math_dtypes, @@ -1497,6 +1497,7 @@ class TestBinaryUfuncs(TestCase): self._test_pow(base, second_exp) @onlyNativeDeviceTypes + @skipMeta def test_pow_scalar_type_promotion(self, device): # Test against a scalar and non-scalar input inputs = [17, [17]] @@ -3393,6 +3394,7 @@ class TestBinaryUfuncs(TestCase): TypeError, 'received an invalid combination of arguments'): actual = torch.cumulative_trapezoid(torch.randn((3, 3)), x=torch.randn((3, 3)), dx=3) + @skipMeta @dtypes(torch.double) def test_pow_scalar_overloads_mem_overlap(self, device, dtype): sz = 3 diff --git a/test/test_modules.py b/test/test_modules.py index 448f8f5fa751..b3d658a5bc5d 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -8,7 +8,7 @@ from operator import methodcaller import torch from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol) + instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta) from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_utils import ( TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck) @@ -233,6 +233,7 @@ class TestModule(TestCase): @modules([module_info for module_info in module_db if 'inplace' in signature(module_info.module_cls).parameters]) + @skipMeta def test_check_inplace(self, device, dtype, module_info): # Check if the inplace variant of the module gives the same result as the out of place # variant. diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 62d595373b3a..68ddec147118 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -14,7 +14,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, torch_to_numpy_dtype_dict, slowTest, - TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS) + TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize) from torch.testing._internal.common_device_type import ( expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes, onlyCPU, largeTensorTest, precisionOverride, dtypes, @@ -2786,36 +2786,44 @@ class TestTensorCreation(TestCase): sparse_size, dtype=torch.float64) self.assertEqual(sparse_with_dtype.device, torch.device('cpu')) + def _test_signal_window_functions(self, name, dtype, device, **kwargs): + import scipy.signal as signal + + torch_method = getattr(torch, name + '_window') + if not dtype.is_floating_point: + with self.assertRaisesRegex(RuntimeError, r'floating point'): + torch_method(3, dtype=dtype) + return + for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]: + for periodic in [True, False]: + res = torch_method(size, periodic=periodic, **kwargs, device=device, dtype=dtype) + # NB: scipy always returns a float64 result + ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic)) + self.assertEqual(res, ref, exact_dtype=False) + with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'): + torch_method(3, layout=torch.sparse_coo) + self.assertTrue(torch_method(3, requires_grad=True).requires_grad) + self.assertFalse(torch_method(3).requires_grad) + @onlyNativeDeviceTypes @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3}) @unittest.skipIf(not TEST_SCIPY, "Scipy not found") @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) @dtypes(torch.float, torch.double, torch.long) - def test_signal_window_functions(self, device, dtype): - import scipy.signal as signal - - def test(name, kwargs): - torch_method = getattr(torch, name + '_window') - if not dtype.is_floating_point: - with self.assertRaisesRegex(RuntimeError, r'floating point'): - torch_method(3, dtype=dtype) - return - for size in [0, 1, 2, 5, 10, 50, 100, 1024, 2048]: - for periodic in [True, False]: - res = torch_method(size, periodic=periodic, **kwargs, device=device, dtype=dtype) - # NB: scipy always returns a float64 result - ref = torch.from_numpy(signal.get_window((name, *(kwargs.values())), size, fftbins=periodic)) - self.assertEqual(res, ref, exact_dtype=False) - with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'): - torch_method(3, layout=torch.sparse_coo) - self.assertTrue(torch_method(3, requires_grad=True).requires_grad) - self.assertFalse(torch_method(3).requires_grad) - - for window in ['hann', 'hamming', 'bartlett', 'blackman']: - test(window, kwargs={}) + @parametrize("window", ['hann', 'hamming', 'bartlett', 'blackman']) + def test_signal_window_functions(self, device, dtype, window): + self._test_signal_window_functions(window, dtype, device) + @onlyNativeDeviceTypes + # See https://github.com/pytorch/pytorch/issues/72630 + @skipMeta + @precisionOverride({torch.bfloat16: 5e-2, torch.half: 1e-3}) + @unittest.skipIf(not TEST_SCIPY, "Scipy not found") + @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) + @dtypes(torch.float, torch.double, torch.long) + def test_kaiser_window(self, device, dtype): for num_test in range(50): - test('kaiser', kwargs={'beta': random.random() * 30}) + self._test_signal_window_functions('kaiser', dtype, device, beta=random.random() * 30) def test_tensor_factories_empty(self, device): # ensure we can create empty tensors from each factory function diff --git a/test/test_testing.py b/test/test_testing.py index 3cfef8cee395..1fe06a229340 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -574,11 +574,10 @@ class TestAssertClose(TestCase): def test_meta(self): actual = torch.empty((2, 2), device="meta") - expected = actual.clone() + expected = torch.empty((2, 2), device="meta") for fn in assert_close_with_inputs(actual, expected): - with self.assertRaisesRegex(NotImplementedError, "meta"): - fn() + fn() def test_mismatching_layout(self): strided = torch.empty((2, 2)) diff --git a/test/test_torch.py b/test/test_torch.py index e2422d1477d6..16cf9e2e61f9 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -37,7 +37,7 @@ from torch.testing._internal.common_utils import ( skipCUDAMemoryLeakCheckIf, BytesIOContext, noarchTest, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, - skipIfNotRegistered, bytes_to_scalar) + skipIfNotRegistered, bytes_to_scalar, parametrize) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( expectedFailureMeta, @@ -793,158 +793,158 @@ class TestTorchDeviceType(TestCase): self.assertFalse(t1.is_set_to(t2)) self.assertFalse(t2.is_set_to(t1)) - def test_broadcast(self, device): - - # all functions - fns = { - "dist", "atan2", "pow", "lerp", "add", - "sub", "mul", "div", "fmod", "remainder", - "eq", "ge", "gt", "le", "lt", "max", "min", "ne", - "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", - "map", "map2", "copy" - } + # See https://github.com/pytorch/pytorch/issues/72650 + @skipMeta + @parametrize( + "fn", + [ + "dist", "atan2", "pow", "lerp", "add", "sub", "mul", "div", "fmod", "remainder", "eq", "ge", "gt", "le", + "lt", "max", "min", "ne", "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", "map", + "map2", "copy", + ], + ) + def test_broadcast(self, fn, device): # functions with three tensor arguments fns_3_args = {"map2"} fns_value_kwarg = {"addcdiv", "addcmul"} - for fn in fns: - (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() - full1d = torch.randn(*dims_full, device=device).flatten().float() - small = torch.randn(*dims_small, device=device).float() - large = torch.randn(*dims_large, device=device).float() - small_expanded = small.expand(*dims_full) - large_expanded = large.expand(*dims_full) - small2 = None - small2_expanded = None - if fn in fns_3_args or fn in fns_value_kwarg: - # create another smaller tensor - (dims_small2, _, _) = self._select_broadcastable_dims(dims_full) - small2 = torch.randn(*dims_small2, device=device).float() - small2_expanded = small2.expand(*dims_full) + (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() + full1d = torch.randn(*dims_full, device=device).flatten().float() + small = torch.randn(*dims_small, device=device).float() + large = torch.randn(*dims_large, device=device).float() + small_expanded = small.expand(*dims_full) + large_expanded = large.expand(*dims_full) + small2 = None + small2_expanded = None + if fn in fns_3_args or fn in fns_value_kwarg: + # create another smaller tensor + (dims_small2, _, _) = self._select_broadcastable_dims(dims_full) + small2 = torch.randn(*dims_small2, device=device).float() + small2_expanded = small2.expand(*dims_full) - if small.is_cuda and fn in ['map', 'map2']: - # map and map2 are not implementd on CUDA tensors - continue + if small.is_cuda and fn in ['map', 'map2']: + # map and map2 are not implementd on CUDA tensors + return - if hasattr(large_expanded, fn): - # run through tensor versions of functions - # and verify fully expanded inputs give same results - expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} + if hasattr(large_expanded, fn): + # run through tensor versions of functions + # and verify fully expanded inputs give same results + expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} - def tensorfn(myfn, t1, t2): - if fn == "lerp": - return myfn(t1, 0.5) - elif fn == "masked_select": - return myfn(t1 < 0) - elif fn == "masked_scatter": - return myfn(t1 < 0.5, full1d) - elif fn == "masked_fill": - return myfn(t1 < 0.5, 1.0) - elif fn in fns_3_args: - return myfn(1, t1, t2) - elif fn in fns_value_kwarg: - return myfn(t1, t2, value=1) - else: - return myfn(t1) - - # test various orders - for first, second, third in [(large, small, small2), (small, large, small2), - (small2, small, large), (small2, large, small)]: - if first is None: - break # ignore last iter when small2 is None - method_expanded = getattr(expanded[first], fn) - method = getattr(first, fn) - r1 = tensorfn(method_expanded, expanded[second], expanded[third]) - r2 = tensorfn(method, second, third) - self.assertEqual(r1, r2) - - # now for torch. versions of functions - if hasattr(torch, fn): - fntorch = getattr(torch, fn) - expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} - - def torchfn(t1, t2, t3): - if fn == "lerp": - return fntorch(t1, t2, 0.5) - elif fn == "masked_select": - return fntorch(t1, t2 < 0) - elif fn == "masked_scatter": - return fntorch(t1, t2 < 0.5, full1d) - elif fn == "masked_fill": - return fntorch(t1, t2 < 0.5, 1.0) - elif fn in fns_3_args: - return fntorch(t1, 1.0, t2, t3) - elif fn in fns_value_kwarg: - return fntorch(t1, t2, t3, value=1.0) - else: - return fntorch(t1, t2) - - # test various orders - for first, second, third in [(large, small, small2), (small, large, small2), - (small2, small, large), (small2, large, small)]: - if first is None: - break # ignore last iter when small2 is None - r1 = torchfn(expanded[first], expanded[second], expanded[third]) - r2 = torchfn(first, second, third) - self.assertEqual(r1, r2) - - # now for in place functions - # in-place tensor is not broadcastable; test only guaranteed - # to work by broadcasting other argument(s) - if not hasattr(large_expanded, fn + "_"): - continue - - # need to clone largeExpanded so we can reuse, since functions are in-place - large_expanded_clone = large_expanded.clone() - - def tensorfn_inplace(t0, t1, t2=None): - t0_fn = getattr(t0, fn + "_") + def tensorfn(myfn, t1, t2): if fn == "lerp": - return t0_fn(t1, 0.5) + return myfn(t1, 0.5) + elif fn == "masked_select": + return myfn(t1 < 0) elif fn == "masked_scatter": - return t0_fn(t1 < 0.5, full1d) + return myfn(t1 < 0.5, full1d) elif fn == "masked_fill": - return t0_fn(t1 < 0.5, 1.0) - elif fn == "map": - return t0_fn(t1, lambda x, y: x + y) - elif fn == "map2": - return t0_fn(t1, t2, lambda x, y, z: x + y + z) + return myfn(t1 < 0.5, 1.0) elif fn in fns_3_args: - return t0_fn(1.0, t1, t2) + return myfn(1, t1, t2) elif fn in fns_value_kwarg: - return t0_fn(t1, t2, value=1.0) + return myfn(t1, t2, value=1) else: - return t0_fn(t1) - # in-place pointwise operations don't actually work if the in-place - # tensor is 0-strided (numpy has the same issue) - if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()): - r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded) - r2 = tensorfn_inplace(large_expanded_clone, small, small2) + return myfn(t1) + + # test various orders + for first, second, third in [(large, small, small2), (small, large, small2), + (small2, small, large), (small2, large, small)]: + if first is None: + break # ignore last iter when small2 is None + method_expanded = getattr(expanded[first], fn) + method = getattr(first, fn) + r1 = tensorfn(method_expanded, expanded[second], expanded[third]) + r2 = tensorfn(method, second, third) self.assertEqual(r1, r2) - def broadcastable(t0, t1, t2=None): - try: - t1.expand_as(t0) - if t2 is not None: - t2.expand_as(t0) - except RuntimeError: - return False - return True + # now for torch. versions of functions + if hasattr(torch, fn): + fntorch = getattr(torch, fn) + expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} - def _test_in_place_broadcastable(t0, t1, t2=None): - if not broadcastable(t0, t1, t2): - same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True) - if not same_size: - self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2)) + def torchfn(t1, t2, t3): + if fn == "lerp": + return fntorch(t1, t2, 0.5) + elif fn == "masked_select": + return fntorch(t1, t2 < 0) + elif fn == "masked_scatter": + return fntorch(t1, t2 < 0.5, full1d) + elif fn == "masked_fill": + return fntorch(t1, t2 < 0.5, 1.0) + elif fn in fns_3_args: + return fntorch(t1, 1.0, t2, t3) + elif fn in fns_value_kwarg: + return fntorch(t1, t2, t3, value=1.0) else: - tensorfn_inplace(t0, t1, t2) + return fntorch(t1, t2) - if fn not in fns_3_args and fn not in fns_value_kwarg: - _test_in_place_broadcastable(small, large_expanded) - _test_in_place_broadcastable(small, large) + # test various orders + for first, second, third in [(large, small, small2), (small, large, small2), + (small2, small, large), (small2, large, small)]: + if first is None: + break # ignore last iter when small2 is None + r1 = torchfn(expanded[first], expanded[second], expanded[third]) + r2 = torchfn(first, second, third) + self.assertEqual(r1, r2) + + # now for in place functions + # in-place tensor is not broadcastable; test only guaranteed + # to work by broadcasting other argument(s) + if not hasattr(large_expanded, fn + "_"): + return + + # need to clone largeExpanded so we can reuse, since functions are in-place + large_expanded_clone = large_expanded.clone() + + def tensorfn_inplace(t0, t1, t2=None): + t0_fn = getattr(t0, fn + "_") + if fn == "lerp": + return t0_fn(t1, 0.5) + elif fn == "masked_scatter": + return t0_fn(t1 < 0.5, full1d) + elif fn == "masked_fill": + return t0_fn(t1 < 0.5, 1.0) + elif fn == "map": + return t0_fn(t1, lambda x, y: x + y) + elif fn == "map2": + return t0_fn(t1, t2, lambda x, y, z: x + y + z) + elif fn in fns_3_args: + return t0_fn(1.0, t1, t2) + elif fn in fns_value_kwarg: + return t0_fn(t1, t2, value=1.0) else: - _test_in_place_broadcastable(small2, small_expanded, large_expanded) - _test_in_place_broadcastable(small2, small, large) + return t0_fn(t1) + # in-place pointwise operations don't actually work if the in-place + # tensor is 0-strided (numpy has the same issue) + if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()): + r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded) + r2 = tensorfn_inplace(large_expanded_clone, small, small2) + self.assertEqual(r1, r2) + + def broadcastable(t0, t1, t2=None): + try: + t1.expand_as(t0) + if t2 is not None: + t2.expand_as(t0) + except RuntimeError: + return False + return True + + def _test_in_place_broadcastable(t0, t1, t2=None): + if not broadcastable(t0, t1, t2): + same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True) + if not same_size: + self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2)) + else: + tensorfn_inplace(t0, t1, t2) + + if fn not in fns_3_args and fn not in fns_value_kwarg: + _test_in_place_broadcastable(small, large_expanded) + _test_in_place_broadcastable(small, large) + else: + _test_in_place_broadcastable(small2, small_expanded, large_expanded) + _test_in_place_broadcastable(small2, small, large) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyCUDA @@ -2963,7 +2963,7 @@ else: index = torch.tensor([0], device=device) x.index_fill_(1, index, 0) self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device)) - if not x.is_complex(): + if not x.is_complex() and not device == "meta": with self.assertRaisesRegex(RuntimeError, r"Scalar"): x.index_fill_(1, index, 1 + 1j) # Make sure that the result stays 0-dim while applied to diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 01e96a3fe112..f32a89933f08 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -9,7 +9,7 @@ import torch from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, TEST_NUMPY, torch_to_numpy_dtype_dict) from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes, - dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta) + dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta) from torch.testing._internal.common_dtype import ( get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes ) @@ -937,7 +937,11 @@ class TestTypePromotion(TestCase): elif op in real_only_ops and dtypes[0].is_complex: with self.assertRaises(RuntimeError): op(t, out=out) - elif op in float_only_ops and (not dtypes[0].is_floating_point and not dtypes[0].is_complex): + elif ( + op in float_only_ops + and (not dtypes[0].is_floating_point and not dtypes[0].is_complex) + and device != "meta" + ): with self.assertRaises(RuntimeError): op(t, out=out) else: @@ -947,6 +951,7 @@ class TestTypePromotion(TestCase): # Verifies that the out= argument doesn't affect the computation, that # is, out = op(...) and op(..., out=out) produce the same result. @onlyNativeDeviceTypes + @skipMeta def test_computation_ignores_out(self, device): t = torch.tensor(33000, dtype=torch.float16, device=device) out = torch.empty(0, dtype=torch.float64, device=device) diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 2678db1d74d5..37d08e39e637 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -14,7 +14,7 @@ from torch.testing._internal.common_utils import ( torch_to_numpy_dtype_dict, ) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes) + (instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta) from torch.testing._internal.common_dtype import ( get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes ) @@ -729,6 +729,7 @@ class TestViewOps(TestCase): s = t.contiguous() self.assertTrue(s is t) + @skipMeta def test_contiguous_nonview(self, device): t = torch.ones(5, 5, device=device) nv = t.t().contiguous() @@ -754,6 +755,7 @@ class TestViewOps(TestCase): v[6] = 0 self.assertEqual(t[1, 1], v[6]) + @skipMeta def test_reshape_nonview(self, device): t = torch.ones(5, 5, device=device) nv = torch.reshape(t.t(), (25,)) @@ -806,7 +808,8 @@ class TestViewOps(TestCase): idx_nv = (0,) * nv.ndim self.assertTrue(not nv._is_view()) nv[idx_nv] = 0 - self.assertNotEqual(t[idx_t], nv[idx_nv]) + if device != "meta": + self.assertNotEqual(t[idx_t], nv[idx_nv]) t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) nv = t.flatten(1, 3) assert_is_nonview(t, nv) @@ -1027,7 +1030,9 @@ class TestOldViewOps(TestCase): self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) y = torch.randn(4, 4, 4, device=device)[:, 0, :] - self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) + # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape + if device != "meta": + self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 993b3d1d5cb9..50101355b8cb 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -598,29 +598,12 @@ class TensorLikePair(Pair): def compare(self) -> None: actual, expected = self.actual, self.expected - with self._handle_meta_tensor_data_access(): - self._compare_attributes(actual, expected) - actual, expected = self._equalize_attributes(actual, expected) + self._compare_attributes(actual, expected) + if any(input.device.type == "meta" for input in (actual, expected)): + return - self._compare_values(actual, expected) - - @contextlib.contextmanager - def _handle_meta_tensor_data_access(self): - """Turns a vanilla :class:`NotImplementedError` stemming from data access on a meta tensor into an expressive - :class:`ErrorMeta`. - - Although it looks like meta tensors could be handled upfront, we need to do it lazily: there are use cases - where a meta tensor wraps a data tensors and dispatches all operator calls to it. Thus, although the tensor is - a meta tensor, it behaves like a regular one. - """ - try: - yield - except NotImplementedError as error: - if "meta" not in str(error).lower(): - raise error - - # TODO: See https://github.com/pytorch/pytorch/issues/68592 - raise self._make_error_meta(NotImplementedError, "Comparing meta tensors is currently not supported.") + actual, expected = self._equalize_attributes(actual, expected) + self._compare_values(actual, expected) def _compare_attributes( self, @@ -1103,10 +1086,15 @@ def assert_close( \lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert - and they have the same :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``), same ``dtype`` (if - ``check_dtype`` is ``True``), and the same stride (if ``check_stride`` is ``True``). Non-finite values - (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are only considered equal - to each other if ``equal_nan`` is ``True``. + Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are + only considered equal to each other if ``equal_nan`` is ``True``. + + In addition, they are only considered close if they have the same + - :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``), + - ``dtype`` (if ``check_dtype`` is ``True``), + - ``layout`` (if ``check_layout`` is ``True``), and + - stride (if ``check_stride`` is ``True``). + If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed. If ``actual`` and ``expected`` are sparse (either having COO or CSR layout), their strided members are checked individually. Indices, namely ``indices`` for COO or ``crow_indices`` and ``col_indices`` for CSR layout, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e984d4b7f8c7..e1f3a1f50326 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14480,9 +14480,21 @@ op_db: List[OpInfo] = [ # These paths have different dtype support. Also JIT supports, # most variants but not all of them. So we split the OpInfo entries, # for `norm` based on the code-paths and JIT support. - OpInfo('norm', - sample_inputs_func=sample_inputs_norm, - dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16)), + OpInfo( + "norm", + sample_inputs_func=sample_inputs_norm, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + skips=( + # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result + # of dtype torch.float32 into an out= with dtype torch.long + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_out", + device_type="meta", + ), + ), + ), OpInfo('norm', variant_test_name='nuc', sample_inputs_func=sample_inputs_norm_nuc, @@ -14517,19 +14529,40 @@ op_db: List[OpInfo] = [ # Arguments for call are not valid. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.complex64, torch.float32,)), # noqa: B950 )), - OpInfo('norm', - variant_test_name='inf', - sample_inputs_func=sample_inputs_norm_inf, - dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), - backward_dtypesIfCPU=floating_and_complex_types_and(torch.float16, torch.bfloat16), - skips=( - # https://github.com/pytorch/pytorch/issues/67517 - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), - # following 2 tests failed intermittenly - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad', device_type='cpu', dtypes=(torch.complex128,)), # noqa: B950 - DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_gradgrad', device_type='cpu', dtypes=(torch.complex128,)), # noqa: B950 - ) - ), + OpInfo( + "norm", + variant_test_name="inf", + sample_inputs_func=sample_inputs_norm_inf, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + backward_dtypesIfCPU=floating_and_complex_types_and(torch.float16, torch.bfloat16), + skips=( + # https://github.com/pytorch/pytorch/issues/67517 + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples"), + # following 2 tests failed intermittenly + DecorateInfo( + unittest.skip("Skipped!"), + "TestGradients", + "test_fn_grad", + device_type="cpu", + dtypes=(torch.complex128,), + ), + DecorateInfo( + unittest.skip("Skipped!"), + "TestGradients", + "test_fn_gradgrad", + device_type="cpu", + dtypes=(torch.complex128,), + ), + # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result + # of dtype torch.float32 into an out= with dtype torch.long + DecorateInfo( + unittest.expectedFailure, + "TestCommon", + "test_out", + device_type="meta", + ), + ), + ), OpInfo('t', sample_inputs_func=sample_inputs_t, supports_out=False,