mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Testing] Add MPS to NATIVE_DEVICES (#153835)
This would allow me to enable more opinfo tests against MPS device eventually and supposed to be a very simple test, but actually required minor adjustments to lots of test files, namely: - Introduce `all_mps_types_and` that is very similar to `all_types_and`, but skips `float64` - Decorate lots of tests with `@dtypesIfMPS(*all_mps_types())` - Skip `test_from_dlpack_noncontinguous` as it currently crashes (need to be fixed) - Add lots of `expectedFailureIfMPS` - Delete all `@onlyNativeDeviceTypesAnd("mps")` <sarcasm> I love how well documented this variable are </sarcasm> Pull Request resolved: https://github.com/pytorch/pytorch/pull/153835 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
0ba09a6d34
commit
e06b110f73
@ -11,15 +11,16 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfMPS,
|
||||
expectedFailureMPS,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyNativeDeviceTypes,
|
||||
onlyNativeDeviceTypesAnd,
|
||||
skipLazy,
|
||||
skipMeta,
|
||||
skipXLA,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_mps_types_and,
|
||||
all_types_and,
|
||||
all_types_and_complex_and,
|
||||
complex_types,
|
||||
@ -157,8 +158,11 @@ class TestViewOps(TestCase):
|
||||
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
@dtypesIfMPS(*integral_types_and(torch.cfloat, torch.float, torch.half, torch.bool))
|
||||
def test_view_dtype_new(self, device, dtype):
|
||||
dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||
if device.startswith("mps"):
|
||||
del dtypes[torch.float64]
|
||||
del dtypes[torch.bool]
|
||||
|
||||
def generate_inputs():
|
||||
@ -271,6 +275,7 @@ class TestViewOps(TestCase):
|
||||
# has a greater element size than the original dtype
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool))
|
||||
def test_view_dtype_upsize_errors(self, device, dtype):
|
||||
dtype_size = torch._utils._element_size(dtype)
|
||||
|
||||
@ -372,6 +377,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*complex_types(), torch.complex32)
|
||||
@dtypesIfMPS(torch.cfloat, torch.chalf)
|
||||
def test_view_as_real(self, device, dtype):
|
||||
def fn(contiguous_input=True):
|
||||
t = torch.randn(3, 4, dtype=dtype, device=device)
|
||||
@ -398,9 +404,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(
|
||||
*integral_types_and(torch.half, torch.bfloat16, torch.bool, torch.float32)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool))
|
||||
def test_view_tensor_split(self, device, dtype):
|
||||
a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
|
||||
a_split_dim0 = a.tensor_split(7, 0)
|
||||
@ -412,6 +416,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
|
||||
def test_view_tensor_hsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_hsplit = torch.hsplit(t, 2)
|
||||
@ -422,6 +427,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
|
||||
def test_view_tensor_vsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_vsplit = torch.vsplit(t, 2)
|
||||
@ -432,6 +438,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
|
||||
def test_view_tensor_dsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_dsplit = torch.dsplit(t, 2)
|
||||
@ -440,9 +447,9 @@ class TestViewOps(TestCase):
|
||||
t[2, 2, 2] = 7
|
||||
self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
|
||||
|
||||
@onlyNativeDeviceTypesAnd("mps")
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
@dtypesIfMPS(*integral_types_and(torch.half, torch.bool, torch.float32))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool))
|
||||
def test_imag_noncomplex(self, device, dtype):
|
||||
t = torch.ones((5, 5), dtype=dtype, device=device)
|
||||
|
||||
@ -451,6 +458,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*complex_types())
|
||||
@dtypesIfMPS(torch.cfloat)
|
||||
def test_real_imag_view(self, device, dtype):
|
||||
def compare_with_numpy(contiguous_input=True):
|
||||
t = torch.randn(3, 3, dtype=dtype, device=device)
|
||||
@ -481,6 +489,7 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(a[5:].imag, a.imag[5:])
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@expectedFailureMPS
|
||||
@dtypes(*complex_types())
|
||||
def test_conj_imag_view(self, device, dtype) -> None:
|
||||
t = _make_tensor((4, 5), dtype, device)
|
||||
@ -512,6 +521,12 @@ class TestViewOps(TestCase):
|
||||
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(
|
||||
*product(
|
||||
[torch.cfloat, torch.chalf],
|
||||
all_mps_types_and(torch.cfloat, torch.chalf, torch.bool),
|
||||
)
|
||||
)
|
||||
@suppress_warnings
|
||||
def test_set_real_imag(self, device, dtypes):
|
||||
x = torch.randn(10, dtype=dtypes[0], device=device)
|
||||
|
Reference in New Issue
Block a user