replace onlyOnCPUAndCUDA with onlyNativeDeviceTypes (#65201)

Summary:
Reference https://github.com/pytorch/pytorch/issues/53849

Replace `onlyOnCPUandCUDA` with `onlyNativeDeviceTypes` which includes `cpu, cuda and meta`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65201

Reviewed By: mrshenli

Differential Revision: D31299718

Pulled By: mruberry

fbshipit-source-id: 2d8356450c035d6a314209ab51b2c237583920fd
This commit is contained in:
kshitij12345
2021-11-01 09:21:20 -07:00
committed by Facebook GitHub Bot
parent 39ad7b670e
commit 885a8e53ba
18 changed files with 288 additions and 263 deletions

View File

@ -12,7 +12,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_utils import \
(TestCase, run_tests, suppress_warnings)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, onlyOnCPUAndCUDA)
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes)
from torch.testing._internal.common_dtype import (
get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
)
@ -125,7 +125,7 @@ class TestViewOps(TestCase):
s = t.conj()
self.assertTrue(s is t)
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*get_all_fp_dtypes(include_bfloat16=False), torch.complex64)
def test_view_dtype(self, device, dtype):
int_dtype = {
@ -175,7 +175,7 @@ class TestViewOps(TestCase):
self.assertFalse(t.view(torch.complex64).requires_grad)
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
def test_view_as_complex(self, device):
def fn(contiguous_input=True, dim0=0, dim1=1):
t = torch.randn(3, 2, 2, device=device)
@ -231,7 +231,7 @@ class TestViewOps(TestCase):
self.assertTrue(self.is_view_of(x, res))
self.assertEqual(res.shape, torch.Size([0]))
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*get_all_complex_dtypes(include_complex32=True))
def test_view_as_real(self, device, dtype):
def fn(contiguous_input=True):
@ -269,7 +269,7 @@ class TestViewOps(TestCase):
self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res))
self.assertEqual(res.shape, torch.Size([2]))
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*get_all_dtypes())
def test_view_tensor_split(self, device, dtype):
a = make_tensor((40, 30), device, dtype, low=-9, high=9)
@ -280,7 +280,7 @@ class TestViewOps(TestCase):
for a_split_dim1_tensor in a_split_dim1:
self.assertTrue(self.is_view_of(a, a_split_dim1_tensor))
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*get_all_dtypes())
def test_view_tensor_hsplit(self, device, dtype):
t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9)
@ -290,7 +290,7 @@ class TestViewOps(TestCase):
t[2, 2, 2] = 7
self.assertEqual(t_hsplit[1][2, 0, 2], t[2, 2, 2])
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*get_all_dtypes())
def test_view_tensor_vsplit(self, device, dtype):
t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9)
@ -300,7 +300,7 @@ class TestViewOps(TestCase):
t[2, 2, 2] = 7
self.assertEqual(t_vsplit[1][0, 2, 2], t[2, 2, 2])
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*get_all_dtypes())
def test_view_tensor_dsplit(self, device, dtype):
t = make_tensor((4, 4, 4), device, dtype, low=-9, high=9)
@ -310,7 +310,7 @@ class TestViewOps(TestCase):
t[2, 2, 2] = 7
self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
def test_real_imag_noncomplex(self, device, dtype):
t = torch.ones((5, 5), dtype=dtype, device=device)
@ -321,7 +321,7 @@ class TestViewOps(TestCase):
with self.assertRaises(RuntimeError):
torch.imag(t)
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*get_all_complex_dtypes())
def test_real_imag_view(self, device, dtype):
def compare_with_numpy(contiguous_input=True):
@ -352,7 +352,7 @@ class TestViewOps(TestCase):
self.assertEqual(a[5:].real, a.real[5:])
self.assertEqual(a[5:].imag, a.imag[5:])
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*get_all_complex_dtypes())
def test_conj_imag_view(self, device, dtype) -> None:
t = _make_tensor((4, 5,), dtype, device)
@ -367,7 +367,7 @@ class TestViewOps(TestCase):
self.assertEqual(v_imag, t_numpy_conj.imag)
self.assertTrue(v_imag.is_neg())
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
def test_conj_view_with_shared_memory(self, device) -> None:
a = _make_tensor((4, 5,), torch.cfloat, device)
b = a.conj()
@ -377,7 +377,7 @@ class TestViewOps(TestCase):
self.assertEqual(torch.add(b, c), torch.add(b, c, out=a))
self.assertEqual(torch.add(b, c), b.add_(c))
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(*product(get_all_complex_dtypes(), get_all_dtypes()))
@suppress_warnings
def test_set_real_imag(self, device, dtypes):
@ -666,7 +666,7 @@ class TestViewOps(TestCase):
test_writes_propagate(t, v3)
self.assertTrue(self.is_view_of_same_base(t, v3))
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
def test_flatten_nonview(self, device):
def assert_is_nonview(t, nv):
idx_t = (0,) * t.ndim
@ -885,7 +885,7 @@ class TestOldViewOps(TestCase):
self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
# TODO: this should be refactored into the view ops test suite
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
def test_reshape(self, device):
x = torch.randn(3, 3, device=device)
self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
@ -1183,7 +1183,7 @@ class TestOldViewOps(TestCase):
test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device)
test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device)
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
@dtypes(torch.int64, torch.float, torch.complex128)
def test_transpose_invalid(self, device, dtype):
for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
@ -1424,7 +1424,7 @@ class TestOldViewOps(TestCase):
x.set_(x.storage(), 0, x.size(), stride)
self.assertTrue(x.is_contiguous())
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
# Skip BFloat16 since numpy does not support it
@dtypes(*get_all_dtypes(include_bfloat16=False))
def test_tensor_split_sections(self, device, dtype):
@ -1455,7 +1455,7 @@ class TestOldViewOps(TestCase):
self.assertEqual(result_n, result1, msg=msg)
self.assertEqual(result_n, result2, msg=msg)
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
# Skip BFloat16 since numpy does not support it
@dtypes(*get_all_dtypes(include_bfloat16=False))
def test_tensor_split_indices(self, device, dtype):
@ -1500,7 +1500,7 @@ class TestOldViewOps(TestCase):
self.assertEqual(result_n, result_1, msg=msg)
self.assertEqual(result_n, result_2, msg=msg)
@onlyOnCPUAndCUDA
@onlyNativeDeviceTypes
def test_tensor_split_errors(self, device):
S = 10
test_cases = [