mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
SparseCsrCUDA: cuDSS backend for linalg.solve (#129856)
This PR switches to cuDSS library and has the same purpose of #127692, which is to add Sparse CSR tensor support to linalg.solve. Fixes #69538 Minimum example of usage: ``` import torch if __name__ == '__main__': spd = torch.rand(4, 3) A = spd.T @ spd b = torch.rand(3).to(torch.float64).cuda() A = A.to_sparse_csr().to(torch.float64).cuda() x = torch.linalg.solve(A, b) print((A @ x - b).norm()) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129856 Approved by: https://github.com/amjames, https://github.com/lezcano, https://github.com/huydhn Co-authored-by: Zihang Fang <zhfang1108@gmail.com> Co-authored-by: Huy Do <huydhn@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
64cfcbd8a3
commit
90c821814e
25
.ci/docker/common/install_cudss.sh
Normal file
25
.ci/docker/common/install_cudss.sh
Normal file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
# cudss license: https://docs.nvidia.com/cuda/cudss/license.html
|
||||
mkdir tmp_cudss && cd tmp_cudss
|
||||
|
||||
if [[ ${CUDA_VERSION:0:4} =~ ^12\.[1-4]$ ]]; then
|
||||
arch_path='sbsa'
|
||||
export TARGETARCH=${TARGETARCH:-$(uname -m)}
|
||||
if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then
|
||||
arch_path='x86_64'
|
||||
fi
|
||||
CUDSS_NAME="libcudss-linux-${arch_path}-0.3.0.9_cuda12-archive"
|
||||
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudss/redist/libcudss/linux-${arch_path}/${CUDSS_NAME}.tar.xz
|
||||
|
||||
# only for cuda 12
|
||||
tar xf ${CUDSS_NAME}.tar.xz
|
||||
cp -a ${CUDSS_NAME}/include/* /usr/local/cuda/include/
|
||||
cp -a ${CUDSS_NAME}/lib/* /usr/local/cuda/lib64/
|
||||
fi
|
||||
|
||||
cd ..
|
||||
rm -rf tmp_cudss
|
||||
ldconfig
|
@ -156,6 +156,12 @@ COPY ./common/install_cusparselt.sh install_cusparselt.sh
|
||||
RUN bash install_cusparselt.sh
|
||||
RUN rm install_cusparselt.sh
|
||||
|
||||
# Install CUDSS
|
||||
ARG CUDA_VERSION
|
||||
COPY ./common/install_cudss.sh install_cudss.sh
|
||||
RUN bash install_cudss.sh
|
||||
RUN rm install_cudss.sh
|
||||
|
||||
# Delete /usr/local/cuda-11.X/cuda-11.X symlinks
|
||||
RUN if [ -h /usr/local/cuda-11.6/cuda-11.6 ]; then rm /usr/local/cuda-11.6/cuda-11.6; fi
|
||||
RUN if [ -h /usr/local/cuda-11.7/cuda-11.7 ]; then rm /usr/local/cuda-11.7/cuda-11.7; fi
|
||||
|
@ -251,6 +251,7 @@ cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
|
||||
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
||||
"USE_CUDNN" OFF)
|
||||
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
|
||||
cmake_dependent_option(USE_CUDSS "Use cuDSS" ON "USE_CUDA" OFF)
|
||||
# Binary builds will fail for cufile due to https://github.com/pytorch/builder/issues/1924
|
||||
# Using TH_BINARY_BUILD to check whether is binary build.
|
||||
# USE_ROCM is guarded against in Dependencies.cmake because USE_ROCM is not properly defined here
|
||||
@ -1296,6 +1297,10 @@ if(BUILD_SHARED_LIBS)
|
||||
FILES ${PROJECT_SOURCE_DIR}/cmake/Modules/FindCUSPARSELT.cmake
|
||||
DESTINATION share/cmake/Caffe2/
|
||||
COMPONENT dev)
|
||||
install(
|
||||
FILES ${PROJECT_SOURCE_DIR}/cmake/Modules/FindCUDSS.cmake
|
||||
DESTINATION share/cmake/Caffe2/
|
||||
COMPONENT dev)
|
||||
install(
|
||||
FILES ${PROJECT_SOURCE_DIR}/cmake/Modules/FindSYCLToolkit.cmake
|
||||
DESTINATION share/cmake/Caffe2/
|
||||
|
@ -15,6 +15,10 @@
|
||||
#include <cusolverDn.h>
|
||||
#endif
|
||||
|
||||
#if defined(USE_CUDSS)
|
||||
#include <cudss.h>
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#include <hipsolver/hipsolver.h>
|
||||
#endif
|
||||
@ -88,4 +92,8 @@ TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
|
||||
#endif
|
||||
|
||||
#if defined(USE_CUDSS)
|
||||
TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle();
|
||||
#endif
|
||||
|
||||
} // namespace at::cuda
|
||||
|
@ -64,4 +64,23 @@ C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status) {
|
||||
} // namespace solver
|
||||
#endif
|
||||
|
||||
#if defined(USE_CUDSS)
|
||||
namespace cudss {
|
||||
|
||||
C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t status) {
|
||||
switch (status) {
|
||||
case CUDSS_STATUS_SUCCESS: return "CUDSS_STATUS_SUCCESS";
|
||||
case CUDSS_STATUS_NOT_INITIALIZED: return "CUDSS_STATUS_NOT_INITIALIZED";
|
||||
case CUDSS_STATUS_ALLOC_FAILED: return "CUDSS_STATUS_ALLOC_FAILED";
|
||||
case CUDSS_STATUS_INVALID_VALUE: return "CUDSS_STATUS_INVALID_VALUE";
|
||||
case CUDSS_STATUS_NOT_SUPPORTED: return "CUDSS_STATUS_NOT_SUPPORTED";
|
||||
case CUDSS_STATUS_EXECUTION_FAILED: return "CUDSS_STATUS_EXECUTION_FAILED";
|
||||
case CUDSS_STATUS_INTERNAL_ERROR: return "CUDSS_STATUS_INTERNAL_ERROR";
|
||||
default: return "Unknown cudss error number";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cudss
|
||||
#endif
|
||||
|
||||
} // namespace at::cuda
|
||||
|
@ -8,6 +8,10 @@
|
||||
#include <cusolver_common.h>
|
||||
#endif
|
||||
|
||||
#if defined(USE_CUDSS)
|
||||
#include <cudss.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/Context.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
@ -73,6 +77,33 @@ const char *cusparseGetErrorString(cusparseStatus_t status);
|
||||
" when calling `" #EXPR "`"); \
|
||||
} while (0)
|
||||
|
||||
#if defined(USE_CUDSS)
|
||||
namespace at::cuda::cudss {
|
||||
C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error);
|
||||
} // namespace at::cuda::solver
|
||||
|
||||
#define TORCH_CUDSS_CHECK(EXPR) \
|
||||
do { \
|
||||
cudssStatus_t __err = EXPR; \
|
||||
if (__err == CUDSS_STATUS_EXECUTION_FAILED) { \
|
||||
TORCH_CHECK_LINALG( \
|
||||
false, \
|
||||
"cudss error: ", \
|
||||
at::cuda::cudss::cudssGetErrorMessage(__err), \
|
||||
", when calling `" #EXPR "`", \
|
||||
". This error may appear if the input matrix contains NaN. ");\
|
||||
} else { \
|
||||
TORCH_CHECK( \
|
||||
__err == CUDSS_STATUS_SUCCESS, \
|
||||
"cudss error: ", \
|
||||
at::cuda::cudss::cudssGetErrorMessage(__err), \
|
||||
", when calling `" #EXPR "`. "); \
|
||||
} \
|
||||
} while (0)
|
||||
#else
|
||||
#define TORCH_CUDSS_CHECK(EXPR) EXPR
|
||||
#endif
|
||||
|
||||
// cusolver related headers are only supported on cuda now
|
||||
#ifdef CUDART_VERSION
|
||||
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_spsolve.h>
|
||||
#include <ATen/ops/_cholesky_solve_helper.h>
|
||||
#include <ATen/ops/_cholesky_solve_helper_native.h>
|
||||
#include <ATen/ops/_linalg_check_errors.h>
|
||||
@ -1936,6 +1937,9 @@ Tensor& linalg_solve_out(const Tensor& A,
|
||||
Tensor linalg_solve(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool left) {
|
||||
if (A.layout() == kSparseCsr) {
|
||||
return at::_spsolve(A, B, left);
|
||||
}
|
||||
auto [result, info] = at::linalg_solve_ex(A, B, left);
|
||||
at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2);
|
||||
return result;
|
||||
|
52
aten/src/ATen/native/cuda/linalg/CudssHandlePool.cpp
Normal file
52
aten/src/ATen/native/cuda/linalg/CudssHandlePool.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/DeviceThreadHandles.h>
|
||||
|
||||
#if defined(USE_CUDSS)
|
||||
|
||||
namespace at::cuda {
|
||||
namespace {
|
||||
|
||||
void createCudssHandle(cudssHandle_t *handle) {
|
||||
TORCH_CUDSS_CHECK(cudssCreate(handle));
|
||||
}
|
||||
|
||||
void destroyCudssHandle(cudssHandle_t handle) {
|
||||
// this is because of something dumb in the ordering of
|
||||
// destruction. Sometimes atexit, the cuda context (or something)
|
||||
// would already be destroyed by the time this gets destroyed. It
|
||||
// happens in fbcode setting. @colesbury and @soumith decided to not destroy
|
||||
// the handle as a workaround.
|
||||
// - Comments of @soumith copied from cuDNN handle pool implementation
|
||||
#ifdef NO_CUDNN_DESTROY_HANDLE
|
||||
(void)handle; // Suppress unused variable warning
|
||||
#else
|
||||
cudssDestroy(handle);
|
||||
#endif
|
||||
}
|
||||
|
||||
using CudssPoolType = DeviceThreadHandlePool<cudssHandle_t, createCudssHandle, destroyCudssHandle>;
|
||||
|
||||
} // namespace
|
||||
|
||||
cudssHandle_t getCurrentCudssHandle() {
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
|
||||
// Thread local PoolWindows are lazily-initialized
|
||||
// to avoid initialization issues that caused hangs on Windows.
|
||||
// See: https://github.com/pytorch/pytorch/pull/22405
|
||||
// This thread local unique_ptrs will be destroyed when the thread terminates,
|
||||
// releasing its reserved handles back to the pool.
|
||||
static auto pool = std::make_shared<CudssPoolType>();
|
||||
thread_local std::unique_ptr<CudssPoolType::PoolWindow> myPoolWindow(
|
||||
pool->newPoolWindow());
|
||||
|
||||
auto handle = myPoolWindow->reserve(device);
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
TORCH_CUDSS_CHECK(cudssSetStream(handle, stream));
|
||||
return handle;
|
||||
}
|
||||
|
||||
} // namespace at::cuda
|
||||
|
||||
#endif
|
@ -14216,6 +14216,11 @@
|
||||
- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor
|
||||
python_module: linalg
|
||||
|
||||
- func: _spsolve(Tensor A, Tensor B, *, bool left=True) -> Tensor
|
||||
python_module: sparse
|
||||
dispatch:
|
||||
SparseCsrCUDA: _sparse_csr_linear_solve
|
||||
|
||||
- func: linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
|
||||
|
@ -705,6 +705,71 @@ struct ReductionMulOp {
|
||||
__forceinline__ scalar_t identity_cpu() const { return 1; }
|
||||
};
|
||||
|
||||
void _apply_sparse_csr_linear_solve(
|
||||
const Tensor& A,
|
||||
const Tensor& b,
|
||||
const bool left,
|
||||
const Tensor& x) {
|
||||
#if defined(USE_ROCM) || !defined(USE_CUDSS)
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Calling linear solver with sparse tensors requires compiling ",
|
||||
"PyTorch with CUDA cuDSS and is not supported in ROCm build.");
|
||||
#else
|
||||
// layout check
|
||||
TORCH_CHECK(A.is_sparse_csr(), "A must be a CSR matrix");
|
||||
TORCH_CHECK(b.layout() == kStrided, "b must be a strided tensor");
|
||||
TORCH_CHECK(x.layout() == kStrided, "x must be a strided tensor");
|
||||
// dim check
|
||||
TORCH_CHECK(b.dim() == 1, "b must be a 1D tensor");
|
||||
TORCH_CHECK(b.stride(0) == 1, "b must be a column major tensor");
|
||||
TORCH_CHECK(b.size(0) == A.size(0), "linear system size mismatch.");
|
||||
TORCH_CHECK(x.dim() == 1, "x must be a 1D tensor");
|
||||
TORCH_CHECK(x.stride(0) == 1, "x must be a column major tensor");
|
||||
TORCH_CHECK(x.size(0) == A.size(1), "linear system size mismatch.");
|
||||
TORCH_CHECK(A.dtype() == b.dtype() && A.dtype() == x.dtype(), "A, x, and b must have the same dtype");
|
||||
TORCH_CHECK(left == true, "only left == true is supported by the Sparse CSR backend")
|
||||
|
||||
Tensor crow = A.crow_indices();
|
||||
Tensor col = A.col_indices();
|
||||
if (crow.scalar_type() != ScalarType::Int) {
|
||||
crow = crow.to(crow.options().dtype(ScalarType::Int));
|
||||
col = col.to(col.options().dtype(ScalarType::Int));
|
||||
}
|
||||
int* rowOffsets = crow.data<int>();
|
||||
int* colIndices = col.data<int>();
|
||||
Tensor values = A.values();
|
||||
// cuDSS data structures and handle initialization
|
||||
cudssConfig_t config;
|
||||
cudssMatrix_t b_mt;
|
||||
cudssMatrix_t A_mt;
|
||||
cudssMatrix_t x_mt;
|
||||
cudssData_t cudss_data;
|
||||
cudssHandle_t handle = at::cuda::getCurrentCudssHandle();
|
||||
|
||||
TORCH_CUDSS_CHECK(cudssConfigCreate(&config));
|
||||
TORCH_CUDSS_CHECK(cudssDataCreate(handle, &cudss_data));
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(values.type(), "create_matrix", ([&] {
|
||||
scalar_t* values_ptr = values.data<scalar_t>();
|
||||
scalar_t* b_ptr = b.data<scalar_t>();
|
||||
scalar_t* x_ptr = x.data<scalar_t>();
|
||||
auto CUDA_R_TYP = std::is_same<scalar_t, double>::value ? CUDA_R_64F : CUDA_R_32F;
|
||||
TORCH_CUDSS_CHECK(cudssMatrixCreateDn(&b_mt, b.size(0), 1, b.size(0), b_ptr, CUDA_R_TYP, CUDSS_LAYOUT_COL_MAJOR));
|
||||
TORCH_CUDSS_CHECK(cudssMatrixCreateDn(&x_mt, x.size(0), 1, x.size(0), x_ptr, CUDA_R_TYP, CUDSS_LAYOUT_COL_MAJOR));
|
||||
TORCH_CUDSS_CHECK(cudssMatrixCreateCsr(&A_mt, A.size(0), A.size(1), A._nnz(), rowOffsets, rowOffsets + crow.size(0), colIndices, values_ptr, CUDA_R_32I, CUDA_R_TYP, CUDSS_MTYPE_GENERAL, CUDSS_MVIEW_FULL, CUDSS_BASE_ZERO));
|
||||
}));
|
||||
TORCH_CUDSS_CHECK(cudssExecute(handle, CUDSS_PHASE_ANALYSIS, config, cudss_data, A_mt, x_mt, b_mt));
|
||||
TORCH_CUDSS_CHECK(cudssExecute(handle, CUDSS_PHASE_FACTORIZATION, config, cudss_data, A_mt, x_mt, b_mt));
|
||||
TORCH_CUDSS_CHECK(cudssExecute(handle, CUDSS_PHASE_SOLVE, config, cudss_data, A_mt, x_mt, b_mt));
|
||||
// Destroy the opaque objects
|
||||
TORCH_CUDSS_CHECK(cudssConfigDestroy(config));
|
||||
TORCH_CUDSS_CHECK(cudssDataDestroy(handle, cudss_data));
|
||||
TORCH_CUDSS_CHECK(cudssMatrixDestroy(A_mt));
|
||||
TORCH_CUDSS_CHECK(cudssMatrixDestroy(x_mt));
|
||||
TORCH_CUDSS_CHECK(cudssMatrixDestroy(b_mt));
|
||||
#endif
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Tensor _sparse_csr_sum_cuda(const Tensor& input, IntArrayRef dims_to_sum, bool keepdim, std::optional<ScalarType> dtype) {
|
||||
@ -735,4 +800,12 @@ Tensor _sparse_csr_prod_cuda(const Tensor& input, IntArrayRef dims_to_reduce, bo
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor _sparse_csr_linear_solve(const Tensor& A, const Tensor& b, const bool left) {
|
||||
Tensor b_copy = b.contiguous();
|
||||
Tensor out = b_copy.new_empty(b_copy.sizes());
|
||||
_apply_sparse_csr_linear_solve(A, b_copy, left, out);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1443,6 +1443,7 @@ aten_cuda_cu_source_list = [
|
||||
"aten/src/ATen/cuda/CUDABlas.cpp",
|
||||
"aten/src/ATen/cuda/CUDASparseBlas.cpp",
|
||||
"aten/src/ATen/cuda/CublasHandlePool.cpp",
|
||||
"aten/src/ATen/native/cuda/linalg/CudssHandlePool.cpp",
|
||||
"aten/src/ATen/cuda/tunable/StreamTimer.cpp",
|
||||
"aten/src/ATen/cuda/tunable/Tunable.cpp",
|
||||
"aten/src/ATen/native/cuda/Activation.cpp",
|
||||
|
@ -936,6 +936,10 @@ elseif(USE_CUDA)
|
||||
target_link_libraries(torch_cuda PRIVATE torch::cusparselt)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT)
|
||||
endif()
|
||||
if(USE_CUDSS)
|
||||
target_link_libraries(torch_cuda PRIVATE torch::cudss)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_CUDSS)
|
||||
endif()
|
||||
if(USE_NCCL)
|
||||
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_NCCL)
|
||||
|
67
cmake/Modules/FindCUDSS.cmake
Normal file
67
cmake/Modules/FindCUDSS.cmake
Normal file
@ -0,0 +1,67 @@
|
||||
# Find the CUDSS library
|
||||
#
|
||||
# The following variables are optionally searched for defaults
|
||||
# CUDSS_ROOT: Base directory where CUDSS is found
|
||||
# CUDSS_INCLUDE_DIR: Directory where CUDSS header is searched for
|
||||
# CUDSS_LIBRARY: Directory where CUDSS library is searched for
|
||||
#
|
||||
# The following are set after configuration is done:
|
||||
# CUDSS_FOUND
|
||||
# CUDSS_INCLUDE_PATH
|
||||
# CUDSS_LIBRARY_PATH
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
|
||||
set(CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} CACHE PATH "Folder containing NVIDIA CUDSS")
|
||||
if (DEFINED $ENV{CUDSS_ROOT_DIR})
|
||||
message(WARNING "CUDSS_ROOT_DIR is deprecated. Please set CUDSS_ROOT instead.")
|
||||
endif()
|
||||
list(APPEND CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
|
||||
|
||||
# Compatible layer for CMake <3.12. CUDSS_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
|
||||
list(APPEND CMAKE_PREFIX_PATH ${CUDSS_ROOT})
|
||||
|
||||
set(CUDSS_INCLUDE_DIR $ENV{CUDSS_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA CUDSS header files")
|
||||
|
||||
find_path(CUDSS_INCLUDE_PATH cudss.h
|
||||
HINTS ${CUDSS_INCLUDE_DIR}
|
||||
PATH_SUFFIXES cuda/include cuda include)
|
||||
|
||||
set(CUDSS_LIBRARY $ENV{CUDSS_LIBRARY} CACHE PATH "Path to the CUDSS library file (e.g., libcudss.so)")
|
||||
|
||||
set(CUDSS_LIBRARY_NAME "libcudss.so")
|
||||
if(MSVC)
|
||||
set(CUDSS_LIBRARY_NAME "cudss.lib")
|
||||
endif()
|
||||
|
||||
find_library(CUDSS_LIBRARY_PATH ${CUDSS_LIBRARY_NAME}
|
||||
PATHS ${CUDSS_LIBRARY}
|
||||
PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
|
||||
|
||||
find_package_handle_standard_args(CUDSS DEFAULT_MSG CUDSS_LIBRARY_PATH CUDSS_INCLUDE_PATH)
|
||||
|
||||
if(CUDSS_FOUND)
|
||||
# Get CUDSS version
|
||||
file(READ ${CUDSS_INCLUDE_PATH}/cudss.h CUDSS_HEADER_CONTENTS)
|
||||
string(REGEX MATCH "define CUDSS_VER_MAJOR * +([0-9]+)"
|
||||
CUDSS_VERSION_MAJOR "${CUDSS_HEADER_CONTENTS}")
|
||||
string(REGEX REPLACE "define CUDSS_VER_MAJOR * +([0-9]+)" "\\1"
|
||||
CUDSS_VERSION_MAJOR "${CUDSS_VERSION_MAJOR}")
|
||||
string(REGEX MATCH "define CUDSS_VER_MINOR * +([0-9]+)"
|
||||
CUDSS_VERSION_MINOR "${CUDSS_HEADER_CONTENTS}")
|
||||
string(REGEX REPLACE "define CUDSS_VER_MINOR * +([0-9]+)" "\\1"
|
||||
CUDSS_VERSION_MINOR "${CUDSS_VERSION_MINOR}")
|
||||
string(REGEX MATCH "define CUDSS_VER_PATCH * +([0-9]+)"
|
||||
CUDSS_VERSION_PATCH "${CUDSS_HEADER_CONTENTS}")
|
||||
string(REGEX REPLACE "define CUDSS_VER_PATCH * +([0-9]+)" "\\1"
|
||||
CUDSS_VERSION_PATCH "${CUDSS_VERSION_PATCH}")
|
||||
# Assemble CUDSS version. Use minor version since current major version is 0.
|
||||
if(NOT CUDSS_VERSION_MINOR)
|
||||
set(CUDSS_VERSION "?")
|
||||
else()
|
||||
set(CUDSS_VERSION
|
||||
"${CUDSS_VERSION_MAJOR}.${CUDSS_VERSION_MINOR}.${CUDSS_VERSION_PATCH}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CUDSS_ROOT CUDSS_INCLUDE_DIR CUDSS_LIBRARY CUDSS_VERSION)
|
@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary)
|
||||
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
|
||||
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
|
||||
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
|
||||
message(STATUS " USE_CUDSS : ${USE_CUDSS}")
|
||||
message(STATUS " USE_CUFILE : ${USE_CUFILE}")
|
||||
message(STATUS " CUDA version : ${CUDA_VERSION}")
|
||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||
@ -102,6 +103,10 @@ function(caffe2_print_configuration_summary)
|
||||
get_target_property(__tmp torch::cusparselt INTERFACE_LINK_LIBRARIES)
|
||||
message(STATUS " cuSPARSELt library : ${__tmp}")
|
||||
endif()
|
||||
if(${USE_CUDSS})
|
||||
get_target_property(__tmp torch::cudss INTERFACE_LINK_LIBRARIES)
|
||||
message(STATUS " cuDSS library : ${__tmp}")
|
||||
endif()
|
||||
message(STATUS " nvrtc : ${CUDA_nvrtc_LIBRARY}")
|
||||
message(STATUS " CUDA include path : ${CUDA_INCLUDE_DIRS}")
|
||||
message(STATUS " NVCC executable : ${CUDA_NVCC_EXECUTABLE}")
|
||||
|
@ -248,6 +248,22 @@ else()
|
||||
message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
|
||||
endif()
|
||||
|
||||
if(USE_CUDSS)
|
||||
find_package(CUDSS)
|
||||
|
||||
if(NOT CUDSS_FOUND)
|
||||
message(WARNING
|
||||
"Cannot find CUDSS library. Turning the option off")
|
||||
set(USE_CUDSS OFF)
|
||||
else()
|
||||
add_library(torch::cudss INTERFACE IMPORTED)
|
||||
target_include_directories(torch::cudss INTERFACE ${CUDSS_INCLUDE_PATH})
|
||||
target_link_libraries(torch::cudss INTERFACE ${CUDSS_LIBRARY_PATH})
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "USE_CUDSS is set to 0. Compiling without cuDSS support")
|
||||
endif()
|
||||
|
||||
# cufile
|
||||
if(CAFFE2_USE_CUFILE)
|
||||
add_library(torch::cufile INTERFACE IMPORTED)
|
||||
|
@ -1147,6 +1147,7 @@ multiplication, and ``@`` is matrix multiplication.
|
||||
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[SparseSemiStructured] @ M[strided]) -> M[strided]``
|
||||
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[strided] @ M[SparseSemiStructured]) -> M[strided]``
|
||||
:func:`torch.sparse.addmm`; yes; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]``
|
||||
:func:`torch.sparse.spsolve`; no; ``SOLVE(M[sparse_csr], V[strided]) -> V[strided]``
|
||||
:func:`torch.sspaddmm`; no; ``f * M[sparse_coo] + f * (M[sparse_coo] @ M[strided]) -> M[sparse_coo]``
|
||||
:func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]``
|
||||
:func:`torch.pca_lowrank`; yes; ``PCA(M[sparse_coo]) -> M[strided], M[strided], M[strided]``
|
||||
@ -1292,6 +1293,7 @@ Torch functions specific to sparse Tensors
|
||||
hspmm
|
||||
smm
|
||||
sparse.softmax
|
||||
sparse.spsolve
|
||||
sparse.log_softmax
|
||||
sparse.spdiags
|
||||
|
||||
|
3
setup.py
3
setup.py
@ -38,6 +38,9 @@
|
||||
# USE_CUSPARSELT=0
|
||||
# disables the cuSPARSELt build
|
||||
#
|
||||
# USE_CUDSS=0
|
||||
# disables the cuDSS build
|
||||
#
|
||||
# USE_CUFILE=0
|
||||
# disables the cuFile build
|
||||
#
|
||||
|
@ -555,6 +555,7 @@ aten::_sparse_sum_backward
|
||||
aten::_sparse_sum_backward.out
|
||||
aten::_spdiags
|
||||
aten::_spdiags.out
|
||||
aten::_spsolve
|
||||
aten::_stack
|
||||
aten::_stack.out
|
||||
aten::_standard_gamma
|
||||
|
@ -10,8 +10,8 @@ from contextlib import redirect_stderr
|
||||
from torch.testing import make_tensor, FileCheck
|
||||
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests,
|
||||
load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
|
||||
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
|
||||
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
|
||||
suppress_warnings)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
|
||||
@ -23,6 +23,7 @@ from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_CU
|
||||
from torch.testing._internal.common_dtype import (
|
||||
floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
|
||||
all_types_and_complex, floating_and_complex_types_and)
|
||||
from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
|
||||
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
|
||||
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
|
||||
import operator
|
||||
@ -3487,6 +3488,46 @@ class TestSparseCSR(TestCase):
|
||||
self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
|
||||
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA_CUDSS, "The test requires cudss")
|
||||
@dtypes(*floating_types())
|
||||
def test_linalg_solve_sparse_csr_cusolver(self, device, dtype):
|
||||
# https://github.com/krshrimali/pytorch/blob/f5ee21dd87a7c5e67ba03bfd77ea22246cabdf0b/test/test_sparse_csr.py
|
||||
|
||||
try:
|
||||
spd = torch.rand(4, 3)
|
||||
A = spd.T @ spd
|
||||
b = torch.rand(3).cuda()
|
||||
A = A.to_sparse_csr().cuda()
|
||||
x = torch.sparse.spsolve(A, b)
|
||||
except RuntimeError as e:
|
||||
if "Calling linear solver with sparse tensors requires compiling " in str(e):
|
||||
self.skipTest("PyTorch was not built with cuDSS support")
|
||||
|
||||
samples = sample_inputs_linalg_solve(None, device, dtype)
|
||||
|
||||
for sample in samples:
|
||||
if sample.input.ndim != 2:
|
||||
continue
|
||||
|
||||
out = torch.zeros(sample.args[0].size(), dtype=dtype, device=device)
|
||||
if sample.args[0].ndim != 1 and sample.args[0].size(-1) != 1:
|
||||
with self.assertRaisesRegex(RuntimeError, "b must be a 1D tensor"):
|
||||
out = torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs)
|
||||
break
|
||||
if not sample.args[0].numel():
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Expected non-empty other tensor, but found empty tensor"):
|
||||
torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
|
||||
break
|
||||
|
||||
expect = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
|
||||
sample.input = sample.input.to_sparse_csr()
|
||||
if sample.args[0].ndim != 1 and sample.args[0].size(-1) == 1:
|
||||
expect = expect.squeeze(-1)
|
||||
sample.args = (sample.args[0].squeeze(-1), )
|
||||
out = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
|
||||
self.assertEqual(expect, out)
|
||||
|
||||
|
||||
def skipIfNoTriton(cls):
|
||||
from torch.utils._triton import has_triton
|
||||
|
@ -2134,6 +2134,9 @@ Letting `*` be zero or more batch dimensions,
|
||||
It is possible to compute the solution of the system :math:`XA = B` by passing the inputs
|
||||
:attr:`A` and :attr:`B` transposed and transposing the output returned by this function.
|
||||
|
||||
.. note::
|
||||
:attr:`A` is allowed to be a non-batched `torch.sparse_csr_tensor`, but only with `left=True`.
|
||||
|
||||
""" + fr"""
|
||||
.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.solve_ex")}
|
||||
""" + r"""
|
||||
|
@ -32,6 +32,7 @@ __all__ = [
|
||||
"mm",
|
||||
"sum",
|
||||
"softmax",
|
||||
"solve",
|
||||
"log_softmax",
|
||||
"SparseSemiStructuredTensor",
|
||||
"SparseSemiStructuredTensorCUTLASS",
|
||||
@ -297,6 +298,26 @@ Args:
|
||||
)
|
||||
|
||||
|
||||
spsolve = _add_docstr(
|
||||
_sparse._spsolve,
|
||||
r"""
|
||||
sparse.spsolve(input, other, *, left=True) -> Tensor
|
||||
|
||||
Computes the solution of a square system of linear equations with
|
||||
a unique solution. Its purpose is similar to :func:`torch.linalg.solve`,
|
||||
except that the system is defined by a sparse CSR matrix with layout
|
||||
`sparse_csr`.
|
||||
|
||||
Args:
|
||||
input (Tensor): a sparse CSR matrix of shape `(n, n)` representing the
|
||||
coefficients of the linear system.
|
||||
other (Tensor): a dense matrix of shape `(n, )` representing the right-hand
|
||||
side of the linear system.
|
||||
left (bool, optional): whether to solve the system for `input @ out = other`
|
||||
(default) or `out @ input = other`. Only `left=True` is supported.
|
||||
""",
|
||||
)
|
||||
|
||||
log_softmax = _add_docstr(
|
||||
_sparse._sparse_log_softmax,
|
||||
r"""
|
||||
|
@ -1427,6 +1427,8 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
|
||||
(torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
|
||||
)
|
||||
|
||||
TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
|
||||
|
||||
def allocator_option_enabled_fn(allocator_config, _, option):
|
||||
if allocator_config is None:
|
||||
return False
|
||||
|
Reference in New Issue
Block a user