Revert D34856571: [pytorch][PR] Replace get_all_ type macros with the ATen dispatch macros.

Test Plan: revert-hammer

Differential Revision:
D34856571 (3ded7b1da3)

Original commit changeset: 0dca038bcad5

Original Phabricator Diff: D34856571 (3ded7b1da3)

fbshipit-source-id: 594553fa0b710d78beba59d5d2b646f1f1270386
(cherry picked from commit 8090eb9b12dcf452a9e7dc01792a66fb91b563b6)
This commit is contained in:
Nikita Shulga
2022-03-15 14:36:04 -07:00
committed by PyTorch MergeBot
parent f14a0be302
commit ef066f0832
21 changed files with 410 additions and 375 deletions

View File

@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes,
dtypesIfCUDA, largeTensorTest)
from torch.testing._internal.common_dtype import all_types_and_complex_and, all_types, all_types_and
from torch.testing._internal.common_dtype import get_all_dtypes
# TODO: replace with make_tensor
def _generate_input(shape, dtype, device, with_extremal):
@ -227,8 +227,9 @@ class TestShapeOps(TestCase):
self.assertEqual(expected, result)
@onlyNativeDeviceTypes
@dtypes(*all_types())
@dtypesIfCUDA(*all_types_and(torch.half))
@dtypes(*get_all_dtypes(include_complex=False, include_bool=False, include_half=False,
include_bfloat16=False))
@dtypesIfCUDA(*get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
def test_trace(self, device, dtype):
def test(shape):
tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
@ -340,7 +341,7 @@ class TestShapeOps(TestCase):
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.clamp(X)
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@dtypes(*get_all_dtypes())
def test_flip(self, device, dtype):
make_from_data = partial(torch.tensor, device=device, dtype=dtype)
make_from_size = partial(make_tensor, device=device, dtype=dtype)
@ -439,7 +440,7 @@ class TestShapeOps(TestCase):
for dims in test_dims:
self.assertEqual(size, list(data.flip(dims).size()))
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@dtypes(*get_all_dtypes())
def test_flip_errors(self, device, dtype):
make_arg = partial(make_tensor, dtype=dtype, device=device)
data = make_arg((2, 2, 2))
@ -457,7 +458,7 @@ class TestShapeOps(TestCase):
def _rand_shape(self, dim, min_size, max_size):
return tuple(torch.randint(min_size, max_size + 1, (dim,)))
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@dtypes(*get_all_dtypes())
def test_flip_numpy(self, device, dtype):
make_arg = partial(make_tensor, dtype=dtype, device=device)
@ -566,7 +567,7 @@ class TestShapeOps(TestCase):
t.nonzero()
self.assertEqual(len(w), 0)
@dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16))
@dtypes(*get_all_dtypes(include_complex=False))
def test_nonzero(self, device, dtype):
shapes = [