Revert "[testing] Update dispatch macros"

This reverts commit eed19a0f38a81015ca50dd25e997b1c6e223d46b.

Reverted https://github.com/pytorch/pytorch/pull/74289 on behalf of https://github.com/malfet
This commit is contained in:
PyTorch MergeBot
2022-03-30 19:52:37 +00:00
parent c6102048b8
commit 2e4152b118
21 changed files with 417 additions and 381 deletions

View File

@ -16,7 +16,7 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta)
from torch.testing._internal.common_dtype import (
all_types_and_complex_and, complex_types, all_types_and, floating_and_complex_types_and,
get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
)
# TODO: replace this with make_tensor() in common_utils.py
@ -121,14 +121,14 @@ class TestViewOps(TestCase):
else:
return x.transpose(dim0, dim1)
@dtypes(*all_types_and(torch.half, torch.bfloat16))
@dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
def test_conj_self(self, device, dtype):
t = torch.ones(5, 5, device=device)
s = t.conj()
self.assertTrue(s is t)
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
@dtypes(*get_all_dtypes(include_bfloat16=False))
def test_view_dtype_new(self, device, dtype):
dtypes = torch_to_numpy_dtype_dict.copy()
del dtypes[torch.bool]
@ -210,18 +210,18 @@ class TestViewOps(TestCase):
# because view(dtype) does not support backward yet
# TODO: Remove this when autograd support is added
if dtype.is_floating_point or dtype.is_complex:
for view_dtype in floating_and_complex_types_and(torch.half, torch.bfloat16):
for view_dtype in [*get_all_fp_dtypes(), *get_all_complex_dtypes()]:
t = make_tensor((5, 5, 64), dtype=dtype, device=device, low=-5, high=5, requires_grad=True)
self.assertFalse(t.view(view_dtype).requires_grad)
# Test the extra error checks that happen when the view dtype
# has a greater element size than the original dtype
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypes(*get_all_dtypes())
def test_view_dtype_upsize_errors(self, device, dtype):
dtype_size = torch._utils._element_size(dtype)
for view_dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
for view_dtype in get_all_dtypes():
view_dtype_size = torch._utils._element_size(view_dtype)
if view_dtype_size <= dtype_size:
continue
@ -302,7 +302,7 @@ class TestViewOps(TestCase):
self.assertEqual(res.shape, torch.Size([0]))
@onlyNativeDeviceTypes
@dtypes(*complex_types(), torch.complex32)
@dtypes(*get_all_complex_dtypes(include_complex32=True))
def test_view_as_real(self, device, dtype):
def fn(contiguous_input=True):
t = torch.randn(3, 4, dtype=dtype, device=device)
@ -340,7 +340,7 @@ class TestViewOps(TestCase):
self.assertEqual(res.shape, torch.Size([2]))
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypes(*get_all_dtypes())
def test_view_tensor_split(self, device, dtype):
a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
a_split_dim0 = a.tensor_split(7, 0)
@ -351,7 +351,7 @@ class TestViewOps(TestCase):
self.assertTrue(self.is_view_of(a, a_split_dim1_tensor))
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypes(*get_all_dtypes())
def test_view_tensor_hsplit(self, device, dtype):
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_hsplit = torch.hsplit(t, 2)
@ -361,7 +361,7 @@ class TestViewOps(TestCase):
self.assertEqual(t_hsplit[1][2, 0, 2], t[2, 2, 2])
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypes(*get_all_dtypes())
def test_view_tensor_vsplit(self, device, dtype):
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_vsplit = torch.vsplit(t, 2)
@ -371,7 +371,7 @@ class TestViewOps(TestCase):
self.assertEqual(t_vsplit[1][0, 2, 2], t[2, 2, 2])
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypes(*get_all_dtypes())
def test_view_tensor_dsplit(self, device, dtype):
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_dsplit = torch.dsplit(t, 2)
@ -381,7 +381,7 @@ class TestViewOps(TestCase):
self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
@onlyNativeDeviceTypes
@dtypes(*all_types_and(torch.half, torch.bfloat16))
@dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
def test_imag_noncomplex(self, device, dtype):
t = torch.ones((5, 5), dtype=dtype, device=device)
@ -389,7 +389,7 @@ class TestViewOps(TestCase):
torch.imag(t)
@onlyNativeDeviceTypes
@dtypes(*complex_types())
@dtypes(*get_all_complex_dtypes())
def test_real_imag_view(self, device, dtype):
def compare_with_numpy(contiguous_input=True):
t = torch.randn(3, 3, dtype=dtype, device=device)
@ -420,7 +420,7 @@ class TestViewOps(TestCase):
self.assertEqual(a[5:].imag, a.imag[5:])
@onlyNativeDeviceTypes
@dtypes(*complex_types())
@dtypes(*get_all_complex_dtypes())
def test_conj_imag_view(self, device, dtype) -> None:
t = _make_tensor((4, 5,), dtype, device)
t_numpy_conj = torch.from_numpy(t.cpu().numpy().conj()).to(device=device)
@ -445,7 +445,7 @@ class TestViewOps(TestCase):
self.assertEqual(torch.add(b, c), b.add_(c))
@onlyNativeDeviceTypes
@dtypes(*product(complex_types(), all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)))
@dtypes(*product(get_all_complex_dtypes(), get_all_dtypes()))
@suppress_warnings
def test_set_real_imag(self, device, dtypes):
x = torch.randn(10, dtype=dtypes[0], device=device)
@ -1264,7 +1264,7 @@ class TestOldViewOps(TestCase):
scalar = torch.tensor(5, device=device)
self.assertEqual(scalar, scalar.T)
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypes(*(torch.testing.get_all_dtypes()))
def test_transposes(self, device, dtype):
for op in ("T", "H", "mT", "mH", "adjoint"):
shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),)
@ -1280,7 +1280,7 @@ class TestOldViewOps(TestCase):
t2 = t2.conj()
self.assertEqual(t2, t1)
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypes(*(torch.testing.get_all_dtypes()))
def test_transposes_errors(self, device, dtype):
for op in ("H", "mT", "mH", "adjoint"):
shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
@ -1406,7 +1406,8 @@ class TestOldViewOps(TestCase):
self.assertEqual(np_res, torch_res)
# TODO: are these view ops?
@dtypes(*all_types_and_complex_and(torch.half))
@dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
get_all_complex_dtypes()))
def test_atleast(self, device, dtype):
self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype)
self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype)
@ -1543,7 +1544,7 @@ class TestOldViewOps(TestCase):
self.assertEqual(res1, res2_numpy)
# Skip BFloat16 since numpy does not support it
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
@dtypes(*get_all_dtypes(include_bfloat16=False))
def test_broadcast_to(self, device, dtype):
def can_broadcast(s0, s1):
# s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension
@ -1646,7 +1647,7 @@ class TestOldViewOps(TestCase):
self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1))
self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1))
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypes(*get_all_dtypes())
def test_reshape_view_semantics(self, device, dtype):
tensor = make_tensor((15, 4), dtype=dtype, device=device)
target = (20, 3)
@ -1673,7 +1674,7 @@ class TestOldViewOps(TestCase):
@onlyNativeDeviceTypes
# Skip BFloat16 since numpy does not support it
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
@dtypes(*get_all_dtypes(include_bfloat16=False))
def test_tensor_split_sections(self, device, dtype):
input_sizes = [
(0,),
@ -1704,7 +1705,7 @@ class TestOldViewOps(TestCase):
@onlyNativeDeviceTypes
# Skip BFloat16 since numpy does not support it
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
@dtypes(*get_all_dtypes(include_bfloat16=False))
def test_tensor_split_indices(self, device, dtype):
input_sizes = [
(0,),
@ -1783,20 +1784,20 @@ class TestOldViewOps(TestCase):
def test_resize_all_dtypes_and_devices(self, device):
shape = (2, 2)
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
for dt in get_all_dtypes():
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
x.resize_(shape)
self.assertEqual(shape, x.shape)
def test_resize_as_all_dtypes_and_devices(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
for dt in get_all_dtypes():
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
x.resize_as_(y)
self.assertEqual(y.shape, x.shape)
def test_view_all_dtypes_and_devices(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
for dt in get_all_dtypes():
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
self.assertEqual(x.view(6).shape, [6])