mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ATen][CPU][Sparse] Use Third-Party Eigen for sparse add and addmm (#155357)
This pull request adds the following ops for sparse matrices using Eigen library: ```python add(a_csr, b_csr) add(a_csc, b_csc) addmm(c_csr, a_csr, b_csr) addmm(c_csr, a_csr, b_csc) addmm(c_csr, a_csc, b_csc) addmm(c_csr, a_csc, b_csr) addmm(c_csc, a_csr, b_csr) addmm(c_csc, a_csr, b_csc) addmm(c_csc, a_csc, b_csc) addmm(c_csc, a_csc, b_csr) ``` Currently, the operations for sparse matrices on CPU are available through MKL only. The non-existence of MKL on `aarch64` causes the unavailability of these ops on any machines with ARM based CPUs, including Apple Silicon, AWS Graviton and NVIDIA Grace. This PR addresses this issue by using Eigen as a backend for the above ops. This is a re-factored version of my previous PR #101814. The main difference with the old one, this does not enable Eigen by default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155357 Approved by: https://github.com/pearu, https://github.com/eqy Co-authored-by: Eli Uriegas <eliuriegas@meta.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
4acdbb8311
commit
3e5b021f21
@ -279,6 +279,7 @@ header_template_rule(
|
||||
"@AT_BLAS_F2C@": "0",
|
||||
"@AT_BLAS_USE_CBLAS_DOT@": "1",
|
||||
"@AT_KLEIDIAI_ENABLED@": "0",
|
||||
"@AT_USE_EIGEN_SPARSE@": "0",
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -289,6 +289,7 @@ option(USE_PRECOMPILED_HEADERS "Use pre-compiled headers to accelerate build."
|
||||
option(USE_PROF "Use profiling" OFF)
|
||||
option(USE_PYTORCH_QNNPACK "Use ATen/QNNPACK (quantized 8-bit operators)" ON)
|
||||
option(USE_SNPE "Use Qualcomm's SNPE library" OFF)
|
||||
option(USE_EIGEN_SPARSE "Use Eigen Sparse Matrices" OFF)
|
||||
option(USE_SYSTEM_EIGEN_INSTALL
|
||||
"Use system Eigen instead of the one under third_party" OFF)
|
||||
cmake_dependent_option(
|
||||
|
@ -96,6 +96,8 @@ file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
|
||||
file(GLOB vulkan_cpp "vulkan/*.cpp")
|
||||
file(GLOB native_vulkan_cpp "native/vulkan/*.cpp" "native/vulkan/api/*.cpp" "native/vulkan/impl/*.cpp" "native/vulkan/ops/*.cpp")
|
||||
|
||||
file(GLOB native_eigen_cpp "native/sparse/eigen/*.cpp")
|
||||
|
||||
# Metal
|
||||
file(GLOB metal_h "metal/*.h")
|
||||
file(GLOB metal_cpp "metal/*.cpp")
|
||||
@ -341,6 +343,9 @@ if(USE_VULKAN)
|
||||
else()
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp})
|
||||
endif()
|
||||
if(USE_EIGEN_SPARSE)
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${native_eigen_cpp})
|
||||
endif()
|
||||
|
||||
if(USE_MTIA)
|
||||
set(ATen_MTIA_SRCS ${ATen_MTIA_SRCS} ${mtia_cpp} ${mtia_h} ${native_mtia_cpp} ${native_mtia_h})
|
||||
|
@ -20,3 +20,4 @@
|
||||
#define AT_BLAS_F2C() @AT_BLAS_F2C@
|
||||
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
|
||||
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@
|
||||
#define AT_USE_EIGEN_SPARSE() @AT_USE_EIGEN_SPARSE@
|
||||
|
@ -698,6 +698,14 @@ bool Context::hasLAPACK() {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool Context::hasEigenSparse() {
|
||||
#if AT_USE_EIGEN_SPARSE()
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
at::QEngine Context::qEngine() const {
|
||||
static auto _quantized_engine = []() {
|
||||
at::QEngine qengine = at::kNoQEngine;
|
||||
|
@ -133,6 +133,7 @@ class TORCH_API Context {
|
||||
static bool hasLAPACK();
|
||||
static bool hasMKLDNN();
|
||||
static bool ckSupported();
|
||||
static bool hasEigenSparse();
|
||||
static bool hasMAGMA() {
|
||||
return detail::getCUDAHooks().hasMAGMA();
|
||||
}
|
||||
@ -615,6 +616,10 @@ inline bool hasLAPACK() {
|
||||
return globalContext().hasLAPACK();
|
||||
}
|
||||
|
||||
inline bool hasEigenSparse() {
|
||||
return globalContext().hasEigenSparse();
|
||||
}
|
||||
|
||||
inline bool hasMAGMA() {
|
||||
return globalContext().hasMAGMA();
|
||||
}
|
||||
|
@ -23,6 +23,9 @@
|
||||
#include <ATen/Parallel.h>
|
||||
#endif
|
||||
|
||||
#if AT_USE_EIGEN_SPARSE()
|
||||
#include <ATen/native/sparse/eigen/SparseBlasImpl.h>
|
||||
#endif
|
||||
|
||||
namespace at::native::sparse::impl {
|
||||
|
||||
@ -442,13 +445,15 @@ void add_out_sparse_csr(
|
||||
const Tensor& mat2,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
#if !AT_MKL_ENABLED()
|
||||
#if AT_USE_MKL_SPARSE()
|
||||
sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result);
|
||||
#elif AT_USE_EIGEN_SPARSE()
|
||||
sparse::impl::eigen::add_out_sparse(mat1, mat2, alpha, result);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ",
|
||||
"Please use PyTorch built MKL support.");
|
||||
#else
|
||||
sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -459,7 +464,7 @@ void triangular_solve_out_sparse_csr(
|
||||
bool upper,
|
||||
bool transpose,
|
||||
bool unitriangular) {
|
||||
#if !AT_MKL_ENABLED()
|
||||
#if !AT_USE_MKL_SPARSE()
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Calling triangular_solve on a sparse CPU tensor requires compiling PyTorch with MKL. ",
|
||||
|
@ -127,6 +127,10 @@
|
||||
#include <ATen/ops/zeros_like.h>
|
||||
#endif
|
||||
|
||||
#if AT_USE_EIGEN_SPARSE()
|
||||
#include <ATen/native/sparse/eigen/SparseBlasImpl.h>
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace at {
|
||||
@ -536,7 +540,12 @@ static void addmm_out_sparse_csr_native_cpu(
|
||||
auto values = sparse.values();
|
||||
|
||||
scalar_t cast_alpha = alpha.to<scalar_t>();
|
||||
// If beta is zero NaN and Inf should not be propagated to the result
|
||||
if (beta.toComplexDouble() == 0.) {
|
||||
r.zero_();
|
||||
} else {
|
||||
r.mul_(beta);
|
||||
}
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
|
||||
auto csr_accessor = csr.accessor<index_t, 1>();
|
||||
@ -648,6 +657,15 @@ Tensor& addmm_out_sparse_compressed_cpu(
|
||||
return result;
|
||||
}
|
||||
|
||||
#if AT_USE_EIGEN_SPARSE()
|
||||
if ((result.layout() == kSparseCsr || result.layout() == kSparseCsc) &&
|
||||
(mat1.layout() == kSparseCsr || mat1.layout() == kSparseCsc) &&
|
||||
(mat2.layout() == kSparseCsr || mat2.layout() == kSparseCsc)) {
|
||||
sparse::impl::eigen::addmm_out_sparse(mat1, mat2, result, alpha, beta);
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !AT_USE_MKL_SPARSE()
|
||||
// The custom impl addmm_out_sparse_csr_native_cpu only supports CSR @
|
||||
// strided -> strided
|
||||
|
329
aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp
Normal file
329
aten/src/ATen/native/sparse/eigen/SparseBlasImpl.cpp
Normal file
@ -0,0 +1,329 @@
|
||||
#include <ATen/native/sparse/eigen/SparseBlasImpl.h>
|
||||
|
||||
#if AT_USE_EIGEN_SPARSE()
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/SparseCsrTensorUtils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#endif
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <Eigen/SparseCore>
|
||||
|
||||
namespace at::native::sparse::impl::eigen {
|
||||
|
||||
namespace {
|
||||
|
||||
void inline sparse_indices_to_result_dtype_inplace(
|
||||
const c10::ScalarType& dtype,
|
||||
const at::Tensor& input) {
|
||||
auto [compressed_indices, plain_indices] =
|
||||
at::sparse_csr::getCompressedPlainIndices(input);
|
||||
static_cast<at::SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
|
||||
->set_member_tensors(
|
||||
compressed_indices.to(dtype),
|
||||
plain_indices.to(dtype),
|
||||
input.values(),
|
||||
input.sizes());
|
||||
}
|
||||
|
||||
void inline sparse_indices_and_values_resize(
|
||||
const at::Tensor& input,
|
||||
int64_t nnz) {
|
||||
auto [compressed_indices, plain_indices] =
|
||||
at::sparse_csr::getCompressedPlainIndices(input);
|
||||
static_cast<SparseCsrTensorImpl*>(input.unsafeGetTensorImpl())
|
||||
->set_member_tensors(
|
||||
compressed_indices,
|
||||
plain_indices.resize_({nnz}),
|
||||
input.values().resize_({nnz}),
|
||||
input.sizes());
|
||||
}
|
||||
|
||||
template <typename scalar_t, int eigen_options, typename index_t>
|
||||
const Eigen::Map<Eigen::SparseMatrix<scalar_t, eigen_options, index_t>>
|
||||
Tensor_to_Eigen(const at::Tensor& tensor) {
|
||||
int64_t rows = tensor.size(0);
|
||||
int64_t cols = tensor.size(1);
|
||||
int64_t nnz = tensor._nnz();
|
||||
TORCH_CHECK(tensor.values().is_contiguous(), "eigen accepts only contiguous tensor values");
|
||||
auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(tensor);
|
||||
index_t* c_indices_ptr = compressed_indices.data_ptr<index_t>();
|
||||
index_t* p_indices_ptr = plain_indices.data_ptr<index_t>();
|
||||
scalar_t* values_ptr = tensor.values().data_ptr<scalar_t>();
|
||||
Eigen::Map<Eigen::SparseMatrix<scalar_t, eigen_options, index_t>> map(
|
||||
rows, cols, nnz, c_indices_ptr, p_indices_ptr, values_ptr);
|
||||
return map;
|
||||
}
|
||||
|
||||
template <typename scalar_t, int eigen_options, typename index_t>
|
||||
void Eigen_to_Tensor(
|
||||
const at::Tensor& tensor,
|
||||
const Eigen::SparseMatrix<scalar_t, eigen_options, index_t>& matrix) {
|
||||
const Layout eigen_layout = (eigen_options == Eigen::RowMajor ? kSparseCsr : kSparseCsc);
|
||||
TORCH_CHECK(
|
||||
tensor.layout() == eigen_layout,
|
||||
"Eigen_to_Tensor, expected tensor be ", eigen_layout, ", but got ",
|
||||
tensor.layout());
|
||||
int64_t nnz = matrix.nonZeros();
|
||||
int64_t csize = matrix.outerSize();
|
||||
sparse_indices_and_values_resize(tensor, nnz);
|
||||
auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(tensor);
|
||||
if (nnz > 0) {
|
||||
std::memcpy(
|
||||
tensor.values().mutable_data_ptr<scalar_t>(),
|
||||
matrix.valuePtr(),
|
||||
nnz * sizeof(scalar_t));
|
||||
std::memcpy(
|
||||
plain_indices.mutable_data_ptr<index_t>(),
|
||||
matrix.innerIndexPtr(),
|
||||
nnz * sizeof(index_t));
|
||||
}
|
||||
if (csize > 0) {
|
||||
std::memcpy(
|
||||
compressed_indices.mutable_data_ptr<index_t>(),
|
||||
matrix.outerIndexPtr(),
|
||||
csize * sizeof(index_t));
|
||||
}
|
||||
compressed_indices.mutable_data_ptr<index_t>()[csize] = nnz;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void add_out_sparse_eigen(
|
||||
const at::Tensor& mat1,
|
||||
const at::Tensor& mat2,
|
||||
const at::Scalar& alpha,
|
||||
const at::Tensor& result) {
|
||||
// empty matrices
|
||||
if (mat1._nnz() == 0 && mat2._nnz() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (mat2._nnz() == 0 || alpha.toComplexDouble() == 0.) {
|
||||
sparse_indices_and_values_resize(result, mat1._nnz());
|
||||
result.copy_(mat1);
|
||||
return;
|
||||
} else if (mat1._nnz() == 0) {
|
||||
sparse_indices_and_values_resize(result, mat2._nnz());
|
||||
result.copy_(mat2);
|
||||
result.values().mul_(alpha);
|
||||
return;
|
||||
}
|
||||
|
||||
c10::ScalarType result_index_dtype = at::sparse_csr::getIndexDtype(result);
|
||||
|
||||
sparse_indices_to_result_dtype_inplace(result_index_dtype, mat1);
|
||||
sparse_indices_to_result_dtype_inplace(result_index_dtype, mat2);
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
result_index_dtype, "eigen_sparse_add", [&]() {
|
||||
scalar_t _alpha = alpha.to<scalar_t>();
|
||||
|
||||
if (result.layout() == kSparseCsr) {
|
||||
auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat1);
|
||||
auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
|
||||
auto mat1_mat2_eigen = (mat1_eigen + _alpha * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(result, mat1_mat2_eigen);
|
||||
} else {
|
||||
auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat1);
|
||||
auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
|
||||
auto mat1_mat2_eigen = (mat1_eigen + _alpha * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(result, mat1_mat2_eigen);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void addmm_out_sparse_eigen(
|
||||
const at::Tensor& mat1,
|
||||
const at::Tensor& mat2,
|
||||
const at::Tensor& result,
|
||||
const at::Scalar& alpha,
|
||||
const at::Scalar& beta) {
|
||||
// empty matrices
|
||||
if (mat1._nnz() == 0 || mat2._nnz() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If beta is zero NaN and Inf should not be propagated to the result
|
||||
// In addition, beta = 0 lets us enable a fast-path for result = alpha * A @ B
|
||||
bool is_beta_zero = false;
|
||||
if (beta.toComplexDouble() == 0.) {
|
||||
is_beta_zero = true;
|
||||
result.values().zero_();
|
||||
} else {
|
||||
result.values().mul_(beta);
|
||||
}
|
||||
|
||||
c10::ScalarType result_index_dtype = at::sparse_csr::getIndexDtype(result);
|
||||
|
||||
sparse_indices_to_result_dtype_inplace(result_index_dtype, mat1);
|
||||
sparse_indices_to_result_dtype_inplace(result_index_dtype, mat2);
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
result_index_dtype, "eigen_sparse_mm", [&]() {
|
||||
typedef Eigen::SparseMatrix<scalar_t, Eigen::RowMajor, index_t> EigenCsrMatrix;
|
||||
typedef Eigen::SparseMatrix<scalar_t, Eigen::ColMajor, index_t> EigenCscMatrix;
|
||||
|
||||
at::Tensor mat1_mat2;
|
||||
if (is_beta_zero) {
|
||||
mat1_mat2 = result;
|
||||
} else {
|
||||
mat1_mat2 = at::empty_like(result, result.options());
|
||||
}
|
||||
|
||||
if (mat1_mat2.layout() == kSparseCsr) {
|
||||
if (mat1.layout() == kSparseCsr) {
|
||||
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat1);
|
||||
if (mat2.layout() == kSparseCsr) {
|
||||
// Out_csr = M1_csr * M2_csr
|
||||
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
|
||||
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
|
||||
} else {
|
||||
// Out_csr = M1_csr * M2_csc
|
||||
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
|
||||
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
|
||||
}
|
||||
} else {
|
||||
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat1);
|
||||
if (mat2.layout() == kSparseCsr) {
|
||||
// Out_csr = M1_csc * M2_csr
|
||||
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
|
||||
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
|
||||
} else {
|
||||
// Out_csr = M1_csc * M2_csc
|
||||
// This multiplication will be computationally inefficient, as it will require
|
||||
// additional conversion of the output matrix from CSC to CSR format.
|
||||
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
|
||||
const EigenCsrMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::RowMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (mat1.layout() == kSparseCsr) {
|
||||
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat1);
|
||||
if (mat2.layout() == kSparseCsr) {
|
||||
// Out_csc = M1_csr * M2_csr
|
||||
// This multiplication will be computationally inefficient, as it will require
|
||||
// additional conversion of the output matrix from CSR to CSC format.
|
||||
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
|
||||
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
|
||||
} else {
|
||||
// Out_csc = M1_csr * M2_csc
|
||||
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
|
||||
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
|
||||
}
|
||||
} else {
|
||||
const auto mat1_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat1);
|
||||
if (mat2.layout() == kSparseCsr) {
|
||||
// Out_csc = M1_csc * M2_csr
|
||||
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::RowMajor, index_t>(mat2);
|
||||
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
|
||||
} else {
|
||||
// Out_csc = M1_csc * M2_csc
|
||||
const auto mat2_eigen = Tensor_to_Eigen<scalar_t, Eigen::ColMajor, index_t>(mat2);
|
||||
const EigenCscMatrix mat1_mat2_eigen = (mat1_eigen * mat2_eigen);
|
||||
Eigen_to_Tensor<scalar_t, Eigen::ColMajor, index_t>(mat1_mat2, mat1_mat2_eigen);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (is_beta_zero) {
|
||||
result.mul_(alpha.to<scalar_t>());
|
||||
} else {
|
||||
result.add_(mat1_mat2, alpha.to<scalar_t>());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void addmm_out_sparse(
|
||||
const at::Tensor& mat1,
|
||||
const at::Tensor& mat2,
|
||||
const at::Tensor& result,
|
||||
const at::Scalar& alpha,
|
||||
const at::Scalar& beta) {
|
||||
AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(mat1.layout(), "eigen::addmm_out_sparse:mat1", [&]{});
|
||||
AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(mat2.layout(), "eigen::addmm_out_sparse:mat2", [&]{});
|
||||
AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(result.layout(), "eigen::addmm_out_sparse:result", [&]{});
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
|
||||
result.scalar_type(), "addmm_out_sparse_eigen", [&] {
|
||||
addmm_out_sparse_eigen<scalar_t>(mat1, mat2, result, alpha, beta);
|
||||
});
|
||||
}
|
||||
|
||||
void add_out_sparse(
|
||||
const at::Tensor& mat1,
|
||||
const at::Tensor& mat2,
|
||||
const at::Scalar& alpha,
|
||||
const at::Tensor& result) {
|
||||
TORCH_CHECK(
|
||||
(result.layout() == kSparseCsr && mat1.layout() == kSparseCsr && mat2.layout() == kSparseCsr) ||
|
||||
(result.layout() == kSparseCsc && mat1.layout() == kSparseCsc && mat2.layout() == kSparseCsc),
|
||||
"eigen::add_out_sparse: expected the same layout for all operands but got ",
|
||||
mat1.layout(),
|
||||
" + ",
|
||||
mat2.layout(),
|
||||
" -> ",
|
||||
result.layout());
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
|
||||
result.scalar_type(), "add_out_sparse_eigen", [&] {
|
||||
add_out_sparse_eigen<scalar_t>(mat1, mat2, alpha, result);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace at::native::sparse::impl::eigen
|
||||
|
||||
#else
|
||||
|
||||
namespace at::native::sparse::impl::eigen {
|
||||
|
||||
void addmm_out_sparse(
|
||||
const at::Tensor& mat1,
|
||||
const at::Tensor& mat2,
|
||||
const at::Tensor& result,
|
||||
const at::Scalar& alpha,
|
||||
const at::Scalar& beta) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"eigen::addmm_out_sparse: Eigen was not enabled for ",
|
||||
result.layout(),
|
||||
" + ",
|
||||
mat1.layout(),
|
||||
" @ ",
|
||||
mat2.layout());
|
||||
}
|
||||
|
||||
void add_out_sparse(
|
||||
const at::Tensor& mat1,
|
||||
const at::Tensor& mat2,
|
||||
const at::Scalar& alpha,
|
||||
const at::Tensor& result) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"eigen::add_out_sparse: Eigen was not enabled for ",
|
||||
mat1.layout(),
|
||||
" + ",
|
||||
mat2.layout(),
|
||||
" -> ",
|
||||
result.layout());
|
||||
}
|
||||
|
||||
} // namespace at::native::sparse::impl::eigen
|
||||
|
||||
#endif // AT_USE_EIGEN_SPARSE()
|
29
aten/src/ATen/native/sparse/eigen/SparseBlasImpl.h
Normal file
29
aten/src/ATen/native/sparse/eigen/SparseBlasImpl.h
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Config.h>
|
||||
|
||||
#if AT_USE_EIGEN_SPARSE()
|
||||
#ifndef EIGEN_MPL2_ONLY
|
||||
#define EIGEN_MPL2_ONLY
|
||||
#endif
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
namespace at::native::sparse::impl::eigen {
|
||||
|
||||
void addmm_out_sparse(
|
||||
const at::Tensor& mat1,
|
||||
const at::Tensor& mat2,
|
||||
const at::Tensor& result,
|
||||
const at::Scalar& alpha,
|
||||
const at::Scalar& beta);
|
||||
|
||||
void add_out_sparse(
|
||||
const at::Tensor& mat1,
|
||||
const at::Tensor& mat2,
|
||||
const at::Scalar& alpha,
|
||||
const at::Tensor& result);
|
||||
|
||||
} // namespace at::native::sparse::impl::eigen
|
||||
|
||||
#endif
|
@ -1148,6 +1148,9 @@ def define_buck_targets(
|
||||
"--replace",
|
||||
"@AT_KLEIDIAI_ENABLED@",
|
||||
"0",
|
||||
"--replace",
|
||||
"@AT_USE_EIGEN_SPARSE@",
|
||||
"0",
|
||||
]),
|
||||
outs = {
|
||||
"Config.h": ["Config.h"],
|
||||
|
@ -153,6 +153,7 @@ set(AT_MKLDNN_ACL_ENABLED 0)
|
||||
set(AT_MKLDNN_ENABLED 0)
|
||||
set(AT_MKL_ENABLED 0)
|
||||
set(AT_KLEIDIAI_ENABLED 0)
|
||||
set(AT_USE_EIGEN_SPARSE 0)
|
||||
# setting default preferred BLAS options if not already present.
|
||||
if(NOT INTERN_BUILD_MOBILE)
|
||||
set(BLAS "MKL" CACHE STRING "Selected BLAS library")
|
||||
@ -262,6 +263,15 @@ if(BLAS_LIBRARIES AND BLAS_CHECK_F2C)
|
||||
include(cmake/BLAS_ABI.cmake)
|
||||
endif()
|
||||
|
||||
if(USE_EIGEN_SPARSE AND BLAS_INFO STREQUAL "mkl")
|
||||
message(WARNING "Disabling USE_EIGEN_SPARSE because MKL is enabled")
|
||||
set(USE_EIGEN_SPARSE OFF)
|
||||
endif()
|
||||
|
||||
if(USE_EIGEN_SPARSE)
|
||||
set(AT_USE_EIGEN_SPARSE 1)
|
||||
endif()
|
||||
|
||||
if(NOT INTERN_BUILD_MOBILE)
|
||||
set(AT_MKL_SEQUENTIAL 0)
|
||||
set(USE_BLAS 1)
|
||||
|
@ -135,6 +135,7 @@ function(caffe2_print_configuration_summary)
|
||||
endif()
|
||||
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
|
||||
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
||||
message(STATUS " USE_EIGEN_FOR_SPARSE : ${USE_EIGEN_SPARSE}")
|
||||
message(STATUS " USE_FBGEMM : ${USE_FBGEMM}")
|
||||
message(STATUS " USE_KINETO : ${USE_KINETO}")
|
||||
message(STATUS " USE_GFLAGS : ${USE_GFLAGS}")
|
||||
|
@ -2202,6 +2202,8 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False));
|
||||
ASSERT_TRUE(
|
||||
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
|
||||
ASSERT_TRUE(set_module_attr(
|
||||
"_has_eigen_sparse", at::hasEigenSparse() ? Py_True : Py_False));
|
||||
|
||||
py_module.def("_valgrind_supported_platform", []() {
|
||||
#if defined(USE_VALGRIND)
|
||||
|
Reference in New Issue
Block a user