[Testing] Add MPS to NATIVE_DEVICES (#153835)

This would allow me to enable more opinfo tests against MPS device eventually and supposed to be a very simple test, but actually required minor adjustments to lots of test files, namely:
- Introduce `all_mps_types_and` that is very similar to `all_types_and`, but skips `float64`
- Decorate lots of tests with `@dtypesIfMPS(*all_mps_types())`
- Skip `test_from_dlpack_noncontinguous` as it currently crashes (need to be fixed)
- Add lots of `expectedFailureIfMPS`
- Delete all `@onlyNativeDeviceTypesAnd("mps")`

<sarcasm> I love how well documented this variable are </sarcasm>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153835
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga
2025-08-05 18:57:35 +00:00
committed by PyTorch MergeBot
parent 0ba09a6d34
commit e06b110f73
8 changed files with 88 additions and 9 deletions

View File

@ -11,15 +11,16 @@ from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfMPS,
expectedFailureMPS,
instantiate_device_type_tests,
onlyCPU,
onlyNativeDeviceTypes,
onlyNativeDeviceTypesAnd,
skipLazy,
skipMeta,
skipXLA,
)
from torch.testing._internal.common_dtype import (
all_mps_types_and,
all_types_and,
all_types_and_complex_and,
complex_types,
@ -157,8 +158,11 @@ class TestViewOps(TestCase):
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
@dtypesIfMPS(*integral_types_and(torch.cfloat, torch.float, torch.half, torch.bool))
def test_view_dtype_new(self, device, dtype):
dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
if device.startswith("mps"):
del dtypes[torch.float64]
del dtypes[torch.bool]
def generate_inputs():
@ -271,6 +275,7 @@ class TestViewOps(TestCase):
# has a greater element size than the original dtype
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(*all_mps_types_and(torch.bool))
def test_view_dtype_upsize_errors(self, device, dtype):
dtype_size = torch._utils._element_size(dtype)
@ -372,6 +377,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*complex_types(), torch.complex32)
@dtypesIfMPS(torch.cfloat, torch.chalf)
def test_view_as_real(self, device, dtype):
def fn(contiguous_input=True):
t = torch.randn(3, 4, dtype=dtype, device=device)
@ -398,9 +404,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(
*integral_types_and(torch.half, torch.bfloat16, torch.bool, torch.float32)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool))
def test_view_tensor_split(self, device, dtype):
a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
a_split_dim0 = a.tensor_split(7, 0)
@ -412,6 +416,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
def test_view_tensor_hsplit(self, device, dtype):
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_hsplit = torch.hsplit(t, 2)
@ -422,6 +427,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
def test_view_tensor_vsplit(self, device, dtype):
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_vsplit = torch.vsplit(t, 2)
@ -432,6 +438,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
def test_view_tensor_dsplit(self, device, dtype):
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_dsplit = torch.dsplit(t, 2)
@ -440,9 +447,9 @@ class TestViewOps(TestCase):
t[2, 2, 2] = 7
self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
@onlyNativeDeviceTypesAnd("mps")
@onlyNativeDeviceTypes
@dtypes(*all_types_and(torch.half, torch.bfloat16))
@dtypesIfMPS(*integral_types_and(torch.half, torch.bool, torch.float32))
@dtypesIfMPS(*all_mps_types_and(torch.bool))
def test_imag_noncomplex(self, device, dtype):
t = torch.ones((5, 5), dtype=dtype, device=device)
@ -451,6 +458,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*complex_types())
@dtypesIfMPS(torch.cfloat)
def test_real_imag_view(self, device, dtype):
def compare_with_numpy(contiguous_input=True):
t = torch.randn(3, 3, dtype=dtype, device=device)
@ -481,6 +489,7 @@ class TestViewOps(TestCase):
self.assertEqual(a[5:].imag, a.imag[5:])
@onlyNativeDeviceTypes
@expectedFailureMPS
@dtypes(*complex_types())
def test_conj_imag_view(self, device, dtype) -> None:
t = _make_tensor((4, 5), dtype, device)
@ -512,6 +521,12 @@ class TestViewOps(TestCase):
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
)
)
@dtypesIfMPS(
*product(
[torch.cfloat, torch.chalf],
all_mps_types_and(torch.cfloat, torch.chalf, torch.bool),
)
)
@suppress_warnings
def test_set_real_imag(self, device, dtypes):
x = torch.randn(10, dtype=dtypes[0], device=device)