mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Testing] Add MPS to NATIVE_DEVICES (#153835)
This would allow me to enable more opinfo tests against MPS device eventually and supposed to be a very simple test, but actually required minor adjustments to lots of test files, namely: - Introduce `all_mps_types_and` that is very similar to `all_types_and`, but skips `float64` - Decorate lots of tests with `@dtypesIfMPS(*all_mps_types())` - Skip `test_from_dlpack_noncontinguous` as it currently crashes (need to be fixed) - Add lots of `expectedFailureIfMPS` - Delete all `@onlyNativeDeviceTypesAnd("mps")` <sarcasm> I love how well documented this variable are </sarcasm> Pull Request resolved: https://github.com/pytorch/pytorch/pull/153835 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
0ba09a6d34
commit
e06b110f73
@ -5,6 +5,7 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
deviceCountAtLeast,
|
||||
dtypes,
|
||||
dtypesIfMPS,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
@ -13,10 +14,14 @@ from torch.testing._internal.common_device_type import (
|
||||
skipCUDAIfRocm,
|
||||
skipMeta,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_mps_types_and,
|
||||
all_types_and_complex_and,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_JETSON,
|
||||
run_tests,
|
||||
skipIfMPS,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
@ -55,6 +60,7 @@ class TestTorchDlPack(TestCase):
|
||||
torch.uint64,
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
|
||||
def test_dlpack_capsule_conversion(self, device, dtype):
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
z = from_dlpack(to_dlpack(x))
|
||||
@ -72,6 +78,7 @@ class TestTorchDlPack(TestCase):
|
||||
torch.uint64,
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
|
||||
def test_dlpack_protocol_conversion(self, device, dtype):
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
z = from_dlpack(x)
|
||||
@ -80,7 +87,8 @@ class TestTorchDlPack(TestCase):
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
def test_dlpack_shared_storage(self, device):
|
||||
x = make_tensor((5,), dtype=torch.float64, device=device)
|
||||
dtype = torch.bfloat16 if device.startswith("mps") else torch.float64
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
z = from_dlpack(to_dlpack(x))
|
||||
z[0] = z[0] + 20.0
|
||||
self.assertEqual(z, x)
|
||||
@ -120,12 +128,14 @@ class TestTorchDlPack(TestCase):
|
||||
torch.uint64,
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
|
||||
def test_from_dlpack(self, device, dtype):
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
y = torch.from_dlpack(x)
|
||||
self.assertEqual(x, y)
|
||||
|
||||
@skipMeta
|
||||
@skipIfMPS # MPS crashes with noncontiguous now
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(
|
||||
*all_types_and_complex_and(
|
||||
@ -189,6 +199,7 @@ class TestTorchDlPack(TestCase):
|
||||
torch.uint64,
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
|
||||
def test_from_dlpack_dtype(self, device, dtype):
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
y = torch.from_dlpack(x)
|
||||
|
Reference in New Issue
Block a user