[BE] Add torch.ops.aten._sparse_compressed_tensor_with_dims (#123083)

Used in https://github.com/pytorch/pytorch/pull/123084 and allows simplifying `empty_like` implementation for sparse compressed tensors (see https://github.com/pytorch/pytorch/pull/121900#issuecomment-2029835473).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123083
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2024-04-02 00:24:47 +03:00
committed by PyTorch MergeBot
parent f9b2ffa7c4
commit 72662bf05b
5 changed files with 138 additions and 48 deletions

View File

@ -7080,6 +7080,10 @@
# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given
# the default would never make sense.
- func: _sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
dispatch:
CompositeExplicitAutograd: sparse_compressed_tensor_with_dims
- func: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
dispatch:
CompositeExplicitAutograd: sparse_compressed_tensor

View File

@ -22,6 +22,7 @@
#include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_compressed_tensor_with_dims_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
#include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
@ -351,6 +352,79 @@ static SparseCsrTensor new_compressed_tensor(const TensorOptions& options) {
return detail::make_tensor<SparseCsrTensorImpl>(DispatchKeySet(dispatch_key), options.device(), layout, options.dtype());
}
Tensor sparse_compressed_tensor_with_dims(
int64_t nnz,
int64_t dense_dim,
c10::IntArrayRef size,
c10::IntArrayRef blocksize,
ScalarType index_dtype,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// sparse_compressed_tensor_with_dims is a generalization of empty
// that enables the specification of nnz, dense_dim, blocksize, and
// index_dtype for sparse compressed tensors.
//
// sparse_compressed_tensor_with_dims indices and values tensors are
// created as empty tensors, so the returned sparse compressed
// tensor will not satisfy the sparse compressed tensor
// invariants. The caller is responsible for initializing the
// indices tensors properly.
TORCH_CHECK(layout, "sparse_compressed_tensor_with_dims: expected sparse compressed tensor layout but got none");
Layout layout_ = layout.value();
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_with_dims", [&]{});
constexpr int64_t sparse_dim = 2;
int64_t batch_dim = size.size() - dense_dim - sparse_dim;
TORCH_CHECK(batch_dim >= 0, "sparse_compressed_tensor_with_dims: dimensionality must be at least dense_dim(=",
dense_dim, ") + sparse_dim(=", sparse_dim, "), but got ", size.size());
TORCH_CHECK(nnz >= 0, "sparse_compressed_tensor_with_dims: nnz must be non-negative, got ", nnz);
auto plain_indices_size = DimVector(size.slice(0, batch_dim));
auto compressed_indices_size = DimVector(size.slice(0, batch_dim));
auto values_size = DimVector(size.slice(0, batch_dim));
plain_indices_size.push_back(nnz);
values_size.push_back(nnz);
if (layout_ == kSparseBsr || layout_ == kSparseBsc) {
TORCH_CHECK(blocksize.size() == (size_t)sparse_dim, "sparse_compressed_tensor_with_dims: blocksize needs to be a tuple of size ",
sparse_dim, ", but got ", blocksize.size());
auto d0 = (layout_ == kSparseBsr ? 0 : 1);
auto d1 = (layout_ == kSparseBsr ? 1 : 0);
TORCH_CHECK(blocksize[0] > 0 && blocksize[1] > 0, "sparse_compressed_tensor_with_dims: blocksize needs to be positive, but got ", blocksize);
auto compressed_size = size[compressedDimension(layout_, size, dense_dim)];
auto plain_size = size[plainDimension(layout_, size, dense_dim)];
TORCH_CHECK(compressed_size % blocksize[d0] == 0, "sparse_compressed_tensor_with_dims: dimension ",
compressedDimension(layout_, size, dense_dim), " must be multiple of blocksize[", d0, "](=", blocksize[d0], ") but got ", compressed_size);
TORCH_CHECK(plain_size % blocksize[d1] == 0, "sparse_compressed_tensor_with_dims: dimension ", plainDimension(layout_, size, dense_dim),
" must be multiple of blocksize[", d1, "](=", blocksize[d1], ") but got ", plain_size);
compressed_indices_size.push_back(compressed_size / blocksize[d0] + 1);
values_size.append(DimVector(blocksize));
} else {
TORCH_CHECK(blocksize.size() == 0, "sparse_compressed_tensor_with_dims: blocksize cannot be specified for non-block layout ", layout_);
compressed_indices_size.push_back(size[compressedDimension(layout_, size, dense_dim)] + 1);
}
values_size.append(DimVector(size.slice(batch_dim + sparse_dim, dense_dim)));
TORCH_CHECK(
index_dtype == ScalarType::Int || index_dtype == ScalarType::Long,
"indices dtype must be Int or Long, but got ", index_dtype);
TensorOptions options_ = TensorOptions().layout(Layout::Strided).device(device).pinned_memory(pin_memory);
auto compressed_indices = at::empty(compressed_indices_size, options_.dtype(index_dtype));
auto plain_indices = at::empty(plain_indices_size, options_.dtype(index_dtype));
auto values = at::empty(values_size, options_.dtype(dtype));
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_compressed_tensor(options);
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
return self;
}
Tensor _sparse_compressed_tensor_unsafe_symint(
const Tensor& compressed_indices,
const Tensor& plain_indices,

View File

@ -506,6 +506,7 @@ aten::_sparse_addmm.out
aten::_sparse_broadcast_to
aten::_sparse_broadcast_to_copy
aten::_sparse_broadcast_to_copy.out
aten::_sparse_compressed_tensor_with_dims
aten::_sparse_coo_tensor_with_dims
aten::_sparse_coo_tensor_with_dims.out
aten::_sparse_coo_tensor_with_dims_and_tensors

View File

@ -336,6 +336,46 @@ class TestSparseCompressed(TestCase):
", but got size"):
torch.empty((5,), dtype=dtype, device=device, layout=layout)
@skipMeta
@all_sparse_compressed_layouts()
@dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
def test_sparse_compressed_tensor_with_dims(self, layout, device, dtype):
def get_sparse_compressed_tensor_properties(s):
if layout in {torch.sparse_csr, torch.sparse_bsr}:
compressed_indices, plain_indices = s.crow_indices(), s.col_indices()
else:
compressed_indices, plain_indices = s.ccol_indices(), s.row_indices()
values = s.values()
return dict(shape=s.shape, dtype=s.dtype, device=s.device, nnz=s._nnz(), layout=s.layout,
compressed_indices_shape=compressed_indices.shape,
compressed_indices_dtype=compressed_indices.dtype,
compressed_indices_device=compressed_indices.device,
plain_indices_shape=plain_indices.shape,
plain_indices_dtype=plain_indices.dtype,
plain_indices_device=plain_indices.device,
values_shape=values.shape,
values_dtype=values.dtype,
values_device=values.device)
for index_dtype in [torch.int32, torch.int64]:
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
dense_dim = t.dense_dim()
sparse_dim = t.sparse_dim()
batch_dim = t.ndim - sparse_dim - dense_dim
nnz = t.values().shape[batch_dim]
if layout in {torch.sparse_bsr, torch.sparse_bsc}:
blocksize = t.values().shape[batch_dim + 1: batch_dim + 1 + sparse_dim]
else:
blocksize = ()
e = torch.ops.aten._sparse_compressed_tensor_with_dims(nnz, dense_dim, t.shape, blocksize, index_dtype,
dtype=dtype, layout=layout, device=device)
e_prop, t_prop = get_sparse_compressed_tensor_properties(e), get_sparse_compressed_tensor_properties(t)
for k, v in e_prop.items():
self.assertEqual(v, t_prop[k], lambda msg: f'{msg} when comparing {k}, expected {t_prop[k]}, got {v}')
@skipMeta
@all_sparse_compressed_layouts()
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))

View File

@ -892,62 +892,33 @@ class MetaConverter:
elif is_sparse_compressed_layout(t.layout):
is_leaf = t.is_leaf
def mk_meta():
if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
assert t.sparse_dim is not None
assert t.dense_dim is not None
nnz = 0
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
batch_size = t.shape[:batch_dim]
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
assert t.crow_indices is not None
assert t.col_indices is not None
index_dtype = t.crow_indices.dtype
compressed_indices = torch.empty(
t.crow_indices.shape, device="meta", dtype=index_dtype
)
plain_indices = torch.empty(
(*t.col_indices.shape[:-1], nnz),
device="meta",
dtype=index_dtype,
)
else:
assert t.ccol_indices is not None
assert t.row_indices is not None
index_dtype = t.ccol_indices.dtype
compressed_indices = torch.empty(
t.ccol_indices.shape, device="meta", dtype=index_dtype
)
plain_indices = torch.empty(
(*t.row_indices.shape[:-1], nnz),
device="meta",
dtype=index_dtype,
)
assert t.values is not None
values_shape = t.values.shape
values = torch.empty(
(
*values_shape[:batch_dim],
nnz,
*values_shape[batch_dim + 1 :],
),
dtype=t.dtype,
device="meta",
)
return torch.ops.aten.sparse_compressed_tensor(
compressed_indices,
plain_indices,
values,
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
else:
blocksize = ()
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
assert t.crow_indices is not None
index_dtype = t.crow_indices.dtype
else:
assert t.ccol_indices is not None
index_dtype = t.ccol_indices.dtype
r = callback(
lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
0,
t.dense_dim,
t.shape,
blocksize,
index_dtype,
layout=t.layout,
dtype=t.dtype,
device="meta",
)
# `mk_meta()` is similar to `t.to(device='meta'))`
# except `to('meta')` preserves nnz value while
# `mk_meta` result has nnz == 0.
r = callback(mk_meta)
)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = True