mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D34856571: [pytorch][PR] Replace get_all_ type macros with the ATen dispatch macros.
Test Plan: revert-hammer Differential Revision: D34856571 (3ded7b1da3) Original commit changeset: 0dca038bcad5 Original Phabricator Diff: D34856571 (3ded7b1da3) fbshipit-source-id: 594553fa0b710d78beba59d5d2b646f1f1270386 (cherry picked from commit 8090eb9b12dcf452a9e7dc01792a66fb91b563b6)
This commit is contained in:
committed by
PyTorch MergeBot
parent
f14a0be302
commit
ef066f0832
@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyNativeDeviceTypes,
|
||||
dtypesIfCUDA, largeTensorTest)
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and, all_types, all_types_and
|
||||
from torch.testing._internal.common_dtype import get_all_dtypes
|
||||
|
||||
# TODO: replace with make_tensor
|
||||
def _generate_input(shape, dtype, device, with_extremal):
|
||||
@ -227,8 +227,9 @@ class TestShapeOps(TestCase):
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types())
|
||||
@dtypesIfCUDA(*all_types_and(torch.half))
|
||||
@dtypes(*get_all_dtypes(include_complex=False, include_bool=False, include_half=False,
|
||||
include_bfloat16=False))
|
||||
@dtypesIfCUDA(*get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
|
||||
def test_trace(self, device, dtype):
|
||||
def test(shape):
|
||||
tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
|
||||
@ -340,7 +341,7 @@ class TestShapeOps(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.clamp(X)
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_flip(self, device, dtype):
|
||||
make_from_data = partial(torch.tensor, device=device, dtype=dtype)
|
||||
make_from_size = partial(make_tensor, device=device, dtype=dtype)
|
||||
@ -439,7 +440,7 @@ class TestShapeOps(TestCase):
|
||||
for dims in test_dims:
|
||||
self.assertEqual(size, list(data.flip(dims).size()))
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_flip_errors(self, device, dtype):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device)
|
||||
data = make_arg((2, 2, 2))
|
||||
@ -457,7 +458,7 @@ class TestShapeOps(TestCase):
|
||||
def _rand_shape(self, dim, min_size, max_size):
|
||||
return tuple(torch.randint(min_size, max_size + 1, (dim,)))
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@dtypes(*get_all_dtypes())
|
||||
def test_flip_numpy(self, device, dtype):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device)
|
||||
|
||||
@ -566,7 +567,7 @@ class TestShapeOps(TestCase):
|
||||
t.nonzero()
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
@dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@dtypes(*get_all_dtypes(include_complex=False))
|
||||
def test_nonzero(self, device, dtype):
|
||||
|
||||
shapes = [
|
||||
|
||||
Reference in New Issue
Block a user