SparseCsrCUDA: cuDSS backend for linalg.solve (#129856)

This PR switches to cuDSS library and has the same purpose of #127692, which is to add Sparse CSR tensor support to linalg.solve.
Fixes #69538

Minimum example of usage:
```
import torch

if __name__ == '__main__':
    spd = torch.rand(4, 3)
    A = spd.T @ spd
    b = torch.rand(3).to(torch.float64).cuda()
    A = A.to_sparse_csr().to(torch.float64).cuda()

    x = torch.linalg.solve(A, b)
    print((A @ x - b).norm())

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129856
Approved by: https://github.com/amjames, https://github.com/lezcano, https://github.com/huydhn

Co-authored-by: Zihang Fang <zhfang1108@gmail.com>
Co-authored-by: Huy Do <huydhn@gmail.com>
This commit is contained in:
Zitong Zhan
2024-08-22 07:57:30 +00:00
committed by PyTorch MergeBot
parent 64cfcbd8a3
commit 90c821814e
22 changed files with 396 additions and 2 deletions

View File

@ -0,0 +1,25 @@
#!/bin/bash
set -ex
# cudss license: https://docs.nvidia.com/cuda/cudss/license.html
mkdir tmp_cudss && cd tmp_cudss
if [[ ${CUDA_VERSION:0:4} =~ ^12\.[1-4]$ ]]; then
arch_path='sbsa'
export TARGETARCH=${TARGETARCH:-$(uname -m)}
if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then
arch_path='x86_64'
fi
CUDSS_NAME="libcudss-linux-${arch_path}-0.3.0.9_cuda12-archive"
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudss/redist/libcudss/linux-${arch_path}/${CUDSS_NAME}.tar.xz
# only for cuda 12
tar xf ${CUDSS_NAME}.tar.xz
cp -a ${CUDSS_NAME}/include/* /usr/local/cuda/include/
cp -a ${CUDSS_NAME}/lib/* /usr/local/cuda/lib64/
fi
cd ..
rm -rf tmp_cudss
ldconfig

View File

@ -156,6 +156,12 @@ COPY ./common/install_cusparselt.sh install_cusparselt.sh
RUN bash install_cusparselt.sh
RUN rm install_cusparselt.sh
# Install CUDSS
ARG CUDA_VERSION
COPY ./common/install_cudss.sh install_cudss.sh
RUN bash install_cudss.sh
RUN rm install_cudss.sh
# Delete /usr/local/cuda-11.X/cuda-11.X symlinks
RUN if [ -h /usr/local/cuda-11.6/cuda-11.6 ]; then rm /usr/local/cuda-11.6/cuda-11.6; fi
RUN if [ -h /usr/local/cuda-11.7/cuda-11.7 ]; then rm /usr/local/cuda-11.7/cuda-11.7; fi

View File

@ -251,6 +251,7 @@ cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
"USE_CUDNN" OFF)
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
cmake_dependent_option(USE_CUDSS "Use cuDSS" ON "USE_CUDA" OFF)
# Binary builds will fail for cufile due to https://github.com/pytorch/builder/issues/1924
# Using TH_BINARY_BUILD to check whether is binary build.
# USE_ROCM is guarded against in Dependencies.cmake because USE_ROCM is not properly defined here
@ -1296,6 +1297,10 @@ if(BUILD_SHARED_LIBS)
FILES ${PROJECT_SOURCE_DIR}/cmake/Modules/FindCUSPARSELT.cmake
DESTINATION share/cmake/Caffe2/
COMPONENT dev)
install(
FILES ${PROJECT_SOURCE_DIR}/cmake/Modules/FindCUDSS.cmake
DESTINATION share/cmake/Caffe2/
COMPONENT dev)
install(
FILES ${PROJECT_SOURCE_DIR}/cmake/Modules/FindSYCLToolkit.cmake
DESTINATION share/cmake/Caffe2/

View File

@ -15,6 +15,10 @@
#include <cusolverDn.h>
#endif
#if defined(USE_CUDSS)
#include <cudss.h>
#endif
#if defined(USE_ROCM)
#include <hipsolver/hipsolver.h>
#endif
@ -88,4 +92,8 @@ TORCH_CUDA_CPP_API void clearCublasWorkspaces();
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
#endif
#if defined(USE_CUDSS)
TORCH_CUDA_CPP_API cudssHandle_t getCurrentCudssHandle();
#endif
} // namespace at::cuda

View File

@ -64,4 +64,23 @@ C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status) {
} // namespace solver
#endif
#if defined(USE_CUDSS)
namespace cudss {
C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t status) {
switch (status) {
case CUDSS_STATUS_SUCCESS: return "CUDSS_STATUS_SUCCESS";
case CUDSS_STATUS_NOT_INITIALIZED: return "CUDSS_STATUS_NOT_INITIALIZED";
case CUDSS_STATUS_ALLOC_FAILED: return "CUDSS_STATUS_ALLOC_FAILED";
case CUDSS_STATUS_INVALID_VALUE: return "CUDSS_STATUS_INVALID_VALUE";
case CUDSS_STATUS_NOT_SUPPORTED: return "CUDSS_STATUS_NOT_SUPPORTED";
case CUDSS_STATUS_EXECUTION_FAILED: return "CUDSS_STATUS_EXECUTION_FAILED";
case CUDSS_STATUS_INTERNAL_ERROR: return "CUDSS_STATUS_INTERNAL_ERROR";
default: return "Unknown cudss error number";
}
}
} // namespace cudss
#endif
} // namespace at::cuda

View File

@ -8,6 +8,10 @@
#include <cusolver_common.h>
#endif
#if defined(USE_CUDSS)
#include <cudss.h>
#endif
#include <ATen/Context.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDAException.h>
@ -73,6 +77,33 @@ const char *cusparseGetErrorString(cusparseStatus_t status);
" when calling `" #EXPR "`"); \
} while (0)
#if defined(USE_CUDSS)
namespace at::cuda::cudss {
C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error);
} // namespace at::cuda::solver
#define TORCH_CUDSS_CHECK(EXPR) \
do { \
cudssStatus_t __err = EXPR; \
if (__err == CUDSS_STATUS_EXECUTION_FAILED) { \
TORCH_CHECK_LINALG( \
false, \
"cudss error: ", \
at::cuda::cudss::cudssGetErrorMessage(__err), \
", when calling `" #EXPR "`", \
". This error may appear if the input matrix contains NaN. ");\
} else { \
TORCH_CHECK( \
__err == CUDSS_STATUS_SUCCESS, \
"cudss error: ", \
at::cuda::cudss::cudssGetErrorMessage(__err), \
", when calling `" #EXPR "`. "); \
} \
} while (0)
#else
#define TORCH_CUDSS_CHECK(EXPR) EXPR
#endif
// cusolver related headers are only supported on cuda now
#ifdef CUDART_VERSION

View File

@ -21,6 +21,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_spsolve.h>
#include <ATen/ops/_cholesky_solve_helper.h>
#include <ATen/ops/_cholesky_solve_helper_native.h>
#include <ATen/ops/_linalg_check_errors.h>
@ -1936,6 +1937,9 @@ Tensor& linalg_solve_out(const Tensor& A,
Tensor linalg_solve(const Tensor& A,
const Tensor& B,
bool left) {
if (A.layout() == kSparseCsr) {
return at::_spsolve(A, B, left);
}
auto [result, info] = at::linalg_solve_ex(A, B, left);
at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2);
return result;

View File

@ -0,0 +1,52 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/DeviceThreadHandles.h>
#if defined(USE_CUDSS)
namespace at::cuda {
namespace {
void createCudssHandle(cudssHandle_t *handle) {
TORCH_CUDSS_CHECK(cudssCreate(handle));
}
void destroyCudssHandle(cudssHandle_t handle) {
// this is because of something dumb in the ordering of
// destruction. Sometimes atexit, the cuda context (or something)
// would already be destroyed by the time this gets destroyed. It
// happens in fbcode setting. @colesbury and @soumith decided to not destroy
// the handle as a workaround.
// - Comments of @soumith copied from cuDNN handle pool implementation
#ifdef NO_CUDNN_DESTROY_HANDLE
(void)handle; // Suppress unused variable warning
#else
cudssDestroy(handle);
#endif
}
using CudssPoolType = DeviceThreadHandlePool<cudssHandle_t, createCudssHandle, destroyCudssHandle>;
} // namespace
cudssHandle_t getCurrentCudssHandle() {
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
// Thread local PoolWindows are lazily-initialized
// to avoid initialization issues that caused hangs on Windows.
// See: https://github.com/pytorch/pytorch/pull/22405
// This thread local unique_ptrs will be destroyed when the thread terminates,
// releasing its reserved handles back to the pool.
static auto pool = std::make_shared<CudssPoolType>();
thread_local std::unique_ptr<CudssPoolType::PoolWindow> myPoolWindow(
pool->newPoolWindow());
auto handle = myPoolWindow->reserve(device);
auto stream = c10::cuda::getCurrentCUDAStream();
TORCH_CUDSS_CHECK(cudssSetStream(handle, stream));
return handle;
}
} // namespace at::cuda
#endif

View File

@ -14216,6 +14216,11 @@
- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor
python_module: linalg
- func: _spsolve(Tensor A, Tensor B, *, bool left=True) -> Tensor
python_module: sparse
dispatch:
SparseCsrCUDA: _sparse_csr_linear_solve
- func: linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)
python_module: linalg

View File

@ -705,6 +705,71 @@ struct ReductionMulOp {
__forceinline__ scalar_t identity_cpu() const { return 1; }
};
void _apply_sparse_csr_linear_solve(
const Tensor& A,
const Tensor& b,
const bool left,
const Tensor& x) {
#if defined(USE_ROCM) || !defined(USE_CUDSS)
TORCH_CHECK(
false,
"Calling linear solver with sparse tensors requires compiling ",
"PyTorch with CUDA cuDSS and is not supported in ROCm build.");
#else
// layout check
TORCH_CHECK(A.is_sparse_csr(), "A must be a CSR matrix");
TORCH_CHECK(b.layout() == kStrided, "b must be a strided tensor");
TORCH_CHECK(x.layout() == kStrided, "x must be a strided tensor");
// dim check
TORCH_CHECK(b.dim() == 1, "b must be a 1D tensor");
TORCH_CHECK(b.stride(0) == 1, "b must be a column major tensor");
TORCH_CHECK(b.size(0) == A.size(0), "linear system size mismatch.");
TORCH_CHECK(x.dim() == 1, "x must be a 1D tensor");
TORCH_CHECK(x.stride(0) == 1, "x must be a column major tensor");
TORCH_CHECK(x.size(0) == A.size(1), "linear system size mismatch.");
TORCH_CHECK(A.dtype() == b.dtype() && A.dtype() == x.dtype(), "A, x, and b must have the same dtype");
TORCH_CHECK(left == true, "only left == true is supported by the Sparse CSR backend")
Tensor crow = A.crow_indices();
Tensor col = A.col_indices();
if (crow.scalar_type() != ScalarType::Int) {
crow = crow.to(crow.options().dtype(ScalarType::Int));
col = col.to(col.options().dtype(ScalarType::Int));
}
int* rowOffsets = crow.data<int>();
int* colIndices = col.data<int>();
Tensor values = A.values();
// cuDSS data structures and handle initialization
cudssConfig_t config;
cudssMatrix_t b_mt;
cudssMatrix_t A_mt;
cudssMatrix_t x_mt;
cudssData_t cudss_data;
cudssHandle_t handle = at::cuda::getCurrentCudssHandle();
TORCH_CUDSS_CHECK(cudssConfigCreate(&config));
TORCH_CUDSS_CHECK(cudssDataCreate(handle, &cudss_data));
AT_DISPATCH_FLOATING_TYPES(values.type(), "create_matrix", ([&] {
scalar_t* values_ptr = values.data<scalar_t>();
scalar_t* b_ptr = b.data<scalar_t>();
scalar_t* x_ptr = x.data<scalar_t>();
auto CUDA_R_TYP = std::is_same<scalar_t, double>::value ? CUDA_R_64F : CUDA_R_32F;
TORCH_CUDSS_CHECK(cudssMatrixCreateDn(&b_mt, b.size(0), 1, b.size(0), b_ptr, CUDA_R_TYP, CUDSS_LAYOUT_COL_MAJOR));
TORCH_CUDSS_CHECK(cudssMatrixCreateDn(&x_mt, x.size(0), 1, x.size(0), x_ptr, CUDA_R_TYP, CUDSS_LAYOUT_COL_MAJOR));
TORCH_CUDSS_CHECK(cudssMatrixCreateCsr(&A_mt, A.size(0), A.size(1), A._nnz(), rowOffsets, rowOffsets + crow.size(0), colIndices, values_ptr, CUDA_R_32I, CUDA_R_TYP, CUDSS_MTYPE_GENERAL, CUDSS_MVIEW_FULL, CUDSS_BASE_ZERO));
}));
TORCH_CUDSS_CHECK(cudssExecute(handle, CUDSS_PHASE_ANALYSIS, config, cudss_data, A_mt, x_mt, b_mt));
TORCH_CUDSS_CHECK(cudssExecute(handle, CUDSS_PHASE_FACTORIZATION, config, cudss_data, A_mt, x_mt, b_mt));
TORCH_CUDSS_CHECK(cudssExecute(handle, CUDSS_PHASE_SOLVE, config, cudss_data, A_mt, x_mt, b_mt));
// Destroy the opaque objects
TORCH_CUDSS_CHECK(cudssConfigDestroy(config));
TORCH_CUDSS_CHECK(cudssDataDestroy(handle, cudss_data));
TORCH_CUDSS_CHECK(cudssMatrixDestroy(A_mt));
TORCH_CUDSS_CHECK(cudssMatrixDestroy(x_mt));
TORCH_CUDSS_CHECK(cudssMatrixDestroy(b_mt));
#endif
}
} // namespace
Tensor _sparse_csr_sum_cuda(const Tensor& input, IntArrayRef dims_to_sum, bool keepdim, std::optional<ScalarType> dtype) {
@ -735,4 +800,12 @@ Tensor _sparse_csr_prod_cuda(const Tensor& input, IntArrayRef dims_to_reduce, bo
return result;
}
Tensor _sparse_csr_linear_solve(const Tensor& A, const Tensor& b, const bool left) {
Tensor b_copy = b.contiguous();
Tensor out = b_copy.new_empty(b_copy.sizes());
_apply_sparse_csr_linear_solve(A, b_copy, left, out);
return out;
}
} // namespace at::native

View File

@ -1443,6 +1443,7 @@ aten_cuda_cu_source_list = [
"aten/src/ATen/cuda/CUDABlas.cpp",
"aten/src/ATen/cuda/CUDASparseBlas.cpp",
"aten/src/ATen/cuda/CublasHandlePool.cpp",
"aten/src/ATen/native/cuda/linalg/CudssHandlePool.cpp",
"aten/src/ATen/cuda/tunable/StreamTimer.cpp",
"aten/src/ATen/cuda/tunable/Tunable.cpp",
"aten/src/ATen/native/cuda/Activation.cpp",

View File

@ -936,6 +936,10 @@ elseif(USE_CUDA)
target_link_libraries(torch_cuda PRIVATE torch::cusparselt)
target_compile_definitions(torch_cuda PRIVATE USE_CUSPARSELT)
endif()
if(USE_CUDSS)
target_link_libraries(torch_cuda PRIVATE torch::cudss)
target_compile_definitions(torch_cuda PRIVATE USE_CUDSS)
endif()
if(USE_NCCL)
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
target_compile_definitions(torch_cuda PRIVATE USE_NCCL)

View File

@ -0,0 +1,67 @@
# Find the CUDSS library
#
# The following variables are optionally searched for defaults
# CUDSS_ROOT: Base directory where CUDSS is found
# CUDSS_INCLUDE_DIR: Directory where CUDSS header is searched for
# CUDSS_LIBRARY: Directory where CUDSS library is searched for
#
# The following are set after configuration is done:
# CUDSS_FOUND
# CUDSS_INCLUDE_PATH
# CUDSS_LIBRARY_PATH
include(FindPackageHandleStandardArgs)
set(CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} CACHE PATH "Folder containing NVIDIA CUDSS")
if (DEFINED $ENV{CUDSS_ROOT_DIR})
message(WARNING "CUDSS_ROOT_DIR is deprecated. Please set CUDSS_ROOT instead.")
endif()
list(APPEND CUDSS_ROOT $ENV{CUDSS_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
# Compatible layer for CMake <3.12. CUDSS_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
list(APPEND CMAKE_PREFIX_PATH ${CUDSS_ROOT})
set(CUDSS_INCLUDE_DIR $ENV{CUDSS_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA CUDSS header files")
find_path(CUDSS_INCLUDE_PATH cudss.h
HINTS ${CUDSS_INCLUDE_DIR}
PATH_SUFFIXES cuda/include cuda include)
set(CUDSS_LIBRARY $ENV{CUDSS_LIBRARY} CACHE PATH "Path to the CUDSS library file (e.g., libcudss.so)")
set(CUDSS_LIBRARY_NAME "libcudss.so")
if(MSVC)
set(CUDSS_LIBRARY_NAME "cudss.lib")
endif()
find_library(CUDSS_LIBRARY_PATH ${CUDSS_LIBRARY_NAME}
PATHS ${CUDSS_LIBRARY}
PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
find_package_handle_standard_args(CUDSS DEFAULT_MSG CUDSS_LIBRARY_PATH CUDSS_INCLUDE_PATH)
if(CUDSS_FOUND)
# Get CUDSS version
file(READ ${CUDSS_INCLUDE_PATH}/cudss.h CUDSS_HEADER_CONTENTS)
string(REGEX MATCH "define CUDSS_VER_MAJOR * +([0-9]+)"
CUDSS_VERSION_MAJOR "${CUDSS_HEADER_CONTENTS}")
string(REGEX REPLACE "define CUDSS_VER_MAJOR * +([0-9]+)" "\\1"
CUDSS_VERSION_MAJOR "${CUDSS_VERSION_MAJOR}")
string(REGEX MATCH "define CUDSS_VER_MINOR * +([0-9]+)"
CUDSS_VERSION_MINOR "${CUDSS_HEADER_CONTENTS}")
string(REGEX REPLACE "define CUDSS_VER_MINOR * +([0-9]+)" "\\1"
CUDSS_VERSION_MINOR "${CUDSS_VERSION_MINOR}")
string(REGEX MATCH "define CUDSS_VER_PATCH * +([0-9]+)"
CUDSS_VERSION_PATCH "${CUDSS_HEADER_CONTENTS}")
string(REGEX REPLACE "define CUDSS_VER_PATCH * +([0-9]+)" "\\1"
CUDSS_VERSION_PATCH "${CUDSS_VERSION_PATCH}")
# Assemble CUDSS version. Use minor version since current major version is 0.
if(NOT CUDSS_VERSION_MINOR)
set(CUDSS_VERSION "?")
else()
set(CUDSS_VERSION
"${CUDSS_VERSION_MAJOR}.${CUDSS_VERSION_MINOR}.${CUDSS_VERSION_PATCH}")
endif()
endif()
mark_as_advanced(CUDSS_ROOT CUDSS_INCLUDE_DIR CUDSS_LIBRARY CUDSS_VERSION)

View File

@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
message(STATUS " USE_CUDSS : ${USE_CUDSS}")
message(STATUS " USE_CUFILE : ${USE_CUFILE}")
message(STATUS " CUDA version : ${CUDA_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
@ -102,6 +103,10 @@ function(caffe2_print_configuration_summary)
get_target_property(__tmp torch::cusparselt INTERFACE_LINK_LIBRARIES)
message(STATUS " cuSPARSELt library : ${__tmp}")
endif()
if(${USE_CUDSS})
get_target_property(__tmp torch::cudss INTERFACE_LINK_LIBRARIES)
message(STATUS " cuDSS library : ${__tmp}")
endif()
message(STATUS " nvrtc : ${CUDA_nvrtc_LIBRARY}")
message(STATUS " CUDA include path : ${CUDA_INCLUDE_DIRS}")
message(STATUS " NVCC executable : ${CUDA_NVCC_EXECUTABLE}")

View File

@ -248,6 +248,22 @@ else()
message(STATUS "USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support")
endif()
if(USE_CUDSS)
find_package(CUDSS)
if(NOT CUDSS_FOUND)
message(WARNING
"Cannot find CUDSS library. Turning the option off")
set(USE_CUDSS OFF)
else()
add_library(torch::cudss INTERFACE IMPORTED)
target_include_directories(torch::cudss INTERFACE ${CUDSS_INCLUDE_PATH})
target_link_libraries(torch::cudss INTERFACE ${CUDSS_LIBRARY_PATH})
endif()
else()
message(STATUS "USE_CUDSS is set to 0. Compiling without cuDSS support")
endif()
# cufile
if(CAFFE2_USE_CUFILE)
add_library(torch::cufile INTERFACE IMPORTED)

View File

@ -1147,6 +1147,7 @@ multiplication, and ``@`` is matrix multiplication.
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[SparseSemiStructured] @ M[strided]) -> M[strided]``
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[strided] @ M[SparseSemiStructured]) -> M[strided]``
:func:`torch.sparse.addmm`; yes; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]``
:func:`torch.sparse.spsolve`; no; ``SOLVE(M[sparse_csr], V[strided]) -> V[strided]``
:func:`torch.sspaddmm`; no; ``f * M[sparse_coo] + f * (M[sparse_coo] @ M[strided]) -> M[sparse_coo]``
:func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]``
:func:`torch.pca_lowrank`; yes; ``PCA(M[sparse_coo]) -> M[strided], M[strided], M[strided]``
@ -1292,6 +1293,7 @@ Torch functions specific to sparse Tensors
hspmm
smm
sparse.softmax
sparse.spsolve
sparse.log_softmax
sparse.spdiags

View File

@ -38,6 +38,9 @@
# USE_CUSPARSELT=0
# disables the cuSPARSELt build
#
# USE_CUDSS=0
# disables the cuDSS build
#
# USE_CUFILE=0
# disables the cuFile build
#

View File

@ -555,6 +555,7 @@ aten::_sparse_sum_backward
aten::_sparse_sum_backward.out
aten::_spdiags
aten::_spdiags.out
aten::_spsolve
aten::_stack
aten::_stack.out
aten::_standard_gamma

View File

@ -10,8 +10,8 @@ from contextlib import redirect_stderr
from torch.testing import make_tensor, FileCheck
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
from torch.testing._internal.common_utils import \
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests,
load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
suppress_warnings)
from torch.testing._internal.common_device_type import \
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
@ -23,6 +23,7 @@ from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_CU
from torch.testing._internal.common_dtype import (
floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
all_types_and_complex, floating_and_complex_types_and)
from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
import operator
@ -3487,6 +3488,46 @@ class TestSparseCSR(TestCase):
self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
@unittest.skipIf(not TEST_CUDA_CUDSS, "The test requires cudss")
@dtypes(*floating_types())
def test_linalg_solve_sparse_csr_cusolver(self, device, dtype):
# https://github.com/krshrimali/pytorch/blob/f5ee21dd87a7c5e67ba03bfd77ea22246cabdf0b/test/test_sparse_csr.py
try:
spd = torch.rand(4, 3)
A = spd.T @ spd
b = torch.rand(3).cuda()
A = A.to_sparse_csr().cuda()
x = torch.sparse.spsolve(A, b)
except RuntimeError as e:
if "Calling linear solver with sparse tensors requires compiling " in str(e):
self.skipTest("PyTorch was not built with cuDSS support")
samples = sample_inputs_linalg_solve(None, device, dtype)
for sample in samples:
if sample.input.ndim != 2:
continue
out = torch.zeros(sample.args[0].size(), dtype=dtype, device=device)
if sample.args[0].ndim != 1 and sample.args[0].size(-1) != 1:
with self.assertRaisesRegex(RuntimeError, "b must be a 1D tensor"):
out = torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs)
break
if not sample.args[0].numel():
with self.assertRaisesRegex(RuntimeError,
"Expected non-empty other tensor, but found empty tensor"):
torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
break
expect = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
sample.input = sample.input.to_sparse_csr()
if sample.args[0].ndim != 1 and sample.args[0].size(-1) == 1:
expect = expect.squeeze(-1)
sample.args = (sample.args[0].squeeze(-1), )
out = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
self.assertEqual(expect, out)
def skipIfNoTriton(cls):
from torch.utils._triton import has_triton

View File

@ -2134,6 +2134,9 @@ Letting `*` be zero or more batch dimensions,
It is possible to compute the solution of the system :math:`XA = B` by passing the inputs
:attr:`A` and :attr:`B` transposed and transposing the output returned by this function.
.. note::
:attr:`A` is allowed to be a non-batched `torch.sparse_csr_tensor`, but only with `left=True`.
""" + fr"""
.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.solve_ex")}
""" + r"""

View File

@ -32,6 +32,7 @@ __all__ = [
"mm",
"sum",
"softmax",
"solve",
"log_softmax",
"SparseSemiStructuredTensor",
"SparseSemiStructuredTensorCUTLASS",
@ -297,6 +298,26 @@ Args:
)
spsolve = _add_docstr(
_sparse._spsolve,
r"""
sparse.spsolve(input, other, *, left=True) -> Tensor
Computes the solution of a square system of linear equations with
a unique solution. Its purpose is similar to :func:`torch.linalg.solve`,
except that the system is defined by a sparse CSR matrix with layout
`sparse_csr`.
Args:
input (Tensor): a sparse CSR matrix of shape `(n, n)` representing the
coefficients of the linear system.
other (Tensor): a dense matrix of shape `(n, )` representing the right-hand
side of the linear system.
left (bool, optional): whether to solve the system for `input @ out = other`
(default) or `out @ input = other`. Only `left=True` is supported.
""",
)
log_softmax = _add_docstr(
_sparse._sparse_log_softmax,
r"""

View File

@ -1427,6 +1427,8 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
(torch.version.hip and float(".".join(torch.version.hip.split(".")[0:2])) >= 5.3)
)
TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
def allocator_option_enabled_fn(allocator_config, _, option):
if allocator_config is None:
return False