[SparseCompressed] support csc layout for add sparse/dense. (#115433)

`add` when passed one sparse and one dense argument  will error if the
sparse argument does not have  csr layout. This PR modifies the
underlying algorithm to be generic on the compressed dimension handling
both csr and csc. The functions are renamed to use the
`sparse_compressed` qualifier rather than `sparse_csr`

Fixes: #114807

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115433
Approved by: https://github.com/cpuhrsch, https://github.com/pearu
ghstack dependencies: #115432
This commit is contained in:
Andrew M. James
2023-12-21 17:01:19 -06:00
committed by PyTorch MergeBot
parent 910baa3a03
commit 4b97ed2ed8
4 changed files with 163 additions and 81 deletions

View File

@ -564,8 +564,8 @@
dispatch:
SparseCPU: add_out_sparse_cpu
SparseCUDA: add_out_sparse_cuda
SparseCsrCPU: add_out_sparse_csr_cpu
SparseCsrCUDA: add_out_sparse_csr_cuda
SparseCsrCPU: add_out_sparse_compressed_cpu
SparseCsrCUDA: add_out_sparse_compressed_cuda
MkldnnCPU: mkldnn_add_out
MPS: add_out_mps
tags: pointwise

View File

@ -851,13 +851,14 @@ Tensor& add_sparse_csr_(
return at::add_out(self, self, other, alpha); // redispatch!
}
static void add_out_dense_sparse_csr_cpu(
static void add_out_dense_sparse_compressed_cpu(
const Tensor& out,
const Tensor& dense,
const SparseCsrTensor& src,
const Scalar& alpha) {
TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
TORCH_INTERNAL_ASSERT(src.is_sparse_csr());
TORCH_INTERNAL_ASSERT(
src.layout() == kSparseCsr || src.layout() == kSparseCsc);
TORCH_INTERNAL_ASSERT(dense.device() == kCPU);
TORCH_CHECK(
@ -908,8 +909,15 @@ static void add_out_dense_sparse_csr_cpu(
auto valuesBuffer = src_values.to(commonDtype).reshape({-1, src_values.size(-1)});
resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
auto src_crow_indices = src.crow_indices().reshape({-1, src.crow_indices().size(-1)});
auto src_col_indices = src.col_indices().reshape({-1, src.col_indices().size(-1)});
Tensor src_compressed_indices;
Tensor src_plain_indices;
std::tie(src_compressed_indices, src_plain_indices) =
at::sparse_csr::getCompressedPlainIndices(src);
src_compressed_indices =
src_compressed_indices.reshape({-1, src_compressed_indices.size(-1)});
src_plain_indices =
src_plain_indices.reshape({-1, src_plain_indices.size(-1)});
auto src_layout = src.layout();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf,
@ -921,35 +929,57 @@ static void add_out_dense_sparse_csr_cpu(
[&valuesBuffer,
&resultBuffer,
&alpha,
&src_crow_indices,
&src_col_indices]() {
&src_compressed_indices,
&src_plain_indices,
&src_layout]() {
AT_DISPATCH_INDEX_TYPES(
src_crow_indices.scalar_type(),
src_compressed_indices.scalar_type(),
"csr_add_out_crow_indices",
[&valuesBuffer,
&resultBuffer,
&alpha,
&src_crow_indices,
&src_col_indices]() {
auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
&src_compressed_indices,
&src_plain_indices,
&src_layout]() {
auto batch_count =
resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
auto values_accessor = valuesBuffer.accessor<scalar_t, 2>();
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
auto crow_indices_accessor =
src_crow_indices.accessor<index_t, 2>();
auto col_indices_accessor =
src_col_indices.accessor<index_t, 2>();
auto compressed_indices_accessor =
src_compressed_indices.accessor<index_t, 2>();
auto plain_indices_accessor =
src_plain_indices.accessor<index_t, 2>();
auto out_strides = resultBuffer.strides();
auto const out_stride_batch = out_strides[0];
auto const out_stride_compressed =
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
src_layout,
"add_out_dense_sparse_compressed_cpu",
[&out_strides] { return out_strides[1]; },
[&out_strides] { return out_strides[2]; });
auto const out_stride_plain =
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
src_layout,
"add_out_dense_sparse_compressed_cpu",
[&out_strides] { return out_strides[2]; },
[&out_strides] { return out_strides[1]; });
for (const auto batch_idx : c10::irange(batch_count)) {
for (const auto irow : c10::irange(src_crow_indices.size(-1) - 1)) {
index_t start_index = crow_indices_accessor[batch_idx][irow];
index_t end_index = crow_indices_accessor[batch_idx][irow + 1];
for (const auto i_compressed :
c10::irange(src_compressed_indices.size(-1) - 1)) {
index_t start_index =
compressed_indices_accessor[batch_idx][i_compressed];
index_t end_index =
compressed_indices_accessor[batch_idx][i_compressed + 1];
for (const auto i : c10::irange(start_index, end_index)) {
auto icol = col_indices_accessor[batch_idx][i];
auto index = batch_idx * out_strides[0] + irow * out_strides[1] + icol * out_strides[2];
out_ptr[index] += cast_value * values_accessor[batch_idx][i];
auto i_plain = plain_indices_accessor[batch_idx][i];
auto index = batch_idx * out_stride_batch +
i_compressed * out_stride_compressed +
i_plain * out_stride_plain;
out_ptr[index] +=
cast_value * values_accessor[batch_idx][i];
}
}
}
@ -960,15 +990,15 @@ static void add_out_dense_sparse_csr_cpu(
}
}
Tensor& add_out_sparse_csr_cpu(
Tensor& add_out_sparse_compressed_cpu(
const Tensor& self,
const SparseCsrTensor& other,
const Scalar& alpha,
SparseCsrTensor& out) {
if (self.layout() == kStrided) {
add_out_dense_sparse_csr_cpu(out, self, other, alpha);
add_out_dense_sparse_compressed_cpu(out, self, other, alpha);
} else if (other.layout() == kStrided) {
add_out_dense_sparse_csr_cpu(out, other, self, alpha);
add_out_dense_sparse_compressed_cpu(out, other, self, alpha);
} else {
TORCH_CHECK(
self.sizes().equals(other.sizes()),

View File

@ -124,13 +124,14 @@ using namespace at::sparse_csr;
// certain utiliy functions are usable from sparse COO.
using namespace at::sparse;
Tensor& add_out_dense_sparse_csr_cuda(
Tensor& add_out_dense_sparse_compressed_cuda(
Tensor& output,
const Tensor& dense,
const SparseCsrTensor& src,
const Scalar& alpha) {
TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
TORCH_INTERNAL_ASSERT(src.is_sparse_csr());
TORCH_INTERNAL_ASSERT(
src.layout() == kSparseCsr || src.layout() == kSparseCsc);
TORCH_INTERNAL_ASSERT(dense.is_cuda());
TORCH_CHECK(
@ -182,67 +183,111 @@ Tensor& add_out_dense_sparse_csr_cuda(
auto valuesBuffer = src_values.to(commonDtype).reshape({-1, src_values.size(-1)}).contiguous();
resultBuffer = resultBuffer.view({-1, output.size(-2), output.size(-1)});
auto src_crow_indices = src.crow_indices().reshape({-1, src.crow_indices().size(-1)}).contiguous();
auto src_col_indices = src.col_indices().reshape({-1, src.col_indices().size(-1)}).contiguous();
Tensor src_compressed_indices;
Tensor src_plain_indices;
std::tie(src_compressed_indices, src_plain_indices) =
at::sparse_csr::getCompressedPlainIndices(src);
src_compressed_indices =
src_compressed_indices.reshape({-1, src_compressed_indices.size(-1)});
src_plain_indices =
src_plain_indices.reshape({-1, src_plain_indices.size(-1)});
auto src_layout = src.layout();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
kComplexHalf,
kHalf,
kBool,
kBFloat16,
commonDtype,
"add_out_op2_sparse_csr",
[&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
[&valuesBuffer,
&resultBuffer,
&alpha,
&src_compressed_indices,
&src_plain_indices,
&src_layout]() {
AT_DISPATCH_INDEX_TYPES(
src_crow_indices.scalar_type(),
src_compressed_indices.scalar_type(),
"csr_add_out_crow_indices",
[&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
scalar_t* values_accessor = valuesBuffer.data_ptr<scalar_t>();
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
[&valuesBuffer,
&resultBuffer,
&alpha,
&src_compressed_indices,
&src_plain_indices,
&src_layout]() {
auto batch_count =
resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
scalar_t* values_accessor = valuesBuffer.data_ptr<scalar_t>();
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
index_t* crow_indices_accessor = src_crow_indices.data_ptr<index_t>();
index_t* col_indices_accessor = src_col_indices.data_ptr<index_t>();
int64_t out_storage_offset = resultBuffer.storage_offset();
index_t* compressed_indices_accessor =
src_compressed_indices.data_ptr<index_t>();
index_t* plain_indices_accessor =
src_plain_indices.data_ptr<index_t>();
int64_t out_storage_offset = resultBuffer.storage_offset();
auto out_strides = resultBuffer.strides();
auto out_strides0 = out_strides[0];
auto out_strides1 = out_strides[1];
auto crow_stride0 = src_crow_indices.stride(0);
auto col_stride0 = src_col_indices.stride(0);
auto val_stride0 = valuesBuffer.stride(0);
auto out_strides = resultBuffer.strides();
auto const out_stride_batch = out_strides[0];
auto const out_stride_compressed =
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
src_layout,
"add_out_dense_sparse_compressed_cpu",
[&out_strides] { return out_strides[1]; },
[&out_strides] { return out_strides[2]; });
auto const out_stride_plain =
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
src_layout,
"add_out_dense_sparse_compressed_cpu",
[&out_strides] { return out_strides[2]; },
[&out_strides] { return out_strides[1]; });
auto compressed_stride0 = src_compressed_indices.stride(0);
auto plain_stride0 = src_plain_indices.stride(0);
auto val_stride0 = valuesBuffer.stride(0);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
// Note that this could be wildly imbalanced if the sparsity pattern varies a lot between rows.
thrust::for_each(
policy,
thrust::make_counting_iterator(int64_t(0)),
thrust::make_counting_iterator(int64_t(src_crow_indices.size(-1) - 1)),
[values_accessor,
crow_indices_accessor,
col_indices_accessor,
out_ptr,
cast_value,
out_strides0,
out_strides1,
crow_stride0,
col_stride0,
val_stride0,
batch_count
]__device__(int64_t irow) {
for (index_t batch_idx = 0; batch_idx < batch_count; batch_idx++) {
index_t start_index = crow_indices_accessor[batch_idx*crow_stride0 + irow];
index_t end_index = crow_indices_accessor[batch_idx*crow_stride0 + irow + 1];
// Note that this could be wildly imbalanced if the sparsity
// pattern varies a lot between slices along the compressed
// dimension.
thrust::for_each(
policy,
thrust::make_counting_iterator(int64_t(0)),
thrust::make_counting_iterator(
int64_t(src_compressed_indices.size(-1) - 1)),
[values_accessor,
compressed_indices_accessor,
plain_indices_accessor,
out_ptr,
cast_value,
out_stride_batch,
out_stride_compressed,
out_stride_plain,
compressed_stride0,
plain_stride0,
val_stride0,
batch_count] __device__(int64_t i_compressed) {
for (index_t batch_idx = 0; batch_idx < batch_count;
batch_idx++) {
index_t start_index = compressed_indices_accessor
[batch_idx * compressed_stride0 + i_compressed];
index_t end_index = compressed_indices_accessor
[batch_idx * compressed_stride0 + i_compressed + 1];
for (index_t i = start_index; i < end_index; ++i) {
auto icol = col_indices_accessor[batch_idx*col_stride0 + i];
auto index = batch_idx * out_strides0 + irow * out_strides1 + icol;
out_ptr[index] += cast_value * values_accessor[batch_idx*val_stride0 + i];
}
for (index_t i = start_index; i < end_index; ++i) {
auto i_plain = plain_indices_accessor
[batch_idx * plain_stride0 + i];
auto index = batch_idx * out_stride_batch +
i_compressed * out_stride_compressed +
i_plain * out_stride_plain;
out_ptr[index] += cast_value *
values_accessor[batch_idx * val_stride0 + i];
}
});
});
}
});
});
});
if (output.scalar_type() != commonDtype) {
output.copy_(resultBuffer);
@ -250,15 +295,15 @@ Tensor& add_out_dense_sparse_csr_cuda(
return output;
}
Tensor& add_out_sparse_csr_cuda(
Tensor& add_out_sparse_compressed_cuda(
const Tensor& self,
const SparseCsrTensor& other,
const Scalar& alpha,
SparseCsrTensor& out) {
if (self.layout() == kStrided) {
add_out_dense_sparse_csr_cuda(out, self, other, alpha);
add_out_dense_sparse_compressed_cuda(out, self, other, alpha);
} else if (other.layout() == kStrided) {
add_out_dense_sparse_csr_cuda(out, other, self, alpha);
add_out_dense_sparse_compressed_cuda(out, other, self, alpha);
} else {
TORCH_CHECK(
self.sizes().equals(other.sizes()),

View File

@ -2109,13 +2109,19 @@ class TestSparseCSR(TestCase):
with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
test(is_sparse=True)
@sparse_compressed_nonblock_layouts()
@dtypes(torch.float, torch.double)
def test_add(self, device, dtype):
def test_add(self, device, layout, dtype):
def _test_spadd_shape(nnz, shape):
# sparse.to_dense() uses torch.add internally so if torch.add is wrong,
# the dense tensor will be wrong but this test would still pass
# there's a separate test that checks for the correctness of the .to_dense() call
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
x = self.genSparseCompressedTensor(shape, nnz,
dtype=dtype,
device=device,
index_dtype=torch.int32,
layout=layout,
blocksize=())
y = torch.randn(*shape, dtype=dtype, device=device)
r = random.random()
@ -2140,6 +2146,7 @@ class TestSparseCSR(TestCase):
self.assertEqual(res, expected)
self.assertEqual(res_perm, expected)
ns = [2, 5]
batch_shapes = [(), (2,), (2, 3)]
for b, m, n in itertools.product(batch_shapes, ns, ns):