mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Testing] Add MPS to NATIVE_DEVICES (#153835)
This would allow me to enable more opinfo tests against MPS device eventually and supposed to be a very simple test, but actually required minor adjustments to lots of test files, namely: - Introduce `all_mps_types_and` that is very similar to `all_types_and`, but skips `float64` - Decorate lots of tests with `@dtypesIfMPS(*all_mps_types())` - Skip `test_from_dlpack_noncontinguous` as it currently crashes (need to be fixed) - Add lots of `expectedFailureIfMPS` - Delete all `@onlyNativeDeviceTypesAnd("mps")` <sarcasm> I love how well documented this variable are </sarcasm> Pull Request resolved: https://github.com/pytorch/pytorch/pull/153835 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
0ba09a6d34
commit
e06b110f73
@ -2842,6 +2842,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
@parametrize_test("strided", [False, True])
|
||||
# Test with both contiguous and non-contiguous inputs.
|
||||
@parametrize_test("contiguous", [False, True])
|
||||
@expectedFailureMPS # No double support
|
||||
def test_conv_backend(
|
||||
self,
|
||||
device,
|
||||
|
@ -504,6 +504,7 @@ class TestPoolingNN(NNTestCase):
|
||||
|
||||
|
||||
class TestPoolingNNDeviceType(NNTestCase):
|
||||
@expectedFailureMPS # No double, float shape prop does not work
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_adaptive_pooling_zero_batch(self, dtype, device):
|
||||
@ -523,6 +524,7 @@ class TestPoolingNNDeviceType(NNTestCase):
|
||||
# when output_size = 0, in adaptive_{avg, max}_pool and its variants.
|
||||
# These tests are explicitly written because ErrorInputs does not support backward calls
|
||||
# Issue: https://github.com/pytorch/pytorch/issues/78868
|
||||
@expectedFailureMPS # No double, float shape prop does not work
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
@dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
||||
@ -556,6 +558,7 @@ class TestPoolingNNDeviceType(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
fn(input2, output_size).sum().backward()
|
||||
|
||||
@expectedFailureMPS # Error message does not match
|
||||
@onlyNativeDeviceTypes
|
||||
def test_adaptive_avg_pooling_backward_fails(self, device):
|
||||
grad_output = torch.randn(1, 2, 7, device=device)
|
||||
@ -582,6 +585,7 @@ class TestPoolingNNDeviceType(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "expected dimensions"):
|
||||
torch.ops.aten.adaptive_max_pool3d_backward(grad_output, input, indices)
|
||||
|
||||
@expectedFailureMPS # Op not implemented
|
||||
@onlyNativeDeviceTypes
|
||||
def test_FractionalMaxPool2d_zero_batch(self, device):
|
||||
mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
|
||||
@ -592,6 +596,7 @@ class TestPoolingNNDeviceType(NNTestCase):
|
||||
inp = torch.randn(1, 0, 50, 32, device=device)
|
||||
mod(inp)
|
||||
|
||||
@expectedFailureMPS # Op not implemented
|
||||
@onlyNativeDeviceTypes
|
||||
def test_FractionalMaxPool3d_zero_batch(self, device):
|
||||
mod = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)).to(device)
|
||||
@ -602,6 +607,7 @@ class TestPoolingNNDeviceType(NNTestCase):
|
||||
inp = torch.randn(1, 0, 50, 32, 32, device=device)
|
||||
mod(inp)
|
||||
|
||||
@expectedFailureMPS # Op not implemented
|
||||
@onlyNativeDeviceTypes
|
||||
def test_FractionalMaxPool2d_zero_out_size(self, device):
|
||||
mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1])
|
||||
@ -609,6 +615,7 @@ class TestPoolingNNDeviceType(NNTestCase):
|
||||
out = mod(inp)
|
||||
self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device))
|
||||
|
||||
@expectedFailureMPS # Op not implemented
|
||||
@onlyNativeDeviceTypes
|
||||
def test_FractionalMaxPool3d_zero_out_size(self, device):
|
||||
mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1])
|
||||
@ -616,6 +623,7 @@ class TestPoolingNNDeviceType(NNTestCase):
|
||||
out = mod(inp)
|
||||
self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device))
|
||||
|
||||
@expectedFailureMPS # Op not implemented
|
||||
@onlyNativeDeviceTypes
|
||||
def test_FractionalMaxPool2d_zero_samples(self, device):
|
||||
samples = torch.rand([0, 16, 2], device=device)
|
||||
@ -630,6 +638,7 @@ class TestPoolingNNDeviceType(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"):
|
||||
mod(inp1)
|
||||
|
||||
@expectedFailureMPS # Op not implemented
|
||||
@onlyNativeDeviceTypes
|
||||
def test_FractionalMaxPool3d_zero_samples(self, device):
|
||||
samples = torch.rand([0, 16, 3], device=device)
|
||||
@ -823,6 +832,7 @@ torch.cuda.synchronize()
|
||||
else:
|
||||
unpool(output, indices)
|
||||
|
||||
@expectedFailureMPS
|
||||
@onlyNativeDeviceTypes
|
||||
def test_AdaptiveMaxPool_zero_batch_dim(self, device):
|
||||
inp = torch.randn(0, 16, 50, device=device)
|
||||
@ -962,6 +972,7 @@ torch.cuda.synchronize()
|
||||
c = out.size(1)
|
||||
self.assertEqual(out.stride(), [c, 1, 1, 1, 1])
|
||||
|
||||
@expectedFailureMPS # Runtime Error not raised for mps
|
||||
@expectedFailureMeta # Runtime Error not raised for meta
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.uint8, torch.int8, torch.short, torch.int, torch.long)
|
||||
@ -976,6 +987,7 @@ torch.cuda.synchronize()
|
||||
with self.assertRaisesRegex(RuntimeError, "not implemented"):
|
||||
module(input)
|
||||
|
||||
@expectedFailureMPS # TODO: fixme
|
||||
@onlyNativeDeviceTypes
|
||||
@gcIfJetson
|
||||
@dtypes(torch.float, torch.double)
|
||||
@ -1123,6 +1135,7 @@ torch.cuda.synchronize()
|
||||
helper(1, 100000, 32, 32, ks=4)
|
||||
helper(1, 100000, 1, 4, ks=(1, 4)) # test for max_pool1d
|
||||
|
||||
@expectedFailureMPS # TODO: Fixme
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
@dtypesIfCUDA(torch.half, torch.float, torch.double)
|
||||
@ -1198,6 +1211,7 @@ torch.cuda.synchronize()
|
||||
torch.channels_last,
|
||||
)
|
||||
|
||||
@expectedFailureMPS # TODO: Fixme
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
@dtypesIfCUDA(torch.half, torch.float, torch.double)
|
||||
@ -1722,6 +1736,7 @@ torch.cuda.synchronize()
|
||||
|
||||
@dtypesIfCUDA(torch.half, torch.float, torch.double)
|
||||
@dtypes(torch.float)
|
||||
@expectedFailureMPS # Exception not raise
|
||||
@onlyNativeDeviceTypes # TODO: Fails on XLA
|
||||
@gcIfJetson
|
||||
def test_max_pool_nan_inf(self, device, dtype):
|
||||
@ -1758,6 +1773,7 @@ torch.cuda.synchronize()
|
||||
res2 = fn(x2, 1 if adaptive else 3)
|
||||
self.assertTrue(math.isinf(res2.item()))
|
||||
|
||||
@expectedFailureMPS # float64
|
||||
@expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta
|
||||
@onlyNativeDeviceTypes
|
||||
def test_fractional_max_pool2d(self, device):
|
||||
@ -1820,6 +1836,7 @@ torch.cuda.synchronize()
|
||||
grad_output, input, kernel_size, output_size, indices
|
||||
)
|
||||
|
||||
@expectedFailureMPS # float64
|
||||
@expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta
|
||||
@onlyNativeDeviceTypes
|
||||
def test_fractional_max_pool3d(self, device):
|
||||
@ -1867,6 +1884,7 @@ torch.cuda.synchronize()
|
||||
x, (2, 2, 2), output_size=output_size, _random_samples=samples
|
||||
)
|
||||
|
||||
@expectedFailureMPS # Not implemented
|
||||
@dtypesIfCUDA(torch.half, torch.float, torch.double)
|
||||
@dtypes(torch.float)
|
||||
@onlyNativeDeviceTypes # TODO: Fails on XLA
|
||||
@ -1896,6 +1914,7 @@ torch.cuda.synchronize()
|
||||
res2.backward(torch.randn_like(res2))
|
||||
self.assertTrue(math.isinf(res2.item()))
|
||||
|
||||
@expectedFailureMPS # TODO: Fix me
|
||||
@onlyNativeDeviceTypes # TODO: RuntimeError message different on XLA
|
||||
def test_pooling_zero_stride(self, device):
|
||||
for op in ("max", "avg"):
|
||||
|
@ -5,6 +5,7 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
deviceCountAtLeast,
|
||||
dtypes,
|
||||
dtypesIfMPS,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
@ -13,10 +14,14 @@ from torch.testing._internal.common_device_type import (
|
||||
skipCUDAIfRocm,
|
||||
skipMeta,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_mps_types_and,
|
||||
all_types_and_complex_and,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_JETSON,
|
||||
run_tests,
|
||||
skipIfMPS,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
@ -55,6 +60,7 @@ class TestTorchDlPack(TestCase):
|
||||
torch.uint64,
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
|
||||
def test_dlpack_capsule_conversion(self, device, dtype):
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
z = from_dlpack(to_dlpack(x))
|
||||
@ -72,6 +78,7 @@ class TestTorchDlPack(TestCase):
|
||||
torch.uint64,
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
|
||||
def test_dlpack_protocol_conversion(self, device, dtype):
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
z = from_dlpack(x)
|
||||
@ -80,7 +87,8 @@ class TestTorchDlPack(TestCase):
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
def test_dlpack_shared_storage(self, device):
|
||||
x = make_tensor((5,), dtype=torch.float64, device=device)
|
||||
dtype = torch.bfloat16 if device.startswith("mps") else torch.float64
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
z = from_dlpack(to_dlpack(x))
|
||||
z[0] = z[0] + 20.0
|
||||
self.assertEqual(z, x)
|
||||
@ -120,12 +128,14 @@ class TestTorchDlPack(TestCase):
|
||||
torch.uint64,
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
|
||||
def test_from_dlpack(self, device, dtype):
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
y = torch.from_dlpack(x)
|
||||
self.assertEqual(x, y)
|
||||
|
||||
@skipMeta
|
||||
@skipIfMPS # MPS crashes with noncontiguous now
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(
|
||||
*all_types_and_complex_and(
|
||||
@ -189,6 +199,7 @@ class TestTorchDlPack(TestCase):
|
||||
torch.uint64,
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf))
|
||||
def test_from_dlpack_dtype(self, device, dtype):
|
||||
x = make_tensor((5,), dtype=dtype, device=device)
|
||||
y = torch.from_dlpack(x)
|
||||
|
@ -16,6 +16,7 @@ from torch.testing._internal.common_device_type import (
|
||||
dtypesIfCPU,
|
||||
dtypesIfCUDA,
|
||||
dtypesIfMPS,
|
||||
expectedFailureMPS,
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
@ -183,6 +184,7 @@ class TestIndexing(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.half, torch.double)
|
||||
@dtypesIfMPS(torch.half) # TODO: add bf16 there?
|
||||
def test_advancedindex(self, device, dtype):
|
||||
# Tests for Integer Array Indexing, Part I - Purely integer array
|
||||
# indexing
|
||||
@ -1193,6 +1195,7 @@ class TestIndexing(TestCase):
|
||||
out_cpu = func1(t, ind, val)
|
||||
self.assertEqual(out_cuda.cpu(), out_cpu)
|
||||
|
||||
@expectedFailureMPS # Doubles not supported
|
||||
@onlyNativeDeviceTypes
|
||||
def test_index_put_accumulate_duplicate_indices(self, device):
|
||||
for i in range(1, 512):
|
||||
|
@ -8766,6 +8766,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64)
|
||||
@dtypesIfMPS(torch.float16, torch.bfloat16, torch.float32)
|
||||
def test_rmsnorm_epsilon(self, device, dtype):
|
||||
def rms_norm_reference_fn(i, normalized_shape):
|
||||
eps = torch.finfo(i.dtype).eps
|
||||
@ -8940,6 +8941,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
Y_cpu = group_norm(X.cpu())
|
||||
self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5)
|
||||
|
||||
@expectedFailureMPS # Double is not supported on MPS
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float64, torch.complex128)
|
||||
def test_pad(self, device, dtype):
|
||||
@ -8971,6 +8973,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
out.fill_(4)
|
||||
self.assertTrue(torch.all(torch.abs(inputs) < 2))
|
||||
|
||||
@expectedFailureMPS # Unsupported float64/complex128
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float64, torch.complex128)
|
||||
def test_ReplicationPad_empty(self, device, dtype):
|
||||
@ -9109,6 +9112,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
self.assertEqual(inp1.grad, torch.zeros_like(inp1))
|
||||
self.assertEqual(inp2.grad, torch.zeros_like(inp2))
|
||||
|
||||
@expectedFailureMPS # Double not supported
|
||||
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
|
||||
@onlyNativeDeviceTypes
|
||||
def test_TransformerEncoderLayer_empty(self, device):
|
||||
@ -9138,6 +9142,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
_test_module_empty_input(self, encoder_layer, input, check_size=False)
|
||||
|
||||
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
|
||||
@expectedFailureMPS # Float64 is not supported
|
||||
@onlyNativeDeviceTypes
|
||||
def test_TransformerEncoder_empty(self, device):
|
||||
for batch_first, input_shape in [(True, (0, 10, 512)),
|
||||
@ -9148,6 +9153,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
_test_module_empty_input(self, transformer_encoder, input, check_size=False)
|
||||
|
||||
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
|
||||
@expectedFailureMPS # Float64 is not supported
|
||||
@onlyNativeDeviceTypes
|
||||
def test_TransformerDecoderLayer_empty(self, device):
|
||||
for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
|
||||
@ -9158,6 +9164,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
self._test_module_empty_inputs(decoder_layer, [tgt, memory])
|
||||
|
||||
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
|
||||
@expectedFailureMPS # Float64 is not supported
|
||||
@onlyNativeDeviceTypes
|
||||
def test_TransformerDecoder_empty(self, device):
|
||||
for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
|
||||
@ -9169,6 +9176,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
self._test_module_empty_inputs(transformer_decoder, [tgt, memory])
|
||||
|
||||
@expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
|
||||
@expectedFailureMPS # Float64 is not supported
|
||||
@onlyNativeDeviceTypes
|
||||
def test_Transformer_empty(self, device):
|
||||
for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]:
|
||||
@ -9304,6 +9312,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
|
||||
self.assertEqual(x.grad, ref_x.grad)
|
||||
|
||||
@expectedFailureMPS # Unimplemented margin_loss
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_MarginLoss_empty(self, device, dtype):
|
||||
@ -9370,6 +9379,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Expected all tensors to be on the same device'):
|
||||
F.mse_loss(i, t)
|
||||
|
||||
@expectedFailureMPS # TODO: Fixme, and raise assert on empty tensor
|
||||
@onlyNativeDeviceTypes
|
||||
def test_Unfold_empty(self, device):
|
||||
inp = torch.randn(0, 3, 3, 4, device=device)
|
||||
@ -9593,6 +9603,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
verify_reduction_scalars(input, reduction, output)
|
||||
|
||||
# verify that bogus reduction strings are errors
|
||||
@expectedFailureMPS # CTCLoss unimplemented
|
||||
@onlyNativeDeviceTypes
|
||||
def test_invalid_reduction_strings(self, device):
|
||||
input = torch.randn(3, 5, requires_grad=True, device=device)
|
||||
@ -10079,6 +10090,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
@parametrize_test("align_corners", [True, False])
|
||||
@parametrize_test("mode", ["bilinear", "bicubic"])
|
||||
@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
|
||||
@expectedFailureMPS # double device type
|
||||
@onlyNativeDeviceTypes
|
||||
def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory_format):
|
||||
# Forward AD does not support XLA because XLA tensors don't have storage
|
||||
@ -10148,6 +10160,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
@parametrize_test("num_channels", [3, 5])
|
||||
@parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"])
|
||||
@parametrize_test("dtype", integral_types() + floating_types())
|
||||
@skipIfMPS # Error message is wrong for some dtypes
|
||||
@onlyNativeDeviceTypes
|
||||
def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_channels, mode, dtype):
|
||||
x = torch.ones(1, num_channels, 32, 32, dtype=dtype, device=device)
|
||||
@ -11470,6 +11483,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
self.assertTrue(gradcheck(F.hardsigmoid, (inputs,)))
|
||||
|
||||
# currently fails on XLA
|
||||
@expectedFailureMPS # TypeError: the MPS framework doesn't support float64
|
||||
@onlyNativeDeviceTypes
|
||||
def test_hardswish_grad(self, device):
|
||||
inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10
|
||||
@ -11677,6 +11691,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
self._test_batchnorm_simple_average(device, dtype, torch.float)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@expectedFailureMPS # Unsupported Border padding mode
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_grid_sample_nan_inf(self, device, dtype):
|
||||
input = torch.zeros([1, 1, 3, 3], device=device, dtype=dtype)
|
||||
@ -12789,6 +12804,7 @@ if __name__ == '__main__':
|
||||
F.threshold(x, 0.5, 0.5, inplace=True)
|
||||
F.threshold_(x, 0.5, 0.5)
|
||||
|
||||
@expectedFailureMPS # Double is unsupported
|
||||
@onlyNativeDeviceTypes
|
||||
def test_triplet_margin_with_distance_loss_default_parity(self, device):
|
||||
# Test for `nn.TripletMarginWithDistanceLoss` and
|
||||
@ -12823,6 +12839,7 @@ if __name__ == '__main__':
|
||||
self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
|
||||
(anchor, positive, negative)))
|
||||
|
||||
@expectedFailureMPS # Double is unsupported
|
||||
@onlyNativeDeviceTypes
|
||||
def test_triplet_margin_with_distance_loss(self, device):
|
||||
# Test for parity between `nn.TripletMarginWithDistanceLoss` and
|
||||
|
@ -11,15 +11,16 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfMPS,
|
||||
expectedFailureMPS,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyNativeDeviceTypes,
|
||||
onlyNativeDeviceTypesAnd,
|
||||
skipLazy,
|
||||
skipMeta,
|
||||
skipXLA,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_mps_types_and,
|
||||
all_types_and,
|
||||
all_types_and_complex_and,
|
||||
complex_types,
|
||||
@ -157,8 +158,11 @@ class TestViewOps(TestCase):
|
||||
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
@dtypesIfMPS(*integral_types_and(torch.cfloat, torch.float, torch.half, torch.bool))
|
||||
def test_view_dtype_new(self, device, dtype):
|
||||
dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||
if device.startswith("mps"):
|
||||
del dtypes[torch.float64]
|
||||
del dtypes[torch.bool]
|
||||
|
||||
def generate_inputs():
|
||||
@ -271,6 +275,7 @@ class TestViewOps(TestCase):
|
||||
# has a greater element size than the original dtype
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool))
|
||||
def test_view_dtype_upsize_errors(self, device, dtype):
|
||||
dtype_size = torch._utils._element_size(dtype)
|
||||
|
||||
@ -372,6 +377,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*complex_types(), torch.complex32)
|
||||
@dtypesIfMPS(torch.cfloat, torch.chalf)
|
||||
def test_view_as_real(self, device, dtype):
|
||||
def fn(contiguous_input=True):
|
||||
t = torch.randn(3, 4, dtype=dtype, device=device)
|
||||
@ -398,9 +404,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(
|
||||
*integral_types_and(torch.half, torch.bfloat16, torch.bool, torch.float32)
|
||||
)
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool))
|
||||
def test_view_tensor_split(self, device, dtype):
|
||||
a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
|
||||
a_split_dim0 = a.tensor_split(7, 0)
|
||||
@ -412,6 +416,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
|
||||
def test_view_tensor_hsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_hsplit = torch.hsplit(t, 2)
|
||||
@ -422,6 +427,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
|
||||
def test_view_tensor_vsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_vsplit = torch.vsplit(t, 2)
|
||||
@ -432,6 +438,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool))
|
||||
def test_view_tensor_dsplit(self, device, dtype):
|
||||
t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
|
||||
t_dsplit = torch.dsplit(t, 2)
|
||||
@ -440,9 +447,9 @@ class TestViewOps(TestCase):
|
||||
t[2, 2, 2] = 7
|
||||
self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
|
||||
|
||||
@onlyNativeDeviceTypesAnd("mps")
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
@dtypesIfMPS(*integral_types_and(torch.half, torch.bool, torch.float32))
|
||||
@dtypesIfMPS(*all_mps_types_and(torch.bool))
|
||||
def test_imag_noncomplex(self, device, dtype):
|
||||
t = torch.ones((5, 5), dtype=dtype, device=device)
|
||||
|
||||
@ -451,6 +458,7 @@ class TestViewOps(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*complex_types())
|
||||
@dtypesIfMPS(torch.cfloat)
|
||||
def test_real_imag_view(self, device, dtype):
|
||||
def compare_with_numpy(contiguous_input=True):
|
||||
t = torch.randn(3, 3, dtype=dtype, device=device)
|
||||
@ -481,6 +489,7 @@ class TestViewOps(TestCase):
|
||||
self.assertEqual(a[5:].imag, a.imag[5:])
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@expectedFailureMPS
|
||||
@dtypes(*complex_types())
|
||||
def test_conj_imag_view(self, device, dtype) -> None:
|
||||
t = _make_tensor((4, 5), dtype, device)
|
||||
@ -512,6 +521,12 @@ class TestViewOps(TestCase):
|
||||
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
|
||||
)
|
||||
)
|
||||
@dtypesIfMPS(
|
||||
*product(
|
||||
[torch.cfloat, torch.chalf],
|
||||
all_mps_types_and(torch.cfloat, torch.chalf, torch.bool),
|
||||
)
|
||||
)
|
||||
@suppress_warnings
|
||||
def test_set_real_imag(self, device, dtypes):
|
||||
x = torch.randn(10, dtype=dtypes[0], device=device)
|
||||
|
@ -121,6 +121,19 @@ def all_types_and_half():
|
||||
return _all_types_and_half
|
||||
|
||||
|
||||
_all_mps_types = (
|
||||
_dispatch_dtypes({torch.float, torch.half, torch.bfloat16}) + _integral_types
|
||||
)
|
||||
|
||||
|
||||
def all_mps_types():
|
||||
return _all_mps_types
|
||||
|
||||
|
||||
def all_mps_types_and(*dtypes):
|
||||
return _all_mps_types + _validate_dtypes(*dtypes)
|
||||
|
||||
|
||||
_float8_types = _dispatch_dtypes(
|
||||
(
|
||||
torch.float8_e4m3fn,
|
||||
|
@ -297,7 +297,7 @@ if os.getenv("SLOW_TESTS_FILE", ""):
|
||||
if os.getenv("DISABLED_TESTS_FILE", ""):
|
||||
disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", ""))
|
||||
|
||||
NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name())
|
||||
NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', torch._C._get_privateuse1_backend_name())
|
||||
|
||||
# used for managing devices testing for torch profiler UTs
|
||||
# for now cpu, cuda and xpu are added for testing torch profiler UTs
|
||||
|
Reference in New Issue
Block a user