[MPS] zeros like, narrow and enable tests (#163011)

zeros like, narrow and enable tests for SparseMPS

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163011
Approved by: https://github.com/malfet
This commit is contained in:
Isalia20
2025-09-16 17:48:02 +00:00
committed by PyTorch MergeBot
parent 559e8d1c20
commit 6db37d7206
2 changed files with 10 additions and 9 deletions

View File

@ -4372,7 +4372,7 @@
variants: function, method
dispatch:
CPU: narrow_copy_dense_cpu
SparseCPU, SparseCUDA: narrow_copy_sparse
SparseCPU, SparseCUDA, SparseMPS: narrow_copy_sparse
CompositeExplicitAutogradNonFunctional: narrow_copy_dense_symint
tags: view_copy
@ -6660,7 +6660,7 @@
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CompositeExplicitAutograd: zeros_out
SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: zeros_sparse_out
- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
dispatch:

View File

@ -479,8 +479,8 @@ class TestSparse(TestSparseBase):
"cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor"):
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), shape, True))
@expectedFailureMPS
@dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16))
@dtypesIfMPS(*all_mps_types())
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
@gradcheck_semantics()
def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
@ -505,7 +505,8 @@ class TestSparse(TestSparseBase):
x.requires_grad_(True)
gradcheck(fn, (x,))
for value_type in [torch.double, torch.cdouble]:
values_types = [torch.double, torch.cdouble] if device != "mps:0" else [torch.float32, torch.complex64]
for value_type in values_types:
i = self.index_tensor([
[0, 1, 2, 2],
[0, 0, 0, 3],
@ -859,8 +860,8 @@ class TestSparse(TestSparseBase):
test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0])
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble, torch.bfloat16)
@dtypesIfMPS(torch.float32, torch.complex64, torch.bfloat16)
@precisionOverride({torch.bfloat16: 2e-2})
def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced):
# This is for testing torch.copy_(SparseTensor, SparseTensor)
@ -883,7 +884,7 @@ class TestSparse(TestSparseBase):
x1.copy_(x2)
self.assertEqual(x1_dtype, x1.dtype)
x2 = x2.to(torch.float64)
x2 = x2.to(torch.float64) if device != "mps:0" else x2.to(torch.float32)
x1_dtype = x1.dtype
x1.copy_(x2)
self.assertEqual(x1_dtype, x1.dtype)
@ -2275,8 +2276,8 @@ class TestSparse(TestSparseBase):
test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 0], [9, 12])
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
def test_zeros_like(self, device, dtype, coalesced):
def _test_zeros_like(nnzs, template_shape_i, template_shape_v=None):
template_shape_v = template_shape_v or []
@ -2416,8 +2417,8 @@ class TestSparse(TestSparseBase):
yield [dim, start, length]
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
def test_narrow(self, device, dtype, coalesced):
shape = [3, 3, 4, 2]
input, _, _ = self._gen_sparse(4, 19, shape, dtype, device, coalesced)
@ -3278,8 +3279,8 @@ class TestSparse(TestSparseBase):
self.assertEqual(list(t.coalesce().values().size()), [1, 3])
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
def test_pickle(self, device, dtype, coalesced):
import pickle