mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Add torch.ops.aten._sparse_compressed_tensor_with_dims (#123083)
Used in https://github.com/pytorch/pytorch/pull/123084 and allows simplifying `empty_like` implementation for sparse compressed tensors (see https://github.com/pytorch/pytorch/pull/121900#issuecomment-2029835473). Pull Request resolved: https://github.com/pytorch/pytorch/pull/123083 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
f9b2ffa7c4
commit
72662bf05b
@ -7080,6 +7080,10 @@
|
||||
# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given
|
||||
# the default would never make sense.
|
||||
|
||||
- func: _sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: sparse_compressed_tensor_with_dims
|
||||
|
||||
- func: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: sparse_compressed_tensor
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/_sparse_compressed_tensor_with_dims_native.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
||||
#include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
|
||||
@ -351,6 +352,79 @@ static SparseCsrTensor new_compressed_tensor(const TensorOptions& options) {
|
||||
return detail::make_tensor<SparseCsrTensorImpl>(DispatchKeySet(dispatch_key), options.device(), layout, options.dtype());
|
||||
}
|
||||
|
||||
Tensor sparse_compressed_tensor_with_dims(
|
||||
int64_t nnz,
|
||||
int64_t dense_dim,
|
||||
c10::IntArrayRef size,
|
||||
c10::IntArrayRef blocksize,
|
||||
ScalarType index_dtype,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
// sparse_compressed_tensor_with_dims is a generalization of empty
|
||||
// that enables the specification of nnz, dense_dim, blocksize, and
|
||||
// index_dtype for sparse compressed tensors.
|
||||
//
|
||||
// sparse_compressed_tensor_with_dims indices and values tensors are
|
||||
// created as empty tensors, so the returned sparse compressed
|
||||
// tensor will not satisfy the sparse compressed tensor
|
||||
// invariants. The caller is responsible for initializing the
|
||||
// indices tensors properly.
|
||||
TORCH_CHECK(layout, "sparse_compressed_tensor_with_dims: expected sparse compressed tensor layout but got none");
|
||||
|
||||
Layout layout_ = layout.value();
|
||||
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_with_dims", [&]{});
|
||||
|
||||
constexpr int64_t sparse_dim = 2;
|
||||
int64_t batch_dim = size.size() - dense_dim - sparse_dim;
|
||||
TORCH_CHECK(batch_dim >= 0, "sparse_compressed_tensor_with_dims: dimensionality must be at least dense_dim(=",
|
||||
dense_dim, ") + sparse_dim(=", sparse_dim, "), but got ", size.size());
|
||||
|
||||
TORCH_CHECK(nnz >= 0, "sparse_compressed_tensor_with_dims: nnz must be non-negative, got ", nnz);
|
||||
|
||||
auto plain_indices_size = DimVector(size.slice(0, batch_dim));
|
||||
auto compressed_indices_size = DimVector(size.slice(0, batch_dim));
|
||||
auto values_size = DimVector(size.slice(0, batch_dim));
|
||||
|
||||
plain_indices_size.push_back(nnz);
|
||||
values_size.push_back(nnz);
|
||||
|
||||
if (layout_ == kSparseBsr || layout_ == kSparseBsc) {
|
||||
TORCH_CHECK(blocksize.size() == (size_t)sparse_dim, "sparse_compressed_tensor_with_dims: blocksize needs to be a tuple of size ",
|
||||
sparse_dim, ", but got ", blocksize.size());
|
||||
auto d0 = (layout_ == kSparseBsr ? 0 : 1);
|
||||
auto d1 = (layout_ == kSparseBsr ? 1 : 0);
|
||||
TORCH_CHECK(blocksize[0] > 0 && blocksize[1] > 0, "sparse_compressed_tensor_with_dims: blocksize needs to be positive, but got ", blocksize);
|
||||
auto compressed_size = size[compressedDimension(layout_, size, dense_dim)];
|
||||
auto plain_size = size[plainDimension(layout_, size, dense_dim)];
|
||||
TORCH_CHECK(compressed_size % blocksize[d0] == 0, "sparse_compressed_tensor_with_dims: dimension ",
|
||||
compressedDimension(layout_, size, dense_dim), " must be multiple of blocksize[", d0, "](=", blocksize[d0], ") but got ", compressed_size);
|
||||
TORCH_CHECK(plain_size % blocksize[d1] == 0, "sparse_compressed_tensor_with_dims: dimension ", plainDimension(layout_, size, dense_dim),
|
||||
" must be multiple of blocksize[", d1, "](=", blocksize[d1], ") but got ", plain_size);
|
||||
compressed_indices_size.push_back(compressed_size / blocksize[d0] + 1);
|
||||
values_size.append(DimVector(blocksize));
|
||||
} else {
|
||||
TORCH_CHECK(blocksize.size() == 0, "sparse_compressed_tensor_with_dims: blocksize cannot be specified for non-block layout ", layout_);
|
||||
compressed_indices_size.push_back(size[compressedDimension(layout_, size, dense_dim)] + 1);
|
||||
}
|
||||
|
||||
values_size.append(DimVector(size.slice(batch_dim + sparse_dim, dense_dim)));
|
||||
TORCH_CHECK(
|
||||
index_dtype == ScalarType::Int || index_dtype == ScalarType::Long,
|
||||
"indices dtype must be Int or Long, but got ", index_dtype);
|
||||
|
||||
TensorOptions options_ = TensorOptions().layout(Layout::Strided).device(device).pinned_memory(pin_memory);
|
||||
auto compressed_indices = at::empty(compressed_indices_size, options_.dtype(index_dtype));
|
||||
auto plain_indices = at::empty(plain_indices_size, options_.dtype(index_dtype));
|
||||
auto values = at::empty(values_size, options_.dtype(dtype));
|
||||
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
|
||||
SparseCsrTensor self = new_compressed_tensor(options);
|
||||
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor _sparse_compressed_tensor_unsafe_symint(
|
||||
const Tensor& compressed_indices,
|
||||
const Tensor& plain_indices,
|
||||
|
||||
@ -506,6 +506,7 @@ aten::_sparse_addmm.out
|
||||
aten::_sparse_broadcast_to
|
||||
aten::_sparse_broadcast_to_copy
|
||||
aten::_sparse_broadcast_to_copy.out
|
||||
aten::_sparse_compressed_tensor_with_dims
|
||||
aten::_sparse_coo_tensor_with_dims
|
||||
aten::_sparse_coo_tensor_with_dims.out
|
||||
aten::_sparse_coo_tensor_with_dims_and_tensors
|
||||
|
||||
@ -336,6 +336,46 @@ class TestSparseCompressed(TestCase):
|
||||
", but got size"):
|
||||
torch.empty((5,), dtype=dtype, device=device, layout=layout)
|
||||
|
||||
@skipMeta
|
||||
@all_sparse_compressed_layouts()
|
||||
@dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
|
||||
def test_sparse_compressed_tensor_with_dims(self, layout, device, dtype):
|
||||
|
||||
def get_sparse_compressed_tensor_properties(s):
|
||||
if layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
compressed_indices, plain_indices = s.crow_indices(), s.col_indices()
|
||||
else:
|
||||
compressed_indices, plain_indices = s.ccol_indices(), s.row_indices()
|
||||
values = s.values()
|
||||
return dict(shape=s.shape, dtype=s.dtype, device=s.device, nnz=s._nnz(), layout=s.layout,
|
||||
compressed_indices_shape=compressed_indices.shape,
|
||||
compressed_indices_dtype=compressed_indices.dtype,
|
||||
compressed_indices_device=compressed_indices.device,
|
||||
plain_indices_shape=plain_indices.shape,
|
||||
plain_indices_dtype=plain_indices.dtype,
|
||||
plain_indices_device=plain_indices.device,
|
||||
values_shape=values.shape,
|
||||
values_dtype=values.dtype,
|
||||
values_device=values.device)
|
||||
|
||||
for index_dtype in [torch.int32, torch.int64]:
|
||||
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
|
||||
dense_dim = t.dense_dim()
|
||||
sparse_dim = t.sparse_dim()
|
||||
batch_dim = t.ndim - sparse_dim - dense_dim
|
||||
nnz = t.values().shape[batch_dim]
|
||||
if layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||
blocksize = t.values().shape[batch_dim + 1: batch_dim + 1 + sparse_dim]
|
||||
else:
|
||||
blocksize = ()
|
||||
|
||||
e = torch.ops.aten._sparse_compressed_tensor_with_dims(nnz, dense_dim, t.shape, blocksize, index_dtype,
|
||||
dtype=dtype, layout=layout, device=device)
|
||||
|
||||
e_prop, t_prop = get_sparse_compressed_tensor_properties(e), get_sparse_compressed_tensor_properties(t)
|
||||
for k, v in e_prop.items():
|
||||
self.assertEqual(v, t_prop[k], lambda msg: f'{msg} when comparing {k}, expected {t_prop[k]}, got {v}')
|
||||
|
||||
@skipMeta
|
||||
@all_sparse_compressed_layouts()
|
||||
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
|
||||
|
||||
@ -892,62 +892,33 @@ class MetaConverter:
|
||||
elif is_sparse_compressed_layout(t.layout):
|
||||
is_leaf = t.is_leaf
|
||||
|
||||
def mk_meta():
|
||||
if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
||||
assert t.sparse_dim is not None
|
||||
assert t.dense_dim is not None
|
||||
nnz = 0
|
||||
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
|
||||
batch_size = t.shape[:batch_dim]
|
||||
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
assert t.crow_indices is not None
|
||||
assert t.col_indices is not None
|
||||
index_dtype = t.crow_indices.dtype
|
||||
compressed_indices = torch.empty(
|
||||
t.crow_indices.shape, device="meta", dtype=index_dtype
|
||||
)
|
||||
plain_indices = torch.empty(
|
||||
(*t.col_indices.shape[:-1], nnz),
|
||||
device="meta",
|
||||
dtype=index_dtype,
|
||||
)
|
||||
else:
|
||||
assert t.ccol_indices is not None
|
||||
assert t.row_indices is not None
|
||||
index_dtype = t.ccol_indices.dtype
|
||||
compressed_indices = torch.empty(
|
||||
t.ccol_indices.shape, device="meta", dtype=index_dtype
|
||||
)
|
||||
plain_indices = torch.empty(
|
||||
(*t.row_indices.shape[:-1], nnz),
|
||||
device="meta",
|
||||
dtype=index_dtype,
|
||||
)
|
||||
assert t.values is not None
|
||||
values_shape = t.values.shape
|
||||
values = torch.empty(
|
||||
(
|
||||
*values_shape[:batch_dim],
|
||||
nnz,
|
||||
*values_shape[batch_dim + 1 :],
|
||||
),
|
||||
dtype=t.dtype,
|
||||
device="meta",
|
||||
)
|
||||
return torch.ops.aten.sparse_compressed_tensor(
|
||||
compressed_indices,
|
||||
plain_indices,
|
||||
values,
|
||||
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
|
||||
blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
|
||||
else:
|
||||
blocksize = ()
|
||||
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
assert t.crow_indices is not None
|
||||
index_dtype = t.crow_indices.dtype
|
||||
else:
|
||||
assert t.ccol_indices is not None
|
||||
index_dtype = t.ccol_indices.dtype
|
||||
|
||||
r = callback(
|
||||
lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
|
||||
0,
|
||||
t.dense_dim,
|
||||
t.shape,
|
||||
blocksize,
|
||||
index_dtype,
|
||||
layout=t.layout,
|
||||
dtype=t.dtype,
|
||||
device="meta",
|
||||
)
|
||||
|
||||
# `mk_meta()` is similar to `t.to(device='meta'))`
|
||||
# except `to('meta')` preserves nnz value while
|
||||
# `mk_meta` result has nnz == 0.
|
||||
r = callback(mk_meta)
|
||||
|
||||
)
|
||||
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
||||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
|
||||
Reference in New Issue
Block a user