mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Add torch.ops.aten._sparse_compressed_tensor_with_dims (#123083)
Used in https://github.com/pytorch/pytorch/pull/123084 and allows simplifying `empty_like` implementation for sparse compressed tensors (see https://github.com/pytorch/pytorch/pull/121900#issuecomment-2029835473). Pull Request resolved: https://github.com/pytorch/pytorch/pull/123083 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
f9b2ffa7c4
commit
72662bf05b
@ -336,6 +336,46 @@ class TestSparseCompressed(TestCase):
|
||||
", but got size"):
|
||||
torch.empty((5,), dtype=dtype, device=device, layout=layout)
|
||||
|
||||
@skipMeta
|
||||
@all_sparse_compressed_layouts()
|
||||
@dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
|
||||
def test_sparse_compressed_tensor_with_dims(self, layout, device, dtype):
|
||||
|
||||
def get_sparse_compressed_tensor_properties(s):
|
||||
if layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
compressed_indices, plain_indices = s.crow_indices(), s.col_indices()
|
||||
else:
|
||||
compressed_indices, plain_indices = s.ccol_indices(), s.row_indices()
|
||||
values = s.values()
|
||||
return dict(shape=s.shape, dtype=s.dtype, device=s.device, nnz=s._nnz(), layout=s.layout,
|
||||
compressed_indices_shape=compressed_indices.shape,
|
||||
compressed_indices_dtype=compressed_indices.dtype,
|
||||
compressed_indices_device=compressed_indices.device,
|
||||
plain_indices_shape=plain_indices.shape,
|
||||
plain_indices_dtype=plain_indices.dtype,
|
||||
plain_indices_device=plain_indices.device,
|
||||
values_shape=values.shape,
|
||||
values_dtype=values.dtype,
|
||||
values_device=values.device)
|
||||
|
||||
for index_dtype in [torch.int32, torch.int64]:
|
||||
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
|
||||
dense_dim = t.dense_dim()
|
||||
sparse_dim = t.sparse_dim()
|
||||
batch_dim = t.ndim - sparse_dim - dense_dim
|
||||
nnz = t.values().shape[batch_dim]
|
||||
if layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||
blocksize = t.values().shape[batch_dim + 1: batch_dim + 1 + sparse_dim]
|
||||
else:
|
||||
blocksize = ()
|
||||
|
||||
e = torch.ops.aten._sparse_compressed_tensor_with_dims(nnz, dense_dim, t.shape, blocksize, index_dtype,
|
||||
dtype=dtype, layout=layout, device=device)
|
||||
|
||||
e_prop, t_prop = get_sparse_compressed_tensor_properties(e), get_sparse_compressed_tensor_properties(t)
|
||||
for k, v in e_prop.items():
|
||||
self.assertEqual(v, t_prop[k], lambda msg: f'{msg} when comparing {k}, expected {t_prop[k]}, got {v}')
|
||||
|
||||
@skipMeta
|
||||
@all_sparse_compressed_layouts()
|
||||
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
|
||||
|
Reference in New Issue
Block a user