mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[BE] Add torch.ops.aten._sparse_compressed_tensor_with_dims (#123083)
Used in https://github.com/pytorch/pytorch/pull/123084 and allows simplifying `empty_like` implementation for sparse compressed tensors (see https://github.com/pytorch/pytorch/pull/121900#issuecomment-2029835473). Pull Request resolved: https://github.com/pytorch/pytorch/pull/123083 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
f9b2ffa7c4
commit
72662bf05b
@ -892,62 +892,33 @@ class MetaConverter:
|
||||
elif is_sparse_compressed_layout(t.layout):
|
||||
is_leaf = t.is_leaf
|
||||
|
||||
def mk_meta():
|
||||
if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||
assert t.sparse_dim is not None
|
||||
assert t.dense_dim is not None
|
||||
nnz = 0
|
||||
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
|
||||
batch_size = t.shape[:batch_dim]
|
||||
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
assert t.crow_indices is not None
|
||||
assert t.col_indices is not None
|
||||
index_dtype = t.crow_indices.dtype
|
||||
compressed_indices = torch.empty(
|
||||
t.crow_indices.shape, device="meta", dtype=index_dtype
|
||||
)
|
||||
plain_indices = torch.empty(
|
||||
(*t.col_indices.shape[:-1], nnz),
|
||||
device="meta",
|
||||
dtype=index_dtype,
|
||||
)
|
||||
else:
|
||||
assert t.ccol_indices is not None
|
||||
assert t.row_indices is not None
|
||||
index_dtype = t.ccol_indices.dtype
|
||||
compressed_indices = torch.empty(
|
||||
t.ccol_indices.shape, device="meta", dtype=index_dtype
|
||||
)
|
||||
plain_indices = torch.empty(
|
||||
(*t.row_indices.shape[:-1], nnz),
|
||||
device="meta",
|
||||
dtype=index_dtype,
|
||||
)
|
||||
assert t.values is not None
|
||||
values_shape = t.values.shape
|
||||
values = torch.empty(
|
||||
(
|
||||
*values_shape[:batch_dim],
|
||||
nnz,
|
||||
*values_shape[batch_dim + 1 :],
|
||||
),
|
||||
dtype=t.dtype,
|
||||
device="meta",
|
||||
)
|
||||
return torch.ops.aten.sparse_compressed_tensor(
|
||||
compressed_indices,
|
||||
plain_indices,
|
||||
values,
|
||||
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
|
||||
blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
|
||||
else:
|
||||
blocksize = ()
|
||||
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
assert t.crow_indices is not None
|
||||
index_dtype = t.crow_indices.dtype
|
||||
else:
|
||||
assert t.ccol_indices is not None
|
||||
index_dtype = t.ccol_indices.dtype
|
||||
|
||||
r = callback(
|
||||
lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
|
||||
0,
|
||||
t.dense_dim,
|
||||
t.shape,
|
||||
blocksize,
|
||||
index_dtype,
|
||||
layout=t.layout,
|
||||
dtype=t.dtype,
|
||||
device="meta",
|
||||
)
|
||||
|
||||
# `mk_meta()` is similar to `t.to(device='meta'))`
|
||||
# except `to('meta')` preserves nnz value while
|
||||
# `mk_meta` result has nnz == 0.
|
||||
r = callback(mk_meta)
|
||||
|
||||
)
|
||||
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
||||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
|
Reference in New Issue
Block a user