diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index c293cab2943a..8bc752a12df8 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -16,25 +16,53 @@ from functools import partial import torch.autograd.forward_ad as fwAD from torch._six import inf, nan from torch.testing._internal.common_utils import ( - TestCase, slowTest, iter_indices, TEST_WITH_ASAN, run_tests, gradcheck, - torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, TEST_SCIPY, set_default_dtype) + TestCase, + slowTest, + iter_indices, + TEST_WITH_ASAN, + run_tests, + gradcheck, + torch_to_numpy_dtype_dict, + numpy_to_torch_dtype_dict, + TEST_SCIPY, + set_default_dtype, +) from torch.testing._internal.common_device_type import ( - expectedFailureMeta, instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, - dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyNativeDeviceTypes, - skipIf, ops, OpDTypes, skipMeta) + expectedFailureMeta, + instantiate_device_type_tests, + onlyCUDA, + onlyCPU, + dtypes, + dtypesIfCUDA, + dtypesIfCPU, + deviceCountAtLeast, + precisionOverride, + onlyNativeDeviceTypes, + skipIf, + ops, + OpDTypes, + skipMeta, +) from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( - all_types_and_complex_and, all_types_and, integral_types, complex_types, integral_types_and, - floating_types_and, floating_and_complex_types, get_all_math_dtypes, + all_types_and_complex_and, + all_types_and, + integral_types, + complex_types, + integral_types_and, + floating_types_and, + floating_and_complex_types, + get_all_math_dtypes, ) from torch.testing._internal.common_methods_invocations import ( - binary_ufuncs, _NOTHING, + binary_ufuncs, + _NOTHING, generate_elementwise_binary_tensors, generate_elementwise_binary_small_value_tensors, generate_elementwise_binary_large_value_tensors, generate_elementwise_binary_extremal_value_tensors, generate_elementwise_binary_broadcasting_tensors, - generate_elementwise_binary_with_scalar_samples + generate_elementwise_binary_with_scalar_samples, ) if TEST_SCIPY: @@ -49,7 +77,9 @@ class TestBinaryUfuncs(TestCase): # Helper for comparing torch tensors and NumPy arrays # TODO: should this or assertEqual also validate that strides are equal? - def assertEqualHelper(self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs): + def assertEqualHelper( + self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs + ): assert isinstance(actual, torch.Tensor) # Some NumPy functions return scalars, not arrays @@ -63,31 +93,55 @@ class TestBinaryUfuncs(TestCase): # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16 # to float32 if expected.dtype == np.float32: - assert actual.dtype in (torch.float16, torch.bfloat16, torch.float32) + assert actual.dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + ) else: assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype] - self.assertEqual(actual, - torch.from_numpy(expected).to(actual.dtype), - msg, - exact_device=False, - **kwargs) + self.assertEqual( + actual, + torch.from_numpy(expected).to(actual.dtype), + msg, + exact_device=False, + **kwargs, + ) else: self.assertEqual(actual, expected, msg, exact_device=False, **kwargs) # Tests that the function and its (array-accepting) reference produce the same # values on given tensors def _test_reference_numerics(self, dtype, op, gen, equal_nan=True): - def _helper_reference_numerics(expected, actual, msg, exact_dtype, equal_nan=True): - if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype): + def _helper_reference_numerics( + expected, actual, msg, exact_dtype, equal_nan=True + ): + if not torch.can_cast( + numpy_to_torch_dtype_dict[expected.dtype.type], dtype + ): exact_dtype = False if dtype is torch.bfloat16 and expected.dtype == np.float32: # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149 - self.assertEqualHelper(actual, expected, msg, dtype=dtype, - exact_dtype=exact_dtype, rtol=16e-3, atol=1e-5) + self.assertEqualHelper( + actual, + expected, + msg, + dtype=dtype, + exact_dtype=exact_dtype, + rtol=16e-3, + atol=1e-5, + ) else: - self.assertEqualHelper(actual, expected, msg, dtype=dtype, equal_nan=equal_nan, exact_dtype=exact_dtype) + self.assertEqualHelper( + actual, + expected, + msg, + dtype=dtype, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + ) for sample in gen: # Each sample input acquired from the generator is just one lhs tensor @@ -110,22 +164,28 @@ class TestBinaryUfuncs(TestCase): return 1 if _numel(l) < 10 and _numel(r) < 10: - msg = ("Failed to produce expected results! Input lhs tensor was" - " {0}, rhs tensor was {1}, torch result is {2}, and reference result is" - " {3}.").format(l, r, actual, expected) + msg = ( + "Failed to produce expected results! Input lhs tensor was" + " {0}, rhs tensor was {1}, torch result is {2}, and reference result is" + " {3}." + ).format(l, r, actual, expected) else: msg = None exact_dtype = True if isinstance(actual, torch.Tensor): - _helper_reference_numerics(expected, actual, msg, exact_dtype, equal_nan) + _helper_reference_numerics( + expected, actual, msg, exact_dtype, equal_nan + ) else: for x, y in zip(expected, actual): # testing multi-outputs results _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan) # The following tests only apply to elementwise binary operators with references - binary_ufuncs_with_references = list(filter(lambda op: op.ref is not None and op.ref is not _NOTHING, binary_ufuncs)) + binary_ufuncs_with_references = list( + filter(lambda op: op.ref is not None and op.ref is not _NOTHING, binary_ufuncs) + ) @ops(binary_ufuncs_with_references) def test_reference_numerics(self, device, dtype, op): @@ -139,42 +199,84 @@ class TestBinaryUfuncs(TestCase): if dtype is torch.bool: self.skipTest("Doesn't support bool!") - gen = generate_elementwise_binary_small_value_tensors(op, device=device, dtype=dtype) + gen = generate_elementwise_binary_small_value_tensors( + op, device=device, dtype=dtype + ) self._test_reference_numerics(dtype, op, gen, equal_nan=True) # TODO: review if this skip is necessary @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") - @ops(binary_ufuncs_with_references, - allowed_dtypes=(torch.int16, torch.int32, torch.int64, torch.float16, - torch.bfloat16, torch.float32, torch.float64, torch.complex64, torch.complex128)) + @ops( + binary_ufuncs_with_references, + allowed_dtypes=( + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ), + ) def test_reference_numerics_large_values(self, device, dtype, op): - gen = generate_elementwise_binary_large_value_tensors(op, device=device, dtype=dtype) + gen = generate_elementwise_binary_large_value_tensors( + op, device=device, dtype=dtype + ) self._test_reference_numerics(dtype, op, gen, equal_nan=True) # TODO: review if this skip is necessary @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") - @ops(binary_ufuncs_with_references, - allowed_dtypes=(torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128)) + @ops( + binary_ufuncs_with_references, + allowed_dtypes=( + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ), + ) def test_reference_numerics_extremal_values(self, device, dtype, op): - gen = generate_elementwise_binary_extremal_value_tensors(op, device=device, dtype=dtype) + gen = generate_elementwise_binary_extremal_value_tensors( + op, device=device, dtype=dtype + ) self._test_reference_numerics(dtype, op, gen, equal_nan=True) # tests broadcasting and noncontiguous broadcasting behavior - @ops(binary_ufuncs_with_references, allowed_dtypes=(torch.long, torch.float32,)) + @ops( + binary_ufuncs_with_references, + allowed_dtypes=( + torch.long, + torch.float32, + ), + ) def test_broadcasting(self, device, dtype, op): - gen = generate_elementwise_binary_broadcasting_tensors(op, device=device, dtype=dtype) + gen = generate_elementwise_binary_broadcasting_tensors( + op, device=device, dtype=dtype + ) self._test_reference_numerics(dtype, op, gen, equal_nan=True) - @ops(binary_ufuncs_with_references, allowed_dtypes=(torch.long, torch.float32, torch.complex64)) + @ops( + binary_ufuncs_with_references, + allowed_dtypes=(torch.long, torch.float32, torch.complex64), + ) def test_scalar_support(self, device, dtype, op): - gen = generate_elementwise_binary_with_scalar_samples(op, device=device, dtype=dtype) + gen = generate_elementwise_binary_with_scalar_samples( + op, device=device, dtype=dtype + ) self._test_reference_numerics(dtype, op, gen, equal_nan=True) @ops(binary_ufuncs) def test_contig_vs_every_other(self, device, dtype, op): - lhs = make_tensor((1026,), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs) - rhs = make_tensor((1026,), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs) + lhs = make_tensor( + (1026,), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs + ) + rhs = make_tensor( + (1026,), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs + ) lhs_non_contig = lhs[::2] rhs_non_contig = rhs[::2] @@ -191,8 +293,12 @@ class TestBinaryUfuncs(TestCase): @ops(binary_ufuncs) def test_contig_vs_transposed(self, device, dtype, op): - lhs = make_tensor((789, 357), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs) - rhs = make_tensor((789, 357), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs) + lhs = make_tensor( + (789, 357), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs + ) + rhs = make_tensor( + (789, 357), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs + ) lhs_non_contig = lhs.T rhs_non_contig = rhs.T @@ -211,13 +317,21 @@ class TestBinaryUfuncs(TestCase): def test_non_contig(self, device, dtype, op): shapes = ((5, 7), (1024,)) for shape in shapes: - lhs = make_tensor(shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs) - rhs = make_tensor(shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs) + lhs = make_tensor( + shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs + ) + rhs = make_tensor( + shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs + ) - lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] + lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[ + ..., 0 + ] lhs_non_contig.copy_(lhs) - rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] + rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[ + ..., 0 + ] rhs_non_contig.copy_(rhs) self.assertTrue(lhs.is_contiguous()) @@ -233,8 +347,12 @@ class TestBinaryUfuncs(TestCase): @ops(binary_ufuncs) def test_non_contig_index(self, device, dtype, op): shape = (2, 2, 1, 2) - lhs = make_tensor(shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs) - rhs = make_tensor(shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs) + lhs = make_tensor( + shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs + ) + rhs = make_tensor( + shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs + ) lhs_non_contig = lhs[:, 1, ...] lhs = lhs_non_contig.contiguous() @@ -256,8 +374,12 @@ class TestBinaryUfuncs(TestCase): def test_non_contig_expand(self, device, dtype, op): shapes = [(1, 3), (1, 7), (5, 7)] for shape in shapes: - lhs = make_tensor(shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs) - rhs = make_tensor(shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs) + lhs = make_tensor( + shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs + ) + rhs = make_tensor( + shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs + ) lhs_non_contig = lhs.clone().expand(3, -1, -1) rhs_non_contig = rhs.clone().expand(3, -1, -1) @@ -276,8 +398,12 @@ class TestBinaryUfuncs(TestCase): @ops(binary_ufuncs) def test_contig_size1(self, device, dtype, op): shape = (5, 100) - lhs = make_tensor(shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs) - rhs = make_tensor(shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs) + lhs = make_tensor( + shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs + ) + rhs = make_tensor( + shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs + ) lhs = lhs[:1, :50] lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype) @@ -300,8 +426,12 @@ class TestBinaryUfuncs(TestCase): @ops(binary_ufuncs) def test_contig_size1_large_dim(self, device, dtype, op): shape = (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4) - lhs = make_tensor(shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs) - rhs = make_tensor(shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs) + lhs = make_tensor( + shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs + ) + rhs = make_tensor( + shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs + ) lhs = lhs[:1, :, :, :, :, :, :, :, :, :, :, :] lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype) @@ -324,8 +454,12 @@ class TestBinaryUfuncs(TestCase): @ops(binary_ufuncs) def test_batch_vs_slicing(self, device, dtype, op): shape = (32, 512) - lhs = make_tensor(shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs) - rhs = make_tensor(shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs) + lhs = make_tensor( + shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs + ) + rhs = make_tensor( + shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs + ) expected = op(lhs, rhs) @@ -349,31 +483,52 @@ class TestBinaryUfuncs(TestCase): # int x int type promotion if _supported((torch.int16, torch.int32, torch.int64)): - lhs_i16 = make_tensor((5,), device=device, dtype=torch.int16, **op.lhs_make_tensor_kwargs) - lhs_i32 = make_tensor((5,), device=device, dtype=torch.int32, **op.lhs_make_tensor_kwargs) - lhs_i64 = make_tensor((5,), device=device, dtype=torch.int64, **op.lhs_make_tensor_kwargs) - - rhs_i16 = make_tensor((5,), device=device, dtype=torch.int16, **op.rhs_make_tensor_kwargs) - rhs_i32 = make_tensor((5,), device=device, dtype=torch.int32, **op.rhs_make_tensor_kwargs) - rhs_i64 = make_tensor((5,), device=device, dtype=torch.int64, **op.rhs_make_tensor_kwargs) + lhs_i16 = make_tensor( + (5,), device=device, dtype=torch.int16, **op.lhs_make_tensor_kwargs + ) + lhs_i32 = make_tensor( + (5,), device=device, dtype=torch.int32, **op.lhs_make_tensor_kwargs + ) + lhs_i64 = make_tensor( + (5,), device=device, dtype=torch.int64, **op.lhs_make_tensor_kwargs + ) + rhs_i16 = make_tensor( + (5,), device=device, dtype=torch.int16, **op.rhs_make_tensor_kwargs + ) + rhs_i32 = make_tensor( + (5,), device=device, dtype=torch.int32, **op.rhs_make_tensor_kwargs + ) + rhs_i64 = make_tensor( + (5,), device=device, dtype=torch.int64, **op.rhs_make_tensor_kwargs + ) if op.promotes_int_to_float: default_dtype = torch.get_default_dtype() self.assertEqual(op(lhs_i16, rhs_i32).dtype, default_dtype) - self.assertEqual(op(lhs_i16, rhs_i32), op(lhs_i16.to(default_dtype), rhs_i32.to(default_dtype))) + self.assertEqual( + op(lhs_i16, rhs_i32), + op(lhs_i16.to(default_dtype), rhs_i32.to(default_dtype)), + ) self.assertEqual(op(lhs_i32, rhs_i64).dtype, default_dtype) - self.assertEqual(op(lhs_i32, rhs_i64), op(lhs_i32.to(default_dtype), rhs_i64.to(default_dtype))) + self.assertEqual( + op(lhs_i32, rhs_i64), + op(lhs_i32.to(default_dtype), rhs_i64.to(default_dtype)), + ) elif op.always_returns_bool: self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.bool) self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.bool) else: # standard type promotion self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.int32) - self.assertEqual(op(lhs_i16, rhs_i32), op(lhs_i16.to(torch.int32), rhs_i32)) + self.assertEqual( + op(lhs_i16, rhs_i32), op(lhs_i16.to(torch.int32), rhs_i32) + ) self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.int64) - self.assertEqual(op(lhs_i32, rhs_i64), op(lhs_i32.to(torch.int64), rhs_i64)) + self.assertEqual( + op(lhs_i32, rhs_i64), op(lhs_i32.to(torch.int64), rhs_i64) + ) if op.supports_out: if not op.promotes_int_to_float: @@ -392,7 +547,11 @@ class TestBinaryUfuncs(TestCase): if not op.always_returns_bool: # Neither integer nor float outs can be cast to bool with self.assertRaisesRegex(RuntimeError, "can't be cast"): - op(lhs_i16, rhs_i32, out=torch.empty_like(lhs_i64, dtype=torch.bool)) + op( + lhs_i16, + rhs_i32, + out=torch.empty_like(lhs_i64, dtype=torch.bool), + ) # All these output types can be cast to any float or complex type out = torch.empty_like(lhs_i64, dtype=torch.float16) @@ -411,17 +570,27 @@ class TestBinaryUfuncs(TestCase): # float x float type promotion if _supported((torch.float32, torch.float64)): - lhs_f32 = make_tensor((5,), device=device, dtype=torch.float32, **op.lhs_make_tensor_kwargs) - lhs_f64 = make_tensor((5,), device=device, dtype=torch.float64, **op.lhs_make_tensor_kwargs) + lhs_f32 = make_tensor( + (5,), device=device, dtype=torch.float32, **op.lhs_make_tensor_kwargs + ) + lhs_f64 = make_tensor( + (5,), device=device, dtype=torch.float64, **op.lhs_make_tensor_kwargs + ) - rhs_f32 = make_tensor((5,), device=device, dtype=torch.float32, **op.rhs_make_tensor_kwargs) - rhs_f64 = make_tensor((5,), device=device, dtype=torch.float64, **op.rhs_make_tensor_kwargs) + rhs_f32 = make_tensor( + (5,), device=device, dtype=torch.float32, **op.rhs_make_tensor_kwargs + ) + rhs_f64 = make_tensor( + (5,), device=device, dtype=torch.float64, **op.rhs_make_tensor_kwargs + ) if op.always_returns_bool: self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.bool) else: # normal float type promotion self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.float64) - self.assertEqual(op(lhs_f32, rhs_f64), op(lhs_f32.to(torch.float64), rhs_f64)) + self.assertEqual( + op(lhs_f32, rhs_f64), op(lhs_f32.to(torch.float64), rhs_f64) + ) if op.supports_out: # All these output types can be cast to any float or complex type @@ -443,7 +612,11 @@ class TestBinaryUfuncs(TestCase): if not op.always_returns_bool: # float outs can't be cast to an integer dtype with self.assertRaisesRegex(RuntimeError, "can't be cast"): - op(lhs_f32, rhs_f64, out=torch.empty_like(lhs_f64, dtype=torch.int64)) + op( + lhs_f32, + rhs_f64, + out=torch.empty_like(lhs_f64, dtype=torch.int64), + ) else: # bool outs can be cast to an integer dtype out = torch.empty_like(lhs_f64, dtype=torch.int64) @@ -452,17 +625,27 @@ class TestBinaryUfuncs(TestCase): # complex x complex type promotion if _supported((torch.complex64, torch.complex128)): - lhs_c64 = make_tensor((5,), device=device, dtype=torch.complex64, **op.lhs_make_tensor_kwargs) - lhs_c128 = make_tensor((5,), device=device, dtype=torch.complex128, **op.lhs_make_tensor_kwargs) + lhs_c64 = make_tensor( + (5,), device=device, dtype=torch.complex64, **op.lhs_make_tensor_kwargs + ) + lhs_c128 = make_tensor( + (5,), device=device, dtype=torch.complex128, **op.lhs_make_tensor_kwargs + ) - rhs_c64 = make_tensor((5,), device=device, dtype=torch.complex64, **op.rhs_make_tensor_kwargs) - rhs_c128 = make_tensor((5,), device=device, dtype=torch.complex128, **op.rhs_make_tensor_kwargs) + rhs_c64 = make_tensor( + (5,), device=device, dtype=torch.complex64, **op.rhs_make_tensor_kwargs + ) + rhs_c128 = make_tensor( + (5,), device=device, dtype=torch.complex128, **op.rhs_make_tensor_kwargs + ) if op.always_returns_bool: self.assertEqual(op(lhs_c64, lhs_c128).dtype, torch.bool) else: # normal complex type promotion self.assertEqual(op(lhs_c64, rhs_c128).dtype, torch.complex128) - self.assertEqual(op(lhs_c64, rhs_c128), op(lhs_c64.to(torch.complex128), rhs_c128)) + self.assertEqual( + op(lhs_c64, rhs_c128), op(lhs_c64.to(torch.complex128), rhs_c128) + ) if op.supports_out: # All these output types can be cast to any or complex type @@ -475,14 +658,24 @@ class TestBinaryUfuncs(TestCase): if not op.always_returns_bool: # complex outs can't be cast to float types with self.assertRaisesRegex(RuntimeError, "can't be cast"): - op(lhs_c64, rhs_c128, out=torch.empty_like(lhs_c64, dtype=torch.float64)) + op( + lhs_c64, + rhs_c128, + out=torch.empty_like(lhs_c64, dtype=torch.float64), + ) # complex outs can't be cast to an integer dtype with self.assertRaisesRegex(RuntimeError, "can't be cast"): - op(lhs_c64, rhs_c128, out=torch.empty_like(lhs_c64, dtype=torch.int64)) + op( + lhs_c64, + rhs_c128, + out=torch.empty_like(lhs_c64, dtype=torch.int64), + ) else: # bool outs can be cast to a float type out = torch.empty_like(lhs_c64, dtype=torch.float64) - self.assertEqual(op(lhs_c64, rhs_c128, out=out).dtype, torch.float64) + self.assertEqual( + op(lhs_c64, rhs_c128, out=out).dtype, torch.float64 + ) self.assertEqual(op(lhs_c64, rhs_c128), out, exact_dtype=False) # bool outs can be cast to an integer dtype @@ -494,13 +687,17 @@ class TestBinaryUfuncs(TestCase): @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) def test_not_broadcastable(self, device, dtype, op): for shape_lhs, shape_rhs in ( - ((2,), (3,)), - ((3, 1), (2, 1)), - ((1, 3, 2), (3,)), - ((3, 1, 2), (2, 1, 2)), + ((2,), (3,)), + ((3, 1), (2, 1)), + ((1, 3, 2), (3,)), + ((3, 1, 2), (2, 1, 2)), ): - lhs = make_tensor(shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs) - rhs = make_tensor(shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs) + lhs = make_tensor( + shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs + ) + rhs = make_tensor( + shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs + ) try: broadcasted_shape = op(lhs, rhs).shape @@ -515,27 +712,48 @@ class TestBinaryUfuncs(TestCase): def test_add_broadcast_empty(self, device): # empty + empty - self.assertRaises(RuntimeError, lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device)) - self.assertEqual(torch.randn(5, 0, device=device), torch.randn(0, device=device) + torch.randn(5, 0, device=device)) - self.assertEqual(torch.randn(5, 0, 0, device=device), torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device)) + self.assertRaises( + RuntimeError, + lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device), + ) + self.assertEqual( + torch.randn(5, 0, device=device), + torch.randn(0, device=device) + torch.randn(5, 0, device=device), + ) + self.assertEqual( + torch.randn(5, 0, 0, device=device), + torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device), + ) # scalar + empty - self.assertEqual(torch.randn(5, 0, 6, device=device), torch.randn((), device=device) + torch.randn(5, 0, 6, device=device)) + self.assertEqual( + torch.randn(5, 0, 6, device=device), + torch.randn((), device=device) + torch.randn(5, 0, 6, device=device), + ) # non-empty, empty - self.assertEqual(torch.randn(0, device=device), torch.randn(0, device=device) + torch.randn(1, device=device)) - self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7, device=device), - torch.randn(0, 7, 0, 6, 5, 0, 1, device=device) + torch.randn(1, 1, 5, 1, 7, device=device)) - self.assertRaises(RuntimeError, lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device)) + self.assertEqual( + torch.randn(0, device=device), + torch.randn(0, device=device) + torch.randn(1, device=device), + ) + self.assertEqual( + torch.randn(0, 7, 0, 6, 5, 0, 7, device=device), + torch.randn(0, 7, 0, 6, 5, 0, 1, device=device) + + torch.randn(1, 1, 5, 1, 7, device=device), + ) + self.assertRaises( + RuntimeError, + lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device), + ) def test_addcmul_scalars_as_floats(self, device): # zero-dim variables that don't require grad should bind to scalar arguments - x = torch.tensor(2.) - y = torch.tensor(3., device=device) + x = torch.tensor(2.0) + y = torch.tensor(3.0, device=device) # 3 + (3 * 3) * 2 self.assertEqual(y.addcmul(y, y, value=x), 21) - x = torch.tensor(2., requires_grad=True) + x = torch.tensor(2.0, requires_grad=True) self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x)) # TODO: update to work on CUDA, too @@ -572,8 +790,8 @@ class TestBinaryUfuncs(TestCase): def test_comparison_ops_device_computation(self, device): operands = ( torch.tensor(0), - torch.tensor(2, device='cuda'), - torch.tensor([0, 2], device='cuda') + torch.tensor(2, device="cuda"), + torch.tensor([0, 2], device="cuda"), ) # Checks that comparison operators compute the correct # output device, given a combination of devices @@ -587,38 +805,49 @@ class TestBinaryUfuncs(TestCase): # TODO: update to work on CUDA, too @onlyCPU def test_comparison_ops_must_take_bool_output(self, device): - for op in [torch.lt, torch.le, torch.gt, torch.ge, torch.eq, torch.ne, - torch.logical_and, torch.logical_or, torch.logical_xor]: - self.assertEqual(op(torch.tensor([True]), torch.tensor([False])).dtype, torch.bool) + for op in [ + torch.lt, + torch.le, + torch.gt, + torch.ge, + torch.eq, + torch.ne, + torch.logical_and, + torch.logical_or, + torch.logical_xor, + ]: + self.assertEqual( + op(torch.tensor([True]), torch.tensor([False])).dtype, torch.bool + ) # TODO: update to work on CUDA, too @onlyCPU def test_comparison_ops_check_for_scalar_overflow(self, device): s = 1 << 20 t = torch.tensor([1 << 5], dtype=torch.uint8) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t < s) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(s < t) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t <= s) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(s <= t) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t > s) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(s > t) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t >= s) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(s >= t) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t == s) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(s == t) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t != s) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(s != t) # TODO: update to work on CUDA, too @@ -628,29 +857,29 @@ class TestBinaryUfuncs(TestCase): t2 = torch.tensor([1 << 30], dtype=torch.int32) ts1 = torch.tensor(1 << 20, dtype=torch.int32) ts2 = torch.tensor(1 << 40, dtype=torch.int64) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t1 < ts1) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(ts2 < t2) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t1 <= ts1) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(ts2 <= t2) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t1 > ts1) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(ts2 > t2) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t1 >= ts1) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(ts2 >= t2) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t1 == ts1) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(ts2 == t2) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(t1 != ts1) - with self.assertRaisesRegex(RuntimeError, 'value cannot be converted to type'): + with self.assertRaisesRegex(RuntimeError, "value cannot be converted to type"): self.assertTrue(ts2 != t2) # Tests that the binary operators and, or, and xor (as well as their reflected and inplace versions) @@ -658,7 +887,14 @@ class TestBinaryUfuncs(TestCase): @dtypes(*integral_types_and(torch.bool)) def test_bitwise_ops(self, device, dtype): # Tensor x Tensor and Tensor x Scalar ops - ops = (operator.and_, operator.iand, operator.or_, operator.ior, operator.xor, operator.ixor) + ops = ( + operator.and_, + operator.iand, + operator.or_, + operator.ior, + operator.xor, + operator.ixor, + ) inplace_ops = (operator.iand, operator.ior, operator.ixor) shapes = ((5,), (15, 15), (500, 500)) @@ -672,12 +908,12 @@ class TestBinaryUfuncs(TestCase): # Tests tensor x scalar case a = make_tensor(shape, device=device, dtype=dtype) - b_scalar = make_tensor((), device='cpu', dtype=dtype).item() + b_scalar = make_tensor((), device="cpu", dtype=dtype).item() a_np = a.cpu().clone().numpy() self.assertEqual(op(a, b_scalar), op(a_np, b_scalar)) # Tests scalar x tensor case - a_scalar = make_tensor((), device='cpu', dtype=dtype).item() + a_scalar = make_tensor((), device="cpu", dtype=dtype).item() b = make_tensor(shape, device=device, dtype=dtype) b_np = b.cpu().clone().numpy() self.assertEqual(op(a_scalar, b), op(a_scalar, b_np)) @@ -695,7 +931,7 @@ class TestBinaryUfuncs(TestCase): # Tests tensor x scalar case a = make_tensor(shape, device=device, dtype=dtype) - b_scalar = make_tensor((), device='cpu', dtype=dtype).item() + b_scalar = make_tensor((), device="cpu", dtype=dtype).item() a_np = a.cpu().clone().numpy() op(a, b_scalar) op(a_np, b_scalar) @@ -734,17 +970,23 @@ class TestBinaryUfuncs(TestCase): self.assertTrue(d_true.is_floating_point()) self.assertEqual(d_true * b, a.to(d_true.dtype)) - d_floor = torch.divide(a, b, rounding_mode='floor') + d_floor = torch.divide(a, b, rounding_mode="floor") if dtype not in (torch.bfloat16, torch.half): self.assertEqual(d_floor * b + torch.remainder(a, b), a) else: - self.assertEqual(d_floor * b + torch.remainder(a.float(), b.float()), a, - exact_dtype=False) + self.assertEqual( + d_floor * b + torch.remainder(a.float(), b.float()), + a, + exact_dtype=False, + ) - d_trunc = torch.divide(a, b, rounding_mode='trunc') + d_trunc = torch.divide(a, b, rounding_mode="trunc") rounding_unsupported = ( - dtype == torch.half and device != 'cuda' or - dtype == torch.bfloat16 and device != 'cpu') + dtype == torch.half + and device != "cuda" + or dtype == torch.bfloat16 + and device != "cpu" + ) d_ref = d_true.float() if rounding_unsupported else d_true self.assertEqual(d_trunc, d_ref.trunc().to(dtype)) @@ -752,8 +994,10 @@ class TestBinaryUfuncs(TestCase): def test_div_rounding_nonfinite(self, device, dtype): # Compare division of special floating point values against NumPy - num = torch.tensor([1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan], - dtype=dtype) + num = torch.tensor( + [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan], + dtype=dtype, + ) # Divide by zero is tested seperately denom = num[num != 0] @@ -767,18 +1011,26 @@ class TestBinaryUfuncs(TestCase): an, bn = a.float().cpu().numpy(), b.float().cpu().numpy() for mode, np_ref in ((None, np.true_divide), ("floor", np.floor_divide)): - with np.errstate(all='ignore'): + with np.errstate(all="ignore"): expect = np_ref(an, bn) kwargs = dict(rounding_mode=mode) if mode is not None else {} with set_default_dtype(torch.double): actual = torch.divide(a, b, **kwargs) - self.assertEqual(actual, torch.from_numpy(expect), - exact_device=False, exact_dtype=exact_dtype) + self.assertEqual( + actual, + torch.from_numpy(expect), + exact_device=False, + exact_dtype=exact_dtype, + ) # Compare contiguous (likely vectorized) against non-contiguous (not vectorized) - a_noncontig = torch.empty([2 * i for i in a.shape], dtype=dtype, device=device)[::2, ::2] + a_noncontig = torch.empty([2 * i for i in a.shape], dtype=dtype, device=device)[ + ::2, ::2 + ] a_noncontig[:] = a - b_noncontig = torch.empty([2 * i for i in b.shape], dtype=dtype, device=device)[::2, ::2] + b_noncontig = torch.empty([2 * i for i in b.shape], dtype=dtype, device=device)[ + ::2, ::2 + ] b_noncontig[:] = b for rounding_mode in (None, "trunc", "floor"): @@ -788,9 +1040,11 @@ class TestBinaryUfuncs(TestCase): @dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64) def test_divide_by_zero_rounding(self, device, dtype): - a = torch.tensor([1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan], - dtype=dtype) - exact_dtype = (dtype != torch.bfloat16) + a = torch.tensor( + [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan], + dtype=dtype, + ) + exact_dtype = dtype != torch.bfloat16 if exact_dtype: an = a.cpu().numpy() else: @@ -800,7 +1054,7 @@ class TestBinaryUfuncs(TestCase): # NOTE: NumPy's floor_divide rounding changed in 1.20.0 to be consistent with divide expect = np.divide(an, 0) - for rounding_mode in (None, 'floor'): + for rounding_mode in (None, "floor"): # CPU scalar actual = torch.divide(a, 0, rounding_mode=rounding_mode) self.assertEqual(actual, expect, exact_dtype=exact_dtype) @@ -810,8 +1064,7 @@ class TestBinaryUfuncs(TestCase): @dtypes(*all_types_and(torch.half)) def test_div_rounding_numpy(self, device, dtype): - info = (torch.finfo(dtype) if dtype.is_floating_point - else torch.iinfo(dtype)) + info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype) low, high = info.min, info.max # Compare division of random values against NumPy @@ -832,34 +1085,39 @@ class TestBinaryUfuncs(TestCase): an, bn = a.float().cpu().numpy(), b.float().cpu().numpy() for mode, np_ref in ( - (None, np.true_divide), - ("floor", np.floor_divide), - ("trunc", lambda a, b: np.trunc(np.true_divide(a, b)).astype(a.dtype)) + (None, np.true_divide), + ("floor", np.floor_divide), + ("trunc", lambda a, b: np.trunc(np.true_divide(a, b)).astype(a.dtype)), ): - with np.errstate(all='ignore'): + with np.errstate(all="ignore"): expect = torch.from_numpy(np_ref(an, bn)) kwargs = dict(rounding_mode=mode) if mode is not None else {} # Contiguous (likely vectorized) with set_default_dtype(torch.double): actual = torch.divide(a, b, **kwargs) - self.assertEqual(actual, expect, exact_device=False, exact_dtype=exact_dtype) + self.assertEqual( + actual, expect, exact_device=False, exact_dtype=exact_dtype + ) # Non-contiguous (not vectorized) expect = expect[::2] with set_default_dtype(torch.double): actual = torch.divide(a[::2], b[::2], **kwargs) - self.assertEqual(actual, expect, exact_device=False, exact_dtype=exact_dtype) + self.assertEqual( + actual, expect, exact_device=False, exact_dtype=exact_dtype + ) # Tests that trying to add, inplace, a CUDA tensor to a CPU tensor # throws the correct error message @onlyCUDA def test_cross_device_inplace_error_msg(self, device): - a = torch.tensor(2.) - b = torch.tensor(2., device=device) - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): + a = torch.tensor(2.0) + b = torch.tensor(2.0, device=device) + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): a += b # TODO: refactor this test into a more generic one, it's parked here currently @@ -872,7 +1130,7 @@ class TestBinaryUfuncs(TestCase): binary_inputs = (a, b) unary_ops = (torch.ceil, torch.exp) binary_ops = (torch.add, torch.sub) - for op in (unary_ops + binary_ops): + for op in unary_ops + binary_ops: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") inputs = unary_inputs if op in unary_ops else binary_inputs @@ -896,30 +1154,29 @@ class TestBinaryUfuncs(TestCase): t -= 1 t *= 1 t /= 1 - with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'): + with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): t //= 1 t %= 1 self.assertEqual(expected, t.data_ptr()) - def check_internal_mem_overlap(self, inplace_op, num_inputs, - dtype, device, - expected_failure=False): + def check_internal_mem_overlap( + self, inplace_op, num_inputs, dtype, device, expected_failure=False + ): if isinstance(inplace_op, str): inplace_op = getattr(torch.Tensor, inplace_op) input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) - inputs = [input] + [torch.randn_like(input) - for i in range(num_inputs - 1)] + inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)] if not expected_failure: - with self.assertRaisesRegex(RuntimeError, 'single memory location'): + with self.assertRaisesRegex(RuntimeError, "single memory location"): inplace_op(*inputs) else: with self.assertRaises(AssertionError): - with self.assertRaisesRegex(RuntimeError, 'single memory location'): + with self.assertRaisesRegex(RuntimeError, "single memory location"): inplace_op(*inputs) - def unary_check_input_output_mem_overlap(self, data, sz, op, - expected_failure=False): - + def unary_check_input_output_mem_overlap( + self, data, sz, op, expected_failure=False + ): def _test(op, output, input): output_exp = torch.empty_like(output) op(input, out=output_exp) @@ -928,93 +1185,114 @@ class TestBinaryUfuncs(TestCase): # output is identical to input: _test(op, output=data[0:sz], input=data[0:sz]) # output and input are independent: - _test(op, output=data[0:sz], input=data[sz:2 * sz]) + _test(op, output=data[0:sz], input=data[sz : 2 * sz]) # output partially overlaps with input: if not expected_failure: - with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): - _test(op, data[0:sz], data[1:sz + 1]) + with self.assertRaisesRegex(RuntimeError, "unsupported operation"): + _test(op, data[0:sz], data[1 : sz + 1]) else: with self.assertRaises(AssertionError): - with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): - _test(op, data[0:sz], data[1:sz + 1]) + with self.assertRaisesRegex(RuntimeError, "unsupported operation"): + _test(op, data[0:sz], data[1 : sz + 1]) - def binary_check_input_output_mem_overlap(self, op, device, - expected_failure=False): + def binary_check_input_output_mem_overlap(self, op, device, expected_failure=False): sz = 3 data = torch.randn(2 * sz, device=device) other = torch.randn(sz, device=device) self.unary_check_input_output_mem_overlap( - data, sz, lambda input, out: op(other, input, out=out), - expected_failure=expected_failure) + data, + sz, + lambda input, out: op(other, input, out=out), + expected_failure=expected_failure, + ) self.unary_check_input_output_mem_overlap( - data, sz, lambda input, out: op(input, other, out=out), - expected_failure=expected_failure) + data, + sz, + lambda input, out: op(input, other, out=out), + expected_failure=expected_failure, + ) @dtypes(torch.double) def test_binary_op_mem_overlap(self, device, dtype): ops = [ - ("add", True, True, 'cpu'), - ("add", True, True, 'cuda'), - ("mul", True, True, 'cpu'), - ("mul", True, True, 'cuda'), - ("sub", True, True, 'cpu'), - ("sub", True, True, 'cuda'), - ("div", True, True, 'cpu'), - ("div", True, True, 'cuda'), - ("pow", True, True, 'cpu'), - ("pow", True, True, 'cuda'), - ("fmod", True, True, 'cpu'), - ("fmod", True, True, 'cuda'), - ("atan2", True, True, 'cpu'), - ("atan2", True, True, 'cuda'), - ("hypot", True, True, 'cpu'), - ("hypot", True, True, 'cuda'), - ("igamma", True, True, 'cpu'), - ("igamma", True, True, 'cuda'), - ("igammac", True, True, 'cpu'), - ("igammac", True, True, 'cuda'), - ("nextafter", True, True, 'cpu'), - ("nextafter", True, True, 'cuda'), - ("le", True, True, 'cpu'), - ("le", True, True, 'cuda'), - ("lt", True, True, 'cpu'), - ("lt", True, True, 'cuda'), - ("ge", True, True, 'cpu'), - ("ge", True, True, 'cuda'), - ("gt", True, True, 'cpu'), - ("gt", True, True, 'cuda'), - ("eq", True, True, 'cpu'), - ("eq", True, True, 'cuda'), - ("ne", True, True, 'cpu'), - ("ne", True, True, 'cuda'), - ("logical_and", True, True, 'cpu'), - ("logical_and", True, True, 'cuda'), - ("logical_or", True, True, 'cpu'), - ("logical_or", True, True, 'cuda'), - ("logical_xor", True, True, 'cpu'), - ("logical_xor", True, True, 'cuda'), + ("add", True, True, "cpu"), + ("add", True, True, "cuda"), + ("mul", True, True, "cpu"), + ("mul", True, True, "cuda"), + ("sub", True, True, "cpu"), + ("sub", True, True, "cuda"), + ("div", True, True, "cpu"), + ("div", True, True, "cuda"), + ("pow", True, True, "cpu"), + ("pow", True, True, "cuda"), + ("fmod", True, True, "cpu"), + ("fmod", True, True, "cuda"), + ("atan2", True, True, "cpu"), + ("atan2", True, True, "cuda"), + ("hypot", True, True, "cpu"), + ("hypot", True, True, "cuda"), + ("igamma", True, True, "cpu"), + ("igamma", True, True, "cuda"), + ("igammac", True, True, "cpu"), + ("igammac", True, True, "cuda"), + ("nextafter", True, True, "cpu"), + ("nextafter", True, True, "cuda"), + ("le", True, True, "cpu"), + ("le", True, True, "cuda"), + ("lt", True, True, "cpu"), + ("lt", True, True, "cuda"), + ("ge", True, True, "cpu"), + ("ge", True, True, "cuda"), + ("gt", True, True, "cpu"), + ("gt", True, True, "cuda"), + ("eq", True, True, "cpu"), + ("eq", True, True, "cuda"), + ("ne", True, True, "cpu"), + ("ne", True, True, "cuda"), + ("logical_and", True, True, "cpu"), + ("logical_and", True, True, "cuda"), + ("logical_or", True, True, "cpu"), + ("logical_or", True, True, "cuda"), + ("logical_xor", True, True, "cpu"), + ("logical_xor", True, True, "cuda"), ] - for (fn, has_input_output_mem_overlap_check, - has_internal_mem_overlap_check, dev) in ops: + for ( + fn, + has_input_output_mem_overlap_check, + has_internal_mem_overlap_check, + dev, + ) in ops: if dev != device: continue out_op = getattr(torch, fn) - inplace_op = getattr(torch.Tensor, fn + '_') + inplace_op = getattr(torch.Tensor, fn + "_") self.check_internal_mem_overlap( - inplace_op, 2, dtype, device, - expected_failure=not has_internal_mem_overlap_check) + inplace_op, + 2, + dtype, + device, + expected_failure=not has_internal_mem_overlap_check, + ) - self.binary_check_input_output_mem_overlap(out_op, device, - expected_failure=not has_input_output_mem_overlap_check) + self.binary_check_input_output_mem_overlap( + out_op, device, expected_failure=not has_input_output_mem_overlap_check + ) def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol): for num in exponents: - if isinstance(num, int) and num < 0 and not m1.is_floating_point() and not m1.is_complex(): - with self.assertRaisesRegex(RuntimeError, - r'Integers to negative integer powers are not allowed\.'): + if ( + isinstance(num, int) + and num < 0 + and not m1.is_floating_point() + and not m1.is_complex() + ): + with self.assertRaisesRegex( + RuntimeError, + r"Integers to negative integer powers are not allowed\.", + ): torch.pow(m1[4], num) else: # base - tensor, exponent - number @@ -1037,7 +1315,9 @@ class TestBinaryUfuncs(TestCase): # scalar ** tensor to enforce correct handling of dtypes for __rpow__(). expected_dtype = torch.result_type(num, m1) res1 = num ** m1[4] - res2 = torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4] + res2 = ( + torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4] + ) self.assertEqual(res1, res2) self.assertEqual(res1.dtype, expected_dtype) @@ -1045,14 +1325,27 @@ class TestBinaryUfuncs(TestCase): def test_pow(self, device, dtype): m1 = torch.empty(0, dtype=dtype, device=device) if m1.is_floating_point() or m1.is_complex(): - m1 = make_tensor((100, 100), low=0, high=1, dtype=dtype, device=device) + 0.5 + m1 = ( + make_tensor((100, 100), low=0, high=1, dtype=dtype, device=device) + 0.5 + ) else: # math.pow will overflow and throw exceptions for large integers range_high = 4 if dtype in (torch.int8, torch.uint8) else 10 - m1 = make_tensor((100, 100), low=1, high=range_high, dtype=dtype, device=device) + m1 = make_tensor( + (100, 100), low=1, high=range_high, dtype=dtype, device=device + ) exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3] - complex_exponents = [-2.5j, -1.0j, 0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j] + complex_exponents = [ + -2.5j, + -1.0j, + 0j, + 1.0j, + 2.5j, + 1.0 + 1.0j, + -1.0 - 1.5j, + 3.3j, + ] if m1.is_complex(): self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4) else: @@ -1086,7 +1379,11 @@ class TestBinaryUfuncs(TestCase): try: np_res = np.power(to_np(base), to_np(np_exponent)) - expected = torch.from_numpy(np_res) if isinstance(np_res, np.ndarray) else torch.tensor(np_res, dtype=base.dtype) + expected = ( + torch.from_numpy(np_res) + if isinstance(np_res, np.ndarray) + else torch.tensor(np_res, dtype=base.dtype) + ) except ValueError as e: err_msg = "Integers to negative integer powers are not allowed." self.assertEqual(str(e), err_msg) @@ -1095,7 +1392,7 @@ class TestBinaryUfuncs(TestCase): lambda: base.pow(exponent), lambda: base.pow_(exponent), lambda: torch.pow(base, exponent), - lambda: torch.pow(base, exponent, out=out) + lambda: torch.pow(base, exponent, out=out), ] for test_case in test_cases: self.assertRaisesRegex(RuntimeError, err_msg, test_case) @@ -1106,16 +1403,24 @@ class TestBinaryUfuncs(TestCase): actual = base.clone() # When base is a 0-dim cpu tensor and exp is a cuda tensor, we exp `pow` to work but `pow_` to fail, since # `pow` will try to create the output tensor on a cuda device, but `pow_` needs to use the cpu tensor as the output - if (isinstance(exponent, torch.Tensor) and base.dim() == 0 and base.device.type == 'cpu' and - exponent.device.type == 'cuda'): - regex = 'Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!' + if ( + isinstance(exponent, torch.Tensor) + and base.dim() == 0 + and base.device.type == "cpu" + and exponent.device.type == "cuda" + ): + regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!" self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent) elif torch.can_cast(torch.result_type(base, exponent), base.dtype): actual2 = actual.pow_(exponent) self.assertEqual(actual, expected) self.assertEqual(actual2, expected) else: - self.assertRaisesRegex(RuntimeError, "Found dtype \\w+ but expected \\w+", lambda: actual.pow_(exponent)) + self.assertRaisesRegex( + RuntimeError, + "Found dtype \\w+ but expected \\w+", + lambda: actual.pow_(exponent), + ) actual = torch.pow(base, exponent) self.assertEqual(actual, expected.to(actual)) @@ -1129,13 +1434,16 @@ class TestBinaryUfuncs(TestCase): # a lambada that switches the inputs, because we also want to test samples inputs # where the second input is a scalar. The wrapper would need some more logic. def test_pow_scalar_base(self, device): - a = torch.arange(1, 13, dtype=torch.double, device=device).view(3, 4).requires_grad_() + a = ( + torch.arange(1, 13, dtype=torch.double, device=device) + .view(3, 4) + .requires_grad_() + ) gradcheck(lambda a: torch.pow(2, a), (a,)) # Tests pow() for integral, floating-type tensors, with integral, floating-type # exponents (tensor or scalar), respectively. noncontiguous tensors are also tested. def test_int_and_float_pow(self, device): - def _test_int_and_float_pow(dt, low, high, dev): test_cases = ( ((4, 4), 0, (4, 1)), @@ -1147,23 +1455,59 @@ class TestBinaryUfuncs(TestCase): ((), 2, ()), ) for base_shape, exp_scalar, exp_shape in test_cases: - base_tensor = make_tensor(base_shape, dtype=dt, device=dev, low=low, high=high) + base_tensor = make_tensor( + base_shape, dtype=dt, device=dev, low=low, high=high + ) # int tensors don't take negative exponents - if dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: - exp_tensor = make_tensor(exp_shape, dtype=dt, device=dev, low=0, high=high) + if dt in [ + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ]: + exp_tensor = make_tensor( + exp_shape, dtype=dt, device=dev, low=0, high=high + ) else: - exp_tensor = make_tensor(exp_shape, dtype=dt, device=dev, low=low, high=high) + exp_tensor = make_tensor( + exp_shape, dtype=dt, device=dev, low=low, high=high + ) self._test_pow(base_tensor, exp_scalar) self._test_pow(base_tensor, exp_tensor) # test non-contiguous tensors as well - base_tensor = make_tensor(base_shape, dtype=dt, device=dev, low=low, high=high, - noncontiguous=True) - if dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: - exp_tensor = make_tensor(exp_shape, dtype=dt, device=dev, low=0, high=high, - noncontiguous=True) + base_tensor = make_tensor( + base_shape, + dtype=dt, + device=dev, + low=low, + high=high, + noncontiguous=True, + ) + if dt in [ + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ]: + exp_tensor = make_tensor( + exp_shape, + dtype=dt, + device=dev, + low=0, + high=high, + noncontiguous=True, + ) else: - exp_tensor = make_tensor(exp_shape, dtype=dt, device=dev, low=low, high=high, - noncontiguous=True) + exp_tensor = make_tensor( + exp_shape, + dtype=dt, + device=dev, + low=low, + high=high, + noncontiguous=True, + ) self._test_pow(base_tensor, exp_scalar) self._test_pow(base_tensor, exp_tensor) @@ -1172,12 +1516,12 @@ class TestBinaryUfuncs(TestCase): _test_int_and_float_pow(torch.int16, -5, 5, device) _test_int_and_float_pow(torch.int64, -10, 10, device) _test_int_and_float_pow(torch.int32, -10, 10, device) - _test_int_and_float_pow(torch.float16, 0., 5., device) - _test_int_and_float_pow(torch.float32, 0., 10., device) - _test_int_and_float_pow(torch.float64, 0., 10., device) + _test_int_and_float_pow(torch.float16, 0.0, 5.0, device) + _test_int_and_float_pow(torch.float32, 0.0, 10.0, device) + _test_int_and_float_pow(torch.float64, 0.0, 10.0, device) # pow's output would have some NaNs as well - _test_int_and_float_pow(torch.float32, -10., 10., device) - _test_int_and_float_pow(torch.float64, -10., 10., device) + _test_int_and_float_pow(torch.float32, -10.0, 10.0, device) + _test_int_and_float_pow(torch.float64, -10.0, 10.0, device) # Tests that a Runtime error occurs when a base tensor cannot be resized # by pow's inplace variant due to PyTorch's broadcasting semantics. @@ -1188,19 +1532,33 @@ class TestBinaryUfuncs(TestCase): ((2, 1), (2, 2)), ((2, 2), (2, 1, 1)), ) - test_inputs = list((make_tensor(base_size, dtype=torch.float64, device=device, - high=10., low=0.), - make_tensor(exp_size, dtype=torch.float64, device=device, - high=10., low=0.)) - for base_size, exp_size in test_cases) + test_inputs = list( + ( + make_tensor( + base_size, dtype=torch.float64, device=device, high=10.0, low=0.0 + ), + make_tensor( + exp_size, dtype=torch.float64, device=device, high=10.0, low=0.0 + ), + ) + for base_size, exp_size in test_cases + ) for base, exponent in test_inputs: regex = "doesn't match the broadcast shape" self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent) def test_int_tensor_pow_neg_ints(self, device): - ints = [torch.iinfo(torch.int32).min, - -3, -2, -1, 0, 1, 2, 3, - torch.iinfo(torch.int32).max] + ints = [ + torch.iinfo(torch.int32).min, + -3, + -2, + -1, + 0, + 1, + 2, + 3, + torch.iinfo(torch.int32).max, + ] neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1] tensor = torch.tensor(ints, dtype=torch.int32, device=device) for pow in neg_ints: @@ -1215,16 +1573,17 @@ class TestBinaryUfuncs(TestCase): @dtypes(*[torch.float32, torch.float64]) def test_float_scalar_pow_float_tensor(self, device, dtype): - floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, - 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] + floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] exponent_shapes = ( (1,), (2, 2), (2, 1), (2, 2, 2), ) - tensors = list(make_tensor(shape, dtype=dtype, device=device, low=0) - for shape in exponent_shapes) + tensors = list( + make_tensor(shape, dtype=dtype, device=device, low=0) + for shape in exponent_shapes + ) floats_tensor = torch.tensor(floats, dtype=dtype, device=device) for base in floats: self._test_pow(base, floats_tensor) @@ -1233,27 +1592,37 @@ class TestBinaryUfuncs(TestCase): @onlyCUDA def test_cuda_tensor_pow_scalar_tensor(self, device): - cuda_tensors = [torch.randn((3, 3), device=device), torch.tensor(3.0, device=device)] - scalar_tensors = [torch.tensor(5.0, device='cpu'), torch.tensor(-3), torch.tensor(1)] + cuda_tensors = [ + torch.randn((3, 3), device=device), + torch.tensor(3.0, device=device), + ] + scalar_tensors = [ + torch.tensor(5.0, device="cpu"), + torch.tensor(-3), + torch.tensor(1), + ] for base, exp in product(cuda_tensors, scalar_tensors): self._test_pow(base, exp) @onlyCUDA def test_cpu_tensor_pow_cuda_scalar_tensor(self, device): - cuda_tensors = [torch.tensor(5.0, device='cuda'), torch.tensor(-3, device='cuda')] + cuda_tensors = [ + torch.tensor(5.0, device="cuda"), + torch.tensor(-3, device="cuda"), + ] for exp in cuda_tensors: - base = torch.randn((3, 3), device='cpu') - regex = 'Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!' + base = torch.randn((3, 3), device="cpu") + regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!" self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp) for exp in cuda_tensors: # Binary ops with a cpu + cuda tensor are allowed if the cpu tensor has 0 dimension - base = torch.tensor(3.0, device='cpu') + base = torch.tensor(3.0, device="cpu") self._test_pow(base, exp) @onlyCUDA @dtypes(torch.complex64, torch.complex128) def test_pow_cuda_complex_extremal_failing(self, device, dtype): - t = torch.tensor(complex(-1., float('inf')), dtype=dtype, device=device) + t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device) with self.assertRaises(AssertionError): cuda_out = t.pow(2) cpu_out = t.cpu().pow(2) @@ -1262,9 +1631,11 @@ class TestBinaryUfuncs(TestCase): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half)) def test_complex_scalar_pow_tensor(self, device, dtype): - complexes = [0.5j, 1. + 1.j, -1.5j, 2.2 - 1.6j, 1 + 0j] + complexes = [0.5j, 1.0 + 1.0j, -1.5j, 2.2 - 1.6j, 1 + 0j] first_exp = make_tensor((100,), dtype=dtype, device=device, low=-2, high=2) - second_exp = make_tensor((100,), dtype=dtype, device=device, low=-2, high=2, noncontiguous=True) + second_exp = make_tensor( + (100,), dtype=dtype, device=device, low=-2, high=2, noncontiguous=True + ) first_exp[0] = first_exp[10] = first_exp[20] = 0 second_exp[0] = second_exp[10] = second_exp[20] = 0 for base in complexes: @@ -1279,14 +1650,25 @@ class TestBinaryUfuncs(TestCase): for input in inputs: # We expect the computation to be performed in uint8 (overflowing to 0), and then cast to int64 input_tensor_uint8 = torch.tensor(input, dtype=torch.uint8, device=device) - out_uint8_computation = torch.pow(2, input_tensor_uint8, out=torch.tensor(0, dtype=torch.int64, device=device)) + out_uint8_computation = torch.pow( + 2, + input_tensor_uint8, + out=torch.tensor(0, dtype=torch.int64, device=device), + ) # Computation should run in int64, and not overflow input_tensor_int64 = torch.tensor(input, dtype=torch.int64, device=device) - out_int64_computation = torch.pow(2, input_tensor_int64, out=torch.tensor(0, dtype=torch.int64, device=device)) + out_int64_computation = torch.pow( + 2, + input_tensor_int64, + out=torch.tensor(0, dtype=torch.int64, device=device), + ) self.assertNotEqual(out_uint8_computation, out_int64_computation) - self.assertEqual(out_uint8_computation.to(dtype=torch.uint8), out_int64_computation.to(dtype=torch.uint8)) + self.assertEqual( + out_uint8_computation.to(dtype=torch.uint8), + out_int64_computation.to(dtype=torch.uint8), + ) def test_tensor_pow_tensor(self, device): def rotate(l, n): @@ -1306,26 +1688,24 @@ class TestBinaryUfuncs(TestCase): test_tensor_pow_tensor(ints, torch.int32, np.int32) test_tensor_pow_tensor(ints, torch.int64, np.int64) - floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, - 0.0, 1 / 3, 1 / 2, 1.0, 2.0, 3.0] + floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 2.0, 3.0] test_tensor_pow_tensor(floats, torch.float16, np.float16) test_tensor_pow_tensor(floats, torch.float32, np.float32) test_tensor_pow_tensor(floats, torch.float64, np.float64) - def test_logical_xor_with_nontrivial_alignment(self, device): # test tensor that is not aligned to multiple of 16 bytes size = 128 - a = (torch.randn(size, device=device) > 0) - b = (torch.randn(size, device=device) > 0) - c = (torch.randn(size, device=device) > 0) + a = torch.randn(size, device=device) > 0 + b = torch.randn(size, device=device) > 0 + c = torch.randn(size, device=device) > 0 non_trivial_alignment = [1, 2, 4, 8, 15] for i in non_trivial_alignment: for j in non_trivial_alignment: for k in non_trivial_alignment: - a_ = a[i: 100 + i] - b_ = b[j: 100 + j] - c_ = c[k: 100 + k] + a_ = a[i : 100 + i] + b_ = b[j : 100 + j] + c_ = c[k : 100 + k] torch.logical_xor(a_, b_, out=c_) for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()): self.assertEqual(x ^ y, z) @@ -1348,7 +1728,7 @@ class TestBinaryUfuncs(TestCase): @deviceCountAtLeast(2) @onlyCUDA def test_cross_device_binary_ops(self, devices): - vals = (1., (2.,)) + vals = (1.0, (2.0,)) cpu_tensor = torch.randn(2, 2) def do_test(op, a, b): @@ -1361,11 +1741,18 @@ class TestBinaryUfuncs(TestCase): with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): op(cpu_tensor, a) - for op in (operator.add, torch.add, - operator.sub, torch.sub, - operator.mul, torch.mul, - operator.truediv, torch.true_divide, - operator.floordiv, torch.floor_divide): + for op in ( + operator.add, + torch.add, + operator.sub, + torch.sub, + operator.mul, + torch.mul, + operator.truediv, + torch.true_divide, + operator.floordiv, + torch.floor_divide, + ): for a, b in product(vals, vals): a = torch.tensor(a, device=devices[0]) b = torch.tensor(b, device=devices[1]) @@ -1378,7 +1765,7 @@ class TestBinaryUfuncs(TestCase): @deviceCountAtLeast(2) @onlyCUDA def test_binary_op_scalar_device_unspecified(self, devices): - scalar_val = torch.tensor(1.) + scalar_val = torch.tensor(1.0) for default_device in devices: with torch.cuda.device(default_device): for device in devices: @@ -1397,7 +1784,7 @@ class TestBinaryUfuncs(TestCase): # the quotient. See https://github.com/pytorch/pytorch/issues/43874. def _scalar_helper(python_op, torch_op): for a, b in product(range(-10, 10), range(-10, 10)): - for op in (lambda x: x * .5, lambda x: math.floor(x)): + for op in (lambda x: x * 0.5, lambda x: math.floor(x)): a = op(a) b = op(b) @@ -1424,7 +1811,7 @@ class TestBinaryUfuncs(TestCase): _scalar_helper(operator.truediv, operator.truediv) _scalar_helper(operator.truediv, torch.true_divide) - with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'): + with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): _scalar_helper(lambda a, b: math.trunc(a / b), operator.floordiv) _scalar_helper(lambda a, b: math.trunc(a / b), torch.floor_divide) @@ -1442,7 +1829,7 @@ class TestBinaryUfuncs(TestCase): scripted_div = torch.jit.script(_wrapped_div) scripted_floordiv = torch.jit.script(_wrapped_floordiv) for a, b in product(range(-10, 10), range(-10, 10)): - for op in (lambda x: x * .5, lambda x: math.floor(x)): + for op in (lambda x: x * 0.5, lambda x: math.floor(x)): a = op(a) b = op(b) @@ -1456,7 +1843,7 @@ class TestBinaryUfuncs(TestCase): b_t = torch.tensor(b, device=device) self.assertEqual(scripted_div(a_t, b_t), expected_div) - with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'): + with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): self.assertEqual(scripted_floordiv(a_t, b_t), expected_truncdiv) # Creates jitted functions of one tensor @@ -1481,13 +1868,13 @@ class TestBinaryUfuncs(TestCase): scripted_rfloordiv_scalar = torch.jit.script(_wrapped_rfloordiv_scalar) for a in range(-10, 10): - for op in (lambda x: x * .5, lambda x: math.floor(x)): + for op in (lambda x: x * 0.5, lambda x: math.floor(x)): a = op(a) a_t = torch.tensor(a, device=device) self.assertEqual(a / 5, scripted_div_scalar(a_t)) - with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'): + with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): self.assertEqual(math.trunc(a / 5), scripted_floordiv_scalar(a_t)) # Skips zero divisors @@ -1556,7 +1943,7 @@ class TestBinaryUfuncs(TestCase): scripted_floor_divide__scalar = torch.jit.script(_wrapped_floor_divide__scalar) for a, b in product(range(-10, 10), range(-10, 10)): - for op in (lambda x: x * .5, lambda x: math.floor(x)): + for op in (lambda x: x * 0.5, lambda x: math.floor(x)): a = op(a) b = op(b) @@ -1580,8 +1967,13 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(tmp0.item(), expected_idiv) self.assertEqual(tmp1.item(), expected_idiv) - self.assertEqual(scripted_true_divide__tensor(a_t.clone(), b_t).item(), expected_idiv) - self.assertEqual(scripted_true_divide__scalar(a_t.clone()).item(), a / 5) + self.assertEqual( + scripted_true_divide__tensor(a_t.clone(), b_t).item(), + expected_idiv, + ) + self.assertEqual( + scripted_true_divide__scalar(a_t.clone()).item(), a / 5 + ) else: tmp = a_t.clone() with self.assertRaises(RuntimeError): @@ -1593,42 +1985,56 @@ class TestBinaryUfuncs(TestCase): with self.assertRaises(RuntimeError): scripted_true_divide__scalar(tmp) - if not a_t.is_floating_point() and b_t.is_floating_point(): # Inplace modification fails because a float tensor is required # if the divisor is a float tensor - with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex(UserWarning, "floor_divide"): + with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex( + UserWarning, "floor_divide" + ): a_t.clone().floor_divide_(b_t) - with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex(UserWarning, "floor_divide"): + with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex( + UserWarning, "floor_divide" + ): scripted_floor_divide_tensor(a_t.clone(), b_t) tmp = a_t.clone() - with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex(UserWarning, "floor_divide"): + with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex( + UserWarning, "floor_divide" + ): tmp //= b_t else: # Inplace modification is OK when both or neither tensor is # a float tensor with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): - self.assertEqual(a_t.clone().floor_divide_(b_t).item(), expected_itruncdiv) - self.assertEqual(scripted_floor_divide__tensor(a_t.clone(), b_t).item(), expected_itruncdiv) + self.assertEqual( + a_t.clone().floor_divide_(b_t).item(), expected_itruncdiv + ) + self.assertEqual( + scripted_floor_divide__tensor(a_t.clone(), b_t).item(), + expected_itruncdiv, + ) tmp = a_t.clone() with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): tmp //= b_t self.assertEqual(tmp.item(), expected_itruncdiv) with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): - self.assertEqual(scripted_floor_divide__scalar(a_t), math.trunc(a / 5)) + self.assertEqual( + scripted_floor_divide__scalar(a_t), math.trunc(a / 5) + ) # Tests binary op equivalence with Python builtin ops # Also tests that reverse operations are equivalent to forward ops # NOTE: division ops are tested separately above def test_binary_ops_with_scalars(self, device): - for python_op, torch_op in ((operator.add, torch.add), - (operator.sub, torch.sub), - (operator.mul, torch.mul), - (operator.truediv, torch.div)): + for python_op, torch_op in ( + (operator.add, torch.add), + (operator.sub, torch.sub), + (operator.mul, torch.mul), + (operator.truediv, torch.div), + ): for a, b in product(range(-10, 10), range(-10, 10)): - for op in (lambda x: x * .5, lambda x: math.floor(x)): + for op in (lambda x: x * 0.5, lambda x: math.floor(x)): a = op(a) b = op(b) @@ -1645,29 +2051,56 @@ class TestBinaryUfuncs(TestCase): for args in product(vals, vals): first, second = args - first_scalar = first if not isinstance(first, torch.Tensor) else first.item() - second_scalar = second if not isinstance(second, torch.Tensor) else second.item() + first_scalar = ( + first + if not isinstance(first, torch.Tensor) + else first.item() + ) + second_scalar = ( + second + if not isinstance(second, torch.Tensor) + else second.item() + ) expected = python_op(first_scalar, second_scalar) self.assertEqual(expected, python_op(first, second)) self.assertEqual(expected, torch_op(first, second)) - @dtypes(*product(all_types_and(torch.half, torch.bfloat16, torch.bool), - all_types_and(torch.half, torch.bfloat16, torch.bool))) + @dtypes( + *product( + all_types_and(torch.half, torch.bfloat16, torch.bool), + all_types_and(torch.half, torch.bfloat16, torch.bool), + ) + ) def test_maximum_minimum_type_promotion(self, device, dtypes): a = torch.tensor((0, 1), device=device, dtype=dtypes[0]) b = torch.tensor((1, 0), device=device, dtype=dtypes[1]) - for op in (torch.maximum, torch.max, torch.fmax, torch.minimum, torch.min, torch.fmin): + for op in ( + torch.maximum, + torch.max, + torch.fmax, + torch.minimum, + torch.min, + torch.fmin, + ): result = op(a, b) self.assertEqual(result.dtype, torch.result_type(a, b)) @dtypes(*integral_types_and(torch.bool)) def test_maximum_minimum_int_and_bool(self, device, dtype): - ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum), - (torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin)) + ops = ( + (torch.maximum, torch.max, np.maximum), + (torch.minimum, torch.min, np.minimum), + (torch.fmax, None, np.fmax), + (torch.fmin, None, np.fmin), + ) rng = np.random.default_rng() - a_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) - b_np = np.array(rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype]) + a_np = np.array( + rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype] + ) + b_np = np.array( + rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype] + ) for torch_op, alias, numpy_op in ops: a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) @@ -1689,8 +2122,12 @@ class TestBinaryUfuncs(TestCase): @precisionOverride({torch.bfloat16: 1e-2}) @dtypes(*(floating_types_and(torch.half, torch.bfloat16))) def test_maximum_minimum_float(self, device, dtype): - ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum), - (torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin)) + ops = ( + (torch.maximum, torch.max, np.maximum), + (torch.minimum, torch.min, np.minimum), + (torch.fmax, None, np.fmax), + (torch.fmin, None, np.fmin), + ) if dtype == torch.bfloat16: a_np = np.random.randn(10).astype(np.float64) @@ -1719,10 +2156,32 @@ class TestBinaryUfuncs(TestCase): def test_maximum_minimum_float_nan_and_inf(self, device, dtype): # np.maximum and np.minimum functions compare input arrays element-wisely. # if one of the elements being compared is a NaN, then that element is returned. - ops = ((torch.maximum, torch.max, np.maximum), (torch.minimum, torch.min, np.minimum), - (torch.fmax, None, np.fmax), (torch.fmin, None, np.fmin)) - a_vals = (float('inf'), -float('inf'), float('nan'), float('inf'), float('nan'), float('nan'), 1, float('nan')) - b_vals = (-float('inf'), float('inf'), float('inf'), float('nan'), float('nan'), 0, float('nan'), -5) + ops = ( + (torch.maximum, torch.max, np.maximum), + (torch.minimum, torch.min, np.minimum), + (torch.fmax, None, np.fmax), + (torch.fmin, None, np.fmin), + ) + a_vals = ( + float("inf"), + -float("inf"), + float("nan"), + float("inf"), + float("nan"), + float("nan"), + 1, + float("nan"), + ) + b_vals = ( + -float("inf"), + float("inf"), + float("inf"), + float("nan"), + float("nan"), + 0, + float("nan"), + -5, + ) if dtype == torch.bfloat16: a_np = np.array(a_vals, dtype=np.float64) b_np = np.array(b_vals, dtype=np.float64) @@ -1751,16 +2210,32 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(tensor_result, numpy_result) self.assertEqual(out, numpy_result) - @dtypes(*product(complex_types(), all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))) + @dtypes( + *product( + complex_types(), + all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + ) + ) def test_maximum_minimum_complex(self, device, dtypes): - for torch_op in (torch.maximum, torch.minimum, torch.max, torch.min, torch.fmax, torch.fmin): - with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'): - torch_op(torch.ones(1, device=device, dtype=dtypes[0]), - torch.ones(1, device=device, dtype=dtypes[1])) + for torch_op in ( + torch.maximum, + torch.minimum, + torch.max, + torch.min, + torch.fmax, + torch.fmin, + ): + with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"): + torch_op( + torch.ones(1, device=device, dtype=dtypes[0]), + torch.ones(1, device=device, dtype=dtypes[1]), + ) - with self.assertRaisesRegex(RuntimeError, '.+not implemented for.+'): - torch_op(torch.ones(1, device=device, dtype=dtypes[1]), - torch.ones(1, device=device, dtype=dtypes[0])) + with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"): + torch_op( + torch.ones(1, device=device, dtype=dtypes[1]), + torch.ones(1, device=device, dtype=dtypes[0]), + ) @onlyCUDA def test_maximum_minimum_cross_device(self, device): @@ -1769,12 +2244,14 @@ class TestBinaryUfuncs(TestCase): ops = (torch.maximum, torch.minimum) for torch_op in ops: - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): torch_op(a, b) - with self.assertRaisesRegex(RuntimeError, - "Expected all tensors to be on the same device"): + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): torch_op(b, a) # test cuda tensor and cpu scalar @@ -1793,8 +2270,12 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(tensor_result_1, numpy_result_1) self.assertEqual(tensor_result_2, numpy_result_2) - @dtypes(*product(floating_types_and(torch.half, torch.bfloat16), - floating_types_and(torch.half, torch.bfloat16))) + @dtypes( + *product( + floating_types_and(torch.half, torch.bfloat16), + floating_types_and(torch.half, torch.bfloat16), + ) + ) def test_maximum_and_minimum_subgradient(self, device, dtypes): def run_test(f, a, b, expected_a_grad, expected_b_grad): a = torch.tensor(a, requires_grad=True, device=device, dtype=dtypes[0]) @@ -1804,8 +2285,20 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(a.grad, expected_a_grad) self.assertEqual(b.grad, expected_b_grad) - run_test(torch.maximum, [0., 1., 2.], [1., 1., 1.], [0., 0.5, 1.], [1., 0.5, 0.]) - run_test(torch.minimum, [0., 1., 2.], [1., 1., 1.], [1., 0.5, 0.], [0., 0.5, 1.]) + run_test( + torch.maximum, + [0.0, 1.0, 2.0], + [1.0, 1.0, 1.0], + [0.0, 0.5, 1.0], + [1.0, 0.5, 0.0], + ) + run_test( + torch.minimum, + [0.0, 1.0, 2.0], + [1.0, 1.0, 1.0], + [1.0, 0.5, 0.0], + [0.0, 0.5, 1.0], + ) def test_maximum_minimum_forward_ad_float32(self, device): # TODO: This should really be covered by OpInfo but it isn't. The problem @@ -1844,7 +2337,9 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(x * y, 4.5) self.assertEqual(y * x, 4.5) - with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): + with self.assertRaisesRegex( + RuntimeError, "can't be cast to the desired output type" + ): y *= x x *= y self.assertEqual(x, 4.5) @@ -1868,7 +2363,7 @@ class TestBinaryUfuncs(TestCase): if dtype == torch.bool: self.assertRaises(RuntimeError, lambda: m1 - m2) - elif (dtype == torch.bfloat16 or dtype == torch.half): + elif dtype == torch.bfloat16 or dtype == torch.half: # bfloat16 has a lower precision so we have to have a separate check for it self.assertEqual(m1 - m2, diff, atol=0.01, rtol=0) else: @@ -1904,27 +2399,43 @@ class TestBinaryUfuncs(TestCase): b = torch.rand(1000, dtype=dtype, device=device) # 0:250: a -- nan, b -- not nan - a[:250] = float('nan') + a[:250] = float("nan") # 250:500: a -- not nan, b -- nan - b[250:500] = float('nan') + b[250:500] = float("nan") # 500:750: a and b both nan - a[500:750] = float('nan') - b[500:750] = float('nan') + a[500:750] = float("nan") + b[500:750] = float("nan") # 750:1000: neither nan ma = torch.max(a, b) mi = torch.min(a, b) for i in range(750): - self.assertTrue(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i])) - self.assertTrue(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i])) + self.assertTrue( + torch.isnan(ma[i]), + "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i]), + ) + self.assertTrue( + torch.isnan(mi[i]), + "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i]), + ) for i in range(750, 1000): - self.assertFalse(torch.isnan(ma[i]), "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i])) - self.assertFalse(torch.isnan(mi[i]), "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i])) + self.assertFalse( + torch.isnan(ma[i]), + "max(a, b): {}, a: {}, b: {}".format(ma[i], a[i], b[i]), + ) + self.assertFalse( + torch.isnan(mi[i]), + "min(a, b): {}, a: {}, b: {}".format(mi[i], a[i], b[i]), + ) - @dtypes(*product(all_types_and(torch.half, torch.bfloat16, torch.bool), - all_types_and(torch.half, torch.bfloat16, torch.bool))) + @dtypes( + *product( + all_types_and(torch.half, torch.bfloat16, torch.bool), + all_types_and(torch.half, torch.bfloat16, torch.bool), + ) + ) def test_copysign(self, device, dtypes): def _test_copysign_numpy(a, b): torch_result = torch.copysign(a, b) @@ -1956,8 +2467,10 @@ class TestBinaryUfuncs(TestCase): # Special case: NaN conversions between FP32 and FP16 is not bitwise # equivalent to pass this assertion. if a.dtype != torch.float16 and b.dtype != torch.float16: - self.assertEqual(torch.copysign(torch.tensor(1.0), torch_result), - torch.copysign(torch.tensor(1.0), expected)) + self.assertEqual( + torch.copysign(torch.tensor(1.0), torch_result), + torch.copysign(torch.tensor(1.0), expected), + ) # Compare Result with NumPy # Type promotion @@ -1975,52 +2488,76 @@ class TestBinaryUfuncs(TestCase): _test_copysign_numpy(a, b) # 0.0/-0.0/inf/-inf/nan - cases = [0.0, -0.0, float('inf'), float('-inf'), float('nan')] + cases = [0.0, -0.0, float("inf"), float("-inf"), float("nan")] # torch.bfloat16 can not hold '-nan' # torch.half can not hold '-nan' on CUDA types = [torch.float32, torch.float64] - if device == 'cpu': + if device == "cpu": types.append(torch.float16) if dtypes[0] in types: b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) for case in cases: - _test_copysign_numpy(torch.tensor([case], device=device, dtype=dtypes[0]), b) + _test_copysign_numpy( + torch.tensor([case], device=device, dtype=dtypes[0]), b + ) if dtypes[1] in floating_types_and(torch.half, torch.bfloat16): a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) for case in cases: - _test_copysign_numpy(a, torch.tensor([case], device=device, dtype=dtypes[1])) + _test_copysign_numpy( + a, torch.tensor([case], device=device, dtype=dtypes[1]) + ) - @dtypes(*product(floating_types_and(torch.half, torch.bfloat16), - floating_types_and(torch.half, torch.bfloat16))) + @dtypes( + *product( + floating_types_and(torch.half, torch.bfloat16), + floating_types_and(torch.half, torch.bfloat16), + ) + ) def test_copysign_subgradient(self, device, dtypes): # Input is 0.0 - x = torch.tensor([0.0, 0.0, 0.0], dtype=dtypes[0], device=device, requires_grad=True) - y = torch.tensor([-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True) + x = torch.tensor( + [0.0, 0.0, 0.0], dtype=dtypes[0], device=device, requires_grad=True + ) + y = torch.tensor( + [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True + ) out = torch.copysign(x, y) out.sum().backward() self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0]) self.assertEqual(y.grad.tolist(), [0.0] * 3) # Input is -0.0 - x = torch.tensor([-0.0, -0.0, -0.0], dtype=dtypes[0], device=device, requires_grad=True) - y = torch.tensor([-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True) + x = torch.tensor( + [-0.0, -0.0, -0.0], dtype=dtypes[0], device=device, requires_grad=True + ) + y = torch.tensor( + [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True + ) out = torch.copysign(x, y) out.sum().backward() self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0]) self.assertEqual(y.grad.tolist(), [0.0] * 3) # Other is 0.0 - x = torch.tensor([-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True) - y = torch.tensor([0.0, 0.0, 0.0], dtype=dtypes[1], device=device, requires_grad=True) + x = torch.tensor( + [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True + ) + y = torch.tensor( + [0.0, 0.0, 0.0], dtype=dtypes[1], device=device, requires_grad=True + ) out = torch.copysign(x, y) out.sum().backward() self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0]) self.assertEqual(y.grad.tolist(), [0.0] * 3) # Other is -0.0 - x = torch.tensor([-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True) - y = torch.tensor([-0.0, -0.0, -0.0], dtype=dtypes[1], device=device, requires_grad=True) + x = torch.tensor( + [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True + ) + y = torch.tensor( + [-0.0, -0.0, -0.0], dtype=dtypes[1], device=device, requires_grad=True + ) out = torch.copysign(x, y) out.sum().backward() self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0]) @@ -2028,9 +2565,10 @@ class TestBinaryUfuncs(TestCase): @dtypes(torch.bfloat16, torch.float) def test_div(self, device, dtype): - for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_), - (torch.true_divide, torch.Tensor.true_divide, - torch.Tensor.true_divide_)): + for op, method, inplace in ( + (torch.div, torch.Tensor.div, torch.Tensor.div_), + (torch.true_divide, torch.Tensor.true_divide, torch.Tensor.true_divide_), + ): m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype) res1 = m1.clone() inplace(res1[:, 3], 2) @@ -2041,40 +2579,48 @@ class TestBinaryUfuncs(TestCase): if dtype == torch.bfloat16: a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) - a2 = torch.tensor([2., 2.], dtype=dtype, device=device) - self.assertEqual(op(a1, a2), - torch.tensor([2.1, 3.1], dtype=dtype, device=device), - atol=0.01, rtol=0) + a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device) + self.assertEqual( + op(a1, a2), + torch.tensor([2.1, 3.1], dtype=dtype, device=device), + atol=0.01, + rtol=0, + ) self.assertEqual(method(a1, a2), op(a1, a2)) @dtypes(torch.bfloat16, torch.float) def test_true_divide_out(self, device, dtype): a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) - a2 = torch.tensor([2., 2.], dtype=dtype, device=device) + a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device) res = torch.empty_like(a1) - self.assertEqual(torch.true_divide(a1, a2, out=res), - torch.tensor([2.1, 3.1], dtype=dtype, device=device), - atol=0.01, rtol=0) + self.assertEqual( + torch.true_divide(a1, a2, out=res), + torch.tensor([2.1, 3.1], dtype=dtype, device=device), + atol=0.01, + rtol=0, + ) @onlyCUDA @dtypes(torch.half) def test_divmul_scalar(self, device, dtype): - x = torch.tensor(100., device=device, dtype=dtype) + x = torch.tensor(100.0, device=device, dtype=dtype) x_ref = x.float() scale = 1e5 res = x.div(scale) expected = x_ref.div(scale) - self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) + self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0) x = torch.tensor(1e-5, device=device, dtype=dtype) x_ref = x.float() res = x.mul(scale) expected = x_ref.mul(scale) - self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) + self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0) res = scale * x - self.assertEqual(res, expected.to(dtype), atol=0., rtol=0.) + self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0) - @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) - @dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) + @dtypesIfCUDA( + *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128} + ) + @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128}) def test_floor_divide_tensor(self, device, dtype): x = torch.randn(10, device=device).mul(30).to(dtype) y = torch.arange(1, 11, dtype=dtype, device=device) @@ -2086,14 +2632,18 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(z.dtype, x.dtype) self.assertEqual(z, z_alt) - @dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) - @dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) + @dtypesIfCUDA( + *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128} + ) + @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128}) def test_floor_divide_scalar(self, device, dtype): x = torch.randn(100, device=device).mul(10).to(dtype) with self.assertWarnsOnceRegex(UserWarning, "__floordiv__"): z = x // 3 - z_alt = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=x.dtype, device=device) + z_alt = torch.tensor( + [math.trunc(v.item() / 3.0) for v in x], dtype=x.dtype, device=device + ) self.assertEqual(z.dtype, x.dtype) self.assertEqual(z, z_alt) @@ -2120,7 +2670,7 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(o, torch.floor_divide(x.float(), y.float())) @onlyCPU - @dtypes(*get_all_math_dtypes('cpu')) + @dtypes(*get_all_math_dtypes("cpu")) def test_rdiv(self, device, dtype): if dtype is torch.float16: return @@ -2151,7 +2701,7 @@ class TestBinaryUfuncs(TestCase): x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) zero = torch.zeros_like(x) # RuntimeError on CPU - if self.device_type == 'cpu': + if self.device_type == "cpu": with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): fn(x, zero) elif torch.version.hip is not None: @@ -2195,9 +2745,12 @@ class TestBinaryUfuncs(TestCase): inplace_fn(x, mod) self.assertEqual(x, exp, exact_dtype=False) except RuntimeError as e: - self.assertRegex(str(e), "result type (Half|Float|Double) " - "can't be cast to the desired output " - "type (Byte|Char|Short|Int|Long)") + self.assertRegex( + str(e), + "result type (Half|Float|Double) " + "can't be cast to the desired output " + "type (Byte|Char|Short|Int|Long)", + ) x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) # mod with same dtype as x @@ -2209,20 +2762,30 @@ class TestBinaryUfuncs(TestCase): mods = [3, 2.3, mod, mod.t()] # mod with floating-point dtype if dtype in integral_types(): - mod_float = make_tensor((10, 10), device=device, dtype=torch.float, low=-9, high=9) + mod_float = make_tensor( + (10, 10), device=device, dtype=torch.float, low=-9, high=9 + ) mod[mod == 0] = 1 mods.append(mod_float) for dividend, mod in product([x, x.t()], mods): - _helper(dividend, mod, - ((torch.fmod, torch.Tensor.fmod_, np.fmod), - (torch.remainder, torch.Tensor.remainder_, np.remainder),)) + _helper( + dividend, + mod, + ( + (torch.fmod, torch.Tensor.fmod_, np.fmod), + (torch.remainder, torch.Tensor.remainder_, np.remainder), + ), + ) # Tests for torch.remainder(scalar, tensor) for dividend, mod in product([5, 3.14], mods): if torch.is_tensor(mod): - _helper(dividend, mod, - ((torch.remainder, torch.Tensor.remainder_, np.remainder),)) + _helper( + dividend, + mod, + ((torch.remainder, torch.Tensor.remainder_, np.remainder),), + ) @dtypes(torch.float, torch.double) def test_remainder_fmod_large_dividend(self, device, dtype): @@ -2234,23 +2797,45 @@ class TestBinaryUfuncs(TestCase): b = torch.tensor([bvalue], dtype=dtype, device=device) c = torch.remainder(a, b) d = torch.fmod(a, b) - self.assertTrue((b[0] > 0) == (c[0] > 0)) # remainder has same sign as divisor - self.assertTrue((a[0] > 0) == (d[0] > 0)) # fmod has same sign as dividend - self.assertTrue(abs(c[0]) < abs(b[0])) # remainder is within range of divisor - self.assertTrue(abs(d[0]) < abs(b[0])) # fmod is within range of divisor - if ((a[0] > 0) == (b[0] > 0)): - self.assertTrue(c[0] == d[0]) # remainder is same as fmod + self.assertTrue( + (b[0] > 0) == (c[0] > 0) + ) # remainder has same sign as divisor + self.assertTrue( + (a[0] > 0) == (d[0] > 0) + ) # fmod has same sign as dividend + self.assertTrue( + abs(c[0]) < abs(b[0]) + ) # remainder is within range of divisor + self.assertTrue( + abs(d[0]) < abs(b[0]) + ) # fmod is within range of divisor + if (a[0] > 0) == (b[0] > 0): + self.assertTrue(c[0] == d[0]) # remainder is same as fmod else: - self.assertTrue(abs(c[0] - d[0]) == abs(b[0])) # differ by one divisor + self.assertTrue( + abs(c[0] - d[0]) == abs(b[0]) + ) # differ by one divisor @dtypesIfCPU(torch.bfloat16, torch.float32, torch.float64) @dtypes(torch.float32, torch.float64) def test_hypot(self, device, dtype): inputs = [ - (torch.randn(10, device=device).to(dtype), torch.randn(10, device=device).to(dtype)), - (torch.randn((3, 3, 3), device=device).to(dtype), torch.randn((3, 3, 3), device=device).to(dtype)), - (torch.randn((10, 1), device=device).to(dtype), torch.randn((10, 1), device=device).to(dtype).transpose(0, 1)), - (torch.randint(100, (10, ), device=device, dtype=torch.long), torch.randn(10, device=device).to(dtype)) + ( + torch.randn(10, device=device).to(dtype), + torch.randn(10, device=device).to(dtype), + ), + ( + torch.randn((3, 3, 3), device=device).to(dtype), + torch.randn((3, 3, 3), device=device).to(dtype), + ), + ( + torch.randn((10, 1), device=device).to(dtype), + torch.randn((10, 1), device=device).to(dtype).transpose(0, 1), + ), + ( + torch.randint(100, (10,), device=device, dtype=torch.long), + torch.randn(10, device=device).to(dtype), + ), ] for input in inputs: actual = torch.hypot(input[0], input[1]) @@ -2329,8 +2914,8 @@ class TestBinaryUfuncs(TestCase): @onlyNativeDeviceTypes @dtypes(torch.bfloat16) def test_nextafter_bfloat16(self, device, dtype): - nan = float('nan') - inf = float('inf') + nan = float("nan") + inf = float("inf") cases = ( # (from, to, expected) (0, 1, 9.183549615799121e-41), @@ -2346,7 +2931,7 @@ class TestBinaryUfuncs(TestCase): (20, -3000, 19.875), (3000, -20, 2992.0), (-3000, 20, -2992.0), - (65536, 0, 65280.0) , + (65536, 0, 65280.0), (65536, inf, 66048.0), (-65536, 0, -65280.0), (-65536, -inf, -66048.0), @@ -2355,11 +2940,11 @@ class TestBinaryUfuncs(TestCase): (nan, nan, nan), (nan, inf, nan), (inf, nan, nan), - (inf, -inf, 3.3895313892515355e+38), - (-inf, inf, -3.3895313892515355e+38), - (inf, 0, 3.3895313892515355e+38), + (inf, -inf, 3.3895313892515355e38), + (-inf, inf, -3.3895313892515355e38), + (inf, 0, 3.3895313892515355e38), (0, inf, 9.183549615799121e-41), - (-inf, 0, -3.3895313892515355e+38), + (-inf, 0, -3.3895313892515355e38), (0, -inf, -9.183549615799121e-41), ) @@ -2392,10 +2977,17 @@ class TestBinaryUfuncs(TestCase): sm1 = m1[:, 4] sm2 = m2[:, 4] # view as sm1.size() - sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0])) + sm2.set_( + sm2.storage(), + sm2.storage_offset(), + sm1.size(), + (sm2.stride()[0] * 10, sm2.stride()[0]), + ) res1 = torchfn(sm1, sm2) # reference_implementation assumes 1-d sm2 - sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()) + sm2.set_( + sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride() + ) res2 = reference_implementation(res1.clone()) self.assertEqual(res1, res2) @@ -2417,14 +3009,16 @@ class TestBinaryUfuncs(TestCase): @onlyCPU @dtypes(torch.float) def test_cpow(self, device, dtype): - self._test_cop(torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device) + self._test_cop( + torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device + ) @onlyCPU @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) def test_floor_divide_zero(self, device, dtype): a = torch.tensor([0, 1], dtype=dtype, device=device) b = torch.tensor([0, 1], dtype=dtype, device=device) - with self.assertRaisesRegex(RuntimeError, 'ZeroDivisionError'): + with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): a // b @@ -2450,17 +3044,19 @@ class TestBinaryUfuncs(TestCase): x = torch.randn(*shape, device=device) * random.randint(30, 100) x = x.to(torch.bfloat16) else: - x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) + x = torch.randn( + *shape, dtype=dtype, device=device + ) * random.randint(30, 100) x[torch.randn(*shape) > 0.5] = 0 if with_extremal and dtype.is_floating_point: # Use extremal values - x[torch.randn(*shape) > 0.5] = float('nan') - x[torch.randn(*shape) > 0.5] = float('inf') - x[torch.randn(*shape) > 0.5] = float('-inf') + x[torch.randn(*shape) > 0.5] = float("nan") + x[torch.randn(*shape) > 0.5] = float("inf") + x[torch.randn(*shape) > 0.5] = float("-inf") elif with_extremal and dtype.is_complex: - x[torch.randn(*shape) > 0.5] = complex('nan') - x[torch.randn(*shape) > 0.5] = complex('inf') - x[torch.randn(*shape) > 0.5] = complex('-inf') + x[torch.randn(*shape) > 0.5] = complex("nan") + x[torch.randn(*shape) > 0.5] = complex("inf") + x[torch.randn(*shape) > 0.5] = complex("-inf") elif dtype == torch.bool: x = torch.zeros(shape, dtype=dtype, device=device) x[torch.randn(*shape) > 0.5] = True @@ -2469,8 +3065,13 @@ class TestBinaryUfuncs(TestCase): return x - @dtypes(*tuple(itertools.combinations_with_replacement(all_types_and_complex_and(torch.half, - torch.bfloat16, torch.bool), 2))) + @dtypes( + *tuple( + itertools.combinations_with_replacement( + all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 2 + ) + ) + ) def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes): # issue #42660 # testing all combinations of broadcasting and type promotion @@ -2479,29 +3080,35 @@ class TestBinaryUfuncs(TestCase): # working around the fact that numpy doesn't support bfloat16 # by letting numpy treat them as float32's x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32) - y_np = y.cpu().numpy() if y.dtype != torch.bfloat16 else y.to(torch.float32).cpu().numpy() - self.compare_with_numpy(lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y), - lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np), - x_np) + y_np = ( + y.cpu().numpy() + if y.dtype != torch.bfloat16 + else y.to(torch.float32).cpu().numpy() + ) + self.compare_with_numpy( + lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y), + lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np), + x_np, + ) - complex_op_denylist = [torch.lt, torch.le, torch.gt, torch.ge] # complex not supported - input_sizes = [ - (1,), - (10,), - (10, 1), - (1, 10), - (4, 10), - (64, 10), - (12, 3)] - op_pairs = [(torch.lt, np.less), - (torch.le, np.less_equal), - (torch.gt, np.greater), - (torch.ge, np.greater_equal), - (torch.eq, np.equal), - (torch.ne, np.not_equal), - (torch.logical_and, np.logical_and), - (torch.logical_or, np.logical_or), - (torch.logical_xor, np.logical_xor)] + complex_op_denylist = [ + torch.lt, + torch.le, + torch.gt, + torch.ge, + ] # complex not supported + input_sizes = [(1,), (10,), (10, 1), (1, 10), (4, 10), (64, 10), (12, 3)] + op_pairs = [ + (torch.lt, np.less), + (torch.le, np.less_equal), + (torch.gt, np.greater), + (torch.ge, np.greater_equal), + (torch.eq, np.equal), + (torch.ne, np.not_equal), + (torch.logical_and, np.logical_and), + (torch.logical_or, np.logical_or), + (torch.logical_xor, np.logical_xor), + ] for size1 in input_sizes: size2 = (2,) + size1 # perform broadcasting @@ -2509,7 +3116,9 @@ class TestBinaryUfuncs(TestCase): a = self._generate_input(size1, dtypes[0], device, with_extremal) b = self._generate_input(size2, dtypes[1], device, with_extremal) for torch_op, numpy_op in op_pairs: - if (dtypes[0].is_complex or dtypes[1].is_complex) and torch_op in complex_op_denylist: + if ( + dtypes[0].is_complex or dtypes[1].is_complex + ) and torch_op in complex_op_denylist: continue # functional version of op compare_with_numpy_bin_op(torch_op, numpy_op, a, b) @@ -2518,7 +3127,9 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(torch_op(a, b).dtype, torch.bool) # out version of op - out = torch.zeros(1, dtype=torch.complex128) # all casts to complex128 are safe + out = torch.zeros( + 1, dtype=torch.complex128 + ) # all casts to complex128 are safe compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out) @onlyNativeDeviceTypes @@ -2526,10 +3137,14 @@ class TestBinaryUfuncs(TestCase): def test_signed_shift(self, device, dtype): "Ensure that signed integer bit shifting works as expected." a = torch.tensor([-10, 10], device=device, dtype=dtype) # [11...1110110, 1010] - expected_l = torch.tensor([-40, 40], device=device, dtype=dtype) # [11...11011000, 101000] + expected_l = torch.tensor( + [-40, 40], device=device, dtype=dtype + ) # [11...11011000, 101000] self.assertEqual(a << 2, expected_l) self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a) - expected_r = torch.tensor([-5, 5], device=device, dtype=dtype) # [1111...111011, 101] + expected_r = torch.tensor( + [-5, 5], device=device, dtype=dtype + ) # [1111...111011, 101] self.assertEqual(a >> 1, expected_r) self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a) @@ -2542,13 +3157,23 @@ class TestBinaryUfuncs(TestCase): b_np = b.cpu().numpy() # Tensor x Tensor - self.assertEqual(torch.bitwise_and(a, b), torch.tensor(np.bitwise_and(a_np, b_np), device=device)) + self.assertEqual( + torch.bitwise_and(a, b), + torch.tensor(np.bitwise_and(a_np, b_np), device=device), + ) # Tensor x int scaler - self.assertEqual(torch.bitwise_and(a, 2), torch.tensor(np.bitwise_and(a_np, 2), device=device)) + self.assertEqual( + torch.bitwise_and(a, 2), + torch.tensor(np.bitwise_and(a_np, 2), device=device), + ) - self.assertEqual(torch.tensor([False, True, False], device=device), - torch.bitwise_and(torch.tensor([True, True, False], device=device), - torch.tensor([False, True, False], device=device))) + self.assertEqual( + torch.tensor([False, True, False], device=device), + torch.bitwise_and( + torch.tensor([True, True, False], device=device), + torch.tensor([False, True, False], device=device), + ), + ) # type promotion c = torch.zeros(2) >= 1 @@ -2580,9 +3205,13 @@ class TestBinaryUfuncs(TestCase): a.bitwise_or_(b_scalar) self.assertEqual(a, expected_res_scalar) - self.assertEqual(torch.tensor([True, True, False], device=device), - torch.bitwise_or(torch.tensor([True, True, False], device=device), - torch.tensor([False, True, False], device=device))) + self.assertEqual( + torch.tensor([True, True, False], device=device), + torch.bitwise_or( + torch.tensor([True, True, False], device=device), + torch.tensor([False, True, False], device=device), + ), + ) def test_bitwise_xor(self, device): for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): @@ -2610,9 +3239,13 @@ class TestBinaryUfuncs(TestCase): a.bitwise_xor_(b_scalar) self.assertEqual(a, expected_res_scalar) - self.assertEqual(torch.tensor([True, False, False], device=device), - torch.bitwise_xor(torch.tensor([True, True, False], device=device), - torch.tensor([False, True, False], device=device))) + self.assertEqual( + torch.tensor([True, False, False], device=device), + torch.bitwise_xor( + torch.tensor([True, True, False], device=device), + torch.tensor([False, True, False], device=device), + ), + ) @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) def test_bitwise_shift(self, device, dtype): @@ -2629,42 +3262,68 @@ class TestBinaryUfuncs(TestCase): b_np = b.cpu().numpy() # Tensor x Tensor - self.assertEqual(torch_op(a, b), torch.tensor(numpy_op(a_np, b_np), device=device)) + self.assertEqual( + torch_op(a, b), torch.tensor(numpy_op(a_np, b_np), device=device) + ) # Tensor x int scalar - self.assertEqual(torch_op(a, 2), torch.tensor(numpy_op(a_np, 2), device=device)) + self.assertEqual( + torch_op(a, 2), torch.tensor(numpy_op(a_np, 2), device=device) + ) def test_bitwise_shift_float(self, device): ops = [ - (torch.bitwise_left_shift, lambda x, y: x * 2. ** y), - (operator.lshift, lambda x, y: x * 2. ** y), - (torch.bitwise_right_shift, lambda x, y: x / 2. ** y), - (operator.rshift, lambda x, y: x / 2. ** y), + (torch.bitwise_left_shift, lambda x, y: x * 2.0**y), + (operator.lshift, lambda x, y: x * 2.0**y), + (torch.bitwise_right_shift, lambda x, y: x / 2.0**y), + (operator.rshift, lambda x, y: x / 2.0**y), ] for torch_op, expected_op in ops: # int tensor x float a = torch.tensor([19, -20, -21, 22], dtype=torch.int64, device=device) - self.assertEqual(torch_op(a, 1.8), torch.floor(expected_op(a, 1)).to(a.dtype)) + self.assertEqual( + torch_op(a, 1.8), torch.floor(expected_op(a, 1)).to(a.dtype) + ) # float tensor x int scalar - a = torch.tensor([19.1, -20.2, -21.3, 22.4], dtype=torch.float32, device=device) + a = torch.tensor( + [19.1, -20.2, -21.3, 22.4], dtype=torch.float32, device=device + ) self.assertEqual(torch_op(a, 2), expected_op(a, 2)) # float tensor x float scalar - a = torch.tensor([19.1, -20.2, -21.3, 22.4], dtype=torch.float32, device=device) + a = torch.tensor( + [19.1, -20.2, -21.3, 22.4], dtype=torch.float32, device=device + ) self.assertEqual(torch_op(a, 2.2), expected_op(a, 2.2)) @onlyNativeDeviceTypes - @dtypes(*list(product(all_types_and(torch.half, torch.bfloat16, torch.bool), - all_types_and(torch.half, torch.bfloat16, torch.bool)))) + @dtypes( + *list( + product( + all_types_and(torch.half, torch.bfloat16, torch.bool), + all_types_and(torch.half, torch.bfloat16, torch.bool), + ) + ) + ) def test_heaviside(self, device, dtypes): input_dtype = dtypes[0] values_dtype = dtypes[1] rng = np.random.default_rng() - input = np.array(rng.integers(-10, 10, size=10), - dtype=torch_to_numpy_dtype_dict[input_dtype if (input_dtype != torch.bfloat16) else torch.float64]) + input = np.array( + rng.integers(-10, 10, size=10), + dtype=torch_to_numpy_dtype_dict[ + input_dtype if (input_dtype != torch.bfloat16) else torch.float64 + ], + ) input[0] = input[3] = input[7] = 0 - values = np.array(rng.integers(-10, 10, size=10), - dtype=torch_to_numpy_dtype_dict[values_dtype if (values_dtype != torch.bfloat16) else torch.float64]) - np_result = torch.from_numpy(np.heaviside(input, values)).to(device=device, dtype=input_dtype) + values = np.array( + rng.integers(-10, 10, size=10), + dtype=torch_to_numpy_dtype_dict[ + values_dtype if (values_dtype != torch.bfloat16) else torch.float64 + ], + ) + np_result = torch.from_numpy(np.heaviside(input, values)).to( + device=device, dtype=input_dtype + ) input = torch.from_numpy(input).to(device=device, dtype=input_dtype) values = torch.from_numpy(values).to(device=device, dtype=values_dtype) @@ -2683,13 +3342,25 @@ class TestBinaryUfuncs(TestCase): input.heaviside_(values) self.assertEqual(np_result, input) else: - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + with self.assertRaisesRegex( + RuntimeError, + "heaviside is not yet implemented for tensors with different dtypes.", + ): torch.heaviside(input, values) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + with self.assertRaisesRegex( + RuntimeError, + "heaviside is not yet implemented for tensors with different dtypes.", + ): input.heaviside(values) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + with self.assertRaisesRegex( + RuntimeError, + "heaviside is not yet implemented for tensors with different dtypes.", + ): torch.heaviside(input, values, out=out) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for tensors with different dtypes.'): + with self.assertRaisesRegex( + RuntimeError, + "heaviside is not yet implemented for tensors with different dtypes.", + ): input.heaviside_(values) @onlyCUDA @@ -2706,10 +3377,14 @@ class TestBinaryUfuncs(TestCase): x = torch.tensor([-9, 5, 0, 6, -2, 2]) y = torch.tensor(0, device=device) - with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): torch.heaviside(x, y) - with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'): + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): torch.heaviside(y, x) @dtypes(*list(product(complex_types(), complex_types()))) @@ -2723,13 +3398,21 @@ class TestBinaryUfuncs(TestCase): out = torch.empty_like(input) real = input.real - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + with self.assertRaisesRegex( + RuntimeError, "heaviside is not yet implemented for complex tensors." + ): torch.heaviside(input, real) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + with self.assertRaisesRegex( + RuntimeError, "heaviside is not yet implemented for complex tensors." + ): real.heaviside(values) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + with self.assertRaisesRegex( + RuntimeError, "heaviside is not yet implemented for complex tensors." + ): input.heaviside_(values) - with self.assertRaisesRegex(RuntimeError, 'heaviside is not yet implemented for complex tensors.'): + with self.assertRaisesRegex( + RuntimeError, "heaviside is not yet implemented for complex tensors." + ): torch.heaviside(real, real, out=out) def _test_logical(self, device, dtypes, op, a_, b_, expected_res_): @@ -2744,23 +3427,41 @@ class TestBinaryUfuncs(TestCase): getattr(torch, op)(a, b, out=c) self.assertEqual(expected_res.bool(), c) - getattr(a, op + '_')(b) + getattr(a, op + "_")(b) self.assertEqual(expected_res, a) - @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), - all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))) + @dtypes( + *product( + all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + ) + ) def test_logical_xor(self, device, dtypes): - self._test_logical(device, dtypes, 'logical_xor', [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1]) + self._test_logical( + device, dtypes, "logical_xor", [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1] + ) - @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), - all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))) + @dtypes( + *product( + all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + ) + ) def test_logical_and(self, device, dtypes): - self._test_logical(device, dtypes, 'logical_and', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0]) + self._test_logical( + device, dtypes, "logical_and", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0] + ) - @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), - all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))) + @dtypes( + *product( + all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), + ) + ) def test_logical_or(self, device, dtypes): - self._test_logical(device, dtypes, 'logical_or', [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1]) + self._test_logical( + device, dtypes, "logical_or", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1] + ) def test_remainder_overflow(self, device): # Check Integer Overflows @@ -2796,7 +3497,9 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(np_outcome, mantissas) # test bounds - mantissas = torch.tensor([float('inf'), float('-inf'), float('inf'), float('nan')], device=device) + mantissas = torch.tensor( + [float("inf"), float("-inf"), float("inf"), float("nan")], device=device + ) exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32) np_outcome = np.ldexp(mantissas.numpy(), exponents.numpy()) pt_outcome = torch.ldexp(mantissas, exponents) @@ -2805,12 +3508,17 @@ class TestBinaryUfuncs(TestCase): @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) def test_lerp(self, device, dtype): start_end_weight_shapes = [(), (5,), (5, 5)] - for shapes in product(start_end_weight_shapes, start_end_weight_shapes, start_end_weight_shapes): + for shapes in product( + start_end_weight_shapes, start_end_weight_shapes, start_end_weight_shapes + ): start = torch.randn(shapes[0], device=device, dtype=dtype) end = torch.randn(shapes[1], device=device, dtype=dtype) # Tensor weights - weights = [torch.randn(shapes[2], device=device, dtype=dtype), random.random()] + weights = [ + torch.randn(shapes[2], device=device, dtype=dtype), + random.random(), + ] if dtype.is_complex: weights += [complex(0, 1), complex(0.4, 1.2)] @@ -2818,7 +3526,7 @@ class TestBinaryUfuncs(TestCase): actual = torch.lerp(start, end, weight) actual_method = start.lerp(end, weight) self.assertEqual(actual, actual_method) - actual_out = torch.tensor(1., dtype=dtype, device=device) + actual_out = torch.tensor(1.0, dtype=dtype, device=device) torch.lerp(start, end, weight, out=actual_out) self.assertEqual(actual, actual_out) expected = start + weight * (end - start) @@ -2828,8 +3536,8 @@ class TestBinaryUfuncs(TestCase): @dtypes(torch.half, torch.bfloat16) def test_lerp_lowp(self, device, dtype): ref_dtype = torch.float - xvals = (0., -30000.) - yvals = (0.1, -20000.) + xvals = (0.0, -30000.0) + yvals = (0.1, -20000.0) xs = [torch.full((4,), xval, device=device, dtype=dtype) for xval in xvals] ys = [torch.full((4,), yval, device=device, dtype=dtype) for yval in yvals] weights = [70000, torch.full((4,), 8, device=device, dtype=dtype)] @@ -2839,7 +3547,7 @@ class TestBinaryUfuncs(TestCase): wref = w.float() if isinstance(w, torch.Tensor) else w actual = torch.lerp(x, y, w) expected = torch.lerp(xref, yref, wref).to(dtype) - self.assertEqual(actual, expected, atol=0., rtol=0.) + self.assertEqual(actual, expected, atol=0.0, rtol=0.0) def _test_logaddexp(self, device, dtype, base2): if base2: @@ -2871,8 +3579,16 @@ class TestBinaryUfuncs(TestCase): _test_helper(a, b) _test_helper(a[:3], b[:3]) - a = torch.tensor([float('inf'), float('-inf'), float('inf'), float("nan")], dtype=dtype, device=device) - b = torch.tensor([float('inf'), float('-inf'), float('-inf'), float("nan")], dtype=dtype, device=device) + a = torch.tensor( + [float("inf"), float("-inf"), float("inf"), float("nan")], + dtype=dtype, + device=device, + ) + b = torch.tensor( + [float("inf"), float("-inf"), float("-inf"), float("nan")], + dtype=dtype, + device=device, + ) _test_helper(a, b) @dtypes(torch.float32, torch.float64, torch.bfloat16) @@ -2950,9 +3666,15 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(torch.add(one, 1).dtype, torch.uint8) # bool - m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) - m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) - expected = torch.tensor([True, True, False, True, False, True], dtype=torch.bool, device=device) + m1 = torch.tensor( + [True, False, False, True, False, False], dtype=torch.bool, device=device + ) + m2 = torch.tensor( + [True, True, False, False, False, True], dtype=torch.bool, device=device + ) + expected = torch.tensor( + [True, True, False, True, False, True], dtype=torch.bool, device=device + ) self.assertEqual(m1 + m2, expected) # fused multiply add @@ -2962,56 +3684,70 @@ class TestBinaryUfuncs(TestCase): self.assertEqual(res, expected) # bfloat16 - m1 = torch.tensor([1., 2.], dtype=torch.bfloat16) - m2 = torch.tensor([3., 4.], dtype=torch.bfloat16) - self.assertEqual(m1 + m2, torch.tensor([4., 6.], dtype=torch.bfloat16)) + m1 = torch.tensor([1.0, 2.0], dtype=torch.bfloat16) + m2 = torch.tensor([3.0, 4.0], dtype=torch.bfloat16) + self.assertEqual(m1 + m2, torch.tensor([4.0, 6.0], dtype=torch.bfloat16)) # different alpha types m1 = torch.tensor([2 + 3j, 4 + 5j], dtype=torch.complex64, device=device) m2 = torch.tensor([4 + 5j, 2 + 3j], dtype=torch.complex64, device=device) # add complex numbers with float alpha res = torch.add(m1, m2, alpha=0.1) - expected = torch.tensor([2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device) + expected = torch.tensor( + [2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device + ) self.assertEqual(res, expected) # add complex numbers with complex alpha res = torch.add(m1, m2, alpha=complex(0.1, 0.2)) - expected = torch.tensor([1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device) + expected = torch.tensor( + [1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device + ) self.assertEqual(res, expected) # add complex numbers with integer alpha res = torch.add(m1, m2, alpha=2) - expected = torch.tensor([10. + 13.j, 8. + 11.j], dtype=torch.complex64, device=device) + expected = torch.tensor( + [10.0 + 13.0j, 8.0 + 11.0j], dtype=torch.complex64, device=device + ) self.assertEqual(res, expected) # mismatched alpha m1 = torch.tensor([1], dtype=torch.int8, device=device) m2 = torch.tensor([2], dtype=torch.int8, device=device) - self.assertRaisesRegex(RuntimeError, - r"Boolean alpha only supported for Boolean results\.", - lambda: torch.add(m1, m2, alpha=True)) - self.assertRaisesRegex(RuntimeError, - r"For integral input tensors, argument alpha must not be a floating point number\.", - lambda: torch.add(m1, m2, alpha=1.0)) + self.assertRaisesRegex( + RuntimeError, + r"Boolean alpha only supported for Boolean results\.", + lambda: torch.add(m1, m2, alpha=True), + ) + self.assertRaisesRegex( + RuntimeError, + r"For integral input tensors, argument alpha must not be a floating point number\.", + lambda: torch.add(m1, m2, alpha=1.0), + ) # mismatched alpha, float / double tensor and complex alpha msg = r"For non-complex input tensors, argument alpha must not be a complex number\." - m1 = torch.tensor([3., 4.], device=device) - m2 = torch.tensor([4., 3.], device=device) - self.assertRaisesRegex(RuntimeError, msg, - lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))) + m1 = torch.tensor([3.0, 4.0], device=device) + m2 = torch.tensor([4.0, 3.0], device=device) + self.assertRaisesRegex( + RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2)) + ) - m1 = torch.tensor([3., 4.], dtype=torch.double, device=device) - m2 = torch.tensor([4., 3.], dtype=torch.double, device=device) - self.assertRaisesRegex(RuntimeError, msg, - lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2))) + m1 = torch.tensor([3.0, 4.0], dtype=torch.double, device=device) + m2 = torch.tensor([4.0, 3.0], dtype=torch.double, device=device) + self.assertRaisesRegex( + RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2)) + ) # complex m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64) - m2 = torch.tensor(4., dtype=torch.float64) - self.assertRaisesRegex(RuntimeError, r"result type ComplexFloat can't be cast to the desired output type Double", - lambda: torch.add(m1, m1, out=m2)) - + m2 = torch.tensor(4.0, dtype=torch.float64) + self.assertRaisesRegex( + RuntimeError, + r"result type ComplexFloat can't be cast to the desired output type Double", + lambda: torch.add(m1, m1, out=m2), + ) @onlyCUDA def test_addsub_half_tensor(self, device): @@ -3026,30 +3762,44 @@ class TestBinaryUfuncs(TestCase): self.assertTrue(not (actual.isnan() or actual.isinf())) def test_sub_typing(self, device): - m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) - m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) - self.assertRaisesRegex(RuntimeError, - r"Subtraction, the `\-` operator, with two bool tensors is not supported. " - r"Use the `\^` or `logical_xor\(\)` operator instead.", - lambda: m1 - m2) - self.assertRaisesRegex(RuntimeError, - r"Subtraction, the `\-` operator, with a bool tensor is not supported. " - r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", - lambda: 1 - m1) - self.assertRaisesRegex(RuntimeError, - r"Subtraction, the `\-` operator, with a bool tensor is not supported. " - r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", - lambda: m2 - 1) + m1 = torch.tensor( + [True, False, False, True, False, False], dtype=torch.bool, device=device + ) + m2 = torch.tensor( + [True, True, False, False, False, True], dtype=torch.bool, device=device + ) + self.assertRaisesRegex( + RuntimeError, + r"Subtraction, the `\-` operator, with two bool tensors is not supported. " + r"Use the `\^` or `logical_xor\(\)` operator instead.", + lambda: m1 - m2, + ) + self.assertRaisesRegex( + RuntimeError, + r"Subtraction, the `\-` operator, with a bool tensor is not supported. " + r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", + lambda: 1 - m1, + ) + self.assertRaisesRegex( + RuntimeError, + r"Subtraction, the `\-` operator, with a bool tensor is not supported. " + r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", + lambda: m2 - 1, + ) # mismatched alpha m1 = torch.tensor([1], dtype=torch.int8, device=device) m2 = torch.tensor([2], dtype=torch.int8, device=device) - self.assertRaisesRegex(RuntimeError, - r"Boolean alpha only supported for Boolean results\.", - lambda: torch.sub(m1, m2, alpha=True)) - self.assertRaisesRegex(RuntimeError, - r"For integral input tensors, argument alpha must not be a floating point number\.", - lambda: torch.sub(m1, m2, alpha=1.0)) + self.assertRaisesRegex( + RuntimeError, + r"Boolean alpha only supported for Boolean results\.", + lambda: torch.sub(m1, m2, alpha=True), + ) + self.assertRaisesRegex( + RuntimeError, + r"For integral input tensors, argument alpha must not be a floating point number\.", + lambda: torch.sub(m1, m2, alpha=1.0), + ) def test_mul(self, device): m1 = torch.randn(10, 10, device=device) @@ -3062,28 +3812,58 @@ class TestBinaryUfuncs(TestCase): a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device) a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) - self.assertEqual(a1 * a2, torch.tensor([True, False, False, False], dtype=torch.bool, device=device)) + self.assertEqual( + a1 * a2, + torch.tensor([True, False, False, False], dtype=torch.bool, device=device), + ) - if device == 'cpu': + if device == "cpu": a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device) a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device) - self.assertEqual(a1 * a2, torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device), atol=0.01, rtol=0) + self.assertEqual( + a1 * a2, + torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device), + atol=0.01, + rtol=0, + ) self.assertEqual(a1.mul(a2), a1 * a2) def test_bool_tensor_comparison_ops(self, device): - a = torch.tensor([True, False, True, False, True, False], dtype=torch.bool, device=device) - b = torch.tensor([True, False, True, True, True, True], dtype=torch.bool, device=device) - self.assertEqual(a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) - self.assertEqual(a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) - self.assertEqual(a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)) - self.assertEqual(a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)) - self.assertEqual(a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a == torch.tensor(True, dtype=torch.bool, device=device), - torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a == torch.tensor(0, dtype=torch.bool, device=device), - torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device)) + a = torch.tensor( + [True, False, True, False, True, False], dtype=torch.bool, device=device + ) + b = torch.tensor( + [True, False, True, True, True, True], dtype=torch.bool, device=device + ) + self.assertEqual( + a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device) + ) + self.assertEqual( + a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device) + ) + self.assertEqual( + a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device) + ) + self.assertEqual( + a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device) + ) + self.assertEqual( + a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device) + ) + self.assertEqual( + a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device) + ) + self.assertEqual( + a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device) + ) + self.assertEqual( + a == torch.tensor(True, dtype=torch.bool, device=device), + torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device), + ) + self.assertEqual( + a == torch.tensor(0, dtype=torch.bool, device=device), + torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device), + ) self.assertFalse(a.equal(b)) @dtypes(*all_types_and(torch.half, torch.bfloat16, torch.bool)) @@ -3120,8 +3900,11 @@ class TestBinaryUfuncs(TestCase): actual = a.atan2(b) x = a.view(-1) y = b.view(-1) - expected = torch.tensor([math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())], - device=device, dtype=torch.double) + expected = torch.tensor( + [math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())], + device=device, + dtype=torch.double, + ) self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02) _test_atan2_with_size((2, 2), device) @@ -3142,10 +3925,10 @@ class TestBinaryUfuncs(TestCase): _test_atan2(0, -1, math.pi / -2, device, dtype) _test_atan2(-1, 0, math.pi, device, dtype) _test_atan2(1, 0, 0, device, dtype) - _test_atan2(-1, -1, math.pi * -3 / 4 , device, dtype) - _test_atan2(1, 1, math.pi / 4 , device, dtype) - _test_atan2(1, -1, math.pi / -4 , device, dtype) - _test_atan2(-1, 1, math.pi * 3 / 4 , device, dtype) + _test_atan2(-1, -1, math.pi * -3 / 4, device, dtype) + _test_atan2(1, 1, math.pi / 4, device, dtype) + _test_atan2(1, -1, math.pi / -4, device, dtype) + _test_atan2(-1, 1, math.pi * 3 / 4, device, dtype) def test_trapezoid(self, device): def test_dx(sizes, dim, dx, device): @@ -3168,7 +3951,9 @@ class TestBinaryUfuncs(TestCase): test_dx((0, 2), 0, 1.0, device) test_dx((0, 2), 1, 1.0, device) test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) - test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device) + test_x( + (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device + ) test_x((1, 10), 0, [1.0], device) test_x((0, 2), 0, [], device) test_x((0, 2), 1, [1.0, 2.0], device) @@ -3177,14 +3962,12 @@ class TestBinaryUfuncs(TestCase): test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device) test_x((2, 2, 4), -1, [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], device) - with self.assertRaisesRegex( - IndexError, - 'Dimension out of range'): + with self.assertRaisesRegex(IndexError, "Dimension out of range"): test_x((2, 3), 2, [], device) test_dx((2, 3), 2, 1.0, device) with self.assertRaisesRegex( - RuntimeError, - 'There must be one `x` value for each sample point'): + RuntimeError, "There must be one `x` value for each sample point" + ): test_x((2, 3), 1, [1.0, 2.0], device) test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device) @@ -3193,7 +3976,7 @@ class TestBinaryUfuncs(TestCase): import scipy.integrate - if hasattr(scipy.integrate, 'cumulative_trapezoid'): + if hasattr(scipy.integrate, "cumulative_trapezoid"): scipy_cumulative_trapezoid = scipy.integrate.cumulative_trapezoid else: # Older version of SciPy uses a different name scipy_cumulative_trapezoid = scipy.integrate.cumtrapz @@ -3208,14 +3991,20 @@ class TestBinaryUfuncs(TestCase): def test_x(sizes, dim, x, device): t = torch.randn(sizes, device=device) - actual = torch.cumulative_trapezoid(t, x=torch.tensor(x, device=device), dim=dim) + actual = torch.cumulative_trapezoid( + t, x=torch.tensor(x, device=device), dim=dim + ) expected = scipy_cumulative_trapezoid(t.cpu().numpy(), x=x, axis=dim) self.assertEqual(expected.shape, actual.shape) - self.assertEqual(expected, actual.cpu(), exact_dtype=False, atol=1e-4, rtol=1e-4) + self.assertEqual( + expected, actual.cpu(), exact_dtype=False, atol=1e-4, rtol=1e-4 + ) def test_empty_x(sizes, dim, x, device): t = torch.randn(sizes, device=device) - actual = torch.cumulative_trapezoid(t, x=torch.tensor(x, device=device), dim=dim) + actual = torch.cumulative_trapezoid( + t, x=torch.tensor(x, device=device), dim=dim + ) self.assertEqual(torch.empty(actual.shape), actual) test_dx((2,), -1, 1, device) @@ -3232,7 +4021,9 @@ class TestBinaryUfuncs(TestCase): test_x((2,), -1, [100, 50], device) test_x((4, 2), 0, [2, 3, 4, 5], device) test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) - test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device) + test_x( + (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device + ) test_x((1, 10), 0, [1.0], device) test_x((0, 2), 1, [1, 2], device) test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device) @@ -3240,41 +4031,51 @@ class TestBinaryUfuncs(TestCase): test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device) - test_empty_x((0, 2), 0, [], device) # SciPy failing when x == [], but our version returns empty + test_empty_x( + (0, 2), 0, [], device + ) # SciPy failing when x == [], but our version returns empty - with self.assertRaisesRegex( - IndexError, - 'Dimension out of range'): + with self.assertRaisesRegex(IndexError, "Dimension out of range"): test_x((2, 3), 2, [], device) test_dx((2, 3), 2, 1.0, device) with self.assertRaisesRegex( - RuntimeError, - 'There must be one `x` value for each sample point'): + RuntimeError, "There must be one `x` value for each sample point" + ): test_x((2, 3), 1, [1.0, 2.0], device) test_x((0, 2), 0, [1.0, 2.0], device) test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device) with self.assertRaisesRegex( - RuntimeError, - 'Currently, we only support dx as a real number'): - test_dx((2, 2), -1, complex(1, 1) , device) + RuntimeError, "Currently, we only support dx as a real number" + ): + test_dx((2, 2), -1, complex(1, 1), device) with self.assertRaisesRegex( - TypeError, 'received an invalid combination of arguments'): - actual = torch.cumulative_trapezoid(torch.randn((3, 3)), x=torch.randn((3, 3)), dx=3) + TypeError, "received an invalid combination of arguments" + ): + actual = torch.cumulative_trapezoid( + torch.randn((3, 3)), x=torch.randn((3, 3)), dx=3 + ) @skipMeta @dtypes(torch.double) def test_pow_scalar_overloads_mem_overlap(self, device, dtype): sz = 3 doubles = torch.randn(2 * sz, dtype=dtype, device=device) - self.check_internal_mem_overlap( - lambda t: t.pow_(42), 1, dtype, device) + self.check_internal_mem_overlap(lambda t: t.pow_(42), 1, dtype, device) self.unary_check_input_output_mem_overlap( - doubles, sz, lambda input, out: torch.pow(input, 42, out=out)) + doubles, sz, lambda input, out: torch.pow(input, 42, out=out) + ) self.unary_check_input_output_mem_overlap( - doubles, sz, lambda input, out: torch.pow(42, input, out=out)) + doubles, sz, lambda input, out: torch.pow(42, input, out=out) + ) - @dtypes(*list(product(all_types_and_complex_and(torch.half, torch.bfloat16), - all_types_and_complex_and(torch.half, torch.bfloat16)))) + @dtypes( + *list( + product( + all_types_and_complex_and(torch.half, torch.bfloat16), + all_types_and_complex_and(torch.half, torch.bfloat16), + ) + ) + ) def test_float_power(self, device, dtypes): def to_np(value): if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16: @@ -3283,7 +4084,11 @@ class TestBinaryUfuncs(TestCase): base_dtype = dtypes[0] exp_dtype = dtypes[1] - out_dtype = torch.complex128 if base_dtype.is_complex or exp_dtype.is_complex else torch.float64 + out_dtype = ( + torch.complex128 + if base_dtype.is_complex or exp_dtype.is_complex + else torch.float64 + ) base = make_tensor((30,), dtype=base_dtype, device=device, low=1, high=100) # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0 @@ -3295,13 +4100,27 @@ class TestBinaryUfuncs(TestCase): expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp))) exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2] - complex_exponents = exponents + [-2.5j, -1.0j, 1.0j, 2.5j, 1.0 + 1.0j, -1.0 - 1.5j, 3.3j] + complex_exponents = exponents + [ + -2.5j, + -1.0j, + 1.0j, + 2.5j, + 1.0 + 1.0j, + -1.0 - 1.5j, + 3.3j, + ] - for op in (torch.float_power, torch.Tensor.float_power, torch.Tensor.float_power_): + for op in ( + torch.float_power, + torch.Tensor.float_power, + torch.Tensor.float_power_, + ): # Case of Tensor x Tensor if op is torch.Tensor.float_power_ and base_dtype != out_dtype: - with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): + with self.assertRaisesRegex( + RuntimeError, "operation's result requires dtype" + ): op(base.clone(), exp) else: result = op(base.clone(), exp) @@ -3314,24 +4133,39 @@ class TestBinaryUfuncs(TestCase): # Case of Tensor x Scalar for i in complex_exponents if exp_dtype.is_complex else exponents: - out_dtype_scalar_exp = torch.complex128 if base_dtype.is_complex or type(i) == complex else torch.float64 + out_dtype_scalar_exp = ( + torch.complex128 + if base_dtype.is_complex or type(i) == complex + else torch.float64 + ) expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i)) - if op is torch.Tensor.float_power_ and base_dtype != out_dtype_scalar_exp: - with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): + if ( + op is torch.Tensor.float_power_ + and base_dtype != out_dtype_scalar_exp + ): + with self.assertRaisesRegex( + RuntimeError, "operation's result requires dtype" + ): op(base.clone(), i) else: result = op(base.clone(), i) self.assertEqual(expected_scalar_exp, result) if op is torch.float_power: - out = torch.empty_like(base).to(device=device, dtype=out_dtype_scalar_exp) + out = torch.empty_like(base).to( + device=device, dtype=out_dtype_scalar_exp + ) op(base, i, out=out) self.assertEqual(expected_scalar_exp, out) # Case of Scalar x Tensor for i in complex_exponents if base_dtype.is_complex else exponents: - out_dtype_scalar_base = torch.complex128 if exp_dtype.is_complex or type(i) == complex else torch.float64 + out_dtype_scalar_base = ( + torch.complex128 + if exp_dtype.is_complex or type(i) == complex + else torch.float64 + ) expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp))) result = torch.float_power(i, exp) @@ -3350,8 +4184,13 @@ class TestBinaryUfuncs(TestCase): return torch.complex128 return torch.double - test_cases = ((torch.tensor([-2, -1, 0, 1, 2], device=device), -.25), - (torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 2.)) + test_cases = ( + (torch.tensor([-2, -1, 0, 1, 2], device=device), -0.25), + ( + torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), + 2.0, + ), + ) for base, exp in test_cases: for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble): out = torch.empty(1, device=device, dtype=out_dtype) @@ -3360,18 +4199,25 @@ class TestBinaryUfuncs(TestCase): if out.dtype == required_dtype: torch.float_power(base, exp, out=out) else: - with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): + with self.assertRaisesRegex( + RuntimeError, "operation's result requires dtype" + ): torch.float_power(base, exp, out=out) if base.dtype == required_dtype: torch.Tensor.float_power_(base.clone(), exp) else: - with self.assertRaisesRegex(RuntimeError, "operation's result requires dtype"): + with self.assertRaisesRegex( + RuntimeError, "operation's result requires dtype" + ): torch.Tensor.float_power_(base.clone(), exp) @skipIf(not TEST_SCIPY, "Scipy required for the test.") - @dtypes(*product(all_types_and(torch.half, torch.bool), - all_types_and(torch.half, torch.bool))) + @dtypes( + *product( + all_types_and(torch.half, torch.bool), all_types_and(torch.half, torch.bool) + ) + ) def test_xlogy_xlog1py(self, device, dtypes): x_dtype, y_dtype = dtypes @@ -3383,8 +4229,9 @@ class TestBinaryUfuncs(TestCase): def xlogy_inplace_variant_helper(x, y): if x.dtype in integral_types_and(torch.bool): - with self.assertRaisesRegex(RuntimeError, - "can't be cast to the desired output type"): + with self.assertRaisesRegex( + RuntimeError, "can't be cast to the desired output type" + ): x.clone().xlogy_(y) else: expected = torch.empty_like(x) @@ -3396,9 +4243,15 @@ class TestBinaryUfuncs(TestCase): x, y, z = inputs torch_fn_partial = partial(torch_fn, x) reference_fn_partial = partial(reference_fn, x.cpu().numpy()) - self.compare_with_numpy(torch_fn_partial, reference_fn_partial, x, exact_dtype=False) - self.compare_with_numpy(torch_fn_partial, reference_fn_partial, y, exact_dtype=False) - self.compare_with_numpy(torch_fn_partial, reference_fn_partial, z, exact_dtype=False) + self.compare_with_numpy( + torch_fn_partial, reference_fn_partial, x, exact_dtype=False + ) + self.compare_with_numpy( + torch_fn_partial, reference_fn_partial, y, exact_dtype=False + ) + self.compare_with_numpy( + torch_fn_partial, reference_fn_partial, z, exact_dtype=False + ) val = scalar if scalar is not None else x out_variant_helper(torch_fn, val, x) @@ -3410,8 +4263,12 @@ class TestBinaryUfuncs(TestCase): y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000) z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000) - x_1p = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.5, high=1000) - y_1p = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000) + x_1p = make_tensor( + (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.5, high=1000 + ) + y_1p = make_tensor( + (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000 + ) z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000) xlogy_fns = torch.xlogy, scipy.special.xlogy @@ -3428,7 +4285,10 @@ class TestBinaryUfuncs(TestCase): test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p), 3.14) # Special Values Tensor-Tensor - t = torch.tensor([-1., 0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device) + t = torch.tensor( + [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")], + device=device, + ) zeros = torch.zeros(7, dtype=y_dtype, device=device) def test_zeros_special_helper(torch_fn, reference_fn, scalar=False): @@ -3436,7 +4296,9 @@ class TestBinaryUfuncs(TestCase): zeros_np = 0 if scalar else zeros.cpu().numpy() torch_fn_partial = partial(torch_fn, zeros_t) reference_fn_partial = partial(reference_fn, zeros_np) - self.compare_with_numpy(torch_fn_partial, reference_fn_partial, t, exact_dtype=False) + self.compare_with_numpy( + torch_fn_partial, reference_fn_partial, t, exact_dtype=False + ) out_variant_helper(torch_fn, zeros_t, t) test_zeros_special_helper(*xlogy_fns) @@ -3453,14 +4315,14 @@ class TestBinaryUfuncs(TestCase): t = torch.randn((), dtype=torch.float32, device=device) self.assertEqual(t.dtype, torch.xlogy(t, 5).dtype) - self.assertEqual(t.dtype, torch.xlogy(t, 5.).dtype) + self.assertEqual(t.dtype, torch.xlogy(t, 5.0).dtype) self.assertEqual(t.dtype, torch.special.xlog1py(t, 5).dtype) - self.assertEqual(t.dtype, torch.special.xlog1py(t, 5.).dtype) + self.assertEqual(t.dtype, torch.special.xlog1py(t, 5.0).dtype) self.assertEqual(t.dtype, torch.xlogy(5, t).dtype) - self.assertEqual(t.dtype, torch.xlogy(5., t).dtype) + self.assertEqual(t.dtype, torch.xlogy(5.0, t).dtype) self.assertEqual(t.dtype, torch.special.xlog1py(5, t).dtype) - self.assertEqual(t.dtype, torch.special.xlog1py(5., t).dtype) + self.assertEqual(t.dtype, torch.special.xlog1py(5.0, t).dtype) @skipIf(not TEST_SCIPY, "Scipy required for the test.") def test_xlogy_xlog1py_bfloat16(self, device): @@ -3478,8 +4340,12 @@ class TestBinaryUfuncs(TestCase): y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000) z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000) - x_1p = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.8, high=1000) - y_1p = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000) + x_1p = make_tensor( + (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.8, high=1000 + ) + y_1p = make_tensor( + (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000 + ) z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000) xlogy_fns = torch.xlogy, scipy.special.xlogy @@ -3500,14 +4366,17 @@ class TestBinaryUfuncs(TestCase): _compare_helper(z_1p, 3.14, *xlog1py_fns) # Special Values Tensor-Tensor - t = torch.tensor([-1., 0., 1., 2., float('inf'), -float('inf'), float('nan')], device=device) + t = torch.tensor( + [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")], + device=device, + ) zeros = torch.tensor(7, dtype=y_dtype, device=device) _compare_helper(t, zeros, *xlogy_fns) - _compare_helper(t, 0., *xlogy_fns) + _compare_helper(t, 0.0, *xlogy_fns) _compare_helper(t, zeros, *xlog1py_fns) - _compare_helper(t, 0., *xlog1py_fns) + _compare_helper(t, 0.0, *xlog1py_fns) @dtypes(*product(all_types_and(torch.bool), all_types_and(torch.bool))) @skipIf(not TEST_SCIPY, "Scipy required for the test.") @@ -3522,7 +4391,7 @@ class TestBinaryUfuncs(TestCase): actual = torch.special.zeta(x, q) rtol, atol = None, None - if self.device_type == 'cpu': + if self.device_type == "cpu": rtol, atol = 1e-6, 1e-6 self.assertEqual(expected, actual, rtol=rtol, atol=atol, exact_dtype=False) @@ -3576,24 +4445,50 @@ class TestBinaryUfuncs(TestCase): tensor_binary_ops = [ - '__lt__', '__le__', - '__gt__', '__ge__', - '__eq__', '__ne__', - - '__add__', '__radd__', '__iadd__', - '__sub__', '__rsub__', '__isub__', - '__mul__', '__rmul__', '__imul__', - '__matmul__', '__rmatmul__', - '__truediv__', '__rtruediv__', '__itruediv__', - '__floordiv__', '__rfloordiv__', '__ifloordiv__', - '__mod__', '__rmod__', '__imod__', - '__pow__', '__rpow__', '__ipow__', - '__lshift__', '__rlshift__', '__ilshift__', - '__rshift__', '__rrshift__', '__irshift__', - '__and__', '__rand__', '__iand__', - '__xor__', '__rxor__', '__ixor__', - '__or__', '__ror__', '__ior__', - + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__eq__", + "__ne__", + "__add__", + "__radd__", + "__iadd__", + "__sub__", + "__rsub__", + "__isub__", + "__mul__", + "__rmul__", + "__imul__", + "__matmul__", + "__rmatmul__", + "__truediv__", + "__rtruediv__", + "__itruediv__", + "__floordiv__", + "__rfloordiv__", + "__ifloordiv__", + "__mod__", + "__rmod__", + "__imod__", + "__pow__", + "__rpow__", + "__ipow__", + "__lshift__", + "__rlshift__", + "__ilshift__", + "__rshift__", + "__rrshift__", + "__irshift__", + "__and__", + "__rand__", + "__iand__", + "__xor__", + "__rxor__", + "__ixor__", + "__or__", + "__ror__", + "__ior__", # Unsupported operators # '__imatmul__', # '__divmod__', '__rdivmod__', '__idivmod__', @@ -3606,9 +4501,14 @@ def generate_not_implemented_tests(cls): # TODO: refactor to inline these _types = [ - torch.half, torch.float, torch.double, - torch.int8, torch.short, torch.int, torch.long, - torch.uint8 + torch.half, + torch.float, + torch.double, + torch.int8, + torch.short, + torch.int, + torch.long, + torch.uint8, ] def create_test_func(op): @@ -3620,12 +4520,14 @@ def generate_not_implemented_tests(cls): # Runs the tensor op on the device result = getattr(tensor, op)(UnknownType()) self.assertEqual(result, NotImplemented) + return test for op in tensor_binary_ops: test_name = "test_{}_not_implemented".format(op) assert not hasattr(cls, test_name), "{0} already in {1}".format( - test_name, cls.__name__) + test_name, cls.__name__ + ) setattr(cls, test_name, create_test_func(op)) @@ -3633,5 +4535,5 @@ def generate_not_implemented_tests(cls): generate_not_implemented_tests(TestBinaryUfuncs) instantiate_device_type_tests(TestBinaryUfuncs, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 6768280806d3..83aebf4356be 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -8,17 +8,44 @@ import itertools import torch from torch.testing import make_tensor -from torch.testing._internal.common_dtype import floating_and_complex_types_and, all_types_and_complex_and -from torch.testing._internal.common_utils import \ - (TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, - IS_IN_CI, suppress_warnings, noncontiguous_like, - TEST_WITH_ASAN, IS_WINDOWS, IS_FBCODE, first_sample) -from torch.testing._internal.common_methods_invocations import \ - (op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, ops_and_refs) -from torch.testing._internal.common_device_type import \ - (deviceCountAtLeast, instantiate_device_type_tests, ops, - onlyCUDA, onlyNativeDeviceTypes, OpDTypes, skipMeta) -# import torch._prims as prims +from torch.testing._internal.common_dtype import ( + floating_and_complex_types_and, + all_types_and_complex_and, +) +from torch.testing._internal.common_utils import ( + TestCase, + is_iterable_of_tensors, + run_tests, + IS_SANDCASTLE, + clone_input_helper, + IS_IN_CI, + suppress_warnings, + noncontiguous_like, + TEST_WITH_ASAN, + IS_WINDOWS, + IS_FBCODE, + first_sample, +) +from torch.testing._internal.common_methods_invocations import ( + op_db, + _NOTHING, + UnaryUfuncInfo, + ReductionOpInfo, + SpectralFuncInfo, + ops_and_refs, + python_ref_db, + BinaryUfuncInfo, +) +from torch.testing._internal.common_device_type import ( + deviceCountAtLeast, + instantiate_device_type_tests, + ops, + onlyCUDA, + onlyNativeDeviceTypes, + OpDTypes, + skipMeta, +) +import torch._prims as prims import torch.testing._internal.opinfo_helper as opinfo_helper from torch.testing._internal import composite_compliance @@ -28,15 +55,25 @@ torch.set_default_dtype(torch.float32) # variant testing is only done with torch.float and torch.cfloat to avoid # excessive test times and maximize signal to noise ratio -_variant_ops = partial(ops, dtypes=OpDTypes.supported, - allowed_dtypes=(torch.float, torch.cfloat)) +_variant_ops = partial( + ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat) +) # Get names of all the operators which have ref in their entry in OpInfo (testing infra) -# except for Unary Ufuncs (separately implemented in test/test_unary_ufuncs.py) +# except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py), +# elementwise binary operators (separately implemented in test_binary_ufuncs.py), +# reduction operations (separately impelemented in test_reductions.py), # and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py) -_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, ReductionOpInfo, - SpectralFuncInfo)) and op.ref is not None and op.ref is not _NOTHING, op_db)) - +_ref_test_ops = tuple( + filter( + lambda op: not isinstance( + op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo) + ) + and op.ref is not None + and op.ref is not _NOTHING, + op_db, + ) +) # Tests that apply to all operators and aren't related to any particular # system @@ -49,8 +86,10 @@ class TestCommon(TestCase): super().tearDownClass() if IS_IN_CI: - err_msg = ("The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries." - "This is OK for testing, but be sure to set the dtypes manually before landing your PR!") + err_msg = ( + "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries." + "This is OK for testing, but be sure to set the dtypes manually before landing your PR!" + ) # Assure no opinfo entry has dynamic_dtypes filtered_ops = list(filter(opinfo_helper.is_dynamic_dtype_set, op_db)) for op in filtered_ops: @@ -68,11 +107,16 @@ class TestCommon(TestCase): # Check complex32 support only if the op claims. # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. device_type = torch.device(device).type - include_complex32 = ((torch.complex32,) if op.supports_dtype(torch.complex32, device_type) else ()) + include_complex32 = ( + (torch.complex32,) + if op.supports_dtype(torch.complex32, device_type) + else () + ) # dtypes to try to backward in allowed_backward_dtypes = floating_and_complex_types_and( - *((torch.half, torch.bfloat16) + include_complex32)) + *((torch.half, torch.bfloat16) + include_complex32) + ) # lists for (un)supported dtypes supported_dtypes = set() @@ -86,11 +130,14 @@ class TestCommon(TestCase): unsupported_backward_dtypes.add(dtype) for dtype in all_types_and_complex_and( - *((torch.half, torch.bfloat16, torch.bool) + include_complex32)): + *((torch.half, torch.bfloat16, torch.bool) + include_complex32) + ): # tries to acquire samples - failure indicates lack of support - requires_grad = (dtype in allowed_backward_dtypes) + requires_grad = dtype in allowed_backward_dtypes try: - samples = tuple(op.sample_inputs(device, dtype, requires_grad=requires_grad)) + samples = tuple( + op.sample_inputs(device, dtype, requires_grad=requires_grad) + ) except Exception as e: unsupported(dtype) continue @@ -113,7 +160,9 @@ class TestCommon(TestCase): result = sample.output_process_fn_grad(result) if isinstance(result, torch.Tensor): backward_tensor = result - elif isinstance(result, Sequence) and isinstance(result[0], torch.Tensor): + elif isinstance(result, Sequence) and isinstance( + result[0], torch.Tensor + ): backward_tensor = result[0] else: continue @@ -130,14 +179,15 @@ class TestCommon(TestCase): except Exception as e: unsupported_backward_dtypes.add(dtype) - # Checks that dtypes are listed correctly and generates an informative # error message supported_forward = supported_dtypes - unsupported_dtypes partially_supported_forward = supported_dtypes & unsupported_dtypes unsupported_forward = unsupported_dtypes - supported_dtypes supported_backward = supported_backward_dtypes - unsupported_backward_dtypes - partially_supported_backward = supported_backward_dtypes & unsupported_backward_dtypes + partially_supported_backward = ( + supported_backward_dtypes & unsupported_backward_dtypes + ) unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes device_type = torch.device(device).type @@ -156,17 +206,27 @@ class TestCommon(TestCase): op.name, device_type ) if len(partially_supported_forward) > 0: - msg = msg + "The following dtypes only worked on some samples during forward: {0}.\n".format( - partially_supported_forward + msg = ( + msg + + "The following dtypes only worked on some samples during forward: {0}.\n".format( + partially_supported_forward + ) ) if len(partially_supported_backward) > 0: - msg = msg + "The following dtypes only worked on some samples during backward: {0}.\n".format( - partially_supported_backward + msg = ( + msg + + "The following dtypes only worked on some samples during backward: {0}.\n".format( + partially_supported_backward + ) ) print(msg) - if (len(supported_but_unclaimed_forward) + len(claimed_but_unsupported_forward) + - len(supported_but_unclaimed_backward) + len(claimed_but_unsupported_backward)) == 0: + if ( + len(supported_but_unclaimed_forward) + + len(claimed_but_unsupported_forward) + + len(supported_but_unclaimed_backward) + + len(claimed_but_unsupported_backward) + ) == 0: return # Generates error msg @@ -174,20 +234,32 @@ class TestCommon(TestCase): op.name, device_type ) if len(supported_but_unclaimed_forward) > 0: - msg = msg + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format( - supported_but_unclaimed_forward + msg = ( + msg + + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format( + supported_but_unclaimed_forward + ) ) if len(supported_but_unclaimed_backward) > 0: - msg = msg + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format( - supported_but_unclaimed_backward + msg = ( + msg + + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format( + supported_but_unclaimed_backward + ) ) if len(claimed_but_unsupported_forward) > 0: - msg = msg + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format( - claimed_but_unsupported_forward + msg = ( + msg + + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format( + claimed_but_unsupported_forward + ) ) if len(claimed_but_unsupported_backward) > 0: - msg = msg + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format( - claimed_but_unsupported_backward + msg = ( + msg + + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format( + claimed_but_unsupported_backward + ) ) self.fail(msg) @@ -209,7 +281,9 @@ class TestCommon(TestCase): elif is_iterable_of_tensors(result): self.assertTrue(all(map(lambda t: t.device == cuda_device, result))) else: - self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.") + self.skipTest( + "Skipped! Only supports single tensor or iterable of tensor outputs." + ) # Tests that the function and its (ndarray-accepting) reference produce the same # values on the tensors from sample_inputs func for the corresponding op. @@ -226,28 +300,48 @@ class TestCommon(TestCase): cur_default = torch.get_default_dtype() torch.set_default_dtype(torch.double) for sample_input in op.reference_inputs(device, dtype): - self.compare_with_reference(op, op.ref, sample_input, exact_dtype=(dtype is not torch.long)) + self.compare_with_reference( + op, op.ref, sample_input, exact_dtype=(dtype is not torch.long) + ) finally: torch.set_default_dtype(cur_default) - # Tests that experimental Python References' can propagate shape, dtype, + # Tests that experimental Python References can propagate shape, dtype, # and device metadata properly. # TODO: include stride propagation. - # @onlyNativeDeviceTypes - # @ops(python_ref_db) - # def test_python_reference_meta_functions(self, device, dtype, op): - # def _to_tensormeta(x): - # if isinstance(x, torch.Tensor): - # return prims.utils.TensorMeta(x) + @onlyNativeDeviceTypes + @ops(python_ref_db) + def test_python_reference_meta_functions(self, device, dtype, op): + def _to_tensormeta(x): + if isinstance(x, torch.Tensor): + return prims.utils.TensorMeta(x) - # # TODO: iterate over requires_grad true/false - # for sample in op.reference_inputs(device, dtype, requires_grad=False): - # result = op(sample.input, *sample.args, **sample.kwargs) + # TODO: iterate over requires_grad true/false + for sample in op.reference_inputs(device, dtype, requires_grad=False): + result = op(sample.input, *sample.args, **sample.kwargs) - # meta_sample = sample.transform(_to_tensormeta) - # meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs) + meta_sample = sample.transform(_to_tensormeta) + meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs) - # prims.utils.compare_tensor_meta(result, meta_result) + prims.utils.compare_tensor_meta(result, meta_result) + + # Tests that experimental Python References perform the same computation + # as the operators they reference. + @onlyNativeDeviceTypes + @ops(python_ref_db) + def test_python_reference_consistency(self, device, dtype, op): + for sample in op.reference_inputs(device, dtype, requires_grad=False): + actual = op(sample.input, *sample.args, **sample.kwargs) + expected = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs) + + self.assertEqual( + actual, + expected, + exact_stride=True, + exact_device=True, + exact_layout=True, + exact_is_coalesced=True, + ) @skipMeta @onlyNativeDeviceTypes @@ -272,9 +366,17 @@ class TestCommon(TestCase): test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type) sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad) for sample_input in sample_inputs: - t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs + t_inp, t_args, t_kwargs = ( + sample_input.input, + sample_input.args, + sample_input.kwargs, + ) noncontig_sample = sample_input.noncontiguous() - n_inp, n_args, n_kwargs = noncontig_sample.input, noncontig_sample.args, noncontig_sample.kwargs + n_inp, n_args, n_kwargs = ( + noncontig_sample.input, + noncontig_sample.args, + noncontig_sample.kwargs, + ) # Verifies sample input tensors should have no grad or history sample_tensor = t_inp if isinstance(t_inp, torch.Tensor) else t_inp[0] @@ -300,10 +402,14 @@ class TestCommon(TestCase): grad_for_actual = noncontiguous_like(grad_for_expected) elif isinstance(expected, Sequence): # Filter output elements that do not require grad - expected = [t for t in expected - if isinstance(t, torch.Tensor) and t.requires_grad] - actual = [n for n in actual - if isinstance(n, torch.Tensor) and n.requires_grad] + expected = [ + t + for t in expected + if isinstance(t, torch.Tensor) and t.requires_grad + ] + actual = [ + n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad + ] grad_for_expected = [torch.randn_like(t) for t in expected] grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected] else: @@ -311,19 +417,35 @@ class TestCommon(TestCase): continue # Concatenate inputs into a tuple - t_inputs = (t_inp,) + t_args if isinstance(t_inp, torch.Tensor) else tuple(t_inp) + t_args - n_inputs = (n_inp,) + n_args if isinstance(n_inp, torch.Tensor) else tuple(n_inp) + n_args + t_inputs = ( + (t_inp,) + t_args + if isinstance(t_inp, torch.Tensor) + else tuple(t_inp) + t_args + ) + n_inputs = ( + (n_inp,) + n_args + if isinstance(n_inp, torch.Tensor) + else tuple(n_inp) + n_args + ) # Filter the elemnts that are tensors that require grad - t_input_tensors = [t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad] - n_input_tensors = [n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad] + t_input_tensors = [ + t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad + ] + n_input_tensors = [ + n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad + ] self.assertEqual(len(t_input_tensors), len(n_input_tensors)) # Some functions may not use all the inputs to generate gradients. One of the # few examples of this "odd" behaviour is F.hinge_embedding_loss - t_grads = torch.autograd.grad(expected, t_input_tensors, grad_for_expected, allow_unused=True) - n_grads = torch.autograd.grad(actual, n_input_tensors, grad_for_actual, allow_unused=True) + t_grads = torch.autograd.grad( + expected, t_input_tensors, grad_for_expected, allow_unused=True + ) + n_grads = torch.autograd.grad( + actual, n_input_tensors, grad_for_actual, allow_unused=True + ) msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}." for i, (t, n) in enumerate(zip(t_grads, n_grads)): @@ -339,7 +461,11 @@ class TestCommon(TestCase): supported_dtypes = op.supported_dtypes(self.device_type) if len(supported_dtypes) == 0: self.skipTest("Skipped! Op has not supported dtypes on this device.") - dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0] + dtype = ( + torch.float32 + if torch.float32 in supported_dtypes + else list(supported_dtypes)[0] + ) samples = op.sample_inputs(device, dtype) for sample in samples: @@ -349,8 +475,12 @@ class TestCommon(TestCase): # Short-circuits if output is not a single tensor or an # iterable of tensors - if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True): - self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.") + if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( + expected, include_empty=True + ): + self.skipTest( + "Skipped! Only supports single tensor or iterable of tensor outputs." + ) # Validates the op doesn't support out if it claims not to if not op.supports_out: @@ -380,7 +510,7 @@ class TestCommon(TestCase): # NOTE: only extracts on the CPU and CUDA device types since some # device types don't have storage def _extract_data_ptrs(out): - if self.device_type != 'cpu' and self.device_type != 'cuda': + if self.device_type != "cpu" and self.device_type != "cuda": return () if isinstance(out, torch.Tensor): @@ -403,7 +533,8 @@ class TestCommon(TestCase): if compare_strides_and_data_ptrs: stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format( - original_strides, final_strides) + original_strides, final_strides + ) self.assertEqual(original_strides, final_strides, msg=stride_msg) self.assertEqual(original_ptrs, final_ptrs) @@ -433,7 +564,9 @@ class TestCommon(TestCase): out = _apply_out_transform(_case_zero_transform, expected) msg_fail = "Resized a non-empty tensor but did not warn about it." if _any_nonempty(out): - with self.assertWarnsRegex(UserWarning, "An output with one or more elements", msg=msg_fail): + with self.assertWarnsRegex( + UserWarning, "An output with one or more elements", msg=msg_fail + ): op_out(out=out) # Validates ops implement the correct out= behavior @@ -452,7 +585,11 @@ class TestCommon(TestCase): supported_dtypes = op.supported_dtypes(self.device_type) if len(supported_dtypes) == 0: self.skipTest("Skipped! Op has not supported dtypes on this device.") - dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0] + dtype = ( + torch.float32 + if torch.float32 in supported_dtypes + else list(supported_dtypes)[0] + ) samples = op.sample_inputs(device, dtype) for sample in samples: @@ -462,8 +599,12 @@ class TestCommon(TestCase): # Short-circuits if output is not a single tensor or an # iterable of tensors - if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True): - self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.") + if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( + expected, include_empty=True + ): + self.skipTest( + "Skipped! Only supports single tensor or iterable of tensor outputs." + ) # Validates the op doesn't support out if it claims not to if not op.supports_out: @@ -493,7 +634,7 @@ class TestCommon(TestCase): # NOTE: only extracts on the CPU and CUDA device types since some # device types don't have storage def _extract_data_ptrs(out): - if self.device_type != 'cpu' and self.device_type != 'cuda': + if self.device_type != "cpu" and self.device_type != "cuda": return () if isinstance(out, torch.Tensor): @@ -515,7 +656,8 @@ class TestCommon(TestCase): if compare_strides_and_data_ptrs: stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format( - original_strides, final_strides) + original_strides, final_strides + ) self.assertEqual(original_strides, final_strides, msg=stride_msg) self.assertEqual(original_ptrs, final_ptrs) @@ -529,7 +671,7 @@ class TestCommon(TestCase): return torch.full_like(t, info.max) except TypeError as te: # for non-integer types fills with NaN - return torch.full_like(t, float('nan')) + return torch.full_like(t, float("nan")) _compare_out(_case_zero_transform) @@ -537,10 +679,9 @@ class TestCommon(TestCase): # but noncontiguous. # Expected behavior: strides are respected and `out` storage is not changed. def _case_one_transform(t): - return make_tensor(t.shape, - dtype=t.dtype, - device=t.device, - noncontiguous=True) + return make_tensor( + t.shape, dtype=t.dtype, device=t.device, noncontiguous=True + ) _compare_out(_case_one_transform) @@ -560,16 +701,19 @@ class TestCommon(TestCase): # Verifies no warning is a resize warning for w in caught: if "An output with one or more elements" in str(w.message): - self.fail("Resizing an out= argument with no elements threw a resize warning!") + self.fail( + "Resizing an out= argument with no elements threw a resize warning!" + ) # Case 3: out= with correct shape and dtype, but wrong device. wrong_device = None - if torch.device(device).type != 'cpu': - wrong_device = 'cpu' + if torch.device(device).type != "cpu": + wrong_device = "cpu" elif torch.cuda.is_available(): - wrong_device = 'cuda' + wrong_device = "cuda" if wrong_device is not None: + def _case_three_transform(t): return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) @@ -587,16 +731,28 @@ class TestCommon(TestCase): # dtypes, or if an op returns multiple tensors when at least one such # tensor is a floating point or complex dtype. _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16) - if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or - (not isinstance(expected, torch.Tensor) and any(t.dtype in _dtypes for t in expected))): + if ( + isinstance(expected, torch.Tensor) + and expected.dtype in _dtypes + or ( + not isinstance(expected, torch.Tensor) + and any(t.dtype in _dtypes for t in expected) + ) + ): + def _case_four_transform(t): return make_tensor(t.shape, dtype=torch.long, device=t.device) out = _apply_out_transform(_case_four_transform, expected) msg_fail = "Expected RuntimeError when doing an unsafe cast!" - msg_fail = msg_fail if not isinstance(expected, torch.Tensor) else \ - ("Expected RuntimeError when doing an unsafe cast from a result of dtype " - f"{expected.dtype} into an out= with dtype torch.long") + msg_fail = ( + msg_fail + if not isinstance(expected, torch.Tensor) + else ( + "Expected RuntimeError when doing an unsafe cast from a result of dtype " + f"{expected.dtype} into an out= with dtype torch.long" + ) + ) with self.assertRaises(RuntimeError, msg=msg_fail): op_out(out=out) @@ -611,7 +767,9 @@ class TestCommon(TestCase): inplace = op.inplace_variant # list of all inplace ops: inplace variant + alias inplace variants if exist - inplace_ops = [inplace, ] + inplace_ops = [ + inplace, + ] variants = [method, inplace] for a_op in op.aliases: @@ -623,30 +781,46 @@ class TestCommon(TestCase): inplace_variants = tuple(filter(None, inplace_ops)) variants = tuple(filter(None, variants)) - _requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type)) + _requires_grad = dtype in op.supported_backward_dtypes( + torch.device(device).type + ) include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex - samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs) + samples = op.sample_inputs( + device, + dtype, + requires_grad=_requires_grad, + include_conjugated_inputs=include_conjugated_inputs, + ) samples = list(samples) def _test_consistency_helper(samples, variants): for sample in samples: # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList - tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] + tensor = ( + sample.input + if isinstance(sample.input, torch.Tensor) + else sample.input[0] + ) # Computes function forward and backward values tensor.grad = None expected_forward = op(sample.input, *sample.args, **sample.kwargs) expected_grad = None - output_process_fn_grad = sample.output_process_fn_grad if sample.output_process_fn_grad \ + output_process_fn_grad = ( + sample.output_process_fn_grad + if sample.output_process_fn_grad else lambda x: x + ) # Skips inplace variants if the output dtype is not the same as # the input dtype skip_inplace = False - if (isinstance(expected_forward, torch.Tensor) and - expected_forward.dtype is not tensor.dtype): + if ( + isinstance(expected_forward, torch.Tensor) + and expected_forward.dtype is not tensor.dtype + ): skip_inplace = True # TODO: backward consistency only supported for single tensor outputs @@ -654,8 +828,9 @@ class TestCommon(TestCase): # tensor inputs # TODO: update to handle checking grads of all tensor inputs as # derived from each tensor output - if (isinstance(expected_forward, torch.Tensor) - and dtype in op.supported_backward_dtypes(torch.device(device).type)): + if isinstance( + expected_forward, torch.Tensor + ) and dtype in op.supported_backward_dtypes(torch.device(device).type): output_process_fn_grad(expected_forward).sum().backward() expected_grad = tensor.grad @@ -668,26 +843,35 @@ class TestCommon(TestCase): # Compares variant's forward # Note: copies the to-be-modified input when testing the inplace variant tensor.grad = None - cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input + cloned = ( + clone_input_helper(sample.input) + if variant in inplace_ops + else sample.input + ) if variant in inplace_ops and sample.broadcasts_input: - with self.assertRaises(RuntimeError, - msg=('inplace variant either incorrectly allowed ' - 'resizing or you have marked the sample {}' - ' incorrectly with `broadcasts_self=True'.format(sample.summary()))): - variant_forward = variant(cloned, - *sample.args, - **sample.kwargs) + with self.assertRaises( + RuntimeError, + msg=( + "inplace variant either incorrectly allowed " + "resizing or you have marked the sample {}" + " incorrectly with `broadcasts_self=True".format( + sample.summary() + ) + ), + ): + variant_forward = variant( + cloned, *sample.args, **sample.kwargs + ) continue - variant_forward = variant(cloned, - *sample.args, - **sample.kwargs) + variant_forward = variant(cloned, *sample.args, **sample.kwargs) self.assertEqual(expected_forward, variant_forward) # Compares variant's backward - if expected_grad is not None and \ - (variant not in inplace_ops or op.supports_inplace_autograd): + if expected_grad is not None and ( + variant not in inplace_ops or op.supports_inplace_autograd + ): output_process_fn_grad(variant_forward).sum().backward() self.assertEqual(expected_grad, tensor.grad) @@ -698,28 +882,45 @@ class TestCommon(TestCase): # Skips inplace variants if the output dtype is not the same as # the input dtype expected_forward = op(sample.input, *sample.args, **sample.kwargs) - tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] + tensor = ( + sample.input + if isinstance(sample.input, torch.Tensor) + else sample.input[0] + ) skip_inplace = False - if (isinstance(expected_forward, torch.Tensor) and - expected_forward.dtype is not tensor.dtype): + if ( + isinstance(expected_forward, torch.Tensor) + and expected_forward.dtype is not tensor.dtype + ): skip_inplace = True if skip_inplace: return for variant in variants: - cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input - inp_tensor = cloned if isinstance(cloned, torch.Tensor) else cloned[0] + cloned = ( + clone_input_helper(sample.input) + if variant in inplace_ops + else sample.input + ) + inp_tensor = ( + cloned if isinstance(cloned, torch.Tensor) else cloned[0] + ) data_ptr = inp_tensor.data_ptr() - variant_forward = variant(cloned, - *sample.args, - **sample.kwargs) + variant_forward = variant(cloned, *sample.args, **sample.kwargs) # TODO Support non-tensor outputs if they exist for inplace ops - if (isinstance(variant_forward, torch.Tensor)): - self.assertEqual(data_ptr, variant_forward.data_ptr(), atol=0, rtol=0) + if isinstance(variant_forward, torch.Tensor): + self.assertEqual( + data_ptr, variant_forward.data_ptr(), atol=0, rtol=0 + ) else: - self.assertTrue(False, "Non-tensor outputs for inplace ops are not supported") + self.assertTrue( + False, + "Non-tensor outputs for inplace ops are not supported", + ) if len(inplace_ops) > 0: - inplace_samples = list(filter(lambda sample: not sample.broadcasts_input, samples)) + inplace_samples = list( + filter(lambda sample: not sample.broadcasts_input, samples) + ) _test_inplace_preserve_storage(inplace_samples, inplace_variants) # Reference testing for operations in complex32 against complex64. @@ -732,14 +933,21 @@ class TestCommon(TestCase): for sample in op.sample_inputs(device, dtype): actual = op(sample.input, *sample.args, **sample.kwargs) transformed_sample = sample.transform(lambda x: x.to(torch.complex64)) - expected = op(transformed_sample.input, *transformed_sample.args, **transformed_sample.kwargs) + expected = op( + transformed_sample.input, + *transformed_sample.args, + **transformed_sample.kwargs, + ) self.assertEqual(actual, expected, exact_dtype=False) + class TestCompositeCompliance(TestCase): # Checks if the operator (if it is composite) is written to support most # backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance" # in aten/src/ATen/native/README.md for more details - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode') + @unittest.skipIf( + IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" + ) @ops(op_db, allowed_dtypes=(torch.float,)) def test_operator(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=False) @@ -750,7 +958,9 @@ class TestCompositeCompliance(TestCase): composite_compliance.check_with_mode(op, args, kwargs) composite_compliance.check_all_permutations(op, args, kwargs) - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode') + @unittest.skipIf( + IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" + ) @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) def test_backward(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=True) @@ -760,7 +970,9 @@ class TestCompositeCompliance(TestCase): kwargs = sample.kwargs composite_compliance.check_backward_formula(op, args, kwargs) - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode') + @unittest.skipIf( + IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" + ) @ops(op_db, allowed_dtypes=(torch.float,)) def test_forward_ad(self, device, dtype, op): if torch.float not in op.supported_backward_dtypes(device): @@ -776,6 +988,7 @@ class TestCompositeCompliance(TestCase): kwargs = sample.kwargs composite_compliance.check_forward_ad_formula(op, args, kwargs) + class TestMathBits(TestCase): # Tests that # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors @@ -787,7 +1000,17 @@ class TestMathBits(TestCase): # This test only runs for C -> R and C -> C functions # TODO: add tests for `R->C` functions # Note: This test runs for functions that take both tensors and tensorlists as input. - def _test_math_view(self, device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, out_type): + def _test_math_view( + self, + device, + dtype, + op, + samples, + math_op_physical, + math_op_view, + is_bit_set, + out_type, + ): inplace_variant = op.inplace_variant # helper function to clone and conjugate/negate the input if its a tensor @@ -796,7 +1019,7 @@ class TestMathBits(TestCase): # have its requires_grad set to that value. def clone_and_perform_view(input, **kwargs): if isinstance(input, torch.Tensor): - requires_grad = kwargs.get('requires_grad', input.requires_grad) + requires_grad = kwargs.get("requires_grad", input.requires_grad) with torch.no_grad(): # Ensure view represents the original sample input input = math_op_physical(input) @@ -813,7 +1036,11 @@ class TestMathBits(TestCase): return tuple(out) for sample in samples: - tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] + tensor = ( + sample.input + if isinstance(sample.input, torch.Tensor) + else sample.input[0] + ) cloned1 = clone_and_perform_view(sample.input) # Computes function forward value with a physically conjugated/negated tensor and @@ -827,9 +1054,13 @@ class TestMathBits(TestCase): # input produces correct output, and the output tensor has the conj/neg bit set to True if inplace_variant is not None and not sample.broadcasts_input: cloned2 = clone_and_perform_view(tensor, requires_grad=False) - if (isinstance(expected_forward, torch.Tensor) and - expected_forward.dtype is tensor.dtype): - inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs) + if ( + isinstance(expected_forward, torch.Tensor) + and expected_forward.dtype is tensor.dtype + ): + inplace_forward = inplace_variant( + cloned2, *sample.args, **sample.kwargs + ) self.assertTrue(is_bit_set(inplace_forward)) self.assertEqual(inplace_forward, expected_forward) @@ -838,25 +1069,36 @@ class TestMathBits(TestCase): # tensor inputs # TODO: update to handle checking grads of all tensor inputs as # derived from each tensor output - if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad: + if ( + isinstance(expected_forward, torch.Tensor) + and expected_forward.requires_grad + ): output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x) expected_forward = output_process_fn_grad(expected_forward) forward_with_mathview = output_process_fn_grad(forward_with_mathview) - tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] + tensor = ( + sample.input + if isinstance(sample.input, torch.Tensor) + else sample.input[0] + ) expected_forward.sum().backward(retain_graph=True) forward_with_mathview.sum().backward(retain_graph=True) if tensor.grad is not None: - cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0] + cloned1_tensor = ( + cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0] + ) self.assertEqual(tensor.grad, cloned1_tensor.grad) tensor.grad, cloned1_tensor.grad = None, None # a repeat of the above test if output is not complex valued - if (out_type(expected_forward)): + if out_type(expected_forward): grad = torch.randn_like(expected_forward) expected_forward.backward(grad) - forward_with_mathview.backward(math_op_view(math_op_physical(grad))) + forward_with_mathview.backward( + math_op_view(math_op_physical(grad)) + ) self.assertEqual(tensor.grad, cloned1_tensor.grad) @@ -866,10 +1108,21 @@ class TestMathBits(TestCase): self.skipTest("Operation doesn't support conjugated inputs.") math_op_physical = torch.conj_physical math_op_view = torch.conj - _requires_grad = torch.cfloat in op.supported_backward_dtypes(torch.device(device).type) + _requires_grad = torch.cfloat in op.supported_backward_dtypes( + torch.device(device).type + ) is_bit_set = torch.is_conj samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) - self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, torch.is_complex) + self._test_math_view( + device, + dtype, + op, + samples, + math_op_physical, + math_op_view, + is_bit_set, + torch.is_complex, + ) @ops(op_db, allowed_dtypes=(torch.double,)) def test_neg_view(self, device, dtype, op): @@ -879,8 +1132,16 @@ class TestMathBits(TestCase): math_op_view = torch._neg_view is_bit_set = torch.is_neg samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) - self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, - lambda x: True) + self._test_math_view( + device, + dtype, + op, + samples, + math_op_physical, + math_op_view, + is_bit_set, + lambda x: True, + ) @ops(op_db, allowed_dtypes=(torch.cdouble,)) def test_neg_conj_view(self, device, dtype, op): @@ -898,17 +1159,27 @@ class TestMathBits(TestCase): def is_bit_set(x): return torch.is_neg(x) and torch.is_conj(x) - _requires_grad = dtype in op.supported_backward_dtypes(torch.device(device).type) + _requires_grad = dtype in op.supported_backward_dtypes( + torch.device(device).type + ) samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) # Only test one sample samples = itertools.islice(samples, 1) - self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, - torch.is_complex) + self._test_math_view( + device, + dtype, + op, + samples, + math_op_physical, + math_op_view, + is_bit_set, + torch.is_complex, + ) instantiate_device_type_tests(TestCommon, globals()) instantiate_device_type_tests(TestCompositeCompliance, globals()) instantiate_device_type_tests(TestMathBits, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_prims.py b/test/test_prims.py index de33d7d54750..5c77cf9d0b5c 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -1,85 +1,83 @@ # Owner(s): ["module: primTorch"] -# TODO: uncomment this file once CI issues with import nvfuser are resolved +from functools import partial -# from functools import partial - -# import torch -# from torch.testing import make_tensor -# from torch.testing._internal.common_utils import run_tests, TestCase -# from torch.testing._internal.common_device_type import ( -# instantiate_device_type_tests, -# onlyCUDA, -# dtypes, -# ) -# import torch._prims as prims -# from torch._prims.executor import make_traced +import torch +from torch.testing import make_tensor +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyCUDA, + dtypes, +) +import torch._prims as prims +from torch._prims.executor import make_traced -# class TestPrims(TestCase): -# @onlyCUDA -# @dtypes(torch.float32) -# def test_broadcast_in_dim(self, device, dtype): -# def _wrapper(a, shape, broadcast_dimensions): -# return prims.broadcast_in_dim(a, shape, broadcast_dimensions) +class TestPrims(TestCase): + @onlyCUDA + @dtypes(torch.float32) + def test_broadcast_in_dim(self, device, dtype): + def _wrapper(a, shape, broadcast_dimensions): + return prims.broadcast_in_dim(a, shape, broadcast_dimensions) -# traced = make_traced(_wrapper) -# make_arg = partial(make_tensor, device=device, dtype=dtype) + traced = make_traced(_wrapper) + make_arg = partial(make_tensor, device=device, dtype=dtype) -# # TODO: FIXME: -# # for executor in ('aten', 'nvfuser'): -# for executor in ("aten",): -# fn = partial(traced, executor=executor) -# # Same shape -# shape = (5, 5) -# a = make_arg(shape) -# result = fn(a, shape, (0, 1)) + # TODO: FIXME: + # for executor in ('aten', 'nvfuser'): + for executor in ("aten",): + fn = partial(traced, executor=executor) + # Same shape + shape = (5, 5) + a = make_arg(shape) + result = fn(a, shape, (0, 1)) -# self.assertEqual(result.shape, a.shape) -# self.assertTrue(result.is_contiguous) -# self.assertEqual(a, result) + self.assertEqual(result.shape, a.shape) + self.assertTrue(result.is_contiguous) + self.assertEqual(a, result) -# # Error input: reordering dims -# with self.assertRaises(Exception): -# result = fn(a, shape, (1, 0)) + # Error input: reordering dims + with self.assertRaises(Exception): + result = fn(a, shape, (1, 0)) -# # Adding outermost dimensions -# a = make_arg((5, 5)) -# target_shape = (3, 3, 5, 5) -# result = fn(a, target_shape, (2, 3)) + # Adding outermost dimensions + a = make_arg((5, 5)) + target_shape = (3, 3, 5, 5) + result = fn(a, target_shape, (2, 3)) -# self.assertEqual(result.shape, target_shape) -# self.assertEqual(a.broadcast_to(target_shape), result) + self.assertEqual(result.shape, target_shape) + self.assertEqual(a.broadcast_to(target_shape), result) -# # Expands -# a = make_arg((1, 5, 1)) -# target_shape = (3, 5, 7) -# result = fn(a, target_shape, (0, 1, 2)) + # Expands + a = make_arg((1, 5, 1)) + target_shape = (3, 5, 7) + result = fn(a, target_shape, (0, 1, 2)) -# self.assertEqual(result.shape, target_shape) -# self.assertEqual(a.expand_as(result), result) + self.assertEqual(result.shape, target_shape) + self.assertEqual(a.expand_as(result), result) -# # Unsqueezes -# a = make_arg((1, 2, 3)) -# target_shape = (1, 2, 1, 3) -# result = fn(a, target_shape, (0, 1, 3)) + # Unsqueezes + a = make_arg((1, 2, 3)) + target_shape = (1, 2, 1, 3) + result = fn(a, target_shape, (0, 1, 3)) -# self.assertEqual(result.shape, target_shape) -# self.assertEqual(a.unsqueeze(2), result) + self.assertEqual(result.shape, target_shape) + self.assertEqual(a.unsqueeze(2), result) -# # Adds outermost, expands, and unsqueezes -# a = make_arg((1, 2, 3)) -# target_shape = (4, 1, 7, 2, 3, 3) -# result = fn(a, target_shape, (1, 3, 4)) + # Adds outermost, expands, and unsqueezes + a = make_arg((1, 2, 3)) + target_shape = (4, 1, 7, 2, 3, 3) + result = fn(a, target_shape, (1, 3, 4)) -# self.assertEqual(result.shape, target_shape) -# a.unsqueeze_(3) -# a.unsqueeze_(1) -# a.unsqueeze_(0) -# self.assertEqual(a.expand_as(result), result) + self.assertEqual(result.shape, target_shape) + a.unsqueeze_(3) + a.unsqueeze_(1) + a.unsqueeze_(0) + self.assertEqual(a.expand_as(result), result) -# instantiate_device_type_tests(TestPrims, globals()) +instantiate_device_type_tests(TestPrims, globals()) -# if __name__ == "__main__": -# run_tests() +if __name__ == "__main__": + run_tests() diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index a15bae17e9dd..2c1e20da8a65 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -13,7 +13,7 @@ import random from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, - torch_to_numpy_dtype_dict, slowTest, + torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest, TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize) from torch.testing._internal.common_device_type import ( expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes, @@ -3683,13 +3683,13 @@ class TestBufferProtocol(TestCase): self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr()) return (numpy_original, torch_frombuffer) - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_same_type(self, device, dtype): self._run_test((), dtype) self._run_test((4,), dtype) self._run_test((10, 10), dtype) - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_requires_grad(self, device, dtype): def _run_test_and_check_grad(requires_grad, *args, **kwargs): kwargs["requires_grad"] = requires_grad @@ -3704,14 +3704,14 @@ class TestBufferProtocol(TestCase): _run_test_and_check_grad(False, (4,), dtype) _run_test_and_check_grad(False, (10, 10), dtype) - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_with_offset(self, device, dtype): # Offset should be valid whenever there is, at least, # one remaining element for i in range(SIZE): self._run_test(SHAPE, dtype, first=i) - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_with_count(self, device, dtype): # Count should be valid for any valid in the interval # [-1, len(input)], except for 0 @@ -3719,7 +3719,7 @@ class TestBufferProtocol(TestCase): if i != 0: self._run_test(SHAPE, dtype, count=i) - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_with_count_and_offset(self, device, dtype): # Explicit default count [-1, 1, 2, ..., len] for i in range(-1, SIZE + 1): @@ -3735,7 +3735,7 @@ class TestBufferProtocol(TestCase): for j in range(SIZE - i + 1): self._run_test(SHAPE, dtype, count=i, first=j) - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_invalid_positional_args(self, device, dtype): bytes = get_dtype_size(dtype) in_bytes = SIZE * bytes @@ -3772,7 +3772,7 @@ class TestBufferProtocol(TestCase): rf"buffer length \({in_bytes} bytes\)"): self._run_test(SHAPE, dtype, count=count, first=first) - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_shared_buffer(self, device, dtype): x = make_tensor((1,), dtype=dtype, device=device) # Modify the whole tensor @@ -3799,13 +3799,13 @@ class TestBufferProtocol(TestCase): arr[first] = x.item() - 1 self.assertEqual(arr[first:last], tensor) - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_not_a_buffer(self, device, dtype): with self.assertRaisesRegex(ValueError, r"object does not implement Python buffer protocol."): torch.frombuffer([1, 2, 3, 4], dtype=dtype) - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_non_writable_buffer(self, device, dtype): numpy_arr = make_tensor((1,), dtype=dtype, device=device).numpy() byte_arr = numpy_arr.tobytes() @@ -3910,7 +3910,7 @@ class TestAsArray(TestCase): self._test_alias_with_cvt(identity, device, dtype) @onlyCPU - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_alias_from_numpy(self, device, dtype): self._test_alias_with_cvt(to_numpy, device, dtype) @@ -3921,7 +3921,7 @@ class TestAsArray(TestCase): self._test_alias_with_cvt(to_dlpack, device, dtype) @onlyCPU - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_alias_from_buffer(self, device, dtype): self._test_alias_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True) @@ -3959,7 +3959,7 @@ class TestAsArray(TestCase): self._test_copy_with_cvt(identity, device, dtype) @onlyCPU - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_copy_from_numpy(self, device, dtype): self._test_copy_with_cvt(to_numpy, device, dtype) @@ -3969,7 +3969,7 @@ class TestAsArray(TestCase): self._test_copy_with_cvt(to_dlpack, device, dtype) @onlyCPU - @dtypes(*torch_to_numpy_dtype_dict.keys()) + @dtypes(*set(numpy_to_torch_dtype_dict.values())) def test_copy_from_buffer(self, device, dtype): self._test_copy_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True) diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index a157f49962d5..5ff2da736ead 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -7,15 +7,14 @@ import unittest import torch from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, - TEST_NUMPY, torch_to_numpy_dtype_dict) + TEST_NUMPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict) from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes, dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta) from torch.testing._internal.common_dtype import ( all_types_and_complex_and, all_types_and, get_all_math_dtypes, integral_types_and, floating_types_and ) -if TEST_NUMPY: - import numpy as np +import numpy as np # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -812,8 +811,8 @@ class TestTypePromotion(TestCase): @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @float_double_default_dtype @onlyCPU - @dtypes(*list(itertools.product(torch_to_numpy_dtype_dict.keys(), - torch_to_numpy_dtype_dict.keys()))) + @dtypes(*list(itertools.product(set(numpy_to_torch_dtype_dict.values()), + set(numpy_to_torch_dtype_dict.values())))) def test_numpy_array_binary_ufunc_promotion(self, device, dtypes): import operator np_type = torch_to_numpy_dtype_dict[dtypes[0]] diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 0c7661fa7393..3a1f77e37ca6 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -11,18 +11,47 @@ import unittest from torch._six import inf, nan from torch.testing._internal.common_utils import ( - TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, - suppress_warnings, TEST_SCIPY, slowTest, skipIfNoSciPy, IS_WINDOWS, gradcheck) + TestCase, + run_tests, + torch_to_numpy_dtype_dict, + numpy_to_torch_dtype_dict, + suppress_warnings, + TEST_SCIPY, + slowTest, + skipIfNoSciPy, + IS_WINDOWS, + gradcheck, + TEST_WITH_ASAN, +) from torch.testing._internal.common_methods_invocations import ( - unary_ufuncs, _NOTHING) + unary_ufuncs, + generate_elementwise_unary_tensors, + _NOTHING, + generate_elementwise_unary_small_value_tensors, + generate_elementwise_unary_large_value_tensors, + generate_elementwise_unary_extremal_value_tensors, +) from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, ops, dtypes, onlyCPU, onlyNativeDeviceTypes, - onlyCUDA, dtypesIfCUDA, precisionOverride, dtypesIfCPU) + instantiate_device_type_tests, + ops, + dtypes, + onlyCPU, + onlyNativeDeviceTypes, + onlyCUDA, + dtypesIfCUDA, + precisionOverride, + dtypesIfCPU, +) from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( - floating_types_and, all_types_and_complex_and, integral_types_and, get_all_math_dtypes, - complex_types, all_types_and, floating_and_complex_types_and + floating_types_and, + all_types_and_complex_and, + integral_types_and, + get_all_math_dtypes, + complex_types, + all_types_and, + floating_and_complex_types_and, ) if TEST_SCIPY: @@ -45,140 +74,8 @@ reference_filtered_ops = list(filter(lambda op: op.ref is not _NOTHING, unary_uf # (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details # about the concept of ufuncs. -# Functions tested here: -# -# Interesting values and extremal values for different dtypes -_unsigned_int_vals = (0, 1, 55, 127) -_int_vals = (0, -1, 1, -55, 55, -127, 127, -128, 128) -_large_int_vals = (-1113, 1113, -10701, 10701) -_float_vals = (0., - -.001, .001, - -.25, .25, - -1., 1., - -math.pi / 2, math.pi / 2, - -math.pi + .00001, math.pi - .00001, - -math.pi, math.pi, - -math.pi - .00001, math.pi + .00001) -_large_float16_vals = (-501, 501, - -1001.2, 1001.2, - -13437.7, 13437.7) -_large_float_vals = _large_float16_vals + (-4988429.2, 4988429.2, -1e20, 1e20) -_float_extremals = (float('inf'), float('-inf'), float('nan')) -_medium_length = 812 -_large_size = (1029, 917) - - -# Replace values satisfying condition with a safe value. This is used to block -# out values the could cause singularity like tan(pi/2) -def replace_values_in_tensor(tensor, condition, safe_value): - mask = condition(tensor) - tensor.masked_fill_(mask, safe_value) - - -# Returns generator of tensors of different sizes filled with values in domain -# and with intested region filled with `vals`. This will help test different code -# paths for the given vals -# `filter_` can be either None or a tuple of (condition, safe_value). When not None -# values satisfying `condition`` will be replaced with `safe_value` in the generated -# tensor. This is useful to avoid singularities when generating inputs for tests, such -# as tan(pi/2) -def generate_tensors_from_vals(vals, device, dtype, domain, filter_): - offset = 63 - - assert _large_size[1] > (_medium_length + offset) # large tensor should be large enough - assert len(vals) < _medium_length # medium tensor should contain all vals - assert _medium_length % 4 == 0 # ensure vectorized code coverage - - if not dtype.is_complex: - # Filter values based on Operators domain. - # Note: Complex numbers don't belong to ordered field, - # so we don't filter for them. - if domain[0] is not None: - vals = list(filter(lambda x: x >= domain[0], vals)) - if domain[1] is not None: - vals = list(filter(lambda x: x < domain[1], vals)) - - if filter_ is not None: - condition, safe_value = filter_ - - # Constructs the large tensor containing vals - large_tensor = make_tensor(_large_size, device=device, dtype=dtype, low=domain[0], high=domain[1]) - - # Inserts the vals at an odd place - large_tensor[57][offset:offset + len(vals)] = torch.tensor(vals, device=device, dtype=dtype) - - if filter_ is not None: - replace_values_in_tensor(large_tensor, condition, safe_value) - - # Takes a medium sized copy of the large tensor containing vals - medium_tensor = large_tensor[57][offset:offset + _medium_length] - - if filter_ is not None: - replace_values_in_tensor(medium_tensor, condition, safe_value) - - # Constructs scalar tensors - scalar_tensors = (t.squeeze() for t in torch.split(medium_tensor, 1)) - - # Tensors with no elements - empty_sizes = ((0,), (0, 3, 3), (1, 0, 5), (6, 0, 0, 0), (3, 0, 1, 0)) - empty_tensors = (torch.empty(size, device=device, dtype=dtype) for size in empty_sizes) - - return chain(empty_tensors, scalar_tensors, (medium_tensor,), (large_tensor,)) - - -# [Note generate_numeric_tensors, generate_numeric_tensors_hard, -# and generate_numeric_tensors_extremal] -# -# Returns an iterable of contiguous tensors with the same storage on the requested -# device and with the requested dtype. -# -# This function is intended to test the non-vectorized and vectorized code -# paths of unary functions, as well as their handling of odd tensor -# sizes (like zero-dim tensors and tensors with zero elements). -# -# The iterable will include an empty tensor, tensors with no elements, -# zero dim (scalar) tensors, small 1D tensors, a medium 1D tensor, and -# a large 2D tensor. -# -# These tensors will include interesting values. The generate_numeric_tensors_hard -# tests larger values (>500) and generate_numeric_tensors_extremal tests extremal -# values like -inf, inf, and nan. -# -# The randomly generated values can be restricted by the domain -# argument. -def generate_numeric_tensors(device, dtype, *, - domain=(None, None), - filter_=None): - # Special-cases bool - if dtype is torch.bool: - tensors = (torch.empty(0, device=device, dtype=torch.bool), - torch.tensor(True, device=device), - torch.tensor(False, device=device), - torch.tensor((True, False), device=device), - make_tensor((_medium_length,), device=device, dtype=dtype, low=None, high=None), - make_tensor(_large_size, device=device, dtype=dtype, low=None, high=None)) - return tensors - - # Acquires dtype-specific vals - if dtype.is_floating_point or dtype.is_complex: - vals = _float_vals - - # Converts float -> complex vals if dtype is complex - if dtype.is_complex: - vals = tuple(complex(x, y) for x, y in product(vals, vals)) - elif dtype is torch.uint8: - vals = _unsigned_int_vals - else: # dtypes is a signed integer type - assert dtype in (torch.int8, torch.int16, torch.int32, torch.int64) - vals = _int_vals - - return generate_tensors_from_vals(vals, device, dtype, domain, filter_) - - -def generate_numeric_tensors_hard(device, dtype, *, - domain=(None, None), - filter_=None): +def generate_numeric_tensors_hard(device, dtype, *, domain=(None, None), filter_=None): is_signed_integral = dtype in (torch.int8, torch.int16, torch.int32, torch.int64) if not (dtype.is_floating_point or dtype.is_complex or is_signed_integral): return () @@ -190,18 +87,23 @@ def generate_numeric_tensors_hard(device, dtype, *, else: vals = _large_float_vals elif dtype.is_complex: - vals = tuple(complex(x, y) for x, y in chain(product(_large_float_vals, _large_float_vals), - product(_float_vals, _large_float_vals), - product(_large_float_vals, _float_vals))) + vals = tuple( + complex(x, y) + for x, y in chain( + product(_large_float_vals, _large_float_vals), + product(_float_vals, _large_float_vals), + product(_large_float_vals, _float_vals), + ) + ) else: vals = _large_int_vals return generate_tensors_from_vals(vals, device, dtype, domain, filter_) -def generate_numeric_tensors_extremal(device, dtype, *, - domain=(None, None), - filter_=None): +def generate_numeric_tensors_extremal( + device, dtype, *, domain=(None, None), filter_=None +): if not (dtype.is_floating_point or dtype.is_complex): return () @@ -209,9 +111,14 @@ def generate_numeric_tensors_extremal(device, dtype, *, if dtype.is_floating_point: vals = _float_extremals elif dtype.is_complex: - vals = tuple(complex(x, y) for x, y in chain(product(_float_extremals, _float_extremals), - product(_float_vals, _float_extremals), - product(_float_extremals, _float_vals))) + vals = tuple( + complex(x, y) + for x, y in chain( + product(_float_extremals, _float_extremals), + product(_float_vals, _float_extremals), + product(_float_extremals, _float_vals), + ) + ) return generate_tensors_from_vals(vals, device, dtype, domain, filter_) @@ -221,8 +128,10 @@ def generate_numeric_tensors_extremal(device, dtype, *, class TestUnaryUfuncs(TestCase): exact_dtype = True - @ops([_fn for _fn in unary_ufuncs if _fn.domain != (None, None)], - allowed_dtypes=floating_types_and(torch.bfloat16, torch.half)) + @ops( + [_fn for _fn in unary_ufuncs if _fn.domain != (None, None)], + allowed_dtypes=floating_types_and(torch.bfloat16, torch.half), + ) def test_float_domains(self, device, dtype, op): eps = (1e-5, 1e-3, 1e-1, 1, 2, 10, 20, 50, 100) @@ -240,11 +149,14 @@ class TestUnaryUfuncs(TestCase): continue result = op(lower_tensor) - self.assertEqual(result.item(), float('nan'), - msg=("input of {0} outside lower domain boundary" - " {1} produced {2}, not nan!").format(lower_tensor.item(), - low, - result.item())) + self.assertEqual( + result.item(), + float("nan"), + msg=( + "input of {0} outside lower domain boundary" + " {1} produced {2}, not nan!" + ).format(lower_tensor.item(), low, result.item()), + ) if high is not None: high_tensor = torch.tensor(high, device=device, dtype=dtype) @@ -256,15 +168,20 @@ class TestUnaryUfuncs(TestCase): continue result = op(higher_tensor) - self.assertEqual(result.item(), float('nan'), - msg=("input of {0} outside upper domain boundary" - " {1} produced {2}, not nan!").format(higher_tensor.item(), - high, - result.item())) + self.assertEqual( + result.item(), + float("nan"), + msg=( + "input of {0} outside upper domain boundary" + " {1} produced {2}, not nan!" + ).format(higher_tensor.item(), high, result.item()), + ) # Helper for comparing torch tensors and numpy arrays # TODO: should this or assertEqual also validate that strides are equal? - def assertEqualHelper(self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs): + def assertEqualHelper( + self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs + ): assert isinstance(actual, torch.Tensor) # Some NumPy functions return scalars, not arrays @@ -273,50 +190,96 @@ class TestUnaryUfuncs(TestCase): elif isinstance(expected, np.ndarray): # Handles exact dtype comparisons between arrays and tensors if exact_dtype: - if actual.dtype is torch.bfloat16 or expected.dtype != torch_to_numpy_dtype_dict[actual.dtype]: + if ( + actual.dtype is torch.bfloat16 + or expected.dtype != torch_to_numpy_dtype_dict[actual.dtype] + ): # Allows array dtype to be float32 when comparing with bfloat16 tensors # since NumPy doesn't support the bfloat16 dtype # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16 # to float32 if expected.dtype == np.float32: - assert actual.dtype in (torch.float16, torch.bfloat16, torch.float32) + assert actual.dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + ) elif expected.dtype == np.float64: - assert actual.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) + assert actual.dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + ) else: - self.fail("Expected dtype {0} but got {1}!".format( - expected.dtype, actual.dtype)) + self.fail( + "Expected dtype {0} but got {1}!".format( + expected.dtype, actual.dtype + ) + ) - self.assertEqual(actual, - torch.from_numpy(expected).to(actual.dtype), - msg, - exact_device=False, - **kwargs) + self.assertEqual( + actual, + torch.from_numpy(expected).to(actual.dtype), + msg, + exact_device=False, + **kwargs + ) else: self.assertEqual(actual, expected, msg, exact_device=False, **kwargs) # Tests that the function and its (array-accepting) reference produce the same # values on given tensors def _test_reference_numerics(self, dtype, op, tensors, equal_nan=True): - def _helper_reference_numerics(expected, actual, msg, exact_dtype, equal_nan=True): - if not torch.can_cast(numpy_to_torch_dtype_dict[expected.dtype.type], dtype): + def _helper_reference_numerics( + expected, actual, msg, exact_dtype, equal_nan=True + ): + if not torch.can_cast( + numpy_to_torch_dtype_dict[expected.dtype.type], dtype + ): exact_dtype = False if dtype in [torch.uint8, torch.int8, torch.bool]: # NOTE: For these dtypes, PyTorch computes in the default scalar type (float) # while NumPy computes in float16 - self.assertEqualHelper(actual, expected, msg, dtype=dtype, - exact_dtype=exact_dtype, rtol=1e-3, atol=1e-2) + self.assertEqualHelper( + actual, + expected, + msg, + dtype=dtype, + exact_dtype=exact_dtype, + rtol=1e-3, + atol=1e-2, + ) elif dtype is torch.bfloat16: # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149 - self.assertEqualHelper(actual, expected, msg, dtype=dtype, - exact_dtype=exact_dtype, rtol=16e-3, atol=1e-5) + self.assertEqualHelper( + actual, + expected, + msg, + dtype=dtype, + exact_dtype=exact_dtype, + rtol=16e-3, + atol=1e-5, + ) + else: - self.assertEqualHelper(actual, expected, msg, dtype=dtype, equal_nan=equal_nan, exact_dtype=exact_dtype) + self.assertEqualHelper( + actual, + expected, + msg, + dtype=dtype, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + ) for t in tensors: + t = t.input torch_kwargs, numpy_kwargs = op.sample_kwargs(t.device, dtype, t) if dtype is torch.bfloat16: a = t.cpu().to(torch.float32).numpy() + elif dtype is torch.complex32: + a = t.cpu().to(torch.complex64).numpy() else: a = t.cpu().numpy() @@ -325,15 +288,19 @@ class TestUnaryUfuncs(TestCase): # Crafts a custom error message for smaller, printable tensors if t.numel() < 10: - msg = ("Failed to produce expected results! Input tensor was" - " {0}, torch result is {1}, and reference result is" - " {2}.").format(t, actual, expected) + msg = ( + "Failed to produce expected results! Input tensor was" + " {0}, torch result is {1}, and reference result is" + " {2}." + ).format(t, actual, expected) else: msg = None exact_dtype = True if isinstance(actual, torch.Tensor): - _helper_reference_numerics(expected, actual, msg, exact_dtype, equal_nan) + _helper_reference_numerics( + expected, actual, msg, exact_dtype, equal_nan + ) else: for x, y in zip(expected, actual): # testing multi-outputs results @@ -343,58 +310,72 @@ class TestUnaryUfuncs(TestCase): # values on a range of tensors, including empty tensors, scalar tensors, # 1D tensors and a large 2D tensor with interesting and extremal values # and noncontiguities. + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @suppress_warnings @ops(reference_filtered_ops) def test_reference_numerics_normal(self, device, dtype, op): - tensors = generate_numeric_tensors(device, dtype, - domain=op.domain, - filter_=op.reference_numerics_filter) + tensors = generate_elementwise_unary_tensors( + op, device=device, dtype=dtype, requires_grad=False + ) self._test_reference_numerics(dtype, op, tensors) + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @suppress_warnings - @ops(reference_filtered_ops, allowed_dtypes=floating_and_complex_types_and( - torch.bfloat16, torch.half, torch.int8, torch.int16, torch.int32, torch.int64 - )) - def test_reference_numerics_hard(self, device, dtype, op): - if not op.handles_large_floats: - raise self.skipTest("This op does not handle large values") + @ops(reference_filtered_ops) + def test_reference_numerics_small(self, device, dtype, op): + if dtype in (torch.bool,): + raise self.skipTest("bool has no small values") - tensors = generate_numeric_tensors_hard(device, dtype, - domain=op.domain) + tensors = generate_elementwise_unary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=False + ) self._test_reference_numerics(dtype, op, tensors) + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @suppress_warnings - @ops(reference_filtered_ops, - allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half)) + @ops(reference_filtered_ops) + def test_reference_numerics_large(self, device, dtype, op): + if dtype in (torch.bool, torch.uint8, torch.int8): + raise self.skipTest("bool, uint8, and int8 dtypes have no large values") + + tensors = generate_elementwise_unary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=False + ) + self._test_reference_numerics(dtype, op, tensors) + + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @suppress_warnings + @ops( + reference_filtered_ops, + allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), + ) def test_reference_numerics_extremal(self, device, dtype, op): - handles_extremals = (op.handles_complex_extremals if - dtype in (torch.cfloat, torch.cdouble) else op.handles_extremals) - if not handles_extremals: - raise self.skipTest("This op does not handle extremal values") - - tensors = generate_numeric_tensors_extremal(device, dtype, - domain=op.domain) - + tensors = generate_elementwise_unary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=False + ) self._test_reference_numerics(dtype, op, tensors) # Tests for testing (non)contiguity consistency - @ops(unary_ufuncs) def test_contig_vs_every_other(self, device, dtype, op): - contig = make_tensor((1026,), device=device, dtype=dtype, - low=op.domain[0], high=op.domain[1]) + contig = make_tensor( + (1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] + ) non_contig = contig[::2] self.assertTrue(contig.is_contiguous()) self.assertFalse(non_contig.is_contiguous()) torch_kwargs, _ = op.sample_kwargs(device, dtype, non_contig) - self.assertEqual(op(contig, **torch_kwargs)[::2], op(non_contig, **torch_kwargs)) + self.assertEqual( + op(contig, **torch_kwargs)[::2], op(non_contig, **torch_kwargs) + ) @ops(unary_ufuncs) def test_contig_vs_transposed(self, device, dtype, op): - contig = make_tensor((789, 357), device=device, dtype=dtype, - low=op.domain[0], high=op.domain[1]) + contig = make_tensor( + (789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] + ) non_contig = contig.T self.assertTrue(contig.is_contiguous()) @@ -407,8 +388,9 @@ class TestUnaryUfuncs(TestCase): def test_non_contig(self, device, dtype, op): shapes = [(5, 7), (1024,)] for shape in shapes: - contig = make_tensor(shape, dtype=dtype, device=device, - low=op.domain[0], high=op.domain[1]) + contig = make_tensor( + shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] + ) non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0] non_contig.copy_(contig) @@ -420,8 +402,13 @@ class TestUnaryUfuncs(TestCase): @ops(unary_ufuncs) def test_non_contig_index(self, device, dtype, op): - contig = make_tensor((2, 2, 1, 2), dtype=dtype, device=device, - low=op.domain[0], high=op.domain[1]) + contig = make_tensor( + (2, 2, 1, 2), + dtype=dtype, + device=device, + low=op.domain[0], + high=op.domain[1], + ) non_contig = contig[:, 1, ...] contig = non_contig.contiguous() @@ -435,8 +422,9 @@ class TestUnaryUfuncs(TestCase): def test_non_contig_expand(self, device, dtype, op): shapes = [(1, 3), (1, 7), (5, 7)] for shape in shapes: - contig = make_tensor(shape, dtype=dtype, device=device, - low=op.domain[0], high=op.domain[1]) + contig = make_tensor( + shape, dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] + ) non_contig = contig.clone().expand(3, -1, -1) self.assertTrue(contig.is_contiguous()) @@ -446,13 +434,15 @@ class TestUnaryUfuncs(TestCase): contig = op(contig, **torch_kwargs) non_contig = op(non_contig, **torch_kwargs) for i in range(3): - self.assertEqual(contig, non_contig[i], - msg='non-contiguous expand[' + str(i) + ']') + self.assertEqual( + contig, non_contig[i], msg="non-contiguous expand[" + str(i) + "]" + ) @ops(unary_ufuncs) def test_contig_size1(self, device, dtype, op): - contig = make_tensor((5, 100), dtype=dtype, device=device, - low=op.domain[0], high=op.domain[1]) + contig = make_tensor( + (5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] + ) contig = contig[:1, :50] contig2 = torch.empty(contig.size(), device=device, dtype=dtype) contig2.copy_(contig) @@ -465,8 +455,13 @@ class TestUnaryUfuncs(TestCase): @ops(unary_ufuncs) def test_contig_size1_large_dim(self, device, dtype, op): - contig = make_tensor((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype, device=device, - low=op.domain[0], high=op.domain[1]) + contig = make_tensor( + (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), + dtype=dtype, + device=device, + low=op.domain[0], + high=op.domain[1], + ) contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :] contig2 = torch.empty(contig.size(), device=device, dtype=dtype) contig2.copy_(contig) @@ -481,8 +476,9 @@ class TestUnaryUfuncs(TestCase): # per-batch computation. @ops(unary_ufuncs) def test_batch_vs_slicing(self, device, dtype, op): - input = make_tensor((1024, 512), dtype=dtype, device=device, - low=op.domain[0], high=op.domain[1]) + input = make_tensor( + (1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] + ) torch_kwargs, _ = op.sample_kwargs(device, dtype, input) actual = op(input, **torch_kwargs) @@ -493,11 +489,11 @@ class TestUnaryUfuncs(TestCase): @dtypes(*all_types_and(torch.bool, torch.half)) def test_nan_to_num(self, device, dtype): for contiguous in [False, True]: - x = make_tensor((64, 64), low=0., high=100., dtype=dtype, device=device) + x = make_tensor((64, 64), low=0.0, high=100.0, dtype=dtype, device=device) if dtype.is_floating_point: # Add extremal values. - extremals = [float('nan'), float('inf'), -float('inf')] + extremals = [float("nan"), float("inf"), -float("inf")] for idx, extremal in zip(torch.randint(0, 63, (3,)), extremals): x[idx, :] = extremal @@ -509,12 +505,16 @@ class TestUnaryUfuncs(TestCase): posinf = random.random() * 5 neginf = random.random() * 10 - self.compare_with_numpy(lambda x: x.nan_to_num(nan=nan, posinf=posinf), - lambda x: np.nan_to_num(x, nan=nan, posinf=posinf), - x) - self.compare_with_numpy(lambda x: x.nan_to_num(posinf=posinf, neginf=neginf), - lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf), - x) + self.compare_with_numpy( + lambda x: x.nan_to_num(nan=nan, posinf=posinf), + lambda x: np.nan_to_num(x, nan=nan, posinf=posinf), + x, + ) + self.compare_with_numpy( + lambda x: x.nan_to_num(posinf=posinf, neginf=neginf), + lambda x: np.nan_to_num(x, posinf=posinf, neginf=neginf), + x, + ) # Out Variant out = torch.empty_like(x) @@ -529,7 +529,7 @@ class TestUnaryUfuncs(TestCase): @dtypes(torch.cdouble) def test_complex_edge_values(self, device, dtype): # sqrt Test Reference: https://github.com/pytorch/pytorch/pull/47424 - x = torch.tensor(0. - 1.0e+20j, dtype=dtype, device=device) + x = torch.tensor(0.0 - 1.0e20j, dtype=dtype, device=device) self.compare_with_numpy(torch.sqrt, np.sqrt, x) # acos test reference: https://github.com/pytorch/pytorch/issue/42952 # Skip on Windows, as CUDA acos returns conjugate value @@ -537,7 +537,11 @@ class TestUnaryUfuncs(TestCase): if not (IS_WINDOWS and dtype == torch.cdouble and "cuda" in device): self.compare_with_numpy(torch.acos, np.arccos, x) - x = torch.tensor((-1.0e+60 if dtype == torch.cdouble else -1.0e+20) - 4988429.2j, dtype=dtype, device=device) + x = torch.tensor( + (-1.0e60 if dtype == torch.cdouble else -1.0e20) - 4988429.2j, + dtype=dtype, + device=device, + ) self.compare_with_numpy(torch.sqrt, np.sqrt, x) @unittest.skipIf(not TEST_SCIPY, "Requires SciPy") @@ -547,14 +551,28 @@ class TestUnaryUfuncs(TestCase): # Reference: # https://github.com/scipy/scipy/blob/3a8a3a1d4657254a6611e77e9c28feafa26e6645/scipy/special/tests/test_digamma.py#L22 euler = 0.57721566490153286 - dataset = [(0., -0.), - (1, -euler), - (0.5, -2 * math.log(2) - euler), - (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler), - (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler), - (1 / 6, -math.pi * math.sqrt(3) / 2 - 2 * math.log(2) - 3 * math.log(3) / 2 - euler), - (1 / 8, -math.pi / 2 - 4 * math.log(2) - - (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2))) / math.sqrt(2) - euler)] + dataset = [ + (0.0, -0.0), + (1, -euler), + (0.5, -2 * math.log(2) - euler), + (1 / 3, -math.pi / (2 * math.sqrt(3)) - 3 * math.log(3) / 2 - euler), + (1 / 4, -math.pi / 2 - 3 * math.log(2) - euler), + ( + 1 / 6, + -math.pi * math.sqrt(3) / 2 + - 2 * math.log(2) + - 3 * math.log(3) / 2 + - euler, + ), + ( + 1 / 8, + -math.pi / 2 + - 4 * math.log(2) + - (math.pi + math.log(2 + math.sqrt(2)) - math.log(2 - math.sqrt(2))) + / math.sqrt(2) + - euler, + ), + ] x = torch.tensor(dataset, device=device, dtype=dtype) self.compare_with_numpy(torch.digamma, scipy.special.digamma, x) @@ -562,9 +580,24 @@ class TestUnaryUfuncs(TestCase): @dtypes(torch.float, torch.double) def test_digamma(self, device, dtype): # Tests pole behavior - tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111, - -100.99999994, 0.000000111, -1931.99999994, - -0.000000111, 0, -0, -1, -2, -931], dtype=dtype, device=device) + tensor = torch.tensor( + [ + -0.999999994, + -1.999999994, + -2.0000000111, + -100.99999994, + 0.000000111, + -1931.99999994, + -0.000000111, + 0, + -0, + -1, + -2, + -931, + ], + dtype=dtype, + device=device, + ) self.compare_with_numpy(torch.digamma, scipy.special.digamma, tensor) @dtypes(*floating_types_and(torch.half)) @@ -584,19 +617,25 @@ class TestUnaryUfuncs(TestCase): invalid_input_dtypes = integral_types_and(torch.bool) + complex_types() for dtype in invalid_input_dtypes: input = make_tensor((50, 50), dtype=dtype, device=device) - with self.assertRaisesRegex(RuntimeError, r"torch\.frexp\(\) only supports floating-point dtypes"): + with self.assertRaisesRegex( + RuntimeError, r"torch\.frexp\(\) only supports floating-point dtypes" + ): torch.frexp(input) for dtype in floating_types_and(torch.half): input = make_tensor((50, 50), dtype=dtype, device=device) - dtypes = list(all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) + dtypes = list( + all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16) + ) dtypes.remove(dtype) for mantissa_dtype in dtypes: mantissa = torch.empty_like(input, dtype=mantissa_dtype) exponent = torch.empty_like(input, dtype=torch.int) - with self.assertRaisesRegex(RuntimeError, - r"torch\.frexp\(\) expects mantissa to have dtype .+ but got .+"): + with self.assertRaisesRegex( + RuntimeError, + r"torch\.frexp\(\) expects mantissa to have dtype .+ but got .+", + ): torch.frexp(input, out=(mantissa, exponent)) dtypes.append(dtype) @@ -604,8 +643,10 @@ class TestUnaryUfuncs(TestCase): for exponent_dtype in dtypes: mantissa = torch.empty_like(input) exponent = torch.empty_like(input, dtype=exponent_dtype) - with self.assertRaisesRegex(RuntimeError, - r"torch\.frexp\(\) expects exponent to have int dtype but got .+"): + with self.assertRaisesRegex( + RuntimeError, + r"torch\.frexp\(\) expects exponent to have int dtype but got .+", + ): torch.frexp(input, out=(mantissa, exponent)) def test_mvlgamma_argcheck(self, device): @@ -613,17 +654,21 @@ class TestUnaryUfuncs(TestCase): input = torch.linspace((d - 2) / 2, 10, 10, device=device) torch.mvlgamma(input, d) - with self.assertRaisesRegex(RuntimeError, r"All elements must be greater than \(p-1\)/2"): + with self.assertRaisesRegex( + RuntimeError, r"All elements must be greater than \(p-1\)/2" + ): run_test(3) def test_polygamma_neg(self, device): - with self.assertRaisesRegex(RuntimeError, r'polygamma\(n, x\) does not support negative n\.'): + with self.assertRaisesRegex( + RuntimeError, r"polygamma\(n, x\) does not support negative n\." + ): torch.polygamma(-1, torch.tensor([1.0, 2.0], device=device)) # TODO resolve with opinfos @onlyCPU def test_op_invert(self, device): - res = 0xffff - torch.arange(127, dtype=torch.int8) + res = 0xFFFF - torch.arange(127, dtype=torch.int8) for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): a = torch.arange(127, dtype=dtype) self.assertEqual(res.to(dtype), ~a) @@ -640,16 +685,19 @@ class TestUnaryUfuncs(TestCase): def test_abs_angle_complex_to_float(self, device, dtype): # Constructs random complex values from random import random + random_vals = [] for multiplier in (-1, 1, -10, 10, -100, 100): for _ in range(10): - random_vals.append(complex(random() * multiplier, random() * multiplier)) + random_vals.append( + complex(random() * multiplier, random() * multiplier) + ) for vals in (random_vals, []): a = np.array(vals, dtype=torch_to_numpy_dtype_dict[dtype]) t = torch.tensor(vals, device=device, dtype=dtype) - for fn_name in ('abs', 'angle'): + for fn_name in ("abs", "angle"): torch_fn = getattr(torch, fn_name) np_fn = getattr(np, fn_name) @@ -659,12 +707,16 @@ class TestUnaryUfuncs(TestCase): self.assertEqual(np_result, torch_result, exact_dtype=True) # Tests float out - float_dtype = torch.float32 if dtype is torch.complex64 else torch.float64 + float_dtype = ( + torch.float32 if dtype is torch.complex64 else torch.float64 + ) np_float_out = np_fn(a).astype(torch_to_numpy_dtype_dict[float_dtype]) float_out = torch.empty_like(t).float() torch_fn(t, out=float_out) # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.from_numpy(np_float_out), float_out.cpu()) + self.assertEqualIgnoreType( + torch.from_numpy(np_float_out), float_out.cpu() + ) # Tests float out (resized out) float_out = torch.empty(1, device=device, dtype=float_dtype) @@ -676,13 +728,17 @@ class TestUnaryUfuncs(TestCase): complex_out = torch.empty_like(t) torch_fn(t, out=complex_out) # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.from_numpy(np_complex_out), complex_out.cpu()) + self.assertEqualIgnoreType( + torch.from_numpy(np_complex_out), complex_out.cpu() + ) # Tests complex out (resized out) complex_out = torch.empty(0, device=device, dtype=dtype) torch_fn(t, out=complex_out) # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.from_numpy(np_complex_out), complex_out.cpu()) + self.assertEqualIgnoreType( + torch.from_numpy(np_complex_out), complex_out.cpu() + ) # Tests long out behavior (expected failure) long_out = torch.empty(0, device=device, dtype=torch.long) @@ -690,40 +746,42 @@ class TestUnaryUfuncs(TestCase): torch_fn(t, out=long_out) # Tests inplace - if fn_name == 'abs': + if fn_name == "abs": torch_inplace_method = getattr(torch.Tensor, fn_name + "_") np_fn(a, out=a) if dtype.is_complex: - with self.assertRaisesRegex(RuntimeError, "In-place abs is not supported for complex tensors."): + with self.assertRaisesRegex( + RuntimeError, + "In-place abs is not supported for complex tensors.", + ): torch_inplace_method(t) return torch_inplace_method(t) self.assertEqual(torch.from_numpy(a), t.cpu()) # Note: angle does not have an in-place variant - if fn_name == 'angle': + if fn_name == "angle": with self.assertRaises(AttributeError): torch_inplace_method = getattr(torch.Tensor, fn_name + "_") - def check_internal_mem_overlap(self, inplace_op, num_inputs, - dtype, device, - expected_failure=False): + def check_internal_mem_overlap( + self, inplace_op, num_inputs, dtype, device, expected_failure=False + ): if isinstance(inplace_op, str): inplace_op = getattr(torch.Tensor, inplace_op) input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) - inputs = [input] + [torch.randn_like(input) - for i in range(num_inputs - 1)] + inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)] if not expected_failure: - with self.assertRaisesRegex(RuntimeError, 'single memory location'): + with self.assertRaisesRegex(RuntimeError, "single memory location"): inplace_op(*inputs) else: with self.assertRaises(AssertionError): - with self.assertRaisesRegex(RuntimeError, 'single memory location'): + with self.assertRaisesRegex(RuntimeError, "single memory location"): inplace_op(*inputs) - def unary_check_input_output_mem_overlap(self, data, sz, op, - expected_failure=False): - + def unary_check_input_output_mem_overlap( + self, data, sz, op, expected_failure=False + ): def _test(op, output, input): output_exp = torch.empty_like(output) op(input, out=output_exp) @@ -732,15 +790,15 @@ class TestUnaryUfuncs(TestCase): # output is identical to input: _test(op, output=data[0:sz], input=data[0:sz]) # output and input are independent: - _test(op, output=data[0:sz], input=data[sz:2 * sz]) + _test(op, output=data[0:sz], input=data[sz : 2 * sz]) # output partially overlaps with input: if not expected_failure: - with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): - _test(op, data[0:sz], data[1:sz + 1]) + with self.assertRaisesRegex(RuntimeError, "unsupported operation"): + _test(op, data[0:sz], data[1 : sz + 1]) else: with self.assertRaises(AssertionError): - with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): - _test(op, data[0:sz], data[1:sz + 1]) + with self.assertRaisesRegex(RuntimeError, "unsupported operation"): + _test(op, data[0:sz], data[1 : sz + 1]) # TODO: run on non-native device types @dtypes(torch.double) @@ -750,131 +808,159 @@ class TestUnaryUfuncs(TestCase): positives = torch.randint(1, 100, (2 * sz,), device=device).double() ints = torch.randint(-100, 100, (2 * sz,), device=device) unary_mem_overlap_cases = [ - ("abs", doubles, True, True, 'cpu'), - ("abs", doubles, True, True, 'cuda'), - ("acos", doubles, True, True, 'cpu'), - ("acos", doubles, True, True, 'cuda'), - ("asin", doubles, True, True, 'cpu'), - ("asin", doubles, True, True, 'cuda'), - ("atan", doubles, True, True, 'cpu'), - ("atan", doubles, True, True, 'cuda'), - ("acosh", doubles, True, True, 'cpu'), - ("acosh", doubles, True, True, 'cuda'), - ("asinh", doubles, True, True, 'cpu'), - ("asinh", doubles, True, True, 'cuda'), - ("atanh", doubles, True, True, 'cpu'), - ("atanh", doubles, True, True, 'cuda'), - ("bitwise_not", ints, True, True, 'cpu'), - ("bitwise_not", ints, True, True, 'cuda'), - ("ceil", doubles, True, True, 'cpu'), - ("ceil", doubles, True, True, 'cuda'), - ("cos", doubles, True, True, 'cpu'), - ("cos", doubles, True, True, 'cuda'), - ("cosh", doubles, True, True, 'cpu'), - ("cosh", doubles, True, True, 'cuda'), - ("digamma", doubles, True, True, 'cpu'), - ("erf", doubles, True, True, 'cpu'), - ("erf", doubles, True, True, 'cuda'), - ("erfc", doubles, True, True, 'cpu'), - ("erfc", doubles, True, True, 'cuda'), - ("erfinv", doubles, True, True, 'cpu'), - ("erfinv", doubles, True, True, 'cuda'), - ("exp", doubles, True, True, 'cpu'), - ("exp", doubles, True, True, 'cuda'), - ("exp2", doubles, True, True, 'cpu'), - ("exp2", doubles, True, True, 'cuda'), - ("expm1", doubles, True, True, 'cpu'), - ("expm1", doubles, True, True, 'cuda'), - ("floor", doubles, True, True, 'cpu'), - ("floor", doubles, True, True, 'cuda'), - ("frac", doubles, True, True, 'cpu'), - ("frac", doubles, True, True, 'cuda'), - ("i0", doubles, True, True, 'cpu'), - ("i0", doubles, True, True, 'cuda'), - ("log", positives, True, True, 'cpu'), - ("log", positives, True, True, 'cuda'), - ("log10", positives, True, True, 'cpu'), - ("log10", positives, True, True, 'cuda'), - ("log1p", positives, True, True, 'cpu'), - ("log1p", positives, True, True, 'cuda'), - ("log2", positives, True, True, 'cpu'), - ("log2", positives, True, True, 'cuda'), - ("neg", doubles, True, True, 'cpu'), - ("neg", doubles, True, True, 'cuda'), - ("reciprocal", doubles, True, True, 'cpu'), - ("reciprocal", doubles, True, True, 'cuda'), - ("round", doubles, True, True, 'cpu'), - ("round", doubles, True, True, 'cuda'), - ("rsqrt", positives, True, True, 'cpu'), - ("rsqrt", positives, True, True, 'cuda'), - ("sin", doubles, True, True, 'cpu'), - ("sin", doubles, True, True, 'cuda'), - ("sinh", doubles, True, True, 'cpu'), - ("sinh", doubles, False, True, 'cuda'), - ("sigmoid", doubles, True, True, 'cpu'), - ("sigmoid", doubles, True, True, 'cuda'), - ("logit", doubles, True, True, 'cpu'), - ("logit", doubles, True, True, 'cuda'), - ("sqrt", doubles, True, True, 'cpu'), - ("sqrt", doubles, False, True, 'cuda'), - ("tan", doubles, True, True, 'cpu'), - ("tan", doubles, True, True, 'cuda'), - ("tanh", doubles, True, True, 'cpu'), - ("tanh", doubles, True, True, 'cuda'), - ("trunc", doubles, True, True, 'cpu'), - ("trunc", doubles, True, True, 'cuda') + ("abs", doubles, True, True, "cpu"), + ("abs", doubles, True, True, "cuda"), + ("acos", doubles, True, True, "cpu"), + ("acos", doubles, True, True, "cuda"), + ("asin", doubles, True, True, "cpu"), + ("asin", doubles, True, True, "cuda"), + ("atan", doubles, True, True, "cpu"), + ("atan", doubles, True, True, "cuda"), + ("acosh", doubles, True, True, "cpu"), + ("acosh", doubles, True, True, "cuda"), + ("asinh", doubles, True, True, "cpu"), + ("asinh", doubles, True, True, "cuda"), + ("atanh", doubles, True, True, "cpu"), + ("atanh", doubles, True, True, "cuda"), + ("bitwise_not", ints, True, True, "cpu"), + ("bitwise_not", ints, True, True, "cuda"), + ("ceil", doubles, True, True, "cpu"), + ("ceil", doubles, True, True, "cuda"), + ("cos", doubles, True, True, "cpu"), + ("cos", doubles, True, True, "cuda"), + ("cosh", doubles, True, True, "cpu"), + ("cosh", doubles, True, True, "cuda"), + ("digamma", doubles, True, True, "cpu"), + ("erf", doubles, True, True, "cpu"), + ("erf", doubles, True, True, "cuda"), + ("erfc", doubles, True, True, "cpu"), + ("erfc", doubles, True, True, "cuda"), + ("erfinv", doubles, True, True, "cpu"), + ("erfinv", doubles, True, True, "cuda"), + ("exp", doubles, True, True, "cpu"), + ("exp", doubles, True, True, "cuda"), + ("exp2", doubles, True, True, "cpu"), + ("exp2", doubles, True, True, "cuda"), + ("expm1", doubles, True, True, "cpu"), + ("expm1", doubles, True, True, "cuda"), + ("floor", doubles, True, True, "cpu"), + ("floor", doubles, True, True, "cuda"), + ("frac", doubles, True, True, "cpu"), + ("frac", doubles, True, True, "cuda"), + ("i0", doubles, True, True, "cpu"), + ("i0", doubles, True, True, "cuda"), + ("log", positives, True, True, "cpu"), + ("log", positives, True, True, "cuda"), + ("log10", positives, True, True, "cpu"), + ("log10", positives, True, True, "cuda"), + ("log1p", positives, True, True, "cpu"), + ("log1p", positives, True, True, "cuda"), + ("log2", positives, True, True, "cpu"), + ("log2", positives, True, True, "cuda"), + ("neg", doubles, True, True, "cpu"), + ("neg", doubles, True, True, "cuda"), + ("reciprocal", doubles, True, True, "cpu"), + ("reciprocal", doubles, True, True, "cuda"), + ("round", doubles, True, True, "cpu"), + ("round", doubles, True, True, "cuda"), + ("rsqrt", positives, True, True, "cpu"), + ("rsqrt", positives, True, True, "cuda"), + ("sin", doubles, True, True, "cpu"), + ("sin", doubles, True, True, "cuda"), + ("sinh", doubles, True, True, "cpu"), + ("sinh", doubles, False, True, "cuda"), + ("sigmoid", doubles, True, True, "cpu"), + ("sigmoid", doubles, True, True, "cuda"), + ("logit", doubles, True, True, "cpu"), + ("logit", doubles, True, True, "cuda"), + ("sqrt", doubles, True, True, "cpu"), + ("sqrt", doubles, False, True, "cuda"), + ("tan", doubles, True, True, "cpu"), + ("tan", doubles, True, True, "cuda"), + ("tanh", doubles, True, True, "cpu"), + ("tanh", doubles, True, True, "cuda"), + ("trunc", doubles, True, True, "cpu"), + ("trunc", doubles, True, True, "cuda"), ] - for (fn, inputs, has_input_output_mem_overlap_check, - has_internal_mem_overlap_check, dev) in unary_mem_overlap_cases: + for ( + fn, + inputs, + has_input_output_mem_overlap_check, + has_internal_mem_overlap_check, + dev, + ) in unary_mem_overlap_cases: if dev != device: continue out_fn = getattr(torch, fn) - in_fn = getattr(torch.Tensor, fn + '_') + in_fn = getattr(torch.Tensor, fn + "_") - self.unary_check_input_output_mem_overlap(inputs, sz, out_fn, - expected_failure=not has_input_output_mem_overlap_check) + self.unary_check_input_output_mem_overlap( + inputs, + sz, + out_fn, + expected_failure=not has_input_output_mem_overlap_check, + ) - self.check_internal_mem_overlap(in_fn, 1, dtype, dev, - expected_failure=not has_internal_mem_overlap_check) + self.check_internal_mem_overlap( + in_fn, + 1, + dtype, + dev, + expected_failure=not has_internal_mem_overlap_check, + ) # TODO: opinfo hardshrink @onlyCPU @dtypes(torch.float, torch.double, torch.bfloat16) def test_hardshrink(self, device, dtype): data = torch.tensor([1, 0.5, 0.3, 0.6], dtype=dtype, device=device).view(2, 2) - self.assertEqual(torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2), - data.hardshrink(0.3)) - self.assertEqual(torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2), - data.hardshrink(0.5)) + self.assertEqual( + torch.tensor([1, 0.5, 0, 0.6], dtype=dtype, device=device).view(2, 2), + data.hardshrink(0.3), + ) + self.assertEqual( + torch.tensor([1, 0, 0, 0.6], dtype=dtype, device=device).view(2, 2), + data.hardshrink(0.5), + ) # test default lambd=0.5 self.assertEqual(data.hardshrink(), data.hardshrink(0.5)) # test non-contiguous case - self.assertEqual(torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2), - data.t().hardshrink(0.3)) + self.assertEqual( + torch.tensor([1, 0, 0.5, 0.6], dtype=dtype, device=device).view(2, 2), + data.t().hardshrink(0.3), + ) @onlyCPU @dtypes(torch.float, torch.double, torch.bfloat16) def test_hardshrink_edge_cases(self, device, dtype) -> None: def h(values, l_expected): for l, expected in l_expected.items(): - values_tensor = torch.tensor([float(v) for v in values], - dtype=dtype, device=device) - expected_tensor = torch.tensor([float(v) for v in expected], - dtype=dtype, device=device) - self.assertEqual(expected_tensor == values_tensor.hardshrink(l), - torch.ones_like(values_tensor, dtype=torch.bool)) + values_tensor = torch.tensor( + [float(v) for v in values], dtype=dtype, device=device + ) + expected_tensor = torch.tensor( + [float(v) for v in expected], dtype=dtype, device=device + ) + self.assertEqual( + expected_tensor == values_tensor.hardshrink(l), + torch.ones_like(values_tensor, dtype=torch.bool), + ) def test_helper(min, max): - h([0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], - {0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], - min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], - 0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf], - 1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf], - max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf], - inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}) + h( + [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], + { + 0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], + min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], + 0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf], + 1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf], + max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf], + inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + }, + ) test_helper(torch.finfo(dtype).tiny, torch.finfo(dtype).max) @@ -886,32 +972,36 @@ class TestUnaryUfuncs(TestCase): # Test for https://github.com/pytorch/pytorch/issues/17271 # This is pretty slow on my Macbook but it only takes a few # seconds on a beefy Xeon server - a = torch.exp(torch.ones(2 ** 31, dtype=dtype, device=device)) + a = torch.exp(torch.ones(2**31, dtype=dtype, device=device)) b = torch.exp(torch.ones(1, dtype=dtype, device=device)) - self.assertEqual(a, b.expand(2 ** 31)) + self.assertEqual(a, b.expand(2**31)) - @precisionOverride({torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}) + @precisionOverride( + {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002} + ) @dtypes(torch.float, torch.double, torch.bfloat16) def test_hardswish(self, device, dtype): inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] expectedOutput = np.multiply( - inputValues, - np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0) + inputValues, np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0 + ) inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) - expectedOutputTensor = \ - torch.tensor(expectedOutput, dtype=dtype, device=device) + expectedOutputTensor = torch.tensor(expectedOutput, dtype=dtype, device=device) # normal - self.assertEqual(torch.nn.functional.hardswish(inputTensor), - expectedOutputTensor) + self.assertEqual( + torch.nn.functional.hardswish(inputTensor), expectedOutputTensor + ) # inplace inputTensorCpy = inputTensor.clone().detach() torch.nn.functional.hardswish(inputTensorCpy, inplace=True) self.assertEqual(inputTensorCpy, expectedOutputTensor) - @precisionOverride({torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}) + @precisionOverride( + {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002} + ) @dtypes(torch.float, torch.double, torch.bfloat16) def test_hardsigmoid(self, device, dtype): inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000] @@ -920,20 +1010,28 @@ class TestUnaryUfuncs(TestCase): inputTensor = torch.tensor(inputValues, dtype=dtype, device=device) # normal - self.assertEqual(torch.nn.functional.hardsigmoid(inputTensor), - torch.tensor(expectedOutput, dtype=dtype, device=device)) + self.assertEqual( + torch.nn.functional.hardsigmoid(inputTensor), + torch.tensor(expectedOutput, dtype=dtype, device=device), + ) # inplace inputTensorCpy = inputTensor.clone().detach() - self.assertEqual(torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True), - torch.tensor(expectedOutput, dtype=dtype, device=device)) + self.assertEqual( + torch.nn.functional.hardsigmoid(inputTensorCpy, inplace=True), + torch.tensor(expectedOutput, dtype=dtype, device=device), + ) - @precisionOverride({torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002}) + @precisionOverride( + {torch.bfloat16: 1e-2, torch.float: 0.0002, torch.double: 0.0002} + ) @dtypes(torch.float, torch.double, torch.bfloat16) def test_hardsigmoid_backward(self, device, dtype): inputValues = [-3.0, 3.0, -2.0, 2.0, -6.0, 6.0] expectedValues = [0.0, 0.0, 1.0 / 6.0, 1.0 / 6.0, 0.0, 0.0] - inputTensor = torch.tensor(inputValues, dtype=dtype, device=device).requires_grad_() + inputTensor = torch.tensor( + inputValues, dtype=dtype, device=device + ).requires_grad_() expetedTensor = torch.tensor(expectedValues, dtype=dtype, device=device) out = torch.nn.functional.hardsigmoid(inputTensor) out.backward(torch.ones_like(inputTensor)) @@ -945,7 +1043,8 @@ class TestUnaryUfuncs(TestCase): input_np = np.random.randn(5, 8) special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]] input_np = np.concatenate((input_np, special_input), axis=0).astype( - torch_to_numpy_dtype_dict[dtype]) + torch_to_numpy_dtype_dict[dtype] + ) expected_output_np = input_np * scipy.special.expit(input_np) expected_output = torch.from_numpy(expected_output_np).to(device) @@ -955,18 +1054,30 @@ class TestUnaryUfuncs(TestCase): rtol = 1e-6 input = torch.from_numpy(input_np).clone().contiguous().to(device) - self.assertEqual(torch.nn.functional.silu(input), expected_output, - atol=atol, rtol=rtol) - self.assertEqual(torch.nn.functional.silu(input, inplace=True), - expected_output, atol=atol, rtol=rtol) + self.assertEqual( + torch.nn.functional.silu(input), expected_output, atol=atol, rtol=rtol + ) + self.assertEqual( + torch.nn.functional.silu(input, inplace=True), + expected_output, + atol=atol, + rtol=rtol, + ) input = torch.from_numpy(input_np).clone().to(device) input_noncontig = input.transpose(0, 1) - self.assertEqual(torch.nn.functional.silu(input_noncontig), - expected_output_noncontig, atol=atol, rtol=rtol) - self.assertEqual(torch.nn.functional.silu( - input_noncontig, inplace=True), expected_output_noncontig, - atol=atol, rtol=rtol) + self.assertEqual( + torch.nn.functional.silu(input_noncontig), + expected_output_noncontig, + atol=atol, + rtol=rtol, + ) + self.assertEqual( + torch.nn.functional.silu(input_noncontig, inplace=True), + expected_output_noncontig, + atol=atol, + rtol=rtol, + ) # It is not obvious how to merge this into OpInfo becuase these inputs # succeed for gradcheck but are expected to fail for gradgradcheck @@ -977,10 +1088,12 @@ class TestUnaryUfuncs(TestCase): # We also need to be careful when we are very close to 0, as the # derivative's denominator is squared, and there are some floats # that are positive and whose squares are zero. - a = torch.tensor([0.0, torch.finfo(torch.double).tiny, 1.0], - dtype=dtype, - requires_grad=True, - device=device) + a = torch.tensor( + [0.0, torch.finfo(torch.double).tiny, 1.0], + dtype=dtype, + requires_grad=True, + device=device, + ) gradcheck(torch.sinc, a) @skipIfNoSciPy @@ -989,7 +1102,8 @@ class TestUnaryUfuncs(TestCase): input_np = np.random.randn(5, 8) special_input = [[-1000, -1, -0.1, 0, 0.5, 1, 2, 1000]] input_np = np.concatenate((input_np, special_input), axis=0).astype( - torch_to_numpy_dtype_dict[dtype]) + torch_to_numpy_dtype_dict[dtype] + ) expected_output_np = input_np * np.tanh(np.log1p(np.exp(input_np))) expected_output = torch.from_numpy(expected_output_np).to(device) @@ -999,34 +1113,50 @@ class TestUnaryUfuncs(TestCase): rtol = 1e-6 input = torch.from_numpy(input_np).clone().contiguous().to(device) - self.assertEqual(torch.nn.functional.mish(input), expected_output, - atol=atol, rtol=rtol) - self.assertEqual(torch.nn.functional.mish(input, inplace=True), - expected_output, atol=atol, rtol=rtol) + self.assertEqual( + torch.nn.functional.mish(input), expected_output, atol=atol, rtol=rtol + ) + self.assertEqual( + torch.nn.functional.mish(input, inplace=True), + expected_output, + atol=atol, + rtol=rtol, + ) input = torch.from_numpy(input_np).clone().to(device) input_noncontig = input.transpose(0, 1) - self.assertEqual(torch.nn.functional.mish(input_noncontig), - expected_output_noncontig, atol=atol, rtol=rtol) - self.assertEqual(torch.nn.functional.mish( - input_noncontig, inplace=True), expected_output_noncontig, - atol=atol, rtol=rtol) + self.assertEqual( + torch.nn.functional.mish(input_noncontig), + expected_output_noncontig, + atol=atol, + rtol=rtol, + ) + self.assertEqual( + torch.nn.functional.mish(input_noncontig, inplace=True), + expected_output_noncontig, + atol=atol, + rtol=rtol, + ) # do ops like threshold need a test_unary(_nonufunc) test suite? @onlyCPU - @dtypes(*get_all_math_dtypes('cpu')) + @dtypes(*get_all_math_dtypes("cpu")) def test_threshold(self, device, dtype): if dtype != torch.uint8 and dtype != torch.float16 and not dtype.is_complex: # 100 is wide enough to use AVX2 instructions for all types - x = torch.randn(100, dtype=torch.float, device=device).sign().to(dtype=dtype) + x = ( + torch.randn(100, dtype=torch.float, device=device) + .sign() + .to(dtype=dtype) + ) y = torch.threshold(x, 0, 0) self.assertTrue(y.le(0).any()) - def _helper_test_igamma(self, loglo, loghi, device, dtype, - torch_fcn, scipy_fcn): + def _helper_test_igamma(self, loglo, loghi, device, dtype, torch_fcn, scipy_fcn): exp1 = 2.71828182846 - vec1 = torch.logspace(loglo, loghi, steps=500, base=exp1, - dtype=torch.float64, device=device).unsqueeze(-1) + vec1 = torch.logspace( + loglo, loghi, steps=500, base=exp1, dtype=torch.float64, device=device + ).unsqueeze(-1) vec1 = vec1.to(dtype) inputs = [ (vec1, vec1.transpose(0, 1)), @@ -1034,8 +1164,8 @@ class TestUnaryUfuncs(TestCase): (vec1, 0.5 * vec1), # test for considerable ratio (vec1, 2.0 * vec1), (vec1[::2, :], vec1[::2, :]), # contiguous/noncontiguous tests - (vec1[::2, :], vec1[:vec1.shape[0] // 2, :]), - (vec1[:vec1.shape[0] // 2, :], vec1[::2, :]), + (vec1[::2, :], vec1[: vec1.shape[0] // 2, :]), + (vec1[: vec1.shape[0] // 2, :], vec1[::2, :]), ] half_prec = dtype in [torch.bfloat16, torch.float16] for input0, input1 in inputs: @@ -1055,8 +1185,9 @@ class TestUnaryUfuncs(TestCase): # test igamma for reasonable range of values loglo = -4 # approx 0.018 loghi = 4 # approx 54.6 - self._helper_test_igamma(loglo, loghi, device, dtype, - torch.igamma, scipy.special.gammainc) + self._helper_test_igamma( + loglo, loghi, device, dtype, torch.igamma, scipy.special.gammainc + ) @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) @dtypes(torch.float32, torch.float64) @@ -1066,8 +1197,9 @@ class TestUnaryUfuncs(TestCase): # test igammac for reasonable range of values loglo = -4 # approx 0.018 loghi = 4 # approx 54.6 - self._helper_test_igamma(loglo, loghi, device, dtype, - torch.igammac, scipy.special.gammaincc) + self._helper_test_igamma( + loglo, loghi, device, dtype, torch.igammac, scipy.special.gammaincc + ) @dtypesIfCPU(torch.float16, torch.bfloat16, torch.float32, torch.float64) @dtypes(torch.float32, torch.float64) @@ -1077,8 +1209,8 @@ class TestUnaryUfuncs(TestCase): infs = torch.zeros((3,), **tkwargs) + float("inf") zeros = torch.zeros((3,), **tkwargs) ones = torch.ones((3,), **tkwargs) - zero_to_large = torch.tensor([0., 1., 1e3], **tkwargs) - small_to_inf = torch.tensor([1e-3, 1., float("inf")], **tkwargs) + zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs) + small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs) nans = torch.zeros((3,), **tkwargs) + float("nan") inpouts = [ # (a , x), out @@ -1106,8 +1238,8 @@ class TestUnaryUfuncs(TestCase): infs = torch.zeros((3,), **tkwargs) + float("inf") zeros = torch.zeros((3,), **tkwargs) ones = torch.ones((3,), **tkwargs) - zero_to_large = torch.tensor([0., 1., 1e3], **tkwargs) - small_to_inf = torch.tensor([1e-3, 1., float("inf")], **tkwargs) + zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs) + small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs) nans = torch.zeros((3,), **tkwargs) + float("nan") inpouts = [ # (a , x), out @@ -1261,7 +1393,7 @@ class TestUnaryUfuncs(TestCase): # TODO: allow large opinfo values to be opted-into via metadata @dtypes(torch.long) def test_abs_big_number(self, device, dtype): - bignumber = 2 ** 31 + 1 + bignumber = 2**31 + 1 res = torch.tensor([bignumber], device=device, dtype=dtype) self.assertGreater(res.abs()[0], 0) @@ -1290,11 +1422,13 @@ class TestUnaryUfuncs(TestCase): def test_isposinf_isneginf_non_boolean_output(self, device, dtype): # test non-boolean tensors as the `out=` parameters # boolean outputs are tested in the above testcases - vals = (float('inf'), -float('inf'), 1.2) + vals = (float("inf"), -float("inf"), 1.2) t = torch.tensor(vals, device=device) for torch_op in (torch.isposinf, torch.isneginf): out = torch.empty_like(t, dtype=dtype) - with self.assertRaisesRegex(RuntimeError, 'does not support non-boolean outputs'): + with self.assertRaisesRegex( + RuntimeError, "does not support non-boolean outputs" + ): torch_op(t, out=out) def test_nonzero_empty(self, device): @@ -1332,7 +1466,12 @@ class TestUnaryUfuncs(TestCase): @dtypesIfCUDA(*floating_and_complex_types_and(torch.half, torch.bfloat16)) def test_exp(self, device, dtype): for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()): - a = torch.tensor(v, dtype=dtype, device=device) * torch.arange(18, device=device) / 3 * math.pi + a = ( + torch.tensor(v, dtype=dtype, device=device) + * torch.arange(18, device=device) + / 3 + * math.pi + ) a = a.to(dtype) # bfloat16 overflows if dtype == torch.bfloat16: @@ -1340,10 +1479,12 @@ class TestUnaryUfuncs(TestCase): self.compare_with_numpy(torch.exp, np.exp, a) if dtype.is_complex: - inf_real_zero_imag_in = torch.tensor(complex(float('inf'), 0), device=device, dtype=dtype) + inf_real_zero_imag_in = torch.tensor( + complex(float("inf"), 0), device=device, dtype=dtype + ) inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item() self.assertTrue(math.isinf(inf_real_zero_imag_out.real)) - if self.device_type == 'cpu': + if self.device_type == "cpu": pass # These are commented out because it cannot be consistently reproduced. # This is incorrect. It should be zero. Need fix! @@ -1357,16 +1498,20 @@ class TestUnaryUfuncs(TestCase): self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0) self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in) - zero_real_inf_imag_in = torch.tensor(complex(0, float('inf')), device=device, dtype=dtype) + zero_real_inf_imag_in = torch.tensor( + complex(0, float("inf")), device=device, dtype=dtype + ) zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item() self.assertTrue(math.isnan(zero_real_inf_imag_out.real)) self.assertTrue(math.isnan(zero_real_inf_imag_out.imag)) # Ensure we are notified when NumPy changes its behavior self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in) - inf_real_imag_in = torch.tensor(complex(float('inf'), float('inf')), device=device, dtype=dtype) + inf_real_imag_in = torch.tensor( + complex(float("inf"), float("inf")), device=device, dtype=dtype + ) inf_real_imag_out = torch.exp(inf_real_imag_in).item() - if self.device_type == 'cpu': + if self.device_type == "cpu": pass # This is incorrect. Need fix! https://github.com/pytorch/pytorch/issues/40590 # This is commented out because it cannot be consistently reproduced. @@ -1377,9 +1522,11 @@ class TestUnaryUfuncs(TestCase): self.assertTrue(math.isnan(inf_real_imag_out.imag)) self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in) - inf_real_nan_imag_in = torch.tensor(complex(float('inf'), float('nan')), device=device, dtype=dtype) + inf_real_nan_imag_in = torch.tensor( + complex(float("inf"), float("nan")), device=device, dtype=dtype + ) inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item() - if self.device_type == 'cpu': + if self.device_type == "cpu": pass # This is incorrect. It should be inf. Need fix! https://github.com/pytorch/pytorch/issues/40590 # This is commented out because it cannot be consistently reproduced. @@ -1390,7 +1537,9 @@ class TestUnaryUfuncs(TestCase): self.assertTrue(math.isnan(inf_real_nan_imag_out.imag)) self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in) - nan_real_inf_imag_in = torch.tensor(complex(float('nan'), float('inf')), device=device, dtype=dtype) + nan_real_inf_imag_in = torch.tensor( + complex(float("nan"), float("inf")), device=device, dtype=dtype + ) nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item() self.assertTrue(math.isnan(nan_real_inf_imag_out.real)) self.assertTrue(math.isnan(nan_real_inf_imag_out.imag)) @@ -1400,5 +1549,5 @@ class TestUnaryUfuncs(TestCase): instantiate_device_type_tests(TestUnaryUfuncs, globals()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 3fcfa72cf45e..2fac09a5f425 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -11,7 +11,7 @@ import random from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck, - torch_to_numpy_dtype_dict, + numpy_to_torch_dtype_dict, ) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta) @@ -130,7 +130,7 @@ class TestViewOps(TestCase): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) def test_view_dtype_new(self, device, dtype): - dtypes = torch_to_numpy_dtype_dict.copy() + dtypes = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()} del dtypes[torch.bool] def generate_inputs(): diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 3c5d13be6371..23ca7252533f 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -2,14 +2,15 @@ import torch from torch import Tensor import torch._prims.utils as utils -from torch._prims.utils import TensorLike, TensorLikeType, TensorMeta, ShapeType +from torch._prims.utils import ( + TensorLike, + TensorLikeType, + TensorMeta, + ShapeType, + getnvFuserDtype, +) from torch.overrides import has_torch_function, handle_torch_function -import torch._C._nvfuser as nvfuser # type: ignore[import] - -FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined] -DataType = nvfuser.DataType # type: ignore[attr-defined] - from typing import Sequence, Optional, Union, Callable, List, Tuple, Any from numbers import Number from functools import reduce @@ -1033,21 +1034,8 @@ def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: return a.to(dtype) -_torch_dtype_to_nvfuser_dtype_map = { - torch.cdouble: DataType.ComplexDouble, - torch.cfloat: DataType.ComplexFloat, - torch.double: DataType.Double, - torch.float: DataType.Float, - torch.half: DataType.Half, - torch.bfloat16: DataType.BFloat16, - torch.long: DataType.Int, - torch.int: DataType.Int32, - torch.bool: DataType.Bool, -} - - def _convert_element_type_nvfuser(fd: Any, a: Tensor, dtype: torch.dtype) -> Tensor: - nvfuser_dtype = _torch_dtype_to_nvfuser_dtype_map[dtype] + nvfuser_dtype = getnvFuserDtype(dtype) return fd.Ops.cast(nvfuser_dtype, a) # type: ignore[attr-defined] diff --git a/torch/_prims/executor.py b/torch/_prims/executor.py index 9fd44fe4a5a8..0e721525cfdb 100644 --- a/torch/_prims/executor.py +++ b/torch/_prims/executor.py @@ -3,27 +3,11 @@ from typing import Callable import torch from torch.fx import GraphModule -from torch._prims.utils import TensorMeta +from torch._prims.utils import TensorMeta, getnvFuserDtype from torch._prims.context import PrimContext -import torch._C._nvfuser as nvfuser # type: ignore[import] - -DataType = nvfuser.DataType # type: ignore[attr-defined] -Fusion = nvfuser.Fusion # type: ignore[attr-defined] -FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined] - -# TODO: refactor me into a common place -_torch_dtype_to_nvfuser_dtype_map = { - torch.cdouble: DataType.ComplexDouble, - torch.cfloat: DataType.ComplexFloat, - torch.double: DataType.Double, - torch.float: DataType.Float, - torch.half: DataType.Half, - torch.bfloat16: DataType.BFloat16, - torch.long: DataType.Int, - torch.int: DataType.Int32, - torch.bool: DataType.Bool, -} +if torch.cuda.is_available(): + from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import] def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): @@ -37,6 +21,11 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): gm = GraphModule({}, ctx.graph) return gm.forward(*args, **kwargs) elif executor == "nvfuser": + if not torch.cuda.is_available(): + raise RuntimeError( + "Attempting to use nvFuser trace executor but CUDA is not available!" + ) + # PROTOTYPE nvfuser executor # Only accepts tensor inputs and single tensor outputs # Does not handle kwargs @@ -53,9 +42,7 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): nv_args = [fd] for arg in args: if isinstance(arg, torch.Tensor): - x = fd.define_tensor( - arg.ndim, _torch_dtype_to_nvfuser_dtype_map[arg.dtype] - ) + x = fd.define_tensor(arg.ndim, getnvFuserDtype(arg.dtype)) fd.add_input(x) nv_args.append(x) else: diff --git a/torch/_prims/utils.py b/torch/_prims/utils.py index 3c9e6d067e85..b9235b1eb166 100644 --- a/torch/_prims/utils.py +++ b/torch/_prims/utils.py @@ -8,6 +8,32 @@ import threading import torch from torch.fx import Node +# nvFuser imports are conditional on CUDA being available +if torch.cuda.is_available(): + from torch._C._nvfuser import DataType # type: ignore[import] + + _torch_dtype_to_nvfuser_dtype_map = { + torch.cdouble: DataType.ComplexDouble, + torch.cfloat: DataType.ComplexFloat, + torch.double: DataType.Double, + torch.float: DataType.Float, + torch.half: DataType.Half, + torch.bfloat16: DataType.BFloat16, + torch.long: DataType.Int, + torch.int: DataType.Int32, + torch.bool: DataType.Bool, + } +else: + _torch_dtype_to_nvfuser_dtype_map = {} + + +def getnvFuserDtype(dtype: torch.dtype): + """ + Translates from torch.dtype to nvFuser's DataType enum + """ + return _torch_dtype_to_nvfuser_dtype_map[dtype] + + ShapeType = Union[torch.Size, List[int], Tuple[int, ...]] StrideType = Union[List[int], Tuple[int, ...]] DimsType = Union[int, List[int], Tuple[int, ...]] diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e8405848e1e6..7dee23f5b746 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -42,7 +42,7 @@ from torch.testing._internal.common_utils import \ freeze_rng_state) import torch.testing._internal.opinfo_helper as opinfo_helper -# import torch._refs as refs # noqa: F401 +import torch._refs as refs # noqa: F401 from distutils.version import LooseVersion @@ -1350,89 +1350,6 @@ class ReductionOpInfo(OpInfo): self.result_dtype = result_dtype self.generate_args_kwargs = generate_args_kwargs - -def sample_inputs_unary(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs): - if not op_kwargs: - op_kwargs = {} - - low, high = op_info.domain - low = low if low is None else low + op_info._domain_eps - high = high if high is None else high - op_info._domain_eps - - if op_info.supports_sparse_csr: - # Tensors with dim=2 for sparse CSR testing - yield SampleInput(make_tensor((L, L), device=device, dtype=dtype, - low=low, high=high, - requires_grad=requires_grad), kwargs=op_kwargs) - else: - # Creates a 1D, empty, and scalar tensor - for shape in ((L,), (1, 0, 3), ()): - yield SampleInput(make_tensor(shape, device=device, dtype=dtype, - low=low, high=high, - requires_grad=requires_grad), kwargs=op_kwargs) - - -# Metadata class for unary "universal functions (ufuncs)" that accept a single -# tensor and have common properties like: -class UnaryUfuncInfo(OpInfo): - """Operator information for 'universal unary functions (unary ufuncs).' - These are functions of a single tensor with common properties like: - - they are elementwise functions - - the input shape is the output shape - - they typically have method and inplace variants - - they typically support the out kwarg - - they typically have NumPy or SciPy references - See NumPy's universal function documentation - (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details - about the concept of ufuncs. - """ - - def __init__(self, - name, # the string name of the function - *, - ref, # a reference function - dtypes=floating_types(), - dtypesIfCUDA=None, - dtypesIfROCM=None, - domain=(None, None), # the [low, high) domain of the function - handles_large_floats=True, # whether the op correctly handles large float values (like 1e20) - handles_extremals=True, # whether the op correctly handles extremal values (like inf) - handles_complex_extremals=True, # whether the op correct handles complex extremals (like inf -infj) - supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle - sample_inputs_func=sample_inputs_unary, - sample_kwargs=lambda device, dtype, input: ({}, {}), - supports_sparse=False, - reference_numerics_filter=None, # Filter for singular input values for test_reference_numerics_normal - **kwargs): - self._original_unary_ufunc_args = locals().copy() - - super(UnaryUfuncInfo, self).__init__(name, - dtypes=dtypes, - dtypesIfCUDA=dtypesIfCUDA, - dtypesIfROCM=dtypesIfROCM, - sample_inputs_func=sample_inputs_func, - supports_sparse=supports_sparse, - **kwargs) - self.ref = ref - self.domain = domain - self.handles_large_floats = handles_large_floats - self.handles_extremals = handles_extremals - self.handles_complex_extremals = handles_complex_extremals - self.supports_complex_to_float = supports_complex_to_float - self.reference_numerics_filter = reference_numerics_filter - - # test_unary_ufuncs.py generates its own inputs to test the consistency - # of the operator on sliced tensors, non-contig tensors, etc. - # `sample_kwargs` is a utility function to provide kwargs - # along with those inputs if required (eg. clamp). - # It should return two dictionaries, first holding kwarg for - # torch operator and second one for reference NumPy operator. - self.sample_kwargs = sample_kwargs - - # Epsilon to ensure grad and gradgrad checks don't test values - # outside a function's domain. - self._domain_eps = 1e-5 - def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, low=None, high=None, requires_grad=requires_grad) @@ -2080,32 +1997,48 @@ def generate_elementwise_binary_tensors(op, *, device, dtype, requires_grad=Fals # medium 1D tensor (812,), # large 2D tensor - (1029, 917) + (1029, 917), ) - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) for shape in shapes: lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) rhs = make_arg(shape, **op.rhs_make_tensor_kwargs) yield SampleInput(lhs, args=(rhs,)) + # Returns a generator of pairs of contiguous tensors on the requested device and with # the requested dtype. # # Unlike the previous function, the values in these tensors are specified manually. -def generate_elementwise_binary_small_value_tensors(op, *, device, dtype, requires_grad=False): +def generate_elementwise_binary_small_value_tensors( + op, *, device, dtype, requires_grad=False, exclude_zero=None +): + if exclude_zero is None: + if hasattr(op, "rhs_make_tensor_kwargs"): + exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False) + # defines interesting values - _unsigned_int_vals = (0, 1, 55, 127, 128, 190, 210, 220, 254, 255) - _int_vals = (0, -1, 1, -55, 55, -127, 127, -128, 128) + _unsigned_int_vals = (0, 1, 55, 127, 128, 190, 210, 220, 254) + _int_vals = (0, -1, 1, -55, 55, -127, 127, -128) _float_vals = ( - 0., - -.001, .001, - -.25, .25, - -1., 1., - -math.pi / 2, math.pi / 2, - -math.pi + .00001, math.pi - .00001, - -math.pi, math.pi, - -math.pi - .00001, math.pi + .00001 + 0.0, + -0.001, + 0.001, + -0.25, + 0.25, + -1.0, + 1.0, + -math.pi / 2, + math.pi / 2, + -math.pi + 0.00001, + math.pi - 0.00001, + -math.pi, + math.pi, + -math.pi - 0.00001, + math.pi + 0.00001, ) l_vals = [] @@ -2128,7 +2061,7 @@ def generate_elementwise_binary_small_value_tensors(op, *, device, dtype, requir for l, r in prod: l_vals.append(l) - if r == 0 and op.rhs_make_tensor_kwargs.get('exclude_zero', False): + if r == 0 and exclude_zero: r_vals.append(1) else: r_vals.append(r) @@ -2138,7 +2071,10 @@ def generate_elementwise_binary_small_value_tensors(op, *, device, dtype, requir yield SampleInput(lhs, args=(rhs,)) -def generate_elementwise_binary_large_value_tensors(op, *, device, dtype, requires_grad=False): + +def generate_elementwise_binary_large_value_tensors( + op, *, device, dtype, requires_grad=False +): _large_int_vals = (-1113, 1113, -10701, 10701) _large_float16_vals = (-501, 501, -1001.2, 1001.2, -13437.7, 13437.7) _large_float_vals = _large_float16_vals + (-4988429.2, 4988429.2, -1e20, 1e20) @@ -2170,8 +2106,11 @@ def generate_elementwise_binary_large_value_tensors(op, *, device, dtype, requir yield SampleInput(lhs, args=(rhs,)) -def generate_elementwise_binary_extremal_value_tensors(op, *, device, dtype, requires_grad=False): - _float_extremals = (float('inf'), float('-inf'), float('nan')) + +def generate_elementwise_binary_extremal_value_tensors( + op, *, device, dtype, requires_grad=False +): + _float_extremals = (float("inf"), float("-inf"), float("nan")) l_vals = [] r_vals = [] @@ -2196,9 +2135,12 @@ def generate_elementwise_binary_extremal_value_tensors(op, *, device, dtype, req yield SampleInput(lhs, args=(rhs,)) + # Returns a generator of pairs of contiguous and noncontiguous tensors that # require broadcasting -def generate_elementwise_binary_broadcasting_tensors(op, *, device, dtype, requires_grad=False): +def generate_elementwise_binary_broadcasting_tensors( + op, *, device, dtype, requires_grad=False +): shapes = ( ((1,), ()), ((2,), ()), @@ -2213,29 +2155,30 @@ def generate_elementwise_binary_broadcasting_tensors(op, *, device, dtype, requi ((3, 1, 2), (1, 3, 2)), ) - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) for shape, noncontiguous in product(shapes, [True, False]): shape_lhs, shape_rhs = shape - lhs = make_arg(shape_lhs, - noncontiguous=noncontiguous, - **op.lhs_make_tensor_kwargs) - rhs = make_arg(shape_rhs, - noncontiguous=noncontiguous, - **op.rhs_make_tensor_kwargs) + lhs = make_arg( + shape_lhs, noncontiguous=noncontiguous, **op.lhs_make_tensor_kwargs + ) + rhs = make_arg( + shape_rhs, noncontiguous=noncontiguous, **op.rhs_make_tensor_kwargs + ) yield SampleInput(lhs, args=(rhs,), broadcasts_input=True) -# Returns a generator of pairs of contiguous tensors and scalars -def generate_elementwise_binary_with_scalar_samples(op, *, device, dtype, requires_grad=False): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - scalar_shapes = ( - (), - (3,), - (5, 3), - (0, 1, 3), - (1, 5) +# Returns a generator of pairs of contiguous tensors and scalars +def generate_elementwise_binary_with_scalar_samples( + op, *, device, dtype, requires_grad=False +): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad ) + + scalar_shapes = ((), (3,), (5, 3), (0, 1, 3), (1, 5)) if op.supports_rhs_python_scalar: for scalar_shape in scalar_shapes: lhs = make_arg(scalar_shape, **op.lhs_make_tensor_kwargs) @@ -2255,9 +2198,14 @@ def generate_elementwise_binary_with_scalar_samples(op, *, device, dtype, requir yield SampleInput(lhs_scalar, args=(rhs_scalar,)) + # Returns a generator of pairs of noncontiguous tensors -def generate_elementwise_binary_noncontiguous_tensors(op, *, device, dtype, requires_grad=False): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) +def generate_elementwise_binary_noncontiguous_tensors( + op, *, device, dtype, requires_grad=False +): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) # Generic noncontiguity lhs = make_arg((1026,), noncontiguous=True, **op.lhs_make_tensor_kwargs) @@ -2273,10 +2221,7 @@ def generate_elementwise_binary_noncontiguous_tensors(op, *, device, dtype, requ yield SampleInput(lhs.T, args=(rhs.T,)) # More noncontiguity - shapes = ( - (5, 7), - (1024,) - ) + shapes = ((5, 7), (1024,)) for shape in shapes: lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) @@ -2303,11 +2248,7 @@ def generate_elementwise_binary_noncontiguous_tensors(op, *, device, dtype, requ yield SampleInput(lhs_non_contig.contiguous(), args=(rhs_non_contig,)) # Expanded tensors - shapes = ( - (1, 3), - (1, 7), - (5, 7) - ) + shapes = ((1, 3), (1, 7), (5, 7)) for shape in shapes: lhs = make_arg(shape, **op.lhs_make_tensor_kwargs) @@ -2321,7 +2262,9 @@ def generate_elementwise_binary_noncontiguous_tensors(op, *, device, dtype, requ # Sample inputs for elementwise binary operators, like add def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) shapes = ( ((), ()), @@ -2332,37 +2275,54 @@ def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) ((S, M, S), (S, M, S)), ((M, 1, S), (M, S)), ((M, 1, S), (1, M, S)), - ((0, 1, 3), (0, 10, 3)) + ((0, 1, 3), (0, 10, 3)), ) - sample_kwargs = kwargs.get('sample_kwargs', {}) + sample_kwargs = kwargs.get("sample_kwargs", {}) for shape_lhs, shape_rhs in shapes: lhs = make_arg(shape_lhs, **op.lhs_make_tensor_kwargs) rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs) - broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)) + broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs) + + yield SampleInput( + lhs, args=(rhs,), kwargs=sample_kwargs, broadcasts_input=broadcasts_input + ) - yield SampleInput(lhs, args=(rhs,), kwargs=sample_kwargs, broadcasts_input=broadcasts_input) # The base reference input generation for elementwise binary operations def _reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) - yield from generate_elementwise_binary_tensors(op, device=device, dtype=dtype, requires_grad=requires_grad) + yield from generate_elementwise_binary_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) if dtype is not torch.bool: - yield from generate_elementwise_binary_small_value_tensors(op, device=device, dtype=dtype, requires_grad=requires_grad) + yield from generate_elementwise_binary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) if dtype not in (torch.bool, torch.uint8, torch.int8): - yield from generate_elementwise_binary_large_value_tensors(op, device=device, dtype=dtype, requires_grad=requires_grad) - yield from generate_elementwise_binary_broadcasting_tensors(op, device=device, dtype=dtype, requires_grad=requires_grad) - yield from generate_elementwise_binary_noncontiguous_tensors(op, device=device, dtype=dtype, requires_grad=requires_grad) - yield from generate_elementwise_binary_with_scalar_samples(op, device=device, dtype=dtype, requires_grad=requires_grad) + yield from generate_elementwise_binary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield from generate_elementwise_binary_broadcasting_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield from generate_elementwise_binary_with_scalar_samples( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) if dtype.is_floating_point or dtype.is_complex: - yield from generate_elementwise_binary_extremal_value_tensors(op, device=device, dtype=dtype, requires_grad=requires_grad) + yield from generate_elementwise_binary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + # Note that these references inputs use scalars for the SampleInput.input value, # and many tests require SampleInput.input be a tensor or a list of tensors def reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs): - gen = partial(_reference_inputs_elementwise_binary, op, device, dtype, requires_grad, **kwargs) + gen = partial( + _reference_inputs_elementwise_binary, op, device, dtype, requires_grad, **kwargs + ) # yields "normal" samples yield from gen() @@ -2371,6 +2331,11 @@ def reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwar for sample in gen(): yield sample.noncontiguous() + yield from generate_elementwise_binary_noncontiguous_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ) + + # A functional that extends an elementwise binary operator's bespoke error inputs # with generic error inputs for the class of elementwise binary operations def make_error_inputs_elementwise_binary(error_inputs_func): @@ -2386,12 +2351,16 @@ def make_error_inputs_elementwise_binary(error_inputs_func): si = SampleInput(2, args=(torch.tensor((1, 2, 3), device=device),)) yield ErrorInput(si, error_type=Exception, error_regex="") - if not kwargs.get('skip_two_python_scalars', False) and not op.supports_two_python_scalars: + if ( + not kwargs.get("skip_two_python_scalars", False) + and not op.supports_two_python_scalars + ): si = SampleInput(2, args=(3,)) yield ErrorInput(si, error_type=Exception, error_regex="") return error_inputs_func_wrapper + # Metadata class for binary "universal functions (ufuncs)" that accept two # tensor and have common properties class BinaryUfuncInfo(OpInfo): @@ -2406,18 +2375,23 @@ class BinaryUfuncInfo(OpInfo): (https://numpy.org/doc/stable/reference/ufuncs.html) for more details about the concept of ufuncs. """ - def __init__(self, name, *, - sample_inputs_func=sample_inputs_elementwise_binary, - reference_inputs_func=reference_inputs_elementwise_binary, - error_inputs_func=None, - lhs_make_tensor_kwargs=None, - rhs_make_tensor_kwargs=None, - promotes_int_to_float=False, # Set to true if the op promotes integer inputs to float - always_returns_bool=False, # Set to true if the op always returns bool tensors - supports_rhs_python_scalar=True, # Whether the operator allows Tensor x scalar inputs - supports_one_python_scalar=False, # Whether the operator allows scalar x tensor and tensor x scalar inputs - supports_two_python_scalars=False, # Whether the operator allows scalar x scalar inputs - **kwargs): + + def __init__( + self, + name, + *, + sample_inputs_func=sample_inputs_elementwise_binary, + reference_inputs_func=reference_inputs_elementwise_binary, + error_inputs_func=None, + lhs_make_tensor_kwargs=None, + rhs_make_tensor_kwargs=None, + promotes_int_to_float=False, # Set to true if the op promotes integer inputs to float + always_returns_bool=False, # Set to true if the op always returns bool tensors + supports_rhs_python_scalar=True, # Whether the operator allows Tensor x scalar inputs + supports_one_python_scalar=False, # Whether the operator allows scalar x tensor and tensor x scalar inputs + supports_two_python_scalars=False, # Whether the operator allows scalar x scalar inputs + **kwargs, + ): self._original_binary_ufunc_args = locals().copy() @@ -2425,17 +2399,20 @@ class BinaryUfuncInfo(OpInfo): # in test_binary_ufuncs, but with additional test granularity. So the # generic test_ops.py test is skipped because it's redundant. common_skips = ( - DecorateInfo(unittest.skip('Skipping redundant test.'), - 'TestCommon', - 'test_reference_testing'), + DecorateInfo( + unittest.skip("Skipping redundant test."), + "TestCommon", + "test_reference_testing", + ), ) - kwargs['skips'] = kwargs.get('skips', tuple()) + common_skips + kwargs["skips"] = kwargs.get("skips", tuple()) + common_skips super(BinaryUfuncInfo, self).__init__( name, sample_inputs_func=sample_inputs_func, reference_inputs_func=reference_inputs_func, error_inputs_func=make_error_inputs_elementwise_binary(error_inputs_func), - **kwargs) + **kwargs, + ) # [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on. if lhs_make_tensor_kwargs is None: @@ -2456,9 +2433,320 @@ class BinaryUfuncInfo(OpInfo): self.supports_one_python_scalar = True if self.supports_one_python_scalar: - assert supports_rhs_python_scalar, "Can't support lhs and rhs Python scalars but not rhs scalars!" + assert ( + supports_rhs_python_scalar + ), "Can't support lhs and rhs Python scalars but not rhs scalars!" +# The following functions and classes are for testing elementwise unary operators. +def sample_inputs_elementwise_unary( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + if not op_kwargs: + op_kwargs = {} + + low, high = op_info.domain + low = low if low is None else low + op_info._domain_eps + high = high if high is None else high - op_info._domain_eps + + if op_info.supports_sparse_csr: + # Tensors with dim=2 for sparse CSR testing + yield SampleInput( + make_tensor( + (L, L), + device=device, + dtype=dtype, + low=low, + high=high, + requires_grad=requires_grad, + ), + kwargs=op_kwargs, + ) + else: + # Creates a 1D, empty, and scalar tensor + for shape in ((L,), (1, 0, 3), ()): + yield SampleInput( + make_tensor( + shape, + device=device, + dtype=dtype, + low=low, + high=high, + requires_grad=requires_grad, + ), + kwargs=op_kwargs, + ) + + +# Replace values satisfying condition with a safe value. This is used to block +# out values the could cause singularity like tan(pi/2) +def _replace_values_in_tensor(tensor, condition, safe_value): + mask = condition(tensor) + tensor.masked_fill_(mask, safe_value) + + +# Helper to create a unary elementwise tensor with valid inputs +def _make_unary_elementwise_tensor(shape, *, op, device, dtype, requires_grad=False): + low, high = op.domain + low = low if low is None else low + op._domain_eps + high = high if high is None else high - op._domain_eps + + make_arg = partial( + make_tensor, + low=low, + high=high, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + a = make_arg(shape) + + if op.reference_numerics_filter is not None and dtype is not torch.bool: + condition, safe_value = op.reference_numerics_filter + _replace_values_in_tensor(a, condition, safe_value) + + return a + + +# Restricts the values in the tensor to the domain of the +# given elementwise unary operator +def _filter_unary_elementwise_tensor(a, *, op): + # short-circuits for boolean tensors + if a.dtype is torch.bool: + return a + + low, high = op.domain + low = low if low is None else low + op._domain_eps + high = high if high is None else high - op._domain_eps + + if a.dtype is torch.uint8 and low is not None: + low = max(low, 0) + + if not a.dtype.is_floating_point and not a.dtype.is_complex: + low = math.ceil(low) if low is not None else None + high = math.floor(high) if high is not None else None + + if op.reference_numerics_filter is not None: + condition, safe_value = op.reference_numerics_filter + _replace_values_in_tensor(a, condition, safe_value) + + if low is not None or high is not None: + if a.dtype.is_complex: + a.real.clamp_(low, high) + a.imag.clamp_(low, high) + else: + a.clamp_(min=low, max=high) + + return a + + +def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs): + + # Special-cases bool + if dtype is torch.bool: + tensors = ( + torch.empty(0, device=device, dtype=torch.bool), + torch.tensor(True, device=device), + torch.tensor(False, device=device), + torch.tensor((True, False), device=device), + make_tensor((812,), device=device, dtype=dtype), + make_tensor((1029, 917), device=device, dtype=dtype), + ) + for a in tensors: + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + shapes = ( + (1029, 917), + (812,), + # Empty sizes + (0,), + (0, 3, 3), + (1, 0, 5), + (6, 0, 0, 0), + (3, 0, 1, 0), + ) + + make_arg = partial( + _make_unary_elementwise_tensor, + op=op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + for shape in shapes: + a = make_arg(shape) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_small_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + a = _filter_unary_elementwise_tensor(sample.input, op=op) + yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_large_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + a = _filter_unary_elementwise_tensor(sample.input, op=op) + yield SampleInput(sample.input, kwargs=op.sample_kwargs(device, dtype, a)[0]) + + +def generate_elementwise_unary_extremal_value_tensors( + op, *, device, dtype, requires_grad=False +): + for sample in generate_elementwise_binary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad + ): + yield SampleInput( + sample.input, kwargs=op.sample_kwargs(device, dtype, sample.input)[0] + ) + + +def generate_elementwise_unary_noncontiguous_tensors( + op, *, device, dtype, requires_grad=False +): + low, high = op_info.domain + low = low if low is None else low + op_info._domain_eps + high = high if high is None else high - op_info._domain_eps + + make_arg = partial( + _make_unary_elementwise_tensor, + op=op, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + # Generic noncontiguity + t = make_arg((1026,), noncontiguous=True) + yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) + + # Transposed + t = make_arg((1024, 1024)).T + yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0]) + + # Expanded tensors + shapes = ((1, 3), (1, 7), (5, 7)) + + for shape in shapes: + t = make_arg(shape) + t_non_contig = t.expand(3, -1, -1) + yield SampleInput( + t_non_contig, kwargs=op.sample_kwargs(device, dtype, t_non_contig)[0] + ) + + +# Reuses the elementwise binary generators for consistency +# TODO: in the future generalize the reference generators to handle n-ary elementwise operations +def _reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): + yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs) + + yield from generate_elementwise_unary_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + if dtype is not torch.bool: + yield from generate_elementwise_unary_small_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + if dtype not in (torch.bool, torch.uint8, torch.int8) and ( + op.handles_large_floats + or (not dtype.is_floating_point and not dtype.is_complex) + ): + yield from generate_elementwise_unary_large_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + yield from generate_elementwise_unary_extremal_value_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + +def reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs): + gen = partial( + _reference_inputs_elementwise_unary, op, device, dtype, requires_grad, **kwargs + ) + + # yields "normal" samples + yield from gen() + + # yields noncontiguous samples + for sample in gen(): + yield sample.noncontiguous() + + yield from generate_elementwise_unary_noncontiguous_tensors( + op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs + ) + + +# Metadata class for unary "universal functions (ufuncs)" that accept a single +# tensor and have common properties like: +class UnaryUfuncInfo(OpInfo): + """Operator information for 'universal unary functions (unary ufuncs).' + These are functions of a single tensor with common properties like: + - they are elementwise functions + - the input shape is the output shape + - they typically have method and inplace variants + - they typically support the out kwarg + - they typically have NumPy or SciPy references + See NumPy's universal function documentation + (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details + about the concept of ufuncs. + """ + + def __init__( + self, + name, # the string name of the function + *, + ref, # a reference function + dtypes=floating_types(), + dtypesIfCUDA=None, + dtypesIfROCM=None, + domain=(None, None), # the [low, high) domain of the function + handles_large_floats=True, # whether the op correctly handles large float values (like 1e20) + supports_complex_to_float=False, # op supports casting from complex input to real output safely eg. angle + sample_inputs_func=sample_inputs_elementwise_unary, + sample_kwargs=lambda device, dtype, input: ({}, {}), + supports_sparse=False, + reference_numerics_filter=None, # Filters values in the range of the domain specified above but that should not be tested + **kwargs, + ): + self._original_unary_ufunc_args = locals().copy() + + super(UnaryUfuncInfo, self).__init__( + name, + dtypes=dtypes, + dtypesIfCUDA=dtypesIfCUDA, + dtypesIfROCM=dtypesIfROCM, + sample_inputs_func=sample_inputs_func, + supports_sparse=supports_sparse, + **kwargs, + ) + self.ref = ref + self.domain = domain + self.handles_large_floats = handles_large_floats + self.supports_complex_to_float = supports_complex_to_float + self.reference_numerics_filter = reference_numerics_filter + + # test_unary_ufuncs.py generates its own inputs to test the consistency + # of the operator on sliced tensors, non-contig tensors, etc. + # `sample_kwargs` is a utility function to provide kwargs + # along with those inputs if required (eg. clamp). + # It should return two dictionaries, first holding kwarg for + # torch operator and second one for reference NumPy operator. + self.sample_kwargs = sample_kwargs + + # Epsilon to ensure grad and gradgrad checks don't test values + # outside a function's domain. + self._domain_eps = 1e-5 + def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs): yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) @@ -8937,10 +9225,10 @@ op_db: List[OpInfo] = [ 'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat]), # Reference: https://github.com/pytorch/pytorch/issues/49224 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=[torch.int8], active_if=TEST_WITH_ASAN), # TODO: Fix test_out_arg_all_dtypes as torch.empty_like(expected_output) where expected_output=op(input) # We can break the logic of the loop over all possible types but it is OK. @@ -8963,7 +9251,6 @@ op_db: List[OpInfo] = [ aliases=('arccos', ), ref=np.arccos, domain=(-1, 1), - handles_complex_extremals=False, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, @@ -8973,15 +9260,21 @@ op_db: List[OpInfo] = [ torch.bfloat16: 1e-1, torch.complex64: 1e-2}),), skips=( - # Failing with wrong imaginary sign on at least some Windows jobs DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + # Failing with wrong imaginary sign on at least some Windows jobs + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), # Failing with wrong imaginary sign on at least some Windows jobs - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_fn_grad', dtypes=[torch.cdouble], active_if=IS_WINDOWS), @@ -9008,18 +9301,22 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), # Failing with wrong imaginary sign on at least some Windows jobs - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), # Reference: https://github.com/pytorch/pytorch/issues/50692 @@ -9267,12 +9564,12 @@ op_db: List[OpInfo] = [ skips=( DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), @@ -9293,14 +9590,16 @@ op_db: List[OpInfo] = [ skips=( DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), @@ -9318,20 +9617,20 @@ op_db: List[OpInfo] = [ supports_sparse_csr=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', active_if=TEST_WITH_ROCM, device_type='cuda'), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', active_if=TEST_WITH_ROCM, device_type='cuda'), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), @@ -9362,16 +9661,16 @@ op_db: List[OpInfo] = [ supports_sparse=True, supports_sparse_csr=True, skips=( - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.cfloat], active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), @@ -9619,12 +9918,7 @@ op_db: List[OpInfo] = [ supports_sparse=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - supports_out=False, - skips=( - # numpy() raises TypeError: Got unsupported ScalarType ComplexHalf - DecorateInfo(unittest.expectedFailure, "TestUnaryUfuncs", "test_reference_numerics_normal", - dtypes=(torch.complex32,)), - )), + supports_out=False), UnaryUfuncInfo('conj_physical', ref=np.conj, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, @@ -9640,9 +9934,6 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )), DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"), 'TestSparseUnaryUfuncs', 'test_inplace'), - # numpy() raises TypeError: Got unsupported ScalarType ComplexHalf - DecorateInfo(unittest.expectedFailure, "TestUnaryUfuncs", "test_reference_numerics_normal", - dtypes=(torch.complex32,)), # RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf' DecorateInfo(unittest.expectedFailure, "TestSparseCSR", "test_sparse_csr_consistency", dtypes=(torch.complex32,)), @@ -9726,6 +10017,11 @@ op_db: List[OpInfo] = [ supports_fwgrad_bwgrad=True, decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + # This fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', @@ -9741,18 +10037,18 @@ op_db: List[OpInfo] = [ supports_fwgrad_bwgrad=True, skips=( # Reference: https://github.com/pytorch/pytorch/issues/48641 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.int8]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.cdouble]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), )), @@ -9852,7 +10148,7 @@ op_db: List[OpInfo] = [ supports_fwgrad_bwgrad=True, skips=( # Reference: https://github.com/pytorch/pytorch/pull/51283#issuecomment-770614273 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.bfloat16]), )), OpInfo('diff', @@ -9921,14 +10217,14 @@ op_db: List[OpInfo] = [ # Reference: https://github.com/pytorch/pytorch/pull/50093#pullrequestreview-561791547 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.bfloat16, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=[torch.bfloat16]), # Reference: https://github.com/pytorch/pytorch/issues/48010 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), ), assert_autodiffed=True, @@ -10056,13 +10352,12 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)), # 76047 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.float32, torch.float64)), - ), - # Reference for disabling extremals - # https://github.com/pytorch/pytorch/issues/51948 - handles_extremals=False), + )), SpectralFuncInfo('fft.fft', aten_name='fft_fft', ref=np.fft.fft, @@ -10369,7 +10664,7 @@ op_db: List[OpInfo] = [ supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_i0_i1, skips=( - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=(torch.int8,)), )), UnaryUfuncInfo('special.i0e', @@ -10397,7 +10692,7 @@ op_db: List[OpInfo] = [ skips=( DecorateInfo(unittest.skip("Incorrect result!"), 'TestUnaryUfuncs', - 'test_reference_numerics_hard', + 'test_reference_numerics_large', dtypes=(torch.int8,)), ), supports_fwgrad_bwgrad=True, @@ -10454,9 +10749,9 @@ op_db: List[OpInfo] = [ # skips test_reference_numerics due to error in Windows CI. # The np.frexp returns exponent as np.intc dtype on Windows platform, # and np.intc does not have the correspond torch dtype - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', active_if=IS_WINDOWS), @@ -12652,6 +12947,10 @@ op_db: List[OpInfo] = [ }), 'TestUnaryUfuncs', device_type='cuda', ), ], + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=(torch.cfloat,), device_type='cpu'), + ) ), # TODO: combine this with the nn.functional.silu OpInfo when # complex autodiff for silu is supported or when @@ -12683,14 +12982,16 @@ op_db: List[OpInfo] = [ 'TestUnaryUfuncs', device_type='cuda', ), ], skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=(torch.cfloat,), device_type='cpu'), # FIXME: intentionally misreports dtypes DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j) DecorateInfo(unittest.skip("Skipped!"), - 'TestUnaryUfuncs', 'test_reference_numerics_hard', + 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=(torch.complex64, torch.cdouble)), DecorateInfo(unittest.skip("Skipped!"), - 'TestUnaryUfuncs', 'test_reference_numerics_normal', + 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=(torch.complex64,)), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', @@ -12733,10 +13034,10 @@ op_db: List[OpInfo] = [ decorators=[ DecorateInfo( precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), - 'TestUnaryUfuncs', 'test_reference_numerics_normal'), + 'TestUnaryUfuncs', 'test_reference_numerics_small'), DecorateInfo( precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), - 'TestUnaryUfuncs', 'test_reference_numerics_hard'), + 'TestUnaryUfuncs', 'test_reference_numerics_large'), DecorateInfo( precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), @@ -12778,13 +13079,13 @@ op_db: List[OpInfo] = [ DecorateInfo( toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ], skips=( - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=(torch.int, torch.int8)), DecorateInfo(unittest.expectedFailure, 'TestGradients', "test_fn_fwgrad_bwgrad", dtypes=(torch.complex128,)), # pytorch computes (0+nanj), numpy computes (-5e-18-1j) for input (-501.-1.0000e+20j) DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', - "test_reference_numerics_hard", dtypes=(torch.complex64,)),), + "test_reference_numerics_large", dtypes=(torch.complex64,)),), ), UnaryUfuncInfo( 'nn.functional.tanhshrink', @@ -12798,6 +13099,8 @@ op_db: List[OpInfo] = [ supports_gradgrad=True, supports_out=False, decorators=[ + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo( toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), 'TestUnaryUfuncs',), DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), @@ -12806,10 +13109,10 @@ op_db: List[OpInfo] = [ skips=( # in each case, pytorch will produce a nan while numpy will not DecorateInfo(unittest.expectedFailure, - 'TestUnaryUfuncs', "test_reference_numerics_normal", + 'TestUnaryUfuncs', "test_reference_numerics_small", dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)), DecorateInfo(unittest.skip("Fails on some jobs works on others!"), - 'TestUnaryUfuncs', "test_reference_numerics_hard", + 'TestUnaryUfuncs', "test_reference_numerics_large", dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)), DecorateInfo(unittest.skip("Fails on some jobs works on others!"), 'TestUnaryUfuncs', "test_reference_numerics_extremal", @@ -13130,26 +13433,29 @@ op_db: List[OpInfo] = [ MvlGammaInfo(variant_test_name='mvlgamma_p_1', domain=(1, None), skips=skips_mvlgamma() + \ - (DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + (DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=(torch.float16, torch.int8)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=(torch.int8,)),), sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), MvlGammaInfo(variant_test_name='mvlgamma_p_3', domain=(2, None), skips=skips_mvlgamma(skip_redundant=True) + ( - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=(torch.float16, torch.int8)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=(torch.int8,)), ), sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), MvlGammaInfo(variant_test_name='mvlgamma_p_5', domain=(3, None), skips=skips_mvlgamma(skip_redundant=True) + ( - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=(torch.float16, torch.int8)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=(torch.int8,)), ), sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})), @@ -13274,9 +13580,9 @@ op_db: List[OpInfo] = [ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), skips=( # Reference: https://github.com/pytorch/pytorch/pull/51283#issuecomment-770614273 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.bfloat16]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', dtypes=[torch.bfloat16]), @@ -13327,7 +13633,7 @@ op_db: List[OpInfo] = [ dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}), - sample_inputs_func=partial(sample_inputs_unary, op_kwargs={'decimals': 0}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}), supports_forward_ad=True, supports_fwgrad_bwgrad=True, assert_autodiffed=False, @@ -13339,7 +13645,7 @@ op_db: List[OpInfo] = [ dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), sample_kwargs=lambda device, dtype, input: ({'decimals': 3}, {'decimals': 3}), - sample_inputs_func=partial(sample_inputs_unary, op_kwargs={'decimals': 3}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 3}), skips=( # test_ops already tested for this overload with `decimals_0` opinfo entry DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), @@ -13358,7 +13664,7 @@ op_db: List[OpInfo] = [ dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), sample_kwargs=lambda device, dtype, input: ({'decimals': -3}, {'decimals': -3}), - sample_inputs_func=partial(sample_inputs_unary, op_kwargs={'decimals': -3}), + sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': -3}), skips=( # test_ops already tested for this overload with `decimals_0` opinfo entry DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), @@ -13376,12 +13682,18 @@ op_db: List[OpInfo] = [ dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, handles_large_floats=False, - handles_complex_extremals=False, supports_sparse=True, supports_sparse_csr=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( + # Fails on CUDA but passes on ROCm + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cdouble,), device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), ), @@ -13392,14 +13704,13 @@ op_db: List[OpInfo] = [ dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), handles_large_floats=False, - handles_complex_extremals=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, decorators=(precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}),), skips=( # Reference: https://github.com/pytorch/pytorch/issues/49133 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=[torch.cfloat]), )), UnaryUfuncInfo('sinh', @@ -13416,13 +13727,13 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=(torch.cdouble,)), # Reference: https://github.com/pytorch/pytorch/issues/48641 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.int8]), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), @@ -13457,7 +13768,7 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.complex64, torch.complex128]), # Reference: https://github.com/pytorch/pytorch/issues/48486 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.complex64]), # The complex formula might be wrong DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', @@ -13701,17 +14012,17 @@ op_db: List[OpInfo] = [ skips=( DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', device_type='cpu', dtypes=[torch.bfloat16]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.float64], active_if=TEST_WITH_ROCM), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), @@ -13737,7 +14048,7 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), # alias, nn.functional.tanh, will produce (because of warning string saved): @@ -13833,9 +14144,9 @@ op_db: List[OpInfo] = [ # Reference: https://github.com/pytorch/pytorch/pull/48926#issuecomment-739734774 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', device_type='cpu', dtypes=[torch.bfloat16]), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), @@ -13871,9 +14182,9 @@ op_db: List[OpInfo] = [ # Reference: https://github.com/pytorch/pytorch/pull/49102#issuecomment-744604601 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=[torch.bfloat16]), )), UnaryUfuncInfo('rsqrt', @@ -13885,7 +14196,10 @@ op_db: List[OpInfo] = [ assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - handles_complex_extremals=False), + skips=( + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + dtypes=(torch.cfloat, torch.cdouble)), + )), UnaryUfuncInfo('sqrt', ref=np.sqrt, supports_sparse=True, @@ -13899,16 +14213,15 @@ op_db: List[OpInfo] = [ decorators=(precisionOverride({torch.bfloat16: 7e-2}),), skips=( # Reference: https://github.com/pytorch/pytorch/issues/47358 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.bfloat16]), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), - ), - handles_complex_extremals=False), + )), UnaryUfuncInfo('square', ref=np.square, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), @@ -13917,7 +14230,7 @@ op_db: List[OpInfo] = [ supports_fwgrad_bwgrad=True, skips=( # Reference: https://github.com/pytorch/pytorch/issues/52549 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.cfloat, torch.cdouble]), # >>> t = torch.tensor(complex(-0.01, float("inf"))) # >>> np.square(t.numpy()) @@ -13929,7 +14242,7 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), # Reference: https://github.com/pytorch/pytorch/pull/52551#issuecomment-782596181 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.bfloat16]), ),), OpInfo('lerp', @@ -13989,7 +14302,7 @@ op_db: List[OpInfo] = [ supports_autograd=False, skips=( # Reference: https://github.com/pytorch/pytorch/issues/66402 - DecorateInfo(unittest.expectedFailure, "TestUnaryUfuncs", "test_reference_numerics_hard", + DecorateInfo(unittest.expectedFailure, "TestUnaryUfuncs", "test_reference_numerics_large", device_type='cpu', dtypes=(torch.complex64,), active_if=not (IS_MACOS or IS_WINDOWS)), )), UnaryUfuncInfo('isinf', @@ -14320,7 +14633,7 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), # Mismatch: https://github.com/pytorch/pytorch/issues/55357 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard'), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'), ), sample_kwargs=lambda device, dtype, input: ({'n': 1}, {'n': 1}), # polygamma functions have multiple singularities at x <= 0 @@ -14342,9 +14655,9 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), # Mismatch: https://github.com/pytorch/pytorch/issues/55357 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', active_if=TEST_WITH_ROCM), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', active_if=TEST_WITH_ROCM),), sample_kwargs=lambda device, dtype, input: ({'n': 2}, {'n': 2}), # polygamma functions have multiple singularities at x <= 0 @@ -14387,9 +14700,9 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), # Mismatch: https://github.com/pytorch/pytorch/issues/55357 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', active_if=TEST_WITH_ROCM), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', active_if=TEST_WITH_ROCM),), sample_kwargs=lambda device, dtype, input: ({'n': 4}, {'n': 4}), # polygamma functions have multiple singularities at x <= 0 @@ -15562,12 +15875,14 @@ op_db: List[OpInfo] = [ torch.bfloat16: 1e-2}),), skips=( # Reference: https://github.com/pytorch/pytorch/issues/56012 + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_normal', + dtypes=(torch.chalf,)), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', dtypes=[torch.complex64, torch.cdouble]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', - dtypes=[torch.complex64, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + dtypes=[torch.chalf, torch.complex64, torch.cdouble]), # RuntimeError: "div_true_cuda" not implemented for 'ComplexHalf' - DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=(torch.complex32,)), # alias, nn.functional.sigmoid, will produce (because of warning string saved): # "RuntimeError: Expected to not find "sigmoid" but found it" @@ -15600,7 +15915,7 @@ op_db: List[OpInfo] = [ dtypes=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), skips=( - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.bfloat16, torch.float16]), ), supports_inplace_autograd=False, @@ -15663,9 +15978,9 @@ op_db: List[OpInfo] = [ # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"), )), OpInfo("nn.functional.smooth_l1_loss", @@ -15713,14 +16028,14 @@ op_db: List[OpInfo] = [ # Reference: https://github.com/pytorch/pytorch/pull/50140#discussion_r552615345 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.bfloat16]), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', device_type='cpu', dtypes=[torch.bfloat16]), # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), ), # lgamma have multiple singularities at x <= 0 @@ -17137,7 +17452,7 @@ class OpInfoPythonRefInfo(OpInfo): assert isinstance(self.torch_opinfo, OpInfo) inherited = self.torch_opinfo._original_opinfo_args - ukwargs = _inherit_constructor_args(name, op, inherited, {}) + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) super(OpInfoPythonRefInfo, self).__init__(**ukwargs) class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo): @@ -17157,7 +17472,7 @@ class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo): assert isinstance(self.torch_opinfo, UnaryUfuncInfo) inherited = self.torch_opinfo._original_unary_ufunc_args - ukwargs = _inherit_constructor_args(name, op, inherited, {}) + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) super(ElementwiseUnaryPythonRefInfo, self).__init__(**ukwargs) @@ -17178,7 +17493,7 @@ class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo): assert isinstance(self.torch_opinfo, BinaryUfuncInfo) inherited = self.torch_opinfo._original_binary_ufunc_args - ukwargs = _inherit_constructor_args(name, op, inherited, {}) + ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) super(ElementwiseBinaryPythonRefInfo, self).__init__(**ukwargs) @@ -17188,17 +17503,29 @@ python_ref_db = [ # # Elementwise unary OpInfos # - # ElementwiseUnaryPythonRefInfo( - # '_refs.floor', - # torch_opinfo_name='floor', - # ), - # # - # # Elementwise binary OpInfos - # # - # ElementwiseBinaryPythonRefInfo( - # '_refs.add', - # torch_opinfo_name='add', - # ), + ElementwiseUnaryPythonRefInfo( + "_refs.floor", + torch_opinfo_name="floor", + ), + # + # Elementwise binary OpInfos + # + ElementwiseBinaryPythonRefInfo( + "_refs.add", + torch_opinfo_name="add", + decorators=( + DecorateInfo( + toleranceOverride( + { + torch.bfloat16: tol(atol=1, rtol=0), + torch.float16: tol(atol=1e-2, rtol=0), + } + ), + "TestCommon", + "test_python_reference_consistency", + ), + ), + ), ] # Common operator groupings diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 9cf1f39ff9ff..13940fcf3bb0 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -871,6 +871,10 @@ if IS_WINDOWS: # Dict of torch dtype -> NumPy dtype torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()} +torch_to_numpy_dtype_dict.update({ + torch.bfloat16: np.float32, + torch.complex32: np.complex64 +}) def skipIfRocm(fn): @wraps(fn)