mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[SparseCompressed] support csc layout for add sparse/dense. (#115433)
`add` when passed one sparse and one dense argument will error if the sparse argument does not have csr layout. This PR modifies the underlying algorithm to be generic on the compressed dimension handling both csr and csc. The functions are renamed to use the `sparse_compressed` qualifier rather than `sparse_csr` Fixes: #114807 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115433 Approved by: https://github.com/cpuhrsch, https://github.com/pearu ghstack dependencies: #115432
This commit is contained in:
committed by
PyTorch MergeBot
parent
910baa3a03
commit
4b97ed2ed8
@ -564,8 +564,8 @@
|
||||
dispatch:
|
||||
SparseCPU: add_out_sparse_cpu
|
||||
SparseCUDA: add_out_sparse_cuda
|
||||
SparseCsrCPU: add_out_sparse_csr_cpu
|
||||
SparseCsrCUDA: add_out_sparse_csr_cuda
|
||||
SparseCsrCPU: add_out_sparse_compressed_cpu
|
||||
SparseCsrCUDA: add_out_sparse_compressed_cuda
|
||||
MkldnnCPU: mkldnn_add_out
|
||||
MPS: add_out_mps
|
||||
tags: pointwise
|
||||
|
||||
@ -851,13 +851,14 @@ Tensor& add_sparse_csr_(
|
||||
return at::add_out(self, self, other, alpha); // redispatch!
|
||||
}
|
||||
|
||||
static void add_out_dense_sparse_csr_cpu(
|
||||
static void add_out_dense_sparse_compressed_cpu(
|
||||
const Tensor& out,
|
||||
const Tensor& dense,
|
||||
const SparseCsrTensor& src,
|
||||
const Scalar& alpha) {
|
||||
TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
|
||||
TORCH_INTERNAL_ASSERT(src.is_sparse_csr());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.layout() == kSparseCsr || src.layout() == kSparseCsc);
|
||||
TORCH_INTERNAL_ASSERT(dense.device() == kCPU);
|
||||
|
||||
TORCH_CHECK(
|
||||
@ -908,8 +909,15 @@ static void add_out_dense_sparse_csr_cpu(
|
||||
|
||||
auto valuesBuffer = src_values.to(commonDtype).reshape({-1, src_values.size(-1)});
|
||||
resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
|
||||
auto src_crow_indices = src.crow_indices().reshape({-1, src.crow_indices().size(-1)});
|
||||
auto src_col_indices = src.col_indices().reshape({-1, src.col_indices().size(-1)});
|
||||
Tensor src_compressed_indices;
|
||||
Tensor src_plain_indices;
|
||||
std::tie(src_compressed_indices, src_plain_indices) =
|
||||
at::sparse_csr::getCompressedPlainIndices(src);
|
||||
src_compressed_indices =
|
||||
src_compressed_indices.reshape({-1, src_compressed_indices.size(-1)});
|
||||
src_plain_indices =
|
||||
src_plain_indices.reshape({-1, src_plain_indices.size(-1)});
|
||||
auto src_layout = src.layout();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
kComplexHalf,
|
||||
@ -921,35 +929,57 @@ static void add_out_dense_sparse_csr_cpu(
|
||||
[&valuesBuffer,
|
||||
&resultBuffer,
|
||||
&alpha,
|
||||
&src_crow_indices,
|
||||
&src_col_indices]() {
|
||||
&src_compressed_indices,
|
||||
&src_plain_indices,
|
||||
&src_layout]() {
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
src_crow_indices.scalar_type(),
|
||||
src_compressed_indices.scalar_type(),
|
||||
"csr_add_out_crow_indices",
|
||||
[&valuesBuffer,
|
||||
&resultBuffer,
|
||||
&alpha,
|
||||
&src_crow_indices,
|
||||
&src_col_indices]() {
|
||||
auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
|
||||
&src_compressed_indices,
|
||||
&src_plain_indices,
|
||||
&src_layout]() {
|
||||
auto batch_count =
|
||||
resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
|
||||
auto values_accessor = valuesBuffer.accessor<scalar_t, 2>();
|
||||
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
|
||||
scalar_t cast_value = alpha.to<scalar_t>();
|
||||
|
||||
auto crow_indices_accessor =
|
||||
src_crow_indices.accessor<index_t, 2>();
|
||||
auto col_indices_accessor =
|
||||
src_col_indices.accessor<index_t, 2>();
|
||||
auto compressed_indices_accessor =
|
||||
src_compressed_indices.accessor<index_t, 2>();
|
||||
auto plain_indices_accessor =
|
||||
src_plain_indices.accessor<index_t, 2>();
|
||||
auto out_strides = resultBuffer.strides();
|
||||
auto const out_stride_batch = out_strides[0];
|
||||
auto const out_stride_compressed =
|
||||
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
||||
src_layout,
|
||||
"add_out_dense_sparse_compressed_cpu",
|
||||
[&out_strides] { return out_strides[1]; },
|
||||
[&out_strides] { return out_strides[2]; });
|
||||
auto const out_stride_plain =
|
||||
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
||||
src_layout,
|
||||
"add_out_dense_sparse_compressed_cpu",
|
||||
[&out_strides] { return out_strides[2]; },
|
||||
[&out_strides] { return out_strides[1]; });
|
||||
|
||||
for (const auto batch_idx : c10::irange(batch_count)) {
|
||||
for (const auto irow : c10::irange(src_crow_indices.size(-1) - 1)) {
|
||||
index_t start_index = crow_indices_accessor[batch_idx][irow];
|
||||
index_t end_index = crow_indices_accessor[batch_idx][irow + 1];
|
||||
for (const auto i_compressed :
|
||||
c10::irange(src_compressed_indices.size(-1) - 1)) {
|
||||
index_t start_index =
|
||||
compressed_indices_accessor[batch_idx][i_compressed];
|
||||
index_t end_index =
|
||||
compressed_indices_accessor[batch_idx][i_compressed + 1];
|
||||
for (const auto i : c10::irange(start_index, end_index)) {
|
||||
auto icol = col_indices_accessor[batch_idx][i];
|
||||
auto index = batch_idx * out_strides[0] + irow * out_strides[1] + icol * out_strides[2];
|
||||
out_ptr[index] += cast_value * values_accessor[batch_idx][i];
|
||||
auto i_plain = plain_indices_accessor[batch_idx][i];
|
||||
auto index = batch_idx * out_stride_batch +
|
||||
i_compressed * out_stride_compressed +
|
||||
i_plain * out_stride_plain;
|
||||
out_ptr[index] +=
|
||||
cast_value * values_accessor[batch_idx][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -960,15 +990,15 @@ static void add_out_dense_sparse_csr_cpu(
|
||||
}
|
||||
}
|
||||
|
||||
Tensor& add_out_sparse_csr_cpu(
|
||||
Tensor& add_out_sparse_compressed_cpu(
|
||||
const Tensor& self,
|
||||
const SparseCsrTensor& other,
|
||||
const Scalar& alpha,
|
||||
SparseCsrTensor& out) {
|
||||
if (self.layout() == kStrided) {
|
||||
add_out_dense_sparse_csr_cpu(out, self, other, alpha);
|
||||
add_out_dense_sparse_compressed_cpu(out, self, other, alpha);
|
||||
} else if (other.layout() == kStrided) {
|
||||
add_out_dense_sparse_csr_cpu(out, other, self, alpha);
|
||||
add_out_dense_sparse_compressed_cpu(out, other, self, alpha);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
self.sizes().equals(other.sizes()),
|
||||
|
||||
@ -124,13 +124,14 @@ using namespace at::sparse_csr;
|
||||
// certain utiliy functions are usable from sparse COO.
|
||||
using namespace at::sparse;
|
||||
|
||||
Tensor& add_out_dense_sparse_csr_cuda(
|
||||
Tensor& add_out_dense_sparse_compressed_cuda(
|
||||
Tensor& output,
|
||||
const Tensor& dense,
|
||||
const SparseCsrTensor& src,
|
||||
const Scalar& alpha) {
|
||||
TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
|
||||
TORCH_INTERNAL_ASSERT(src.is_sparse_csr());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.layout() == kSparseCsr || src.layout() == kSparseCsc);
|
||||
TORCH_INTERNAL_ASSERT(dense.is_cuda());
|
||||
|
||||
TORCH_CHECK(
|
||||
@ -182,67 +183,111 @@ Tensor& add_out_dense_sparse_csr_cuda(
|
||||
|
||||
auto valuesBuffer = src_values.to(commonDtype).reshape({-1, src_values.size(-1)}).contiguous();
|
||||
resultBuffer = resultBuffer.view({-1, output.size(-2), output.size(-1)});
|
||||
auto src_crow_indices = src.crow_indices().reshape({-1, src.crow_indices().size(-1)}).contiguous();
|
||||
auto src_col_indices = src.col_indices().reshape({-1, src.col_indices().size(-1)}).contiguous();
|
||||
Tensor src_compressed_indices;
|
||||
Tensor src_plain_indices;
|
||||
std::tie(src_compressed_indices, src_plain_indices) =
|
||||
at::sparse_csr::getCompressedPlainIndices(src);
|
||||
src_compressed_indices =
|
||||
src_compressed_indices.reshape({-1, src_compressed_indices.size(-1)});
|
||||
src_plain_indices =
|
||||
src_plain_indices.reshape({-1, src_plain_indices.size(-1)});
|
||||
auto src_layout = src.layout();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
kComplexHalf, kHalf, kBool, kBFloat16,
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
kBFloat16,
|
||||
commonDtype,
|
||||
"add_out_op2_sparse_csr",
|
||||
[&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
|
||||
[&valuesBuffer,
|
||||
&resultBuffer,
|
||||
&alpha,
|
||||
&src_compressed_indices,
|
||||
&src_plain_indices,
|
||||
&src_layout]() {
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
src_crow_indices.scalar_type(),
|
||||
src_compressed_indices.scalar_type(),
|
||||
"csr_add_out_crow_indices",
|
||||
[&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
|
||||
auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
|
||||
scalar_t* values_accessor = valuesBuffer.data_ptr<scalar_t>();
|
||||
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
|
||||
scalar_t cast_value = alpha.to<scalar_t>();
|
||||
[&valuesBuffer,
|
||||
&resultBuffer,
|
||||
&alpha,
|
||||
&src_compressed_indices,
|
||||
&src_plain_indices,
|
||||
&src_layout]() {
|
||||
auto batch_count =
|
||||
resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
|
||||
scalar_t* values_accessor = valuesBuffer.data_ptr<scalar_t>();
|
||||
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
|
||||
scalar_t cast_value = alpha.to<scalar_t>();
|
||||
|
||||
index_t* crow_indices_accessor = src_crow_indices.data_ptr<index_t>();
|
||||
index_t* col_indices_accessor = src_col_indices.data_ptr<index_t>();
|
||||
int64_t out_storage_offset = resultBuffer.storage_offset();
|
||||
index_t* compressed_indices_accessor =
|
||||
src_compressed_indices.data_ptr<index_t>();
|
||||
index_t* plain_indices_accessor =
|
||||
src_plain_indices.data_ptr<index_t>();
|
||||
int64_t out_storage_offset = resultBuffer.storage_offset();
|
||||
|
||||
auto out_strides = resultBuffer.strides();
|
||||
auto out_strides0 = out_strides[0];
|
||||
auto out_strides1 = out_strides[1];
|
||||
auto crow_stride0 = src_crow_indices.stride(0);
|
||||
auto col_stride0 = src_col_indices.stride(0);
|
||||
auto val_stride0 = valuesBuffer.stride(0);
|
||||
auto out_strides = resultBuffer.strides();
|
||||
auto const out_stride_batch = out_strides[0];
|
||||
auto const out_stride_compressed =
|
||||
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
||||
src_layout,
|
||||
"add_out_dense_sparse_compressed_cpu",
|
||||
[&out_strides] { return out_strides[1]; },
|
||||
[&out_strides] { return out_strides[2]; });
|
||||
auto const out_stride_plain =
|
||||
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
||||
src_layout,
|
||||
"add_out_dense_sparse_compressed_cpu",
|
||||
[&out_strides] { return out_strides[2]; },
|
||||
[&out_strides] { return out_strides[1]; });
|
||||
auto compressed_stride0 = src_compressed_indices.stride(0);
|
||||
auto plain_stride0 = src_plain_indices.stride(0);
|
||||
auto val_stride0 = valuesBuffer.stride(0);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
at::cuda::ThrustAllocator allocator;
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
at::cuda::ThrustAllocator allocator;
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
|
||||
// Note that this could be wildly imbalanced if the sparsity pattern varies a lot between rows.
|
||||
thrust::for_each(
|
||||
policy,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_counting_iterator(int64_t(src_crow_indices.size(-1) - 1)),
|
||||
[values_accessor,
|
||||
crow_indices_accessor,
|
||||
col_indices_accessor,
|
||||
out_ptr,
|
||||
cast_value,
|
||||
out_strides0,
|
||||
out_strides1,
|
||||
crow_stride0,
|
||||
col_stride0,
|
||||
val_stride0,
|
||||
batch_count
|
||||
]__device__(int64_t irow) {
|
||||
for (index_t batch_idx = 0; batch_idx < batch_count; batch_idx++) {
|
||||
index_t start_index = crow_indices_accessor[batch_idx*crow_stride0 + irow];
|
||||
index_t end_index = crow_indices_accessor[batch_idx*crow_stride0 + irow + 1];
|
||||
// Note that this could be wildly imbalanced if the sparsity
|
||||
// pattern varies a lot between slices along the compressed
|
||||
// dimension.
|
||||
thrust::for_each(
|
||||
policy,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_counting_iterator(
|
||||
int64_t(src_compressed_indices.size(-1) - 1)),
|
||||
[values_accessor,
|
||||
compressed_indices_accessor,
|
||||
plain_indices_accessor,
|
||||
out_ptr,
|
||||
cast_value,
|
||||
out_stride_batch,
|
||||
out_stride_compressed,
|
||||
out_stride_plain,
|
||||
compressed_stride0,
|
||||
plain_stride0,
|
||||
val_stride0,
|
||||
batch_count] __device__(int64_t i_compressed) {
|
||||
for (index_t batch_idx = 0; batch_idx < batch_count;
|
||||
batch_idx++) {
|
||||
index_t start_index = compressed_indices_accessor
|
||||
[batch_idx * compressed_stride0 + i_compressed];
|
||||
index_t end_index = compressed_indices_accessor
|
||||
[batch_idx * compressed_stride0 + i_compressed + 1];
|
||||
|
||||
for (index_t i = start_index; i < end_index; ++i) {
|
||||
auto icol = col_indices_accessor[batch_idx*col_stride0 + i];
|
||||
auto index = batch_idx * out_strides0 + irow * out_strides1 + icol;
|
||||
out_ptr[index] += cast_value * values_accessor[batch_idx*val_stride0 + i];
|
||||
}
|
||||
for (index_t i = start_index; i < end_index; ++i) {
|
||||
auto i_plain = plain_indices_accessor
|
||||
[batch_idx * plain_stride0 + i];
|
||||
auto index = batch_idx * out_stride_batch +
|
||||
i_compressed * out_stride_compressed +
|
||||
i_plain * out_stride_plain;
|
||||
out_ptr[index] += cast_value *
|
||||
values_accessor[batch_idx * val_stride0 + i];
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
if (output.scalar_type() != commonDtype) {
|
||||
output.copy_(resultBuffer);
|
||||
@ -250,15 +295,15 @@ Tensor& add_out_dense_sparse_csr_cuda(
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& add_out_sparse_csr_cuda(
|
||||
Tensor& add_out_sparse_compressed_cuda(
|
||||
const Tensor& self,
|
||||
const SparseCsrTensor& other,
|
||||
const Scalar& alpha,
|
||||
SparseCsrTensor& out) {
|
||||
if (self.layout() == kStrided) {
|
||||
add_out_dense_sparse_csr_cuda(out, self, other, alpha);
|
||||
add_out_dense_sparse_compressed_cuda(out, self, other, alpha);
|
||||
} else if (other.layout() == kStrided) {
|
||||
add_out_dense_sparse_csr_cuda(out, other, self, alpha);
|
||||
add_out_dense_sparse_compressed_cuda(out, other, self, alpha);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
self.sizes().equals(other.sizes()),
|
||||
|
||||
@ -2109,13 +2109,19 @@ class TestSparseCSR(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
|
||||
test(is_sparse=True)
|
||||
|
||||
@sparse_compressed_nonblock_layouts()
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_add(self, device, dtype):
|
||||
def test_add(self, device, layout, dtype):
|
||||
def _test_spadd_shape(nnz, shape):
|
||||
# sparse.to_dense() uses torch.add internally so if torch.add is wrong,
|
||||
# the dense tensor will be wrong but this test would still pass
|
||||
# there's a separate test that checks for the correctness of the .to_dense() call
|
||||
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
|
||||
x = self.genSparseCompressedTensor(shape, nnz,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
index_dtype=torch.int32,
|
||||
layout=layout,
|
||||
blocksize=())
|
||||
y = torch.randn(*shape, dtype=dtype, device=device)
|
||||
r = random.random()
|
||||
|
||||
@ -2140,6 +2146,7 @@ class TestSparseCSR(TestCase):
|
||||
self.assertEqual(res, expected)
|
||||
self.assertEqual(res_perm, expected)
|
||||
|
||||
|
||||
ns = [2, 5]
|
||||
batch_shapes = [(), (2,), (2, 3)]
|
||||
for b, m, n in itertools.product(batch_shapes, ns, ns):
|
||||
|
||||
Reference in New Issue
Block a user