mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
API change for new enum in cusparseltsplitkmode-t for cusparseLT 0.7.0+ (#150536)
Changing the bool to int to express split_k_mode. Before 0.7.0 we only have 2 cusparseLtSplitKMode_t enum values ONE_KERNEL and TWO_KERNELS so a boolean is enough but since 0.7.0 there are more. For Blackwell, there has to be minor change to parameter split_k_one_kernel (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp#L103), since there are new values introduced to enum [cusparseLtSplitKMode_t](https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t) and a bool type is not enough for it (would have to be replaced with integer) https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t Error we see without the change ``` RuntimeError: CUDA error: invalid value when calling `cusparseLtMatmulAlgSetAttribute( &handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode))` To execute this test, run the following from the base repo dir: python test/test_sparse_semi_structured.py TestSparseSemiStructuredCUSPARSELTCUDA.test_csrc_cslt_sparse_mm_search_cuda_int8 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150536 Approved by: https://github.com/jcaip, https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
72fee137dd
commit
c2bc7e2827
@ -3385,7 +3385,7 @@
|
||||
dispatch:
|
||||
CUDA: _cslt_compress
|
||||
|
||||
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, bool split_k_one_kernel=True) -> Tensor
|
||||
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, int split_k_mode=-1) -> Tensor
|
||||
dispatch:
|
||||
CUDA: _cslt_sparse_mm
|
||||
tags: needs_fixed_stride_order
|
||||
|
@ -91,7 +91,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
||||
return compressed_tensor;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||
std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
@ -100,7 +100,7 @@ std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||
bool transpose_result,
|
||||
int alg_id,
|
||||
int split_k,
|
||||
bool split_k_one_kernel,
|
||||
int split_k_mode,
|
||||
bool search_alg_id) {
|
||||
if (!handle_initialized) {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
|
||||
@ -351,14 +351,15 @@ std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||
&split_k,
|
||||
sizeof(split_k)));
|
||||
|
||||
splitKMode = split_k_one_kernel ? CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL
|
||||
: CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS;
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||
&handle,
|
||||
&alg_sel,
|
||||
CUSPARSELT_MATMUL_SPLIT_K_MODE,
|
||||
&splitKMode,
|
||||
sizeof(splitKMode)));
|
||||
if (split_k_mode > 0) {
|
||||
splitKMode = static_cast<cusparseLtSplitKMode_t>(split_k_mode);
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||
&handle,
|
||||
&alg_sel,
|
||||
CUSPARSELT_MATMUL_SPLIT_K_MODE,
|
||||
&splitKMode,
|
||||
sizeof(splitKMode)));
|
||||
}
|
||||
}
|
||||
|
||||
// set tensor_alpha_mode and alpha pointer for matmul
|
||||
@ -465,7 +466,7 @@ std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||
res,
|
||||
alg_id,
|
||||
split_k,
|
||||
splitKMode == CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL,
|
||||
static_cast<int64_t>(splitKMode),
|
||||
max_alg_id};
|
||||
}
|
||||
|
||||
@ -478,7 +479,7 @@ at::Tensor _cslt_sparse_mm(
|
||||
bool transpose_result,
|
||||
int64_t alg_id,
|
||||
int64_t split_k,
|
||||
bool split_k_one_kernel) {
|
||||
int64_t split_k_mode) {
|
||||
auto result = _cslt_sparse_mm_impl(
|
||||
compressed_A,
|
||||
dense_B,
|
||||
@ -488,7 +489,7 @@ at::Tensor _cslt_sparse_mm(
|
||||
transpose_result,
|
||||
(int)alg_id,
|
||||
(int)split_k,
|
||||
split_k_one_kernel,
|
||||
(int)split_k_mode,
|
||||
false);
|
||||
return std::get<0>(result);
|
||||
}
|
||||
@ -504,7 +505,7 @@ int64_t _cslt_sparse_mm_search(
|
||||
"torch._cslt_sparse_mm_search is deprecated and will be removed in a future PyTorch release. Please use torch._C._cusparselt.mm_search instead.");
|
||||
int alg_id_int = 0;
|
||||
int split_k = 1;
|
||||
bool split_k_one_kernel = true;
|
||||
int split_k_mode = -1;
|
||||
auto result = _cslt_sparse_mm_impl(
|
||||
compressed_A,
|
||||
dense_B,
|
||||
@ -514,7 +515,7 @@ int64_t _cslt_sparse_mm_search(
|
||||
transpose_result,
|
||||
alg_id_int,
|
||||
split_k,
|
||||
split_k_one_kernel,
|
||||
split_k_mode,
|
||||
true);
|
||||
return (int64_t)std::get<1>(result);
|
||||
}
|
||||
@ -538,7 +539,7 @@ at::Tensor _cslt_sparse_mm(
|
||||
bool transpose_result,
|
||||
int64_t alg_id,
|
||||
int64_t split_k,
|
||||
bool split_k_one_kernel) {
|
||||
int64_t split_k_mode) {
|
||||
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,7 @@ namespace at::native {
|
||||
|
||||
at::Tensor _cslt_compress(const Tensor& sparse_input);
|
||||
|
||||
TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||
TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
@ -30,7 +30,7 @@ TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt
|
||||
bool transpose_result,
|
||||
int alg_id,
|
||||
int split_k,
|
||||
bool split_k_one_kernel,
|
||||
int split_k_mode,
|
||||
bool search_alg_id
|
||||
);
|
||||
|
||||
@ -43,7 +43,7 @@ at::Tensor _cslt_sparse_mm(
|
||||
bool transpose_result,
|
||||
int64_t alg_id,
|
||||
int64_t split_k,
|
||||
bool split_k_one_kernel
|
||||
int64_t split_k_mode
|
||||
);
|
||||
|
||||
int64_t _cslt_sparse_mm_search(
|
||||
|
@ -70,6 +70,7 @@ ALLOW_LIST = [
|
||||
("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)),
|
||||
("profiler::_record_function_enter", datetime.date(9999, 1, 1)),
|
||||
("aten::_cholesky_helper", datetime.date(9999, 1, 1)),
|
||||
("aten::_cslt_sparse_mm", datetime.date(9999, 1, 1)),
|
||||
("aten::_lstsq_helper", datetime.date(9999, 1, 1)),
|
||||
("aten::_syevd_helper", datetime.date(9999, 1, 1)),
|
||||
("aten::_linalg_solve_out_helper_", datetime.date(9999, 1, 1)),
|
||||
|
@ -1207,11 +1207,11 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
B = torch.ones((128, 128), device=device).to(dtype)
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||
alg_id, split_k, split_k_mode, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
|
||||
alg_id=alg_id,
|
||||
split_k=split_k,
|
||||
split_k_one_kernel=split_k_one_kernel)
|
||||
split_k_mode=split_k_mode)
|
||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||
dense_result = dense_result.to(dtype)
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
@ -638,7 +638,7 @@ def meta__cslt_sparse_mm(
|
||||
transpose_result: bool = False,
|
||||
alg_id: int = 0,
|
||||
split_k: int = 1,
|
||||
split_k_one_kernel: bool = False,
|
||||
split_k_mode: int = -1,
|
||||
):
|
||||
assert dense_B.dtype in {
|
||||
torch.float32,
|
||||
|
@ -9,7 +9,7 @@ size_t getVersionInt() {
|
||||
return CUSPARSELT_VERSION;
|
||||
}
|
||||
|
||||
std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
|
||||
std::tuple<int64_t, int64_t, int64_t, int64_t> mmSearch(
|
||||
const at::Tensor& compressed_A,
|
||||
const at::Tensor& dense_B,
|
||||
const std::optional<at::Tensor>& bias_opt,
|
||||
@ -18,7 +18,7 @@ std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
|
||||
bool transpose_result) {
|
||||
int alg_id_int = 0;
|
||||
int split_k = 1;
|
||||
bool split_k_one_kernel = true;
|
||||
int split_k_mode = -1;
|
||||
auto result = at::native::_cslt_sparse_mm_impl(
|
||||
compressed_A,
|
||||
dense_B,
|
||||
@ -28,12 +28,12 @@ std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
|
||||
transpose_result,
|
||||
alg_id_int,
|
||||
split_k,
|
||||
split_k_one_kernel,
|
||||
split_k_mode,
|
||||
true);
|
||||
return {
|
||||
(int64_t)std::get<1>(result),
|
||||
(int64_t)std::get<2>(result),
|
||||
(bool)std::get<3>(result),
|
||||
(int64_t)std::get<3>(result),
|
||||
(int64_t)std::get<4>(result)};
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user