[MPS] mps sparse mul op implementation (#162349)

Implements mps sparse mul operation as well as enables other operations such as:
1. copy_
2. div
3. sum
4. floor
5. power
6. sub
7. floor_divide

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162349
Approved by: https://github.com/pearu, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Isalia20
2025-09-09 15:45:34 +00:00
committed by PyTorch MergeBot
parent be3b8d2ec9
commit 3ea6868049
4 changed files with 483 additions and 25 deletions

View File

@ -1108,8 +1108,8 @@ class TestSparse(TestSparseBase):
test_shape(2, 20, [3, 17, 19, 5])
test_shape(2, 20, [3, 17, 19, 0])
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
def test_add_sub_nnz(self, device, dtype):
# nnz should not grow unbounded (gh-34964)
x = torch.randn(10, dtype=dtype, device=device).to_sparse()
@ -1687,8 +1687,8 @@ class TestSparse(TestSparseBase):
test_shape(7, 8, 9, 20, True)
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
@gradcheck_semantics()
def test_sparse_mul(self, device, dtype, coalesced, gradcheck):
@ -1868,8 +1868,8 @@ class TestSparse(TestSparseBase):
x.norm(**kwargs)
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
@unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
def test_sparse_sum(self, device, dtype, coalesced):
@ -1933,7 +1933,6 @@ class TestSparse(TestSparseBase):
S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
run_tests(S.requires_grad_(True), test_dim)
@expectedFailureMPS
def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced):
shape = shape_i + (shape_v)
x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced)
@ -2011,6 +2010,7 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
def test_basic_ops(self, device, dtype, coalesced):
def _test_basic_ops():