mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Sparse-sparse matrix multiplication (CPU/CUDA) (#39526)
Summary: This PR implements matrix multiplication support for 2-d sparse tensors using the COO sparse format. The current implementation of `torch.sparse.mm` support this configuration, `torch.sparse.mm(sparse_matrix1, sparse_matrix2.to_dense())`, but this could spend a lot of memory when sparse_matrix2's shape is large. This implementation extends `torch.sparse.mm` function to support `torch.sparse.mm(sparse_matrix1, sparse_matrix2)` Resolves #[20988](https://github.com/pytorch/pytorch/issues/20988) for CPU/CUDA. - [x] sparse matmul - [x] CPU/CUDA C++ implementation - [x] unittests - [x] update torch.sparse.mm documentation - [x] autograd support The CPU sparse-sparse matmul was implemented taking as a reference this work "Sparse Matrix Multiplication Package (SMMP)". The GPU sparse-sparse matmul is based on cuSparse, there is specific code for CUSPARSE when CUSPARSE_VERSION >= 11 and old version of CUSPARSE. Both CPU/CUDA rely on the sparse-sparse matmul algorithm using the CSR indices format as it is one of the fastest algorithm. Here it is the latest benchmark (script is here) results for torch.sparse.mm (CUDA) and torch.sparse.mm (CPU) and scipy, values are float32 scalars: size | density | sparse.mm(CUDA) | sparse.mm(CPU) | scipy_coo_matmul -- | -- | -- | -- | -- (32, 10000) | 0.01 | 822.7 | 79.4 | 704.1 (32, 10000) | 0.05 | 1741.1 | 402.6 | 1155.3 (32, 10000) | 0.1 | 2956.8 | 840.8 | 1885.4 (32, 10000) | 0.25 | 6417.7 | 2832.3 | 4665.2 (512, 10000) | 0.01 | 1010.2 | 3941.3 | 26937.7 (512, 10000) | 0.05 | 2216.2 | 26903.8 | 57343.7 (512, 10000) | 0.1 | 4868.4 | 87773.7 | 117477.0 (512, 10000) | 0.25 | 16639.3 | 608105.0 | 624290.4 (1024, 10000) | 0.01 | 1224.8 | 13088.1 | 110379.2 (1024, 10000) | 0.05 | 3897.5 | 94783.9 | 236541.8 (1024, 10000) | 0.1 | 10559.1 | 405312.5 | 525483.4 (1024, 10000) | 0.25 | 57456.3 | 2424337.5 | 2729318.7 A new backward algorithm was implemented using only `sparse @ sparse` and `sparse_mask` operations. Here is some benchmarking: ``` [------------------------- sparse.mm-backward -------------------------] | sparse.backward | dense.backward ----------------------------------------------------------------------- (32, 10000) | 0.01 | 13.5 | 2.4 (32, 10000) | 0.05 | 52.3 | 2.4 (512, 10000) | 0.01 | 1016.8 | 491.5 (512, 10000) | 0.05 | 1604.3 | 492.3 (1024, 10000) | 0.01 | 2384.1 | 1963.7 (1024, 10000) | 0.05 | 3965.8 | 1951.9 ``` I added new benchmark tests. Now I am using a real dataset used in recent studies [1, 2] with different sparsity levels. ``` [---------------------------------- matmul ---------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ------------------------------------------------------------------ (cpu) torch | 5.4 | 5.4 | 5.2 | 5.3 | 5.3 | 5.4 torch.sparse | 122.2 | 51.9 | 27.5 | 11.4 | 4.9 | 1.8 scipy | 150.1 | 87.4 | 69.2 | 56.8 | 38.4 | 17.1 (cuda) torch | 1.3 | 1.1 | 1.1 | 1.1 | 1.1 | 1.1 torch.sparse | 20.0 | 8.4 | 5.1 | 2.5 | 1.5 | 1.1 [----------------------------------- backward -----------------------------------] | 0.5 | 0.7 | 0.8 | 0.9 | 0.95 | 0.98 1 threads: ----------------------------------------------------------------------- (cpu) torch | 17.7 | 17.9 | 17.7 | 17.7 | 17.6 | 17.9 torch.sparse | 672.9 | 432.6 | 327.5 | 230.8 | 176.7 | 116.7 (cuda) torch | 3.8 | 3.6 | 3.5 | 3.5 | 3.6 | 3.5 torch.sparse | 68.8 | 46.2 | 35.6 | 24.2 | 17.8 | 11.9 Times are in milliseconds (ms). ``` In summary, I can say that the new `sparse @ sparse` backward algorithm is better as it is more about saving space than performance. Moreover, it is better than other options tested before. ## **References** 1. Trevor Gale, Matei Zaharia, Cliff Young, Erich Elsen. **Sparse GPU Kernels for Deep Learning.** Proceedings of the International Conference for High Performance Computing, 2020. [https://github.com/google-research/google-research/tree/master/sgk](https://github.com/google-research/google-research/tree/master/sgk) 2. Trevor Gale, Erich Elsen, Sara Hooker. **The State of Sparsity in Deep Neural Networks.** [https://github.com/google-research/google-research/tree/master/state_of_sparsity](https://github.com/google-research/google-research/tree/master/state_of_sparsity) Pull Request resolved: https://github.com/pytorch/pytorch/pull/39526 Reviewed By: mruberry Differential Revision: D25661239 Pulled By: ngimel fbshipit-source-id: b515ecd66d25f347d637e159d51aa45fb43b6938
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3779bdec56
commit
44ce0b8883
113
aten/src/ATen/SparseTensorUtils.cpp
Normal file
113
aten/src/ATen/SparseTensorUtils.cpp
Normal file
@ -0,0 +1,113 @@
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
namespace at { namespace sparse {
|
||||
|
||||
// NOTE [ Flatten Sparse Indices ]
|
||||
// This helper function flattens a sparse indices tensor (a Tensor) into a 1D
|
||||
// indices tensor. E.g.,
|
||||
// input = [[2, 4, 0],
|
||||
// [3, 1, 10]]
|
||||
// full_size = [2, 12]
|
||||
// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
|
||||
//
|
||||
// In other words, assuming that each `indices[i, :]` is a valid index to a
|
||||
// tensor `t` of shape `full_size`. This returns the corresponding indices to
|
||||
// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
|
||||
// if forceClone is true, the result will forced to be a clone of self.
|
||||
// if force_clone is true, the result will forced to be a clone of self.
|
||||
Tensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone /*= false*/) {
|
||||
int64_t sparse_dim = indices.size(0);
|
||||
if (sparse_dim == 1) {
|
||||
if (force_clone) {
|
||||
return indices.squeeze(0).clone(at::MemoryFormat::Contiguous);
|
||||
} else {
|
||||
return indices.squeeze(0);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> indices_mult_cpu_vec;
|
||||
indices_mult_cpu_vec.reserve(sparse_dim);
|
||||
int64_t mult = 1;
|
||||
for (int64_t i = sparse_dim - 1; i >= 0; i--) {
|
||||
indices_mult_cpu_vec[i] = mult;
|
||||
mult *= full_size[i];
|
||||
}
|
||||
auto indices_mult_cpu = at::from_blob(
|
||||
indices_mult_cpu_vec.data(),
|
||||
/*size=*/{sparse_dim, 1},
|
||||
indices.options().device(kCPU));
|
||||
// NB: must be blocking because this blob may be freed after this closure,
|
||||
// and non_blocking copy will see garbage.
|
||||
auto indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false);
|
||||
// Ideally we want matmul but matmul is slow on CPU Long and not implemented
|
||||
// on CUDA Long. So mul is faster.
|
||||
return indices.mul(indices_mult).sum(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ],
|
||||
// except this one allows partial flatten: only flatten on specified dims. Note that
|
||||
// the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim.
|
||||
// Also if input indices is already coalesced, the flattened indices will also be sorted.
|
||||
//
|
||||
// args:
|
||||
// indices: sparse tensor indices
|
||||
// sizes: sparse tensor sizes
|
||||
// dims_to_flatten: a list of dim index to flatten
|
||||
//
|
||||
// Ex1:
|
||||
// indices = [[2, 4, 0],
|
||||
// [3, 1, 3]]
|
||||
// sizes = [2, 12]
|
||||
// dims_to_flatten = [0, 1]
|
||||
// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
|
||||
//
|
||||
// Ex2:
|
||||
// dims_to_flatten = [1]
|
||||
// new_indices = [ 3, 1, 3 ] # uncoalesced
|
||||
Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten){
|
||||
Tensor new_indices = at::zeros({indices.size(1)}, indices.options());
|
||||
for (auto d : dims_to_flatten) {
|
||||
new_indices.mul_(sizes[d]);
|
||||
new_indices.add_(indices.select(0, d));
|
||||
}
|
||||
return new_indices;
|
||||
}
|
||||
|
||||
Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz) {
|
||||
/*
|
||||
Find the CSR representation for a row `indices` from the COO format
|
||||
Inputs:
|
||||
`indices` is the row pointer from COO indices
|
||||
`dim` is the row dimensionality
|
||||
`nnz` is the number of non-zeros
|
||||
|
||||
Output:
|
||||
`csr` is a compressed row array in a CSR format
|
||||
*/
|
||||
Tensor csr = native::zeros({dim + 1}, kLong);
|
||||
|
||||
// TODO: eliminate this conditional when zero-size dims supported correctly
|
||||
if (nnz > 0) {
|
||||
auto csr_accessor = csr.accessor<int64_t, 1>();
|
||||
// Convert the sparse matrix to CSR format
|
||||
at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) {
|
||||
int64_t h, hp0, hp1;
|
||||
for (auto i = start; i < end; i++) {
|
||||
hp0 = indices[i];
|
||||
hp1 = (i+1 == nnz) ? dim : indices[i+1];
|
||||
if (hp0 != hp1) {
|
||||
for (h = hp0; h < hp1; h++) {
|
||||
csr_accessor[h+1] = i+1;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
return csr;
|
||||
}
|
||||
|
||||
}} // namespace at::sparse
|
@ -2,15 +2,15 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
namespace at { namespace sparse {
|
||||
|
||||
// Just for documentary purposes
|
||||
using SparseTensor = Tensor;
|
||||
using LongTensor = Tensor;
|
||||
using IntTensor = Tensor;
|
||||
using SparseType = Type;
|
||||
|
||||
|
||||
// This is an internal utility function for getting at the SparseTensorImpl,
|
||||
// so that we can write sparse tensor specific accessors for special fields
|
||||
// in SparseTensor. You should only use this for writing low level
|
||||
@ -18,20 +18,20 @@ using SparseType = Type;
|
||||
// the low level setters/getters that were implemented using this.
|
||||
//
|
||||
// This may be called repeatedly, so make sure it's pretty cheap.
|
||||
inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
|
||||
inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
|
||||
AT_ASSERTM(self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
|
||||
return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
|
||||
}
|
||||
|
||||
// Takes indices and values and directly puts them into the sparse tensor, no
|
||||
// copy. This used to be called THSTensor_(_move)
|
||||
inline void alias_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values) {
|
||||
inline void alias_into_sparse(const SparseTensor& self, const Tensor& indices, const Tensor& values) {
|
||||
get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
|
||||
}
|
||||
|
||||
// Take indices and values and makes a (data) copy of them to put into the sparse
|
||||
// indices/values. This used to be called THSTensor_(_set)
|
||||
inline void copy_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values, bool non_blocking) {
|
||||
inline void copy_into_sparse(const SparseTensor& self, const Tensor& indices, const Tensor& values, bool non_blocking) {
|
||||
alias_into_sparse(
|
||||
self,
|
||||
indices.to(self._indices().options(), non_blocking, /*copy=*/true),
|
||||
@ -58,7 +58,7 @@ inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
|
||||
}
|
||||
|
||||
// NOTE [ Flatten Sparse Indices ]
|
||||
// This helper function flattens a sparse indices tensor (a LongTensor) into a 1D
|
||||
// This helper function flattens a sparse indices tensor (a Tensor) into a 1D
|
||||
// indices tensor. E.g.,
|
||||
// input = [[2, 4, 0],
|
||||
// [3, 1, 10]]
|
||||
@ -70,34 +70,7 @@ inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
|
||||
// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
|
||||
// if forceClone is true, the result will forced to be a clone of self.
|
||||
// if force_clone is true, the result will forced to be a clone of self.
|
||||
inline LongTensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone = false) {
|
||||
int64_t sparse_dim = indices.size(0);
|
||||
if (sparse_dim == 1) {
|
||||
if (force_clone) {
|
||||
return indices.squeeze(0).clone(at::MemoryFormat::Contiguous);
|
||||
} else {
|
||||
return indices.squeeze(0);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> indices_mult_cpu_vec;
|
||||
indices_mult_cpu_vec.reserve(sparse_dim);
|
||||
int64_t mult = 1;
|
||||
for (int64_t i = sparse_dim - 1; i >= 0; i--) {
|
||||
indices_mult_cpu_vec[i] = mult;
|
||||
mult *= full_size[i];
|
||||
}
|
||||
auto indices_mult_cpu = at::from_blob(
|
||||
indices_mult_cpu_vec.data(),
|
||||
/*size=*/{sparse_dim, 1},
|
||||
indices.options().device(kCPU));
|
||||
// NB: must be blocking because this blob may be freed after this closure,
|
||||
// and non_blocking copy will see garbage.
|
||||
auto indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false);
|
||||
// Ideally we want matmul but matmul is slow on CPU Long and not implemented
|
||||
// on CUDA Long. So mul is faster.
|
||||
return indices.mul(indices_mult).sum(0);
|
||||
}
|
||||
}
|
||||
TORCH_API Tensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone = false);
|
||||
|
||||
// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ],
|
||||
// except this one allows partial flatten: only flatten on specified dims. Note that
|
||||
@ -119,13 +92,9 @@ inline LongTensor flatten_indices(const Tensor& indices, IntArrayRef full_size,
|
||||
// Ex2:
|
||||
// dims_to_flatten = [1]
|
||||
// new_indices = [ 3, 1, 3 ] # uncoalesced
|
||||
inline LongTensor flatten_indices_by_dims(const LongTensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten){
|
||||
LongTensor new_indices = at::zeros({indices.size(1)}, indices.options());
|
||||
for (auto d : dims_to_flatten) {
|
||||
new_indices.mul_(sizes[d]);
|
||||
new_indices.add_(indices.select(0, d));
|
||||
}
|
||||
return new_indices;
|
||||
}
|
||||
TORCH_API Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten);
|
||||
|
||||
// Find the CSR representation for a row `indices` from the COO format
|
||||
TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
|
||||
|
||||
}} // namespace at::sparse
|
||||
|
@ -2979,6 +2979,17 @@
|
||||
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
SparseCPU: sparse_sparse_matmul_cpu
|
||||
SparseCUDA: sparse_sparse_matmul_cuda
|
||||
|
||||
- func: _sparse_matrix_mask_helper(Tensor t, Tensor mask_indices) -> Tensor
|
||||
dispatch:
|
||||
SparseCPU: sparse_matrix_mask_helper_cpu
|
||||
SparseCUDA: sparse_matrix_mask_helper_cuda
|
||||
|
||||
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
use_c10_dispatcher: full
|
||||
variants: function, method
|
||||
|
285
aten/src/ATen/native/sparse/SparseMatMul.cpp
Normal file
285
aten/src/ATen/native/sparse/SparseMatMul.cpp
Normal file
@ -0,0 +1,285 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
using namespace at::sparse;
|
||||
|
||||
/*
|
||||
This is an implementation of the SMMP algorithm:
|
||||
"Sparse Matrix Multiplication Package (SMMP)"
|
||||
|
||||
Randolph E. Bank and Craig C. Douglas
|
||||
https://doi.org/10.1007/BF02070824
|
||||
*/
|
||||
namespace {
|
||||
void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) {
|
||||
/*
|
||||
Expands a compressed row pointer into a row indices array
|
||||
Inputs:
|
||||
`n_row` is the number of rows in `Ap`
|
||||
`Ap` is the row pointer
|
||||
|
||||
Output:
|
||||
`Bi` is the row indices
|
||||
*/
|
||||
for (int64_t i = 0; i < n_row; i++) {
|
||||
for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
|
||||
Bi[jj] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t _csr_matmult_maxnnz(
|
||||
const int64_t n_row,
|
||||
const int64_t n_col,
|
||||
const int64_t Ap[],
|
||||
const int64_t Aj[],
|
||||
const int64_t Bp[],
|
||||
const int64_t Bj[]) {
|
||||
/*
|
||||
Compute needed buffer size for matrix `C` in `C = A@B` operation.
|
||||
|
||||
The matrices should be in proper CSR structure, and their dimensions
|
||||
should be compatible.
|
||||
*/
|
||||
std::vector<int64_t> mask(n_col, -1);
|
||||
int64_t nnz = 0;
|
||||
for (int64_t i = 0; i < n_row; i++) {
|
||||
int64_t row_nnz = 0;
|
||||
|
||||
for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
|
||||
int64_t j = Aj[jj];
|
||||
for (int64_t kk = Bp[j]; kk < Bp[j + 1]; kk++) {
|
||||
int64_t k = Bj[kk];
|
||||
if (mask[k] != i) {
|
||||
mask[k] = i;
|
||||
row_nnz++;
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t next_nnz = nnz + row_nnz;
|
||||
nnz = next_nnz;
|
||||
}
|
||||
return nnz;
|
||||
}
|
||||
|
||||
template<class scalar_t>
|
||||
void _csr_matmult(
|
||||
const int64_t n_row,
|
||||
const int64_t n_col,
|
||||
const int64_t Ap[],
|
||||
const int64_t Aj[],
|
||||
const scalar_t Ax[],
|
||||
const int64_t Bp[],
|
||||
const int64_t Bj[],
|
||||
const scalar_t Bx[],
|
||||
int64_t Cp[],
|
||||
int64_t Cj[],
|
||||
scalar_t Cx[]) {
|
||||
/*
|
||||
Compute CSR entries for matrix C = A@B.
|
||||
|
||||
The matrices `A` and 'B' should be in proper CSR structure, and their dimensions
|
||||
should be compatible.
|
||||
|
||||
Inputs:
|
||||
`n_row` - number of row in A
|
||||
`n_col` - number of columns in B
|
||||
`Ap[n_row+1]` - row pointer
|
||||
`Aj[nnz(A)]` - column indices
|
||||
`Ax[nnz(A)] - nonzeros
|
||||
`Bp[?]` - row pointer
|
||||
`Bj[nnz(B)]` - column indices
|
||||
`Bx[nnz(B)]` - nonzeros
|
||||
Outputs:
|
||||
`Cp[n_row+1]` - row pointer
|
||||
`Cj[nnz(C)]` - column indices
|
||||
`Cx[nnz(C)]` - nonzeros
|
||||
|
||||
Note:
|
||||
Output arrays Cp, Cj, and Cx must be preallocated
|
||||
*/
|
||||
std::vector<int64_t> next(n_col, -1);
|
||||
std::vector<scalar_t> sums(n_col, 0);
|
||||
|
||||
int64_t nnz = 0;
|
||||
|
||||
Cp[0] = 0;
|
||||
|
||||
for (int64_t i = 0; i < n_row; i++) {
|
||||
int64_t head = -2;
|
||||
int64_t length = 0;
|
||||
|
||||
int64_t jj_start = Ap[i];
|
||||
int64_t jj_end = Ap[i + 1];
|
||||
for (int64_t jj = jj_start; jj < jj_end; jj++) {
|
||||
int64_t j = Aj[jj];
|
||||
scalar_t v = Ax[jj];
|
||||
|
||||
int64_t kk_start = Bp[j];
|
||||
int64_t kk_end = Bp[j + 1];
|
||||
for (int64_t kk = kk_start; kk < kk_end; kk++) {
|
||||
int64_t k = Bj[kk];
|
||||
|
||||
sums[k] += v * Bx[kk];
|
||||
|
||||
if (next[k] == -1) {
|
||||
next[k] = head;
|
||||
head = k;
|
||||
length++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t jj = 0; jj < length; jj++) {
|
||||
Cj[nnz] = head;
|
||||
Cx[nnz] = sums[head];
|
||||
nnz++;
|
||||
|
||||
int64_t temp = head;
|
||||
head = next[head];
|
||||
|
||||
next[temp] = -1; // clear arrays
|
||||
sums[temp] = 0;
|
||||
}
|
||||
|
||||
Cp[i + 1] = nnz;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
void sparse_matmul_kernel(
|
||||
Tensor& output,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2) {
|
||||
/*
|
||||
Computes the sparse-sparse matrix multiplication between `mat1` and `mat2`, which are sparse tensors in COO format.
|
||||
*/
|
||||
|
||||
auto M = mat1.size(0);
|
||||
auto K = mat1.size(1);
|
||||
auto N = mat2.size(1);
|
||||
|
||||
auto mat1_indices_ = mat1._indices().contiguous();
|
||||
auto mat1_values = mat1._values().contiguous();
|
||||
Tensor mat1_row_indices = mat1_indices_.select(0, 0);
|
||||
Tensor mat1_col_indices = mat1_indices_.select(0, 1);
|
||||
|
||||
Tensor mat1_indptr = coo_to_csr(mat1_row_indices.data_ptr<int64_t>(), M, mat1._nnz());
|
||||
|
||||
auto mat2_indices_ = mat2._indices().contiguous();
|
||||
auto mat2_values = mat2._values().contiguous();
|
||||
Tensor mat2_row_indices = mat2_indices_.select(0, 0);
|
||||
Tensor mat2_col_indices = mat2_indices_.select(0, 1);
|
||||
|
||||
Tensor mat2_indptr = coo_to_csr(mat2_row_indices.data_ptr<int64_t>(), K, mat2._nnz());
|
||||
|
||||
auto nnz = _csr_matmult_maxnnz(M, N, mat1_indptr.data_ptr<int64_t>(), mat1_col_indices.data_ptr<int64_t>(),
|
||||
mat2_indptr.data_ptr<int64_t>(), mat2_col_indices.data_ptr<int64_t>());
|
||||
|
||||
auto output_indices = output._indices();
|
||||
auto output_values = output._values();
|
||||
|
||||
Tensor output_indptr = at::empty({M + 1}, kLong);
|
||||
at::native::resize_output(output_indices, {2, nnz});
|
||||
at::native::resize_output(output_values, nnz);
|
||||
|
||||
Tensor output_row_indices = output_indices.select(0, 0);
|
||||
Tensor output_col_indices = output_indices.select(0, 1);
|
||||
|
||||
_csr_matmult(M, N, mat1_indptr.data_ptr<int64_t>(), mat1_col_indices.data_ptr<int64_t>(), mat1_values.data_ptr<scalar_t>(),
|
||||
mat2_indptr.data_ptr<int64_t>(), mat2_col_indices.data_ptr<int64_t>(), mat2_values.data_ptr<scalar_t>(),
|
||||
output_indptr.data_ptr<int64_t>(), output_col_indices.data_ptr<int64_t>(), output_values.data_ptr<scalar_t>());
|
||||
|
||||
csr_to_coo(M, output_indptr.data_ptr<int64_t>(), output_row_indices.data_ptr<int64_t>());
|
||||
}
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
Tensor sparse_matrix_mask_helper_cpu(
|
||||
const SparseTensor& t,
|
||||
const Tensor& mask_indices
|
||||
) {
|
||||
/*
|
||||
This is a helper function which filter values from `t._values()` using the `mask_indices`.
|
||||
This CPU implementation uses a simple hash_map to filter values by matching the `mask_indices`
|
||||
with the indices at tensor input `t`.
|
||||
|
||||
Inputs:
|
||||
`t` - tensor input
|
||||
`mask_indices` - mask indices tensor
|
||||
*/
|
||||
int64_t r_nnz = mask_indices.size(1);
|
||||
auto t_v = t._values();
|
||||
Tensor r_values = at::zeros({r_nnz}, t_v.options());
|
||||
auto t_i = t._indices();
|
||||
auto t_nnz = t._nnz();
|
||||
|
||||
std::unordered_map<int64_t, int64_t> t_flatten_indices = std::unordered_map<int64_t, int64_t>{};
|
||||
|
||||
// Step 1: flatten the sparse indices `t._indices()` tensor and then map this flatten value `index` to the original position `i`
|
||||
auto t_indices_accessor = t_i.accessor<int64_t, 2>();
|
||||
for(int64_t i = 0; i < t_nnz; i++) {
|
||||
int64_t index = t_indices_accessor[0][i] * t.size(1) + t_indices_accessor[1][i];
|
||||
t_flatten_indices[index] = i;
|
||||
}
|
||||
|
||||
// Step 2: Filter `t._values()` values by matching the flatten `mask_indices` with the flatten `t._indices()` using the
|
||||
// hash_map `t_flatten_indices`
|
||||
AT_DISPATCH_FLOATING_TYPES(r_values.scalar_type(), "_sparse_matrix_mask", [&] {
|
||||
auto r_values_accessor = r_values.accessor<scalar_t, 1>();
|
||||
auto t_values = t_v.accessor<scalar_t, 1>();
|
||||
auto mask_indices_accessor = mask_indices.accessor<int64_t, 2>();
|
||||
at::parallel_for(0, r_nnz, 0, [&](int64_t start, int64_t end) {
|
||||
for (auto i = start; i < end; i++) {
|
||||
auto x = mask_indices_accessor[0][i];
|
||||
auto y = mask_indices_accessor[1][i];
|
||||
int64_t index = (x * t.size(1) + y);
|
||||
auto iter = t_flatten_indices.find(index);
|
||||
if (iter != t_flatten_indices.end()) {
|
||||
assert(iter->second < t_nnz);
|
||||
assert(i < r_nnz);
|
||||
r_values_accessor[i] = t_values[ iter->second ];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
return r_values;
|
||||
}
|
||||
|
||||
Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
TORCH_INTERNAL_ASSERT(mat1_.is_sparse());
|
||||
TORCH_INTERNAL_ASSERT(mat2_.is_sparse());
|
||||
TORCH_CHECK(mat1_.dim() == 2);
|
||||
TORCH_CHECK(mat2_.dim() == 2);
|
||||
TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat1_.dense_dim(), "D values");
|
||||
TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat2_.dense_dim(), "D values");
|
||||
|
||||
TORCH_CHECK(
|
||||
mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
|
||||
|
||||
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
|
||||
"mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type());
|
||||
|
||||
auto output = at::native::empty_like(mat1_);
|
||||
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
|
||||
sparse_matmul_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -95,7 +95,7 @@ SparseTensor new_with_dims_and_tensor_sparse(
|
||||
int64_t sparse_dim,
|
||||
int64_t dense_dim,
|
||||
ArrayRef<int64_t> size,
|
||||
const LongTensor& indices,
|
||||
const Tensor& indices,
|
||||
const Tensor& values,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
@ -106,7 +106,7 @@ SparseTensor new_with_dims_and_tensor_sparse(
|
||||
// NOTE: There is no guarantee that `indices` and `values` don't contain AutogradMeta. However,
|
||||
// we want to maintain the invariant that `indices_` and `values_` of a sparse tensor don't
|
||||
// contain AutogradMeta, and to achieve that we shallow-copy `indices` and `values` here.
|
||||
auto indices_shallow_copy = LongTensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach(
|
||||
auto indices_shallow_copy = Tensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach(
|
||||
/*version_counter=*/indices.unsafeGetTensorImpl()->version_counter(),
|
||||
/*allow_tensor_metadata_change=*/true));
|
||||
auto values_shallow_copy = Tensor(values.unsafeGetTensorImpl()->shallow_copy_and_detach(
|
||||
@ -163,11 +163,11 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_, const Ten
|
||||
// If the indices has elements in it, we infer the minimum sparse dimension sizes
|
||||
// as the max value of each dim in indices.
|
||||
// NB: It used to keepdim. I think that was wrong.
|
||||
LongTensor min_indices = std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
LongTensor computed_indices_sizes = std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor min_indices = std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor computed_indices_sizes = std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
|
||||
computed_indices_sizes.add_(1); // len = max_index + 1
|
||||
LongTensor cpu_min_indices = min_indices.to(at::DeviceType::CPU);
|
||||
LongTensor cpu_computed_indices_sizes = computed_indices_sizes.to(at::DeviceType::CPU);
|
||||
Tensor cpu_min_indices = min_indices.to(at::DeviceType::CPU);
|
||||
Tensor cpu_computed_indices_sizes = computed_indices_sizes.to(at::DeviceType::CPU);
|
||||
auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
|
||||
auto cpu_computed_indices_sizes_accessor = cpu_computed_indices_sizes.accessor<int64_t, 1>();
|
||||
for (int64_t d = 0; d < sparse_dim; d++) {
|
||||
@ -206,9 +206,9 @@ void _validate_sparse_coo_tensor_args(const Tensor& indices, const Tensor& value
|
||||
|
||||
// Check to make sure all indices are within the boundaries of `size`
|
||||
if (indices.numel() > 0) {
|
||||
LongTensor min_indices = std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
LongTensor max_indices = std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
|
||||
LongTensor cpu_min_indices, cpu_max_indices;
|
||||
Tensor min_indices = std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor max_indices = std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor cpu_min_indices, cpu_max_indices;
|
||||
if (indices.is_cuda()) {
|
||||
cpu_min_indices = min_indices.to(at::DeviceType::CPU);
|
||||
cpu_max_indices = max_indices.to(at::DeviceType::CPU);
|
||||
@ -317,7 +317,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
|
||||
if (nz.size(1) == 0) {
|
||||
return new_with_dims_sparse(sparse_dim, dims - sparse_dim, sizes, optTypeMetaToScalarType(sparse_options.dtype_opt()), sparse_options.layout_opt(), sparse_options.device_opt(), sparse_options.pinned_memory_opt());
|
||||
}
|
||||
LongTensor indices;
|
||||
Tensor indices;
|
||||
if (sparse_dim == dims) {
|
||||
indices = nz.clone();
|
||||
} else {
|
||||
@ -375,23 +375,23 @@ SparseTensor coalesce_sparse_cpu(const SparseTensor& self) {
|
||||
return dst;
|
||||
}
|
||||
|
||||
LongTensor indices = self._indices();
|
||||
Tensor indices = self._indices();
|
||||
Tensor values = self._values().contiguous();
|
||||
int64_t sparse_dim = self.sparse_dim();
|
||||
int64_t dense_dim = self.dense_dim();
|
||||
int64_t nnz = self._nnz();
|
||||
|
||||
LongTensor indices_scalar = flatten_indices(indices, self.sizes());
|
||||
Tensor indices_scalar = flatten_indices(indices, self.sizes());
|
||||
|
||||
SparseTensor dst = new_sparse(optTypeMetaToScalarType(self.options().dtype_opt()), self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt());
|
||||
get_sparse_impl(dst)->resize_(sparse_dim, dense_dim, self.sizes());
|
||||
// TODO: is there a more idiomatic way to do this?
|
||||
LongTensor newIndices = at::empty(indices.sizes(), indices.options());
|
||||
Tensor newIndices = at::empty(indices.sizes(), indices.options());
|
||||
Tensor newValues = at::empty(values.sizes(), values.options());
|
||||
alias_into_sparse(dst, newIndices, newValues);
|
||||
|
||||
LongTensor indicesBuffer;
|
||||
LongTensor indicesPermutation;
|
||||
Tensor indicesBuffer;
|
||||
Tensor indicesPermutation;
|
||||
std::tie(indicesBuffer, indicesPermutation) = indices_scalar.sort(0);
|
||||
// NB: The accessor accesses here rely on self._nnz() > 0 (tested earlier in this function)
|
||||
auto newIndicesAccessor = newIndices.accessor<int64_t, 2>();
|
||||
@ -446,7 +446,7 @@ void inline sparse_mask_out_cpu_kernel(
|
||||
const Tensor& t,
|
||||
const int64_t r_nnz,
|
||||
const int64_t sparse_dim,
|
||||
const LongTensor& mask_indices
|
||||
const Tensor& mask_indices
|
||||
) {
|
||||
auto r_values_accessor = r_values.accessor<scalar_t, 1>();
|
||||
auto mask_indices_accessor = mask_indices.accessor<int64_t, 2>();
|
||||
@ -476,7 +476,7 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse
|
||||
}
|
||||
int64_t dim = t.dim();
|
||||
int64_t sparse_dim = mask.sparse_dim();
|
||||
LongTensor mask_indices = mask._indices();
|
||||
Tensor mask_indices = mask._indices();
|
||||
Tensor mask_values = mask._values();
|
||||
Tensor r_values = at::empty(mask_values.sizes(), r._values().options());
|
||||
alias_into_sparse(r, mask_indices.clone(), r_values);
|
||||
@ -492,7 +492,7 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse
|
||||
|
||||
// Get a flattened sparse indices, similar to NOTE [ Flatten Sparse Indices ].
|
||||
// Keeping this implementation because it is faster than flatten_indices()
|
||||
LongTensor indices = at::zeros({mask._nnz()}, mask_indices.options());
|
||||
Tensor indices = at::zeros({mask._nnz()}, mask_indices.options());
|
||||
for (int64_t d = 0; d < mask.sparse_dim(); d++) {
|
||||
indices.mul_(mask.size(d));
|
||||
indices.add_(mask_indices.select(0, d));
|
||||
|
@ -21,26 +21,6 @@ using namespace at::sparse;
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
LongTensor _to_csr(const int64_t* indices, int64_t dim, int64_t nnz) {
|
||||
LongTensor csr = native::zeros({dim + 1}, kLong);
|
||||
|
||||
// TODO: eliminate this conditional when zero-size dims supported correctly
|
||||
if (nnz > 0) {
|
||||
auto csr_accessor = csr.accessor<int64_t, 1>();
|
||||
// Convert the sparse matrix to CSR format
|
||||
at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) {
|
||||
int64_t h, hp0, hp1;
|
||||
for (auto i = start; i < end; i++) {
|
||||
hp0 = indices[i];
|
||||
hp1 = (i+1 == nnz) ? dim : indices[i+1];
|
||||
if (hp0 != hp1) for (h = hp0; h < hp1; h++) {
|
||||
csr_accessor[h+1] = i+1;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
return csr;
|
||||
}
|
||||
|
||||
inline SparseTensor get_result_tensor_for_unary_op(const SparseTensor& input) {
|
||||
if (c10::isIntegralType(input.scalar_type(), /*includeBool=*/true)) {
|
||||
@ -453,7 +433,7 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t,
|
||||
bool coalesced = t.is_coalesced() && src.is_coalesced();
|
||||
int64_t sparse_dim = src.sparse_dim();
|
||||
|
||||
LongTensor r_indices = at::empty({src.sparse_dim(), max_nnz}, t._indices().options());
|
||||
Tensor r_indices = at::empty({src.sparse_dim(), max_nnz}, t._indices().options());
|
||||
|
||||
Tensor t_values = t._values().to(commonDtype);
|
||||
Tensor s_values = src._values().to(commonDtype);
|
||||
@ -548,7 +528,7 @@ SparseTensor& add_out_sparse_non_contiguous(SparseTensor& r, const SparseTensor&
|
||||
}
|
||||
});
|
||||
|
||||
LongTensor r_indices = at::cat({t._indices(), src._indices()}, 1);
|
||||
Tensor r_indices = at::cat({t._indices(), src._indices()}, 1);
|
||||
Tensor r_values = at::cat({t_values, s_values}, 0).to(r.scalar_type());
|
||||
alias_into_sparse(r, r_indices, r_values);
|
||||
|
||||
@ -640,7 +620,7 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTen
|
||||
r.resize_as_(dense);
|
||||
SparseTensor sparse = sparse_.coalesce();
|
||||
|
||||
LongTensor indices = sparse._indices();
|
||||
Tensor indices = sparse._indices();
|
||||
Tensor values = sparse._values();
|
||||
int64_t nDim = dense.dim();
|
||||
int64_t nDimI = sparse.sparse_dim();
|
||||
@ -721,9 +701,9 @@ SparseTensor& mul_out_sparse_cpu(SparseTensor& r, const Tensor& t_, const Tensor
|
||||
int64_t t_nnz = t._nnz(), s_nnz = src._nnz();
|
||||
int64_t max_nnz = std::min(t_nnz, s_nnz); // multiply by zero is zero, and can be dropped
|
||||
int64_t sparse_dim = src.sparse_dim();
|
||||
LongTensor t_indices = t._indices();
|
||||
LongTensor src_indices = src._indices();
|
||||
LongTensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options());
|
||||
Tensor t_indices = t._indices();
|
||||
Tensor src_indices = src._indices();
|
||||
Tensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options());
|
||||
|
||||
int64_t match, d;
|
||||
int64_t r_i = 0, t_i = 0, s_i = 0;
|
||||
@ -889,7 +869,7 @@ Tensor& s_addmm_out_sparse_dense_cpu(
|
||||
return r;
|
||||
}
|
||||
|
||||
LongTensor indices = sparse_._indices();
|
||||
Tensor indices = sparse_._indices();
|
||||
Tensor values = sparse_._values();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
@ -1021,13 +1001,13 @@ SparseTensor& hspmm_out_sparse_cpu(SparseTensor& r, const SparseTensor& sparse_,
|
||||
return r;
|
||||
}
|
||||
|
||||
LongTensor indices = at::empty({1, nnz}, at::initialTensorOptions().dtype(kLong));
|
||||
Tensor indices = at::empty({1, nnz}, at::initialTensorOptions().dtype(kLong));
|
||||
|
||||
// Initialize the sparse matrix that will be used with spaddmm to send rows
|
||||
// from the dense matrix to rows of the output's value tensor
|
||||
SparseTensor newSparse = sparse.clone();
|
||||
LongTensor spIndices = newSparse._indices();
|
||||
LongTensor valueIndices = spIndices.select(0, 0);
|
||||
Tensor spIndices = newSparse._indices();
|
||||
Tensor valueIndices = spIndices.select(0, 0);
|
||||
|
||||
// Compute output indices
|
||||
auto valueIndices_accessor = valueIndices.accessor<int64_t, 1>();
|
||||
@ -1108,19 +1088,19 @@ SparseTensor& _sspaddmm_out_cpu(
|
||||
"sspaddmm: Argument #1: Expected dim 1 size ", dim_k, ", got ", t.size(1));
|
||||
|
||||
int64_t nnz = sparse._nnz();
|
||||
// We have to make indices contiguous as we use indices.data_ptr in _to_csr which assumes row-contiguous storage
|
||||
LongTensor indices = sparse._indices().contiguous();
|
||||
// We have to make indices contiguous as we use indices.data_ptr in _to_csr which assumes row-contiguous storage
|
||||
Tensor indices = sparse._indices().contiguous();
|
||||
Tensor values = sparse._values();
|
||||
|
||||
LongTensor csr = _to_csr(indices.data_ptr<int64_t>(), dim_i, nnz);
|
||||
Tensor csr = coo_to_csr(indices.data_ptr<int64_t>(), dim_i, nnz);
|
||||
|
||||
int64_t t_nnz = t._nnz();
|
||||
int64_t r_nnz = nnz * dim_k + t_nnz;
|
||||
LongTensor newi = at::empty({2, r_nnz}, kLong);
|
||||
LongTensor newv = native::zeros({r_nnz}, values.options());
|
||||
Tensor newi = at::empty({2, r_nnz}, kLong);
|
||||
Tensor newv = native::zeros({r_nnz}, values.options());
|
||||
|
||||
if (t_nnz != 0) {
|
||||
LongTensor narrowi = newi.narrow(1, 0, t_nnz);
|
||||
Tensor narrowi = newi.narrow(1, 0, t_nnz);
|
||||
Tensor narrowv = newv.narrow(0, 0, t_nnz);
|
||||
|
||||
narrowi.copy_(t._indices());
|
||||
@ -1230,7 +1210,7 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
|
||||
auto dims_to_sum_v = dims_to_sum.vec();
|
||||
maybe_wrap_dims(dims_to_sum_v, input_dim);
|
||||
|
||||
LongTensor indices = input._indices();
|
||||
Tensor indices = input._indices();
|
||||
Tensor values = input._values();
|
||||
IntArrayRef sizes = input.sizes();
|
||||
const int64_t sparse_dim = input.sparse_dim();
|
||||
@ -1266,7 +1246,7 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
|
||||
}
|
||||
else { // !sum_all_sparse_dim
|
||||
// new indices
|
||||
LongTensor new_indices;
|
||||
Tensor new_indices;
|
||||
if (sparse_dims_to_sum_size == 0) {
|
||||
new_indices = indices.clone(at::MemoryFormat::Contiguous);
|
||||
}
|
||||
@ -1348,7 +1328,7 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_,
|
||||
auto dims_to_sum_v = dims_to_sum.vec();
|
||||
maybe_wrap_dims(dims_to_sum_v, input_dim);
|
||||
|
||||
LongTensor input_indices = input._indices();
|
||||
Tensor input_indices = input._indices();
|
||||
Tensor input_values = input._values();
|
||||
IntArrayRef input_sizes = input.sizes();
|
||||
const int64_t input_sparse_dim = input.sparse_dim();
|
||||
@ -1389,7 +1369,7 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_,
|
||||
else {
|
||||
TORCH_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be sparse, but got dense");
|
||||
auto grad = grad_.coalesce();
|
||||
LongTensor grad_indices = grad._indices();
|
||||
Tensor grad_indices = grad._indices();
|
||||
Tensor grad_values = grad._values();
|
||||
const int64_t grad_sparse_dim = grad.sparse_dim();
|
||||
const int64_t grad_nnz = grad._nnz();
|
||||
@ -1533,12 +1513,12 @@ Tensor& bmm_out_sparse_cpu(Tensor& result, const SparseTensor& self, const Tenso
|
||||
SparseTensor self_coalesced = self.coalesce();
|
||||
|
||||
int64_t nnz = self_coalesced._nnz();
|
||||
LongTensor indices = self_coalesced._indices();
|
||||
Tensor indices = self_coalesced._indices();
|
||||
Tensor values = self_coalesced._values();
|
||||
|
||||
LongTensor indices_dim0 = indices[0];
|
||||
Tensor indices_dim0 = indices[0];
|
||||
auto indices_dim0_accessor = indices_dim0.accessor<int64_t, 1>();
|
||||
LongTensor indices_dim1_dim2 = indices.slice(0, 1, 3);
|
||||
Tensor indices_dim1_dim2 = indices.slice(0, 1, 3);
|
||||
|
||||
int64_t dim_i = self_coalesced.size(1);
|
||||
int64_t dim_j = self_coalesced.size(2);
|
||||
@ -1588,7 +1568,7 @@ Tensor& bmm_out_sparse_cpu(Tensor& result, const SparseTensor& self, const Tenso
|
||||
// Create tensors to view just the current set of matrices
|
||||
const Tensor dense_matrix = mat2[cur_mat_num];
|
||||
Tensor result_matrix = result[cur_mat_num];
|
||||
LongTensor sparse_indices = indices_dim1_dim2.slice(1, mat_el_begin_idx, mat_el_end_idx);
|
||||
Tensor sparse_indices = indices_dim1_dim2.slice(1, mat_el_begin_idx, mat_el_end_idx);
|
||||
Tensor sparse_values = values.slice(0, mat_el_begin_idx, mat_el_end_idx);
|
||||
int64_t sparse_nnz = mat_el_end_idx - mat_el_begin_idx;
|
||||
|
||||
|
@ -22,7 +22,7 @@ SparseTensor& sparse_mask_out_cuda(SparseTensor& r, const Tensor& t, const Spars
|
||||
if (mask._nnz() == 0) {
|
||||
return r.zero_();
|
||||
}
|
||||
LongTensor mask_indices = mask._indices();
|
||||
Tensor mask_indices = mask._indices();
|
||||
Tensor mask_values = mask._values();
|
||||
Tensor r_values = at::empty(mask_values.sizes(), r._values().options());
|
||||
alias_into_sparse(r, mask_indices.clone(at::MemoryFormat::Contiguous), r_values);
|
||||
@ -33,7 +33,7 @@ SparseTensor& sparse_mask_out_cuda(SparseTensor& r, const Tensor& t, const Spars
|
||||
|
||||
// Get a flattened sparse indices, similar to NOTE [ Flatten Sparse Indices ].
|
||||
// Keeping this implementation because it is faster than flatten_indices()
|
||||
LongTensor indices = at::zeros({mask._nnz()}, mask_indices.options());
|
||||
Tensor indices = at::zeros({mask._nnz()}, mask_indices.options());
|
||||
for (int64_t d = 0; d < mask.sparse_dim(); d++) {
|
||||
indices.mul_(mask.size(d));
|
||||
// This used to use a buffer but I deoptimized it
|
||||
|
@ -52,10 +52,10 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
|
||||
|
||||
// indices will be modified by Thrust, so we have to clone or use new storage
|
||||
// here.
|
||||
LongTensor indices1D = flatten_indices(self._indices(), self.sizes(), true);
|
||||
Tensor indices1D = flatten_indices(self._indices(), self.sizes(), true);
|
||||
|
||||
LongTensor origIndices = at::empty({nnz}, self._indices().options());
|
||||
LongTensor uniqueOffsets = at::empty({nnz}, self._indices().options());
|
||||
Tensor origIndices = at::empty({nnz}, self._indices().options());
|
||||
Tensor uniqueOffsets = at::empty({nnz}, self._indices().options());
|
||||
|
||||
typedef thrust::device_ptr<int64_t> thrust_ptr;
|
||||
thrust_ptr indicesIter(indices1D.data_ptr<int64_t>());
|
||||
@ -126,14 +126,14 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
|
||||
|
||||
////////////////////////////////////////////////////////////
|
||||
// unflatten indices if necessary
|
||||
LongTensor newIndices;
|
||||
Tensor newIndices;
|
||||
if (sparse_dim == 1) {
|
||||
newIndices = indices1D;
|
||||
} else {
|
||||
newIndices = at::empty({sparse_dim, newNnz}, origIndices.options());
|
||||
for (int64_t d = sparse_dim - 1; d >= 0; d--) {
|
||||
// NB: Not a select, so I can preserve the outer dimension
|
||||
LongTensor indicesSlice = newIndices.narrow(0, d, 1);
|
||||
Tensor indicesSlice = newIndices.narrow(0, d, 1);
|
||||
// Note for the porting guide: THCTensor_(copy) does NOT do normal
|
||||
// broadcasting logic; instead, it will blast the elements from one
|
||||
// to the other so long as the numel is the same
|
||||
|
@ -39,9 +39,9 @@ using at::cuda::detail::getTensorInfo;
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
IntTensor _to_csr_int(const LongTensor& rowIndices, int64_t dim, int64_t nnz) {
|
||||
IntTensor csr = at::empty({dim+1}, CUDA(kInt));
|
||||
IntTensor rowIndicesInt = at::empty({rowIndices.size(0)}, CUDA(kInt));
|
||||
Tensor _to_csr_int(const Tensor& rowIndices, int64_t dim, int64_t nnz) {
|
||||
Tensor csr = at::empty({dim+1}, CUDA(kInt));
|
||||
Tensor rowIndicesInt = at::empty({rowIndices.size(0)}, CUDA(kInt));
|
||||
rowIndicesInt.copy_(rowIndices);
|
||||
sparse::cuda::Xcoo2csr(rowIndicesInt.data_ptr<int32_t>(), nnz, dim, csr.data_ptr<int32_t>());
|
||||
return csr;
|
||||
@ -52,13 +52,13 @@ namespace {
|
||||
// wired at all)
|
||||
|
||||
template <typename scalar_t>
|
||||
void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, Scalar beta, const Tensor& t, Scalar alpha, LongTensor& indices, Tensor& values, const Tensor& dense) {
|
||||
void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int64_t k, Tensor& r_, Scalar beta, const Tensor& t, Scalar alpha, Tensor& indices, Tensor& values, const Tensor& dense) {
|
||||
scalar_t cast_beta = beta.to<scalar_t>();
|
||||
scalar_t cast_alpha = alpha.to<scalar_t>();
|
||||
LongTensor rowIndices = indices.select(0, 0);
|
||||
LongTensor colIndices = indices.select(0, 1);
|
||||
IntTensor csr = _to_csr_int(rowIndices, m, nnz);
|
||||
IntTensor colIndicesInt = at::empty({colIndices.size(0)}, indices.options().dtype(kInt));
|
||||
Tensor rowIndices = indices.select(0, 0);
|
||||
Tensor colIndices = indices.select(0, 1);
|
||||
Tensor csr = _to_csr_int(rowIndices, m, nnz);
|
||||
Tensor colIndicesInt = at::empty({colIndices.size(0)}, indices.options().dtype(kInt));
|
||||
colIndicesInt.copy_(colIndices);
|
||||
|
||||
Tensor r__;
|
||||
@ -147,7 +147,7 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT
|
||||
SparseTensor sparse = sparse_.coalesce();
|
||||
|
||||
int64_t nnz = sparse._nnz();
|
||||
LongTensor indices = sparse._indices();
|
||||
Tensor indices = sparse._indices();
|
||||
Tensor values = sparse._values();
|
||||
|
||||
|
||||
@ -247,7 +247,7 @@ SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse
|
||||
|
||||
int64_t nnz = sparse._nnz();
|
||||
|
||||
LongTensor indices = at::empty({1, nnz}, CUDA(kLong));
|
||||
Tensor indices = at::empty({1, nnz}, CUDA(kLong));
|
||||
// create values in column-major format to avoid copying in spaddmm
|
||||
Tensor values = at::empty({n, nnz}, dense.options());
|
||||
values.transpose_(0, 1);
|
||||
@ -255,8 +255,8 @@ SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse
|
||||
// why does sparse need to be cloned? If this is really necessary maybe we
|
||||
// need to fuse this with newCoalesce
|
||||
SparseTensor newSparse = sparse.clone();
|
||||
LongTensor spIndices = newSparse._indices();
|
||||
LongTensor dstIndices = spIndices.select(0, 0);
|
||||
Tensor spIndices = newSparse._indices();
|
||||
Tensor dstIndices = spIndices.select(0, 0);
|
||||
// Save destination indices to output hybrid tensor
|
||||
indices.copy_(dstIndices);
|
||||
// Replace destination indices with 0, 1, 2, 3, ... and compute output values
|
||||
@ -320,7 +320,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT
|
||||
r.copy_(dense_buffer);
|
||||
}
|
||||
|
||||
LongTensor indices = sparse._indices();
|
||||
Tensor indices = sparse._indices();
|
||||
int64_t nDim = dense.dim();
|
||||
int64_t nDimI = sparse.sparse_dim();
|
||||
|
||||
@ -363,7 +363,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT
|
||||
}
|
||||
} else {
|
||||
|
||||
LongTensor indices1D = flatten_indices(indices, sparse.sizes(), 0);
|
||||
Tensor indices1D = flatten_indices(indices, sparse.sizes(), 0);
|
||||
|
||||
// FIXME: at some point we can wrap the scale into indexAdd
|
||||
// NB: Purposely not inplace!
|
||||
@ -431,8 +431,8 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const
|
||||
// rather than merging them. This removes the need to synchronously fetch nnz
|
||||
// at the end of the operation, at the cost of having a non-coalesced result.
|
||||
// This trade-off is preferable for the common use-case of gradient accumulation.
|
||||
LongTensor t_indices_ = t._indices();
|
||||
LongTensor s_indices_ = src._indices();
|
||||
Tensor t_indices_ = t._indices();
|
||||
Tensor s_indices_ = src._indices();
|
||||
|
||||
Tensor t_values_ = t._values().to(commonDtype);
|
||||
Tensor s_values_ = src._values().to(commonDtype);
|
||||
@ -443,7 +443,7 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const
|
||||
s_values_ = s_values_.mul(value);
|
||||
}
|
||||
});
|
||||
LongTensor r_indices_ = at::cat({t_indices_, s_indices_}, 1);
|
||||
Tensor r_indices_ = at::cat({t_indices_, s_indices_}, 1);
|
||||
Tensor r_values_ = at::cat({t_values_, s_values_}, 0);
|
||||
|
||||
if (r_.scalar_type() != commonDtype) {
|
||||
@ -501,11 +501,11 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons
|
||||
int64_t sparse_dim = src.sparse_dim();
|
||||
auto commonDtype = at::result_type(t, src);
|
||||
TORCH_CHECK(canCast(commonDtype, r_.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r_.scalar_type());
|
||||
LongTensor t_indices_ = t._indices().contiguous();
|
||||
Tensor t_indices_ = t._indices().contiguous();
|
||||
Tensor t_values_ = t._values().to(commonDtype);
|
||||
LongTensor s_indices_ = src._indices().contiguous();
|
||||
Tensor s_indices_ = src._indices().contiguous();
|
||||
Tensor s_values_ = src._values().to(commonDtype);
|
||||
LongTensor r_indices_ = at::empty({sparse_dim, max_nnz}, t_indices_.options());
|
||||
Tensor r_indices_ = at::empty({sparse_dim, max_nnz}, t_indices_.options());
|
||||
r_.resize_as_(src);
|
||||
|
||||
Tensor r_values_ = new_values_with_size_of(t_values_, max_nnz).zero_();
|
||||
@ -518,7 +518,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
||||
TORCH_CHECK(cuda::getApplyGrid(valueSize, grid, curDevice), "mul: Argument #0: tensor too large or too many dimensions");
|
||||
|
||||
LongTensor resultNnz = at::empty({1}, CUDA(kLong));
|
||||
Tensor resultNnz = at::empty({1}, CUDA(kLong));
|
||||
AT_DISPATCH_ALL_TYPES_AND(
|
||||
at::ScalarType::Half, commonDtype, "mul_out_sparse_cuda", [&] {
|
||||
apply::valueSparseIntersectionKernel<TensorMulOp<scalar_t>, uint64_t, scalar_t>
|
||||
@ -541,7 +541,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons
|
||||
get_sparse_impl(r_)->set_indices_and_values_unsafe(r_indices_, r_values_);
|
||||
|
||||
// sync! (surely there is a more idiomatic way to do this...)
|
||||
LongTensor cpu_resultNnz = at::empty({1}, CPU(kLong));
|
||||
Tensor cpu_resultNnz = at::empty({1}, CPU(kLong));
|
||||
cpu_resultNnz.copy_(resultNnz);
|
||||
get_sparse_impl(r_)->set_nnz_and_narrow(cpu_resultNnz.accessor<int64_t, 1>()[0]);
|
||||
|
||||
@ -601,7 +601,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_
|
||||
auto dims_to_sum_v = dims_to_sum.vec();
|
||||
maybe_wrap_dims(dims_to_sum_v, input_dim);
|
||||
|
||||
LongTensor input_indices = input._indices();
|
||||
Tensor input_indices = input._indices();
|
||||
Tensor input_values = input._values();
|
||||
IntArrayRef input_sizes = input.sizes();
|
||||
const int64_t input_sparse_dim = input.sparse_dim();
|
||||
@ -641,7 +641,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_
|
||||
else {
|
||||
TORCH_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cuda: expected grad_ Tensor to be sparse, but got dense");
|
||||
auto grad = grad_.coalesce();
|
||||
LongTensor grad_indices = grad._indices();
|
||||
Tensor grad_indices = grad._indices();
|
||||
Tensor grad_values = grad._values();
|
||||
const int64_t grad_sparse_dim = grad.sparse_dim();
|
||||
const int64_t grad_nnz = grad._nnz();
|
||||
@ -679,7 +679,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_
|
||||
thrust_ptr input_indices_iter(input_indices_1D.data_ptr<int64_t>());
|
||||
|
||||
// store lower_bound of input indices at grad indices
|
||||
LongTensor input_indices_pos = at::empty_like(input_indices_1D, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
Tensor input_indices_pos = at::empty_like(input_indices_1D, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
thrust_ptr input_indices_pos_iter(input_indices_pos.data_ptr<int64_t>());
|
||||
thrust::lower_bound(policy,
|
||||
grad_indices_iter, grad_indices_iter + grad_nnz,
|
||||
@ -767,7 +767,7 @@ __global__ void search_end_matrix_indices_cuda_kernel(
|
||||
|
||||
// Search through a 1D tensor of sorted sparse matrix
|
||||
// indices to find the end index for each matrix
|
||||
void search_end_matrix_indices(int64_t* mat_el_end_indices, int64_t num_matrices, const LongTensor& indices_1D) {
|
||||
void search_end_matrix_indices(int64_t* mat_el_end_indices, int64_t num_matrices, const Tensor& indices_1D) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
||||
@ -855,10 +855,10 @@ Tensor& _bmm_out_sparse_cuda(Tensor& result, const SparseTensor& self, const Ten
|
||||
SparseTensor self_coalesced = coalesce_sparse_cuda(self);
|
||||
|
||||
int64_t nnz = self_coalesced._nnz();
|
||||
LongTensor indices = self_coalesced._indices();
|
||||
Tensor indices = self_coalesced._indices();
|
||||
Tensor values = self_coalesced._values();
|
||||
|
||||
LongTensor indices_dim0 = indices[0];
|
||||
Tensor indices_dim0 = indices[0];
|
||||
|
||||
// Need to convert dim1 and dim2 indices to 32-bit since cusparseSpMM
|
||||
// only supports 32-bit indices
|
||||
|
896
aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
Normal file
896
aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
Normal file
@ -0,0 +1,896 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/for_each.h>
|
||||
#include <thrust/sequence.h>
|
||||
|
||||
#include <THC/THCTensorMathPointwise.cuh>
|
||||
#include <THC/THCThrustAllocator.cuh>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAUtils.h>
|
||||
#include <cusparse.h>
|
||||
#include <ATen/native/sparse/cuda/SparseCUDABlas.cuh>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
|
||||
|
||||
#if defined(__CUDACC__) && (CUSPARSE_VERSION >= 11000)
|
||||
#define IS_CUSPARSE11_AVAILABLE() 1
|
||||
#else
|
||||
#define IS_CUSPARSE11_AVAILABLE() 0
|
||||
#endif
|
||||
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
#include <library_types.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace at::sparse;
|
||||
|
||||
Tensor _to_csr_int(const Tensor& rowIndices, int64_t dim, int64_t nnz) {
|
||||
Tensor csr = at::empty({dim + 1}, CUDA(kInt));
|
||||
Tensor rowIndicesInt = at::empty({rowIndices.size(0)}, CUDA(kInt));
|
||||
rowIndicesInt.copy_(rowIndices);
|
||||
sparse::cuda::Xcoo2csr(
|
||||
rowIndicesInt.data_ptr<int32_t>(), nnz, dim, csr.data_ptr<int32_t>());
|
||||
return csr;
|
||||
}
|
||||
|
||||
int confirm_mult_size(const std::vector<int>& mat1_size, const std::vector<int>& mat2_size) {
|
||||
TORCH_CHECK(
|
||||
mat1_size[1] == mat2_size[0],
|
||||
"mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1_size[0],
|
||||
"x",
|
||||
mat1_size[1],
|
||||
" and ",
|
||||
mat2_size[0],
|
||||
"x",
|
||||
mat2_size[1],
|
||||
")");
|
||||
return mat1_size[1];
|
||||
}
|
||||
|
||||
void create_general_description_(cusparseMatDescr_t& description_) {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&description_));
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSetMatType(description_, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSetMatIndexBase(description_, CUSPARSE_INDEX_BASE_ZERO));
|
||||
}
|
||||
|
||||
// csrMatrixRef is used to have a representation of a raw CSR matrix representation
|
||||
// comming from `sparse_sparse_matmul_cuda_kernel` function.
|
||||
// Moreover this implements a RAII guard for a cusparse descriptor
|
||||
template<class scalar_t>
|
||||
struct csrMatrixRef {
|
||||
int* csr_indices_{nullptr};
|
||||
int* csr_pointers_{nullptr};
|
||||
scalar_t* csr_values_{nullptr};
|
||||
int nnz_{0};
|
||||
std::vector<int> size_{};
|
||||
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
cusparseSpMatDescr_t description_{0};
|
||||
#else
|
||||
cusparseMatDescr_t description_{0};
|
||||
#endif
|
||||
|
||||
csrMatrixRef() {
|
||||
#if !IS_CUSPARSE11_AVAILABLE()
|
||||
create_general_description_(description_);
|
||||
#endif
|
||||
}
|
||||
|
||||
csrMatrixRef(
|
||||
int* csr_indices,
|
||||
int* csr_pointers,
|
||||
scalar_t* csr_values,
|
||||
int nnz,
|
||||
const std::vector<int>& size)
|
||||
: csr_indices_{csr_indices},
|
||||
csr_pointers_{csr_pointers},
|
||||
csr_values_{csr_values},
|
||||
nnz_{nnz},
|
||||
size_{size} {
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
cudaDataType cuda_data_type;
|
||||
if ( std::is_same<float, scalar_t>::value ) {
|
||||
cuda_data_type = CUDA_R_32F;
|
||||
} else if ( std::is_same<double, scalar_t>::value) {
|
||||
cuda_data_type = CUDA_R_64F;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Tensor types must be either float32 or float64");
|
||||
}
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
|
||||
&description_,
|
||||
this->size(0),
|
||||
this->size(1),
|
||||
this->nnz_,
|
||||
this->csr_pointers_,
|
||||
this->csr_indices_,
|
||||
this->csr_values_,
|
||||
CUSPARSE_INDEX_32I,
|
||||
CUSPARSE_INDEX_32I,
|
||||
CUSPARSE_INDEX_BASE_ZERO,
|
||||
cuda_data_type));
|
||||
#else
|
||||
create_general_description_(description_);
|
||||
#endif
|
||||
}
|
||||
|
||||
~csrMatrixRef() {
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
cusparseDestroySpMat(description_);
|
||||
#else
|
||||
cusparseDestroyMatDescr(description_);
|
||||
#endif
|
||||
}
|
||||
|
||||
int size(int index) const {
|
||||
return size_.at(index);
|
||||
}
|
||||
};
|
||||
|
||||
// csrOutput is used to represent the output for `CusparseMatrixMultiplyOp`
|
||||
// Note that `csrOutput` is different from `csrMatrixRef` and the purpose
|
||||
// of this was to have a materialized version of a CSR matrix.
|
||||
// Moreover this implements a RAII guard for a cusparse descriptor
|
||||
struct csrOutput {
|
||||
Tensor csr_indices_{};
|
||||
Tensor csr_pointers_{};
|
||||
at::Tensor csr_values_{};
|
||||
int nnz_{0};
|
||||
std::vector<int> size_;
|
||||
|
||||
cusparseMatDescr_t description_{0};
|
||||
|
||||
csrOutput(const std::vector<int> &size) : size_{size} {
|
||||
create_general_description_(description_);
|
||||
}
|
||||
|
||||
~csrOutput() {
|
||||
cusparseDestroyMatDescr(description_);
|
||||
}
|
||||
|
||||
int size(int index) const {
|
||||
return size_.at(index);
|
||||
}
|
||||
};
|
||||
|
||||
#if IS_CUSPARSE11_AVAILABLE()
|
||||
|
||||
// RAII guard helps to support cuSparse 11 API for `A @ B` operation
|
||||
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
|
||||
template <class scalar_t>
|
||||
struct CusparseMatrixMultiplyOp {
|
||||
|
||||
cusparseSpGEMMDescr_t spgemmDesc;
|
||||
|
||||
CusparseMatrixMultiplyOp() {
|
||||
static_assert(std::is_same<float, scalar_t>::value || std::is_same<double, scalar_t>::value,
|
||||
"cusparse csr sparse-sparse MM only supports data type of float and double.");
|
||||
// SpGEMM Computation
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&spgemmDesc));
|
||||
}
|
||||
|
||||
~CusparseMatrixMultiplyOp() {
|
||||
// destroy matrix/vector descriptors
|
||||
cusparseSpGEMM_destroyDescr(spgemmDesc);
|
||||
}
|
||||
|
||||
csrOutput operator ()(
|
||||
const csrMatrixRef<scalar_t>& A,
|
||||
const csrMatrixRef<scalar_t>& B,
|
||||
Tensor& output_values,
|
||||
Tensor& output_indices) {
|
||||
const int A_num_rows = A.size(0);
|
||||
const int A_num_cols = A.size(1);
|
||||
const int A_num_nnz = A.nnz_;
|
||||
|
||||
const int B_num_rows = B.size(0);
|
||||
const int B_num_cols = B.size(1);
|
||||
const int B_num_nnz = B.nnz_;
|
||||
|
||||
int* dA_csrOffsets = A.csr_pointers_;
|
||||
int* dA_columns = A.csr_indices_;
|
||||
scalar_t* dA_values = A.csr_values_;
|
||||
|
||||
int* dB_csrOffsets = B.csr_pointers_;
|
||||
int* dB_columns = B.csr_indices_;
|
||||
scalar_t* dB_values = B.csr_values_;
|
||||
|
||||
cudaDataType computeType;
|
||||
if ( std::is_same<float, scalar_t>::value ) {
|
||||
computeType = CUDA_R_32F;
|
||||
} else if ( std::is_same<double, scalar_t>::value) {
|
||||
computeType = CUDA_R_64F;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Tensor types must be either float32 or float64");
|
||||
}
|
||||
csrOutput out({A.size(0), B.size(1)});
|
||||
|
||||
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
|
||||
|
||||
int* dC_csrOffsets = out.csr_pointers_.data_ptr<int>();
|
||||
int* dC_columns = nullptr;
|
||||
scalar_t* dC_values = nullptr;
|
||||
|
||||
scalar_t alpha = 1.0f;
|
||||
scalar_t beta = 0.0f;
|
||||
cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
csrMatrixRef<scalar_t> C(
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
/*nnz*/0,
|
||||
{A_num_rows, B_num_cols}
|
||||
);
|
||||
|
||||
//--------------------------------------------------------------------------
|
||||
// CUSPARSE APIs
|
||||
cusparseHandle_t handle = at::cuda::getCurrentCUDASparseHandle();
|
||||
void *dBuffer1 = NULL, *dBuffer2 = NULL;
|
||||
size_t bufferSize1 = 0, bufferSize2 = 0;
|
||||
|
||||
cusparseSpMatDescr_t matA = A.description_;
|
||||
cusparseSpMatDescr_t matB = B.description_;
|
||||
cusparseSpMatDescr_t matC = C.description_;
|
||||
//--------------------------------------------------------------------------
|
||||
|
||||
// ask bufferSize1 bytes for external memory
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
|
||||
handle,
|
||||
opA,
|
||||
opB,
|
||||
&alpha,
|
||||
matA,
|
||||
matB,
|
||||
&beta,
|
||||
matC,
|
||||
computeType,
|
||||
CUSPARSE_SPGEMM_DEFAULT,
|
||||
spgemmDesc,
|
||||
&bufferSize1,
|
||||
NULL));
|
||||
|
||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||
|
||||
at::DataPtr dataPtr1 = allocator.allocate(bufferSize1);
|
||||
dBuffer1 = dataPtr1.get();
|
||||
// inspect the matrices A and B to understand the memory requiremnent for
|
||||
// the next step
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_workEstimation(
|
||||
handle,
|
||||
opA,
|
||||
opB,
|
||||
&alpha,
|
||||
matA,
|
||||
matB,
|
||||
&beta,
|
||||
matC,
|
||||
computeType,
|
||||
CUSPARSE_SPGEMM_DEFAULT,
|
||||
spgemmDesc,
|
||||
&bufferSize1,
|
||||
dBuffer1));
|
||||
|
||||
// ask bufferSize2 bytes for external memory
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
|
||||
handle,
|
||||
opA,
|
||||
opB,
|
||||
&alpha,
|
||||
matA,
|
||||
matB,
|
||||
&beta,
|
||||
matC,
|
||||
computeType,
|
||||
CUSPARSE_SPGEMM_DEFAULT,
|
||||
spgemmDesc,
|
||||
&bufferSize2,
|
||||
NULL));
|
||||
|
||||
at::DataPtr dataPtr2 = allocator.allocate(bufferSize2);
|
||||
dBuffer2 = dataPtr2.get();
|
||||
|
||||
// compute the intermediate product of A * B
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_compute(
|
||||
handle,
|
||||
opA,
|
||||
opB,
|
||||
&alpha,
|
||||
matA,
|
||||
matB,
|
||||
&beta,
|
||||
matC,
|
||||
computeType,
|
||||
CUSPARSE_SPGEMM_DEFAULT,
|
||||
spgemmDesc,
|
||||
&bufferSize2,
|
||||
dBuffer2));
|
||||
// get matrix C non-zero entries C_num_nnz1
|
||||
int64_t C_num_rows1, C_num_cols1, C_num_nnz1;
|
||||
TORCH_CUDASPARSE_CHECK(
|
||||
cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_num_nnz1));
|
||||
// allocate matrix C
|
||||
// allocate C offsets
|
||||
out.nnz_ = C_num_nnz1;
|
||||
|
||||
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
|
||||
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
|
||||
dC_columns = out.csr_indices_.data_ptr<int>();
|
||||
dC_values = out.csr_values_.data_ptr<scalar_t>();
|
||||
|
||||
// update matC with the new pointers
|
||||
TORCH_CUDASPARSE_CHECK(
|
||||
cusparseCsrSetPointers(matC, dC_csrOffsets, dC_columns, dC_values));
|
||||
|
||||
// copy the final products to the matrix C
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_copy(
|
||||
handle,
|
||||
opA,
|
||||
opB,
|
||||
&alpha,
|
||||
matA,
|
||||
matB,
|
||||
&beta,
|
||||
matC,
|
||||
computeType,
|
||||
CUSPARSE_SPGEMM_DEFAULT,
|
||||
spgemmDesc));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template struct CusparseMatrixMultiplyOp<float>;
|
||||
|
||||
template struct CusparseMatrixMultiplyOp<double>;
|
||||
|
||||
#else // if not IS_CUSPARSE11_AVAILABLE()
|
||||
|
||||
using DcsrMatrixRef = csrMatrixRef<double>;
|
||||
using ScsrMatrixRef = csrMatrixRef<float>;
|
||||
|
||||
// RAII guard helps to support cuSparse 10 API for `A @ B` operation
|
||||
// This generic template exists because with cuSparse the `scalar_t` type could be a double or float
|
||||
template <class scalar_t>
|
||||
struct CusparseMatrixMultiplyOp {
|
||||
csrOutput operator()(
|
||||
const csrMatrixRef<scalar_t>& lhs,
|
||||
const csrMatrixRef<scalar_t>& rhs,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(false, "cusparse csr sparse-sparse MM only supports data type of float and double.");
|
||||
}
|
||||
};
|
||||
|
||||
// Specializacion for `A @ B` operation for double values with cuSparse
|
||||
template<> struct CusparseMatrixMultiplyOp<double> {
|
||||
csrgemm2Info_t gemm2Info_;
|
||||
|
||||
CusparseMatrixMultiplyOp() {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
|
||||
}
|
||||
~CusparseMatrixMultiplyOp() {
|
||||
cusparseDestroyCsrgemm2Info(gemm2Info_);
|
||||
}
|
||||
|
||||
csrOutput operator ()(
|
||||
const DcsrMatrixRef& lhs,
|
||||
const DcsrMatrixRef& rhs,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
double alpha = 1.0;
|
||||
DcsrMatrixRef empty;
|
||||
return Dgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
|
||||
}
|
||||
|
||||
csrOutput Dgemm2(
|
||||
const DcsrMatrixRef& A,
|
||||
const DcsrMatrixRef& B,
|
||||
const DcsrMatrixRef& C,
|
||||
const double* alpha,
|
||||
const double* beta,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
void* buffer_{nullptr};
|
||||
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
|
||||
|
||||
csrOutput out({A.size(0), B.size(1)});
|
||||
int innerSize = confirm_mult_size(A.size_, B.size_);
|
||||
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
|
||||
|
||||
// Compute needed buffer size
|
||||
size_t new_bubber_sz;
|
||||
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2_bufferSizeExt(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
gemm2Info_,
|
||||
&new_bubber_sz));
|
||||
|
||||
// (Re)allocate buffer if needed
|
||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
|
||||
buffer_ = data_ptr.get();
|
||||
|
||||
// Find the resulting non-zero pattern.
|
||||
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
&out.nnz_,
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
|
||||
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
|
||||
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
|
||||
|
||||
// Perform the gemm2 operation for doubles
|
||||
// out = alpha ∗ A ∗ B + beta ∗ C
|
||||
TORCH_CUDASPARSE_CHECK(cusparseDcsrgemm2(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_values_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_values_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_values_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_values_.data_ptr<double>(),
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
out.csr_indices_.data_ptr<int>(),
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// Specializacion for `A @ B` operation for float values with cuSparse
|
||||
template<> struct CusparseMatrixMultiplyOp<float> {
|
||||
csrgemm2Info_t gemm2Info_;
|
||||
|
||||
CusparseMatrixMultiplyOp() {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseCreateCsrgemm2Info(&gemm2Info_));
|
||||
|
||||
}
|
||||
~CusparseMatrixMultiplyOp() {
|
||||
cusparseDestroyCsrgemm2Info(gemm2Info_);
|
||||
}
|
||||
csrOutput operator()(
|
||||
const ScsrMatrixRef& lhs,
|
||||
const ScsrMatrixRef& rhs,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
float alpha = 1.0;
|
||||
ScsrMatrixRef empty;
|
||||
return Sgemm2(lhs, rhs, empty, &alpha, nullptr, output_values, output_indices);
|
||||
}
|
||||
|
||||
csrOutput Sgemm2(
|
||||
const ScsrMatrixRef& A,
|
||||
const ScsrMatrixRef& B,
|
||||
const ScsrMatrixRef& C,
|
||||
const float* alpha,
|
||||
const float* beta,
|
||||
Tensor &output_values,
|
||||
Tensor &output_indices) {
|
||||
void* buffer_{nullptr};
|
||||
cusparseHandle_t cusparseHandle_ = at::cuda::getCurrentCUDASparseHandle();
|
||||
TORCH_CUDASPARSE_CHECK(cusparseSetPointerMode(cusparseHandle_, CUSPARSE_POINTER_MODE_HOST));
|
||||
|
||||
csrOutput out({A.size(0), B.size(1)});
|
||||
|
||||
int innerSize = confirm_mult_size(A.size_, B.size_);
|
||||
|
||||
out.csr_pointers_ = at::empty({out.size(0) + 1}, output_indices.options().dtype(kInt));
|
||||
|
||||
// Compute needed buffer size
|
||||
size_t new_bubber_sz;
|
||||
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2_bufferSizeExt(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
gemm2Info_,
|
||||
&new_bubber_sz));
|
||||
|
||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||
at::DataPtr data_ptr = allocator.allocate(new_bubber_sz);
|
||||
buffer_ = data_ptr.get();
|
||||
|
||||
// Find the resulting non-zero pattern.
|
||||
TORCH_CUDASPARSE_CHECK(cusparseXcsrgemm2Nnz(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
&out.nnz_,
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
|
||||
out.csr_indices_ = at::empty({out.nnz_}, output_indices.options().dtype(kInt));
|
||||
out.csr_values_ = at::empty({out.nnz_}, output_values.options());
|
||||
|
||||
// Perform the gemm2 operation for doubles
|
||||
// out = alpha ∗ A ∗ B + beta ∗ C
|
||||
TORCH_CUDASPARSE_CHECK(cusparseScsrgemm2(
|
||||
cusparseHandle_,
|
||||
out.size(0),
|
||||
out.size(1),
|
||||
innerSize,
|
||||
alpha,
|
||||
A.description_,
|
||||
A.nnz_,
|
||||
A.csr_values_,
|
||||
A.csr_pointers_,
|
||||
A.csr_indices_,
|
||||
B.description_,
|
||||
B.nnz_,
|
||||
B.csr_values_,
|
||||
B.csr_pointers_,
|
||||
B.csr_indices_,
|
||||
beta,
|
||||
C.description_,
|
||||
C.nnz_,
|
||||
C.csr_values_,
|
||||
C.csr_pointers_,
|
||||
C.csr_indices_,
|
||||
out.description_,
|
||||
out.csr_values_.data_ptr<float>(),
|
||||
out.csr_pointers_.data_ptr<int>(),
|
||||
out.csr_indices_.data_ptr<int>(),
|
||||
gemm2Info_,
|
||||
buffer_));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
#endif // IS_CUSPARSE11_AVAILABLE()
|
||||
|
||||
template <typename scalar_t>
|
||||
void sparse_sparse_matmul_cuda_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2) {
|
||||
|
||||
static_assert(std::is_same<float, scalar_t>::value || std::is_same<double, scalar_t>::value,
|
||||
"sparse_sparse_matmul_cuda_kernel only supports float and double value types");
|
||||
|
||||
Tensor mat1_indices_ = mat1._indices().contiguous();
|
||||
Tensor mat1_values = mat1._values().contiguous();
|
||||
|
||||
Tensor mat1_row_indices = mat1_indices_.select(0, 0);
|
||||
Tensor mat1_col_indices = mat1_indices_.select(0, 1);
|
||||
|
||||
Tensor mat1_indptr = _to_csr_int(mat1_row_indices, mat1.size(0), mat1._nnz());
|
||||
|
||||
Tensor mat1_indices = at::empty(
|
||||
{mat1_col_indices.size(0)}, mat1_col_indices.options().dtype(kInt));
|
||||
|
||||
mat1_indices.copy_(mat1_col_indices);
|
||||
|
||||
Tensor mat2_indices_ = mat2._indices().contiguous();
|
||||
Tensor mat2_values = mat2._values().contiguous();
|
||||
Tensor mat2_row_indices = mat2_indices_.select(0, 0);
|
||||
Tensor mat2_col_indices = mat2_indices_.select(0, 1);
|
||||
|
||||
Tensor mat2_indptr = _to_csr_int(mat2_row_indices, mat2.size(0), mat2._nnz());
|
||||
Tensor mat2_indices = at::empty({mat2_col_indices.size(0)}, mat2_col_indices.options().dtype(kInt));
|
||||
mat2_indices.copy_(mat2_col_indices);
|
||||
|
||||
auto m = mat1.size(0);
|
||||
auto k1 = mat1.size(1);
|
||||
|
||||
auto k2 = mat2.size(0);
|
||||
auto n = mat2.size(1);
|
||||
TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k1 <= INT_MAX),
|
||||
"At the moment, cusparseDcsrgemm2 only supports m, n, k, nnz with the bound [val] <= ", INT_MAX, ".",
|
||||
"If you need this, please file an issue on GitHub."
|
||||
);
|
||||
auto output_indices = result._indices();
|
||||
auto output_values = result._values();
|
||||
|
||||
if ((k1 == 0 && k2 == 0) || (n == 0 && m == 0)) {
|
||||
output_indices.zero_();
|
||||
output_values.zero_();
|
||||
return;
|
||||
}
|
||||
|
||||
csrMatrixRef<scalar_t> csr_mat1(
|
||||
mat1_indices.data_ptr<int>(),
|
||||
mat1_indptr.data_ptr<int>(),
|
||||
mat1_values.data_ptr<scalar_t>(),
|
||||
(int)mat1._nnz(),
|
||||
{(int)mat1.size(0), (int)mat1.size(1)});
|
||||
|
||||
csrMatrixRef<scalar_t> csr_mat2(
|
||||
mat2_indices.data_ptr<int>(),
|
||||
mat2_indptr.data_ptr<int>(),
|
||||
mat2_values.data_ptr<scalar_t>(),
|
||||
(int)mat2._nnz(),
|
||||
{(int)mat2.size(0), (int)mat2.size(1)});
|
||||
|
||||
// Sparse matrix multiplication
|
||||
CusparseMatrixMultiplyOp<scalar_t> op;
|
||||
csrOutput csr_output = op(csr_mat1, csr_mat2, output_values, output_indices);
|
||||
auto nnz = csr_output.nnz_;
|
||||
|
||||
output_values.set_(csr_output.csr_values_);
|
||||
output_indices.resize_({2, nnz});
|
||||
auto output_indices_accessor = output_indices.packed_accessor<int64_t, 2>();
|
||||
|
||||
auto csr_output_pointers_accessor =
|
||||
csr_output.csr_pointers_.packed_accessor<int, 1>();
|
||||
|
||||
auto csr_output_ind_accessor =
|
||||
csr_output.csr_indices_.packed_accessor<int, 1>();
|
||||
|
||||
auto major_dim = result.size(0);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
|
||||
// Filling the COO row indices
|
||||
thrust::for_each(
|
||||
policy,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_counting_iterator(int64_t(major_dim)),
|
||||
[output_indices_accessor,
|
||||
csr_output_pointers_accessor,
|
||||
major_dim,
|
||||
nnz] __device__(int64_t i) {
|
||||
auto Ap = csr_output_pointers_accessor.data();
|
||||
int64_t* indices_row = output_indices_accessor[0].data();
|
||||
|
||||
for (int jj = Ap[i]; jj < Ap[i + 1]; jj++) {
|
||||
indices_row[jj] = i;
|
||||
}
|
||||
});
|
||||
|
||||
// Filling the COO column indices
|
||||
thrust::for_each(
|
||||
policy,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_counting_iterator(int64_t(csr_output.nnz_)),
|
||||
[output_indices_accessor,
|
||||
csr_output_pointers_accessor,
|
||||
csr_output_ind_accessor,
|
||||
major_dim,
|
||||
nnz] __device__(int64_t i) {
|
||||
int64_t* indices_col = output_indices_accessor[1].data();
|
||||
indices_col[i] = csr_output_ind_accessor[i];
|
||||
});
|
||||
}
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
Tensor sparse_matrix_mask_helper_cuda(
|
||||
const SparseTensor& t,
|
||||
const Tensor& mask_indices
|
||||
) {
|
||||
/*
|
||||
This is a helper function which filter values from `t._values()` using the `mask_indices`.
|
||||
This CUDA implementation uses `thrust::set_intersection_by_key` operation to find the intersection
|
||||
of the `mask_indices` and the `t._indices()` to then filter the values.
|
||||
|
||||
Inputs:
|
||||
`t` - tensor input
|
||||
`mask_indices` - mask indices tensor
|
||||
*/
|
||||
int64_t r_nnz = mask_indices.size(1);
|
||||
auto t_v = t._values().contiguous();
|
||||
|
||||
Tensor r_values = at::zeros({r_nnz}, t_v.options());
|
||||
|
||||
auto t_i = t._indices().contiguous();
|
||||
auto t_indices_accessor = t_i.packed_accessor<int64_t, 2>();
|
||||
auto t_nnz = t._nnz();
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
|
||||
Tensor t_flatten_indices = at::empty({t_nnz}, mask_indices.options());
|
||||
auto t_flatten_indices_accessor = t_flatten_indices.packed_accessor<int64_t, 1>();
|
||||
auto t_n_cols = t.size(1);
|
||||
|
||||
// Step 1: flatten the sparse indices `t._indices()` tensor into a 1D indices tensor `t_flatten_indices`.
|
||||
thrust::for_each(
|
||||
policy,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_counting_iterator(int64_t(t_nnz)),
|
||||
[t_indices_accessor, t_flatten_indices_accessor, t_n_cols] __device__ (int64_t i) mutable {
|
||||
auto index = t_indices_accessor[0][i] * t_n_cols + t_indices_accessor[1][i];
|
||||
t_flatten_indices_accessor[i] = index;
|
||||
});
|
||||
|
||||
Tensor mask_flatten_indices = at::empty({r_nnz}, mask_indices.options());
|
||||
auto mask_flatten_indices_accessor = mask_flatten_indices.packed_accessor<int64_t, 1>();
|
||||
auto mask_indices_accessor = mask_indices.packed_accessor<int64_t, 2>();
|
||||
|
||||
// Step 2: flatten the sparse indices `mask_indices` tensor into a 1D indices tensor `mask_flatten_indices`.
|
||||
thrust::for_each(
|
||||
policy,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_counting_iterator(int64_t(r_nnz)),
|
||||
[mask_flatten_indices_accessor, mask_indices_accessor, t_n_cols] __device__ (int64_t i) mutable {
|
||||
auto index = mask_indices_accessor[0][i] * t_n_cols + mask_indices_accessor[1][i];
|
||||
mask_flatten_indices_accessor[i] = index;
|
||||
});
|
||||
auto max_sz = std::max(r_nnz, t_nnz);
|
||||
Tensor t_index_set = at::empty({max_sz}, mask_indices.options());
|
||||
|
||||
// Step 3: find the intersection between `t_flatten_indices` and `mask_flatten_indices` indices.
|
||||
// Note: the original positions from `t_flatten_indices` are stored in `t_index_set`
|
||||
auto result_end = thrust::set_intersection_by_key(
|
||||
policy,
|
||||
t_flatten_indices.data_ptr<int64_t>(),
|
||||
t_flatten_indices.data_ptr<int64_t>() + t_nnz,
|
||||
mask_flatten_indices.data_ptr<int64_t>(),
|
||||
mask_flatten_indices.data_ptr<int64_t>() + r_nnz,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_discard_iterator(),
|
||||
t_index_set.data_ptr<int64_t>());
|
||||
|
||||
// new_sz is the size of the intersection of the `mask_indices` and the `t._indices()`
|
||||
auto new_sz = thrust::distance(t_index_set.data_ptr<int64_t>(), result_end.second);
|
||||
|
||||
Tensor mask_index_set = at::empty({max_sz}, mask_indices.options());
|
||||
|
||||
// Step 4: Repeat the intersection operation between `mask_flatten_indices` and `t_flatten_indices` indices.
|
||||
// But now store the positions from `mask_flatten_indices` in `mask_index_set`
|
||||
thrust::set_intersection_by_key(
|
||||
policy,
|
||||
mask_flatten_indices.data_ptr<int64_t>(),
|
||||
mask_flatten_indices.data_ptr<int64_t>() + r_nnz,
|
||||
t_flatten_indices.data_ptr<int64_t>(),
|
||||
t_flatten_indices.data_ptr<int64_t>() + t_nnz,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_discard_iterator(),
|
||||
mask_index_set.data_ptr<int64_t>());
|
||||
|
||||
// Step 5: Filter `t._values()` values by using `mask_index_set` and `t_index_set`
|
||||
AT_DISPATCH_FLOATING_TYPES(r_values.scalar_type(), "_sparse_matrix_mask", [&] {
|
||||
auto r_values_accessor = r_values.packed_accessor<scalar_t, 1>();
|
||||
auto t_values = t_v.packed_accessor<scalar_t, 1>();
|
||||
auto mask_index_set_ptr = mask_index_set.packed_accessor<int64_t, 1>();
|
||||
auto t_index_set_ptr = t_index_set.packed_accessor<int64_t, 1>();
|
||||
thrust::for_each(
|
||||
policy,
|
||||
thrust::make_counting_iterator(int64_t(0)),
|
||||
thrust::make_counting_iterator(int64_t(new_sz)),
|
||||
[r_values_accessor, t_values, t_index_set_ptr, mask_index_set_ptr, r_nnz] __device__ (int64_t i) mutable {
|
||||
int64_t target = mask_index_set_ptr[i];
|
||||
int64_t origin = t_index_set_ptr[i];
|
||||
r_values_accessor[target] = t_values[origin];
|
||||
});
|
||||
});
|
||||
return r_values;
|
||||
}
|
||||
|
||||
Tensor sparse_sparse_matmul_cuda(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
TORCH_INTERNAL_ASSERT(mat1_.is_sparse());
|
||||
TORCH_INTERNAL_ASSERT(mat2_.is_sparse());
|
||||
TORCH_CHECK(mat1_.dim() == 2);
|
||||
TORCH_CHECK(mat2_.dim() == 2);
|
||||
TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_mm: scalar values expected, mat1 got ", mat1_.dense_dim(), "D values");
|
||||
TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_mm: scalar values expected, mat2 got ", mat2_.dense_dim(), "D values");
|
||||
|
||||
TORCH_CHECK(
|
||||
mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
|
||||
|
||||
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
|
||||
"mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type());
|
||||
|
||||
auto output = at::native::empty_like(mat1_);
|
||||
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
|
||||
sparse_sparse_matmul_cuda_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
198
benchmarks/sparse/matmul_dlmc_bench.py
Normal file
198
benchmarks/sparse/matmul_dlmc_bench.py
Normal file
@ -0,0 +1,198 @@
|
||||
# Sparse benchmarks
|
||||
|
||||
# These benchmarks are for the sparse matrix functionality.
|
||||
# They exist for comparing the performance of sparse matrix routines
|
||||
# torch.sparse.mm(sparse, sparse)` with different backends (CPU/CUDA)
|
||||
# and with other frameworks such as scipy.
|
||||
|
||||
import sys
|
||||
from scipy import sparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import argparse
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark_utils
|
||||
|
||||
def read_matrix_params(path):
|
||||
sys.stdin = open(path)
|
||||
nrows, ncols, nnz = map(lambda el: int(el), input().split(', '))
|
||||
return (nrows, ncols), nnz
|
||||
|
||||
|
||||
def load_matrix(path):
|
||||
sys.stdin = open(path)
|
||||
nrows, ncols, nnz = map(lambda el: int(el), input().split(', '))
|
||||
index_pointers = map(lambda el: int(el), input().split())
|
||||
indices = map(lambda el: int(el), input().split())
|
||||
|
||||
index_pointers = list(index_pointers)
|
||||
indices = list(indices)
|
||||
data = np.random.rand(nnz)
|
||||
coo = sparse.csr_matrix(
|
||||
(data, np.array(indices), np.array(index_pointers)),
|
||||
shape=(nrows, ncols)).tocoo()
|
||||
return torch.sparse_coo_tensor([coo.row, coo.col], coo.data, coo.shape)
|
||||
|
||||
|
||||
def scipy_coo_matmul(mat1, mat2):
|
||||
result = mat1.dot(mat2).tocoo()
|
||||
return torch.sparse_coo_tensor([result.row, result.col], result.data,
|
||||
result.shape)
|
||||
|
||||
|
||||
def to_coo_scipy(x):
|
||||
indices_1 = x._indices().numpy()
|
||||
values_1 = x._values().numpy()
|
||||
return sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])),
|
||||
shape=x.shape)
|
||||
|
||||
|
||||
def torch_backward(a_dense, b_dense):
|
||||
a_dense.requires_grad = True
|
||||
b_dense.requires_grad = True
|
||||
r1 = a_dense.matmul(b_dense)
|
||||
f1 = torch.sum(r1)
|
||||
f1.backward()
|
||||
|
||||
|
||||
def sparse_torch_backward(a, b):
|
||||
a.requires_grad = True
|
||||
b.requires_grad = True
|
||||
|
||||
r2 = torch.sparse.mm(a, b)
|
||||
f2 = torch.sparse.sum(r2)
|
||||
f2.backward()
|
||||
|
||||
|
||||
def load_dataset(dataset_path, hidden_size, sparsity, n_limit=20):
|
||||
current_folder_path = f"{dataset_path}/{sparsity}"
|
||||
path = Path(current_folder_path)
|
||||
files = path.glob('**/*.smtx')
|
||||
xs = []
|
||||
ys = []
|
||||
print(dataset_path, hidden_size, sparsity)
|
||||
index = 0
|
||||
for elem in files:
|
||||
if index == n_limit:
|
||||
break
|
||||
print('.', end='')
|
||||
size, nnz = read_matrix_params(elem.as_posix())
|
||||
if size[1] == hidden_size:
|
||||
xs.append(load_matrix(elem.as_posix()))
|
||||
if size[0] == hidden_size:
|
||||
ys.append(load_matrix(elem.as_posix()))
|
||||
index += 1
|
||||
print()
|
||||
return zip(xs, ys)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
path = Path()
|
||||
parser = argparse.ArgumentParser(description='Sparse Matmul Bench')
|
||||
|
||||
parser.add_argument('--path', type=str, help='dataset path')
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
help='dataset name',
|
||||
default='random_pruning')
|
||||
parser.add_argument('--operation',
|
||||
type=str,
|
||||
help='matmul or backward',
|
||||
default='matmul')
|
||||
parser.add_argument('--output',
|
||||
type=str,
|
||||
help='dataframe output path',
|
||||
default='/tmp/matmul_bench.pkl')
|
||||
args = parser.parse_args()
|
||||
print('path =', args.path)
|
||||
print('dataset =', args.dataset)
|
||||
print('operation =', args.operation)
|
||||
print('output =', args.output)
|
||||
|
||||
dataset_path = args.path
|
||||
dataset_name = args.dataset
|
||||
dataset_path = f"{dataset_path}/{dataset_name}"
|
||||
df_output_path = args.output
|
||||
tasks = []
|
||||
if args.operation == 'matmul':
|
||||
tasks = [
|
||||
("matmul", "cpu", "torch", "torch.mm(dense_x, dense_y)"),
|
||||
("matmul", "cpu", "torch.sparse", "torch.sparse.mm(tx, ty)"),
|
||||
("matmul", "cpu", "scipy",
|
||||
"scipy_coo_matmul(scipy_varx, scipy_vary)"),
|
||||
("matmul", "cuda", "torch",
|
||||
"torch.mm(dense_cuda_x, dense_cuda_y)"),
|
||||
("matmul", "cuda", "torch.sparse",
|
||||
"torch.sparse.mm(tx_cuda, ty_cuda)"),
|
||||
]
|
||||
else:
|
||||
tasks = [
|
||||
("backward", "cpu", "torch", "torch_backward(dense_x, dense_y)"),
|
||||
("backward", "cpu", "torch.sparse",
|
||||
"sparse_torch_backward(tx, ty)"),
|
||||
("backward", "cuda", "torch",
|
||||
"torch_backward(dense_cuda_x, dense_cuda_y)"),
|
||||
("backward", "cuda", "torch.sparse",
|
||||
"sparse_torch_backward(tx_cuda, ty_cuda)"),
|
||||
]
|
||||
serialized_results = []
|
||||
repeats = 2
|
||||
timers = [
|
||||
benchmark_utils.Timer(
|
||||
stmt=stmt,
|
||||
globals={
|
||||
"scipy_coo_matmul": scipy_coo_matmul,
|
||||
"torch_backward": torch_backward,
|
||||
"sparse_torch_backward": sparse_torch_backward,
|
||||
"scipy_varx": to_coo_scipy(x),
|
||||
"scipy_vary": to_coo_scipy(y),
|
||||
"tx": x,
|
||||
"ty": y,
|
||||
"tx_cuda": x.cuda(),
|
||||
"ty_cuda": y.cuda(),
|
||||
"dense_cuda_x": x.to_dense().cuda(),
|
||||
"dense_cuda_y": y.to_dense().cuda(),
|
||||
"dense_x": x.to_dense(),
|
||||
"dense_y": y.to_dense(),
|
||||
},
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=f"{sparsity}",
|
||||
env=device,
|
||||
# num_threads=num_threads,
|
||||
) for hidden_size in [512]
|
||||
for sparsity in [0.5, 0.7, 0.8, 0.9, 0.95, 0.98]
|
||||
for label, device, sub_label, stmt in tasks
|
||||
for num_threads in [1, 4, 8, 16]
|
||||
for x, y in load_dataset(dataset_path, hidden_size, sparsity)
|
||||
]
|
||||
measurements = []
|
||||
|
||||
for i, timer in enumerate(timers * repeats):
|
||||
m = timer.blocked_autorange(min_run_time=0.05)
|
||||
serialized_results.append(pickle.dumps(m))
|
||||
m.metadata = {
|
||||
"device": 'cuda' if m.task_spec.env.find("cuda") >= 0 else 'cpu'
|
||||
}
|
||||
measurements.append(m)
|
||||
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
|
||||
comparison = benchmark_utils.Compare(
|
||||
[pickle.loads(i) for i in serialized_results])
|
||||
|
||||
print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
|
||||
comparison.print()
|
||||
|
||||
print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
|
||||
comparison.trim_significant_figures()
|
||||
comparison.colorize()
|
||||
comparison.print()
|
||||
|
||||
table = [(m.task_spec.sub_label, m.task_spec.description,
|
||||
m.metadata["device"], m.mean) for m in measurements]
|
||||
df = pd.DataFrame(table, columns=['method', 'sparsity', 'device', 'time'])
|
||||
df.to_pickle(df_output_path)
|
14
benchmarks/sparse/test.sh
Normal file
14
benchmarks/sparse/test.sh
Normal file
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
DATASET_ROOT_DIR=$HOME/datasets/
|
||||
|
||||
# wget https://storage.googleapis.com/sgk-sc2020/dlmc.tar.gz -P $DATASET_ROOT_DIR
|
||||
# tar -xvf $DATASET_ROOT_DIR/dlmc.tar.gz
|
||||
|
||||
echo "!! SPARSE SPMS TIME BENCHMARK!! "
|
||||
|
||||
python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation matmul --output /tmp/matmul_bench.pkl
|
||||
python matmul_dlmc_bench.py --path $DATASET_ROOT_DIR/dlmc/rn50 --dataset random_pruning --operation backward --output /tmp/backward_bench.pkl
|
||||
|
||||
python plot_results.py -i /tmp/matmul_bench.pkl
|
||||
python plot_results.py -i /tmp/backward_bench.pkl
|
@ -8,14 +8,18 @@ import itertools
|
||||
import functools
|
||||
import operator
|
||||
import random
|
||||
from collections import defaultdict
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
|
||||
do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS
|
||||
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
|
||||
from numbers import Number
|
||||
from torch.autograd.gradcheck import gradcheck
|
||||
from typing import Dict, Any
|
||||
|
||||
if TEST_SCIPY:
|
||||
import scipy.sparse
|
||||
|
||||
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
@ -3008,6 +3012,157 @@ class TestSparse(TestCase):
|
||||
test_op(3, 100, [3, 4, 2, 3, 5, 2])
|
||||
test_op(4, 100, [3, 4, 2, 3, 5, 2])
|
||||
|
||||
@skipIfRocm
|
||||
def test_sparse_matmul(self):
|
||||
"""
|
||||
This function test `torch.sparse.mm` when both the mat1 and mat2 are sparse tensors.
|
||||
"""
|
||||
|
||||
def _indices2csr(indices, dim):
|
||||
nnz = len(indices)
|
||||
r = [0] * (dim + 1)
|
||||
last_i = 0
|
||||
for i in indices:
|
||||
if i != last_i:
|
||||
for _i in range(last_i, i + 1):
|
||||
r[_i + 1] = r[last_i + 1]
|
||||
last_i = i
|
||||
r[last_i + 1] += 1
|
||||
for _i in range(last_i, dim):
|
||||
r[_i + 1] = r[last_i + 1]
|
||||
assert r[-1] == nnz
|
||||
return r
|
||||
|
||||
def sparse_mm(a, b, method='scipy'):
|
||||
a = a.to('cpu')
|
||||
b = b.to('cpu')
|
||||
if method == 'scipy':
|
||||
indices_1 = a._indices().numpy()
|
||||
values_1 = a._values().numpy()
|
||||
indices_2 = b._indices().numpy()
|
||||
values_2 = b._values().numpy()
|
||||
|
||||
mat1 = scipy.sparse.coo_matrix((values_1, (indices_1[0], indices_1[1])), shape=a.shape)
|
||||
mat2 = scipy.sparse.coo_matrix((values_2, (indices_2[0], indices_2[1])), shape=b.shape)
|
||||
result = mat1.dot(mat2).tocoo()
|
||||
return torch.sparse_coo_tensor([result.row, result.col], result.data, result.shape)
|
||||
else:
|
||||
assert a.shape[1] == b.shape[0]
|
||||
n, p = a.shape
|
||||
p, m = b.shape
|
||||
indices_a = a._indices()
|
||||
values_a = a._values()
|
||||
indices_b = b._indices()
|
||||
values_b = b._values()
|
||||
nnz1 = len(indices_a[0])
|
||||
nnz2 = len(indices_b[0])
|
||||
|
||||
if a.is_coalesced() and b.is_coalesced():
|
||||
r2 = _indices2csr(indices_b[0], b.shape[0])
|
||||
d = defaultdict(values_b.numpy().dtype.type)
|
||||
for n1 in range(nnz1):
|
||||
for n2 in range(r2[indices_a[1][n1]], r2[indices_a[1][n1] + 1]):
|
||||
d[indices_a[0][n1].item(), indices_b[1][n2].item()] += values_a[n1] * values_b[n2]
|
||||
|
||||
else:
|
||||
d = defaultdict(values_b.numpy().dtype.type)
|
||||
for n1 in range(nnz1):
|
||||
for n2 in range(nnz2):
|
||||
if indices_b[0][n2] == indices_a[1][n1]:
|
||||
d[indices_a[0][n1].item(), indices_b[1][n2].item()] += values_a[n1] * values_b[n2]
|
||||
i3 = []
|
||||
j3 = []
|
||||
values = []
|
||||
for i, j in sorted(d):
|
||||
i3.append(i)
|
||||
j3.append(j)
|
||||
values.append(d[i, j])
|
||||
return torch.sparse_coo_tensor(torch.tensor([i3, j3]), torch.tensor(values), (n, m))
|
||||
|
||||
def grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b):
|
||||
def test_grad_dense(a_s, b_s, g_s):
|
||||
a = a_s.to_dense().detach()
|
||||
b = b_s.to_dense().detach()
|
||||
g = g_s.to_dense().detach()
|
||||
|
||||
a.requires_grad_(True)
|
||||
b.requires_grad_(True)
|
||||
c = a @ b
|
||||
c.backward(g)
|
||||
return a.grad.sparse_mask(a_s.coalesce()), b.grad.sparse_mask(b_s.coalesce())
|
||||
|
||||
a, _, _ = self._gen_sparse(sparse_dims, nnz, shape_a)
|
||||
b, _, _ = self._gen_sparse(sparse_dims, nnz, shape_b)
|
||||
a.requires_grad_(True)
|
||||
b.requires_grad_(True)
|
||||
|
||||
c = torch.sparse.mm(a, b)
|
||||
c2 = c.to_dense().detach()
|
||||
c2 = torch.rand_like(c2)
|
||||
g = c2.sparse_mask(c.coalesce())
|
||||
|
||||
c.backward(g)
|
||||
|
||||
a_grad, b_grad = test_grad_dense(a, b, g)
|
||||
self.assertEqual(a.grad, a_grad)
|
||||
self.assertEqual(b.grad, b_grad)
|
||||
|
||||
def test_sparse_matmul(sparse_dims, nnz, shape_a, shape_b):
|
||||
a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a)
|
||||
b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b)
|
||||
|
||||
# python implementation
|
||||
r1 = sparse_mm(a, b, 'scipy' if TEST_SCIPY else 'direct')
|
||||
|
||||
self.assertEqual(r1.to_dense(), torch.mm(a.to_dense(), b.to_dense()))
|
||||
|
||||
# cpp implementation
|
||||
r2 = torch.sparse.mm(a, b)
|
||||
self.assertEqual(r1, r2)
|
||||
|
||||
a.requires_grad_(True)
|
||||
b.requires_grad_(True)
|
||||
|
||||
# check autograd support on sparse matmul
|
||||
def fn(D1, D2):
|
||||
return torch.sparse.mm(D1, D2).to_dense()
|
||||
|
||||
# For cuda, `nondet_tol` is set with `1e-5`
|
||||
# This is because cuSparse sometimes returns approximate zero values like `~e-323`
|
||||
# TODO: Check this cuSparse issue.
|
||||
# This happens when you do chain multiplication `torch.sparse.mm` operations
|
||||
gradcheck(fn, (a, b), check_sparse_nnz=True, nondet_tol=1e-5)
|
||||
grad_with_custom_sparsity_pattern_test_helper(sparse_dims, nnz, shape_a, shape_b)
|
||||
|
||||
def test_error_cases():
|
||||
def fn(sparse_dims, nnz, shape_a, shape_b):
|
||||
a, i_a, v_a = self._gen_sparse(sparse_dims, nnz, shape_a)
|
||||
b, i_b, v_b = self._gen_sparse(sparse_dims, nnz, shape_b)
|
||||
r2 = torch.sparse.mm(a, b)
|
||||
|
||||
# This is not a matrix
|
||||
self.assertRaises(RuntimeError, lambda: fn(3, 4, [2, 2, 2], [2, 2, 2]))
|
||||
|
||||
# Shapes does not
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
r"mat1 and mat2 shapes cannot be multiplied \(2x3 and 4x2\)",
|
||||
lambda: fn(2, 10, [2, 3], [4, 2]))
|
||||
|
||||
def different_dtypes():
|
||||
a, i_a, v_a = self._gen_sparse(2, 10, [2, 2])
|
||||
b, i_b, v_b = self._gen_sparse(2, 10, [2, 2])
|
||||
r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32))
|
||||
|
||||
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
|
||||
|
||||
for n in range(2, 5):
|
||||
for m in range(2, 8):
|
||||
for p in range(2, 8):
|
||||
test_sparse_matmul(2, 10, [n, m], [m, p])
|
||||
|
||||
test_sparse_matmul(2, 0, [0, 0], [0, 0])
|
||||
test_sparse_matmul(2, 0, [0, 10], [10, 0])
|
||||
test_error_cases()
|
||||
|
||||
class TestUncoalescedSparse(TestSparse):
|
||||
def setUp(self):
|
||||
|
@ -1353,6 +1353,10 @@
|
||||
- name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
|
||||
self: _sparse_softmax_backward_data(grad, result, dim, self)
|
||||
|
||||
- name: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor
|
||||
self: sparse_sparse_matmul_backward(grad, self, other, 0)
|
||||
other: sparse_sparse_matmul_backward(grad, self, other, 1)
|
||||
|
||||
- name: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor
|
||||
self: softplus_backward(grad, self, beta, threshold, result)
|
||||
|
||||
|
@ -423,7 +423,11 @@ def gradcheck(
|
||||
return fail_test('grad is sparse tensor, but has incorrect dense_dim')
|
||||
gi = gi.to_dense()
|
||||
di = di.to_dense()
|
||||
if not gi.eq(0).all():
|
||||
|
||||
if check_sparse_nnz:
|
||||
if not torch.allclose(gi, torch.zeros_like(gi)):
|
||||
return fail_test('backward not multiplied by grad_output')
|
||||
elif not gi.eq(0).all():
|
||||
return fail_test('backward not multiplied by grad_output')
|
||||
if gi.dtype != di.dtype or gi.device != di.device or gi.is_sparse != di.is_sparse:
|
||||
return fail_test("grad is incorrect type")
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <ATen/BatchedTensorImpl.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
|
||||
#include <ciso646>
|
||||
#include <algorithm>
|
||||
@ -628,6 +629,57 @@ Tensor _sparse_addmm_sparse_backward(const Tensor& grad, const Tensor& sparse_,
|
||||
return grad_sparse.sparse_mask(sparse);
|
||||
}
|
||||
|
||||
// This function return a new SparseTensor with values from Tensor `input` filtered by indices of `mask`
|
||||
// and values are ignored. `input` and `mask` are sparse matrices, a sparse tensor with sparse_dim=2 and dense_dim=2,
|
||||
// and they must have the same shape.
|
||||
// Note that the `output` must have the same `indices` as the `mask` so we are using just a clone.
|
||||
// However, to get `values` we have to use specific helper function for CPU/CUDA and use the `mask` data to filter `values`
|
||||
// That's why we created this `_sparse_matrix_mask_helper` function.
|
||||
Tensor _sparse_matrix_mask(const Tensor& input, const Tensor& mask){
|
||||
Tensor output = at::native::empty_like(mask);
|
||||
Tensor mask_indices = mask._indices().clone();
|
||||
Tensor r_values;
|
||||
if (mask._nnz() == 0) {
|
||||
r_values = at::native::zeros_like(mask._values());
|
||||
} else {
|
||||
r_values = _sparse_matrix_mask_helper(input, mask_indices.contiguous());
|
||||
}
|
||||
at::sparse::get_sparse_impl(output)->set_indices_and_values_unsafe(mask_indices, r_values);
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor sparse_sparse_matmul_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& a,
|
||||
const Tensor& b,
|
||||
int64_t grad_order) {
|
||||
/*
|
||||
To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we can start from the following definition
|
||||
for dense tensors:
|
||||
|
||||
c = a @ b
|
||||
then
|
||||
a_grad = c_grad @ b^T
|
||||
b_grad = a^T @ c_grad
|
||||
|
||||
So for sparse matrices we can use the following definition:
|
||||
|
||||
if grad_order == 0:
|
||||
a_grad = sparse_matrix_mask(c_grad @ b^T, mask=a)
|
||||
else:
|
||||
b_grad = sparse_matrix_mask(a^T @ c_grad, mask=b)
|
||||
*/
|
||||
TORCH_CHECK(
|
||||
grad_order == 0 || grad_order == 1,
|
||||
": grad_order not in [0, 1] at sparse_sparse_matmul_backward function");
|
||||
if (grad_order == 0) {
|
||||
auto a_grad = _sparse_sparse_matmul(grad, b.t());
|
||||
return _sparse_matrix_mask(a_grad.coalesce(), a.coalesce());
|
||||
}
|
||||
auto b_grad = _sparse_sparse_matmul(a.t(), grad);
|
||||
return _sparse_matrix_mask(b_grad.coalesce(), b.coalesce());
|
||||
}
|
||||
|
||||
Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) {
|
||||
auto transposed_sizes = self.transpose(dim, 0).sizes().vec();
|
||||
auto flatten = [&](const Tensor & t) {
|
||||
|
@ -75,6 +75,7 @@ at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const &
|
||||
at::Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, const Scalar & alpha);
|
||||
at::Tensor mm_mat2_backward(const at::Tensor & grad, const at::Tensor & mat1, at::IntArrayRef sizes, at::IntArrayRef strides, const at::Scalar & alpha);
|
||||
at::Tensor _sparse_addmm_sparse_backward(const at::Tensor& grad, const at::Tensor& sparse_, const at::Tensor& dense, const at::Scalar& alpha);
|
||||
at::Tensor sparse_sparse_matmul_backward(const at::Tensor& grad, const at::Tensor& mat1, const at::Tensor& mat2,int64_t grad_order);
|
||||
at::Tensor renorm_backward(const at::Tensor & grad, const at::Tensor & self, at::Scalar p, int64_t dim, at::Scalar maxnorm);
|
||||
at::Tensor repeat_backward(at::Tensor grad, at::IntArrayRef repeats, at::IntArrayRef input_shape);
|
||||
at::Tensor _fused_dropout_backward(at::Tensor grad, at::Tensor mask, double p1m);
|
||||
|
@ -45,15 +45,20 @@ def addmm(mat: Tensor, mat1: Tensor, mat2: Tensor,
|
||||
def mm(mat1: Tensor, mat2: Tensor) -> Tensor:
|
||||
r"""
|
||||
Performs a matrix multiplication of the sparse matrix :attr:`mat1`
|
||||
and dense matrix :attr:`mat2`. Similar to :func:`torch.mm`, If :attr:`mat1` is a
|
||||
and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, If :attr:`mat1` is a
|
||||
:math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
|
||||
:math:`(n \times p)` dense tensor. :attr:`mat1` need to have `sparse_dim = 2`.
|
||||
:math:`(n \times p)` tensor. :attr:`mat1` need to have `sparse_dim = 2`.
|
||||
This function also supports backward for both matrices. Note that the gradients of
|
||||
:attr:`mat1` is a coalesced sparse tensor.
|
||||
|
||||
Args:
|
||||
mat1 (Tensor): the first sparse matrix to be multiplied
|
||||
mat2 (Tensor): the second dense matrix to be multiplied
|
||||
mat1 (SparseTensor): the first sparse matrix to be multiplied
|
||||
mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
|
||||
|
||||
Shape:
|
||||
The format of the output tensor of this function follows:
|
||||
- sparse x sparse -> sparse
|
||||
- sparse x dense -> dense
|
||||
|
||||
Example::
|
||||
|
||||
@ -81,6 +86,8 @@ def mm(mat1: Tensor, mat2: Tensor) -> Tensor:
|
||||
values=tensor([ 0.1394, -0.6415, -2.1639, 0.1394, -0.6415, -2.1639]),
|
||||
size=(2, 3), nnz=6, layout=torch.sparse_coo)
|
||||
"""
|
||||
if mat1.is_sparse and mat2.is_sparse:
|
||||
return torch._sparse_sparse_matmul(mat1, mat2)
|
||||
return torch._sparse_mm(mat1, mat2)
|
||||
|
||||
|
||||
|
@ -7746,6 +7746,10 @@ CUDA_SPARSE_MAP = collections.OrderedDict(
|
||||
[
|
||||
("cusparseStatus_t", ("hipsparseStatus_t", CONV_MATH_FUNC, API_SPARSE)),
|
||||
("cusparseHandle_t", ("hipsparseHandle_t", CONV_MATH_FUNC, API_SPARSE)),
|
||||
(
|
||||
"CUSPARSE_POINTER_MODE_HOST",
|
||||
("HIPSPARSE_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_SPARSE),
|
||||
),
|
||||
("cusparseOperation_t", ("hipsparseOperation_t", CONV_TYPE, API_SPARSE)),
|
||||
(
|
||||
"cusparseCreateMatDescr",
|
||||
@ -7767,6 +7771,17 @@ CUDA_SPARSE_MAP = collections.OrderedDict(
|
||||
"cusparseXcsrsort_bufferSizeExt",
|
||||
("hipsparseXcsrsort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE),
|
||||
),
|
||||
("cusparseCreateCsrgemm2Info", ("hipsparseCreateCsrgemm2Info", CONV_MATH_FUNC, API_SPARSE)),
|
||||
(
|
||||
"cusparseDestroyCsrgemm2Info",
|
||||
("hipsparseDestroyCsrgemm2Info", CONV_MATH_FUNC, API_SPARSE),
|
||||
),
|
||||
("cusparseXcsrgemm2Nnz", ("hipsparseXcsrgemm2Nnz", CONV_MATH_FUNC, API_SPARSE)),
|
||||
("cusparseDcsrgemm2_bufferSizeExt", ("hipsparseDcsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE)),
|
||||
("cusparseScsrgemm2_bufferSizeExt", ("hipsparseScsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE)),
|
||||
("cusparseDcsrgemm2", ("hipsparseDcsrgemm2", CONV_MATH_FUNC, API_SPARSE)),
|
||||
("cusparseScsrgemm2", ("hipsparseScsrgemm2", CONV_MATH_FUNC, API_SPARSE)),
|
||||
("cusparseSetPointerMode", ("hipsparseSetPointerMode", CONV_MATH_FUNC, API_SPARSE)),
|
||||
("cusparseXcsrsort", ("hipsparseXcsrsort", CONV_MATH_FUNC, API_SPARSE)),
|
||||
(
|
||||
"cusparseXcoosort_bufferSizeExt",
|
||||
|
Reference in New Issue
Block a user