Add CSR (compressed sparse row) layout for sparse tensors (#50937)

Summary:
Implement compressed sparse row format. Derived from the GCS implementation at https://github.com/pytorch/pytorch/pull/44190

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50937

Reviewed By: mrshenli

Differential Revision: D27439865

Pulled By: ezyang

fbshipit-source-id: 3ba3dcb9679505b980ff6a5f513e913bbae2fb1d
This commit is contained in:
Sameer Deshmukh
2021-04-12 10:07:56 -07:00
committed by Facebook GitHub Bot
parent c6d9ca0c2b
commit 5fb1142702
52 changed files with 2310 additions and 201 deletions

View File

@ -130,6 +130,7 @@ genrule(
"aten/src/ATen/RegisterMkldnnCPU.cpp",
"aten/src/ATen/RegisterQuantizedCPU.cpp",
"aten/src/ATen/RegisterSparseCPU.cpp",
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
"aten/src/ATen/RegisterMeta.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",

View File

@ -0,0 +1,153 @@
#include <ATen/ATen.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/core/LegacyTypeDispatch.h>
namespace at {
namespace {
DeviceType SparseCsrTensorSetToDeviceType(DispatchKeySet key_set) {
if (key_set.has(DispatchKey::SparseCsrCPU)) {
return kCPU;
} else if (key_set.has(DispatchKey::SparseCsrCUDA)) {
return kCUDA;
} else {
TORCH_CHECK(false,
"Cannot construct SparseCsrTensor with non-sparse tensor type ID ",
key_set);
}
}
} // namespace
SparseCsrTensorImpl::SparseCsrTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type)
: SparseCsrTensorImpl(
key_set,
data_type,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.dtype(ScalarType::Int)) // crow_indices
,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.dtype(ScalarType::Int)) // col_indices
,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.dtype(data_type)) // values
) {}
SparseCsrTensorImpl::SparseCsrTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
at::Tensor crow_indices,
at::Tensor col_indices,
at::Tensor values)
: TensorImpl(key_set, data_type, values.device()),
crow_indices_(std::move(crow_indices)),
col_indices_(std::move(col_indices)),
values_(std::move(values)) {}
void SparseCsrTensorImpl::resize_and_clear_(
const int64_t nnz_size,
IntArrayRef size) {
// call crow_indices().options() here since the struct contructor calls the
// tensor constructor with args for device specific init.
auto empty_crow_indices = at::empty(size[0] + 1, crow_indices().options());
auto empty_col_indices = at::empty(nnz_size, col_indices().options());
auto empty_values = at::empty(nnz_size, values().options());
crow_indices_ = empty_crow_indices;
col_indices_ = empty_col_indices;
values_ = empty_values;
sizes_and_strides_.set_sizes(size);
}
void SparseCsrTensorImpl::resize_as_sparse_csr_tensor_(const Tensor& src) {
crow_indices_ = at::empty_like(
src.crow_indices(),
src.crow_indices().options(),
src.crow_indices().suggest_memory_format());
col_indices_ = at::empty_like(
src.col_indices(),
src.col_indices().options(),
src.col_indices().suggest_memory_format());
values_ = at::empty_like(
src.values(),
src.values().options(),
src.values().suggest_memory_format());
sizes_and_strides_.set_sizes(src.sizes());
}
void SparseCsrTensorImpl::set_member_tensors(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values) {
auto crow_indices_type = crow_indices.scalar_type();
auto col_indices_type = col_indices.scalar_type();
TORCH_CHECK(
crow_indices_type == col_indices_type,
"both crow_indices and col_indices should have the same type.");
TORCH_CHECK(
crow_indices_type == kInt || crow_indices_type == kLong,
"crow_indices and col_indices must be an int32 or int64 type, but got: ",
crow_indices_type);
TORCH_CHECK(
values.scalar_type() == typeMetaToScalarType(dtype()),
"dtype of values (",
values.scalar_type(),
") must match dtype of sparse tensor (",
typeMetaToScalarType(dtype()),
")");
TORCH_CHECK(
col_indices.layout() == kStrided,
"expected col_indices to be a strided tensor, but got indices of layout ",
col_indices.layout());
TORCH_CHECK(
crow_indices.layout() == kStrided,
"expected crow_indices to be a strided tensor, but got crow_indices of layout ",
crow_indices.layout());
TORCH_CHECK(
values.layout() == kStrided && values.is_contiguous(),
"expected values to be a strided and contiguous tensor, but got values of layout ",
values.layout());
TORCH_CHECK(
values.device().type() == device().type(),
"device type of values (",
values.device().type(),
") must match device type of device().type()",
device().type(),
")");
TORCH_CHECK(
values.is_cuda() || col_indices.get_device() == crow_indices.get_device(),
"crow_indices and col_indices devices (",
crow_indices.get_device(),
", ",
col_indices.get_device(),
") must match with the (non-cuda) device of values (",
values.get_device(),
")");
TORCH_CHECK(
col_indices.size(0) == values.size(0),
"col_indices and values must have equal sizes, but got col_indices.size(0): ",
col_indices.size(0),
", values.size(0): ",
values.size(0));
crow_indices_ = crow_indices;
col_indices_ = col_indices;
values_ = values;
}
} // namespace at

View File

@ -0,0 +1,55 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>
namespace at {
// Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
// denoting the data: `crow_indices_`, `col_indices_` and `values_`.
// The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
// that represents the compressed row indices of the CSR tensor. The
// `col_indices_` tensor is an integer tensor of shape `(nnz())`
// that explicitly stores the column indices of each value of the sparse
// tensor. The `values_` tensor can be of any pytorch-supported data type
// and has shape `(nnz())`.
//
// Since the main advantage of the CSR format over the COO format is speed of
// computation, care must be taken to facilitate smooth interfacing of
// these data structures with optimized libraries such as MKL and MAGMA.
// Since the MKL interface for pytorch currently uses indexing with int32
// type, it is important to make sure that the `crow_indices` and `col_indices`
// are of type int32 when calling MKL routines such as SPMM or SPMV.
//
// If not calling MKL, it should be alright to use 64 bit integer tensors
// for indexing.
struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
Tensor crow_indices_;
Tensor col_indices_;
Tensor values_;
public:
explicit SparseCsrTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
void resize_and_clear_(const int64_t nnz_size, IntArrayRef size);
void resize_as_sparse_csr_tensor_(const Tensor& src);
void set_member_tensors(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values);
const Tensor& crow_indices() const { return crow_indices_; }
const Tensor& col_indices() const { return col_indices_; }
const Tensor& values() const { return values_; }
int nnz() { return values_.size(0); }
private:
explicit SparseCsrTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
at::Tensor crow_indices,
at::Tensor col_indices,
at::Tensor values);
};
} // namespace at

View File

@ -0,0 +1,20 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
namespace at {
namespace sparse_csr {
using SparseCsrTensor = Tensor;
inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
AT_ASSERTM(
self.is_sparse_csr(),
"_internal_get_SparseCsrTensorImpl: not a sparse CSR tensor");
return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
}
} // namespace sparse
} // namespace at

View File

@ -34,6 +34,10 @@ class TORCH_API DeprecatedTypeProperties {
return layout_from_backend(backend()) == kSparse;
}
bool is_sparse_csr() const {
return layout_from_backend(backend()) == kSparseCsr;
}
DeviceType device_type() const {
return backendToDeviceType(backend_);
}

View File

@ -402,6 +402,7 @@ _(aten, is_same_size) \
_(aten, is_set_to) \
_(aten, is_signed) \
_(aten, is_sparse) \
_(aten, is_sparse_csr) \
_(aten, isclose) \
_(aten, isreal) \
_(aten, istft) \

View File

@ -85,7 +85,7 @@ Tensor& resize_as_(
!optional_memory_format.has_value(),
"Unsupported memory format for sparse tensor resize_as_ :",
optional_memory_format.value());
return native::resize_as_sparse_(self, the_template);
return at::native::resize_as_sparse_(self, the_template);
}
Tensor& result = self.resize_(the_template.sizes());
if (optional_memory_format.has_value()) {

View File

@ -30,6 +30,10 @@ bool is_sparse(const Tensor& self) {
return self.is_sparse();
}
bool is_sparse_csr(const Tensor& self) {
return self.is_sparse_csr();
}
bool is_quantized(const Tensor& self) {
return self.is_quantized();
}

View File

@ -0,0 +1,236 @@
#include <ATen/native/mkl/SparseCsrLinearAlgebra.h>
// Don't compile with MKL for MSVC/macos since linking the sparse MKL routines
// needs some build fixes.
// https://github.com/pytorch/pytorch/pull/50937#issuecomment-778732740
// Macros source:
// https://web.archive.org/web/20191012035921/http://nadeausoftware.com/articles/2012/01/c_c_tip_how_use_compiler_predefined_macros_detect_operating_system
#if !AT_MKL_ENABLED() || defined(_MSC_VER) || defined(__APPLE__) || \
defined(__MACH__)
namespace at {
namespace sparse_csr {
Tensor& _sparse_mm_mkl_(
Tensor& self,
const SparseCsrTensor& sparse_,
const Tensor& dense,
const Tensor& t,
const Scalar& alpha,
const Scalar& beta) {
#if _MSC_VER
AT_ERROR("sparse_mm_mkl: MKL support is disabled on Windows");
#elif __APPLE__ || __MACH__
AT_ERROR("sparse_mm_mkl: MKL support is disabled on macos/iOS.");
#else
AT_ERROR("sparse_mm_mkl: ATen not compiled with MKL support");
#endif
return self; // for stopping compiler warnings.
}
} // namespace native
} // namespace at
#else // AT_MKL_ENABLED
#include <ATen/mkl/Descriptors.h>
#include <ATen/mkl/Exceptions.h>
#include <ATen/mkl/Limits.h>
#include <mkl.h>
#include <mkl_spblas.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/SparseCsrTensorImpl.h>
namespace at {
namespace sparse_csr {
#ifdef MKL_ILP64
static constexpr ScalarType TORCH_INT_TYPE = at::kLong;
#else
static constexpr ScalarType TORCH_INT_TYPE = at::kInt;
#endif
class SparseCsrMKLInterface {
private:
sparse_matrix_t A = 0;
matrix_descr desc;
public:
SparseCsrMKLInterface(
MKL_INT* col_indices,
MKL_INT* crow_indices,
double* values,
MKL_INT nrows,
MKL_INT ncols) {
desc.type = SPARSE_MATRIX_TYPE_GENERAL;
int retval = mkl_sparse_d_create_csr(
&A,
SPARSE_INDEX_BASE_ZERO,
nrows,
ncols,
crow_indices,
crow_indices + 1,
col_indices,
values);
TORCH_CHECK(
retval == 0,
"mkl_sparse_d_create_csr failed with error code: ",
retval);
}
SparseCsrMKLInterface(
MKL_INT* col_indices,
MKL_INT* crow_indices,
float* values,
MKL_INT nrows,
MKL_INT ncols) {
desc.type = SPARSE_MATRIX_TYPE_GENERAL;
int retval = mkl_sparse_s_create_csr(
&A,
SPARSE_INDEX_BASE_ZERO,
nrows,
ncols,
crow_indices,
crow_indices + 1,
col_indices,
values);
TORCH_CHECK(
retval == 0,
"mkl_sparse_s_create_csr failed with error code: ",
retval);
}
inline void sparse_mm(
float* res,
float* dense,
float alpha,
float beta,
MKL_INT nrows,
MKL_INT ncols,
MKL_INT dense_ncols) {
int stat = mkl_sparse_s_mm(
SPARSE_OPERATION_NON_TRANSPOSE,
alpha,
A,
desc,
SPARSE_LAYOUT_ROW_MAJOR,
dense,
dense_ncols,
dense_ncols,
beta,
res,
dense_ncols);
TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
}
inline void sparse_mm(
double* res,
double* dense,
double alpha,
double beta,
MKL_INT nrows,
MKL_INT ncols,
MKL_INT dense_ncols) {
int stat = mkl_sparse_d_mm(
SPARSE_OPERATION_NON_TRANSPOSE,
alpha,
A,
desc,
SPARSE_LAYOUT_ROW_MAJOR,
dense,
dense_ncols,
dense_ncols,
beta,
res,
dense_ncols);
TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
}
~SparseCsrMKLInterface() {
mkl_sparse_destroy(A);
}
};
template <typename scalar_t>
static inline void sparse_mm_mkl_template(
Tensor& res,
const Tensor& col_indices,
const Tensor& crow_indices,
const Tensor& values,
const Tensor& dense,
const Tensor& t,
const Scalar& alpha,
const Scalar& beta,
IntArrayRef size,
IntArrayRef dense_size) {
SparseCsrMKLInterface mkl_impl(
col_indices.data_ptr<MKL_INT>(),
crow_indices.data_ptr<MKL_INT>(),
values.data_ptr<scalar_t>(),
size[0],
size[1]);
mkl_impl.sparse_mm(
res.data_ptr<scalar_t>(),
dense.data_ptr<scalar_t>(),
alpha.to<scalar_t>(),
beta.to<scalar_t>(),
size[0],
size[1],
dense_size[1]);
}
static bool inline constexpr is_mkl_int32_index() {
#ifdef MKL_ILP64
return false;
#else
return true;
#endif
}
Tensor& _sparse_mm_mkl_(
Tensor& self,
const SparseCsrTensor& sparse_,
const Tensor& dense,
const Tensor& t,
const Scalar& alpha,
const Scalar& beta) {
if (is_mkl_int32_index()) {
if (sparse_.crow_indices().scalar_type() != kInt) {
TORCH_WARN(
"Pytorch is compiled with MKL LP64 and will convert crow_indices to int32.");
}
if (sparse_.col_indices().scalar_type() != kInt) {
TORCH_WARN(
"Pytorch is compiled with MKL LP64 and will convert col_indices to int32.");
}
} else { // This is for future proofing if we ever change to using MKL ILP64.
if (sparse_.crow_indices().scalar_type() != kLong) {
TORCH_WARN(
"Pytorch is compiled with MKL ILP64 and will convert crow_indices dtype to int64.");
}
if (sparse_.col_indices().scalar_type() != kLong) {
TORCH_WARN(
"Pytorch is compiled with MKL ILP64 and will convert col_indices dtype to int64.");
}
}
AT_DISPATCH_FLOATING_TYPES(
dense.scalar_type(), "addmm_sparse_csr_dense", [&] {
sparse_mm_mkl_template<scalar_t>(
self,
sparse_.col_indices().to(TORCH_INT_TYPE),
sparse_.crow_indices().to(TORCH_INT_TYPE),
sparse_.values(),
dense,
t,
alpha,
beta,
sparse_.sizes(),
dense.sizes());
});
return self;
}
} // namespace native
} // namespace at
#endif // AT_MKL_ENABLED

View File

@ -0,0 +1,14 @@
#include <ATen/ATen.h>
#include <ATen/SparseCsrTensorUtils.h>
namespace at {
namespace sparse_csr {
Tensor& _sparse_mm_mkl_(
Tensor& self,
const SparseCsrTensor& sparse_,
const Tensor& dense,
const Tensor& t,
const Scalar& alpha,
const Scalar& beta);
} // namespace native
} // namespace at

View File

@ -330,6 +330,7 @@
variants: function, method
dispatch:
SparseCPU, SparseCUDA: add_sparse
SparseCsrCPU: add_sparse_csr
MkldnnCPU: mkldnn_add
- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
@ -337,6 +338,7 @@
structured_delegate: add.out
dispatch:
SparseCPU, SparseCUDA: add_sparse_
SparseCsrCPU: add_sparse_csr_
MkldnnCPU: mkldnn_add_
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
@ -346,6 +348,7 @@
CPU, CUDA: add_out
SparseCPU: add_out_sparse_cpu
SparseCUDA: add_out_sparse_cuda
SparseCsrCPU: add_out_sparse_csr_cpu
MkldnnCPU: mkldnn_add_out
- func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
@ -388,6 +391,7 @@
dispatch:
CPU, CUDA: addmv_out
- func: _addmv_impl_(Tensor(a!) self, Tensor self2, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
dispatch:
CPU: addmv_impl_cpu
@ -2551,13 +2555,14 @@
dispatch:
CPU: mm_cpu
CUDA: mm_cuda
SparseCPU, SparseCUDA: _sparse_mm
SparseCPU, SparseCUDA, SparseCsrCPU: _sparse_mm
- func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: mm_cpu_out
CUDA: mm_out_cuda
SparseCPU, SparseCUDA: _sparse_mm_out
SparseCsrCPU: _sparse_csr_mm_out
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
@ -2638,7 +2643,7 @@
variants: function, method
dispatch:
CPU, CUDA: mv
SparseCPU, SparseCUDA: mv_sparse
SparseCPU, SparseCUDA, SparseCsrCPU: mv_sparse
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
@ -4001,6 +4006,12 @@
dispatch:
CompositeExplicitAutograd: resize_as_
- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
variants: function
dispatch:
SparseCPU, SparseCUDA: resize_as_sparse_
SparseCsrCPU: resize_as_sparse_csr_
- func: zero_(Tensor(a!) self) -> Tensor(a!)
variants: method, function
dispatch:
@ -4088,6 +4099,7 @@
CUDA: addmm_out_cuda
SparseCPU: addmm_out_sparse_dense_cpu
SparseCUDA: addmm_out_sparse_dense_cuda
SparseCsrCPU: addmm_out_sparse_csr_dense_cpu
- func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: function, method
@ -4096,6 +4108,7 @@
CUDA: addmm_cuda
SparseCPU: addmm_sparse_dense_cpu
SparseCUDA: addmm_sparse_dense_cuda
SparseCsrCPU: addmm_sparse_csr_dense_cpu
- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
variants: method
@ -4215,9 +4228,13 @@
# the view relation is not tracked by autograd, but the version counter is still
# shared. In other words, their outputs are non-differentiable views of the
# sparse tensor.
# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given
# the default would never make sense.
- func: sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
- func: sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
@ -4255,7 +4272,7 @@
- func: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_to_dense
SparseCPU, SparseCUDA, SparseCsrCPU: sparse_to_dense
MkldnnCPU: mkldnn_to_dense
- func: to_dense_backward(Tensor grad, Tensor input) -> Tensor
@ -4290,6 +4307,7 @@
variants: method
dispatch:
SparseCPU, SparseCUDA: _nnz_sparse
SparseCsrCPU: _nnz_sparse_csr
device_guard: False
# NOTE: [ coalesce autograd ]
@ -4342,6 +4360,19 @@
variants: method
dispatch:
SparseCPU, SparseCUDA: values_sparse
SparseCsrCPU: values_sparse_csr
device_guard: False
- func: crow_indices(Tensor(a) self) -> Tensor(a)
variants: method
dispatch:
SparseCsrCPU: crow_indices_sparse_csr
device_guard: False
- func: col_indices(Tensor(a) self) -> Tensor(a)
variants: method
dispatch:
SparseCsrCPU: col_indices_sparse_csr
device_guard: False
- func: hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)

View File

@ -0,0 +1,159 @@
// Basic functions on sparse tensors
#include <ATen/ATen.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/Layout.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/SparseTensorImpl.h>
namespace at {
namespace native {
using namespace at::sparse_csr;
// Construction of CSR tensors.
SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
// TODO: remove this comment after enabling autograd support for CSR tensor
// constructor.
// TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch());
TORCH_INTERNAL_ASSERT(options.layout() == kSparseCsr);
DispatchKey dispatch_key;
if (options.device().is_cuda()) {
dispatch_key = DispatchKey::SparseCsrCUDA;
} else {
TORCH_INTERNAL_ASSERT(options.device().is_cpu());
dispatch_key = DispatchKey::SparseCsrCPU;
}
return detail::make_tensor<SparseCsrTensorImpl>(
DispatchKeySet(dispatch_key), options.dtype());
}
// TODO: This constructor should probably use an ATen abstract method in order
// to make autograd dispatch available for the CSR constructor. See the relevant
// note in native_functions.yaml.
Tensor sparse_csr_tensor(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
TORCH_CHECK(
options.layout() == kSparseCsr,
"expected sparse CSR layout, but got layout ",
options.layout());
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
auto crow_indices_accessor = crow_indices.accessor<index_t, 1>();
TORCH_CHECK(
crow_indices_accessor[crow_indices.numel() - 1] <= col_indices.numel(),
"last value of crow_indices should be less than length of col_indices.");
TORCH_CHECK(
crow_indices_accessor[0] == 0, "0th value of crow_indices must be 0.");
});
TORCH_CHECK(
crow_indices.dim() == 1,
"crow_indices must have dim=1 but got crow_indices.dim()=",
crow_indices.dim());
TORCH_CHECK(
col_indices.dim() == 1,
"col_indices must have dim=1 but got col_indices.dim()=",
col_indices.dim());
TORCH_CHECK(
values.dim() == 1,
"values must have dim=1 but got values.dim()=",
values.dim());
TORCH_CHECK(
(crow_indices.numel() - 1) == size[0],
"crow_indices.numel() must be size(0) + 1, but got: ",
crow_indices.numel());
SparseCsrTensor self = new_csr_tensor(options);
get_sparse_csr_impl(self)->resize_and_clear_(values.numel(), size);
get_sparse_csr_impl(self)->set_member_tensors(
crow_indices, col_indices, values);
return self;
}
Tensor sparse_csr_tensor(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
TORCH_CHECK(
options.layout() == kSparseCsr,
"expected sparse CSR layout, but got layout ",
options.layout());
TORCH_CHECK(crow_indices.numel() >= 1, "expected crow_indices.numel() >= 1, but got ",
crow_indices.numel());
std::array<int64_t, 2> size;
if (col_indices.numel() > 0) {
size[0] = crow_indices.numel() - 1;
Tensor max_col_indices = std::get<0>(col_indices.max(0, false));
size[1] = *max_col_indices.data_ptr<int64_t>() + 1;
} else {
size[0] = 0;
size[1] = 0;
}
return at::sparse_csr_tensor(
crow_indices, col_indices, values, size, options);
}
// Access members of CSR tensors.
int64_t _nnz_sparse_csr(const SparseCsrTensor& self) {
return get_sparse_csr_impl(self)->nnz();
}
Tensor values_sparse_csr(const Tensor& self) {
return get_sparse_csr_impl(self)->values().alias();
}
Tensor crow_indices_sparse_csr(const Tensor& self) {
return get_sparse_csr_impl(self)->crow_indices().alias();
}
Tensor col_indices_sparse_csr(const Tensor& self) {
return get_sparse_csr_impl(self)->col_indices().alias();
}
bool _is_same_size_as_sparse_csr(
const SparseCsrTensor& self,
const SparseCsrTensor& src) {
return self.sizes().equals(src.sizes());
}
SparseCsrTensor& resize_as_sparse_csr_(
SparseCsrTensor& self,
const SparseCsrTensor& src) {
TORCH_CHECK(
src.is_sparse_csr() && self.is_sparse_csr(),
"resize_as_sparse_csr_: layout for self and src must be sparse_csr but got self, src: ",
self.layout(),
src.layout());
if (!_is_same_size_as_sparse_csr(self, src)) {
get_sparse_csr_impl(self)->resize_as_sparse_csr_tensor_(src);
}
return self;
}
} // namespace native
} // namespace at

View File

@ -0,0 +1,331 @@
#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/mkl/SparseCsrLinearAlgebra.h>
#include <algorithm>
namespace at {
namespace native {
using namespace at::sparse_csr;
// certain utiliy functions are usable from sparse COO.
using namespace at::sparse;
static constexpr bool is_msvc() {
#ifdef _MSC_VER
return true;
#else
return false;
#endif
}
// Functions for matrix multiplication.
Tensor& addmm_out_sparse_csr_dense_cpu(
const Tensor& self,
const SparseCsrTensor& op1,
const Tensor& op2,
const Scalar& beta,
const Scalar& alpha,
Tensor& out) {
AT_ASSERT(op1.is_sparse_csr());
Tensor expand_self = *expand_size(self, {op1.size(0), op2.size(1)}, "addmm_out_sparse_csr");
AT_ASSERT(expand_self.device().type() == kCPU);
TORCH_CHECK(
out.device().type() == kCPU,
"addmm: expected 'out' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(
op1.device().type() == kCPU,
"addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
TORCH_CHECK(
op2.device().type() == kCPU,
"addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");
TORCH_CHECK(
op1.dim() == 2,
"addmm: 2-D matrices expected, got ",
op1.dim(),
"D tensor");
TORCH_CHECK(
op2.dim() == 2,
"addmm: 2-D matrices expected, got ",
op2.dim(),
"D tensor");
TORCH_CHECK(
out.is_contiguous(),
"out argument must be contiguous, but got: ",
out.suggest_memory_format());
// ixk * kxj = ixj
int64_t dim_i = op1.size(0);
int64_t dim_j = op2.size(1);
int64_t dim_k = op1.size(1);
TORCH_CHECK(
op2.size(0) == dim_k,
"addmm: Expected dense matrix (op2) size(0)=",
dim_k,
", got ",
op2.size(0));
TORCH_CHECK(
op1.size(1) == dim_k,
"addmm: Expected sparse matrix (op1) size(1)=",
dim_k,
", got ",
op1.size(1));
out.resize_({dim_i, dim_j});
auto col_indices = op1.col_indices();
auto crow_indices = op1.crow_indices();
auto values = op1.values();
AT_DISPATCH_FLOATING_TYPES(
values.scalar_type(), "addmm_sparse_csr_dense", [&] {
scalar_t cast_beta = beta.to<scalar_t>();
if (!is_same_tensor(out, expand_self)) {
out.copy_(expand_self);
}
if (cast_beta == 0) {
out.zero_();
} else {
at::mul_out(out, expand_self, scalar_to_tensor(beta));
}
});
// Do not use MKL for Windows due to linking issues with sparse MKL routines.
if (at::hasMKL() && !is_msvc()) {
_sparse_mm_mkl_(out, op1, op2, expand_self, alpha, beta);
} else {
int64_t dense_stride0 = op1.stride(0);
int64_t dense_stride1 = op1.stride(1);
int64_t out_stride0 = out.stride(0);
int64_t out_stride1 = out.stride(1);
AT_DISPATCH_FLOATING_TYPES(
values.scalar_type(),
"sparse_csr_mm_cpu",
[&alpha,
&beta,
&op1,
&out,
&values,
&crow_indices,
&col_indices,
&dense_stride0,
&dense_stride1,
&out_stride0,
&out_stride1,
&dim_k]() {
AT_DISPATCH_INDEX_TYPES(
crow_indices.scalar_type(),
"csr_mm_crow_indices",
[&alpha,
&beta,
&op1,
&out,
&values,
&crow_indices,
&col_indices,
&dense_stride0,
&dense_stride1,
&out_stride0,
&out_stride1,
&dim_k]() {
scalar_t cast_alpha = alpha.to<scalar_t>();
scalar_t cast_beta = beta.to<scalar_t>();
scalar_t* dense_ptr = op1.data_ptr<scalar_t>();
scalar_t* out_ptr = out.data_ptr<scalar_t>();
auto col_indices_accessor = col_indices.accessor<index_t, 1>();
auto crow_indices_accessor =
crow_indices.accessor<index_t, 1>();
auto values_accessor = values.accessor<scalar_t, 1>();
at::parallel_for(
0,
crow_indices.size(0) - 1,
internal::GRAIN_SIZE,
[&](int64_t irow_start, int64_t irow_end) {
for (int irow = irow_start; irow < irow_end; ++irow) {
int start_index = crow_indices_accessor[irow];
int end_index = crow_indices_accessor[irow + 1];
for (int i = start_index; i < end_index; ++i) {
auto val = values_accessor[i];
auto icol = col_indices_accessor[i];
at::native::cpublas::axpy<scalar_t>(
dim_k,
cast_alpha * val,
dense_ptr + icol * dense_stride0,
dense_stride1,
out_ptr + irow * out_stride0,
out_stride1);
}
}
});
});
});
}
return out;
}
Tensor addmm_sparse_csr_dense_cpu(
const Tensor& self,
const SparseCsrTensor& sparse,
const Tensor& dense,
const Scalar& beta,
const Scalar& alpha) {
Tensor r = at::empty({0}, self.options());
at::addmm_out(r, self, sparse, dense, beta, alpha);
return r;
}
SparseCsrTensor& _sparse_csr_mm_out(
const SparseCsrTensor& sparse,
const Tensor& dense,
SparseCsrTensor& result) {
Tensor t = at::zeros({}, dense.options());
return at::addmm_out(result, t, sparse, dense, 0.0, 1.0); // redispatch!
}
Tensor _sparse_csr_addmm(
const Tensor& t,
const SparseCsrTensor& sparse,
const Tensor& dense,
const Scalar& beta,
const Scalar& alpha) {
// _sparse_addmm forward is functionally equivalent to addmm; it's
// just the backward that is different. This technically does an
// unnecessary redispatch, I was too lazy to make it not do that
return at::addmm(t, sparse, dense, beta, alpha);
}
// Functions for element-wise addition.
Tensor add_sparse_csr(const Tensor& self, const Tensor& other, const Scalar& alpha) {
auto commonDtype = at::result_type(self, other);
alpha_check(commonDtype, alpha);
Tensor result = at::empty({0}, self.options().dtype(commonDtype));
return at::add_out(result, self, other, alpha); // redispatch!
}
Tensor& add_sparse_csr_(Tensor& self, const Tensor& other, const Scalar& alpha) {
return at::add_out(self, self, other, alpha); // redispatch!
}
Tensor& add_out_dense_sparse_csr_cpu(
Tensor& out,
const Tensor& dense,
const SparseCsrTensor& src,
const Scalar& alpha) {
AT_ASSERT(dense.layout() == kStrided);
AT_ASSERT(src.is_sparse_csr());
AT_ASSERT(dense.device() == kCPU);
TORCH_CHECK(
out.is_contiguous(),
"out argument must be contiguous, but got: ",
out.suggest_memory_format());
TORCH_CHECK(
out.device() == kCPU,
"add: expected 'out' to be CPU tensor, but got tensor on device: ",
out.device());
TORCH_CHECK(
src.device() == kCPU,
"add: expected 'other' to be a CPU tensor, but got tensor on device: ",
src.device());
TORCH_CHECK(
dense.sizes().equals(src.sizes()),
"add: expected 'self' and 'other' to have same size, but self has size ",
dense.sizes(),
" while other has size ",
src.sizes(),
" (FYI: op2-sparse addition does not currently support broadcasting)");
auto commonDtype = promoteTypes(dense.scalar_type(), src.scalar_type());
TORCH_CHECK(
canCast(commonDtype, out.scalar_type()),
"Can't convert result type ",
commonDtype,
" to output ",
out.scalar_type(),
" in add operation");
auto src_values = src.values().to(commonDtype);
auto src_crow_indices = src.crow_indices();
auto src_col_indices = src.col_indices();
out.resize_as_(dense);
Tensor resultBuffer = out;
Tensor valuesBuffer = src_values.to(commonDtype);
if (out.scalar_type() != commonDtype) {
resultBuffer = dense.to(commonDtype);
} else if (!is_same_tensor(out, dense)) {
resultBuffer.copy_(dense);
}
AT_DISPATCH_ALL_TYPES(
commonDtype,
"add_out_op2_sparse_csr",
[&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() {
AT_DISPATCH_INDEX_TYPES(
src_crow_indices.scalar_type(),
"csr_add_out_crow_indices",
[&src_values, &out, &alpha, &src_crow_indices, &src_col_indices]() {
auto values_accessor = src_values.accessor<scalar_t, 1>();
scalar_t* out_ptr = out.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
auto crow_indices_accessor =
src_crow_indices.accessor<index_t, 1>();
auto col_indices_accessor =
src_col_indices.accessor<index_t, 1>();
auto out_strides0 = out.strides()[0];
auto out_strides1 = out.strides()[1];
for (int32_t irow = 0; irow < src_crow_indices.size(0) - 1;
++irow) {
int32_t start_index = crow_indices_accessor[irow];
int32_t end_index = crow_indices_accessor[irow + 1];
for (int i = start_index; i < end_index; ++i) {
auto icol = col_indices_accessor[i];
auto index = out.storage_offset() + irow * out_strides0 +
icol * out_strides1;
out_ptr[index] += cast_value * values_accessor[i];
}
}
});
});
return out;
}
Tensor& add_out_sparse_csr_cpu(
const Tensor& self,
const SparseCsrTensor& other,
const Scalar& alpha,
SparseCsrTensor& out) {
if (self.layout() == kStrided) {
return add_out_dense_sparse_csr_cpu(out, self, other, alpha);
} else {
TORCH_CHECK(
false,
"NotImplementedError: Addition of sparse CSR tensors is not yet implemented.")
}
return out;
}
} // namespace native
} // namespace at

View File

@ -1,17 +1,18 @@
// Basic functions on sparse tensors
#include <ATen/ATen.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/Layout.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/NativeFunctions.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/CPUBlas.h>
namespace at { namespace native {
namespace at {
namespace native {
using namespace at::sparse;
@ -36,7 +37,8 @@ int64_t _nnz_sparse(const SparseTensor& self) {
}
// Why are there so many methods to get indices and value?
// See Note [ Sparse: different methods to get indices and values ] in native_functions.yaml
// See Note [ Sparse: different methods to get indices and values ] in
// native_functions.yaml
Tensor _indices_sparse(const SparseTensor& self) {
return get_sparse_impl(self)->indices();
@ -46,20 +48,22 @@ Tensor _values_sparse(const SparseTensor& self) {
return get_sparse_impl(self)->values();
}
Tensor &_coalesced_sparse_(SparseTensor& self, bool coalesced) {
Tensor& _coalesced_sparse_(SparseTensor& self, bool coalesced) {
get_sparse_impl(self)->set_coalesced(coalesced);
return self;
}
Tensor indices_sparse(const Tensor& self) {
TORCH_CHECK(self.is_coalesced(),
"Cannot get indices on an uncoalesced tensor, please call .coalesce() first");
TORCH_CHECK(
self.is_coalesced(),
"Cannot get indices on an uncoalesced tensor, please call .coalesce() first");
return get_sparse_impl(self)->indices().alias();
}
Tensor values_sparse(const Tensor& self) {
TORCH_CHECK(self.is_coalesced(),
"Cannot get values on an uncoalesced tensor, please call .coalesce() first");
TORCH_CHECK(
self.is_coalesced(),
"Cannot get values on an uncoalesced tensor, please call .coalesce() first");
return get_sparse_impl(self)->values().alias();
}
@ -70,7 +74,11 @@ Tensor values_sparse(const Tensor& self) {
/*** Helper methods ***/
SparseTensor new_sparse(c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory) {
SparseTensor new_sparse(
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
AT_ASSERT(layout.has_value() && *layout == kSparse);
DispatchKey dispatch_key;
if (device_or_default(device).is_cuda()) {
@ -81,13 +89,20 @@ SparseTensor new_sparse(c10::optional<ScalarType> dtype, c10::optional<Layout> l
dispatch_key = DispatchKey::SparseCPU;
}
return detail::make_tensor<SparseTensorImpl>(
DispatchKeySet(dispatch_key), scalarTypeToTypeMeta(dtype_or_default(dtype)));
DispatchKeySet(dispatch_key),
scalarTypeToTypeMeta(dtype_or_default(dtype)));
}
/** Actual dispatched creation methods ***/
SparseTensor new_with_dims_sparse(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size, c10::optional<ScalarType> dtype,
c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory) {
SparseTensor new_with_dims_sparse(
int64_t sparse_dim,
int64_t dense_dim,
ArrayRef<int64_t> size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
SparseTensor self = new_sparse(dtype, layout, device, pin_memory);
get_sparse_impl(self)->resize_and_clear_(sparse_dim, dense_dim, size);
return self;
@ -105,15 +120,18 @@ SparseTensor new_with_dims_and_tensor_sparse(
c10::optional<bool> pin_memory) {
SparseTensor self = new_sparse(dtype, layout, device, pin_memory);
get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size);
// NOTE: There is no guarantee that `indices` and `values` don't contain AutogradMeta. However,
// we want to maintain the invariant that `indices_` and `values_` of a sparse tensor don't
// contain AutogradMeta, and to achieve that we shallow-copy `indices` and `values` here.
auto indices_shallow_copy = Tensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/indices.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/true));
auto values_shallow_copy = Tensor(values.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/values.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/true));
// NOTE: There is no guarantee that `indices` and `values` don't contain
// AutogradMeta. However, we want to maintain the invariant that `indices_`
// and `values_` of a sparse tensor don't contain AutogradMeta, and to achieve
// that we shallow-copy `indices` and `values` here.
auto indices_shallow_copy =
Tensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/indices.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/true));
auto values_shallow_copy =
Tensor(values.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/values.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/true));
alias_into_sparse(self, indices_shallow_copy, values_shallow_copy);
return self;
}
@ -121,9 +139,18 @@ SparseTensor new_with_dims_and_tensor_sparse(
/** Public creation API that dispatch to methods above **/
/** Empty init **/
Tensor empty_sparse(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<MemoryFormat> optional_memory_format) {
TORCH_CHECK(!pin_memory.has_value() || !*pin_memory, "Only dense CPU tensors can be pinned");
return new_with_dims_sparse(size.size(), 0, size, dtype, layout, device, pin_memory);
Tensor empty_sparse(
IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory,
c10::optional<MemoryFormat> optional_memory_format) {
TORCH_CHECK(
!pin_memory.has_value() || !*pin_memory,
"Only dense CPU tensors can be pinned");
return new_with_dims_sparse(
size.size(), 0, size, dtype, layout, device, pin_memory);
}
/* Shape init */
@ -142,16 +169,16 @@ Tensor sparse_coo_tensor(IntArrayRef size,
// helper
namespace {
static inline Tensor expand_values_if_needed(const Tensor& values) {
// expand
if (values.dim() == 0) {
// Mimic Numpy behavior here and treat it as a 1D tensor
return values.expand({1});
} else {
return values;
}
static inline Tensor expand_values_if_needed(const Tensor& values) {
// expand
if (values.dim() == 0) {
// Mimic Numpy behavior here and treat it as a 1D tensor
return values.expand({1});
} else {
return values;
}
}
} // namespace
Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
c10::optional<ScalarType> dtype,
@ -164,11 +191,21 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
Tensor values = expand_values_if_needed(values_);
// arg checking
TORCH_CHECK(!options.has_layout() || options.layout() == kSparse, "expected sparse layout, but got layout ", options.layout());
// the following checks are redundant because they are also checked in SparseTensorImpl::set_indices_and_values_unsafe
// but we need to ensure them in order to infer the shape.
TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sizes())
TORCH_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
TORCH_CHECK(
!options.has_layout() || options.layout() == kSparse,
"expected sparse layout, but got layout ",
options.layout());
// the following checks are redundant because they are also checked in
// SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them
// in order to infer the shape.
TORCH_CHECK(
indices.dim() == 2,
"indices must be sparse_dim x nnz, but got: ",
indices.sizes())
TORCH_CHECK(
!indices.is_sparse(),
"expected indices to be a dense tensor, but got indices of layout ",
indices.layout());
// If sizes are not given, it is inferred as max index of each dim.
int64_t sparse_dim = indices.size(0);
@ -176,54 +213,86 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
std::vector<int64_t> computed_sizes(sparse_dim + dense_dim);
if (indices.numel() > 0) {
// If the indices has elements in it, we infer the minimum sparse dimension sizes
// as the max value of each dim in indices.
// NB: It used to keepdim. I think that was wrong.
Tensor min_indices = std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
Tensor computed_indices_sizes = std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
// If the indices has elements in it, we infer the minimum sparse dimension
// sizes as the max value of each dim in indices. NB: It used to keepdim. I
// think that was wrong.
Tensor min_indices =
std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
Tensor computed_indices_sizes =
std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
computed_indices_sizes.add_(1); // len = max_index + 1
Tensor cpu_min_indices = min_indices.to(at::DeviceType::CPU);
Tensor cpu_computed_indices_sizes = computed_indices_sizes.to(at::DeviceType::CPU);
Tensor cpu_computed_indices_sizes =
computed_indices_sizes.to(at::DeviceType::CPU);
auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
auto cpu_computed_indices_sizes_accessor = cpu_computed_indices_sizes.accessor<int64_t, 1>();
auto cpu_computed_indices_sizes_accessor =
cpu_computed_indices_sizes.accessor<int64_t, 1>();
for (int64_t d = 0; d < sparse_dim; d++) {
int64_t min_index_in_dim = cpu_min_indices_accessor[d];
TORCH_CHECK(min_index_in_dim >= 0,
"found negative index ", min_index_in_dim, " for dim ", d);
computed_sizes[static_cast<size_t>(d)] = cpu_computed_indices_sizes_accessor[d];
TORCH_CHECK(
min_index_in_dim >= 0,
"found negative index ",
min_index_in_dim,
" for dim ",
d);
computed_sizes[static_cast<size_t>(d)] =
cpu_computed_indices_sizes_accessor[d];
}
} else {
// If the indices doesn't have elements in it, there is not enough information
// to know what the minimum sparse dimension sizes should be, and in this case
// we set them to 0
// If the indices doesn't have elements in it, there is not enough
// information to know what the minimum sparse dimension sizes should be,
// and in this case we set them to 0
for (int64_t d = 0; d < sparse_dim; d++) {
computed_sizes[static_cast<size_t>(d)] = 0;
}
}
for (int64_t d = 0; d < dense_dim; d++) {
computed_sizes[static_cast<size_t>(sparse_dim + d)] = values.size(d+1);
computed_sizes[static_cast<size_t>(sparse_dim + d)] = values.size(d + 1);
}
return at::_sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, computed_sizes, indices, values, values.options().layout(kSparse));
sparse_dim,
dense_dim,
computed_sizes,
indices,
values,
values.options().layout(kSparse));
}
void _validate_sparse_coo_tensor_args(const Tensor& indices, const Tensor& values_, ArrayRef<int64_t> size) {
void _validate_sparse_coo_tensor_args(
const Tensor& indices,
const Tensor& values_,
ArrayRef<int64_t> size) {
Tensor values = expand_values_if_needed(values_);
// the following checks are redundant because they are also checked in SparseTensorImpl::set_indices_and_values_unsafe
// but we need to ensure them in order to infer the shape.
TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sizes())
TORCH_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
// the following checks are redundant because they are also checked in
// SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them
// in order to infer the shape.
TORCH_CHECK(
indices.dim() == 2,
"indices must be sparse_dim x nnz, but got: ",
indices.sizes())
TORCH_CHECK(
!indices.is_sparse(),
"expected indices to be a dense tensor, but got indices of layout ",
indices.layout());
int64_t sparse_dim = indices.size(0);
int64_t dense_dim = values.dim() - 1;
TORCH_CHECK(size.size() == sparse_dim + dense_dim,
"number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
TORCH_CHECK(
size.size() == sparse_dim + dense_dim,
"number of dimensions must be sparse_dim (",
sparse_dim,
") + dense_dim (",
dense_dim,
"), but got ",
size.size());
// Check to make sure all indices are within the boundaries of `size`
if (indices.numel() > 0) {
Tensor min_indices = std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
Tensor max_indices = std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
Tensor min_indices =
std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
Tensor max_indices =
std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
Tensor cpu_min_indices, cpu_max_indices;
if (indices.is_cuda()) {
cpu_min_indices = min_indices.to(at::DeviceType::CPU);
@ -238,17 +307,28 @@ void _validate_sparse_coo_tensor_args(const Tensor& indices, const Tensor& value
// NB: This used to sync ndim times to access each entry; now we copy
// everything to CPU first and then access it.
int64_t min_index_in_dim = cpu_min_indices_accessor[d];
TORCH_CHECK(min_index_in_dim >= 0,
"found negative index ", min_index_in_dim, " for dim ", d);
TORCH_CHECK(
min_index_in_dim >= 0,
"found negative index ",
min_index_in_dim,
" for dim ",
d);
int64_t max_index_in_dim = cpu_max_indices_accessor[d];
int64_t dim_size = size[static_cast<size_t>(d)];
TORCH_CHECK(max_index_in_dim < dim_size,
"size is inconsistent with indices: for dim ", d, ", size is ", dim_size, " but found index ", max_index_in_dim);
TORCH_CHECK(
max_index_in_dim < dim_size,
"size is inconsistent with indices: for dim ",
d,
", size is ",
dim_size,
" but found index ",
max_index_in_dim);
}
}
}
// NB: Got rid of the sizes == NULL case
Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
@ -256,9 +336,11 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
// arg checking
TORCH_CHECK(!options.has_layout() || options.layout() == kSparse, "expected sparse layout, but got layout ", options.layout());
TORCH_CHECK(
!options.has_layout() || options.layout() == kSparse,
"expected sparse layout, but got layout ",
options.layout());
at::native::_validate_sparse_coo_tensor_args(indices, values, size);
return at::native::_sparse_coo_tensor_unsafe(
@ -291,19 +373,31 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, I
int64_t dense_dim = values.dim() - 1;
return at::_sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, size, indices, values, values.options().layout(kSparse));
sparse_dim,
dense_dim,
size,
indices,
values,
values.options().layout(kSparse));
}
// NB: Deleted newWithSizeNd variants
SparseTensor clone_sparse(const SparseTensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
SparseTensor clone_sparse(
const SparseTensor& self,
c10::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(
!optional_memory_format.has_value(),
"unsupported memory format option ",
optional_memory_format.value());
SparseTensor other = new_with_dims_sparse(self.sparse_dim(), self.dense_dim(), self.sizes(),
optTypeMetaToScalarType(self.options().dtype_opt()), self.options().layout_opt(),
self.options().device_opt(), self.options().pinned_memory_opt());
SparseTensor other = new_with_dims_sparse(
self.sparse_dim(),
self.dense_dim(),
self.sizes(),
optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt());
copy_into_sparse(other, self._indices(), self._values(), true);
return other._coalesced_(self.is_coalesced());
}
@ -312,21 +406,32 @@ SparseTensor clone_sparse(const SparseTensor& self, c10::optional<c10::MemoryFor
* reshaping methods
******************************************************************************/
SparseTensor& sparse_resize_(SparseTensor& self, ArrayRef<int64_t> size, int64_t sparse_dim, int64_t dense_dim) {
SparseTensor& sparse_resize_(
SparseTensor& self,
ArrayRef<int64_t> size,
int64_t sparse_dim,
int64_t dense_dim) {
get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size);
return self;
}
SparseTensor& sparse_resize_and_clear_(SparseTensor& self, ArrayRef<int64_t> size, int64_t sparse_dim, int64_t dense_dim) {
SparseTensor& sparse_resize_and_clear_(
SparseTensor& self,
ArrayRef<int64_t> size,
int64_t sparse_dim,
int64_t dense_dim) {
get_sparse_impl(self)->resize_and_clear_(sparse_dim, dense_dim, size);
return self;
}
namespace {
bool _is_same_size_as_sparse(const SparseTensor& self, const SparseTensor& src) {
return self.sparse_dim() == src.sparse_dim() && self.dense_dim() == src.dense_dim() && self.sizes().equals(src.sizes());
}
bool _is_same_size_as_sparse(
const SparseTensor& self,
const SparseTensor& src) {
return self.sparse_dim() == src.sparse_dim() &&
self.dense_dim() == src.dense_dim() && self.sizes().equals(src.sizes());
}
} // namespace
// Invoked from native/Resize.cpp (no dynamic dispatch necessary)
SparseTensor& resize_as_sparse_(SparseTensor& self, const SparseTensor& src) {
@ -336,23 +441,33 @@ SparseTensor& resize_as_sparse_(SparseTensor& self, const SparseTensor& src) {
return self;
}
SparseTensor dense_to_sparse(const Tensor& self){
SparseTensor dense_to_sparse(const Tensor& self) {
return dense_to_sparse(self, self.dim());
}
SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim) {
int64_t dims = self.dim();
// TODO: it seems like sparse_dim == 0 could be supported even if self.dim() > 0,
// but this would take some work and doesn't seem particularly useful.
TORCH_CHECK(sparse_dim > 0 || self.dim() == 0, "sparse_dim must be >0 if dimensionality > 0");
TORCH_CHECK(sparse_dim <= dims,
"sparse_dim must be less than or equal to self.dim()");
// TODO: it seems like sparse_dim == 0 could be supported even if self.dim() >
// 0, but this would take some work and doesn't seem particularly useful.
TORCH_CHECK(
sparse_dim > 0 || self.dim() == 0,
"sparse_dim must be >0 if dimensionality > 0");
TORCH_CHECK(
sparse_dim <= dims,
"sparse_dim must be less than or equal to self.dim()");
at::TensorOptions sparse_options = self.options().layout(kSparse);
std::vector<int64_t> sizes = self.sizes().vec();
Tensor nz = self.nonzero().transpose(0, 1);
if (nz.size(1) == 0) {
return new_with_dims_sparse(sparse_dim, dims - sparse_dim, sizes, optTypeMetaToScalarType(sparse_options.dtype_opt()), sparse_options.layout_opt(), sparse_options.device_opt(), sparse_options.pinned_memory_opt());
return new_with_dims_sparse(
sparse_dim,
dims - sparse_dim,
sizes,
optTypeMetaToScalarType(sparse_options.dtype_opt()),
sparse_options.layout_opt(),
sparse_options.device_opt(),
sparse_options.pinned_memory_opt());
}
Tensor indices;
if (sparse_dim == dims) {
@ -360,7 +475,8 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
} else {
Tensor i = nz.narrow(0, 0, sparse_dim);
std::tie(indices, std::ignore, std::ignore) = unique_dim(i, 1);
indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633
indices = indices.contiguous(); // many sparse CUDA kernels require
// contiguity, see issue #12633
}
Tensor values;
@ -369,8 +485,8 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve);
} else {
AT_ASSERT(nz.sizes().equals({0, 1}));
// In this cases, indices is a clone of nz, which is a tensor of shape (0, 1).
// Given sparse tensor invariants, values should be shape (1,)
// In this cases, indices is a clone of nz, which is a tensor of shape (0,
// 1). Given sparse tensor invariants, values should be shape (1,)
values = self.unsqueeze(0).clone(at::MemoryFormat::Preserve);
}
@ -380,18 +496,27 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
// NB: Dropped the resizeNd variants
Tensor sparse_to_dense(const SparseTensor& self, c10::optional<ScalarType> dtype) {
TORCH_CHECK(!dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
if(self.scalar_type() == ScalarType::Half && self.options().device().is_cpu()) {
Tensor sparse_to_dense(
const SparseTensor& self,
c10::optional<ScalarType> dtype) {
TORCH_CHECK(
!dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
if (self.scalar_type() == ScalarType::Half &&
self.options().device().is_cpu()) {
TORCH_CHECK(false, "to_dense() not supported for float16 on CPU");
}
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
return dst.add_(self);
}
SparseTensor& copy_sparse_(SparseTensor& self, const SparseTensor& src, bool non_blocking) {
if (is_same_tensor(self, src)) return self;
get_sparse_impl(self)->resize_(src.sparse_dim(), src.dense_dim(), src.sizes());
SparseTensor& copy_sparse_(
SparseTensor& self,
const SparseTensor& src,
bool non_blocking) {
if (is_same_tensor(self, src))
return self;
get_sparse_impl(self)->resize_(
src.sparse_dim(), src.dense_dim(), src.sizes());
copy_into_sparse(self, src._indices(), src._values(), non_blocking);
return self._coalesced_(src.is_coalesced());
}
@ -426,7 +551,11 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
Tensor indices_scalar = flatten_indices(indices, self.sizes());
SparseTensor dst = new_sparse(optTypeMetaToScalarType(self.options().dtype_opt()), self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt());
SparseTensor dst = new_sparse(
optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt());
get_sparse_impl(dst)->resize_(sparse_dim, dense_dim, self.sizes());
// TODO: is there a more idiomatic way to do this?
Tensor newIndices = at::empty(indices.sizes(), indices.options());
@ -436,38 +565,51 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
Tensor indicesBuffer;
Tensor indicesPermutation;
std::tie(indicesBuffer, indicesPermutation) = indices_scalar.sort(0);
// NB: The accessor accesses here rely on self._nnz() > 0 (tested earlier in this function)
// NB: The accessor accesses here rely on self._nnz() > 0 (tested earlier in
// this function)
auto newIndicesAccessor = newIndices.accessor<int64_t, 2>();
auto indicesAccessor = indices.accessor<int64_t, 2>();
auto indicesPermutationAccessor = indicesPermutation.accessor<int64_t, 1>();
auto indicesBufferAccessor = indicesBuffer.accessor<int64_t, 1>();
int64_t i = -1;
AT_DISPATCH_ALL_TYPES(
values.scalar_type(), "coalesce", [&] {
int64_t prev = -1;
int64_t blockSize = values.stride(0);
scalar_t* values_ptr = values.data_ptr<scalar_t>();
scalar_t* newValues_ptr = newValues.data_ptr<scalar_t>();
for (int64_t j = 0; j < nnz; j++) {
int64_t pos = indicesPermutationAccessor[j];
int64_t curr = indicesBufferAccessor[j];
if (curr == prev) {
if (values.numel() > 0) { // if values is an empty tensor, there are no elements to copy
at::native::cpublas::axpy<scalar_t>(blockSize, 1, values_ptr + pos * blockSize, 1, newValues_ptr + i * blockSize, 1);
}
} else {
++i;
for (int64_t d = 0; d < sparse_dim; d++) {
newIndicesAccessor[d][i] = indicesAccessor[d][pos];
}
if (values.numel() > 0) { // if values is an empty tensor, there are no elements to copy
at::native::cpublas::copy<scalar_t>(blockSize, values_ptr + pos * blockSize, 1, newValues_ptr + i * blockSize, 1);
}
}
prev = curr;
AT_DISPATCH_ALL_TYPES(values.scalar_type(), "coalesce", [&] {
int64_t prev = -1;
int64_t blockSize = values.stride(0);
scalar_t* values_ptr = values.data_ptr<scalar_t>();
scalar_t* newValues_ptr = newValues.data_ptr<scalar_t>();
for (int64_t j = 0; j < nnz; j++) {
int64_t pos = indicesPermutationAccessor[j];
int64_t curr = indicesBufferAccessor[j];
if (curr == prev) {
if (values.numel() >
0) { // if values is an empty tensor, there are no elements to copy
at::native::cpublas::axpy<scalar_t>(
blockSize,
1,
values_ptr + pos * blockSize,
1,
newValues_ptr + i * blockSize,
1);
}
});
} else {
++i;
for (int64_t d = 0; d < sparse_dim; d++) {
newIndicesAccessor[d][i] = indicesAccessor[d][pos];
}
if (values.numel() >
0) { // if values is an empty tensor, there are no elements to copy
at::native::cpublas::copy<scalar_t>(
blockSize,
values_ptr + pos * blockSize,
1,
newValues_ptr + i * blockSize,
1);
}
}
prev = curr;
}
});
dst._coalesced_(true);
get_sparse_impl(dst)->set_nnz_and_narrow(i + 1);
@ -475,7 +617,6 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
return dst;
}
// --------------------------------------------------------------------
// sparse_mask(D, S) -> S
//
@ -485,12 +626,11 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
template <typename scalar_t>
void inline sparse_mask_out_cpu_kernel(
Tensor& r_values,
const Tensor& t,
const int64_t r_nnz,
const int64_t sparse_dim,
const Tensor& mask_indices
) {
Tensor& r_values,
const Tensor& t,
const int64_t r_nnz,
const int64_t sparse_dim,
const Tensor& mask_indices) {
auto r_values_accessor = r_values.accessor<scalar_t, 1>();
auto mask_indices_accessor = mask_indices.accessor<int64_t, 2>();
scalar_t* t_ptr = t.data_ptr<scalar_t>();
@ -506,14 +646,23 @@ void inline sparse_mask_out_cpu_kernel(
});
}
SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const SparseTensor& mask) {
SparseTensor& sparse_mask_out_cpu(
SparseTensor& r,
const Tensor& t,
const SparseTensor& mask) {
TORCH_CHECK(mask.is_coalesced(), "sparse_mask: mask is uncoalesced");
TORCH_CHECK(mask.sizes().equals(t.sizes()), "sparse_mask: operands have incompatible sizes; self has size ",
t.sizes(), " but mask has size ", mask.sizes());
TORCH_CHECK(
mask.sizes().equals(t.sizes()),
"sparse_mask: operands have incompatible sizes; self has size ",
t.sizes(),
" but mask has size ",
mask.sizes());
AT_ASSERT(!t.is_cuda()); // we were supposed to have dispatched on this
TORCH_CHECK(!r.is_cuda(), "sparse_mask: expected 'out' to be CPU, but got CUDA");
TORCH_CHECK(!mask.is_cuda(), "sparse_mask: expected 'mask' to be CPU, but got CUDA");
resize_as_sparse_(r, mask);
TORCH_CHECK(
!r.is_cuda(), "sparse_mask: expected 'out' to be CPU, but got CUDA");
TORCH_CHECK(
!mask.is_cuda(), "sparse_mask: expected 'mask' to be CPU, but got CUDA");
at::resize_as_sparse_(r, mask);
if (mask._nnz() == 0) {
return r.zero_();
}
@ -527,14 +676,15 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse
int64_t r_nnz = mask._nnz();
get_sparse_impl(r)->set_nnz_and_narrow(r_nnz);
if (t.numel() == 0) { // if t is an empty tensor, there is no need to mask its elements
if (t.numel() ==
0) { // if t is an empty tensor, there is no need to mask its elements
return r;
}
if (dim > sparse_dim) {
// Get a flattened sparse indices, similar to NOTE [ Flatten Sparse Indices ].
// Keeping this implementation because it is faster than flatten_indices()
// Get a flattened sparse indices, similar to NOTE [ Flatten Sparse Indices
// ]. Keeping this implementation because it is faster than
// flatten_indices()
Tensor indices = at::zeros({mask._nnz()}, mask_indices.options());
for (int64_t d = 0; d < mask.sparse_dim(); d++) {
indices.mul_(mask.size(d));
@ -553,11 +703,7 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse
} else {
AT_DISPATCH_ALL_TYPES(r_values.scalar_type(), "sparse_mask", [&] {
sparse_mask_out_cpu_kernel<scalar_t>(
r_values,
t,
r_nnz,
sparse_dim,
mask_indices);
r_values, t, r_nnz, sparse_dim, mask_indices);
});
}
return r;
@ -570,26 +716,30 @@ SparseTensor sparse_mask_cpu(const Tensor& t, const SparseTensor& mask) {
}
Tensor sparse_mask_helper_cpu(
const SparseTensor& t,
const Tensor& mask_indices
) {
const SparseTensor& t,
const Tensor& mask_indices) {
/*
This is a helper function which filter values from `t._values()` using the `mask_indices`.
This CPU implementation uses a simple hash_map to filter values by matching the `mask_indices`
with the indices at tensor input `t`.
This is a helper function which filter values from `t._values()` using the
`mask_indices`. This CPU implementation uses a simple hash_map to filter
values by matching the `mask_indices` with the indices at tensor input `t`.
Inputs:
`t` - coalesced sparse tensor input
`mask_indices` - mask indices tensor
Note: The nnz in the output tensor will be same as the `mask_indices`. So it will
works independently if the mask is coalesced or not.
Note: The nnz in the output tensor will be same as the `mask_indices`. So it
will works independently if the mask is coalesced or not.
*/
TORCH_CHECK(t.is_sparse(), "t: input is not a sparse tensor");
TORCH_CHECK(t.is_coalesced(), "t: input is uncoalesced");
TORCH_CHECK(mask_indices.dim() == t._indices().dim(), "mask_indices: operands have incompatible indices dim; self has dim ",
t._indices().dim(), " but mask has dim ", mask_indices.dim());
TORCH_CHECK(mask_indices.is_contiguous(), "mask_indices: mask is not contiguous");
TORCH_CHECK(
mask_indices.dim() == t._indices().dim(),
"mask_indices: operands have incompatible indices dim; self has dim ",
t._indices().dim(),
" but mask has dim ",
mask_indices.dim());
TORCH_CHECK(
mask_indices.is_contiguous(), "mask_indices: mask is not contiguous");
int64_t r_nnz = mask_indices.size(1);
auto t_v = t._values();
@ -600,31 +750,35 @@ Tensor sparse_mask_helper_cpu(
auto t_i = t._indices();
auto t_nnz = t._nnz();
std::unordered_map<int64_t, int64_t> t_flatten_indices = std::unordered_map<int64_t, int64_t>{};
std::unordered_map<int64_t, int64_t> t_flatten_indices =
std::unordered_map<int64_t, int64_t>{};
auto full_size = t.sizes();
auto ti_flattened_indices = at::sparse::flatten_indices(t_i, full_size);
// Step 1: flatten the sparse indices `t._indices()` tensor and then map this flatten value `index` to the original position `i`
// Step 1: flatten the sparse indices `t._indices()` tensor and then map this
// flatten value `index` to the original position `i`
auto t_indices_accessor = t_i.accessor<int64_t, 2>();
for(int64_t i = 0; i < t_nnz; i++) {
for (int64_t i = 0; i < t_nnz; i++) {
int64_t index = ti_flattened_indices.data_ptr<int64_t>()[i];
t_flatten_indices[index] = i;
}
// Step 2: Filter `t._values()` values by matching the flatten `mask_indices` with the flatten `t._indices()` using the
// hash_map `t_flatten_indices`
// Step 2: Filter `t._values()` values by matching the flatten `mask_indices`
// with the flatten `t._indices()` using the hash_map `t_flatten_indices`
auto flattened_mask_indices = at::sparse::flatten_indices(mask_indices, full_size);
auto flattened_mask_indices =
at::sparse::flatten_indices(mask_indices, full_size);
at::parallel_for(0, r_nnz, 0, [&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
int64_t index = flattened_mask_indices.data_ptr<int64_t>()[i];
auto iter = t_flatten_indices.find(index);
if (iter != t_flatten_indices.end()) {
r_values[i] = t_v[ iter->second ];
r_values[i] = t_v[iter->second];
}
}
});
return r_values;
}
}} // namespace at::native
} // namespace native
} // namespace at

View File

@ -819,6 +819,7 @@ void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j,
// r_ = alpha * sparse * dense
scalar_t cast_alpha = alpha.to<scalar_t>();
scalar_t cast_beta = beta.to<scalar_t>();
if (cast_beta == 0) {
r.zero_();
} else if (cast_beta == 1) {
@ -981,8 +982,7 @@ Tensor _sparse_mm(
// we can redispatch to addmm_out; this is NOT an implementation of
// the sparse masking version of mm
SparseTensor& _sparse_mm_out(const SparseTensor& sparse,
const Tensor& dense
,
const Tensor& dense,
SparseTensor& result) {
Tensor t = at::zeros({}, dense.options());
return at::addmm_out(result, t, sparse, dense, 0, 1); // redispatch!
@ -1601,6 +1601,7 @@ Tensor& bmm_out_sparse_cpu(const SparseTensor& self, const Tensor& mat2, Tensor&
Tensor sparse_values = values.slice(0, mat_el_begin_idx, mat_el_end_idx);
int64_t sparse_nnz = mat_el_end_idx - mat_el_begin_idx;
s_addmm_out_sparse_dense_worker<scalar_t>(
sparse_nnz,
dim_i, dim_j, dim_k,

View File

@ -416,6 +416,12 @@ class TORCH_API Tensor {
return impl_->is_sparse();
}
/// Returns is a `Tensor` has a sparse CSR backend.
bool is_sparse_csr() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_sparse_csr();
}
/// Returns if a `Tensor` is mkldnn tensor.
bool is_mkldnn() const {
// NB: this is not a native function to avoid dispatching overhead.

View File

@ -1,7 +1,5 @@
# PyTorch Benchmarks
NOTE: This folder is currently work in progress.
This folder contains scripts that produce reproducible timings of various PyTorch features.
It also provides mechanisms to compare PyTorch with other frameworks.

View File

@ -0,0 +1,5 @@
#Sparse benchmarks
These sets of benchmarks are for the sparse matrix functionality. They exist for
comparing the performance of sparse matrix routines such as SpMV between various
sparse matrix formats and with other frameworks such as TensorFlow.

View File

@ -0,0 +1,3 @@
if __name__ == "__main__":
pass

105
benchmarks/sparse/spmm.py Normal file
View File

@ -0,0 +1,105 @@
import argparse
import sys
import torch
from utils import gen_sparse_csr, gen_sparse_coo, Event
def test_sparse_csr(m, n, k, nnz, test_count):
start_timer = Event(enable_timing=True)
stop_timer = Event(enable_timing=True)
csr = gen_sparse_csr((m, k), nnz)
mat = torch.randn(k, n, dtype=torch.double)
times = []
for _ in range(test_count):
start_timer.record()
csr.matmul(mat)
stop_timer.record()
times.append(start_timer.elapsed_time(stop_timer))
return sum(times) / len(times)
def test_sparse_coo(m, n, k, nnz, test_count):
start_timer = Event(enable_timing=True)
stop_timer = Event(enable_timing=True)
coo = gen_sparse_coo((m, k), nnz)
mat = torch.randn(k, n, dtype=torch.double)
times = []
for _ in range(test_count):
start_timer.record()
coo.matmul(mat)
stop_timer.record()
times.append(start_timer.elapsed_time(stop_timer))
return sum(times) / len(times)
def test_sparse_coo_and_csr(m, n, k, nnz, test_count):
start = Event(enable_timing=True)
stop = Event(enable_timing=True)
coo, csr = gen_sparse_coo_and_csr((m, k), nnz)
mat = torch.randn((k, n), dtype=torch.double)
times = []
for _ in range(test_count):
start.record()
coo.matmul(mat)
stop.record()
times.append(start.elapsed_time(stop))
coo_mean_time = sum(times) / len(times)
times = []
for _ in range(test_count):
start.record()
csr.matmul(mat)
stop.record()
times.append(start.elapsed_time(stop))
csr_mean_time = sum(times) / len(times)
return coo_mean_time, csr_mean_time
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SpMM")
parser.add_argument("--format", default='csr', type=str)
parser.add_argument("--m", default='1000', type=int)
parser.add_argument("--n", default='1000', type=int)
parser.add_argument("--k", default='1000', type=int)
parser.add_argument("--nnz_ratio", default='0.1', type=float)
parser.add_argument("--outfile", default='stdout', type=str)
parser.add_argument("--test_count", default='10', type=int)
args = parser.parse_args()
if args.outfile == 'stdout':
outfile = sys.stdout
elif args.outfile == 'stderr':
outfile = sys.stderr
else:
outfile = open(args.outfile, "a")
test_count = args.test_count
m = args.m
n = args.n
k = args.k
nnz_ratio = args.nnz_ratio
nnz = int(nnz_ratio * m * k)
if args.format == 'csr':
time = test_sparse_csr(m, n, k, nnz, test_count)
elif args.format == 'coo':
time = test_sparse_coo(m, n, k, nnz, test_count)
elif args.format == 'both':
time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count)
if args.format == 'both':
print("format=coo", " nnz_ratio=", nnz_ratio, " m=", m, " n=", n, " k=", k, " time=", time_coo, file=outfile)
print("format=csr", " nnz_ratio=", nnz_ratio, " m=", m, " n=", n, " k=", k, " time=", time_csr, file=outfile)
else:
print("format=", args.format, " nnz_ratio=", nnz_ratio, " m=", m, " n=", n, " k=", k, " time=", time,
file=outfile)

103
benchmarks/sparse/spmv.py Normal file
View File

@ -0,0 +1,103 @@
import argparse
import sys
import torch
from .utils import gen_sparse_csr, gen_sparse_coo, gen_sparse_coo_and_csr, Event
def test_sparse_csr(m, nnz, test_count):
start_timer = Event(enable_timing=True)
stop_timer = Event(enable_timing=True)
csr = gen_sparse_csr((m, m), nnz)
vector = torch.randn(m, dtype=torch.double)
times = []
for _ in range(test_count):
start_timer.record()
csr.matmul(vector)
stop_timer.record()
times.append(start_timer.elapsed_time(stop_timer))
return sum(times) / len(times)
def test_sparse_coo(m, nnz, test_count):
start_timer = Event(enable_timing=True)
stop_timer = Event(enable_timing=True)
coo = gen_sparse_coo((m, m), nnz)
vector = torch.randn(m, dtype=torch.double)
times = []
for _ in range(test_count):
start_timer.record()
coo.matmul(vector)
stop_timer.record()
times.append(start_timer.elapsed_time(stop_timer))
return sum(times) / len(times)
def test_sparse_coo_and_csr(m, nnz, test_count):
start = Event(enable_timing=True)
stop = Event(enable_timing=True)
coo, csr = gen_sparse_coo_and_csr((m, m), nnz)
vector = torch.randn(m, dtype=torch.double)
times = []
for _ in range(test_count):
start.record()
coo.matmul(vector)
stop.record()
times.append(start.elapsed_time(stop))
coo_mean_time = sum(times) / len(times)
times = []
for _ in range(test_count):
start.record()
csr.matmul(vector)
stop.record()
times.append(start.elapsed_time(stop))
csr_mean_time = sum(times) / len(times)
return coo_mean_time, csr_mean_time
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SpMV")
parser.add_argument("--format", default='csr', type=str)
parser.add_argument("--m", default='1000', type=int)
parser.add_argument("--nnz_ratio", default='0.1', type=float)
parser.add_argument("--outfile", default='stdout', type=str)
parser.add_argument("--test_count", default='10', type=int)
args = parser.parse_args()
if args.outfile == 'stdout':
outfile = sys.stdout
elif args.outfile == 'stderr':
outfile = sys.stderr
else:
outfile = open(args.outfile, "a")
test_count = args.test_count
m = args.m
nnz_ratio = args.nnz_ratio
nnz = int(nnz_ratio * m * m)
if args.format == 'csr':
time = test_sparse_csr(m, nnz, test_count)
elif args.format == 'coo':
time = test_sparse_coo(m, nnz, test_count)
elif args.format == 'both':
time_coo, time_csr = test_sparse_coo_and_csr(m, nnz, test_count)
if args.format != 'both':
print("format=", args.format, " nnz_ratio=", nnz_ratio, " m=", m,
" time=", time, file=outfile)
else:
print("format=coo", " nnz_ratio=", nnz_ratio, " m=", m,
" time=", time_coo, file=outfile)
print("format=csr", " nnz_ratio=", nnz_ratio, " m=", m,
" time=", time_csr, file=outfile)

View File

@ -0,0 +1,41 @@
OUTFILE=spmm-no-mkl-test.txt
PYTORCH_HOME=$1
cd $PYTORCH_HOME
echo "" >> $OUTFILE
echo "----- USE_MKL=1 -----" >> $OUTFILE
rm -rf build
export USE_MKL=1
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
python setup.py build --cmake-only
ccmake build # or cmake-gui build
python setup.py install
cd benchmarks
echo "!! SPARSE SPMM TIME BENCHMARK!! " >> $OUTFILE
for dim0 in 1000 5000 10000; do
for nnzr in 0.01 0.05 0.1 0.3; do
python -m sparse.spmm --format csr --m $dim0 --n $dim0 --k $dim0 --nnz_ratio $nnzr --outfile $OUTFILE
# python -m sparse.spmm --format coo --m $dim0 --n $dim0 --k $dim0 --nnz_ratio $nnzr --outfile $OUTFILE
done
done
echo "----------------------" >> $OUTFILE
cd $PYTORCH_HOME
echo "----- USE_MKL=0 ------" >> $OUTFILE
rm -rf build
export USE_MKL=0
python setup.py install
cd benchmarks
for dim0 in 1000 5000 10000; do
for nnzr in 0.01 0.05 0.1 0.3; do
python -m sparse.spmv --format csr --m $dim0 --nnz_ratio $nnzr --outfile $OUTFILE
python -m sparse.spmv --format coo --m $dim0 --nnz_ratio $nnzr --outfile $OUTFILE
done
done
echo "----------------------" >> $OUTFILE

View File

@ -0,0 +1,54 @@
import torch
import functools
import random
import operator
import numpy as np
import time
# shim for torch.cuda.Event when running on cpu
class Event(object):
def __init__(self, enable_timing):
pass
def record(self):
self.time = time.perf_counter()
def elapsed_time(self, end_event):
assert isinstance(end_event, Event)
return end_event.time - self.time
def gen_sparse_csr(shape, nnz):
fill_value = 0
total_values = functools.reduce(operator.mul, shape, 1)
dense = np.random.randn(total_values)
fills = random.sample(list(range(total_values)), total_values - nnz)
for f in fills:
dense[f] = fill_value
dense = torch.from_numpy(dense.reshape(shape))
return dense.to_sparse_csr()
def gen_sparse_coo(shape, nnz):
dense = np.random.randn(*shape)
values = []
indices = [[], []]
for n in range(nnz):
row = random.randint(0, shape[0] - 1)
col = random.randint(0, shape[1] - 1)
indices[0].append(row)
indices[1].append(col)
values.append(dense[row, col])
return torch.sparse_coo_tensor(indices, values, size=shape)
def gen_sparse_coo_and_csr(shape, nnz):
total_values = functools.reduce(operator.mul, shape, 1)
dense = np.random.randn(total_values)
fills = random.sample(list(range(total_values)), total_values - nnz)
for f in fills:
dense[f] = 0
dense = torch.from_numpy(dense.reshape(shape))
return dense.to_sparse(), dense.to_sparse_csr()

View File

@ -34,6 +34,8 @@ enum class Backend {
XPU,
SparseCPU,
SparseCUDA,
SparseCsrCPU,
SparseCsrCUDA,
SparseHIP,
SparseXPU,
MSNPU,
@ -74,6 +76,10 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::SparseCUDA;
} else if (t == DispatchKey::SparseHIP) {
return Backend::SparseHIP;
} else if (t == DispatchKey::SparseCsrCPU) {
return Backend::SparseCsrCPU;
} else if (t == DispatchKey::SparseCsrCUDA) {
return Backend::SparseCsrCUDA;
} else if (t == DispatchKey::MkldnnCPU) {
return Backend::MkldnnCPU;
} else if (t == DispatchKey::QuantizedCPU) {
@ -117,6 +123,10 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::SparseCUDA;
case Backend::SparseHIP:
return DispatchKey::SparseHIP;
case Backend::SparseCsrCPU:
return DispatchKey::SparseCsrCPU;
case Backend::SparseCsrCUDA:
return DispatchKey::SparseCsrCUDA;
case Backend::MkldnnCPU:
return DispatchKey::MkldnnCPU;
case Backend::Vulkan:
@ -156,6 +166,10 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::CUDA;
case Backend::SparseHIP:
return DeviceType::HIP;
case Backend::SparseCsrCPU:
return DeviceType::CPU;
case Backend::SparseCsrCUDA:
return DeviceType::CUDA;
case Backend::XPU:
case Backend::SparseXPU:
case Backend::QuantizedXPU:
@ -205,6 +219,10 @@ static inline const char* toString(Backend b) {
return "SparseHIP";
case Backend::SparseXPU:
return "SparseXPU";
case Backend::SparseCsrCPU:
return "SparseCsrCPU";
case Backend::SparseCsrCUDA:
return "SparseCsrCUDA";
case Backend::MkldnnCPU:
return "MkldnnCPU";
case Backend::Vulkan:
@ -234,4 +252,14 @@ static inline bool isSparse(Backend b) {
}
}
static inline bool isSparseCsr(Backend b) {
switch(b) {
case Backend::SparseCsrCPU:
case Backend::SparseCsrCUDA:
return true;
default:
return false;
}
}
} // namespace c10

View File

@ -11,6 +11,7 @@ const char* toString(DispatchKey t) {
return "CPU";
case DispatchKey::CUDA:
return "CUDA";
case DispatchKey::HIP:
return "HIP";
case DispatchKey::FPGA:
@ -51,6 +52,10 @@ const char* toString(DispatchKey t) {
return "SparseCPU";
case DispatchKey::SparseCUDA:
return "SparseCUDA";
case DispatchKey::SparseCsrCPU:
return "SparseCsrCPU";
case DispatchKey::SparseCsrCUDA:
return "SparseCsrCUDA";
case DispatchKey::SparseHIP:
return "SparseHIP";
case DispatchKey::SparseXPU:

View File

@ -112,6 +112,9 @@ enum class DispatchKey : uint8_t {
// [Masquerading as CUDA]
SparseXPU, // For out of tree Intel's heterogeneous computing plug-in
SparseCsrCPU,
SparseCsrCUDA,
NestedTensor, // lives out of tree at https://github.com/pytorch/nestedtensor
// Here are reserved backends for user-defined backends, see Note [Private use
// DispatchKey]

View File

@ -235,7 +235,9 @@ constexpr DispatchKeySet autogradother_backends = DispatchKeySet({
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::SparseHIP,
DispatchKey::Meta,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
DispatchKey::Meta
});
// The set of dispatch keys that come after autograd

View File

@ -6,10 +6,11 @@
#include <iostream>
namespace c10 {
enum class Layout : int8_t { Strided, Sparse, Mkldnn, NumOptions };
enum class Layout : int8_t { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
constexpr auto kStrided = Layout::Strided;
constexpr auto kSparse = Layout::Sparse;
constexpr auto kSparseCsr = Layout::SparseCsr;
constexpr auto kMkldnn = Layout::Mkldnn;
inline Layout layout_from_backend(Backend backend) {
@ -21,6 +22,9 @@ inline Layout layout_from_backend(Backend backend) {
return Layout::Sparse;
case Backend::MkldnnCPU:
return Layout::Mkldnn;
case Backend::SparseCsrCPU:
case Backend::SparseCsrCUDA:
return Layout::SparseCsr;
default:
return Layout::Strided;
}
@ -32,6 +36,8 @@ inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) {
return stream << "Strided";
case at::kSparse:
return stream << "Sparse";
case at::kSparseCsr:
return stream << "SparseCsr";
case at::kMkldnn:
return stream << "Mkldnn";
default:

View File

@ -590,6 +590,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
key_set_.has(DispatchKey::SparseXPU);
}
// Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR format.
bool is_sparse_csr() const {
return key_set_.has(DispatchKey::SparseCsrCPU) ||
key_set_.has(DispatchKey::SparseCsrCUDA);
}
bool is_quantized() const {
// NB: This method is not virtual and avoid dispatches for performance reasons.
return key_set_.has(DispatchKey::QuantizedCPU) ||

View File

@ -337,6 +337,10 @@ struct C10_API TensorOptions {
return layout_ == c10::Layout::Sparse;
}
bool is_sparse_csr() const {
return layout_ == c10::Layout::SparseCsr;
}
// For compatibility with legacy tensor.type() comparisons
bool type_equal(const TensorOptions& other) const {
return computeDispatchKey() == other.computeDispatchKey() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
@ -660,6 +664,15 @@ inline DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::opti
default:
TORCH_CHECK_NOT_IMPLEMENTED(false, "Unsupported device type for mkldnn layout: ", device_.type());
}
case Layout::SparseCsr:
switch(device_.type()) {
case DeviceType::CPU:
return DispatchKey::SparseCsrCPU;
case DeviceType::CUDA:
return DispatchKey::SparseCsrCUDA;
default:
AT_ERROR("Unsupported device type for sparse CSR layout: ", device_.type());
}
default:
TORCH_CHECK(false, "Unsupported layout: ", layout_);
}
@ -671,6 +684,8 @@ inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
case DispatchKey::SparseCUDA:
case DispatchKey::SparseHIP:
case DispatchKey::SparseXPU:
case DispatchKey::SparseCsrCPU:
case DispatchKey::SparseCsrCUDA:
return Layout::Sparse;
case DispatchKey::MkldnnCPU:
return Layout::Mkldnn;

View File

@ -126,6 +126,7 @@ If you don't see an operation listed here, but it would help your use case, plea
:meth:`Tensor.is_shared`,None
":meth:`Tensor.is_signed`, :func:`torch.is_signed`",None
:attr:`Tensor.is_sparse`,None
:attr:`Tensor.is_sparse_csr`,None
:func:`torch.is_tensor`,None
:meth:`Tensor.item`,None
":meth:`Tensor.kthvalue`, :func:`torch.kthvalue`",:ref:`removes_dimensions-doc`

View File

@ -53,8 +53,8 @@ __ https://en.wikipedia.org/wiki/Sparse_matrix
Sparse COO tensors
++++++++++++++++++
Currently, PyTorch implements the so-called Coordinate format, or COO
format, as the default sparse storage format for storing sparse
PyTorch implements the so-called Coordinate format, or COO
format, as one of the storage formats for implementing sparse
tensors. In COO format, the specified elements are stored as tuples
of element indices and the corresponding values. In particular,
@ -363,6 +363,82 @@ assumption that the fill value is negative infinity.
.. See https://github.com/Quansight-Labs/rfcs/tree/pearu/rfc-fill-value/RFC-0004-sparse-fill-value for a new API
.. _sparse-csr-docs:
Sparse CSR Tensor
+++++++++++++++++
The CSR (Compressed Sparse Row) sparse tensor format implements the CSR format
for storage of 2 dimensional tensors. Although there is no support for N-dimensional
tensors, the primary advantage over the COO format is better use of storage and
much faster computation operations such as sparse matrix-vector multiplication
using MKL and MAGMA backends. CUDA support does not exist as of now.
A CSR sparse tensor consists of three 1-D tensors: ``crow_indices``, ``col_indices``
and ``values``:
- The ``crow_indices`` tensor consists of compressed row indices. This is a 1-D tensor
of size ``size[0] + 1``. The last element is the number of non-zeros. This tensor
encodes the index in ``values`` and ``col_indices`` depending on where the given row
starts. Each successive number in the tensor subtracted by the number before it denotes
the number of elements in a given row.
- The ``col_indices`` tensor contains the column indices of each value. This is a 1-D
tensor of size ``nnz``.
- The ``values`` tensor contains the values of the CSR tensor. This is a 1-D tensor
of size ``nnz``.
.. note::
The index tensors ``crow_indices`` and ``col_indices`` should have element type either
``torch.int64`` (default) or ``torch.int32``. If you want to use MKL-enabled matrix
operations, use ``torch.int32``. This is as a result of the default linking of pytorch
being with MKL LP64, which uses 32 bit integer indexing.
Construction of CSR tensors
---------------------------
Sparse CSR matrices can be directly constructed by using the :func:`torch.sparse_csr_tensor`
method. The user must supply the row and column indices and values tensors separately.
The ``size`` argument is optional and will be deduced from the the ``crow_indices``
and ``col_indices`` if it is not present.
>>> crow_indices = torch.tensor([0, 2, 4])
>>> col_indices = torch.tensor([0, 1, 0, 1])
>>> values = torch.tensor([1, 2, 3, 4])
>>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.double)
>>> csr
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64)
>>> csr.to_dense()
tensor([[1., 2.],
[3., 4.]], dtype=torch.float64)
CSR Tensor Operations
---------------------
The simplest way of constructing a sparse CSR tensor from a strided or sparse COO
tensor is to use :meth:`tensor.to_sparse_csr`. Any zeros in the (strided) tensor will
be interpreted as missing values in the sparse tensor:
>>> a = torch.tensor([[0, 0, 1, 0], [1, 2, 0, 0], [0, 0, 0, 0]], dtype = torch.float64)
>>> sp = a.to_sparse_csr()
>>> sp
tensor(crow_indices=tensor([0, 1, 3, 3]),
col_indices=tensor([2, 0, 1]),
values=tensor([1., 1., 2.]), size=(3, 4), nnz=3, dtype=torch.float64)
The sparse matrix-vector multiplication can be performed with the
:meth:`tensor.matmul` method. This is currently the only math operation
supported on CSR tensors.
>>> vec = torch.randn(4, 1, dtype=torch.float64)
>>> sp.matmul(vec)
tensor([[0.9078],
[1.3180],
[0.0000]], dtype=torch.float64)
Supported Linear Algebra operations
+++++++++++++++++++++++++++++++++++
@ -380,7 +456,9 @@ multiplication, and ``@`` is matrix multiplication.
:delim: ;
:func:`torch.mv`;no; ``M[sparse_coo] @ V[strided] -> V[strided]``
:func:`torch.mv`;no; ``M[sparse_csr] @ V[strided] -> V[strided]``
:func:`torch.matmul`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
:func:`torch.matmul`; no; ``M[sparse_csr] @ M[strided] -> M[strided]``
:func:`torch.mm`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
:func:`torch.sparse.mm`; yes; ``M[sparse_coo] @ M[strided] -> M[strided]``
:func:`torch.smm`; no; ``M[sparse_coo] @ M[strided] -> M[sparse_coo]``
@ -405,8 +483,6 @@ matrix arguments.
applications can still compute this using the matrix relation ``D @
S == (S.t() @ D.t()).t()``.
Tensor methods and sparse
+++++++++++++++++++++++++
@ -420,6 +496,7 @@ The following Tensor methods are related to sparse tensors:
Tensor.sparse_dim
Tensor.sparse_mask
Tensor.to_sparse
Tensor.to_sparse_csr
Tensor.indices
Tensor.values
@ -435,6 +512,14 @@ The following Tensor methods are specific to sparse COO tensors:
Tensor.is_coalesced
Tensor.to_dense
The following methods are specific to :ref:`sparse CSR tensors <sparse-csr-docs>`:
.. autosummary::
:nosignatures:
Tensor.crow_indices
Tensor.col_indices
The following Tensor methods support sparse COO tensors:
:meth:`~torch.Tensor.add`
@ -496,6 +581,7 @@ Torch functions specific to sparse Tensors
:nosignatures:
sparse_coo_tensor
sparse_csr_tensor
sparse.sum
sparse.addmm
sparse.mm

View File

@ -1803,6 +1803,23 @@ graph(%Ra, %Rb):
self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
@suppress_warnings
def test_sparse_csr_tensors(self):
@torch.jit.ignore
def get_sparse_csr():
return torch.randn(3, 3).to_sparse_csr()
@torch.jit.script
def test_is_sparse_csr(input):
# type: (Tensor) -> bool
return input.is_sparse_csr
script_out_is_sparse_csr = test_is_sparse_csr(get_sparse_csr())
script_out_is_dense_csr = test_is_sparse_csr(torch.randn(3, 3))
self.assertEqual(script_out_is_sparse_csr, True)
self.assertEqual(script_out_is_dense_csr, False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_device_not_equal(self):

202
test/test_sparse_csr.py Normal file
View File

@ -0,0 +1,202 @@
import torch
torch.set_default_dtype(torch.double)
import functools
import random
import operator
import numpy as np
import warnings
from torch.testing._internal.common_utils import TestCase, run_tests, load_tests
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
class TestSparseCSR(TestCase):
def gen_sparse_csr(self, shape, nnz):
total_values = functools.reduce(operator.mul, shape, 1)
dense = np.random.randn(total_values)
fills = random.sample(list(range(total_values)), total_values - nnz)
for f in fills:
dense[f] = 0
dense = torch.from_numpy(dense.reshape(shape))
return dense.to_sparse_csr()
def setUp(self):
# These parameters control the various ways we can run the test.
# We will subclass and override this method to implement CUDA
# tests
self.is_cuda = False
self.device = 'cpu'
self.index_tensor = lambda *args: torch.tensor(*args, dtype=torch.int32)
self.value_tensor = lambda *args: torch.tensor(*args, dtype=torch.double)
def test_csr_layout(self):
self.assertEqual(str(torch.sparse_csr), 'torch.sparse_csr')
self.assertEqual(type(torch.sparse_csr), torch.layout)
def test_sparse_csr_constructor_shape_inference(self):
crow_indices = [0, 2, 4]
col_indices = [0, 1, 0, 1]
values = [1, 2, 3, 4]
sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
torch.tensor(col_indices, dtype=torch.int64),
torch.tensor(values), dtype=torch.double)
self.assertEqual(torch.tensor(crow_indices, dtype=torch.int64), sparse.crow_indices())
self.assertEqual((len(crow_indices) - 1, max(col_indices) + 1), sparse.shape)
def test_sparse_csr_constructor(self):
crow_indices = [0, 2, 4]
col_indices = [0, 1, 0, 1]
values = [1, 2, 3, 4]
sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int32),
torch.tensor(col_indices, dtype=torch.int32),
torch.tensor(values), size=(2, 10), dtype=torch.float)
self.assertEqual((2, 10), sparse.shape)
self.assertEqual(torch.tensor(crow_indices, dtype=torch.int32), sparse.crow_indices())
def test_sparse_csr_print(self):
shape_nnz = [
((1000, 10), 10)
]
printed = []
for shape, nnz in shape_nnz:
values_shape = torch.Size((nnz,))
col_indices_shape = torch.Size((nnz,))
crow_indices_shape = torch.Size((shape[0] + 1,))
printed.append("# shape: {}".format(torch.Size(shape)))
printed.append("# nnz: {}".format(nnz))
printed.append("# crow_indices shape: {}".format(crow_indices_shape))
printed.append("# col_indices shape: {}".format(col_indices_shape))
printed.append("# values_shape: {}".format(values_shape))
x = self.gen_sparse_csr(shape, nnz)
printed.append("# sparse tensor")
printed.append(str(x))
printed.append("# _crow_indices")
printed.append(str(x.crow_indices()))
printed.append("# _col_indices")
printed.append(str(x.col_indices()))
printed.append("# _values")
printed.append(str(x.values()))
printed.append('')
self.assertEqual(len(printed) > 0, True)
def test_sparse_csr_from_dense(self):
sp = torch.tensor([[1, 2], [3, 4]]).to_sparse_csr()
self.assertEqual(torch.tensor([0, 2, 4], dtype=torch.int64), sp.crow_indices())
self.assertEqual(torch.tensor([0, 1, 0, 1], dtype=torch.int64), sp.col_indices())
self.assertEqual(torch.tensor([1, 2, 3, 4], dtype=torch.int64), sp.values())
dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]])
sparse = dense.to_sparse_csr()
self.assertEqual(torch.tensor([0, 2, 2, 3], dtype=torch.int64), sparse.crow_indices())
self.assertEqual(torch.tensor([0, 1, 0], dtype=torch.int64), sparse.col_indices())
self.assertEqual(torch.tensor([4, 5, 1]), sparse.values())
dense = torch.tensor([[0, 0, 0], [0, 0, 1], [1, 0, 0]])
sparse = dense.to_sparse_csr()
self.assertEqual(torch.tensor([0, 0, 1, 2], dtype=torch.int64), sparse.crow_indices())
self.assertEqual(torch.tensor([2, 0], dtype=torch.int64), sparse.col_indices())
self.assertEqual(torch.tensor([1, 1]), sparse.values())
dense = torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]])
sparse = dense.to_sparse_csr()
self.assertEqual(torch.tensor([0, 3, 6, 9], dtype=torch.int64), sparse.crow_indices())
self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices())
self.assertEqual(torch.tensor([2] * 9), sparse.values())
def test_dense_convert(self):
size = (5, 5)
dense = torch.randn(size)
sparse = dense.to_sparse_csr()
self.assertEqual(sparse.to_dense(), dense)
size = (4, 6)
dense = torch.randn(size)
sparse = dense.to_sparse_csr()
self.assertEqual(sparse.to_dense(), dense)
crow_indices = torch.tensor([0, 3, 5])
col_indices = torch.tensor([0, 1, 2, 0, 1])
values = torch.tensor([1, 2, 1, 3, 4])
csr = torch.sparse_csr_tensor(crow_indices, col_indices,
values, dtype=torch.double)
dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=torch.double)
self.assertEqual(csr.to_dense(), dense)
def test_coo_to_csr_convert(self):
size = (5, 5)
dense = torch.randn(size)
sparse_coo = dense.to_sparse()
sparse_csr = sparse_coo.to_sparse_csr()
self.assertTrue(sparse_csr.is_sparse_csr)
self.assertEqual(sparse_csr.to_dense(), dense)
vec = torch.randn((5, 1))
coo_product = sparse_coo.matmul(vec)
csr_product = sparse_csr.matmul(vec)
self.assertEqual(coo_product, csr_product)
vec = torch.randn((100, 1))
index = self.index_tensor([
[1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
[92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
])
values = self.value_tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
coo = torch.sparse_coo_tensor(index, values, torch.Size([100, 100]))
csr = coo.to_sparse_csr()
self.assertEqual(coo.matmul(vec), csr.matmul(vec))
def test_mkl_matvec_warnings(self):
if torch.has_mkl:
sp = torch.sparse_csr_tensor(torch.tensor([0, 2, 4]),
torch.tensor([0, 1, 0, 1]),
torch.tensor([1, 2, 3, 4], dtype=torch.double))
vec = torch.randn((2, 1))
with warnings.catch_warnings(record=True) as w:
sp.matmul(vec)
self.assertEqual(len(w), 2)
def test_dense_convert_error(self):
size = (4, 2, 4)
dense = torch.randn(size)
with self.assertRaisesRegex(RuntimeError, "Only 2D"):
sparse = dense.to_sparse_csr()
def test_csr_matvec(self):
side = 100
csr = self.gen_sparse_csr((side, side), 1000)
vec = torch.randn(side, dtype=torch.double)
res = csr.matmul(vec)
expected = csr.to_dense().matmul(vec)
self.assertEqual(res, expected)
bad_vec = torch.randn(side + 10, dtype=torch.double)
with self.assertRaisesRegex(RuntimeError, "mv: expected"):
csr.matmul(bad_vec)
def test_coo_csr_conversion(self):
size = (5, 5)
dense = torch.randn(size)
coo_sparse = dense.to_sparse()
csr_sparse = coo_sparse.to_sparse_csr()
self.assertEqual(csr_sparse.to_dense(), dense)
if __name__ == '__main__':
run_tests()

View File

@ -64,9 +64,10 @@ except ImportError:
# These functions require manual Python bindings or are not exposed to Python
SKIP_PYTHON_BINDINGS = [
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'is_sparse_csr', 'size', 'stride',
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
'_?sparse_csr_tensor.*',
'_arange.*', '_range.*', '_linspace.*', '_logspace.*',
'_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out',
'index', 'unique_dim_consecutive',

View File

@ -406,6 +406,14 @@ static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor & self) {
static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs);
static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
@ -485,6 +493,7 @@ static PyMethodDef torch_functions[] = {
{"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},

View File

@ -672,6 +672,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/ScalarOps.cpp",
"aten/src/ATen/SequenceNumber.cpp",
"aten/src/ATen/SparseTensorImpl.cpp",
"aten/src/ATen/SparseCsrTensorImpl.cpp",
"aten/src/ATen/SparseTensorUtils.cpp",
"aten/src/ATen/TensorGeometry.cpp",
"aten/src/ATen/TensorIndexing.cpp",
@ -720,6 +721,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/native/DispatchStub.cpp",
"aten/src/ATen/native/UpSample.cpp",
"aten/src/ATen/native/mkl/LinearAlgebra.cpp",
"aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp",
"aten/src/ATen/native/mkl/SpectralOps.cpp",
"aten/src/ATen/native/mkldnn/BinaryOps.cpp",
"aten/src/ATen/native/mkldnn/Conv.cpp",
@ -967,7 +969,9 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/sparse/SoftMax.cpp",
"aten/src/ATen/native/sparse/SparseMatMul.cpp",
"aten/src/ATen/native/sparse/SparseTensor.cpp",
"aten/src/ATen/native/sparse/SparseCsrTensor.cpp",
"aten/src/ATen/native/sparse/SparseTensorMath.cpp",
"aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
"aten/src/TH/THAllocator.cpp",
"aten/src/TH/THBlas.cpp",
"aten/src/TH/THGeneral.cpp",

View File

@ -857,9 +857,11 @@ def main() -> None:
dispatch_keys = [
DispatchKey.CPU,
DispatchKey.SparseCPU,
DispatchKey.SparseCsrCPU,
DispatchKey.MkldnnCPU,
DispatchKey.CUDA,
DispatchKey.SparseCUDA,
DispatchKey.SparseCsrCUDA,
DispatchKey.QuantizedCPU,
DispatchKey.QuantizedCUDA,
DispatchKey.CompositeImplicitAutograd,

View File

@ -74,6 +74,8 @@ class DispatchKey(Enum):
MkldnnCPU = auto()
SparseCPU = auto()
SparseCUDA = auto()
SparseCsrCPU = auto()
SparseCsrCUDA = auto()
SparseHIP = auto()
SparseXPU = auto()
NestedTensor = auto()

View File

@ -294,6 +294,10 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Tensor, col_indices: Tensor,'
' values: Tensor, size: Optional[_size]=None,'
' *, dtype: Optional[_dtype]=None,'
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
'_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
' requires_grad: bool = False) -> Tensor: ...'],
@ -446,6 +450,7 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
'is_cuda': ['is_cuda: _bool'],
'is_leaf': ['is_leaf: _bool'],
'is_sparse': ['is_sparse: _bool'],
'is_sparse_csr' : ['is_sparse_csr: _bool'],
'is_quantized': ['is_quantized: _bool'],
'is_meta': ['is_meta: _bool'],
'is_mkldnn': ['is_mkldnn: _bool'],

View File

@ -908,6 +908,40 @@ class Tensor(torch._C._TensorBase):
# See Note [rename_ / rename API]
return update_names(self, names, rename_map, inplace=False)
def to_sparse_csr(self):
""" Convert a tensor to compressed row storage format. Only works with 2D tensors.
Examples::
>>> dense = torch.randn(5, 5)
>>> sparse = dense.to_sparse_csr()
>>> sparse._nnz()
25
"""
shape = self.size()
fill_value = 0
if len(shape) != 2:
raise RuntimeError("Only 2D tensors can be converted to the CSR format but got shape: ", shape)
if self.is_sparse:
coalesced_self = self.coalesce()
row_indices = coalesced_self.indices()[0]
ro = [0]
i = 0
for irow in range(self.shape[0]):
while i < row_indices.size()[0] and row_indices[i] == irow:
i += 1
ro.append(i)
return torch.sparse_csr_tensor(torch.tensor(ro, dtype=row_indices.dtype),
coalesced_self.indices()[1], coalesced_self.values(),
size=coalesced_self.shape, dtype=coalesced_self.dtype)
elif self.is_sparse_csr:
return self
else:
return self.to_sparse().to_sparse_csr()
def _update_names(self, names, inplace):
if has_torch_function_unary(self):
return handle_torch_function(Tensor._update_names, (self,), self, names, inplace)

View File

@ -4651,6 +4651,11 @@ add_docstr_all('is_sparse',
Is ``True`` if the Tensor uses sparse storage layout, ``False`` otherwise.
""")
add_docstr_all('is_sparse_csr',
r"""
Is ``True`` if the Tensor uses sparse CSR storage layout, ``False`` otherwise.
""")
add_docstr_all('device',
r"""
Is the :class:`torch.device` where this Tensor is.
@ -4711,3 +4716,39 @@ Makes a ``cls`` instance with the same data pointer as ``self``. Changes
in the output mirror changes in ``self``, and the output stays attached
to the autograd graph. ``cls`` must be a subclass of ``Tensor``.
""")
add_docstr_all('crow_indices',
r"""
crow_indices() -> IntTensor
Returns the tensor containing the compressed row indices of the :attr:`self`
tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``.
The ``crow_indices`` tensor is strictly of shape (:attr:`self`.size(0) + 1)
and of type ``int32`` or ``int64``. When using MKL routines such as sparse
matrix multiplication, it is necessary to use ``int32`` indexing in order
to avoid downcasting and potentially losing information.
Example::
>>> csr = torch.eye(5,5).to_sparse_csr()
>>> csr.crow_indices()
tensor([0, 1, 2, 3, 4, 5], dtype=torch.int32)
""")
add_docstr_all('col_indices',
r"""
col_indices() -> IntTensor
Returns the tensor containing the column indices of the :attr:`self`
tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``.
The ``col_indices`` tensor is strictly of shape (:attr:`self`.nnz())
and of type ``int32`` or ``int64``. When using MKL routines such as sparse
matrix multiplication, it is necessary to use ``int32`` indexing in order
to avoid downcasting and potentially losing information.
Example::
>>> csr = torch.eye(5,5).to_sparse_csr()
>>> csr.col_indices()
tensor([0, 1, 2, 3, 4], dtype=torch.int32)
""")

View File

@ -315,6 +315,29 @@ def _str_intern(inp):
if values.numel() == 0:
values_str += ', size=' + str(tuple(values.shape))
tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
elif self.is_sparse_csr:
suffixes.append('size=' + str(tuple(self.shape)))
suffixes.append('nnz=' + str(self._nnz()))
if not has_default_dtype:
suffixes.append('dtype=' + str(self.dtype))
crow_indices_prefix = 'crow_indices=tensor('
crow_indices = self.crow_indices().detach()
crow_indices_str = _tensor_str(crow_indices, indent + len(crow_indices_prefix))
if crow_indices.numel() == 0:
crow_indices_str += ', size=' + str(tuple(crow_indices.shape))
col_indices_prefix = 'col_indices=tensor('
col_indices = self.col_indices().detach()
col_indices_str = _tensor_str(col_indices, indent + len(col_indices_prefix))
if col_indices.numel() == 0:
col_indices_str += ', size=' + str(tuple(col_indices.shape))
values_prefix = 'values=tensor('
values = self.values().detach()
values_str = _tensor_str(values, indent + len(values_prefix))
if values.numel() == 0:
values_str += ', size=' + str(tuple(values.shape))
tensor_str = crow_indices_prefix + crow_indices_str + '),\n' + ' ' * indent +\
col_indices_prefix + col_indices_str + '),\n' + ' ' * indent +\
values_prefix + values_str + ')'
elif self.is_quantized:
suffixes.append('size=' + str(tuple(self.shape)))
if not has_default_dtype:

View File

@ -7880,6 +7880,49 @@ Example::
[-0.0881, 0.4370, 0.2275, 1.0284]])
""".format(**common_args))
add_docstr(torch.sparse_csr_tensor,
r"""
sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) <sparse-csr-docs>` with specified
values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations
in CSR format are typically faster than that for sparse tensors in COO format. Make you have a look
at :ref:`the note on the data type of the indices <sparse-csr-docs>`.
Args:
crow_indices (array_like): One-dimensional array of size size[0] + 1. The last element
is the number of non-zeros. This tensor encodes the index in values and col_indices
depending on where the given row starts. Each successive number in the tensor
subtracted by the number before it denotes the number of elements in a given row.
col_indices (array_like): Column co-ordinates of each element in values. Strictly one
dimensional tensor with the same length as values.
values (array_list): Initial values for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar,
and other types.
size (list, tuple, :class:`torch.Size`, optional): Size of the sparse tensor. If not provided, the
size will be inferred as the minimum size big enough to hold all non-zero elements.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if None, infers data type from :attr:`values`.
device (:class:`torch.device`, optional): the desired device of returned tensor.
Default: if None, uses the current device for the default tensor type
(see :func:`torch.set_default_tensor_type`). :attr:`device` will be the CPU
for CPU tensor types and the current CUDA device for CUDA tensor types.
{requires_grad}
Example ::
>>> crow_indices = [0, 2, 4]
>>> col_indices = [0, 1, 0, 1]
>>> values = [1, 2, 3, 4]
>>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
... torch.tensor(col_indices, dtype=torch.int64),
... torch.tensor(values), dtype=torch.double)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
""".format(**factory_common_args))
add_docstr(torch.sparse_coo_tensor,
r"""
sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor

View File

@ -599,6 +599,17 @@ PyObject *THPVariable_is_sparse(THPVariable *self, void *unused)
END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_is_sparse_csr(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "is_sparse_csr");
}
auto& self_ = self->cdata;
return torch::autograd::utils::wrap(self_.is_sparse_csr());
END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
@ -761,6 +772,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
{"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},

View File

@ -103,25 +103,16 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
static const PropertiesLookup builtin_properties = {
{TypeKind::TensorType,
{
{"dtype", "prim"},
{"device", "prim"},
{"grad", "prim"},
{"data", "prim"},
{"shape", "prim"},
{"is_cuda", "prim"},
{"is_xpu", "prim"},
{"is_sparse", "prim"},
{"is_mkldnn", "prim"},
{"is_mlc", "prim"},
{"is_quantized", "prim"},
{"is_vulkan", "prim"},
{"is_meta", "prim"},
{"is_leaf", "aten"},
{"requires_grad", "prim"},
{"layout", "prim"},
{"T", "prim"},
{"ndim", "prim"},
{"name", "prim"},
{"dtype", "prim"}, {"device", "prim"},
{"grad", "prim"}, {"data", "prim"},
{"shape", "prim"}, {"is_cuda", "prim"},
{"is_xpu", "prim"}, {"is_sparse", "prim"},
{"is_sparse_csr", "prim"}, {"is_mkldnn", "prim"},
{"is_mlc", "prim"}, {"is_quantized", "prim"},
{"is_vulkan", "prim"}, {"is_meta", "prim"},
{"is_leaf", "aten"}, {"requires_grad", "prim"},
{"layout", "prim"}, {"T", "prim"},
{"ndim", "prim"}, {"name", "prim"},
}},
{TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
auto kind = value_->type()->kind();

View File

@ -281,6 +281,14 @@ RegisterOperators reg(
push(stack, a.is_sparse());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_sparse_csr(Tensor a) -> bool",
[](Stack* stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_sparse_csr());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_mkldnn(Tensor a) -> bool",
[](Stack* stack) {

View File

@ -123,6 +123,14 @@ PyObject *Tensor_is_sparse(PyTensorType *self, void *unused) {
}
}
PyObject *Tensor_is_sparse_csr(PyTensorType *self, void *unused) {
if (self->layout->layout == at::Layout::SparseCsr) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
static struct PyMethodDef metaclass_methods[] = {
{"__instancecheck__", Tensor_instancecheck, METH_O, nullptr},
{nullptr}
@ -135,6 +143,7 @@ static struct PyGetSetDef metaclass_properties[] = {
{"layout", (getter)Tensor_layout, nullptr, nullptr, nullptr},
{"is_cuda", (getter)Tensor_is_cuda, nullptr, nullptr, nullptr},
{"is_sparse", (getter)Tensor_is_sparse, nullptr, nullptr, nullptr},
{"is_sparse_csr",(getter)Tensor_is_sparse_csr, nullptr, nullptr, nullptr},
{nullptr}
};

View File

@ -1,4 +1,3 @@
#include <torch/csrc/utils/tensor_layouts.h>
#include <ATen/Layout.h>
#include <c10/core/ScalarType.h>
#include <torch/csrc/DynamicTypes.h>
@ -6,6 +5,7 @@
#include <torch/csrc/Layout.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/tensor_layouts.h>
namespace torch { namespace utils {
@ -13,21 +13,29 @@ void initializeLayouts() {
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
if (!torch_module) throw python_error();
PyObject *strided_layout = THPLayout_New(at::Layout::Strided, "torch.strided");
PyObject* strided_layout = THPLayout_New(at::Layout::Strided, "torch.strided");
Py_INCREF(strided_layout);
if (PyModule_AddObject(torch_module, "strided", strided_layout) != 0) {
throw python_error();
}
registerLayoutObject((THPLayout*)strided_layout, at::Layout::Strided);
PyObject *sparse_coo_layout = THPLayout_New(at::Layout::Sparse, "torch.sparse_coo");
PyObject* sparse_coo_layout = THPLayout_New(at::Layout::Sparse, "torch.sparse_coo");
Py_INCREF(sparse_coo_layout);
if (PyModule_AddObject(torch_module, "sparse_coo", sparse_coo_layout) != 0) {
throw python_error();
}
registerLayoutObject((THPLayout*)sparse_coo_layout, at::Layout::Sparse);
PyObject *mkldnn_layout = THPLayout_New(at::Layout::Mkldnn, "torch._mkldnn");
PyObject* sparse_csr_layout =
THPLayout_New(at::Layout::SparseCsr, "torch.sparse_csr");
Py_INCREF(sparse_csr_layout);
if (PyModule_AddObject(torch_module, "sparse_csr", sparse_csr_layout) != 0) {
throw python_error();
}
registerLayoutObject((THPLayout*)sparse_csr_layout, at::Layout::SparseCsr);
PyObject* mkldnn_layout = THPLayout_New(at::Layout::Mkldnn, "torch._mkldnn");
Py_INCREF(mkldnn_layout);
if (PyModule_AddObject(torch_module, "_mkldnn", mkldnn_layout) != 0) {
throw python_error();

View File

@ -34,6 +34,7 @@ using at::IntArrayRef;
using at::kCPU;
using at::kCUDA;
using at::kLong;
using at::kInt;
using at::Scalar;
using at::ScalarType;
using at::Storage;
@ -597,6 +598,66 @@ Tensor indexing_tensor_from_data(
}
}
Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
static PythonArgParser parser({
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
});
const int NUM_ARGS = 9, CROW_INDICES_ARG = 0, COL_INDICES_ARG = 1, VALUES_ARG = 2;
ParsedArgs<NUM_ARGS> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
THPObjectPtr crow_indices_dtype_attr(PyObject_GetAttrString(r.pyobject(CROW_INDICES_ARG), "dtype"));
THPObjectPtr col_indices_dtype_attr(PyObject_GetAttrString(r.pyobject(COL_INDICES_ARG), "dtype"));
at::ScalarType crow_indices_scalar_type = reinterpret_cast<THPDtype*>(
crow_indices_dtype_attr.get())->scalar_type;
at::ScalarType col_indices_scalar_type = reinterpret_cast<THPDtype*>(
col_indices_dtype_attr.get())->scalar_type;
if (r.idx == 0) {
const int SIZE_ARRAY_ARG = 3, TYPE_INFERENCE_ARG = 4, DEVICE_TYPE_ARG = 7, REQ_GRAD_ARG = 8;
bool type_inference = r.isNone(TYPE_INFERENCE_ARG);
const auto inferred_options = typeIdWithDefault(r, DEVICE_TYPE_ARG, dispatch_key);
const auto inferred_scalar_type = r.scalartypeWithDefault(TYPE_INFERENCE_ARG, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(DEVICE_TYPE_ARG));
Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
r.pyobject(VALUES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/type_inference);
Tensor crow_indices = internal_new_from_data(values.options(),
crow_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG), r.pyobject(CROW_INDICES_ARG),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/false);
Tensor col_indices = internal_new_from_data(values.options(),
col_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG), r.pyobject(COL_INDICES_ARG),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/false);
return at::sparse_csr_tensor(crow_indices, col_indices, values, r.intlist(SIZE_ARRAY_ARG),
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
} else if (r.idx == 1) {
const int TYPE_INFERENCE_ARG = 3, DEVICE_TYPE_ARG = 6, REQ_GRAD_ARG = 7;
bool type_inference = r.isNone(TYPE_INFERENCE_ARG);
const auto inferred_options = typeIdWithDefault(r, DEVICE_TYPE_ARG, dispatch_key);
const auto inferred_scalar_type = r.scalartypeWithDefault(TYPE_INFERENCE_ARG, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(DEVICE_TYPE_ARG));
Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
r.pyobject(VALUES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/type_inference);
Tensor crow_indices = internal_new_from_data(values.options(),
crow_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
r.pyobject(CROW_INDICES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/false);
Tensor col_indices = internal_new_from_data(values.options(), col_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
r.pyobject(COL_INDICES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/false);
return at::sparse_csr_tensor(crow_indices, col_indices, values,
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
}
throw std::runtime_error("sparse_csr_tensor(): invalid arguments");
}
// Note [Ensuring sparse values and indices match devices]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In all places where we construct indices, we read out options from values

View File

@ -13,6 +13,7 @@ at::Tensor indexing_tensor_from_data(
at::ScalarType scalar_type,
c10::optional<at::Device> device,
PyObject* data);
at::Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor _sparse_coo_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
void _validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);

View File

@ -160,6 +160,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.result_type,
torch.scalar_tensor,
torch.sparse_coo_tensor,
torch.sparse_csr_tensor,
torch.tril_indices,
torch.triu_indices,
torch.vander,
@ -216,6 +217,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._make_subclass,
Tensor.stride,
Tensor.unflatten,
Tensor.to_sparse_csr,
Tensor._reduce_ex_internal,
}
@ -935,6 +937,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1,
Tensor.is_sparse_csr.__get__: lambda self: -1,
Tensor.is_vulkan.__get__: lambda self: -1,
Tensor.layout.__get__: lambda self: -1,
Tensor.name.__get__: lambda self: -1,
@ -954,6 +957,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor._indices: lambda self: -1,
Tensor._is_view: lambda self: -1,
Tensor._nnz: lambda self: -1,
Tensor.crow_indices: lambda self: -1,
Tensor.col_indices: lambda self: -1,
Tensor._update_names: lambda self, names, inplace: -1,
Tensor._values: lambda self: -1,
Tensor.align_as: lambda self, other: -1,