[BE] Add torch.ops.aten._sparse_compressed_tensor_with_dims (#123083)

Used in https://github.com/pytorch/pytorch/pull/123084 and allows simplifying `empty_like` implementation for sparse compressed tensors (see https://github.com/pytorch/pytorch/pull/121900#issuecomment-2029835473).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123083
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2024-04-02 00:24:47 +03:00
committed by PyTorch MergeBot
parent f9b2ffa7c4
commit 72662bf05b
5 changed files with 138 additions and 48 deletions

View File

@ -892,62 +892,33 @@ class MetaConverter:
elif is_sparse_compressed_layout(t.layout):
is_leaf = t.is_leaf
def mk_meta():
if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
assert t.sparse_dim is not None
assert t.dense_dim is not None
nnz = 0
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
batch_size = t.shape[:batch_dim]
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
assert t.crow_indices is not None
assert t.col_indices is not None
index_dtype = t.crow_indices.dtype
compressed_indices = torch.empty(
t.crow_indices.shape, device="meta", dtype=index_dtype
)
plain_indices = torch.empty(
(*t.col_indices.shape[:-1], nnz),
device="meta",
dtype=index_dtype,
)
else:
assert t.ccol_indices is not None
assert t.row_indices is not None
index_dtype = t.ccol_indices.dtype
compressed_indices = torch.empty(
t.ccol_indices.shape, device="meta", dtype=index_dtype
)
plain_indices = torch.empty(
(*t.row_indices.shape[:-1], nnz),
device="meta",
dtype=index_dtype,
)
assert t.values is not None
values_shape = t.values.shape
values = torch.empty(
(
*values_shape[:batch_dim],
nnz,
*values_shape[batch_dim + 1 :],
),
dtype=t.dtype,
device="meta",
)
return torch.ops.aten.sparse_compressed_tensor(
compressed_indices,
plain_indices,
values,
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
else:
blocksize = ()
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
assert t.crow_indices is not None
index_dtype = t.crow_indices.dtype
else:
assert t.ccol_indices is not None
index_dtype = t.ccol_indices.dtype
r = callback(
lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
0,
t.dense_dim,
t.shape,
blocksize,
index_dtype,
layout=t.layout,
dtype=t.dtype,
device="meta",
)
# `mk_meta()` is similar to `t.to(device='meta'))`
# except `to('meta')` preserves nnz value while
# `mk_meta` result has nnz == 0.
r = callback(mk_meta)
)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = True