[BE] Add torch.ops.aten._sparse_compressed_tensor_with_dims (#123083)

Used in https://github.com/pytorch/pytorch/pull/123084 and allows simplifying `empty_like` implementation for sparse compressed tensors (see https://github.com/pytorch/pytorch/pull/121900#issuecomment-2029835473).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123083
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2024-04-02 00:24:47 +03:00
committed by PyTorch MergeBot
parent f9b2ffa7c4
commit 72662bf05b
5 changed files with 138 additions and 48 deletions

View File

@ -336,6 +336,46 @@ class TestSparseCompressed(TestCase):
", but got size"):
torch.empty((5,), dtype=dtype, device=device, layout=layout)
@skipMeta
@all_sparse_compressed_layouts()
@dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
def test_sparse_compressed_tensor_with_dims(self, layout, device, dtype):
def get_sparse_compressed_tensor_properties(s):
if layout in {torch.sparse_csr, torch.sparse_bsr}:
compressed_indices, plain_indices = s.crow_indices(), s.col_indices()
else:
compressed_indices, plain_indices = s.ccol_indices(), s.row_indices()
values = s.values()
return dict(shape=s.shape, dtype=s.dtype, device=s.device, nnz=s._nnz(), layout=s.layout,
compressed_indices_shape=compressed_indices.shape,
compressed_indices_dtype=compressed_indices.dtype,
compressed_indices_device=compressed_indices.device,
plain_indices_shape=plain_indices.shape,
plain_indices_dtype=plain_indices.dtype,
plain_indices_device=plain_indices.device,
values_shape=values.shape,
values_dtype=values.dtype,
values_device=values.device)
for index_dtype in [torch.int32, torch.int64]:
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
dense_dim = t.dense_dim()
sparse_dim = t.sparse_dim()
batch_dim = t.ndim - sparse_dim - dense_dim
nnz = t.values().shape[batch_dim]
if layout in {torch.sparse_bsr, torch.sparse_bsc}:
blocksize = t.values().shape[batch_dim + 1: batch_dim + 1 + sparse_dim]
else:
blocksize = ()
e = torch.ops.aten._sparse_compressed_tensor_with_dims(nnz, dense_dim, t.shape, blocksize, index_dtype,
dtype=dtype, layout=layout, device=device)
e_prop, t_prop = get_sparse_compressed_tensor_properties(e), get_sparse_compressed_tensor_properties(t)
for k, v in e_prop.items():
self.assertEqual(v, t_prop[k], lambda msg: f'{msg} when comparing {k}, expected {t_prop[k]}, got {v}')
@skipMeta
@all_sparse_compressed_layouts()
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))