[Testing] Add MPS to NATIVE_DEVICES (#153835)

This would allow me to enable more opinfo tests against MPS device eventually and supposed to be a very simple test, but actually required minor adjustments to lots of test files, namely:
- Introduce `all_mps_types_and` that is very similar to `all_types_and`, but skips `float64`
- Decorate lots of tests with `@dtypesIfMPS(*all_mps_types())`
- Skip `test_from_dlpack_noncontinguous` as it currently crashes (need to be fixed)
- Add lots of `expectedFailureIfMPS`
- Delete all `@onlyNativeDeviceTypesAnd("mps")`

<sarcasm> I love how well documented this variable are </sarcasm>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153835
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga
2025-08-05 18:57:35 +00:00
committed by PyTorch MergeBot
parent 0ba09a6d34
commit e06b110f73
8 changed files with 88 additions and 9 deletions

View File

@ -2842,6 +2842,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
@parametrize_test("strided", [False, True])
# Test with both contiguous and non-contiguous inputs.
@parametrize_test("contiguous", [False, True])
@expectedFailureMPS # No double support
def test_conv_backend(
self,
device,

View File

@ -504,6 +504,7 @@ class TestPoolingNN(NNTestCase):
class TestPoolingNNDeviceType(NNTestCase):
@expectedFailureMPS # No double, float shape prop does not work
@onlyNativeDeviceTypes
@dtypes(torch.float, torch.double)
def test_adaptive_pooling_zero_batch(self, dtype, device):
@ -523,6 +524,7 @@ class TestPoolingNNDeviceType(NNTestCase):
# when output_size = 0, in adaptive_{avg, max}_pool and its variants.
# These tests are explicitly written because ErrorInputs does not support backward calls
# Issue: https://github.com/pytorch/pytorch/issues/78868
@expectedFailureMPS # No double, float shape prop does not work
@onlyNativeDeviceTypes
@dtypes(torch.float32, torch.float64)
@dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16, torch.float16)
@ -556,6 +558,7 @@ class TestPoolingNNDeviceType(NNTestCase):
with self.assertRaisesRegex(RuntimeError, error_msg):
fn(input2, output_size).sum().backward()
@expectedFailureMPS # Error message does not match
@onlyNativeDeviceTypes
def test_adaptive_avg_pooling_backward_fails(self, device):
grad_output = torch.randn(1, 2, 7, device=device)
@ -582,6 +585,7 @@ class TestPoolingNNDeviceType(NNTestCase):
with self.assertRaisesRegex(RuntimeError, "expected dimensions"):
torch.ops.aten.adaptive_max_pool3d_backward(grad_output, input, indices)
@expectedFailureMPS # Op not implemented
@onlyNativeDeviceTypes
def test_FractionalMaxPool2d_zero_batch(self, device):
mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
@ -592,6 +596,7 @@ class TestPoolingNNDeviceType(NNTestCase):
inp = torch.randn(1, 0, 50, 32, device=device)
mod(inp)
@expectedFailureMPS # Op not implemented
@onlyNativeDeviceTypes
def test_FractionalMaxPool3d_zero_batch(self, device):
mod = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)).to(device)
@ -602,6 +607,7 @@ class TestPoolingNNDeviceType(NNTestCase):
inp = torch.randn(1, 0, 50, 32, 32, device=device)
mod(inp)
@expectedFailureMPS # Op not implemented
@onlyNativeDeviceTypes
def test_FractionalMaxPool2d_zero_out_size(self, device):
mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1])
@ -609,6 +615,7 @@ class TestPoolingNNDeviceType(NNTestCase):
out = mod(inp)
self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device))
@expectedFailureMPS # Op not implemented
@onlyNativeDeviceTypes
def test_FractionalMaxPool3d_zero_out_size(self, device):
mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1])
@ -616,6 +623,7 @@ class TestPoolingNNDeviceType(NNTestCase):
out = mod(inp)
self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device))
@expectedFailureMPS # Op not implemented
@onlyNativeDeviceTypes
def test_FractionalMaxPool2d_zero_samples(self, device):
samples = torch.rand([0, 16, 2], device=device)
@ -630,6 +638,7 @@ class TestPoolingNNDeviceType(NNTestCase):
with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"):
mod(inp1)
@expectedFailureMPS # Op not implemented
@onlyNativeDeviceTypes
def test_FractionalMaxPool3d_zero_samples(self, device):
samples = torch.rand([0, 16, 3], device=device)
@ -823,6 +832,7 @@ torch.cuda.synchronize()
else:
unpool(output, indices)
@expectedFailureMPS
@onlyNativeDeviceTypes
def test_AdaptiveMaxPool_zero_batch_dim(self, device):
inp = torch.randn(0, 16, 50, device=device)
@ -962,6 +972,7 @@ torch.cuda.synchronize()
c = out.size(1)
self.assertEqual(out.stride(), [c, 1, 1, 1, 1])
@expectedFailureMPS # Runtime Error not raised for mps
@expectedFailureMeta # Runtime Error not raised for meta
@onlyNativeDeviceTypes
@dtypes(torch.uint8, torch.int8, torch.short, torch.int, torch.long)
@ -976,6 +987,7 @@ torch.cuda.synchronize()
with self.assertRaisesRegex(RuntimeError, "not implemented"):
module(input)
@expectedFailureMPS # TODO: fixme
@onlyNativeDeviceTypes
@gcIfJetson
@dtypes(torch.float, torch.double)
@ -1123,6 +1135,7 @@ torch.cuda.synchronize()
helper(1, 100000, 32, 32, ks=4)
helper(1, 100000, 1, 4, ks=(1, 4)) # test for max_pool1d
@expectedFailureMPS # TODO: Fixme
@onlyNativeDeviceTypes
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@ -1198,6 +1211,7 @@ torch.cuda.synchronize()
torch.channels_last,
)
@expectedFailureMPS # TODO: Fixme
@onlyNativeDeviceTypes
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@ -1722,6 +1736,7 @@ torch.cuda.synchronize()
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@dtypes(torch.float)
@expectedFailureMPS # Exception not raise
@onlyNativeDeviceTypes # TODO: Fails on XLA
@gcIfJetson
def test_max_pool_nan_inf(self, device, dtype):
@ -1758,6 +1773,7 @@ torch.cuda.synchronize()
res2 = fn(x2, 1 if adaptive else 3)
self.assertTrue(math.isinf(res2.item()))
@expectedFailureMPS # float64
@expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta
@onlyNativeDeviceTypes
def test_fractional_max_pool2d(self, device):
@ -1820,6 +1836,7 @@ torch.cuda.synchronize()
grad_output, input, kernel_size, output_size, indices
)
@expectedFailureMPS # float64
@expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta
@onlyNativeDeviceTypes
def test_fractional_max_pool3d(self, device):
@ -1867,6 +1884,7 @@ torch.cuda.synchronize()
x, (2, 2, 2), output_size=output_size, _random_samples=samples
)
@expectedFailureMPS # Not implemented
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@dtypes(torch.float)
@onlyNativeDeviceTypes # TODO: Fails on XLA
@ -1896,6 +1914,7 @@ torch.cuda.synchronize()
res2.backward(torch.randn_like(res2))
self.assertTrue(math.isinf(res2.item()))
@expectedFailureMPS # TODO: Fix me
@onlyNativeDeviceTypes # TODO: RuntimeError message different on XLA
def test_pooling_zero_stride(self, device):
for op in ("max", "avg"):

View File

@ -5,6 +5,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
deviceCountAtLeast,
dtypes,
dtypesIfMPS,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
@ -13,10 +14,14 @@ from torch.testing._internal.common_device_type import (
skipCUDAIfRocm,
skipMeta,
)
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing._internal.common_dtype import (
all_mps_types_and,
all_types_and_complex_and,
)
from torch.testing._internal.common_utils import (
IS_JETSON,
run_tests,
skipIfMPS,
skipIfTorchDynamo,
TestCase,
)
@ -55,6 +60,7 @@ class TestTorchDlPack(TestCase):
torch.uint64,
)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
def test_dlpack_capsule_conversion(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(to_dlpack(x))
@ -72,6 +78,7 @@ class TestTorchDlPack(TestCase):
torch.uint64,
)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
def test_dlpack_protocol_conversion(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(x)
@ -80,7 +87,8 @@ class TestTorchDlPack(TestCase):
@skipMeta
@onlyNativeDeviceTypes
def test_dlpack_shared_storage(self, device):
x = make_tensor((5,), dtype=torch.float64, device=device)
dtype = torch.bfloat16 if device.startswith("mps") else torch.float64
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(to_dlpack(x))
z[0] = z[0] + 20.0
self.assertEqual(z, x)
@ -120,12 +128,14 @@ class TestTorchDlPack(TestCase):
torch.uint64,
)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
def test_from_dlpack(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)
self.assertEqual(x, y)
@skipMeta
@skipIfMPS # MPS crashes with noncontiguous now
@onlyNativeDeviceTypes
@dtypes(
*all_types_and_complex_and(
@ -189,6 +199,7 @@ class TestTorchDlPack(TestCase):
torch.uint64,
)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
def test_from_dlpack_dtype(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)

View File

@ -16,6 +16,7 @@ from torch.testing._internal.common_device_type import (
dtypesIfCPU,
dtypesIfCUDA,
dtypesIfMPS,
expectedFailureMPS,
instantiate_device_type_tests,
onlyCUDA,
onlyNativeDeviceTypes,
@ -183,6 +184,7 @@ class TestIndexing(TestCase):
@onlyNativeDeviceTypes
@dtypes(torch.half, torch.double)
@dtypesIfMPS(torch.half) # TODO: add bf16 there?
def test_advancedindex(self, device, dtype):
# Tests for Integer Array Indexing, Part I - Purely integer array
# indexing
@ -1193,6 +1195,7 @@ class TestIndexing(TestCase):
out_cpu = func1(t, ind, val)
self.assertEqual(out_cuda.cpu(), out_cpu)
@expectedFailureMPS # Doubles not supported
@onlyNativeDeviceTypes
def test_index_put_accumulate_duplicate_indices(self, device):
for i in range(1, 512):

View File

@ -8766,6 +8766,7 @@ class TestNNDeviceType(NNTestCase):
@onlyNativeDeviceTypes
@dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64)
@dtypesIfMPS(torch.float16, torch.bfloat16, torch.float32)
def test_rmsnorm_epsilon(self, device, dtype):
def rms_norm_reference_fn(i, normalized_shape):
eps = torch.finfo(i.dtype).eps
@ -8940,6 +8941,7 @@ class TestNNDeviceType(NNTestCase):
Y_cpu = group_norm(X.cpu())
self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5)
@expectedFailureMPS # Double is not supported on MPS
@onlyNativeDeviceTypes
@dtypes(torch.float64, torch.complex128)
def test_pad(self, device, dtype):
@ -8971,6 +8973,7 @@ class TestNNDeviceType(NNTestCase):
out.fill_(4)
self.assertTrue(torch.all(torch.abs(inputs) < 2))
@expectedFailureMPS # Unsupported float64/complex128
@onlyNativeDeviceTypes
@dtypes(torch.float64, torch.complex128)
def test_ReplicationPad_empty(self, device, dtype):
@ -9109,6 +9112,7 @@ class TestNNDeviceType(NNTestCase):
self.assertEqual(inp1.grad, torch.zeros_like(inp1))
self.assertEqual(inp2.grad, torch.zeros_like(inp2))
@expectedFailureMPS # Double not supported
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
@onlyNativeDeviceTypes
def test_TransformerEncoderLayer_empty(self, device):
@ -9138,6 +9142,7 @@ class TestNNDeviceType(NNTestCase):
_test_module_empty_input(self, encoder_layer, input, check_size=False)
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
@expectedFailureMPS # Float64 is not supported
@onlyNativeDeviceTypes
def test_TransformerEncoder_empty(self, device):
for batch_first, input_shape in [(True, (0, 10, 512)),
@ -9148,6 +9153,7 @@ class TestNNDeviceType(NNTestCase):
_test_module_empty_input(self, transformer_encoder, input, check_size=False)
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
@expectedFailureMPS # Float64 is not supported
@onlyNativeDeviceTypes
def test_TransformerDecoderLayer_empty(self, device):
for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
@ -9158,6 +9164,7 @@ class TestNNDeviceType(NNTestCase):
self._test_module_empty_inputs(decoder_layer, [tgt, memory])
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
@expectedFailureMPS # Float64 is not supported
@onlyNativeDeviceTypes
def test_TransformerDecoder_empty(self, device):
for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
@ -9169,6 +9176,7 @@ class TestNNDeviceType(NNTestCase):
self._test_module_empty_inputs(transformer_decoder, [tgt, memory])
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
@expectedFailureMPS # Float64 is not supported
@onlyNativeDeviceTypes
def test_Transformer_empty(self, device):
for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]:
@ -9304,6 +9312,7 @@ class TestNNDeviceType(NNTestCase):
self.assertEqual(x.grad, ref_x.grad)
@expectedFailureMPS # Unimplemented margin_loss
@onlyNativeDeviceTypes
@dtypes(torch.float, torch.double)
def test_MarginLoss_empty(self, device, dtype):
@ -9370,6 +9379,7 @@ class TestNNDeviceType(NNTestCase):
with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'):
F.mse_loss(i, t)
@expectedFailureMPS # TODO: Fixme, and raise assert on empty tensor
@onlyNativeDeviceTypes
def test_Unfold_empty(self, device):
inp = torch.randn(0, 3, 3, 4, device=device)
@ -9593,6 +9603,7 @@ class TestNNDeviceType(NNTestCase):
verify_reduction_scalars(input, reduction, output)
# verify that bogus reduction strings are errors
@expectedFailureMPS # CTCLoss unimplemented
@onlyNativeDeviceTypes
def test_invalid_reduction_strings(self, device):
input = torch.randn(3, 5, requires_grad=True, device=device)
@ -10079,6 +10090,7 @@ class TestNNDeviceType(NNTestCase):
@parametrize_test("align_corners", [True, False])
@parametrize_test("mode", ["bilinear", "bicubic"])
@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
@expectedFailureMPS # double device type
@onlyNativeDeviceTypes
def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory_format):
# Forward AD does not support XLA because XLA tensors don't have storage
@ -10148,6 +10160,7 @@ class TestNNDeviceType(NNTestCase):
@parametrize_test("num_channels", [3, 5])
@parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"])
@parametrize_test("dtype", integral_types() + floating_types())
@skipIfMPS # Error message is wrong for some dtypes
@onlyNativeDeviceTypes
def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_channels, mode, dtype):
x = torch.ones(1, num_channels, 32, 32, dtype=dtype, device=device)
@ -11470,6 +11483,7 @@ class TestNNDeviceType(NNTestCase):
self.assertTrue(gradcheck(F.hardsigmoid, (inputs,)))
# currently fails on XLA
@expectedFailureMPS # TypeError: the MPS framework doesn't support float64
@onlyNativeDeviceTypes
def test_hardswish_grad(self, device):
inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10
@ -11677,6 +11691,7 @@ class TestNNDeviceType(NNTestCase):
self._test_batchnorm_simple_average(device, dtype, torch.float)
@onlyNativeDeviceTypes
@expectedFailureMPS # Unsupported Border padding mode
@dtypes(torch.float, torch.double)
def test_grid_sample_nan_inf(self, device, dtype):
input = torch.zeros([1, 1, 3, 3], device=device, dtype=dtype)
@ -12789,6 +12804,7 @@ if __name__ == '__main__':
F.threshold(x, 0.5, 0.5, inplace=True)
F.threshold_(x, 0.5, 0.5)
@expectedFailureMPS # Double is unsupported
@onlyNativeDeviceTypes
def test_triplet_margin_with_distance_loss_default_parity(self, device):
# Test for `nn.TripletMarginWithDistanceLoss` and
@ -12823,6 +12839,7 @@ if __name__ == '__main__':
self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
(anchor, positive, negative)))
@expectedFailureMPS # Double is unsupported
@onlyNativeDeviceTypes
def test_triplet_margin_with_distance_loss(self, device):
# Test for parity between `nn.TripletMarginWithDistanceLoss` and

View File

@ -11,15 +11,16 @@ from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfMPS,
expectedFailureMPS,
instantiate_device_type_tests,
onlyCPU,
onlyNativeDeviceTypes,
onlyNativeDeviceTypesAnd,
skipLazy,
skipMeta,
skipXLA,
)
from torch.testing._internal.common_dtype import (
all_mps_types_and,
all_types_and,
all_types_and_complex_and,
complex_types,
@ -157,8 +158,11 @@ class TestViewOps(TestCase):
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
@dtypesIfMPS(*integral_types_and(torch.cfloat, torch.float, torch.half, torch.bool))
def test_view_dtype_new(self, device, dtype):
dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
if device.startswith("mps"):
del dtypes[torch.float64]
del dtypes[torch.bool]
def generate_inputs():
@ -271,6 +275,7 @@ class TestViewOps(TestCase):
# has a greater element size than the original dtype
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(*all_mps_types_and(torch.bool))
def test_view_dtype_upsize_errors(self, device, dtype):
dtype_size = torch._utils._element_size(dtype)
@ -372,6 +377,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*complex_types(), torch.complex32)
@dtypesIfMPS(torch.cfloat, torch.chalf)
def test_view_as_real(self, device, dtype):
def fn(contiguous_input=True):
t = torch.randn(3, 4, dtype=dtype, device=device)
@ -398,9 +404,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(
*integral_types_and(torch.half, torch.bfloat16, torch.bool, torch.float32)
)
@dtypesIfMPS(*all_mps_types_and(torch.bool))
def test_view_tensor_split(self, device, dtype):
a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
a_split_dim0 = a.tensor_split(7, 0)
@ -412,6 +416,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
def test_view_tensor_hsplit(self, device, dtype):
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_hsplit = torch.hsplit(t, 2)
@ -422,6 +427,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
def test_view_tensor_vsplit(self, device, dtype):
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_vsplit = torch.vsplit(t, 2)
@ -432,6 +438,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
def test_view_tensor_dsplit(self, device, dtype):
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
t_dsplit = torch.dsplit(t, 2)
@ -440,9 +447,9 @@ class TestViewOps(TestCase):
t[2, 2, 2] = 7
self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
@onlyNativeDeviceTypesAnd("mps")
@onlyNativeDeviceTypes
@dtypes(*all_types_and(torch.half, torch.bfloat16))
@dtypesIfMPS(*integral_types_and(torch.half, torch.bool, torch.float32))
@dtypesIfMPS(*all_mps_types_and(torch.bool))
def test_imag_noncomplex(self, device, dtype):
t = torch.ones((5, 5), dtype=dtype, device=device)
@ -451,6 +458,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*complex_types())
@dtypesIfMPS(torch.cfloat)
def test_real_imag_view(self, device, dtype):
def compare_with_numpy(contiguous_input=True):
t = torch.randn(3, 3, dtype=dtype, device=device)
@ -481,6 +489,7 @@ class TestViewOps(TestCase):
self.assertEqual(a[5:].imag, a.imag[5:])
@onlyNativeDeviceTypes
@expectedFailureMPS
@dtypes(*complex_types())
def test_conj_imag_view(self, device, dtype) -> None:
t = _make_tensor((4, 5), dtype, device)
@ -512,6 +521,12 @@ class TestViewOps(TestCase):
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
)
)
@dtypesIfMPS(
*product(
[torch.cfloat, torch.chalf],
all_mps_types_and(torch.cfloat, torch.chalf, torch.bool),
)
)
@suppress_warnings
def test_set_real_imag(self, device, dtypes):
x = torch.randn(10, dtype=dtypes[0], device=device)

View File

@ -121,6 +121,19 @@ def all_types_and_half():
return _all_types_and_half
_all_mps_types = (
_dispatch_dtypes({torch.float, torch.half, torch.bfloat16}) + _integral_types
)
def all_mps_types():
return _all_mps_types
def all_mps_types_and(*dtypes):
return _all_mps_types + _validate_dtypes(*dtypes)
_float8_types = _dispatch_dtypes(
(
torch.float8_e4m3fn,

View File

@ -297,7 +297,7 @@ if os.getenv("SLOW_TESTS_FILE", ""):
if os.getenv("DISABLED_TESTS_FILE", ""):
disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", ""))
NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name())
NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', torch._C._get_privateuse1_backend_name())
# used for managing devices testing for torch profiler UTs
# for now cpu, cuda and xpu are added for testing torch profiler UTs