[ATen][CPU][Sparse] Use Third-Party Eigen for sparse add and addmm (#155357)

This pull request adds the following ops for sparse matrices using Eigen library:
```python
    add(a_csr, b_csr)
    add(a_csc, b_csc)

    addmm(c_csr, a_csr, b_csr)
    addmm(c_csr, a_csr, b_csc)
    addmm(c_csr, a_csc, b_csc)
    addmm(c_csr, a_csc, b_csr)

    addmm(c_csc, a_csr, b_csr)
    addmm(c_csc, a_csr, b_csc)
    addmm(c_csc, a_csc, b_csc)
    addmm(c_csc, a_csc, b_csr)
```

Currently, the operations for sparse matrices on CPU are available through MKL only. The non-existence of MKL on `aarch64` causes the unavailability of these ops on any machines with ARM based CPUs, including Apple Silicon, AWS Graviton and NVIDIA Grace. This PR addresses this issue by using Eigen as a backend for the above ops.

This is a re-factored version of my previous PR #101814. The main difference with the old one, this does not enable Eigen by default.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155357
Approved by: https://github.com/pearu, https://github.com/eqy
This commit is contained in:
Aidyn-A
2025-08-20 15:44:54 +00:00
committed by PyTorch MergeBot
parent 90ea9ccefe
commit ce048de608
13 changed files with 423 additions and 8 deletions

View File

@ -279,6 +279,7 @@ header_template_rule(
"@AT_BLAS_F2C@": "0",
"@AT_BLAS_USE_CBLAS_DOT@": "1",
"@AT_KLEIDIAI_ENABLED@": "0",
"@AT_USE_EIGEN_SPARSE@": "0",
},
)

View File

@ -289,6 +289,7 @@ option(USE_PRECOMPILED_HEADERS "Use pre-compiled headers to accelerate build."
option(USE_PROF "Use profiling" OFF)
option(USE_PYTORCH_QNNPACK "Use ATen/QNNPACK (quantized 8-bit operators)" ON)
option(USE_SNPE "Use Qualcomm's SNPE library" OFF)
option(USE_EIGEN_SPARSE "Use Eigen Sparse Matrices" OFF)
option(USE_SYSTEM_EIGEN_INSTALL
"Use system Eigen instead of the one under third_party" OFF)
cmake_dependent_option(

View File

@ -96,6 +96,8 @@ file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
file(GLOB vulkan_cpp "vulkan/*.cpp")
file(GLOB native_vulkan_cpp "native/vulkan/*.cpp" "native/vulkan/api/*.cpp" "native/vulkan/impl/*.cpp" "native/vulkan/ops/*.cpp")
file(GLOB native_eigen_cpp "native/sparse/eigen/*.cpp")
# Metal
file(GLOB metal_h "metal/*.h")
file(GLOB metal_cpp "metal/*.cpp")
@ -341,6 +343,9 @@ if(USE_VULKAN)
else()
set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp})
endif()
if(USE_EIGEN_SPARSE)
set(all_cpu_cpp ${all_cpu_cpp} ${native_eigen_cpp})
endif()
if(USE_MTIA)
set(ATen_MTIA_SRCS ${ATen_MTIA_SRCS} ${mtia_cpp} ${mtia_h} ${native_mtia_cpp} ${native_mtia_h})

View File

@ -20,3 +20,4 @@
#define AT_BLAS_F2C() @AT_BLAS_F2C@
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@
#define AT_USE_EIGEN_SPARSE() @AT_USE_EIGEN_SPARSE@

View File

@ -698,6 +698,14 @@ bool Context::hasLAPACK() {
#endif
}
bool Context::hasEigenSparse() {
#if AT_USE_EIGEN_SPARSE()
return true;
#else
return false;
#endif
}
at::QEngine Context::qEngine() const {
static auto _quantized_engine = []() {
at::QEngine qengine = at::kNoQEngine;

View File

@ -133,6 +133,7 @@ class TORCH_API Context {
static bool hasLAPACK();
static bool hasMKLDNN();
static bool ckSupported();
static bool hasEigenSparse();
static bool hasMAGMA() {
return detail::getCUDAHooks().hasMAGMA();
}
@ -615,6 +616,10 @@ inline bool hasLAPACK() {
return globalContext().hasLAPACK();
}
inline bool hasEigenSparse() {
return globalContext().hasEigenSparse();
}
inline bool hasMAGMA() {
return globalContext().hasMAGMA();
}

View File

@ -23,6 +23,9 @@
#include <ATen/Parallel.h>
#endif
#if AT_USE_EIGEN_SPARSE()
#include <ATen/native/sparse/eigen/SparseBlasImpl.h>
#endif
namespace at::native::sparse::impl {
@ -442,13 +445,15 @@ void add_out_sparse_csr(
const Tensor& mat2,
const Scalar& alpha,
const Tensor& result) {
#if !AT_MKL_ENABLED()
TORCH_CHECK(
false,
"Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#else
#if AT_USE_MKL_SPARSE()
sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result);
#elif AT_USE_EIGEN_SPARSE()
sparse::impl::eigen::add_out_sparse(mat1, mat2, alpha, result);
#else
TORCH_CHECK(
false,
"Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#endif
}
@ -459,7 +464,7 @@ void triangular_solve_out_sparse_csr(
bool upper,
bool transpose,
bool unitriangular) {
#if !AT_MKL_ENABLED()
#if !AT_USE_MKL_SPARSE()
TORCH_CHECK(
false,
"Calling triangular_solve on a sparse CPU tensor requires compiling PyTorch with MKL. ",

View File

@ -127,6 +127,10 @@
#include <ATen/ops/zeros_like.h>
#endif
#if AT_USE_EIGEN_SPARSE()
#include <ATen/native/sparse/eigen/SparseBlasImpl.h>
#endif
#include <algorithm>
namespace at {
@ -536,7 +540,12 @@ static void addmm_out_sparse_csr_native_cpu(
auto values = sparse.values();
scalar_t cast_alpha = alpha.to<scalar_t>();
r.mul_(beta);
// If beta is zero NaN and Inf should not be propagated to the result
if (beta.toComplexDouble() == 0.) {
r.zero_();
} else {
r.mul_(beta);
}
AT_DISPATCH_INDEX_TYPES(
col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
auto csr_accessor = csr.accessor<index_t, 1>();
@ -648,6 +657,15 @@ Tensor& addmm_out_sparse_compressed_cpu(
return result;
}
#if AT_USE_EIGEN_SPARSE()
if ((result.layout() == kSparseCsr || result.layout() == kSparseCsc) &&
(mat1.layout() == kSparseCsr || mat1.layout() == kSparseCsc) &&
(mat2.layout() == kSparseCsr || mat2.layout() == kSparseCsc)) {
sparse::impl::eigen::addmm_out_sparse(mat1, mat2, result, alpha, beta);
return result;
}
#endif
#if !AT_USE_MKL_SPARSE()
// The custom impl addmm_out_sparse_csr_native_cpu only supports CSR @
// strided -> strided

View File

@ -0,0 +1,329 @@
#include <ATen/native/sparse/eigen/SparseBlasImpl.h>
#if AT_USE_EIGEN_SPARSE()
#include <ATen/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/SparseCsrTensorUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty_like.h>
#endif
#include <c10/core/ScalarType.h>
#include <Eigen/SparseCore>
namespace at::native::sparse::impl::eigen {
namespace {
void inline sparse_indices_to_result_dtype_inplace(
const c10::ScalarType& dtype,
const at::Tensor& input) {
auto [compressed_indices, plain_indices] =
at::sparse_csr::getCompressedPlainIndices(input);
static_cast<at::SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
->set_member_tensors(
compressed_indices.to(dtype),
plain_indices.to(dtype),
input.values(),
input.sizes());
}
void inline sparse_indices_and_values_resize(
const at::Tensor& input,
int64_t nnz) {
auto [compressed_indices, plain_indices] =
at::sparse_csr::getCompressedPlainIndices(input);
static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
->set_member_tensors(
compressed_indices,
plain_indices.resize_({nnz}),
input.values().resize_({nnz}),
input.sizes());
}
template <typename scalar_t, int eigen_options, typename index_t>
const Eigen::Map<Eigen::SparseMatrix<scalar_t, eigen_options, index_t>>
Tensor_to_Eigen(const at::Tensor& tensor) {
int64_t rows = tensor.size(0);
int64_t cols = tensor.size(1);
int64_t nnz = tensor._nnz();
TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensor values");
auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(tensor);
index_t* c_indices_ptr = compressed_indices.data_ptr<index_t>();
index_t* p_indices_ptr = plain_indices.data_ptr<index_t>();
scalar_t* values_ptr = tensor.values().data_ptr<scalar_t>();
Eigen::Map<Eigen::SparseMatrix<scalar_t, eigen_options, index_t>> map(
rows, cols, nnz, c_indices_ptr, p_indices_ptr, values_ptr);
return map;
}
template <typename scalar_t, int eigen_options, typename index_t>
void Eigen_to_Tensor(
const at::Tensor& tensor,
const Eigen::SparseMatrix<scalar_t, eigen_options, index_t>& matrix) {
const Layout eigen_layout = (eigen_options == Eigen::RowMajor ? kSparseCsr : kSparseCsc);
TORCH_CHECK(
tensor.layout() == eigen_layout,
"Eigen_to_Tensor, expected tensor be ", eigen_layout, ", but got ",
tensor.layout());
int64_t nnz = matrix.nonZeros();
int64_t csize = matrix.outerSize();
sparse_indices_and_values_resize(tensor, nnz);
auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(tensor);
if (nnz > 0) {
std::memcpy(
tensor.values().mutable_data_ptr<scalar_t>(),
matrix.valuePtr(),
nnz * sizeof(scalar_t));
std::memcpy(
plain_indices.mutable_data_ptr<index_t>(),
matrix.innerIndexPtr(),
nnz * sizeof(index_t));
}
if (csize > 0) {
std::memcpy(
compressed_indices.mutable_data_ptr<index_t>(),
matrix.outerIndexPtr(),
csize * sizeof(index_t));
}
compressed_indices.mutable_data_ptr<index_t>()[csize] = nnz;
}
template <typename scalar_t>
void add_out_sparse_eigen(
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Scalar& alpha,
const at::Tensor& result) {
// empty matrices
if (mat1._nnz() == 0 && mat2._nnz() == 0) {
return;
}
if (mat2._nnz() == 0 || alpha.toComplexDouble() == 0.) {
sparse_indices_and_values_resize(result, mat1._nnz());
result.copy_(mat1);
return;
} else if (mat1._nnz() == 0) {
sparse_indices_and_values_resize(result, mat2._nnz());
result.copy_(mat2);
result.values().mul_(alpha);
return;
}
c10::ScalarType result_index_dtype = at::sparse_csr::getIndexDtype(result);
sparse_indices_to_result_dtype_inplace(result_index_dtype, mat1);
sparse_indices_to_result_dtype_inplace(result_index_dtype, mat2);
AT_DISPATCH_INDEX_TYPES(
result_index_dtype, "eigen_sparse_add", [&]() {
scalar_t _alpha = alpha.to<scalar_t>();
if (result.layout() == kSparseCsr) {
auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat1);
auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
auto mat1_mat2_eigen = (mat1_eigen + _alpha * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(result, mat1_mat2_eigen);
} else {
auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat1);
auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
auto mat1_mat2_eigen = (mat1_eigen + _alpha * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(result, mat1_mat2_eigen);
}
});
}
template <typename scalar_t>
void addmm_out_sparse_eigen(
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Tensor& result,
const at::Scalar& alpha,
const at::Scalar& beta) {
// empty matrices
if (mat1._nnz() == 0 || mat2._nnz() == 0) {
return;
}
// If beta is zero NaN and Inf should not be propagated to the result
// In addition, beta = 0 lets us enable a fast-path for result = alpha * A @ B
bool is_beta_zero = false;
if (beta.toComplexDouble() == 0.) {
is_beta_zero = true;
result.values().zero_();
} else {
result.values().mul_(beta);
}
c10::ScalarType result_index_dtype = at::sparse_csr::getIndexDtype(result);
sparse_indices_to_result_dtype_inplace(result_index_dtype, mat1);
sparse_indices_to_result_dtype_inplace(result_index_dtype, mat2);
AT_DISPATCH_INDEX_TYPES(
result_index_dtype, "eigen_sparse_mm", [&]() {
typedef Eigen::SparseMatrix<scalar_t, Eigen::RowMajor, index_t> EigenCsrMatrix;
typedef Eigen::SparseMatrix<scalar_t, Eigen::ColMajor, index_t> EigenCscMatrix;
at::Tensor mat1_mat2;
if (is_beta_zero) {
mat1_mat2 = result;
} else {
mat1_mat2 = at::empty_like(result, result.options());
}
if (mat1_mat2.layout() == kSparseCsr) {
if (mat1.layout() == kSparseCsr) {
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat1);
if (mat2.layout() == kSparseCsr) {
// Out_csr = M1_csr * M2_csr
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
} else {
// Out_csr = M1_csr * M2_csc
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
}
} else {
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat1);
if (mat2.layout() == kSparseCsr) {
// Out_csr = M1_csc * M2_csr
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
} else {
// Out_csr = M1_csc * M2_csc
// This multiplication will be computationally inefficient, as it will require
// additional conversion of the output matrix from CSC to CSR format.
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
}
}
} else {
if (mat1.layout() == kSparseCsr) {
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat1);
if (mat2.layout() == kSparseCsr) {
// Out_csc = M1_csr * M2_csr
// This multiplication will be computationally inefficient, as it will require
// additional conversion of the output matrix from CSR to CSC format.
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
} else {
// Out_csc = M1_csr * M2_csc
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
}
} else {
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat1);
if (mat2.layout() == kSparseCsr) {
// Out_csc = M1_csc * M2_csr
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
} else {
// Out_csc = M1_csc * M2_csc
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
}
}
}
if (is_beta_zero) {
result.mul_(alpha.to<scalar_t>());
} else {
result.add_(mat1_mat2, alpha.to<scalar_t>());
}
});
}
} // anonymous namespace
void addmm_out_sparse(
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Tensor& result,
const at::Scalar& alpha,
const at::Scalar& beta) {
AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(mat1.layout(), "eigen::addmm_out_sparse:mat1", [&]{});
AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(mat2.layout(), "eigen::addmm_out_sparse:mat2", [&]{});
AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(result.layout(), "eigen::addmm_out_sparse:result", [&]{});
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(), "addmm_out_sparse_eigen", [&] {
addmm_out_sparse_eigen<scalar_t>(mat1, mat2, result, alpha, beta);
});
}
void add_out_sparse(
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Scalar& alpha,
const at::Tensor& result) {
TORCH_CHECK(
(result.layout() == kSparseCsr && mat1.layout() == kSparseCsr && mat2.layout() == kSparseCsr) ||
(result.layout() == kSparseCsc && mat1.layout() == kSparseCsc && mat2.layout() == kSparseCsc),
"eigen::add_out_sparse: expected the same layout for all operands but got ",
mat1.layout(),
" + ",
mat2.layout(),
" -> ",
result.layout());
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(), "add_out_sparse_eigen", [&] {
add_out_sparse_eigen<scalar_t>(mat1, mat2, alpha, result);
});
}
} // namespace at::native::sparse::impl::eigen
#else
namespace at::native::sparse::impl::eigen {
void addmm_out_sparse(
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Tensor& result,
const at::Scalar& alpha,
const at::Scalar& beta) {
TORCH_CHECK(
false,
"eigen::addmm_out_sparse: Eigen was not enabled for ",
result.layout(),
" + ",
mat1.layout(),
" @ ",
mat2.layout());
}
void add_out_sparse(
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Scalar& alpha,
const at::Tensor& result) {
TORCH_CHECK(
false,
"eigen::add_out_sparse: Eigen was not enabled for ",
mat1.layout(),
" + ",
mat2.layout(),
" -> ",
result.layout());
}
} // namespace at::native::sparse::impl::eigen
#endif // AT_USE_EIGEN_SPARSE()

View File

@ -0,0 +1,29 @@
#pragma once
#include <ATen/Config.h>
#if AT_USE_EIGEN_SPARSE()
#ifndef EIGEN_MPL2_ONLY
#define EIGEN_MPL2_ONLY
#endif
#include <ATen/Tensor.h>
namespace at::native::sparse::impl::eigen {
void addmm_out_sparse(
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Tensor& result,
const at::Scalar& alpha,
const at::Scalar& beta);
void add_out_sparse(
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Scalar& alpha,
const at::Tensor& result);
} // namespace at::native::sparse::impl::eigen
#endif

View File

@ -153,6 +153,7 @@ set(AT_MKLDNN_ACL_ENABLED 0)
set(AT_MKLDNN_ENABLED 0)
set(AT_MKL_ENABLED 0)
set(AT_KLEIDIAI_ENABLED 0)
set(AT_USE_EIGEN_SPARSE 0)
# setting default preferred BLAS options if not already present.
if(NOT INTERN_BUILD_MOBILE)
set(BLAS "MKL" CACHE STRING "Selected BLAS library")
@ -262,6 +263,15 @@ if(BLAS_LIBRARIES AND BLAS_CHECK_F2C)
include(cmake/BLAS_ABI.cmake)
endif()
if(USE_EIGEN_SPARSE AND BLAS_INFO STREQUAL "mkl")
message(WARNING "Disabling USE_EIGEN_SPARSE because MKL is enabled")
set(USE_EIGEN_SPARSE OFF)
endif()
if(USE_EIGEN_SPARSE)
set(AT_USE_EIGEN_SPARSE 1)
endif()
if(NOT INTERN_BUILD_MOBILE)
set(AT_MKL_SEQUENTIAL 0)
set(USE_BLAS 1)

View File

@ -135,6 +135,7 @@ function(caffe2_print_configuration_summary)
endif()
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
message(STATUS " USE_EIGEN_FOR_SPARSE : ${USE_EIGEN_SPARSE}")
message(STATUS " USE_FBGEMM : ${USE_FBGEMM}")
message(STATUS " USE_KINETO : ${USE_KINETO}")
message(STATUS " USE_GFLAGS : ${USE_GFLAGS}")

View File

@ -2202,6 +2202,8 @@ Call this whenever a new thread is created in order to propagate values from
set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False));
ASSERT_TRUE(
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
ASSERT_TRUE(set_module_attr(
"_has_eigen_sparse", at::hasEigenSparse() ? Py_True : Py_False));
py_module.def("_valgrind_supported_platform", []() {
#if defined(USE_VALGRIND)