mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] sparse add unary funcs + add for sparse tensors (#160839)
Adds several unary functions and add. Enables tests for unary functions in test_sparse but not enabling other tests yet, needs more ops before we fully migrate to testing SparseMPS with `test_sparse.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160839 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
ebfee60101
commit
8627a19adf
@ -12885,6 +12885,100 @@ class TestSparseMPS(TestCaseMPS):
|
||||
self.assertEqual(coalesced_mps._indices().cpu(), coalesced_cpu._indices())
|
||||
self.assertEqual(coalesced_mps._values().cpu(), coalesced_cpu._values())
|
||||
|
||||
def test_sparse_add(self):
|
||||
# Basic dense + sparse add
|
||||
dense_mps = torch.zeros((2, 3), device="mps", dtype=torch.float32)
|
||||
sparse_mps = self._get_basic_sparse_coo(device="mps")
|
||||
|
||||
dense_cpu = dense_mps.cpu()
|
||||
sparse_cpu = torch.sparse_coo_tensor(
|
||||
sparse_mps._indices().cpu(), sparse_mps._values().cpu(), sparse_mps.size(), device="cpu"
|
||||
)
|
||||
|
||||
res_mps = torch.add(dense_mps, sparse_mps)
|
||||
res_cpu = torch.add(dense_cpu, sparse_cpu)
|
||||
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||
|
||||
# alpha scaling (integral alpha)
|
||||
res_mps = torch.add(dense_mps, sparse_mps, alpha=2)
|
||||
res_cpu = torch.add(dense_cpu, sparse_cpu, alpha=2)
|
||||
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||
|
||||
# alpha scaling (float alpha) with random dense
|
||||
dense2_mps = torch.randn((2, 3), device="mps", dtype=torch.float32)
|
||||
dense2_cpu = dense2_mps.cpu()
|
||||
res_mps = torch.add(dense2_mps, sparse_mps, alpha=0.5)
|
||||
res_cpu = torch.add(dense2_cpu, sparse_cpu, alpha=0.5)
|
||||
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||
|
||||
# nnz == 0 fast-path
|
||||
empty_indices_mps = torch.zeros((2, 0), dtype=torch.int64, device="mps")
|
||||
empty_values_mps = torch.tensor([], dtype=torch.float32, device="mps")
|
||||
empty_sparse_mps = torch.sparse_coo_tensor(empty_indices_mps, empty_values_mps, (2, 3), device="mps")
|
||||
|
||||
empty_indices_cpu = empty_indices_mps.cpu()
|
||||
empty_values_cpu = empty_values_mps.cpu()
|
||||
empty_sparse_cpu = torch.sparse_coo_tensor(empty_indices_cpu, empty_values_cpu, (2, 3), device="cpu")
|
||||
|
||||
res_mps = torch.add(dense2_mps, empty_sparse_mps)
|
||||
res_cpu = torch.add(dense2_cpu, empty_sparse_cpu)
|
||||
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||
|
||||
# 3D case to exercise view_cols > 1 path (values are 2D)
|
||||
indices3_mps = torch.tensor([[0, 1], [2, 0]], dtype=torch.int64, device="mps")
|
||||
values3_mps = torch.tensor([[1., 2., 3., 4.], [5., 6., 7., 8.]], dtype=torch.float32, device="mps")
|
||||
size3 = (2, 3, 4)
|
||||
sp3_mps = torch.sparse_coo_tensor(indices3_mps, values3_mps, size3, device="mps")
|
||||
dense3_mps = torch.randn(size3, device="mps", dtype=torch.float32)
|
||||
|
||||
indices3_cpu = indices3_mps.cpu()
|
||||
values3_cpu = values3_mps.cpu()
|
||||
sp3_cpu = torch.sparse_coo_tensor(indices3_cpu, values3_cpu, size3, device="cpu")
|
||||
dense3_cpu = dense3_mps.cpu()
|
||||
|
||||
res_mps = torch.add(dense3_mps, sp3_mps, alpha=1.0)
|
||||
res_cpu = torch.add(dense3_cpu, sp3_cpu, alpha=1.0)
|
||||
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||
|
||||
# dtype promotion: dense float32 + sparse float16
|
||||
sparse_f16_mps = torch.sparse_coo_tensor(
|
||||
sparse_mps._indices(),
|
||||
sparse_mps._values().to(torch.float16),
|
||||
sparse_mps.size(),
|
||||
device="mps",
|
||||
)
|
||||
sparse_f16_cpu = torch.sparse_coo_tensor(
|
||||
sparse_f16_mps._indices().cpu(),
|
||||
sparse_f16_mps._values().cpu(),
|
||||
sparse_f16_mps.size(),
|
||||
device="cpu",
|
||||
)
|
||||
res_mps = torch.add(dense2_mps, sparse_f16_mps, alpha=0.25)
|
||||
res_cpu = torch.add(dense2_cpu, sparse_f16_cpu, alpha=0.25)
|
||||
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||
|
||||
# broadcasting not supported: mismatched size should error
|
||||
bad_sparse_mps = torch.sparse_coo_tensor(
|
||||
sparse_mps._indices(), sparse_mps._values(), (2, 4), device="mps"
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "same size"):
|
||||
torch.add(dense_mps, bad_sparse_mps)
|
||||
|
||||
# sparse + sparse with overlap (tests concatenation + coalesce + alpha)
|
||||
s1_idx = torch.tensor([[0, 0, 1], [0, 0, 2]], dtype=torch.int64)
|
||||
s1_val = torch.tensor([1., 2., 3.], dtype=torch.float32)
|
||||
s2_idx = torch.tensor([[0, 1, 1], [0, 2, 2]], dtype=torch.int64)
|
||||
s2_val = torch.tensor([4., 5., 6.], dtype=torch.float32)
|
||||
|
||||
s1_mps = torch.sparse_coo_tensor(s1_idx.to("mps"), s1_val.to("mps"), (2, 3), device="mps")
|
||||
s2_mps = torch.sparse_coo_tensor(s2_idx.to("mps"), s2_val.to("mps"), (2, 3), device="mps")
|
||||
s1_cpu = torch.sparse_coo_tensor(s1_idx, s1_val, (2, 3), device="cpu")
|
||||
s2_cpu = torch.sparse_coo_tensor(s2_idx, s2_val, (2, 3), device="cpu")
|
||||
|
||||
sp_res_mps = torch.add(s1_mps, s2_mps, alpha=2.0).coalesce()
|
||||
sp_res_cpu = torch.add(s1_cpu, s2_cpu, alpha=2.0).coalesce()
|
||||
self.assertEqual(sp_res_mps.cpu(), sp_res_cpu)
|
||||
|
||||
|
||||
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
|
||||
# This requires mps to be properly registered in the device generic test framework which is not the
|
||||
|
Reference in New Issue
Block a user