[Testing] Add MPS to NATIVE_DEVICES (#153835)

This would allow me to enable more opinfo tests against MPS device eventually and supposed to be a very simple test, but actually required minor adjustments to lots of test files, namely:
- Introduce `all_mps_types_and` that is very similar to `all_types_and`, but skips `float64`
- Decorate lots of tests with `@dtypesIfMPS(*all_mps_types())`
- Skip `test_from_dlpack_noncontinguous` as it currently crashes (need to be fixed)
- Add lots of `expectedFailureIfMPS`
- Delete all `@onlyNativeDeviceTypesAnd("mps")`

<sarcasm> I love how well documented this variable are </sarcasm>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153835
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga
2025-08-05 18:57:35 +00:00
committed by PyTorch MergeBot
parent 0ba09a6d34
commit e06b110f73
8 changed files with 88 additions and 9 deletions

View File

@ -5,6 +5,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
deviceCountAtLeast,
dtypes,
dtypesIfMPS,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
@ -13,10 +14,14 @@ from torch.testing._internal.common_device_type import (
skipCUDAIfRocm,
skipMeta,
)
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing._internal.common_dtype import (
all_mps_types_and,
all_types_and_complex_and,
)
from torch.testing._internal.common_utils import (
IS_JETSON,
run_tests,
skipIfMPS,
skipIfTorchDynamo,
TestCase,
)
@ -55,6 +60,7 @@ class TestTorchDlPack(TestCase):
torch.uint64,
)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
def test_dlpack_capsule_conversion(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(to_dlpack(x))
@ -72,6 +78,7 @@ class TestTorchDlPack(TestCase):
torch.uint64,
)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
def test_dlpack_protocol_conversion(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(x)
@ -80,7 +87,8 @@ class TestTorchDlPack(TestCase):
@skipMeta
@onlyNativeDeviceTypes
def test_dlpack_shared_storage(self, device):
x = make_tensor((5,), dtype=torch.float64, device=device)
dtype = torch.bfloat16 if device.startswith("mps") else torch.float64
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(to_dlpack(x))
z[0] = z[0] + 20.0
self.assertEqual(z, x)
@ -120,12 +128,14 @@ class TestTorchDlPack(TestCase):
torch.uint64,
)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
def test_from_dlpack(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)
self.assertEqual(x, y)
@skipMeta
@skipIfMPS # MPS crashes with noncontiguous now
@onlyNativeDeviceTypes
@dtypes(
*all_types_and_complex_and(
@ -189,6 +199,7 @@ class TestTorchDlPack(TestCase):
torch.uint64,
)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
def test_from_dlpack_dtype(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)