mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D34856571: [pytorch][PR] Replace get_all_
type macros with the ATen dispatch macros.
Test Plan: revert-hammer Differential Revision: D34856571 (3ded7b1da3
) Original commit changeset: 0dca038bcad5 Original Phabricator Diff: D34856571 (3ded7b1da3
) fbshipit-source-id: 594553fa0b710d78beba59d5d2b646f1f1270386 (cherry picked from commit 8090eb9b12dcf452a9e7dc01792a66fb91b563b6)
This commit is contained in:
committed by
PyTorch MergeBot
parent
f14a0be302
commit
ef066f0832
@ -16,7 +16,7 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_types_and_complex_and, complex_types, all_types_and, floating_and_complex_types_and,
|
||||
get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
|
||||
)
|
||||
|
||||
# TODO: replace this with make_tensor() in common_utils.py
|
||||
@ -121,14 +121,14 @@ class TestViewOps(TestCase):
|
||||
else:
|
||||
return x.transpose(dim0, dim1)
|
||||
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
@dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
|
||||
def test_conj_self(self, device, dtype):
|
||||
t = torch.ones(5, 5, device=device)
|
||||
s = t.conj()
|
||||
self.assertTrue(s is t)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
@dtypes(*get_all_dtypes(include_bfloat16=False))
|
||||
def test_view_dtype_new(self, device, dtype):
|
||||
dtypes = torch_to_numpy_dtype_dict.copy()
|
||||
del dtypes[torch.bool]
|
||||
@ -210,18 +210,18 @@ class TestViewOps(TestCase):
|
||||
# because view(dtype) does not support backward yet
|
||||
# TODO: Remove this when autograd support is added
|
||||
if dtype.is_floating_point or dtype.is_complex:
|
||||
for view_dtype in floating_and_complex_types_and(torch.half, torch.bfloat16):
|
||||
for view_dtype in [*get_all_fp_dtypes(), *get_all_complex_dtypes()]:
|
||||
t = make_tensor((5, 5, 64), dtype=dtype, device=device, low=-5, high=5, requires_grad=True)
|
||||
self.assertFalse(t.view(view_dtype).requires_grad)
|
||||
|
||||
# Test the extra error checks that happen when the view dtype
|
||||
# has a greater element size than the original dtype
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_view_dtype_upsize_errors(self, device, dtype):
|
||||
dtype_size = torch._utils._element_size(dtype)
|
||||
|
||||
for view_dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
|
||||
for view_dtype in get_all_dtypes():
|
||||
view_dtype_size = torch._utils._element_size(view_dtype)
|
||||
if view_dtype_size <= dtype_size:
|
||||
continue
|
||||
@ -302,7 +302,7 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(res.shape, torch.Size([0]))
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*complex_types(), torch.complex32)
|
||||
@dtypes(*get_all_complex_dtypes(include_complex32=True))
|
||||
def test_view_as_real(self, device, dtype):
|
||||
def fn(contiguous_input=True):
|
||||
t = torch.randn(3, 4, dtype=dtype, device=device)
|
||||
@ -340,7 +340,7 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(res.shape, torch.Size([2]))
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_view_tensor_split(self, device, dtype):
|
||||
a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
|
||||
a_split_dim0 = a.tensor_split(7, 0)
|
||||
@ -351,7 +351,7 @@ class TestViewOps(TestCase):
|
||||
self.assertTrue(self.is_view_of(a, a_split_dim1_tensor))
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_view_tensor_hsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_hsplit = torch.hsplit(t, 2)
|
||||
@ -361,7 +361,7 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(t_hsplit[1][2, 0, 2], t[2, 2, 2])
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_view_tensor_vsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_vsplit = torch.vsplit(t, 2)
|
||||
@ -371,7 +371,7 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(t_vsplit[1][0, 2, 2], t[2, 2, 2])
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_view_tensor_dsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_dsplit = torch.dsplit(t, 2)
|
||||
@ -381,7 +381,7 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
@dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
|
||||
def test_imag_noncomplex(self, device, dtype):
|
||||
t = torch.ones((5, 5), dtype=dtype, device=device)
|
||||
|
||||
@ -389,7 +389,7 @@ class TestViewOps(TestCase):
|
||||
torch.imag(t)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*complex_types())
|
||||
@dtypes(*get_all_complex_dtypes())
|
||||
def test_real_imag_view(self, device, dtype):
|
||||
def compare_with_numpy(contiguous_input=True):
|
||||
t = torch.randn(3, 3, dtype=dtype, device=device)
|
||||
@ -420,7 +420,7 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(a[5:].imag, a.imag[5:])
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*complex_types())
|
||||
@dtypes(*get_all_complex_dtypes())
|
||||
def test_conj_imag_view(self, device, dtype) -> None:
|
||||
t = _make_tensor((4, 5,), dtype, device)
|
||||
t_numpy_conj = torch.from_numpy(t.cpu().numpy().conj()).to(device=device)
|
||||
@ -445,7 +445,7 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(torch.add(b, c), b.add_(c))
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*product(complex_types(), all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)))
|
||||
@dtypes(*product(get_all_complex_dtypes(), get_all_dtypes()))
|
||||
@suppress_warnings
|
||||
def test_set_real_imag(self, device, dtypes):
|
||||
x = torch.randn(10, dtype=dtypes[0], device=device)
|
||||
@ -1255,7 +1255,7 @@ class TestOldViewOps(TestCase):
|
||||
scalar = torch.tensor(5, device=device)
|
||||
self.assertEqual(scalar, scalar.T)
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypes(*(torch.testing.get_all_dtypes()))
|
||||
def test_transposes(self, device, dtype):
|
||||
for op in ("T", "H", "mT", "mH", "adjoint"):
|
||||
shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),)
|
||||
@ -1271,7 +1271,7 @@ class TestOldViewOps(TestCase):
|
||||
t2 = t2.conj()
|
||||
self.assertEqual(t2, t1)
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypes(*(torch.testing.get_all_dtypes()))
|
||||
def test_transposes_errors(self, device, dtype):
|
||||
for op in ("H", "mT", "mH", "adjoint"):
|
||||
shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
|
||||
@ -1397,7 +1397,8 @@ class TestOldViewOps(TestCase):
|
||||
self.assertEqual(np_res, torch_res)
|
||||
|
||||
# TODO: are these view ops?
|
||||
@dtypes(*all_types_and_complex_and(torch.half))
|
||||
@dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
|
||||
get_all_complex_dtypes()))
|
||||
def test_atleast(self, device, dtype):
|
||||
self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype)
|
||||
self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype)
|
||||
@ -1534,7 +1535,7 @@ class TestOldViewOps(TestCase):
|
||||
self.assertEqual(res1, res2_numpy)
|
||||
|
||||
# Skip BFloat16 since numpy does not support it
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
@dtypes(*get_all_dtypes(include_bfloat16=False))
|
||||
def test_broadcast_to(self, device, dtype):
|
||||
def can_broadcast(s0, s1):
|
||||
# s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension
|
||||
@ -1637,7 +1638,7 @@ class TestOldViewOps(TestCase):
|
||||
self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1))
|
||||
self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1))
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_reshape_view_semantics(self, device, dtype):
|
||||
tensor = make_tensor((15, 4), dtype=dtype, device=device)
|
||||
target = (20, 3)
|
||||
@ -1664,7 +1665,7 @@ class TestOldViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
# Skip BFloat16 since numpy does not support it
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
@dtypes(*get_all_dtypes(include_bfloat16=False))
|
||||
def test_tensor_split_sections(self, device, dtype):
|
||||
input_sizes = [
|
||||
(0,),
|
||||
@ -1695,7 +1696,7 @@ class TestOldViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
# Skip BFloat16 since numpy does not support it
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
@dtypes(*get_all_dtypes(include_bfloat16=False))
|
||||
def test_tensor_split_indices(self, device, dtype):
|
||||
input_sizes = [
|
||||
(0,),
|
||||
@ -1774,20 +1775,20 @@ class TestOldViewOps(TestCase):
|
||||
|
||||
def test_resize_all_dtypes_and_devices(self, device):
|
||||
shape = (2, 2)
|
||||
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
|
||||
for dt in get_all_dtypes():
|
||||
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
|
||||
x.resize_(shape)
|
||||
self.assertEqual(shape, x.shape)
|
||||
|
||||
def test_resize_as_all_dtypes_and_devices(self, device):
|
||||
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
|
||||
for dt in get_all_dtypes():
|
||||
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
|
||||
y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
|
||||
x.resize_as_(y)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_view_all_dtypes_and_devices(self, device):
|
||||
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
|
||||
for dt in get_all_dtypes():
|
||||
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
|
||||
self.assertEqual(x.view(6).shape, [6])
|
||||
|
||||
|
Reference in New Issue
Block a user