mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add CSR (compressed sparse row) layout for sparse tensors (#50937)
Summary: Implement compressed sparse row format. Derived from the GCS implementation at https://github.com/pytorch/pytorch/pull/44190 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50937 Reviewed By: mrshenli Differential Revision: D27439865 Pulled By: ezyang fbshipit-source-id: 3ba3dcb9679505b980ff6a5f513e913bbae2fb1d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c6d9ca0c2b
commit
5fb1142702
@ -130,6 +130,7 @@ genrule(
|
||||
"aten/src/ATen/RegisterMkldnnCPU.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedCPU.cpp",
|
||||
"aten/src/ATen/RegisterSparseCPU.cpp",
|
||||
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
|
||||
"aten/src/ATen/RegisterMeta.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
|
||||
|
153
aten/src/ATen/SparseCsrTensorImpl.cpp
Normal file
153
aten/src/ATen/SparseCsrTensorImpl.cpp
Normal file
@ -0,0 +1,153 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/SparseCsrTensorImpl.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
|
||||
namespace at {
|
||||
namespace {
|
||||
DeviceType SparseCsrTensorSetToDeviceType(DispatchKeySet key_set) {
|
||||
if (key_set.has(DispatchKey::SparseCsrCPU)) {
|
||||
return kCPU;
|
||||
} else if (key_set.has(DispatchKey::SparseCsrCUDA)) {
|
||||
return kCUDA;
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"Cannot construct SparseCsrTensor with non-sparse tensor type ID ",
|
||||
key_set);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
SparseCsrTensorImpl::SparseCsrTensorImpl(
|
||||
at::DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type)
|
||||
: SparseCsrTensorImpl(
|
||||
key_set,
|
||||
data_type,
|
||||
at::empty(
|
||||
{0},
|
||||
at::initialTensorOptions()
|
||||
.device(SparseCsrTensorSetToDeviceType(key_set))
|
||||
.dtype(ScalarType::Int)) // crow_indices
|
||||
,
|
||||
at::empty(
|
||||
{0},
|
||||
at::initialTensorOptions()
|
||||
.device(SparseCsrTensorSetToDeviceType(key_set))
|
||||
.dtype(ScalarType::Int)) // col_indices
|
||||
,
|
||||
at::empty(
|
||||
{0},
|
||||
at::initialTensorOptions()
|
||||
.device(SparseCsrTensorSetToDeviceType(key_set))
|
||||
.dtype(data_type)) // values
|
||||
) {}
|
||||
|
||||
SparseCsrTensorImpl::SparseCsrTensorImpl(
|
||||
at::DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type,
|
||||
at::Tensor crow_indices,
|
||||
at::Tensor col_indices,
|
||||
at::Tensor values)
|
||||
: TensorImpl(key_set, data_type, values.device()),
|
||||
crow_indices_(std::move(crow_indices)),
|
||||
col_indices_(std::move(col_indices)),
|
||||
values_(std::move(values)) {}
|
||||
|
||||
void SparseCsrTensorImpl::resize_and_clear_(
|
||||
const int64_t nnz_size,
|
||||
IntArrayRef size) {
|
||||
// call crow_indices().options() here since the struct contructor calls the
|
||||
// tensor constructor with args for device specific init.
|
||||
auto empty_crow_indices = at::empty(size[0] + 1, crow_indices().options());
|
||||
auto empty_col_indices = at::empty(nnz_size, col_indices().options());
|
||||
auto empty_values = at::empty(nnz_size, values().options());
|
||||
|
||||
crow_indices_ = empty_crow_indices;
|
||||
col_indices_ = empty_col_indices;
|
||||
values_ = empty_values;
|
||||
sizes_and_strides_.set_sizes(size);
|
||||
}
|
||||
|
||||
void SparseCsrTensorImpl::resize_as_sparse_csr_tensor_(const Tensor& src) {
|
||||
crow_indices_ = at::empty_like(
|
||||
src.crow_indices(),
|
||||
src.crow_indices().options(),
|
||||
src.crow_indices().suggest_memory_format());
|
||||
col_indices_ = at::empty_like(
|
||||
src.col_indices(),
|
||||
src.col_indices().options(),
|
||||
src.col_indices().suggest_memory_format());
|
||||
values_ = at::empty_like(
|
||||
src.values(),
|
||||
src.values().options(),
|
||||
src.values().suggest_memory_format());
|
||||
sizes_and_strides_.set_sizes(src.sizes());
|
||||
}
|
||||
|
||||
void SparseCsrTensorImpl::set_member_tensors(
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values) {
|
||||
auto crow_indices_type = crow_indices.scalar_type();
|
||||
auto col_indices_type = col_indices.scalar_type();
|
||||
|
||||
TORCH_CHECK(
|
||||
crow_indices_type == col_indices_type,
|
||||
"both crow_indices and col_indices should have the same type.");
|
||||
TORCH_CHECK(
|
||||
crow_indices_type == kInt || crow_indices_type == kLong,
|
||||
"crow_indices and col_indices must be an int32 or int64 type, but got: ",
|
||||
crow_indices_type);
|
||||
TORCH_CHECK(
|
||||
values.scalar_type() == typeMetaToScalarType(dtype()),
|
||||
"dtype of values (",
|
||||
values.scalar_type(),
|
||||
") must match dtype of sparse tensor (",
|
||||
typeMetaToScalarType(dtype()),
|
||||
")");
|
||||
|
||||
TORCH_CHECK(
|
||||
col_indices.layout() == kStrided,
|
||||
"expected col_indices to be a strided tensor, but got indices of layout ",
|
||||
col_indices.layout());
|
||||
TORCH_CHECK(
|
||||
crow_indices.layout() == kStrided,
|
||||
"expected crow_indices to be a strided tensor, but got crow_indices of layout ",
|
||||
crow_indices.layout());
|
||||
TORCH_CHECK(
|
||||
values.layout() == kStrided && values.is_contiguous(),
|
||||
"expected values to be a strided and contiguous tensor, but got values of layout ",
|
||||
values.layout());
|
||||
|
||||
TORCH_CHECK(
|
||||
values.device().type() == device().type(),
|
||||
"device type of values (",
|
||||
values.device().type(),
|
||||
") must match device type of device().type()",
|
||||
device().type(),
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
values.is_cuda() || col_indices.get_device() == crow_indices.get_device(),
|
||||
"crow_indices and col_indices devices (",
|
||||
crow_indices.get_device(),
|
||||
", ",
|
||||
col_indices.get_device(),
|
||||
") must match with the (non-cuda) device of values (",
|
||||
values.get_device(),
|
||||
")");
|
||||
|
||||
TORCH_CHECK(
|
||||
col_indices.size(0) == values.size(0),
|
||||
"col_indices and values must have equal sizes, but got col_indices.size(0): ",
|
||||
col_indices.size(0),
|
||||
", values.size(0): ",
|
||||
values.size(0));
|
||||
|
||||
crow_indices_ = crow_indices;
|
||||
col_indices_ = col_indices;
|
||||
values_ = values;
|
||||
}
|
||||
} // namespace at
|
55
aten/src/ATen/SparseCsrTensorImpl.h
Normal file
55
aten/src/ATen/SparseCsrTensorImpl.h
Normal file
@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
// Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
|
||||
// denoting the data: `crow_indices_`, `col_indices_` and `values_`.
|
||||
// The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
|
||||
// that represents the compressed row indices of the CSR tensor. The
|
||||
// `col_indices_` tensor is an integer tensor of shape `(nnz())`
|
||||
// that explicitly stores the column indices of each value of the sparse
|
||||
// tensor. The `values_` tensor can be of any pytorch-supported data type
|
||||
// and has shape `(nnz())`.
|
||||
//
|
||||
// Since the main advantage of the CSR format over the COO format is speed of
|
||||
// computation, care must be taken to facilitate smooth interfacing of
|
||||
// these data structures with optimized libraries such as MKL and MAGMA.
|
||||
// Since the MKL interface for pytorch currently uses indexing with int32
|
||||
// type, it is important to make sure that the `crow_indices` and `col_indices`
|
||||
// are of type int32 when calling MKL routines such as SPMM or SPMV.
|
||||
//
|
||||
// If not calling MKL, it should be alright to use 64 bit integer tensors
|
||||
// for indexing.
|
||||
struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
|
||||
Tensor crow_indices_;
|
||||
Tensor col_indices_;
|
||||
Tensor values_;
|
||||
|
||||
public:
|
||||
explicit SparseCsrTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
|
||||
|
||||
void resize_and_clear_(const int64_t nnz_size, IntArrayRef size);
|
||||
void resize_as_sparse_csr_tensor_(const Tensor& src);
|
||||
void set_member_tensors(
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values);
|
||||
|
||||
const Tensor& crow_indices() const { return crow_indices_; }
|
||||
const Tensor& col_indices() const { return col_indices_; }
|
||||
const Tensor& values() const { return values_; }
|
||||
int nnz() { return values_.size(0); }
|
||||
|
||||
private:
|
||||
explicit SparseCsrTensorImpl(
|
||||
at::DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type,
|
||||
at::Tensor crow_indices,
|
||||
at::Tensor col_indices,
|
||||
at::Tensor values);
|
||||
};
|
||||
} // namespace at
|
20
aten/src/ATen/SparseCsrTensorUtils.h
Normal file
20
aten/src/ATen/SparseCsrTensorUtils.h
Normal file
@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/SparseCsrTensorImpl.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
|
||||
namespace at {
|
||||
namespace sparse_csr {
|
||||
|
||||
using SparseCsrTensor = Tensor;
|
||||
|
||||
inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
|
||||
AT_ASSERTM(
|
||||
self.is_sparse_csr(),
|
||||
"_internal_get_SparseCsrTensorImpl: not a sparse CSR tensor");
|
||||
return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
|
||||
}
|
||||
} // namespace sparse
|
||||
} // namespace at
|
@ -34,6 +34,10 @@ class TORCH_API DeprecatedTypeProperties {
|
||||
return layout_from_backend(backend()) == kSparse;
|
||||
}
|
||||
|
||||
bool is_sparse_csr() const {
|
||||
return layout_from_backend(backend()) == kSparseCsr;
|
||||
}
|
||||
|
||||
DeviceType device_type() const {
|
||||
return backendToDeviceType(backend_);
|
||||
}
|
||||
|
@ -402,6 +402,7 @@ _(aten, is_same_size) \
|
||||
_(aten, is_set_to) \
|
||||
_(aten, is_signed) \
|
||||
_(aten, is_sparse) \
|
||||
_(aten, is_sparse_csr) \
|
||||
_(aten, isclose) \
|
||||
_(aten, isreal) \
|
||||
_(aten, istft) \
|
||||
|
@ -85,7 +85,7 @@ Tensor& resize_as_(
|
||||
!optional_memory_format.has_value(),
|
||||
"Unsupported memory format for sparse tensor resize_as_ :",
|
||||
optional_memory_format.value());
|
||||
return native::resize_as_sparse_(self, the_template);
|
||||
return at::native::resize_as_sparse_(self, the_template);
|
||||
}
|
||||
Tensor& result = self.resize_(the_template.sizes());
|
||||
if (optional_memory_format.has_value()) {
|
||||
|
@ -30,6 +30,10 @@ bool is_sparse(const Tensor& self) {
|
||||
return self.is_sparse();
|
||||
}
|
||||
|
||||
bool is_sparse_csr(const Tensor& self) {
|
||||
return self.is_sparse_csr();
|
||||
}
|
||||
|
||||
bool is_quantized(const Tensor& self) {
|
||||
return self.is_quantized();
|
||||
}
|
||||
|
236
aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp
Normal file
236
aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp
Normal file
@ -0,0 +1,236 @@
|
||||
#include <ATen/native/mkl/SparseCsrLinearAlgebra.h>
|
||||
|
||||
// Don't compile with MKL for MSVC/macos since linking the sparse MKL routines
|
||||
// needs some build fixes.
|
||||
// https://github.com/pytorch/pytorch/pull/50937#issuecomment-778732740
|
||||
// Macros source:
|
||||
// https://web.archive.org/web/20191012035921/http://nadeausoftware.com/articles/2012/01/c_c_tip_how_use_compiler_predefined_macros_detect_operating_system
|
||||
#if !AT_MKL_ENABLED() || defined(_MSC_VER) || defined(__APPLE__) || \
|
||||
defined(__MACH__)
|
||||
|
||||
namespace at {
|
||||
namespace sparse_csr {
|
||||
Tensor& _sparse_mm_mkl_(
|
||||
Tensor& self,
|
||||
const SparseCsrTensor& sparse_,
|
||||
const Tensor& dense,
|
||||
const Tensor& t,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta) {
|
||||
#if _MSC_VER
|
||||
AT_ERROR("sparse_mm_mkl: MKL support is disabled on Windows");
|
||||
#elif __APPLE__ || __MACH__
|
||||
AT_ERROR("sparse_mm_mkl: MKL support is disabled on macos/iOS.");
|
||||
#else
|
||||
AT_ERROR("sparse_mm_mkl: ATen not compiled with MKL support");
|
||||
#endif
|
||||
return self; // for stopping compiler warnings.
|
||||
}
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#else // AT_MKL_ENABLED
|
||||
|
||||
#include <ATen/mkl/Descriptors.h>
|
||||
#include <ATen/mkl/Exceptions.h>
|
||||
#include <ATen/mkl/Limits.h>
|
||||
#include <mkl.h>
|
||||
#include <mkl_spblas.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/SparseCsrTensorImpl.h>
|
||||
|
||||
namespace at {
|
||||
namespace sparse_csr {
|
||||
|
||||
#ifdef MKL_ILP64
|
||||
static constexpr ScalarType TORCH_INT_TYPE = at::kLong;
|
||||
#else
|
||||
static constexpr ScalarType TORCH_INT_TYPE = at::kInt;
|
||||
#endif
|
||||
|
||||
class SparseCsrMKLInterface {
|
||||
private:
|
||||
sparse_matrix_t A = 0;
|
||||
matrix_descr desc;
|
||||
|
||||
public:
|
||||
SparseCsrMKLInterface(
|
||||
MKL_INT* col_indices,
|
||||
MKL_INT* crow_indices,
|
||||
double* values,
|
||||
MKL_INT nrows,
|
||||
MKL_INT ncols) {
|
||||
desc.type = SPARSE_MATRIX_TYPE_GENERAL;
|
||||
int retval = mkl_sparse_d_create_csr(
|
||||
&A,
|
||||
SPARSE_INDEX_BASE_ZERO,
|
||||
nrows,
|
||||
ncols,
|
||||
crow_indices,
|
||||
crow_indices + 1,
|
||||
col_indices,
|
||||
values);
|
||||
TORCH_CHECK(
|
||||
retval == 0,
|
||||
"mkl_sparse_d_create_csr failed with error code: ",
|
||||
retval);
|
||||
}
|
||||
|
||||
SparseCsrMKLInterface(
|
||||
MKL_INT* col_indices,
|
||||
MKL_INT* crow_indices,
|
||||
float* values,
|
||||
MKL_INT nrows,
|
||||
MKL_INT ncols) {
|
||||
desc.type = SPARSE_MATRIX_TYPE_GENERAL;
|
||||
int retval = mkl_sparse_s_create_csr(
|
||||
&A,
|
||||
SPARSE_INDEX_BASE_ZERO,
|
||||
nrows,
|
||||
ncols,
|
||||
crow_indices,
|
||||
crow_indices + 1,
|
||||
col_indices,
|
||||
values);
|
||||
TORCH_CHECK(
|
||||
retval == 0,
|
||||
"mkl_sparse_s_create_csr failed with error code: ",
|
||||
retval);
|
||||
}
|
||||
|
||||
inline void sparse_mm(
|
||||
float* res,
|
||||
float* dense,
|
||||
float alpha,
|
||||
float beta,
|
||||
MKL_INT nrows,
|
||||
MKL_INT ncols,
|
||||
MKL_INT dense_ncols) {
|
||||
int stat = mkl_sparse_s_mm(
|
||||
SPARSE_OPERATION_NON_TRANSPOSE,
|
||||
alpha,
|
||||
A,
|
||||
desc,
|
||||
SPARSE_LAYOUT_ROW_MAJOR,
|
||||
dense,
|
||||
dense_ncols,
|
||||
dense_ncols,
|
||||
beta,
|
||||
res,
|
||||
dense_ncols);
|
||||
TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
|
||||
}
|
||||
|
||||
inline void sparse_mm(
|
||||
double* res,
|
||||
double* dense,
|
||||
double alpha,
|
||||
double beta,
|
||||
MKL_INT nrows,
|
||||
MKL_INT ncols,
|
||||
MKL_INT dense_ncols) {
|
||||
int stat = mkl_sparse_d_mm(
|
||||
SPARSE_OPERATION_NON_TRANSPOSE,
|
||||
alpha,
|
||||
A,
|
||||
desc,
|
||||
SPARSE_LAYOUT_ROW_MAJOR,
|
||||
dense,
|
||||
dense_ncols,
|
||||
dense_ncols,
|
||||
beta,
|
||||
res,
|
||||
dense_ncols);
|
||||
TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
|
||||
}
|
||||
|
||||
~SparseCsrMKLInterface() {
|
||||
mkl_sparse_destroy(A);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
static inline void sparse_mm_mkl_template(
|
||||
Tensor& res,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& values,
|
||||
const Tensor& dense,
|
||||
const Tensor& t,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta,
|
||||
IntArrayRef size,
|
||||
IntArrayRef dense_size) {
|
||||
SparseCsrMKLInterface mkl_impl(
|
||||
col_indices.data_ptr<MKL_INT>(),
|
||||
crow_indices.data_ptr<MKL_INT>(),
|
||||
values.data_ptr<scalar_t>(),
|
||||
size[0],
|
||||
size[1]);
|
||||
mkl_impl.sparse_mm(
|
||||
res.data_ptr<scalar_t>(),
|
||||
dense.data_ptr<scalar_t>(),
|
||||
alpha.to<scalar_t>(),
|
||||
beta.to<scalar_t>(),
|
||||
size[0],
|
||||
size[1],
|
||||
dense_size[1]);
|
||||
}
|
||||
|
||||
static bool inline constexpr is_mkl_int32_index() {
|
||||
#ifdef MKL_ILP64
|
||||
return false;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor& _sparse_mm_mkl_(
|
||||
Tensor& self,
|
||||
const SparseCsrTensor& sparse_,
|
||||
const Tensor& dense,
|
||||
const Tensor& t,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta) {
|
||||
if (is_mkl_int32_index()) {
|
||||
if (sparse_.crow_indices().scalar_type() != kInt) {
|
||||
TORCH_WARN(
|
||||
"Pytorch is compiled with MKL LP64 and will convert crow_indices to int32.");
|
||||
}
|
||||
if (sparse_.col_indices().scalar_type() != kInt) {
|
||||
TORCH_WARN(
|
||||
"Pytorch is compiled with MKL LP64 and will convert col_indices to int32.");
|
||||
}
|
||||
} else { // This is for future proofing if we ever change to using MKL ILP64.
|
||||
if (sparse_.crow_indices().scalar_type() != kLong) {
|
||||
TORCH_WARN(
|
||||
"Pytorch is compiled with MKL ILP64 and will convert crow_indices dtype to int64.");
|
||||
}
|
||||
if (sparse_.col_indices().scalar_type() != kLong) {
|
||||
TORCH_WARN(
|
||||
"Pytorch is compiled with MKL ILP64 and will convert col_indices dtype to int64.");
|
||||
}
|
||||
}
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
dense.scalar_type(), "addmm_sparse_csr_dense", [&] {
|
||||
sparse_mm_mkl_template<scalar_t>(
|
||||
self,
|
||||
sparse_.col_indices().to(TORCH_INT_TYPE),
|
||||
sparse_.crow_indices().to(TORCH_INT_TYPE),
|
||||
sparse_.values(),
|
||||
dense,
|
||||
t,
|
||||
alpha,
|
||||
beta,
|
||||
sparse_.sizes(),
|
||||
dense.sizes());
|
||||
});
|
||||
return self;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif // AT_MKL_ENABLED
|
14
aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.h
Normal file
14
aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.h
Normal file
@ -0,0 +1,14 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/SparseCsrTensorUtils.h>
|
||||
|
||||
namespace at {
|
||||
namespace sparse_csr {
|
||||
Tensor& _sparse_mm_mkl_(
|
||||
Tensor& self,
|
||||
const SparseCsrTensor& sparse_,
|
||||
const Tensor& dense,
|
||||
const Tensor& t,
|
||||
const Scalar& alpha,
|
||||
const Scalar& beta);
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -330,6 +330,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: add_sparse
|
||||
SparseCsrCPU: add_sparse_csr
|
||||
MkldnnCPU: mkldnn_add
|
||||
|
||||
- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
||||
@ -337,6 +338,7 @@
|
||||
structured_delegate: add.out
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: add_sparse_
|
||||
SparseCsrCPU: add_sparse_csr_
|
||||
MkldnnCPU: mkldnn_add_
|
||||
|
||||
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||
@ -346,6 +348,7 @@
|
||||
CPU, CUDA: add_out
|
||||
SparseCPU: add_out_sparse_cpu
|
||||
SparseCUDA: add_out_sparse_cuda
|
||||
SparseCsrCPU: add_out_sparse_csr_cpu
|
||||
MkldnnCPU: mkldnn_add_out
|
||||
|
||||
- func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
|
||||
@ -388,6 +391,7 @@
|
||||
dispatch:
|
||||
CPU, CUDA: addmv_out
|
||||
|
||||
|
||||
- func: _addmv_impl_(Tensor(a!) self, Tensor self2, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: addmv_impl_cpu
|
||||
@ -2551,13 +2555,14 @@
|
||||
dispatch:
|
||||
CPU: mm_cpu
|
||||
CUDA: mm_cuda
|
||||
SparseCPU, SparseCUDA: _sparse_mm
|
||||
SparseCPU, SparseCUDA, SparseCsrCPU: _sparse_mm
|
||||
|
||||
- func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: mm_cpu_out
|
||||
CUDA: mm_out_cuda
|
||||
SparseCPU, SparseCUDA: _sparse_mm_out
|
||||
SparseCsrCPU: _sparse_csr_mm_out
|
||||
|
||||
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
|
||||
|
||||
@ -2638,7 +2643,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU, CUDA: mv
|
||||
SparseCPU, SparseCUDA: mv_sparse
|
||||
SparseCPU, SparseCUDA, SparseCsrCPU: mv_sparse
|
||||
|
||||
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
@ -4001,6 +4006,12 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: resize_as_
|
||||
|
||||
- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: resize_as_sparse_
|
||||
SparseCsrCPU: resize_as_sparse_csr_
|
||||
|
||||
- func: zero_(Tensor(a!) self) -> Tensor(a!)
|
||||
variants: method, function
|
||||
dispatch:
|
||||
@ -4088,6 +4099,7 @@
|
||||
CUDA: addmm_out_cuda
|
||||
SparseCPU: addmm_out_sparse_dense_cpu
|
||||
SparseCUDA: addmm_out_sparse_dense_cuda
|
||||
SparseCsrCPU: addmm_out_sparse_csr_dense_cpu
|
||||
|
||||
- func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||
variants: function, method
|
||||
@ -4096,6 +4108,7 @@
|
||||
CUDA: addmm_cuda
|
||||
SparseCPU: addmm_sparse_dense_cpu
|
||||
SparseCUDA: addmm_sparse_dense_cuda
|
||||
SparseCsrCPU: addmm_sparse_csr_dense_cpu
|
||||
|
||||
- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
|
||||
variants: method
|
||||
@ -4215,9 +4228,13 @@
|
||||
# the view relation is not tracked by autograd, but the version counter is still
|
||||
# shared. In other words, their outputs are non-differentiable views of the
|
||||
# sparse tensor.
|
||||
|
||||
# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given
|
||||
# the default would never make sense.
|
||||
|
||||
- func: sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
|
||||
- func: sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
|
||||
- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
|
||||
- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
@ -4255,7 +4272,7 @@
|
||||
- func: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_to_dense
|
||||
SparseCPU, SparseCUDA, SparseCsrCPU: sparse_to_dense
|
||||
MkldnnCPU: mkldnn_to_dense
|
||||
|
||||
- func: to_dense_backward(Tensor grad, Tensor input) -> Tensor
|
||||
@ -4290,6 +4307,7 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: _nnz_sparse
|
||||
SparseCsrCPU: _nnz_sparse_csr
|
||||
device_guard: False
|
||||
|
||||
# NOTE: [ coalesce autograd ]
|
||||
@ -4342,6 +4360,19 @@
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: values_sparse
|
||||
SparseCsrCPU: values_sparse_csr
|
||||
device_guard: False
|
||||
|
||||
- func: crow_indices(Tensor(a) self) -> Tensor(a)
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCsrCPU: crow_indices_sparse_csr
|
||||
device_guard: False
|
||||
|
||||
- func: col_indices(Tensor(a) self) -> Tensor(a)
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCsrCPU: col_indices_sparse_csr
|
||||
device_guard: False
|
||||
|
||||
- func: hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
159
aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Normal file
159
aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Normal file
@ -0,0 +1,159 @@
|
||||
// Basic functions on sparse tensors
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/Layout.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/SparseCsrTensorImpl.h>
|
||||
#include <ATen/SparseCsrTensorUtils.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
using namespace at::sparse_csr;
|
||||
|
||||
// Construction of CSR tensors.
|
||||
SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
|
||||
// TODO: remove this comment after enabling autograd support for CSR tensor
|
||||
// constructor.
|
||||
// TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch());
|
||||
TORCH_INTERNAL_ASSERT(options.layout() == kSparseCsr);
|
||||
DispatchKey dispatch_key;
|
||||
|
||||
if (options.device().is_cuda()) {
|
||||
dispatch_key = DispatchKey::SparseCsrCUDA;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(options.device().is_cpu());
|
||||
dispatch_key = DispatchKey::SparseCsrCPU;
|
||||
}
|
||||
|
||||
return detail::make_tensor<SparseCsrTensorImpl>(
|
||||
DispatchKeySet(dispatch_key), options.dtype());
|
||||
}
|
||||
|
||||
// TODO: This constructor should probably use an ATen abstract method in order
|
||||
// to make autograd dispatch available for the CSR constructor. See the relevant
|
||||
// note in native_functions.yaml.
|
||||
Tensor sparse_csr_tensor(
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
// See [Note: hacky wrapper removal for TensorOptions]
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
|
||||
TORCH_CHECK(
|
||||
options.layout() == kSparseCsr,
|
||||
"expected sparse CSR layout, but got layout ",
|
||||
options.layout());
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
|
||||
auto crow_indices_accessor = crow_indices.accessor<index_t, 1>();
|
||||
TORCH_CHECK(
|
||||
crow_indices_accessor[crow_indices.numel() - 1] <= col_indices.numel(),
|
||||
"last value of crow_indices should be less than length of col_indices.");
|
||||
TORCH_CHECK(
|
||||
crow_indices_accessor[0] == 0, "0th value of crow_indices must be 0.");
|
||||
});
|
||||
|
||||
TORCH_CHECK(
|
||||
crow_indices.dim() == 1,
|
||||
"crow_indices must have dim=1 but got crow_indices.dim()=",
|
||||
crow_indices.dim());
|
||||
TORCH_CHECK(
|
||||
col_indices.dim() == 1,
|
||||
"col_indices must have dim=1 but got col_indices.dim()=",
|
||||
col_indices.dim());
|
||||
TORCH_CHECK(
|
||||
values.dim() == 1,
|
||||
"values must have dim=1 but got values.dim()=",
|
||||
values.dim());
|
||||
|
||||
TORCH_CHECK(
|
||||
(crow_indices.numel() - 1) == size[0],
|
||||
"crow_indices.numel() must be size(0) + 1, but got: ",
|
||||
crow_indices.numel());
|
||||
|
||||
SparseCsrTensor self = new_csr_tensor(options);
|
||||
get_sparse_csr_impl(self)->resize_and_clear_(values.numel(), size);
|
||||
get_sparse_csr_impl(self)->set_member_tensors(
|
||||
crow_indices, col_indices, values);
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor sparse_csr_tensor(
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
// See [Note: hacky wrapper removal for TensorOptions]
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
|
||||
|
||||
TORCH_CHECK(
|
||||
options.layout() == kSparseCsr,
|
||||
"expected sparse CSR layout, but got layout ",
|
||||
options.layout());
|
||||
TORCH_CHECK(crow_indices.numel() >= 1, "expected crow_indices.numel() >= 1, but got ",
|
||||
crow_indices.numel());
|
||||
std::array<int64_t, 2> size;
|
||||
|
||||
if (col_indices.numel() > 0) {
|
||||
size[0] = crow_indices.numel() - 1;
|
||||
Tensor max_col_indices = std::get<0>(col_indices.max(0, false));
|
||||
size[1] = *max_col_indices.data_ptr<int64_t>() + 1;
|
||||
} else {
|
||||
size[0] = 0;
|
||||
size[1] = 0;
|
||||
}
|
||||
|
||||
return at::sparse_csr_tensor(
|
||||
crow_indices, col_indices, values, size, options);
|
||||
}
|
||||
|
||||
// Access members of CSR tensors.
|
||||
int64_t _nnz_sparse_csr(const SparseCsrTensor& self) {
|
||||
return get_sparse_csr_impl(self)->nnz();
|
||||
}
|
||||
|
||||
Tensor values_sparse_csr(const Tensor& self) {
|
||||
return get_sparse_csr_impl(self)->values().alias();
|
||||
}
|
||||
|
||||
Tensor crow_indices_sparse_csr(const Tensor& self) {
|
||||
return get_sparse_csr_impl(self)->crow_indices().alias();
|
||||
}
|
||||
|
||||
Tensor col_indices_sparse_csr(const Tensor& self) {
|
||||
return get_sparse_csr_impl(self)->col_indices().alias();
|
||||
}
|
||||
|
||||
bool _is_same_size_as_sparse_csr(
|
||||
const SparseCsrTensor& self,
|
||||
const SparseCsrTensor& src) {
|
||||
return self.sizes().equals(src.sizes());
|
||||
}
|
||||
|
||||
SparseCsrTensor& resize_as_sparse_csr_(
|
||||
SparseCsrTensor& self,
|
||||
const SparseCsrTensor& src) {
|
||||
TORCH_CHECK(
|
||||
src.is_sparse_csr() && self.is_sparse_csr(),
|
||||
"resize_as_sparse_csr_: layout for self and src must be sparse_csr but got self, src: ",
|
||||
self.layout(),
|
||||
src.layout());
|
||||
if (!_is_same_size_as_sparse_csr(self, src)) {
|
||||
get_sparse_csr_impl(self)->resize_as_sparse_csr_tensor_(src);
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
331
aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
Normal file
331
aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
Normal file
@ -0,0 +1,331 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/SparseCsrTensorImpl.h>
|
||||
#include <ATen/SparseCsrTensorUtils.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/native/BinaryOps.h>
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
#include <ATen/native/mkl/SparseCsrLinearAlgebra.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
using namespace at::sparse_csr;
|
||||
// certain utiliy functions are usable from sparse COO.
|
||||
using namespace at::sparse;
|
||||
|
||||
static constexpr bool is_msvc() {
|
||||
#ifdef _MSC_VER
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Functions for matrix multiplication.
|
||||
Tensor& addmm_out_sparse_csr_dense_cpu(
|
||||
const Tensor& self,
|
||||
const SparseCsrTensor& op1,
|
||||
const Tensor& op2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& out) {
|
||||
AT_ASSERT(op1.is_sparse_csr());
|
||||
Tensor expand_self = *expand_size(self, {op1.size(0), op2.size(1)}, "addmm_out_sparse_csr");
|
||||
|
||||
AT_ASSERT(expand_self.device().type() == kCPU);
|
||||
TORCH_CHECK(
|
||||
out.device().type() == kCPU,
|
||||
"addmm: expected 'out' to be CPU tensor, but got CUDA tensor");
|
||||
TORCH_CHECK(
|
||||
op1.device().type() == kCPU,
|
||||
"addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
|
||||
TORCH_CHECK(
|
||||
op2.device().type() == kCPU,
|
||||
"addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");
|
||||
|
||||
TORCH_CHECK(
|
||||
op1.dim() == 2,
|
||||
"addmm: 2-D matrices expected, got ",
|
||||
op1.dim(),
|
||||
"D tensor");
|
||||
TORCH_CHECK(
|
||||
op2.dim() == 2,
|
||||
"addmm: 2-D matrices expected, got ",
|
||||
op2.dim(),
|
||||
"D tensor");
|
||||
|
||||
TORCH_CHECK(
|
||||
out.is_contiguous(),
|
||||
"out argument must be contiguous, but got: ",
|
||||
out.suggest_memory_format());
|
||||
|
||||
// ixk * kxj = ixj
|
||||
int64_t dim_i = op1.size(0);
|
||||
int64_t dim_j = op2.size(1);
|
||||
int64_t dim_k = op1.size(1);
|
||||
|
||||
TORCH_CHECK(
|
||||
op2.size(0) == dim_k,
|
||||
"addmm: Expected dense matrix (op2) size(0)=",
|
||||
dim_k,
|
||||
", got ",
|
||||
op2.size(0));
|
||||
TORCH_CHECK(
|
||||
op1.size(1) == dim_k,
|
||||
"addmm: Expected sparse matrix (op1) size(1)=",
|
||||
dim_k,
|
||||
", got ",
|
||||
op1.size(1));
|
||||
out.resize_({dim_i, dim_j});
|
||||
|
||||
auto col_indices = op1.col_indices();
|
||||
auto crow_indices = op1.crow_indices();
|
||||
auto values = op1.values();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
values.scalar_type(), "addmm_sparse_csr_dense", [&] {
|
||||
scalar_t cast_beta = beta.to<scalar_t>();
|
||||
if (!is_same_tensor(out, expand_self)) {
|
||||
out.copy_(expand_self);
|
||||
}
|
||||
if (cast_beta == 0) {
|
||||
out.zero_();
|
||||
} else {
|
||||
at::mul_out(out, expand_self, scalar_to_tensor(beta));
|
||||
}
|
||||
});
|
||||
|
||||
// Do not use MKL for Windows due to linking issues with sparse MKL routines.
|
||||
if (at::hasMKL() && !is_msvc()) {
|
||||
_sparse_mm_mkl_(out, op1, op2, expand_self, alpha, beta);
|
||||
} else {
|
||||
int64_t dense_stride0 = op1.stride(0);
|
||||
int64_t dense_stride1 = op1.stride(1);
|
||||
int64_t out_stride0 = out.stride(0);
|
||||
int64_t out_stride1 = out.stride(1);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
values.scalar_type(),
|
||||
"sparse_csr_mm_cpu",
|
||||
[&alpha,
|
||||
&beta,
|
||||
&op1,
|
||||
&out,
|
||||
&values,
|
||||
&crow_indices,
|
||||
&col_indices,
|
||||
&dense_stride0,
|
||||
&dense_stride1,
|
||||
&out_stride0,
|
||||
&out_stride1,
|
||||
&dim_k]() {
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
crow_indices.scalar_type(),
|
||||
"csr_mm_crow_indices",
|
||||
[&alpha,
|
||||
&beta,
|
||||
&op1,
|
||||
&out,
|
||||
&values,
|
||||
&crow_indices,
|
||||
&col_indices,
|
||||
&dense_stride0,
|
||||
&dense_stride1,
|
||||
&out_stride0,
|
||||
&out_stride1,
|
||||
&dim_k]() {
|
||||
scalar_t cast_alpha = alpha.to<scalar_t>();
|
||||
scalar_t cast_beta = beta.to<scalar_t>();
|
||||
scalar_t* dense_ptr = op1.data_ptr<scalar_t>();
|
||||
scalar_t* out_ptr = out.data_ptr<scalar_t>();
|
||||
|
||||
auto col_indices_accessor = col_indices.accessor<index_t, 1>();
|
||||
auto crow_indices_accessor =
|
||||
crow_indices.accessor<index_t, 1>();
|
||||
auto values_accessor = values.accessor<scalar_t, 1>();
|
||||
|
||||
at::parallel_for(
|
||||
0,
|
||||
crow_indices.size(0) - 1,
|
||||
internal::GRAIN_SIZE,
|
||||
[&](int64_t irow_start, int64_t irow_end) {
|
||||
for (int irow = irow_start; irow < irow_end; ++irow) {
|
||||
int start_index = crow_indices_accessor[irow];
|
||||
int end_index = crow_indices_accessor[irow + 1];
|
||||
|
||||
for (int i = start_index; i < end_index; ++i) {
|
||||
auto val = values_accessor[i];
|
||||
auto icol = col_indices_accessor[i];
|
||||
|
||||
at::native::cpublas::axpy<scalar_t>(
|
||||
dim_k,
|
||||
cast_alpha * val,
|
||||
dense_ptr + icol * dense_stride0,
|
||||
dense_stride1,
|
||||
out_ptr + irow * out_stride0,
|
||||
out_stride1);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor addmm_sparse_csr_dense_cpu(
|
||||
const Tensor& self,
|
||||
const SparseCsrTensor& sparse,
|
||||
const Tensor& dense,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
Tensor r = at::empty({0}, self.options());
|
||||
at::addmm_out(r, self, sparse, dense, beta, alpha);
|
||||
return r;
|
||||
}
|
||||
|
||||
SparseCsrTensor& _sparse_csr_mm_out(
|
||||
const SparseCsrTensor& sparse,
|
||||
const Tensor& dense,
|
||||
SparseCsrTensor& result) {
|
||||
Tensor t = at::zeros({}, dense.options());
|
||||
return at::addmm_out(result, t, sparse, dense, 0.0, 1.0); // redispatch!
|
||||
}
|
||||
|
||||
Tensor _sparse_csr_addmm(
|
||||
const Tensor& t,
|
||||
const SparseCsrTensor& sparse,
|
||||
const Tensor& dense,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
// _sparse_addmm forward is functionally equivalent to addmm; it's
|
||||
// just the backward that is different. This technically does an
|
||||
// unnecessary redispatch, I was too lazy to make it not do that
|
||||
return at::addmm(t, sparse, dense, beta, alpha);
|
||||
}
|
||||
|
||||
// Functions for element-wise addition.
|
||||
Tensor add_sparse_csr(const Tensor& self, const Tensor& other, const Scalar& alpha) {
|
||||
auto commonDtype = at::result_type(self, other);
|
||||
alpha_check(commonDtype, alpha);
|
||||
Tensor result = at::empty({0}, self.options().dtype(commonDtype));
|
||||
return at::add_out(result, self, other, alpha); // redispatch!
|
||||
}
|
||||
|
||||
Tensor& add_sparse_csr_(Tensor& self, const Tensor& other, const Scalar& alpha) {
|
||||
return at::add_out(self, self, other, alpha); // redispatch!
|
||||
}
|
||||
|
||||
Tensor& add_out_dense_sparse_csr_cpu(
|
||||
Tensor& out,
|
||||
const Tensor& dense,
|
||||
const SparseCsrTensor& src,
|
||||
const Scalar& alpha) {
|
||||
AT_ASSERT(dense.layout() == kStrided);
|
||||
AT_ASSERT(src.is_sparse_csr());
|
||||
AT_ASSERT(dense.device() == kCPU);
|
||||
|
||||
TORCH_CHECK(
|
||||
out.is_contiguous(),
|
||||
"out argument must be contiguous, but got: ",
|
||||
out.suggest_memory_format());
|
||||
TORCH_CHECK(
|
||||
out.device() == kCPU,
|
||||
"add: expected 'out' to be CPU tensor, but got tensor on device: ",
|
||||
out.device());
|
||||
TORCH_CHECK(
|
||||
src.device() == kCPU,
|
||||
"add: expected 'other' to be a CPU tensor, but got tensor on device: ",
|
||||
src.device());
|
||||
|
||||
TORCH_CHECK(
|
||||
dense.sizes().equals(src.sizes()),
|
||||
"add: expected 'self' and 'other' to have same size, but self has size ",
|
||||
dense.sizes(),
|
||||
" while other has size ",
|
||||
src.sizes(),
|
||||
" (FYI: op2-sparse addition does not currently support broadcasting)");
|
||||
|
||||
auto commonDtype = promoteTypes(dense.scalar_type(), src.scalar_type());
|
||||
TORCH_CHECK(
|
||||
canCast(commonDtype, out.scalar_type()),
|
||||
"Can't convert result type ",
|
||||
commonDtype,
|
||||
" to output ",
|
||||
out.scalar_type(),
|
||||
" in add operation");
|
||||
|
||||
auto src_values = src.values().to(commonDtype);
|
||||
auto src_crow_indices = src.crow_indices();
|
||||
auto src_col_indices = src.col_indices();
|
||||
|
||||
out.resize_as_(dense);
|
||||
Tensor resultBuffer = out;
|
||||
Tensor valuesBuffer = src_values.to(commonDtype);
|
||||
|
||||
if (out.scalar_type() != commonDtype) {
|
||||
resultBuffer = dense.to(commonDtype);
|
||||
} else if (!is_same_tensor(out, dense)) {
|
||||
resultBuffer.copy_(dense);
|
||||
}
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
commonDtype,
|
||||
"add_out_op2_sparse_csr",
|
||||
[&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() {
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
src_crow_indices.scalar_type(),
|
||||
"csr_add_out_crow_indices",
|
||||
[&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() {
|
||||
auto values_accessor = src_values.accessor<scalar_t, 1>();
|
||||
scalar_t* out_ptr = out.data_ptr<scalar_t>();
|
||||
scalar_t cast_value = alpha.to<scalar_t>();
|
||||
|
||||
auto crow_indices_accessor =
|
||||
src_crow_indices.accessor<index_t, 1>();
|
||||
auto col_indices_accessor =
|
||||
src_col_indices.accessor<index_t, 1>();
|
||||
auto out_strides0 = out.strides()[0];
|
||||
auto out_strides1 = out.strides()[1];
|
||||
|
||||
for (int32_t irow = 0; irow < src_crow_indices.size(0) - 1;
|
||||
++irow) {
|
||||
int32_t start_index = crow_indices_accessor[irow];
|
||||
int32_t end_index = crow_indices_accessor[irow + 1];
|
||||
|
||||
for (int i = start_index; i < end_index; ++i) {
|
||||
auto icol = col_indices_accessor[i];
|
||||
auto index = out.storage_offset() + irow * out_strides0 +
|
||||
icol * out_strides1;
|
||||
out_ptr[index] += cast_value * values_accessor[i];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& add_out_sparse_csr_cpu(
|
||||
const Tensor& self,
|
||||
const SparseCsrTensor& other,
|
||||
const Scalar& alpha,
|
||||
SparseCsrTensor& out) {
|
||||
if (self.layout() == kStrided) {
|
||||
return add_out_dense_sparse_csr_cpu(out, self, other, alpha);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"NotImplementedError: Addition of sparse CSR tensors is not yet implemented.")
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -1,17 +1,18 @@
|
||||
// Basic functions on sparse tensors
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/Layout.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
using namespace at::sparse;
|
||||
|
||||
@ -36,7 +37,8 @@ int64_t _nnz_sparse(const SparseTensor& self) {
|
||||
}
|
||||
|
||||
// Why are there so many methods to get indices and value?
|
||||
// See Note [ Sparse: different methods to get indices and values ] in native_functions.yaml
|
||||
// See Note [ Sparse: different methods to get indices and values ] in
|
||||
// native_functions.yaml
|
||||
|
||||
Tensor _indices_sparse(const SparseTensor& self) {
|
||||
return get_sparse_impl(self)->indices();
|
||||
@ -46,20 +48,22 @@ Tensor _values_sparse(const SparseTensor& self) {
|
||||
return get_sparse_impl(self)->values();
|
||||
}
|
||||
|
||||
Tensor &_coalesced_sparse_(SparseTensor& self, bool coalesced) {
|
||||
Tensor& _coalesced_sparse_(SparseTensor& self, bool coalesced) {
|
||||
get_sparse_impl(self)->set_coalesced(coalesced);
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor indices_sparse(const Tensor& self) {
|
||||
TORCH_CHECK(self.is_coalesced(),
|
||||
"Cannot get indices on an uncoalesced tensor, please call .coalesce() first");
|
||||
TORCH_CHECK(
|
||||
self.is_coalesced(),
|
||||
"Cannot get indices on an uncoalesced tensor, please call .coalesce() first");
|
||||
return get_sparse_impl(self)->indices().alias();
|
||||
}
|
||||
|
||||
Tensor values_sparse(const Tensor& self) {
|
||||
TORCH_CHECK(self.is_coalesced(),
|
||||
"Cannot get values on an uncoalesced tensor, please call .coalesce() first");
|
||||
TORCH_CHECK(
|
||||
self.is_coalesced(),
|
||||
"Cannot get values on an uncoalesced tensor, please call .coalesce() first");
|
||||
return get_sparse_impl(self)->values().alias();
|
||||
}
|
||||
|
||||
@ -70,7 +74,11 @@ Tensor values_sparse(const Tensor& self) {
|
||||
|
||||
/*** Helper methods ***/
|
||||
|
||||
SparseTensor new_sparse(c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory) {
|
||||
SparseTensor new_sparse(
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
AT_ASSERT(layout.has_value() && *layout == kSparse);
|
||||
DispatchKey dispatch_key;
|
||||
if (device_or_default(device).is_cuda()) {
|
||||
@ -81,13 +89,20 @@ SparseTensor new_sparse(c10::optional<ScalarType> dtype, c10::optional<Layout> l
|
||||
dispatch_key = DispatchKey::SparseCPU;
|
||||
}
|
||||
return detail::make_tensor<SparseTensorImpl>(
|
||||
DispatchKeySet(dispatch_key), scalarTypeToTypeMeta(dtype_or_default(dtype)));
|
||||
DispatchKeySet(dispatch_key),
|
||||
scalarTypeToTypeMeta(dtype_or_default(dtype)));
|
||||
}
|
||||
|
||||
/** Actual dispatched creation methods ***/
|
||||
|
||||
SparseTensor new_with_dims_sparse(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size, c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory) {
|
||||
SparseTensor new_with_dims_sparse(
|
||||
int64_t sparse_dim,
|
||||
int64_t dense_dim,
|
||||
ArrayRef<int64_t> size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
SparseTensor self = new_sparse(dtype, layout, device, pin_memory);
|
||||
get_sparse_impl(self)->resize_and_clear_(sparse_dim, dense_dim, size);
|
||||
return self;
|
||||
@ -105,15 +120,18 @@ SparseTensor new_with_dims_and_tensor_sparse(
|
||||
c10::optional<bool> pin_memory) {
|
||||
SparseTensor self = new_sparse(dtype, layout, device, pin_memory);
|
||||
get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size);
|
||||
// NOTE: There is no guarantee that `indices` and `values` don't contain AutogradMeta. However,
|
||||
// we want to maintain the invariant that `indices_` and `values_` of a sparse tensor don't
|
||||
// contain AutogradMeta, and to achieve that we shallow-copy `indices` and `values` here.
|
||||
auto indices_shallow_copy = Tensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach(
|
||||
/*version_counter=*/indices.unsafeGetTensorImpl()->version_counter(),
|
||||
/*allow_tensor_metadata_change=*/true));
|
||||
auto values_shallow_copy = Tensor(values.unsafeGetTensorImpl()->shallow_copy_and_detach(
|
||||
/*version_counter=*/values.unsafeGetTensorImpl()->version_counter(),
|
||||
/*allow_tensor_metadata_change=*/true));
|
||||
// NOTE: There is no guarantee that `indices` and `values` don't contain
|
||||
// AutogradMeta. However, we want to maintain the invariant that `indices_`
|
||||
// and `values_` of a sparse tensor don't contain AutogradMeta, and to achieve
|
||||
// that we shallow-copy `indices` and `values` here.
|
||||
auto indices_shallow_copy =
|
||||
Tensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach(
|
||||
/*version_counter=*/indices.unsafeGetTensorImpl()->version_counter(),
|
||||
/*allow_tensor_metadata_change=*/true));
|
||||
auto values_shallow_copy =
|
||||
Tensor(values.unsafeGetTensorImpl()->shallow_copy_and_detach(
|
||||
/*version_counter=*/values.unsafeGetTensorImpl()->version_counter(),
|
||||
/*allow_tensor_metadata_change=*/true));
|
||||
alias_into_sparse(self, indices_shallow_copy, values_shallow_copy);
|
||||
return self;
|
||||
}
|
||||
@ -121,9 +139,18 @@ SparseTensor new_with_dims_and_tensor_sparse(
|
||||
/** Public creation API that dispatch to methods above **/
|
||||
|
||||
/** Empty init **/
|
||||
Tensor empty_sparse(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(!pin_memory.has_value() || !*pin_memory, "Only dense CPU tensors can be pinned");
|
||||
return new_with_dims_sparse(size.size(), 0, size, dtype, layout, device, pin_memory);
|
||||
Tensor empty_sparse(
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!pin_memory.has_value() || !*pin_memory,
|
||||
"Only dense CPU tensors can be pinned");
|
||||
return new_with_dims_sparse(
|
||||
size.size(), 0, size, dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
/* Shape init */
|
||||
@ -142,16 +169,16 @@ Tensor sparse_coo_tensor(IntArrayRef size,
|
||||
|
||||
// helper
|
||||
namespace {
|
||||
static inline Tensor expand_values_if_needed(const Tensor& values) {
|
||||
// expand
|
||||
if (values.dim() == 0) {
|
||||
// Mimic Numpy behavior here and treat it as a 1D tensor
|
||||
return values.expand({1});
|
||||
} else {
|
||||
return values;
|
||||
}
|
||||
static inline Tensor expand_values_if_needed(const Tensor& values) {
|
||||
// expand
|
||||
if (values.dim() == 0) {
|
||||
// Mimic Numpy behavior here and treat it as a 1D tensor
|
||||
return values.expand({1});
|
||||
} else {
|
||||
return values;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
|
||||
c10::optional<ScalarType> dtype,
|
||||
@ -164,11 +191,21 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
|
||||
Tensor values = expand_values_if_needed(values_);
|
||||
|
||||
// arg checking
|
||||
TORCH_CHECK(!options.has_layout() || options.layout() == kSparse, "expected sparse layout, but got layout ", options.layout());
|
||||
// the following checks are redundant because they are also checked in SparseTensorImpl::set_indices_and_values_unsafe
|
||||
// but we need to ensure them in order to infer the shape.
|
||||
TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sizes())
|
||||
TORCH_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
|
||||
TORCH_CHECK(
|
||||
!options.has_layout() || options.layout() == kSparse,
|
||||
"expected sparse layout, but got layout ",
|
||||
options.layout());
|
||||
// the following checks are redundant because they are also checked in
|
||||
// SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them
|
||||
// in order to infer the shape.
|
||||
TORCH_CHECK(
|
||||
indices.dim() == 2,
|
||||
"indices must be sparse_dim x nnz, but got: ",
|
||||
indices.sizes())
|
||||
TORCH_CHECK(
|
||||
!indices.is_sparse(),
|
||||
"expected indices to be a dense tensor, but got indices of layout ",
|
||||
indices.layout());
|
||||
|
||||
// If sizes are not given, it is inferred as max index of each dim.
|
||||
int64_t sparse_dim = indices.size(0);
|
||||
@ -176,54 +213,86 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
|
||||
|
||||
std::vector<int64_t> computed_sizes(sparse_dim + dense_dim);
|
||||
if (indices.numel() > 0) {
|
||||
// If the indices has elements in it, we infer the minimum sparse dimension sizes
|
||||
// as the max value of each dim in indices.
|
||||
// NB: It used to keepdim. I think that was wrong.
|
||||
Tensor min_indices = std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor computed_indices_sizes = std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
|
||||
// If the indices has elements in it, we infer the minimum sparse dimension
|
||||
// sizes as the max value of each dim in indices. NB: It used to keepdim. I
|
||||
// think that was wrong.
|
||||
Tensor min_indices =
|
||||
std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor computed_indices_sizes =
|
||||
std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
|
||||
computed_indices_sizes.add_(1); // len = max_index + 1
|
||||
Tensor cpu_min_indices = min_indices.to(at::DeviceType::CPU);
|
||||
Tensor cpu_computed_indices_sizes = computed_indices_sizes.to(at::DeviceType::CPU);
|
||||
Tensor cpu_computed_indices_sizes =
|
||||
computed_indices_sizes.to(at::DeviceType::CPU);
|
||||
auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
|
||||
auto cpu_computed_indices_sizes_accessor = cpu_computed_indices_sizes.accessor<int64_t, 1>();
|
||||
auto cpu_computed_indices_sizes_accessor =
|
||||
cpu_computed_indices_sizes.accessor<int64_t, 1>();
|
||||
for (int64_t d = 0; d < sparse_dim; d++) {
|
||||
int64_t min_index_in_dim = cpu_min_indices_accessor[d];
|
||||
TORCH_CHECK(min_index_in_dim >= 0,
|
||||
"found negative index ", min_index_in_dim, " for dim ", d);
|
||||
computed_sizes[static_cast<size_t>(d)] = cpu_computed_indices_sizes_accessor[d];
|
||||
TORCH_CHECK(
|
||||
min_index_in_dim >= 0,
|
||||
"found negative index ",
|
||||
min_index_in_dim,
|
||||
" for dim ",
|
||||
d);
|
||||
computed_sizes[static_cast<size_t>(d)] =
|
||||
cpu_computed_indices_sizes_accessor[d];
|
||||
}
|
||||
} else {
|
||||
// If the indices doesn't have elements in it, there is not enough information
|
||||
// to know what the minimum sparse dimension sizes should be, and in this case
|
||||
// we set them to 0
|
||||
// If the indices doesn't have elements in it, there is not enough
|
||||
// information to know what the minimum sparse dimension sizes should be,
|
||||
// and in this case we set them to 0
|
||||
for (int64_t d = 0; d < sparse_dim; d++) {
|
||||
computed_sizes[static_cast<size_t>(d)] = 0;
|
||||
}
|
||||
}
|
||||
for (int64_t d = 0; d < dense_dim; d++) {
|
||||
computed_sizes[static_cast<size_t>(sparse_dim + d)] = values.size(d+1);
|
||||
computed_sizes[static_cast<size_t>(sparse_dim + d)] = values.size(d + 1);
|
||||
}
|
||||
|
||||
return at::_sparse_coo_tensor_with_dims_and_tensors(
|
||||
sparse_dim, dense_dim, computed_sizes, indices, values, values.options().layout(kSparse));
|
||||
sparse_dim,
|
||||
dense_dim,
|
||||
computed_sizes,
|
||||
indices,
|
||||
values,
|
||||
values.options().layout(kSparse));
|
||||
}
|
||||
|
||||
void _validate_sparse_coo_tensor_args(const Tensor& indices, const Tensor& values_, ArrayRef<int64_t> size) {
|
||||
void _validate_sparse_coo_tensor_args(
|
||||
const Tensor& indices,
|
||||
const Tensor& values_,
|
||||
ArrayRef<int64_t> size) {
|
||||
Tensor values = expand_values_if_needed(values_);
|
||||
|
||||
// the following checks are redundant because they are also checked in SparseTensorImpl::set_indices_and_values_unsafe
|
||||
// but we need to ensure them in order to infer the shape.
|
||||
TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sizes())
|
||||
TORCH_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
|
||||
// the following checks are redundant because they are also checked in
|
||||
// SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them
|
||||
// in order to infer the shape.
|
||||
TORCH_CHECK(
|
||||
indices.dim() == 2,
|
||||
"indices must be sparse_dim x nnz, but got: ",
|
||||
indices.sizes())
|
||||
TORCH_CHECK(
|
||||
!indices.is_sparse(),
|
||||
"expected indices to be a dense tensor, but got indices of layout ",
|
||||
indices.layout());
|
||||
int64_t sparse_dim = indices.size(0);
|
||||
int64_t dense_dim = values.dim() - 1;
|
||||
TORCH_CHECK(size.size() == sparse_dim + dense_dim,
|
||||
"number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
|
||||
TORCH_CHECK(
|
||||
size.size() == sparse_dim + dense_dim,
|
||||
"number of dimensions must be sparse_dim (",
|
||||
sparse_dim,
|
||||
") + dense_dim (",
|
||||
dense_dim,
|
||||
"), but got ",
|
||||
size.size());
|
||||
|
||||
// Check to make sure all indices are within the boundaries of `size`
|
||||
if (indices.numel() > 0) {
|
||||
Tensor min_indices = std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor max_indices = std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor min_indices =
|
||||
std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor max_indices =
|
||||
std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor cpu_min_indices, cpu_max_indices;
|
||||
if (indices.is_cuda()) {
|
||||
cpu_min_indices = min_indices.to(at::DeviceType::CPU);
|
||||
@ -238,17 +307,28 @@ void _validate_sparse_coo_tensor_args(const Tensor& indices, const Tensor& value
|
||||
// NB: This used to sync ndim times to access each entry; now we copy
|
||||
// everything to CPU first and then access it.
|
||||
int64_t min_index_in_dim = cpu_min_indices_accessor[d];
|
||||
TORCH_CHECK(min_index_in_dim >= 0,
|
||||
"found negative index ", min_index_in_dim, " for dim ", d);
|
||||
TORCH_CHECK(
|
||||
min_index_in_dim >= 0,
|
||||
"found negative index ",
|
||||
min_index_in_dim,
|
||||
" for dim ",
|
||||
d);
|
||||
int64_t max_index_in_dim = cpu_max_indices_accessor[d];
|
||||
int64_t dim_size = size[static_cast<size_t>(d)];
|
||||
TORCH_CHECK(max_index_in_dim < dim_size,
|
||||
"size is inconsistent with indices: for dim ", d, ", size is ", dim_size, " but found index ", max_index_in_dim);
|
||||
TORCH_CHECK(
|
||||
max_index_in_dim < dim_size,
|
||||
"size is inconsistent with indices: for dim ",
|
||||
d,
|
||||
", size is ",
|
||||
dim_size,
|
||||
" but found index ",
|
||||
max_index_in_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NB: Got rid of the sizes == NULL case
|
||||
|
||||
Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
@ -256,9 +336,11 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
|
||||
c10::optional<bool> pin_memory) {
|
||||
// See [Note: hacky wrapper removal for TensorOptions]
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
|
||||
|
||||
// arg checking
|
||||
TORCH_CHECK(!options.has_layout() || options.layout() == kSparse, "expected sparse layout, but got layout ", options.layout());
|
||||
TORCH_CHECK(
|
||||
!options.has_layout() || options.layout() == kSparse,
|
||||
"expected sparse layout, but got layout ",
|
||||
options.layout());
|
||||
|
||||
at::native::_validate_sparse_coo_tensor_args(indices, values, size);
|
||||
return at::native::_sparse_coo_tensor_unsafe(
|
||||
@ -291,19 +373,31 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, I
|
||||
int64_t dense_dim = values.dim() - 1;
|
||||
|
||||
return at::_sparse_coo_tensor_with_dims_and_tensors(
|
||||
sparse_dim, dense_dim, size, indices, values, values.options().layout(kSparse));
|
||||
sparse_dim,
|
||||
dense_dim,
|
||||
size,
|
||||
indices,
|
||||
values,
|
||||
values.options().layout(kSparse));
|
||||
}
|
||||
|
||||
// NB: Deleted newWithSizeNd variants
|
||||
|
||||
SparseTensor clone_sparse(const SparseTensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
SparseTensor clone_sparse(
|
||||
const SparseTensor& self,
|
||||
c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!optional_memory_format.has_value(),
|
||||
"unsupported memory format option ",
|
||||
optional_memory_format.value());
|
||||
SparseTensor other = new_with_dims_sparse(self.sparse_dim(), self.dense_dim(), self.sizes(),
|
||||
optTypeMetaToScalarType(self.options().dtype_opt()), self.options().layout_opt(),
|
||||
self.options().device_opt(), self.options().pinned_memory_opt());
|
||||
SparseTensor other = new_with_dims_sparse(
|
||||
self.sparse_dim(),
|
||||
self.dense_dim(),
|
||||
self.sizes(),
|
||||
optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().layout_opt(),
|
||||
self.options().device_opt(),
|
||||
self.options().pinned_memory_opt());
|
||||
copy_into_sparse(other, self._indices(), self._values(), true);
|
||||
return other._coalesced_(self.is_coalesced());
|
||||
}
|
||||
@ -312,21 +406,32 @@ SparseTensor clone_sparse(const SparseTensor& self, c10::optional<c10::MemoryFor
|
||||
* reshaping methods
|
||||
******************************************************************************/
|
||||
|
||||
SparseTensor& sparse_resize_(SparseTensor& self, ArrayRef<int64_t> size, int64_t sparse_dim, int64_t dense_dim) {
|
||||
SparseTensor& sparse_resize_(
|
||||
SparseTensor& self,
|
||||
ArrayRef<int64_t> size,
|
||||
int64_t sparse_dim,
|
||||
int64_t dense_dim) {
|
||||
get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size);
|
||||
return self;
|
||||
}
|
||||
|
||||
SparseTensor& sparse_resize_and_clear_(SparseTensor& self, ArrayRef<int64_t> size, int64_t sparse_dim, int64_t dense_dim) {
|
||||
SparseTensor& sparse_resize_and_clear_(
|
||||
SparseTensor& self,
|
||||
ArrayRef<int64_t> size,
|
||||
int64_t sparse_dim,
|
||||
int64_t dense_dim) {
|
||||
get_sparse_impl(self)->resize_and_clear_(sparse_dim, dense_dim, size);
|
||||
return self;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool _is_same_size_as_sparse(const SparseTensor& self, const SparseTensor& src) {
|
||||
return self.sparse_dim() == src.sparse_dim() && self.dense_dim() == src.dense_dim() && self.sizes().equals(src.sizes());
|
||||
}
|
||||
bool _is_same_size_as_sparse(
|
||||
const SparseTensor& self,
|
||||
const SparseTensor& src) {
|
||||
return self.sparse_dim() == src.sparse_dim() &&
|
||||
self.dense_dim() == src.dense_dim() && self.sizes().equals(src.sizes());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Invoked from native/Resize.cpp (no dynamic dispatch necessary)
|
||||
SparseTensor& resize_as_sparse_(SparseTensor& self, const SparseTensor& src) {
|
||||
@ -336,23 +441,33 @@ SparseTensor& resize_as_sparse_(SparseTensor& self, const SparseTensor& src) {
|
||||
return self;
|
||||
}
|
||||
|
||||
SparseTensor dense_to_sparse(const Tensor& self){
|
||||
SparseTensor dense_to_sparse(const Tensor& self) {
|
||||
return dense_to_sparse(self, self.dim());
|
||||
}
|
||||
|
||||
SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
|
||||
SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim) {
|
||||
int64_t dims = self.dim();
|
||||
// TODO: it seems like sparse_dim == 0 could be supported even if self.dim() > 0,
|
||||
// but this would take some work and doesn't seem particularly useful.
|
||||
TORCH_CHECK(sparse_dim > 0 || self.dim() == 0, "sparse_dim must be >0 if dimensionality > 0");
|
||||
TORCH_CHECK(sparse_dim <= dims,
|
||||
"sparse_dim must be less than or equal to self.dim()");
|
||||
// TODO: it seems like sparse_dim == 0 could be supported even if self.dim() >
|
||||
// 0, but this would take some work and doesn't seem particularly useful.
|
||||
TORCH_CHECK(
|
||||
sparse_dim > 0 || self.dim() == 0,
|
||||
"sparse_dim must be >0 if dimensionality > 0");
|
||||
TORCH_CHECK(
|
||||
sparse_dim <= dims,
|
||||
"sparse_dim must be less than or equal to self.dim()");
|
||||
at::TensorOptions sparse_options = self.options().layout(kSparse);
|
||||
std::vector<int64_t> sizes = self.sizes().vec();
|
||||
|
||||
Tensor nz = self.nonzero().transpose(0, 1);
|
||||
if (nz.size(1) == 0) {
|
||||
return new_with_dims_sparse(sparse_dim, dims - sparse_dim, sizes, optTypeMetaToScalarType(sparse_options.dtype_opt()), sparse_options.layout_opt(), sparse_options.device_opt(), sparse_options.pinned_memory_opt());
|
||||
return new_with_dims_sparse(
|
||||
sparse_dim,
|
||||
dims - sparse_dim,
|
||||
sizes,
|
||||
optTypeMetaToScalarType(sparse_options.dtype_opt()),
|
||||
sparse_options.layout_opt(),
|
||||
sparse_options.device_opt(),
|
||||
sparse_options.pinned_memory_opt());
|
||||
}
|
||||
Tensor indices;
|
||||
if (sparse_dim == dims) {
|
||||
@ -360,7 +475,8 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
|
||||
} else {
|
||||
Tensor i = nz.narrow(0, 0, sparse_dim);
|
||||
std::tie(indices, std::ignore, std::ignore) = unique_dim(i, 1);
|
||||
indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633
|
||||
indices = indices.contiguous(); // many sparse CUDA kernels require
|
||||
// contiguity, see issue #12633
|
||||
}
|
||||
|
||||
Tensor values;
|
||||
@ -369,8 +485,8 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
|
||||
values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve);
|
||||
} else {
|
||||
AT_ASSERT(nz.sizes().equals({0, 1}));
|
||||
// In this cases, indices is a clone of nz, which is a tensor of shape (0, 1).
|
||||
// Given sparse tensor invariants, values should be shape (1,)
|
||||
// In this cases, indices is a clone of nz, which is a tensor of shape (0,
|
||||
// 1). Given sparse tensor invariants, values should be shape (1,)
|
||||
values = self.unsqueeze(0).clone(at::MemoryFormat::Preserve);
|
||||
}
|
||||
|
||||
@ -380,18 +496,27 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
|
||||
|
||||
// NB: Dropped the resizeNd variants
|
||||
|
||||
Tensor sparse_to_dense(const SparseTensor& self, c10::optional<ScalarType> dtype) {
|
||||
TORCH_CHECK(!dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
|
||||
if(self.scalar_type() == ScalarType::Half && self.options().device().is_cpu()) {
|
||||
Tensor sparse_to_dense(
|
||||
const SparseTensor& self,
|
||||
c10::optional<ScalarType> dtype) {
|
||||
TORCH_CHECK(
|
||||
!dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
|
||||
if (self.scalar_type() == ScalarType::Half &&
|
||||
self.options().device().is_cpu()) {
|
||||
TORCH_CHECK(false, "to_dense() not supported for float16 on CPU");
|
||||
}
|
||||
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
|
||||
return dst.add_(self);
|
||||
}
|
||||
|
||||
SparseTensor& copy_sparse_(SparseTensor& self, const SparseTensor& src, bool non_blocking) {
|
||||
if (is_same_tensor(self, src)) return self;
|
||||
get_sparse_impl(self)->resize_(src.sparse_dim(), src.dense_dim(), src.sizes());
|
||||
SparseTensor& copy_sparse_(
|
||||
SparseTensor& self,
|
||||
const SparseTensor& src,
|
||||
bool non_blocking) {
|
||||
if (is_same_tensor(self, src))
|
||||
return self;
|
||||
get_sparse_impl(self)->resize_(
|
||||
src.sparse_dim(), src.dense_dim(), src.sizes());
|
||||
copy_into_sparse(self, src._indices(), src._values(), non_blocking);
|
||||
return self._coalesced_(src.is_coalesced());
|
||||
}
|
||||
@ -426,7 +551,11 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
|
||||
|
||||
Tensor indices_scalar = flatten_indices(indices, self.sizes());
|
||||
|
||||
SparseTensor dst = new_sparse(optTypeMetaToScalarType(self.options().dtype_opt()), self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt());
|
||||
SparseTensor dst = new_sparse(
|
||||
optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().layout_opt(),
|
||||
self.options().device_opt(),
|
||||
self.options().pinned_memory_opt());
|
||||
get_sparse_impl(dst)->resize_(sparse_dim, dense_dim, self.sizes());
|
||||
// TODO: is there a more idiomatic way to do this?
|
||||
Tensor newIndices = at::empty(indices.sizes(), indices.options());
|
||||
@ -436,38 +565,51 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
|
||||
Tensor indicesBuffer;
|
||||
Tensor indicesPermutation;
|
||||
std::tie(indicesBuffer, indicesPermutation) = indices_scalar.sort(0);
|
||||
// NB: The accessor accesses here rely on self._nnz() > 0 (tested earlier in this function)
|
||||
// NB: The accessor accesses here rely on self._nnz() > 0 (tested earlier in
|
||||
// this function)
|
||||
auto newIndicesAccessor = newIndices.accessor<int64_t, 2>();
|
||||
auto indicesAccessor = indices.accessor<int64_t, 2>();
|
||||
auto indicesPermutationAccessor = indicesPermutation.accessor<int64_t, 1>();
|
||||
auto indicesBufferAccessor = indicesBuffer.accessor<int64_t, 1>();
|
||||
|
||||
int64_t i = -1;
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
values.scalar_type(), "coalesce", [&] {
|
||||
int64_t prev = -1;
|
||||
int64_t blockSize = values.stride(0);
|
||||
scalar_t* values_ptr = values.data_ptr<scalar_t>();
|
||||
scalar_t* newValues_ptr = newValues.data_ptr<scalar_t>();
|
||||
for (int64_t j = 0; j < nnz; j++) {
|
||||
int64_t pos = indicesPermutationAccessor[j];
|
||||
int64_t curr = indicesBufferAccessor[j];
|
||||
if (curr == prev) {
|
||||
if (values.numel() > 0) { // if values is an empty tensor, there are no elements to copy
|
||||
at::native::cpublas::axpy<scalar_t>(blockSize, 1, values_ptr + pos * blockSize, 1, newValues_ptr + i * blockSize, 1);
|
||||
}
|
||||
} else {
|
||||
++i;
|
||||
for (int64_t d = 0; d < sparse_dim; d++) {
|
||||
newIndicesAccessor[d][i] = indicesAccessor[d][pos];
|
||||
}
|
||||
if (values.numel() > 0) { // if values is an empty tensor, there are no elements to copy
|
||||
at::native::cpublas::copy<scalar_t>(blockSize, values_ptr + pos * blockSize, 1, newValues_ptr + i * blockSize, 1);
|
||||
}
|
||||
}
|
||||
prev = curr;
|
||||
AT_DISPATCH_ALL_TYPES(values.scalar_type(), "coalesce", [&] {
|
||||
int64_t prev = -1;
|
||||
int64_t blockSize = values.stride(0);
|
||||
scalar_t* values_ptr = values.data_ptr<scalar_t>();
|
||||
scalar_t* newValues_ptr = newValues.data_ptr<scalar_t>();
|
||||
for (int64_t j = 0; j < nnz; j++) {
|
||||
int64_t pos = indicesPermutationAccessor[j];
|
||||
int64_t curr = indicesBufferAccessor[j];
|
||||
if (curr == prev) {
|
||||
if (values.numel() >
|
||||
0) { // if values is an empty tensor, there are no elements to copy
|
||||
at::native::cpublas::axpy<scalar_t>(
|
||||
blockSize,
|
||||
1,
|
||||
values_ptr + pos * blockSize,
|
||||
1,
|
||||
newValues_ptr + i * blockSize,
|
||||
1);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
++i;
|
||||
for (int64_t d = 0; d < sparse_dim; d++) {
|
||||
newIndicesAccessor[d][i] = indicesAccessor[d][pos];
|
||||
}
|
||||
if (values.numel() >
|
||||
0) { // if values is an empty tensor, there are no elements to copy
|
||||
at::native::cpublas::copy<scalar_t>(
|
||||
blockSize,
|
||||
values_ptr + pos * blockSize,
|
||||
1,
|
||||
newValues_ptr + i * blockSize,
|
||||
1);
|
||||
}
|
||||
}
|
||||
prev = curr;
|
||||
}
|
||||
});
|
||||
|
||||
dst._coalesced_(true);
|
||||
get_sparse_impl(dst)->set_nnz_and_narrow(i + 1);
|
||||
@ -475,7 +617,6 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
|
||||
return dst;
|
||||
}
|
||||
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// sparse_mask(D, S) -> S
|
||||
//
|
||||
@ -485,12 +626,11 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
|
||||
|
||||
template <typename scalar_t>
|
||||
void inline sparse_mask_out_cpu_kernel(
|
||||
Tensor& r_values,
|
||||
const Tensor& t,
|
||||
const int64_t r_nnz,
|
||||
const int64_t sparse_dim,
|
||||
const Tensor& mask_indices
|
||||
) {
|
||||
Tensor& r_values,
|
||||
const Tensor& t,
|
||||
const int64_t r_nnz,
|
||||
const int64_t sparse_dim,
|
||||
const Tensor& mask_indices) {
|
||||
auto r_values_accessor = r_values.accessor<scalar_t, 1>();
|
||||
auto mask_indices_accessor = mask_indices.accessor<int64_t, 2>();
|
||||
scalar_t* t_ptr = t.data_ptr<scalar_t>();
|
||||
@ -506,14 +646,23 @@ void inline sparse_mask_out_cpu_kernel(
|
||||
});
|
||||
}
|
||||
|
||||
SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const SparseTensor& mask) {
|
||||
SparseTensor& sparse_mask_out_cpu(
|
||||
SparseTensor& r,
|
||||
const Tensor& t,
|
||||
const SparseTensor& mask) {
|
||||
TORCH_CHECK(mask.is_coalesced(), "sparse_mask: mask is uncoalesced");
|
||||
TORCH_CHECK(mask.sizes().equals(t.sizes()), "sparse_mask: operands have incompatible sizes; self has size ",
|
||||
t.sizes(), " but mask has size ", mask.sizes());
|
||||
TORCH_CHECK(
|
||||
mask.sizes().equals(t.sizes()),
|
||||
"sparse_mask: operands have incompatible sizes; self has size ",
|
||||
t.sizes(),
|
||||
" but mask has size ",
|
||||
mask.sizes());
|
||||
AT_ASSERT(!t.is_cuda()); // we were supposed to have dispatched on this
|
||||
TORCH_CHECK(!r.is_cuda(), "sparse_mask: expected 'out' to be CPU, but got CUDA");
|
||||
TORCH_CHECK(!mask.is_cuda(), "sparse_mask: expected 'mask' to be CPU, but got CUDA");
|
||||
resize_as_sparse_(r, mask);
|
||||
TORCH_CHECK(
|
||||
!r.is_cuda(), "sparse_mask: expected 'out' to be CPU, but got CUDA");
|
||||
TORCH_CHECK(
|
||||
!mask.is_cuda(), "sparse_mask: expected 'mask' to be CPU, but got CUDA");
|
||||
at::resize_as_sparse_(r, mask);
|
||||
if (mask._nnz() == 0) {
|
||||
return r.zero_();
|
||||
}
|
||||
@ -527,14 +676,15 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse
|
||||
int64_t r_nnz = mask._nnz();
|
||||
get_sparse_impl(r)->set_nnz_and_narrow(r_nnz);
|
||||
|
||||
if (t.numel() == 0) { // if t is an empty tensor, there is no need to mask its elements
|
||||
if (t.numel() ==
|
||||
0) { // if t is an empty tensor, there is no need to mask its elements
|
||||
return r;
|
||||
}
|
||||
|
||||
if (dim > sparse_dim) {
|
||||
|
||||
// Get a flattened sparse indices, similar to NOTE [ Flatten Sparse Indices ].
|
||||
// Keeping this implementation because it is faster than flatten_indices()
|
||||
// Get a flattened sparse indices, similar to NOTE [ Flatten Sparse Indices
|
||||
// ]. Keeping this implementation because it is faster than
|
||||
// flatten_indices()
|
||||
Tensor indices = at::zeros({mask._nnz()}, mask_indices.options());
|
||||
for (int64_t d = 0; d < mask.sparse_dim(); d++) {
|
||||
indices.mul_(mask.size(d));
|
||||
@ -553,11 +703,7 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(r_values.scalar_type(), "sparse_mask", [&] {
|
||||
sparse_mask_out_cpu_kernel<scalar_t>(
|
||||
r_values,
|
||||
t,
|
||||
r_nnz,
|
||||
sparse_dim,
|
||||
mask_indices);
|
||||
r_values, t, r_nnz, sparse_dim, mask_indices);
|
||||
});
|
||||
}
|
||||
return r;
|
||||
@ -570,26 +716,30 @@ SparseTensor sparse_mask_cpu(const Tensor& t, const SparseTensor& mask) {
|
||||
}
|
||||
|
||||
Tensor sparse_mask_helper_cpu(
|
||||
const SparseTensor& t,
|
||||
const Tensor& mask_indices
|
||||
) {
|
||||
const SparseTensor& t,
|
||||
const Tensor& mask_indices) {
|
||||
/*
|
||||
This is a helper function which filter values from `t._values()` using the `mask_indices`.
|
||||
This CPU implementation uses a simple hash_map to filter values by matching the `mask_indices`
|
||||
with the indices at tensor input `t`.
|
||||
This is a helper function which filter values from `t._values()` using the
|
||||
`mask_indices`. This CPU implementation uses a simple hash_map to filter
|
||||
values by matching the `mask_indices` with the indices at tensor input `t`.
|
||||
|
||||
Inputs:
|
||||
`t` - coalesced sparse tensor input
|
||||
`mask_indices` - mask indices tensor
|
||||
|
||||
Note: The nnz in the output tensor will be same as the `mask_indices`. So it will
|
||||
works independently if the mask is coalesced or not.
|
||||
Note: The nnz in the output tensor will be same as the `mask_indices`. So it
|
||||
will works independently if the mask is coalesced or not.
|
||||
*/
|
||||
TORCH_CHECK(t.is_sparse(), "t: input is not a sparse tensor");
|
||||
TORCH_CHECK(t.is_coalesced(), "t: input is uncoalesced");
|
||||
TORCH_CHECK(mask_indices.dim() == t._indices().dim(), "mask_indices: operands have incompatible indices dim; self has dim ",
|
||||
t._indices().dim(), " but mask has dim ", mask_indices.dim());
|
||||
TORCH_CHECK(mask_indices.is_contiguous(), "mask_indices: mask is not contiguous");
|
||||
TORCH_CHECK(
|
||||
mask_indices.dim() == t._indices().dim(),
|
||||
"mask_indices: operands have incompatible indices dim; self has dim ",
|
||||
t._indices().dim(),
|
||||
" but mask has dim ",
|
||||
mask_indices.dim());
|
||||
TORCH_CHECK(
|
||||
mask_indices.is_contiguous(), "mask_indices: mask is not contiguous");
|
||||
|
||||
int64_t r_nnz = mask_indices.size(1);
|
||||
auto t_v = t._values();
|
||||
@ -600,31 +750,35 @@ Tensor sparse_mask_helper_cpu(
|
||||
auto t_i = t._indices();
|
||||
auto t_nnz = t._nnz();
|
||||
|
||||
std::unordered_map<int64_t, int64_t> t_flatten_indices = std::unordered_map<int64_t, int64_t>{};
|
||||
std::unordered_map<int64_t, int64_t> t_flatten_indices =
|
||||
std::unordered_map<int64_t, int64_t>{};
|
||||
auto full_size = t.sizes();
|
||||
auto ti_flattened_indices = at::sparse::flatten_indices(t_i, full_size);
|
||||
|
||||
// Step 1: flatten the sparse indices `t._indices()` tensor and then map this flatten value `index` to the original position `i`
|
||||
// Step 1: flatten the sparse indices `t._indices()` tensor and then map this
|
||||
// flatten value `index` to the original position `i`
|
||||
auto t_indices_accessor = t_i.accessor<int64_t, 2>();
|
||||
for(int64_t i = 0; i < t_nnz; i++) {
|
||||
for (int64_t i = 0; i < t_nnz; i++) {
|
||||
int64_t index = ti_flattened_indices.data_ptr<int64_t>()[i];
|
||||
t_flatten_indices[index] = i;
|
||||
}
|
||||
|
||||
// Step 2: Filter `t._values()` values by matching the flatten `mask_indices` with the flatten `t._indices()` using the
|
||||
// hash_map `t_flatten_indices`
|
||||
// Step 2: Filter `t._values()` values by matching the flatten `mask_indices`
|
||||
// with the flatten `t._indices()` using the hash_map `t_flatten_indices`
|
||||
|
||||
auto flattened_mask_indices = at::sparse::flatten_indices(mask_indices, full_size);
|
||||
auto flattened_mask_indices =
|
||||
at::sparse::flatten_indices(mask_indices, full_size);
|
||||
at::parallel_for(0, r_nnz, 0, [&](int64_t start, int64_t end) {
|
||||
for (auto i = start; i < end; i++) {
|
||||
int64_t index = flattened_mask_indices.data_ptr<int64_t>()[i];
|
||||
auto iter = t_flatten_indices.find(index);
|
||||
if (iter != t_flatten_indices.end()) {
|
||||
r_values[i] = t_v[ iter->second ];
|
||||
r_values[i] = t_v[iter->second];
|
||||
}
|
||||
}
|
||||
});
|
||||
return r_values;
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -819,6 +819,7 @@ void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j,
|
||||
// r_ = alpha * sparse * dense
|
||||
scalar_t cast_alpha = alpha.to<scalar_t>();
|
||||
scalar_t cast_beta = beta.to<scalar_t>();
|
||||
|
||||
if (cast_beta == 0) {
|
||||
r.zero_();
|
||||
} else if (cast_beta == 1) {
|
||||
@ -981,8 +982,7 @@ Tensor _sparse_mm(
|
||||
// we can redispatch to addmm_out; this is NOT an implementation of
|
||||
// the sparse masking version of mm
|
||||
SparseTensor& _sparse_mm_out(const SparseTensor& sparse,
|
||||
const Tensor& dense
|
||||
,
|
||||
const Tensor& dense,
|
||||
SparseTensor& result) {
|
||||
Tensor t = at::zeros({}, dense.options());
|
||||
return at::addmm_out(result, t, sparse, dense, 0, 1); // redispatch!
|
||||
@ -1601,6 +1601,7 @@ Tensor& bmm_out_sparse_cpu(const SparseTensor& self, const Tensor& mat2, Tensor&
|
||||
Tensor sparse_values = values.slice(0, mat_el_begin_idx, mat_el_end_idx);
|
||||
int64_t sparse_nnz = mat_el_end_idx - mat_el_begin_idx;
|
||||
|
||||
|
||||
s_addmm_out_sparse_dense_worker<scalar_t>(
|
||||
sparse_nnz,
|
||||
dim_i, dim_j, dim_k,
|
||||
|
@ -416,6 +416,12 @@ class TORCH_API Tensor {
|
||||
return impl_->is_sparse();
|
||||
}
|
||||
|
||||
/// Returns is a `Tensor` has a sparse CSR backend.
|
||||
bool is_sparse_csr() const {
|
||||
// NB: this is not a native function to avoid dispatching overhead.
|
||||
return impl_->is_sparse_csr();
|
||||
}
|
||||
|
||||
/// Returns if a `Tensor` is mkldnn tensor.
|
||||
bool is_mkldnn() const {
|
||||
// NB: this is not a native function to avoid dispatching overhead.
|
||||
|
@ -1,7 +1,5 @@
|
||||
# PyTorch Benchmarks
|
||||
|
||||
NOTE: This folder is currently work in progress.
|
||||
|
||||
This folder contains scripts that produce reproducible timings of various PyTorch features.
|
||||
|
||||
It also provides mechanisms to compare PyTorch with other frameworks.
|
||||
|
5
benchmarks/sparse/README.md
Normal file
5
benchmarks/sparse/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
#Sparse benchmarks
|
||||
|
||||
These sets of benchmarks are for the sparse matrix functionality. They exist for
|
||||
comparing the performance of sparse matrix routines such as SpMV between various
|
||||
sparse matrix formats and with other frameworks such as TensorFlow.
|
3
benchmarks/sparse/__init__.py
Normal file
3
benchmarks/sparse/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
105
benchmarks/sparse/spmm.py
Normal file
105
benchmarks/sparse/spmm.py
Normal file
@ -0,0 +1,105 @@
|
||||
import argparse
|
||||
import sys
|
||||
import torch
|
||||
from utils import gen_sparse_csr, gen_sparse_coo, Event
|
||||
|
||||
def test_sparse_csr(m, n, k, nnz, test_count):
|
||||
start_timer = Event(enable_timing=True)
|
||||
stop_timer = Event(enable_timing=True)
|
||||
|
||||
csr = gen_sparse_csr((m, k), nnz)
|
||||
mat = torch.randn(k, n, dtype=torch.double)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start_timer.record()
|
||||
csr.matmul(mat)
|
||||
stop_timer.record()
|
||||
times.append(start_timer.elapsed_time(stop_timer))
|
||||
|
||||
return sum(times) / len(times)
|
||||
|
||||
def test_sparse_coo(m, n, k, nnz, test_count):
|
||||
start_timer = Event(enable_timing=True)
|
||||
stop_timer = Event(enable_timing=True)
|
||||
|
||||
coo = gen_sparse_coo((m, k), nnz)
|
||||
mat = torch.randn(k, n, dtype=torch.double)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start_timer.record()
|
||||
coo.matmul(mat)
|
||||
stop_timer.record()
|
||||
times.append(start_timer.elapsed_time(stop_timer))
|
||||
|
||||
return sum(times) / len(times)
|
||||
|
||||
def test_sparse_coo_and_csr(m, n, k, nnz, test_count):
|
||||
start = Event(enable_timing=True)
|
||||
stop = Event(enable_timing=True)
|
||||
|
||||
coo, csr = gen_sparse_coo_and_csr((m, k), nnz)
|
||||
mat = torch.randn((k, n), dtype=torch.double)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
coo.matmul(mat)
|
||||
stop.record()
|
||||
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
coo_mean_time = sum(times) / len(times)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
csr.matmul(mat)
|
||||
stop.record()
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
csr_mean_time = sum(times) / len(times)
|
||||
|
||||
return coo_mean_time, csr_mean_time
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="SpMM")
|
||||
|
||||
parser.add_argument("--format", default='csr', type=str)
|
||||
parser.add_argument("--m", default='1000', type=int)
|
||||
parser.add_argument("--n", default='1000', type=int)
|
||||
parser.add_argument("--k", default='1000', type=int)
|
||||
parser.add_argument("--nnz_ratio", default='0.1', type=float)
|
||||
parser.add_argument("--outfile", default='stdout', type=str)
|
||||
parser.add_argument("--test_count", default='10', type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.outfile == 'stdout':
|
||||
outfile = sys.stdout
|
||||
elif args.outfile == 'stderr':
|
||||
outfile = sys.stderr
|
||||
else:
|
||||
outfile = open(args.outfile, "a")
|
||||
|
||||
test_count = args.test_count
|
||||
m = args.m
|
||||
n = args.n
|
||||
k = args.k
|
||||
nnz_ratio = args.nnz_ratio
|
||||
|
||||
nnz = int(nnz_ratio * m * k)
|
||||
if args.format == 'csr':
|
||||
time = test_sparse_csr(m, n, k, nnz, test_count)
|
||||
elif args.format == 'coo':
|
||||
time = test_sparse_coo(m, n, k, nnz, test_count)
|
||||
elif args.format == 'both':
|
||||
time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count)
|
||||
|
||||
if args.format == 'both':
|
||||
print("format=coo", " nnz_ratio=", nnz_ratio, " m=", m, " n=", n, " k=", k, " time=", time_coo, file=outfile)
|
||||
print("format=csr", " nnz_ratio=", nnz_ratio, " m=", m, " n=", n, " k=", k, " time=", time_csr, file=outfile)
|
||||
else:
|
||||
print("format=", args.format, " nnz_ratio=", nnz_ratio, " m=", m, " n=", n, " k=", k, " time=", time,
|
||||
file=outfile)
|
103
benchmarks/sparse/spmv.py
Normal file
103
benchmarks/sparse/spmv.py
Normal file
@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import sys
|
||||
import torch
|
||||
from .utils import gen_sparse_csr, gen_sparse_coo, gen_sparse_coo_and_csr, Event
|
||||
|
||||
def test_sparse_csr(m, nnz, test_count):
|
||||
start_timer = Event(enable_timing=True)
|
||||
stop_timer = Event(enable_timing=True)
|
||||
|
||||
csr = gen_sparse_csr((m, m), nnz)
|
||||
vector = torch.randn(m, dtype=torch.double)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start_timer.record()
|
||||
csr.matmul(vector)
|
||||
stop_timer.record()
|
||||
times.append(start_timer.elapsed_time(stop_timer))
|
||||
|
||||
return sum(times) / len(times)
|
||||
|
||||
def test_sparse_coo(m, nnz, test_count):
|
||||
start_timer = Event(enable_timing=True)
|
||||
stop_timer = Event(enable_timing=True)
|
||||
|
||||
coo = gen_sparse_coo((m, m), nnz)
|
||||
vector = torch.randn(m, dtype=torch.double)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start_timer.record()
|
||||
coo.matmul(vector)
|
||||
stop_timer.record()
|
||||
times.append(start_timer.elapsed_time(stop_timer))
|
||||
|
||||
return sum(times) / len(times)
|
||||
|
||||
def test_sparse_coo_and_csr(m, nnz, test_count):
|
||||
start = Event(enable_timing=True)
|
||||
stop = Event(enable_timing=True)
|
||||
|
||||
coo, csr = gen_sparse_coo_and_csr((m, m), nnz)
|
||||
vector = torch.randn(m, dtype=torch.double)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
coo.matmul(vector)
|
||||
stop.record()
|
||||
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
coo_mean_time = sum(times) / len(times)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
csr.matmul(vector)
|
||||
stop.record()
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
csr_mean_time = sum(times) / len(times)
|
||||
|
||||
return coo_mean_time, csr_mean_time
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="SpMV")
|
||||
|
||||
parser.add_argument("--format", default='csr', type=str)
|
||||
parser.add_argument("--m", default='1000', type=int)
|
||||
parser.add_argument("--nnz_ratio", default='0.1', type=float)
|
||||
parser.add_argument("--outfile", default='stdout', type=str)
|
||||
parser.add_argument("--test_count", default='10', type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.outfile == 'stdout':
|
||||
outfile = sys.stdout
|
||||
elif args.outfile == 'stderr':
|
||||
outfile = sys.stderr
|
||||
else:
|
||||
outfile = open(args.outfile, "a")
|
||||
|
||||
test_count = args.test_count
|
||||
m = args.m
|
||||
nnz_ratio = args.nnz_ratio
|
||||
|
||||
nnz = int(nnz_ratio * m * m)
|
||||
if args.format == 'csr':
|
||||
time = test_sparse_csr(m, nnz, test_count)
|
||||
elif args.format == 'coo':
|
||||
time = test_sparse_coo(m, nnz, test_count)
|
||||
elif args.format == 'both':
|
||||
time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count)
|
||||
|
||||
if args.format != 'both':
|
||||
print("format=", args.format, " nnz_ratio=", nnz_ratio, " m=", m,
|
||||
" time=", time, file=outfile)
|
||||
else:
|
||||
print("format=coo", " nnz_ratio=", nnz_ratio, " m=", m,
|
||||
" time=", time_coo, file=outfile)
|
||||
print("format=csr", " nnz_ratio=", nnz_ratio, " m=", m,
|
||||
" time=", time_csr, file=outfile)
|
41
benchmarks/sparse/test_csr.sh
Normal file
41
benchmarks/sparse/test_csr.sh
Normal file
@ -0,0 +1,41 @@
|
||||
OUTFILE=spmm-no-mkl-test.txt
|
||||
PYTORCH_HOME=$1
|
||||
|
||||
cd $PYTORCH_HOME
|
||||
|
||||
echo "" >> $OUTFILE
|
||||
echo "----- USE_MKL=1 -----" >> $OUTFILE
|
||||
rm -rf build
|
||||
|
||||
export USE_MKL=1
|
||||
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
|
||||
python setup.py build --cmake-only
|
||||
ccmake build # or cmake-gui build
|
||||
|
||||
python setup.py install
|
||||
|
||||
cd benchmarks
|
||||
echo "!! SPARSE SPMM TIME BENCHMARK!! " >> $OUTFILE
|
||||
for dim0 in 1000 5000 10000; do
|
||||
for nnzr in 0.01 0.05 0.1 0.3; do
|
||||
python -m sparse.spmm --format csr --m $dim0 --n $dim0 --k $dim0 --nnz_ratio $nnzr --outfile $OUTFILE
|
||||
# python -m sparse.spmm --format coo --m $dim0 --n $dim0 --k $dim0 --nnz_ratio $nnzr --outfile $OUTFILE
|
||||
done
|
||||
done
|
||||
echo "----------------------" >> $OUTFILE
|
||||
|
||||
cd $PYTORCH_HOME
|
||||
echo "----- USE_MKL=0 ------" >> $OUTFILE
|
||||
rm -rf build
|
||||
|
||||
export USE_MKL=0
|
||||
python setup.py install
|
||||
|
||||
cd benchmarks
|
||||
for dim0 in 1000 5000 10000; do
|
||||
for nnzr in 0.01 0.05 0.1 0.3; do
|
||||
python -m sparse.spmv --format csr --m $dim0 --nnz_ratio $nnzr --outfile $OUTFILE
|
||||
python -m sparse.spmv --format coo --m $dim0 --nnz_ratio $nnzr --outfile $OUTFILE
|
||||
done
|
||||
done
|
||||
echo "----------------------" >> $OUTFILE
|
54
benchmarks/sparse/utils.py
Normal file
54
benchmarks/sparse/utils.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch
|
||||
import functools
|
||||
import random
|
||||
import operator
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
# shim for torch.cuda.Event when running on cpu
|
||||
class Event(object):
|
||||
def __init__(self, enable_timing):
|
||||
pass
|
||||
|
||||
def record(self):
|
||||
self.time = time.perf_counter()
|
||||
|
||||
def elapsed_time(self, end_event):
|
||||
assert isinstance(end_event, Event)
|
||||
return end_event.time - self.time
|
||||
|
||||
def gen_sparse_csr(shape, nnz):
|
||||
fill_value = 0
|
||||
total_values = functools.reduce(operator.mul, shape, 1)
|
||||
dense = np.random.randn(total_values)
|
||||
fills = random.sample(list(range(total_values)), total_values - nnz)
|
||||
|
||||
for f in fills:
|
||||
dense[f] = fill_value
|
||||
dense = torch.from_numpy(dense.reshape(shape))
|
||||
|
||||
return dense.to_sparse_csr()
|
||||
|
||||
def gen_sparse_coo(shape, nnz):
|
||||
dense = np.random.randn(*shape)
|
||||
values = []
|
||||
indices = [[], []]
|
||||
for n in range(nnz):
|
||||
row = random.randint(0, shape[0] - 1)
|
||||
col = random.randint(0, shape[1] - 1)
|
||||
indices[0].append(row)
|
||||
indices[1].append(col)
|
||||
values.append(dense[row, col])
|
||||
|
||||
return torch.sparse_coo_tensor(indices, values, size=shape)
|
||||
|
||||
def gen_sparse_coo_and_csr(shape, nnz):
|
||||
total_values = functools.reduce(operator.mul, shape, 1)
|
||||
dense = np.random.randn(total_values)
|
||||
fills = random.sample(list(range(total_values)), total_values - nnz)
|
||||
|
||||
for f in fills:
|
||||
dense[f] = 0
|
||||
|
||||
dense = torch.from_numpy(dense.reshape(shape))
|
||||
return dense.to_sparse(), dense.to_sparse_csr()
|
@ -34,6 +34,8 @@ enum class Backend {
|
||||
XPU,
|
||||
SparseCPU,
|
||||
SparseCUDA,
|
||||
SparseCsrCPU,
|
||||
SparseCsrCUDA,
|
||||
SparseHIP,
|
||||
SparseXPU,
|
||||
MSNPU,
|
||||
@ -74,6 +76,10 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
|
||||
return Backend::SparseCUDA;
|
||||
} else if (t == DispatchKey::SparseHIP) {
|
||||
return Backend::SparseHIP;
|
||||
} else if (t == DispatchKey::SparseCsrCPU) {
|
||||
return Backend::SparseCsrCPU;
|
||||
} else if (t == DispatchKey::SparseCsrCUDA) {
|
||||
return Backend::SparseCsrCUDA;
|
||||
} else if (t == DispatchKey::MkldnnCPU) {
|
||||
return Backend::MkldnnCPU;
|
||||
} else if (t == DispatchKey::QuantizedCPU) {
|
||||
@ -117,6 +123,10 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
|
||||
return DispatchKey::SparseCUDA;
|
||||
case Backend::SparseHIP:
|
||||
return DispatchKey::SparseHIP;
|
||||
case Backend::SparseCsrCPU:
|
||||
return DispatchKey::SparseCsrCPU;
|
||||
case Backend::SparseCsrCUDA:
|
||||
return DispatchKey::SparseCsrCUDA;
|
||||
case Backend::MkldnnCPU:
|
||||
return DispatchKey::MkldnnCPU;
|
||||
case Backend::Vulkan:
|
||||
@ -156,6 +166,10 @@ static inline DeviceType backendToDeviceType(Backend b) {
|
||||
return DeviceType::CUDA;
|
||||
case Backend::SparseHIP:
|
||||
return DeviceType::HIP;
|
||||
case Backend::SparseCsrCPU:
|
||||
return DeviceType::CPU;
|
||||
case Backend::SparseCsrCUDA:
|
||||
return DeviceType::CUDA;
|
||||
case Backend::XPU:
|
||||
case Backend::SparseXPU:
|
||||
case Backend::QuantizedXPU:
|
||||
@ -205,6 +219,10 @@ static inline const char* toString(Backend b) {
|
||||
return "SparseHIP";
|
||||
case Backend::SparseXPU:
|
||||
return "SparseXPU";
|
||||
case Backend::SparseCsrCPU:
|
||||
return "SparseCsrCPU";
|
||||
case Backend::SparseCsrCUDA:
|
||||
return "SparseCsrCUDA";
|
||||
case Backend::MkldnnCPU:
|
||||
return "MkldnnCPU";
|
||||
case Backend::Vulkan:
|
||||
@ -234,4 +252,14 @@ static inline bool isSparse(Backend b) {
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool isSparseCsr(Backend b) {
|
||||
switch(b) {
|
||||
case Backend::SparseCsrCPU:
|
||||
case Backend::SparseCsrCUDA:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -11,6 +11,7 @@ const char* toString(DispatchKey t) {
|
||||
return "CPU";
|
||||
case DispatchKey::CUDA:
|
||||
return "CUDA";
|
||||
|
||||
case DispatchKey::HIP:
|
||||
return "HIP";
|
||||
case DispatchKey::FPGA:
|
||||
@ -51,6 +52,10 @@ const char* toString(DispatchKey t) {
|
||||
return "SparseCPU";
|
||||
case DispatchKey::SparseCUDA:
|
||||
return "SparseCUDA";
|
||||
case DispatchKey::SparseCsrCPU:
|
||||
return "SparseCsrCPU";
|
||||
case DispatchKey::SparseCsrCUDA:
|
||||
return "SparseCsrCUDA";
|
||||
case DispatchKey::SparseHIP:
|
||||
return "SparseHIP";
|
||||
case DispatchKey::SparseXPU:
|
||||
|
@ -112,6 +112,9 @@ enum class DispatchKey : uint8_t {
|
||||
// [Masquerading as CUDA]
|
||||
SparseXPU, // For out of tree Intel's heterogeneous computing plug-in
|
||||
|
||||
SparseCsrCPU,
|
||||
SparseCsrCUDA,
|
||||
|
||||
NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
|
||||
// Here are reserved backends for user-defined backends, see Note [Private use
|
||||
// DispatchKey]
|
||||
|
@ -235,7 +235,9 @@ constexpr DispatchKeySet autogradother_backends = DispatchKeySet({
|
||||
DispatchKey::SparseCPU,
|
||||
DispatchKey::SparseCUDA,
|
||||
DispatchKey::SparseHIP,
|
||||
DispatchKey::Meta,
|
||||
DispatchKey::SparseCsrCPU,
|
||||
DispatchKey::SparseCsrCUDA,
|
||||
DispatchKey::Meta
|
||||
});
|
||||
|
||||
// The set of dispatch keys that come after autograd
|
||||
|
@ -6,10 +6,11 @@
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
enum class Layout : int8_t { Strided, Sparse, Mkldnn, NumOptions };
|
||||
enum class Layout : int8_t { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
|
||||
|
||||
constexpr auto kStrided = Layout::Strided;
|
||||
constexpr auto kSparse = Layout::Sparse;
|
||||
constexpr auto kSparseCsr = Layout::SparseCsr;
|
||||
constexpr auto kMkldnn = Layout::Mkldnn;
|
||||
|
||||
inline Layout layout_from_backend(Backend backend) {
|
||||
@ -21,6 +22,9 @@ inline Layout layout_from_backend(Backend backend) {
|
||||
return Layout::Sparse;
|
||||
case Backend::MkldnnCPU:
|
||||
return Layout::Mkldnn;
|
||||
case Backend::SparseCsrCPU:
|
||||
case Backend::SparseCsrCUDA:
|
||||
return Layout::SparseCsr;
|
||||
default:
|
||||
return Layout::Strided;
|
||||
}
|
||||
@ -32,6 +36,8 @@ inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) {
|
||||
return stream << "Strided";
|
||||
case at::kSparse:
|
||||
return stream << "Sparse";
|
||||
case at::kSparseCsr:
|
||||
return stream << "SparseCsr";
|
||||
case at::kMkldnn:
|
||||
return stream << "Mkldnn";
|
||||
default:
|
||||
|
@ -590,6 +590,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
key_set_.has(DispatchKey::SparseXPU);
|
||||
}
|
||||
|
||||
// Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR format.
|
||||
bool is_sparse_csr() const {
|
||||
return key_set_.has(DispatchKey::SparseCsrCPU) ||
|
||||
key_set_.has(DispatchKey::SparseCsrCUDA);
|
||||
}
|
||||
|
||||
bool is_quantized() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance reasons.
|
||||
return key_set_.has(DispatchKey::QuantizedCPU) ||
|
||||
|
@ -337,6 +337,10 @@ struct C10_API TensorOptions {
|
||||
return layout_ == c10::Layout::Sparse;
|
||||
}
|
||||
|
||||
bool is_sparse_csr() const {
|
||||
return layout_ == c10::Layout::SparseCsr;
|
||||
}
|
||||
|
||||
// For compatibility with legacy tensor.type() comparisons
|
||||
bool type_equal(const TensorOptions& other) const {
|
||||
return computeDispatchKey() == other.computeDispatchKey() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
|
||||
@ -660,6 +664,15 @@ inline DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::opti
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for mkldnn layout: ", device_.type());
|
||||
}
|
||||
case Layout::SparseCsr:
|
||||
switch(device_.type()) {
|
||||
case DeviceType::CPU:
|
||||
return DispatchKey::SparseCsrCPU;
|
||||
case DeviceType::CUDA:
|
||||
return DispatchKey::SparseCsrCUDA;
|
||||
default:
|
||||
AT_ERROR("Unsupported device type for sparse CSR layout: ", device_.type());
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported layout: ", layout_);
|
||||
}
|
||||
@ -671,6 +684,8 @@ inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
|
||||
case DispatchKey::SparseCUDA:
|
||||
case DispatchKey::SparseHIP:
|
||||
case DispatchKey::SparseXPU:
|
||||
case DispatchKey::SparseCsrCPU:
|
||||
case DispatchKey::SparseCsrCUDA:
|
||||
return Layout::Sparse;
|
||||
case DispatchKey::MkldnnCPU:
|
||||
return Layout::Mkldnn;
|
||||
|
@ -126,6 +126,7 @@ If you don't see an operation listed here, but it would help your use case, plea
|
||||
:meth:`Tensor.is_shared`,None
|
||||
":meth:`Tensor.is_signed`, :func:`torch.is_signed`",None
|
||||
:attr:`Tensor.is_sparse`,None
|
||||
:attr:`Tensor.is_sparse_csr`,None
|
||||
:func:`torch.is_tensor`,None
|
||||
:meth:`Tensor.item`,None
|
||||
":meth:`Tensor.kthvalue`, :func:`torch.kthvalue`",:ref:`removes_dimensions-doc`
|
||||
|
@ -53,8 +53,8 @@ __ https://en.wikipedia.org/wiki/Sparse_matrix
|
||||
Sparse COO tensors
|
||||
++++++++++++++++++
|
||||
|
||||
Currently, PyTorch implements the so-called Coordinate format, or COO
|
||||
format, as the default sparse storage format for storing sparse
|
||||
PyTorch implements the so-called Coordinate format, or COO
|
||||
format, as one of the storage formats for implementing sparse
|
||||
tensors. In COO format, the specified elements are stored as tuples
|
||||
of element indices and the corresponding values. In particular,
|
||||
|
||||
@ -363,6 +363,82 @@ assumption that the fill value is negative infinity.
|
||||
|
||||
.. See https://github.com/Quansight-Labs/rfcs/tree/pearu/rfc-fill-value/RFC-0004-sparse-fill-value for a new API
|
||||
|
||||
.. _sparse-csr-docs:
|
||||
|
||||
Sparse CSR Tensor
|
||||
+++++++++++++++++
|
||||
|
||||
The CSR (Compressed Sparse Row) sparse tensor format implements the CSR format
|
||||
for storage of 2 dimensional tensors. Although there is no support for N-dimensional
|
||||
tensors, the primary advantage over the COO format is better use of storage and
|
||||
much faster computation operations such as sparse matrix-vector multiplication
|
||||
using MKL and MAGMA backends. CUDA support does not exist as of now.
|
||||
|
||||
A CSR sparse tensor consists of three 1-D tensors: ``crow_indices``, ``col_indices``
|
||||
and ``values``:
|
||||
|
||||
- The ``crow_indices`` tensor consists of compressed row indices. This is a 1-D tensor
|
||||
of size ``size[0] + 1``. The last element is the number of non-zeros. This tensor
|
||||
encodes the index in ``values`` and ``col_indices`` depending on where the given row
|
||||
starts. Each successive number in the tensor subtracted by the number before it denotes
|
||||
the number of elements in a given row.
|
||||
- The ``col_indices`` tensor contains the column indices of each value. This is a 1-D
|
||||
tensor of size ``nnz``.
|
||||
- The ``values`` tensor contains the values of the CSR tensor. This is a 1-D tensor
|
||||
of size ``nnz``.
|
||||
|
||||
.. note::
|
||||
|
||||
The index tensors ``crow_indices`` and ``col_indices`` should have element type either
|
||||
``torch.int64`` (default) or ``torch.int32``. If you want to use MKL-enabled matrix
|
||||
operations, use ``torch.int32``. This is as a result of the default linking of pytorch
|
||||
being with MKL LP64, which uses 32 bit integer indexing.
|
||||
|
||||
Construction of CSR tensors
|
||||
---------------------------
|
||||
|
||||
Sparse CSR matrices can be directly constructed by using the :func:`torch.sparse_csr_tensor`
|
||||
method. The user must supply the row and column indices and values tensors separately.
|
||||
The ``size`` argument is optional and will be deduced from the the ``crow_indices``
|
||||
and ``col_indices`` if it is not present.
|
||||
|
||||
>>> crow_indices = torch.tensor([0, 2, 4])
|
||||
>>> col_indices = torch.tensor([0, 1, 0, 1])
|
||||
>>> values = torch.tensor([1, 2, 3, 4])
|
||||
>>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.double)
|
||||
>>> csr
|
||||
tensor(crow_indices=tensor([0, 2, 4]),
|
||||
col_indices=tensor([0, 1, 0, 1]),
|
||||
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
|
||||
dtype=torch.float64)
|
||||
>>> csr.to_dense()
|
||||
tensor([[1., 2.],
|
||||
[3., 4.]], dtype=torch.float64)
|
||||
|
||||
CSR Tensor Operations
|
||||
---------------------
|
||||
|
||||
The simplest way of constructing a sparse CSR tensor from a strided or sparse COO
|
||||
tensor is to use :meth:`tensor.to_sparse_csr`. Any zeros in the (strided) tensor will
|
||||
be interpreted as missing values in the sparse tensor:
|
||||
|
||||
>>> a = torch.tensor([[0, 0, 1, 0], [1, 2, 0, 0], [0, 0, 0, 0]], dtype = torch.float64)
|
||||
>>> sp = a.to_sparse_csr()
|
||||
>>> sp
|
||||
tensor(crow_indices=tensor([0, 1, 3, 3]),
|
||||
col_indices=tensor([2, 0, 1]),
|
||||
values=tensor([1., 1., 2.]), size=(3, 4), nnz=3, dtype=torch.float64)
|
||||
|
||||
The sparse matrix-vector multiplication can be performed with the
|
||||
:meth:`tensor.matmul` method. This is currently the only math operation
|
||||
supported on CSR tensors.
|
||||
|
||||
>>> vec = torch.randn(4, 1, dtype=torch.float64)
|
||||
>>> sp.matmul(vec)
|
||||
tensor([[0.9078],
|
||||
[1.3180],
|
||||
[0.0000]], dtype=torch.float64)
|
||||
|
||||
Supported Linear Algebra operations
|
||||
+++++++++++++++++++++++++++++++++++
|
||||
|
||||
@ -380,7 +456,9 @@ multiplication, and ``@`` is matrix multiplication.
|
||||
:delim: ;
|
||||
|
||||
:func:`torch.mv`;no; ``M[sparse_coo] @ V[strided] -> V[strided]``
|
||||
:func:`torch.mv`;no; ``M[sparse_csr] @ V[strided] -> V[strided]``
|
||||
:func:`torch.matmul`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
|
||||
:func:`torch.matmul`; no; ``M[sparse_csr] @ M[strided] -> M[strided]``
|
||||
:func:`torch.mm`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
|
||||
:func:`torch.sparse.mm`; yes; ``M[sparse_coo] @ M[strided] -> M[strided]``
|
||||
:func:`torch.smm`; no; ``M[sparse_coo] @ M[strided] -> M[sparse_coo]``
|
||||
@ -405,8 +483,6 @@ matrix arguments.
|
||||
applications can still compute this using the matrix relation ``D @
|
||||
S == (S.t() @ D.t()).t()``.
|
||||
|
||||
|
||||
|
||||
Tensor methods and sparse
|
||||
+++++++++++++++++++++++++
|
||||
|
||||
@ -420,6 +496,7 @@ The following Tensor methods are related to sparse tensors:
|
||||
Tensor.sparse_dim
|
||||
Tensor.sparse_mask
|
||||
Tensor.to_sparse
|
||||
Tensor.to_sparse_csr
|
||||
Tensor.indices
|
||||
Tensor.values
|
||||
|
||||
@ -435,6 +512,14 @@ The following Tensor methods are specific to sparse COO tensors:
|
||||
Tensor.is_coalesced
|
||||
Tensor.to_dense
|
||||
|
||||
The following methods are specific to :ref:`sparse CSR tensors <sparse-csr-docs>`:
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
Tensor.crow_indices
|
||||
Tensor.col_indices
|
||||
|
||||
The following Tensor methods support sparse COO tensors:
|
||||
|
||||
:meth:`~torch.Tensor.add`
|
||||
@ -496,6 +581,7 @@ Torch functions specific to sparse Tensors
|
||||
:nosignatures:
|
||||
|
||||
sparse_coo_tensor
|
||||
sparse_csr_tensor
|
||||
sparse.sum
|
||||
sparse.addmm
|
||||
sparse.mm
|
||||
|
@ -1803,6 +1803,23 @@ graph(%Ra, %Rb):
|
||||
self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
|
||||
self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
|
||||
|
||||
@suppress_warnings
|
||||
def test_sparse_csr_tensors(self):
|
||||
@torch.jit.ignore
|
||||
def get_sparse_csr():
|
||||
return torch.randn(3, 3).to_sparse_csr()
|
||||
|
||||
@torch.jit.script
|
||||
def test_is_sparse_csr(input):
|
||||
# type: (Tensor) -> bool
|
||||
return input.is_sparse_csr
|
||||
|
||||
script_out_is_sparse_csr = test_is_sparse_csr(get_sparse_csr())
|
||||
script_out_is_dense_csr = test_is_sparse_csr(torch.randn(3, 3))
|
||||
|
||||
self.assertEqual(script_out_is_sparse_csr, True)
|
||||
self.assertEqual(script_out_is_dense_csr, False)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
||||
def test_device_not_equal(self):
|
||||
|
||||
|
202
test/test_sparse_csr.py
Normal file
202
test/test_sparse_csr.py
Normal file
@ -0,0 +1,202 @@
|
||||
import torch
|
||||
|
||||
torch.set_default_dtype(torch.double)
|
||||
|
||||
import functools
|
||||
import random
|
||||
import operator
|
||||
import numpy as np
|
||||
import warnings
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, load_tests
|
||||
|
||||
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
|
||||
class TestSparseCSR(TestCase):
|
||||
def gen_sparse_csr(self, shape, nnz):
|
||||
total_values = functools.reduce(operator.mul, shape, 1)
|
||||
dense = np.random.randn(total_values)
|
||||
fills = random.sample(list(range(total_values)), total_values - nnz)
|
||||
|
||||
for f in fills:
|
||||
dense[f] = 0
|
||||
dense = torch.from_numpy(dense.reshape(shape))
|
||||
|
||||
return dense.to_sparse_csr()
|
||||
|
||||
def setUp(self):
|
||||
# These parameters control the various ways we can run the test.
|
||||
# We will subclass and override this method to implement CUDA
|
||||
# tests
|
||||
self.is_cuda = False
|
||||
self.device = 'cpu'
|
||||
self.index_tensor = lambda *args: torch.tensor(*args, dtype=torch.int32)
|
||||
self.value_tensor = lambda *args: torch.tensor(*args, dtype=torch.double)
|
||||
|
||||
def test_csr_layout(self):
|
||||
self.assertEqual(str(torch.sparse_csr), 'torch.sparse_csr')
|
||||
self.assertEqual(type(torch.sparse_csr), torch.layout)
|
||||
|
||||
def test_sparse_csr_constructor_shape_inference(self):
|
||||
crow_indices = [0, 2, 4]
|
||||
col_indices = [0, 1, 0, 1]
|
||||
values = [1, 2, 3, 4]
|
||||
sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
|
||||
torch.tensor(col_indices, dtype=torch.int64),
|
||||
torch.tensor(values), dtype=torch.double)
|
||||
self.assertEqual(torch.tensor(crow_indices, dtype=torch.int64), sparse.crow_indices())
|
||||
self.assertEqual((len(crow_indices) - 1, max(col_indices) + 1), sparse.shape)
|
||||
|
||||
def test_sparse_csr_constructor(self):
|
||||
crow_indices = [0, 2, 4]
|
||||
col_indices = [0, 1, 0, 1]
|
||||
values = [1, 2, 3, 4]
|
||||
|
||||
sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int32),
|
||||
torch.tensor(col_indices, dtype=torch.int32),
|
||||
torch.tensor(values), size=(2, 10), dtype=torch.float)
|
||||
|
||||
self.assertEqual((2, 10), sparse.shape)
|
||||
self.assertEqual(torch.tensor(crow_indices, dtype=torch.int32), sparse.crow_indices())
|
||||
|
||||
def test_sparse_csr_print(self):
|
||||
shape_nnz = [
|
||||
((1000, 10), 10)
|
||||
]
|
||||
|
||||
printed = []
|
||||
for shape, nnz in shape_nnz:
|
||||
values_shape = torch.Size((nnz,))
|
||||
col_indices_shape = torch.Size((nnz,))
|
||||
crow_indices_shape = torch.Size((shape[0] + 1,))
|
||||
printed.append("# shape: {}".format(torch.Size(shape)))
|
||||
printed.append("# nnz: {}".format(nnz))
|
||||
printed.append("# crow_indices shape: {}".format(crow_indices_shape))
|
||||
printed.append("# col_indices shape: {}".format(col_indices_shape))
|
||||
printed.append("# values_shape: {}".format(values_shape))
|
||||
|
||||
x = self.gen_sparse_csr(shape, nnz)
|
||||
|
||||
printed.append("# sparse tensor")
|
||||
printed.append(str(x))
|
||||
printed.append("# _crow_indices")
|
||||
printed.append(str(x.crow_indices()))
|
||||
printed.append("# _col_indices")
|
||||
printed.append(str(x.col_indices()))
|
||||
printed.append("# _values")
|
||||
printed.append(str(x.values()))
|
||||
printed.append('')
|
||||
|
||||
self.assertEqual(len(printed) > 0, True)
|
||||
|
||||
def test_sparse_csr_from_dense(self):
|
||||
sp = torch.tensor([[1, 2], [3, 4]]).to_sparse_csr()
|
||||
self.assertEqual(torch.tensor([0, 2, 4], dtype=torch.int64), sp.crow_indices())
|
||||
self.assertEqual(torch.tensor([0, 1, 0, 1], dtype=torch.int64), sp.col_indices())
|
||||
self.assertEqual(torch.tensor([1, 2, 3, 4], dtype=torch.int64), sp.values())
|
||||
|
||||
dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]])
|
||||
sparse = dense.to_sparse_csr()
|
||||
self.assertEqual(torch.tensor([0, 2, 2, 3], dtype=torch.int64), sparse.crow_indices())
|
||||
self.assertEqual(torch.tensor([0, 1, 0], dtype=torch.int64), sparse.col_indices())
|
||||
self.assertEqual(torch.tensor([4, 5, 1]), sparse.values())
|
||||
|
||||
dense = torch.tensor([[0, 0, 0], [0, 0, 1], [1, 0, 0]])
|
||||
sparse = dense.to_sparse_csr()
|
||||
self.assertEqual(torch.tensor([0, 0, 1, 2], dtype=torch.int64), sparse.crow_indices())
|
||||
self.assertEqual(torch.tensor([2, 0], dtype=torch.int64), sparse.col_indices())
|
||||
self.assertEqual(torch.tensor([1, 1]), sparse.values())
|
||||
|
||||
dense = torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]])
|
||||
sparse = dense.to_sparse_csr()
|
||||
self.assertEqual(torch.tensor([0, 3, 6, 9], dtype=torch.int64), sparse.crow_indices())
|
||||
self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices())
|
||||
self.assertEqual(torch.tensor([2] * 9), sparse.values())
|
||||
|
||||
def test_dense_convert(self):
|
||||
size = (5, 5)
|
||||
dense = torch.randn(size)
|
||||
sparse = dense.to_sparse_csr()
|
||||
self.assertEqual(sparse.to_dense(), dense)
|
||||
|
||||
size = (4, 6)
|
||||
dense = torch.randn(size)
|
||||
sparse = dense.to_sparse_csr()
|
||||
self.assertEqual(sparse.to_dense(), dense)
|
||||
|
||||
crow_indices = torch.tensor([0, 3, 5])
|
||||
col_indices = torch.tensor([0, 1, 2, 0, 1])
|
||||
values = torch.tensor([1, 2, 1, 3, 4])
|
||||
csr = torch.sparse_csr_tensor(crow_indices, col_indices,
|
||||
values, dtype=torch.double)
|
||||
dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=torch.double)
|
||||
self.assertEqual(csr.to_dense(), dense)
|
||||
|
||||
def test_coo_to_csr_convert(self):
|
||||
size = (5, 5)
|
||||
dense = torch.randn(size)
|
||||
sparse_coo = dense.to_sparse()
|
||||
sparse_csr = sparse_coo.to_sparse_csr()
|
||||
|
||||
self.assertTrue(sparse_csr.is_sparse_csr)
|
||||
self.assertEqual(sparse_csr.to_dense(), dense)
|
||||
|
||||
vec = torch.randn((5, 1))
|
||||
coo_product = sparse_coo.matmul(vec)
|
||||
csr_product = sparse_csr.matmul(vec)
|
||||
|
||||
self.assertEqual(coo_product, csr_product)
|
||||
|
||||
vec = torch.randn((100, 1))
|
||||
index = self.index_tensor([
|
||||
[1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
|
||||
[92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
|
||||
])
|
||||
values = self.value_tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
coo = torch.sparse_coo_tensor(index, values, torch.Size([100, 100]))
|
||||
csr = coo.to_sparse_csr()
|
||||
|
||||
self.assertEqual(coo.matmul(vec), csr.matmul(vec))
|
||||
|
||||
def test_mkl_matvec_warnings(self):
|
||||
if torch.has_mkl:
|
||||
sp = torch.sparse_csr_tensor(torch.tensor([0, 2, 4]),
|
||||
torch.tensor([0, 1, 0, 1]),
|
||||
torch.tensor([1, 2, 3, 4], dtype=torch.double))
|
||||
vec = torch.randn((2, 1))
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
sp.matmul(vec)
|
||||
self.assertEqual(len(w), 2)
|
||||
|
||||
def test_dense_convert_error(self):
|
||||
size = (4, 2, 4)
|
||||
dense = torch.randn(size)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Only 2D"):
|
||||
sparse = dense.to_sparse_csr()
|
||||
|
||||
def test_csr_matvec(self):
|
||||
side = 100
|
||||
csr = self.gen_sparse_csr((side, side), 1000)
|
||||
vec = torch.randn(side, dtype=torch.double)
|
||||
|
||||
res = csr.matmul(vec)
|
||||
expected = csr.to_dense().matmul(vec)
|
||||
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
bad_vec = torch.randn(side + 10, dtype=torch.double)
|
||||
with self.assertRaisesRegex(RuntimeError, "mv: expected"):
|
||||
csr.matmul(bad_vec)
|
||||
|
||||
def test_coo_csr_conversion(self):
|
||||
size = (5, 5)
|
||||
dense = torch.randn(size)
|
||||
coo_sparse = dense.to_sparse()
|
||||
csr_sparse = coo_sparse.to_sparse_csr()
|
||||
|
||||
self.assertEqual(csr_sparse.to_dense(), dense)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -64,9 +64,10 @@ except ImportError:
|
||||
|
||||
# These functions require manual Python bindings or are not exposed to Python
|
||||
SKIP_PYTHON_BINDINGS = [
|
||||
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
|
||||
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'is_sparse_csr', 'size', 'stride',
|
||||
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
|
||||
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
|
||||
'_?sparse_csr_tensor.*',
|
||||
'_arange.*', '_range.*', '_linspace.*', '_logspace.*',
|
||||
'_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out',
|
||||
'index', 'unique_dim_consecutive',
|
||||
|
@ -406,6 +406,14 @@ static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor & self) {
|
||||
|
||||
static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs);
|
||||
|
||||
static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR);
|
||||
return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
@ -485,6 +493,7 @@ static PyMethodDef torch_functions[] = {
|
||||
{"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
|
@ -672,6 +672,7 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/ScalarOps.cpp",
|
||||
"aten/src/ATen/SequenceNumber.cpp",
|
||||
"aten/src/ATen/SparseTensorImpl.cpp",
|
||||
"aten/src/ATen/SparseCsrTensorImpl.cpp",
|
||||
"aten/src/ATen/SparseTensorUtils.cpp",
|
||||
"aten/src/ATen/TensorGeometry.cpp",
|
||||
"aten/src/ATen/TensorIndexing.cpp",
|
||||
@ -720,6 +721,7 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/native/DispatchStub.cpp",
|
||||
"aten/src/ATen/native/UpSample.cpp",
|
||||
"aten/src/ATen/native/mkl/LinearAlgebra.cpp",
|
||||
"aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp",
|
||||
"aten/src/ATen/native/mkl/SpectralOps.cpp",
|
||||
"aten/src/ATen/native/mkldnn/BinaryOps.cpp",
|
||||
"aten/src/ATen/native/mkldnn/Conv.cpp",
|
||||
@ -967,7 +969,9 @@ aten_native_source_non_codegen_list = [
|
||||
"aten/src/ATen/native/sparse/SoftMax.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseMatMul.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseTensor.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseCsrTensor.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseTensorMath.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
|
||||
"aten/src/TH/THAllocator.cpp",
|
||||
"aten/src/TH/THBlas.cpp",
|
||||
"aten/src/TH/THGeneral.cpp",
|
||||
|
@ -857,9 +857,11 @@ def main() -> None:
|
||||
dispatch_keys = [
|
||||
DispatchKey.CPU,
|
||||
DispatchKey.SparseCPU,
|
||||
DispatchKey.SparseCsrCPU,
|
||||
DispatchKey.MkldnnCPU,
|
||||
DispatchKey.CUDA,
|
||||
DispatchKey.SparseCUDA,
|
||||
DispatchKey.SparseCsrCUDA,
|
||||
DispatchKey.QuantizedCPU,
|
||||
DispatchKey.QuantizedCUDA,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
|
@ -74,6 +74,8 @@ class DispatchKey(Enum):
|
||||
MkldnnCPU = auto()
|
||||
SparseCPU = auto()
|
||||
SparseCUDA = auto()
|
||||
SparseCsrCPU = auto()
|
||||
SparseCsrCUDA = auto()
|
||||
SparseHIP = auto()
|
||||
SparseXPU = auto()
|
||||
NestedTensor = auto()
|
||||
|
@ -294,6 +294,10 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
|
||||
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
|
||||
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
|
||||
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
|
||||
'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Tensor, col_indices: Tensor,'
|
||||
' values: Tensor, size: Optional[_size]=None,'
|
||||
' *, dtype: Optional[_dtype]=None,'
|
||||
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
|
||||
'_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
|
||||
' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
|
||||
' requires_grad: bool = False) -> Tensor: ...'],
|
||||
@ -446,6 +450,7 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
|
||||
'is_cuda': ['is_cuda: _bool'],
|
||||
'is_leaf': ['is_leaf: _bool'],
|
||||
'is_sparse': ['is_sparse: _bool'],
|
||||
'is_sparse_csr' : ['is_sparse_csr: _bool'],
|
||||
'is_quantized': ['is_quantized: _bool'],
|
||||
'is_meta': ['is_meta: _bool'],
|
||||
'is_mkldnn': ['is_mkldnn: _bool'],
|
||||
|
@ -908,6 +908,40 @@ class Tensor(torch._C._TensorBase):
|
||||
# See Note [rename_ / rename API]
|
||||
return update_names(self, names, rename_map, inplace=False)
|
||||
|
||||
def to_sparse_csr(self):
|
||||
""" Convert a tensor to compressed row storage format. Only works with 2D tensors.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> dense = torch.randn(5, 5)
|
||||
>>> sparse = dense.to_sparse_csr()
|
||||
>>> sparse._nnz()
|
||||
25
|
||||
|
||||
"""
|
||||
shape = self.size()
|
||||
fill_value = 0
|
||||
if len(shape) != 2:
|
||||
raise RuntimeError("Only 2D tensors can be converted to the CSR format but got shape: ", shape)
|
||||
|
||||
if self.is_sparse:
|
||||
coalesced_self = self.coalesce()
|
||||
row_indices = coalesced_self.indices()[0]
|
||||
ro = [0]
|
||||
i = 0
|
||||
for irow in range(self.shape[0]):
|
||||
while i < row_indices.size()[0] and row_indices[i] == irow:
|
||||
i += 1
|
||||
ro.append(i)
|
||||
|
||||
return torch.sparse_csr_tensor(torch.tensor(ro, dtype=row_indices.dtype),
|
||||
coalesced_self.indices()[1], coalesced_self.values(),
|
||||
size=coalesced_self.shape, dtype=coalesced_self.dtype)
|
||||
elif self.is_sparse_csr:
|
||||
return self
|
||||
else:
|
||||
return self.to_sparse().to_sparse_csr()
|
||||
|
||||
def _update_names(self, names, inplace):
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor._update_names, (self,), self, names, inplace)
|
||||
|
@ -4651,6 +4651,11 @@ add_docstr_all('is_sparse',
|
||||
Is ``True`` if the Tensor uses sparse storage layout, ``False`` otherwise.
|
||||
""")
|
||||
|
||||
add_docstr_all('is_sparse_csr',
|
||||
r"""
|
||||
Is ``True`` if the Tensor uses sparse CSR storage layout, ``False`` otherwise.
|
||||
""")
|
||||
|
||||
add_docstr_all('device',
|
||||
r"""
|
||||
Is the :class:`torch.device` where this Tensor is.
|
||||
@ -4711,3 +4716,39 @@ Makes a ``cls`` instance with the same data pointer as ``self``. Changes
|
||||
in the output mirror changes in ``self``, and the output stays attached
|
||||
to the autograd graph. ``cls`` must be a subclass of ``Tensor``.
|
||||
""")
|
||||
|
||||
add_docstr_all('crow_indices',
|
||||
r"""
|
||||
crow_indices() -> IntTensor
|
||||
|
||||
Returns the tensor containing the compressed row indices of the :attr:`self`
|
||||
tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``.
|
||||
The ``crow_indices`` tensor is strictly of shape (:attr:`self`.size(0) + 1)
|
||||
and of type ``int32`` or ``int64``. When using MKL routines such as sparse
|
||||
matrix multiplication, it is necessary to use ``int32`` indexing in order
|
||||
to avoid downcasting and potentially losing information.
|
||||
|
||||
Example::
|
||||
>>> csr = torch.eye(5,5).to_sparse_csr()
|
||||
>>> csr.crow_indices()
|
||||
tensor([0, 1, 2, 3, 4, 5], dtype=torch.int32)
|
||||
|
||||
""")
|
||||
|
||||
add_docstr_all('col_indices',
|
||||
r"""
|
||||
col_indices() -> IntTensor
|
||||
|
||||
Returns the tensor containing the column indices of the :attr:`self`
|
||||
tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``.
|
||||
The ``col_indices`` tensor is strictly of shape (:attr:`self`.nnz())
|
||||
and of type ``int32`` or ``int64``. When using MKL routines such as sparse
|
||||
matrix multiplication, it is necessary to use ``int32`` indexing in order
|
||||
to avoid downcasting and potentially losing information.
|
||||
|
||||
Example::
|
||||
>>> csr = torch.eye(5,5).to_sparse_csr()
|
||||
>>> csr.col_indices()
|
||||
tensor([0, 1, 2, 3, 4], dtype=torch.int32)
|
||||
|
||||
""")
|
||||
|
@ -315,6 +315,29 @@ def _str_intern(inp):
|
||||
if values.numel() == 0:
|
||||
values_str += ', size=' + str(tuple(values.shape))
|
||||
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
|
||||
elif self.is_sparse_csr:
|
||||
suffixes.append('size=' + str(tuple(self.shape)))
|
||||
suffixes.append('nnz=' + str(self._nnz()))
|
||||
if not has_default_dtype:
|
||||
suffixes.append('dtype=' + str(self.dtype))
|
||||
crow_indices_prefix = 'crow_indices=tensor('
|
||||
crow_indices = self.crow_indices().detach()
|
||||
crow_indices_str = _tensor_str(crow_indices, indent + len(crow_indices_prefix))
|
||||
if crow_indices.numel() == 0:
|
||||
crow_indices_str += ', size=' + str(tuple(crow_indices.shape))
|
||||
col_indices_prefix = 'col_indices=tensor('
|
||||
col_indices = self.col_indices().detach()
|
||||
col_indices_str = _tensor_str(col_indices, indent + len(col_indices_prefix))
|
||||
if col_indices.numel() == 0:
|
||||
col_indices_str += ', size=' + str(tuple(col_indices.shape))
|
||||
values_prefix = 'values=tensor('
|
||||
values = self.values().detach()
|
||||
values_str = _tensor_str(values, indent + len(values_prefix))
|
||||
if values.numel() == 0:
|
||||
values_str += ', size=' + str(tuple(values.shape))
|
||||
tensor_str = crow_indices_prefix + crow_indices_str + '),\n' + ' ' * indent +\
|
||||
col_indices_prefix + col_indices_str + '),\n' + ' ' * indent +\
|
||||
values_prefix + values_str + ')'
|
||||
elif self.is_quantized:
|
||||
suffixes.append('size=' + str(tuple(self.shape)))
|
||||
if not has_default_dtype:
|
||||
|
@ -7880,6 +7880,49 @@ Example::
|
||||
[-0.0881, 0.4370, 0.2275, 1.0284]])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.sparse_csr_tensor,
|
||||
r"""
|
||||
sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
|
||||
|
||||
Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) <sparse-csr-docs>` with specified
|
||||
values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations
|
||||
in CSR format are typically faster than that for sparse tensors in COO format. Make you have a look
|
||||
at :ref:`the note on the data type of the indices <sparse-csr-docs>`.
|
||||
|
||||
Args:
|
||||
crow_indices (array_like): One-dimensional array of size size[0] + 1. The last element
|
||||
is the number of non-zeros. This tensor encodes the index in values and col_indices
|
||||
depending on where the given row starts. Each successive number in the tensor
|
||||
subtracted by the number before it denotes the number of elements in a given row.
|
||||
col_indices (array_like): Column co-ordinates of each element in values. Strictly one
|
||||
dimensional tensor with the same length as values.
|
||||
values (array_list): Initial values for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar,
|
||||
and other types.
|
||||
size (list, tuple, :class:`torch.Size`, optional): Size of the sparse tensor. If not provided, the
|
||||
size will be inferred as the minimum size big enough to hold all non-zero elements.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if None, infers data type from :attr:`values`.
|
||||
device (:class:`torch.device`, optional): the desired device of returned tensor.
|
||||
Default: if None, uses the current device for the default tensor type
|
||||
(see :func:`torch.set_default_tensor_type`). :attr:`device` will be the CPU
|
||||
for CPU tensor types and the current CUDA device for CUDA tensor types.
|
||||
{requires_grad}
|
||||
|
||||
Example ::
|
||||
>>> crow_indices = [0, 2, 4]
|
||||
>>> col_indices = [0, 1, 0, 1]
|
||||
>>> values = [1, 2, 3, 4]
|
||||
>>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
|
||||
... torch.tensor(col_indices, dtype=torch.int64),
|
||||
... torch.tensor(values), dtype=torch.double)
|
||||
tensor(crow_indices=tensor([0, 2, 4]),
|
||||
col_indices=tensor([0, 1, 0, 1]),
|
||||
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
|
||||
dtype=torch.float64, layout=torch.sparse_csr)
|
||||
""".format(**factory_common_args))
|
||||
|
||||
add_docstr(torch.sparse_coo_tensor,
|
||||
r"""
|
||||
sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
|
||||
|
@ -599,6 +599,17 @@ PyObject *THPVariable_is_sparse(THPVariable *self, void *unused)
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject *THPVariable_is_sparse_csr(THPVariable *self, void *unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject *)self)) {
|
||||
return handle_torch_function_getter(self, "is_sparse_csr");
|
||||
}
|
||||
auto& self_ = self->cdata;
|
||||
return torch::autograd::utils::wrap(self_.is_sparse_csr());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
@ -761,6 +772,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
||||
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
|
||||
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
|
||||
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
|
||||
{"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr},
|
||||
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
|
||||
{"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
|
||||
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
|
||||
|
@ -103,25 +103,16 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
static const PropertiesLookup builtin_properties = {
|
||||
{TypeKind::TensorType,
|
||||
{
|
||||
{"dtype", "prim"},
|
||||
{"device", "prim"},
|
||||
{"grad", "prim"},
|
||||
{"data", "prim"},
|
||||
{"shape", "prim"},
|
||||
{"is_cuda", "prim"},
|
||||
{"is_xpu", "prim"},
|
||||
{"is_sparse", "prim"},
|
||||
{"is_mkldnn", "prim"},
|
||||
{"is_mlc", "prim"},
|
||||
{"is_quantized", "prim"},
|
||||
{"is_vulkan", "prim"},
|
||||
{"is_meta", "prim"},
|
||||
{"is_leaf", "aten"},
|
||||
{"requires_grad", "prim"},
|
||||
{"layout", "prim"},
|
||||
{"T", "prim"},
|
||||
{"ndim", "prim"},
|
||||
{"name", "prim"},
|
||||
{"dtype", "prim"}, {"device", "prim"},
|
||||
{"grad", "prim"}, {"data", "prim"},
|
||||
{"shape", "prim"}, {"is_cuda", "prim"},
|
||||
{"is_xpu", "prim"}, {"is_sparse", "prim"},
|
||||
{"is_sparse_csr", "prim"}, {"is_mkldnn", "prim"},
|
||||
{"is_mlc", "prim"}, {"is_quantized", "prim"},
|
||||
{"is_vulkan", "prim"}, {"is_meta", "prim"},
|
||||
{"is_leaf", "aten"}, {"requires_grad", "prim"},
|
||||
{"layout", "prim"}, {"T", "prim"},
|
||||
{"ndim", "prim"}, {"name", "prim"},
|
||||
}},
|
||||
{TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
|
||||
auto kind = value_->type()->kind();
|
||||
|
@ -281,6 +281,14 @@ RegisterOperators reg(
|
||||
push(stack, a.is_sparse());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"prim::is_sparse_csr(Tensor a) -> bool",
|
||||
[](Stack* stack) {
|
||||
at::Tensor a;
|
||||
pop(stack, a);
|
||||
push(stack, a.is_sparse_csr());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"prim::is_mkldnn(Tensor a) -> bool",
|
||||
[](Stack* stack) {
|
||||
|
@ -123,6 +123,14 @@ PyObject *Tensor_is_sparse(PyTensorType *self, void *unused) {
|
||||
}
|
||||
}
|
||||
|
||||
PyObject *Tensor_is_sparse_csr(PyTensorType *self, void *unused) {
|
||||
if (self->layout->layout == at::Layout::SparseCsr) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
}
|
||||
|
||||
static struct PyMethodDef metaclass_methods[] = {
|
||||
{"__instancecheck__", Tensor_instancecheck, METH_O, nullptr},
|
||||
{nullptr}
|
||||
@ -135,6 +143,7 @@ static struct PyGetSetDef metaclass_properties[] = {
|
||||
{"layout", (getter)Tensor_layout, nullptr, nullptr, nullptr},
|
||||
{"is_cuda", (getter)Tensor_is_cuda, nullptr, nullptr, nullptr},
|
||||
{"is_sparse", (getter)Tensor_is_sparse, nullptr, nullptr, nullptr},
|
||||
{"is_sparse_csr",(getter)Tensor_is_sparse_csr, nullptr, nullptr, nullptr},
|
||||
{nullptr}
|
||||
};
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include <torch/csrc/utils/tensor_layouts.h>
|
||||
#include <ATen/Layout.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
@ -6,6 +5,7 @@
|
||||
#include <torch/csrc/Layout.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/utils/tensor_layouts.h>
|
||||
|
||||
namespace torch { namespace utils {
|
||||
|
||||
@ -13,21 +13,29 @@ void initializeLayouts() {
|
||||
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
|
||||
if (!torch_module) throw python_error();
|
||||
|
||||
PyObject *strided_layout = THPLayout_New(at::Layout::Strided, "torch.strided");
|
||||
PyObject* strided_layout = THPLayout_New(at::Layout::Strided, "torch.strided");
|
||||
Py_INCREF(strided_layout);
|
||||
if (PyModule_AddObject(torch_module, "strided", strided_layout) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
registerLayoutObject((THPLayout*)strided_layout, at::Layout::Strided);
|
||||
|
||||
PyObject *sparse_coo_layout = THPLayout_New(at::Layout::Sparse, "torch.sparse_coo");
|
||||
PyObject* sparse_coo_layout = THPLayout_New(at::Layout::Sparse, "torch.sparse_coo");
|
||||
Py_INCREF(sparse_coo_layout);
|
||||
if (PyModule_AddObject(torch_module, "sparse_coo", sparse_coo_layout) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
registerLayoutObject((THPLayout*)sparse_coo_layout, at::Layout::Sparse);
|
||||
|
||||
PyObject *mkldnn_layout = THPLayout_New(at::Layout::Mkldnn, "torch._mkldnn");
|
||||
PyObject* sparse_csr_layout =
|
||||
THPLayout_New(at::Layout::SparseCsr, "torch.sparse_csr");
|
||||
Py_INCREF(sparse_csr_layout);
|
||||
if (PyModule_AddObject(torch_module, "sparse_csr", sparse_csr_layout) != 0) {
|
||||
throw python_error();
|
||||
}
|
||||
registerLayoutObject((THPLayout*)sparse_csr_layout, at::Layout::SparseCsr);
|
||||
|
||||
PyObject* mkldnn_layout = THPLayout_New(at::Layout::Mkldnn, "torch._mkldnn");
|
||||
Py_INCREF(mkldnn_layout);
|
||||
if (PyModule_AddObject(torch_module, "_mkldnn", mkldnn_layout) != 0) {
|
||||
throw python_error();
|
||||
|
@ -34,6 +34,7 @@ using at::IntArrayRef;
|
||||
using at::kCPU;
|
||||
using at::kCUDA;
|
||||
using at::kLong;
|
||||
using at::kInt;
|
||||
using at::Scalar;
|
||||
using at::ScalarType;
|
||||
using at::Storage;
|
||||
@ -597,6 +598,66 @@ Tensor indexing_tensor_from_data(
|
||||
}
|
||||
}
|
||||
|
||||
Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
|
||||
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
|
||||
static PythonArgParser parser({
|
||||
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
|
||||
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
|
||||
});
|
||||
const int NUM_ARGS = 9, CROW_INDICES_ARG = 0, COL_INDICES_ARG = 1, VALUES_ARG = 2;
|
||||
ParsedArgs<NUM_ARGS> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
THPObjectPtr crow_indices_dtype_attr(PyObject_GetAttrString(r.pyobject(CROW_INDICES_ARG), "dtype"));
|
||||
THPObjectPtr col_indices_dtype_attr(PyObject_GetAttrString(r.pyobject(COL_INDICES_ARG), "dtype"));
|
||||
at::ScalarType crow_indices_scalar_type = reinterpret_cast<THPDtype*>(
|
||||
crow_indices_dtype_attr.get())->scalar_type;
|
||||
at::ScalarType col_indices_scalar_type = reinterpret_cast<THPDtype*>(
|
||||
col_indices_dtype_attr.get())->scalar_type;
|
||||
|
||||
if (r.idx == 0) {
|
||||
const int SIZE_ARRAY_ARG = 3, TYPE_INFERENCE_ARG = 4, DEVICE_TYPE_ARG = 7, REQ_GRAD_ARG = 8;
|
||||
bool type_inference = r.isNone(TYPE_INFERENCE_ARG);
|
||||
const auto inferred_options = typeIdWithDefault(r, DEVICE_TYPE_ARG, dispatch_key);
|
||||
const auto inferred_scalar_type = r.scalartypeWithDefault(TYPE_INFERENCE_ARG, scalar_type);
|
||||
at::OptionalDeviceGuard device_guard(r.deviceOptional(DEVICE_TYPE_ARG));
|
||||
|
||||
Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
|
||||
r.pyobject(VALUES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/type_inference);
|
||||
Tensor crow_indices = internal_new_from_data(values.options(),
|
||||
crow_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG), r.pyobject(CROW_INDICES_ARG),
|
||||
/*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/false);
|
||||
Tensor col_indices = internal_new_from_data(values.options(),
|
||||
col_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG), r.pyobject(COL_INDICES_ARG),
|
||||
/*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/false);
|
||||
|
||||
return at::sparse_csr_tensor(crow_indices, col_indices, values, r.intlist(SIZE_ARRAY_ARG),
|
||||
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
|
||||
} else if (r.idx == 1) {
|
||||
const int TYPE_INFERENCE_ARG = 3, DEVICE_TYPE_ARG = 6, REQ_GRAD_ARG = 7;
|
||||
bool type_inference = r.isNone(TYPE_INFERENCE_ARG);
|
||||
const auto inferred_options = typeIdWithDefault(r, DEVICE_TYPE_ARG, dispatch_key);
|
||||
const auto inferred_scalar_type = r.scalartypeWithDefault(TYPE_INFERENCE_ARG, scalar_type);
|
||||
at::OptionalDeviceGuard device_guard(r.deviceOptional(DEVICE_TYPE_ARG));
|
||||
|
||||
Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
|
||||
r.pyobject(VALUES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/type_inference);
|
||||
Tensor crow_indices = internal_new_from_data(values.options(),
|
||||
crow_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
|
||||
r.pyobject(CROW_INDICES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/false);
|
||||
Tensor col_indices = internal_new_from_data(values.options(), col_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
|
||||
r.pyobject(COL_INDICES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/false);
|
||||
return at::sparse_csr_tensor(crow_indices, col_indices, values,
|
||||
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
|
||||
}
|
||||
throw std::runtime_error("sparse_csr_tensor(): invalid arguments");
|
||||
}
|
||||
|
||||
// Note [Ensuring sparse values and indices match devices]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// In all places where we construct indices, we read out options from values
|
||||
|
@ -13,6 +13,7 @@ at::Tensor indexing_tensor_from_data(
|
||||
at::ScalarType scalar_type,
|
||||
c10::optional<at::Device> device,
|
||||
PyObject* data);
|
||||
at::Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor _sparse_coo_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
void _validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
|
@ -160,6 +160,7 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
torch.result_type,
|
||||
torch.scalar_tensor,
|
||||
torch.sparse_coo_tensor,
|
||||
torch.sparse_csr_tensor,
|
||||
torch.tril_indices,
|
||||
torch.triu_indices,
|
||||
torch.vander,
|
||||
@ -216,6 +217,7 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
Tensor._make_subclass,
|
||||
Tensor.stride,
|
||||
Tensor.unflatten,
|
||||
Tensor.to_sparse_csr,
|
||||
Tensor._reduce_ex_internal,
|
||||
}
|
||||
|
||||
@ -935,6 +937,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
Tensor.is_mkldnn.__get__: lambda self: -1,
|
||||
Tensor.is_quantized.__get__: lambda self: -1,
|
||||
Tensor.is_sparse.__get__: lambda self: -1,
|
||||
Tensor.is_sparse_csr.__get__: lambda self: -1,
|
||||
Tensor.is_vulkan.__get__: lambda self: -1,
|
||||
Tensor.layout.__get__: lambda self: -1,
|
||||
Tensor.name.__get__: lambda self: -1,
|
||||
@ -954,6 +957,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
Tensor._indices: lambda self: -1,
|
||||
Tensor._is_view: lambda self: -1,
|
||||
Tensor._nnz: lambda self: -1,
|
||||
Tensor.crow_indices: lambda self: -1,
|
||||
Tensor.col_indices: lambda self: -1,
|
||||
Tensor._update_names: lambda self, names, inplace: -1,
|
||||
Tensor._values: lambda self: -1,
|
||||
Tensor.align_as: lambda self, other: -1,
|
||||
|
Reference in New Issue
Block a user