mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SparseCompressed] support csc layout for add sparse/dense. (#115433)
`add` when passed one sparse and one dense argument will error if the sparse argument does not have csr layout. This PR modifies the underlying algorithm to be generic on the compressed dimension handling both csr and csc. The functions are renamed to use the `sparse_compressed` qualifier rather than `sparse_csr` Fixes: #114807 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115433 Approved by: https://github.com/cpuhrsch, https://github.com/pearu ghstack dependencies: #115432
This commit is contained in:
committed by
PyTorch MergeBot
parent
910baa3a03
commit
4b97ed2ed8
@ -2109,13 +2109,19 @@ class TestSparseCSR(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
|
||||
test(is_sparse=True)
|
||||
|
||||
@sparse_compressed_nonblock_layouts()
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_add(self, device, dtype):
|
||||
def test_add(self, device, layout, dtype):
|
||||
def _test_spadd_shape(nnz, shape):
|
||||
# sparse.to_dense() uses torch.add internally so if torch.add is wrong,
|
||||
# the dense tensor will be wrong but this test would still pass
|
||||
# there's a separate test that checks for the correctness of the .to_dense() call
|
||||
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
|
||||
x = self.genSparseCompressedTensor(shape, nnz,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
index_dtype=torch.int32,
|
||||
layout=layout,
|
||||
blocksize=())
|
||||
y = torch.randn(*shape, dtype=dtype, device=device)
|
||||
r = random.random()
|
||||
|
||||
@ -2140,6 +2146,7 @@ class TestSparseCSR(TestCase):
|
||||
self.assertEqual(res, expected)
|
||||
self.assertEqual(res_perm, expected)
|
||||
|
||||
|
||||
ns = [2, 5]
|
||||
batch_shapes = [(), (2,), (2, 3)]
|
||||
for b, m, n in itertools.product(batch_shapes, ns, ns):
|
||||
|
||||
Reference in New Issue
Block a user