API change for new enum in cusparseltsplitkmode-t for cusparseLT 0.7.0+ (#150536)

Changing the bool to int to express split_k_mode. Before 0.7.0 we only have 2 cusparseLtSplitKMode_t enum values ONE_KERNEL and TWO_KERNELS so a boolean is enough but since 0.7.0 there are more.

For Blackwell, there has to be minor change to parameter split_k_one_kernel (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp#L103), since there are new values introduced to enum [cusparseLtSplitKMode_t](https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t) and a bool type is not enough for it (would have to be replaced with integer) https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t

Error we see without the change
```
RuntimeError: CUDA error: invalid value when calling `cusparseLtMatmulAlgSetAttribute( &handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode))`

To execute this test, run the following from the base repo dir:
    python test/test_sparse_semi_structured.py TestSparseSemiStructuredCUSPARSELTCUDA.test_csrc_cslt_sparse_mm_search_cuda_int8
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150536
Approved by: https://github.com/jcaip, https://github.com/atalman
This commit is contained in:
Ting Lu
2025-05-14 23:36:53 +00:00
committed by PyTorch MergeBot
parent 72fee137dd
commit c2bc7e2827
7 changed files with 29 additions and 27 deletions

View File

@ -3385,7 +3385,7 @@
dispatch:
CUDA: _cslt_compress
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, bool split_k_one_kernel=True) -> Tensor
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, int split_k_mode=-1) -> Tensor
dispatch:
CUDA: _cslt_sparse_mm
tags: needs_fixed_stride_order

View File

@ -91,7 +91,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
return compressed_tensor;
}
std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
const Tensor& compressed_A,
const Tensor& dense_B,
const std::optional<Tensor>& bias_opt,
@ -100,7 +100,7 @@ std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
bool transpose_result,
int alg_id,
int split_k,
bool split_k_one_kernel,
int split_k_mode,
bool search_alg_id) {
if (!handle_initialized) {
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
@ -351,14 +351,15 @@ std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
&split_k,
sizeof(split_k)));
splitKMode = split_k_one_kernel ? CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL
: CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS;
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
&handle,
&alg_sel,
CUSPARSELT_MATMUL_SPLIT_K_MODE,
&splitKMode,
sizeof(splitKMode)));
if (split_k_mode > 0) {
splitKMode = static_cast<cusparseLtSplitKMode_t>(split_k_mode);
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
&handle,
&alg_sel,
CUSPARSELT_MATMUL_SPLIT_K_MODE,
&splitKMode,
sizeof(splitKMode)));
}
}
// set tensor_alpha_mode and alpha pointer for matmul
@ -465,7 +466,7 @@ std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
res,
alg_id,
split_k,
splitKMode == CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL,
static_cast<int64_t>(splitKMode),
max_alg_id};
}
@ -478,7 +479,7 @@ at::Tensor _cslt_sparse_mm(
bool transpose_result,
int64_t alg_id,
int64_t split_k,
bool split_k_one_kernel) {
int64_t split_k_mode) {
auto result = _cslt_sparse_mm_impl(
compressed_A,
dense_B,
@ -488,7 +489,7 @@ at::Tensor _cslt_sparse_mm(
transpose_result,
(int)alg_id,
(int)split_k,
split_k_one_kernel,
(int)split_k_mode,
false);
return std::get<0>(result);
}
@ -504,7 +505,7 @@ int64_t _cslt_sparse_mm_search(
"torch._cslt_sparse_mm_search is deprecated and will be removed in a future PyTorch release. Please use torch._C._cusparselt.mm_search instead.");
int alg_id_int = 0;
int split_k = 1;
bool split_k_one_kernel = true;
int split_k_mode = -1;
auto result = _cslt_sparse_mm_impl(
compressed_A,
dense_B,
@ -514,7 +515,7 @@ int64_t _cslt_sparse_mm_search(
transpose_result,
alg_id_int,
split_k,
split_k_one_kernel,
split_k_mode,
true);
return (int64_t)std::get<1>(result);
}
@ -538,7 +539,7 @@ at::Tensor _cslt_sparse_mm(
bool transpose_result,
int64_t alg_id,
int64_t split_k,
bool split_k_one_kernel) {
int64_t split_k_mode) {
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
}

View File

@ -21,7 +21,7 @@ namespace at::native {
at::Tensor _cslt_compress(const Tensor& sparse_input);
TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
const Tensor& compressed_A,
const Tensor& dense_B,
const std::optional<Tensor>& bias_opt,
@ -30,7 +30,7 @@ TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt
bool transpose_result,
int alg_id,
int split_k,
bool split_k_one_kernel,
int split_k_mode,
bool search_alg_id
);
@ -43,7 +43,7 @@ at::Tensor _cslt_sparse_mm(
bool transpose_result,
int64_t alg_id,
int64_t split_k,
bool split_k_one_kernel
int64_t split_k_mode
);
int64_t _cslt_sparse_mm_search(

View File

@ -70,6 +70,7 @@ ALLOW_LIST = [
("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)),
("profiler::_record_function_enter", datetime.date(9999, 1, 1)),
("aten::_cholesky_helper", datetime.date(9999, 1, 1)),
("aten::_cslt_sparse_mm", datetime.date(9999, 1, 1)),
("aten::_lstsq_helper", datetime.date(9999, 1, 1)),
("aten::_syevd_helper", datetime.date(9999, 1, 1)),
("aten::_linalg_solve_out_helper_", datetime.date(9999, 1, 1)),

View File

@ -1207,11 +1207,11 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
alg_id, split_k, split_k_mode, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
alg_id=alg_id,
split_k=split_k,
split_k_one_kernel=split_k_one_kernel)
split_k_mode=split_k_mode)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)

View File

@ -638,7 +638,7 @@ def meta__cslt_sparse_mm(
transpose_result: bool = False,
alg_id: int = 0,
split_k: int = 1,
split_k_one_kernel: bool = False,
split_k_mode: int = -1,
):
assert dense_B.dtype in {
torch.float32,

View File

@ -9,7 +9,7 @@ size_t getVersionInt() {
return CUSPARSELT_VERSION;
}
std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
std::tuple<int64_t, int64_t, int64_t, int64_t> mmSearch(
const at::Tensor& compressed_A,
const at::Tensor& dense_B,
const std::optional<at::Tensor>& bias_opt,
@ -18,7 +18,7 @@ std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
bool transpose_result) {
int alg_id_int = 0;
int split_k = 1;
bool split_k_one_kernel = true;
int split_k_mode = -1;
auto result = at::native::_cslt_sparse_mm_impl(
compressed_A,
dense_B,
@ -28,12 +28,12 @@ std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
transpose_result,
alg_id_int,
split_k,
split_k_one_kernel,
split_k_mode,
true);
return {
(int64_t)std::get<1>(result),
(int64_t)std::get<2>(result),
(bool)std::get<3>(result),
(int64_t)std::get<3>(result),
(int64_t)std::get<4>(result)};
}