[MPS] sparse add unary funcs + add for sparse tensors (#160839)

Adds several unary functions and add. Enables tests for unary functions in test_sparse but not enabling other tests yet, needs more ops before we fully migrate to testing SparseMPS with `test_sparse.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160839
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Irakli Salia
2025-08-30 01:08:57 +00:00
committed by PyTorch MergeBot
parent ebfee60101
commit 8627a19adf
10 changed files with 465 additions and 83 deletions

View File

@ -12885,6 +12885,100 @@ class TestSparseMPS(TestCaseMPS):
self.assertEqual(coalesced_mps._indices().cpu(), coalesced_cpu._indices())
self.assertEqual(coalesced_mps._values().cpu(), coalesced_cpu._values())
def test_sparse_add(self):
# Basic dense + sparse add
dense_mps = torch.zeros((2, 3), device="mps", dtype=torch.float32)
sparse_mps = self._get_basic_sparse_coo(device="mps")
dense_cpu = dense_mps.cpu()
sparse_cpu = torch.sparse_coo_tensor(
sparse_mps._indices().cpu(), sparse_mps._values().cpu(), sparse_mps.size(), device="cpu"
)
res_mps = torch.add(dense_mps, sparse_mps)
res_cpu = torch.add(dense_cpu, sparse_cpu)
self.assertEqual(res_mps.cpu(), res_cpu)
# alpha scaling (integral alpha)
res_mps = torch.add(dense_mps, sparse_mps, alpha=2)
res_cpu = torch.add(dense_cpu, sparse_cpu, alpha=2)
self.assertEqual(res_mps.cpu(), res_cpu)
# alpha scaling (float alpha) with random dense
dense2_mps = torch.randn((2, 3), device="mps", dtype=torch.float32)
dense2_cpu = dense2_mps.cpu()
res_mps = torch.add(dense2_mps, sparse_mps, alpha=0.5)
res_cpu = torch.add(dense2_cpu, sparse_cpu, alpha=0.5)
self.assertEqual(res_mps.cpu(), res_cpu)
# nnz == 0 fast-path
empty_indices_mps = torch.zeros((2, 0), dtype=torch.int64, device="mps")
empty_values_mps = torch.tensor([], dtype=torch.float32, device="mps")
empty_sparse_mps = torch.sparse_coo_tensor(empty_indices_mps, empty_values_mps, (2, 3), device="mps")
empty_indices_cpu = empty_indices_mps.cpu()
empty_values_cpu = empty_values_mps.cpu()
empty_sparse_cpu = torch.sparse_coo_tensor(empty_indices_cpu, empty_values_cpu, (2, 3), device="cpu")
res_mps = torch.add(dense2_mps, empty_sparse_mps)
res_cpu = torch.add(dense2_cpu, empty_sparse_cpu)
self.assertEqual(res_mps.cpu(), res_cpu)
# 3D case to exercise view_cols > 1 path (values are 2D)
indices3_mps = torch.tensor([[0, 1], [2, 0]], dtype=torch.int64, device="mps")
values3_mps = torch.tensor([[1., 2., 3., 4.], [5., 6., 7., 8.]], dtype=torch.float32, device="mps")
size3 = (2, 3, 4)
sp3_mps = torch.sparse_coo_tensor(indices3_mps, values3_mps, size3, device="mps")
dense3_mps = torch.randn(size3, device="mps", dtype=torch.float32)
indices3_cpu = indices3_mps.cpu()
values3_cpu = values3_mps.cpu()
sp3_cpu = torch.sparse_coo_tensor(indices3_cpu, values3_cpu, size3, device="cpu")
dense3_cpu = dense3_mps.cpu()
res_mps = torch.add(dense3_mps, sp3_mps, alpha=1.0)
res_cpu = torch.add(dense3_cpu, sp3_cpu, alpha=1.0)
self.assertEqual(res_mps.cpu(), res_cpu)
# dtype promotion: dense float32 + sparse float16
sparse_f16_mps = torch.sparse_coo_tensor(
sparse_mps._indices(),
sparse_mps._values().to(torch.float16),
sparse_mps.size(),
device="mps",
)
sparse_f16_cpu = torch.sparse_coo_tensor(
sparse_f16_mps._indices().cpu(),
sparse_f16_mps._values().cpu(),
sparse_f16_mps.size(),
device="cpu",
)
res_mps = torch.add(dense2_mps, sparse_f16_mps, alpha=0.25)
res_cpu = torch.add(dense2_cpu, sparse_f16_cpu, alpha=0.25)
self.assertEqual(res_mps.cpu(), res_cpu)
# broadcasting not supported: mismatched size should error
bad_sparse_mps = torch.sparse_coo_tensor(
sparse_mps._indices(), sparse_mps._values(), (2, 4), device="mps"
)
with self.assertRaisesRegex(RuntimeError, "same size"):
torch.add(dense_mps, bad_sparse_mps)
# sparse + sparse with overlap (tests concatenation + coalesce + alpha)
s1_idx = torch.tensor([[0, 0, 1], [0, 0, 2]], dtype=torch.int64)
s1_val = torch.tensor([1., 2., 3.], dtype=torch.float32)
s2_idx = torch.tensor([[0, 1, 1], [0, 2, 2]], dtype=torch.int64)
s2_val = torch.tensor([4., 5., 6.], dtype=torch.float32)
s1_mps = torch.sparse_coo_tensor(s1_idx.to("mps"), s1_val.to("mps"), (2, 3), device="mps")
s2_mps = torch.sparse_coo_tensor(s2_idx.to("mps"), s2_val.to("mps"), (2, 3), device="mps")
s1_cpu = torch.sparse_coo_tensor(s1_idx, s1_val, (2, 3), device="cpu")
s2_cpu = torch.sparse_coo_tensor(s2_idx, s2_val, (2, 3), device="cpu")
sp_res_mps = torch.add(s1_mps, s2_mps, alpha=2.0).coalesce()
sp_res_cpu = torch.add(s1_cpu, s2_cpu, alpha=2.0).coalesce()
self.assertEqual(sp_res_mps.cpu(), sp_res_cpu)
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the