API change for new enum in cusparseltsplitkmode-t for cusparseLT 0.7.0+ (#150536)

Changing the bool to int to express split_k_mode. Before 0.7.0 we only have 2 cusparseLtSplitKMode_t enum values ONE_KERNEL and TWO_KERNELS so a boolean is enough but since 0.7.0 there are more.

For Blackwell, there has to be minor change to parameter split_k_one_kernel (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp#L103), since there are new values introduced to enum [cusparseLtSplitKMode_t](https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t) and a bool type is not enough for it (would have to be replaced with integer) https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t

Error we see without the change
```
RuntimeError: CUDA error: invalid value when calling `cusparseLtMatmulAlgSetAttribute( &handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode))`

To execute this test, run the following from the base repo dir:
    python test/test_sparse_semi_structured.py TestSparseSemiStructuredCUSPARSELTCUDA.test_csrc_cslt_sparse_mm_search_cuda_int8
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150536
Approved by: https://github.com/jcaip, https://github.com/atalman
This commit is contained in:
Ting Lu
2025-05-14 23:36:53 +00:00
committed by PyTorch MergeBot
parent 72fee137dd
commit c2bc7e2827
7 changed files with 29 additions and 27 deletions

View File

@ -1207,11 +1207,11 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
alg_id, split_k, split_k_mode, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
alg_id=alg_id,
split_k=split_k,
split_k_one_kernel=split_k_one_kernel)
split_k_mode=split_k_mode)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)