[sparse] add extra options to _cslt_spare_mm (#137427)

Summary:

Splitting this PR into two, one for the cuSPARSELt improvements, and one
for the inductor lowering.

This PR adds in the additional cuSPARSELt bindings into pytorch.

* `torch._cslt_sparse_mm_search` will be deprecated in a future PR,
  so a warning has been added

* Added a header file for cuSPARSELtOps.cpp

* max_id is now available in `torch.backends.cusparselt` via
  `torch.backends.cusparselt.get_max_alg_id()`

* fixed meta registrations for float8

Test Plan:

python test/test_sparse_semi_structured.py

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch, https://github.com/eqy
This commit is contained in:
Jesse Cai
2024-11-25 11:31:05 -08:00
committed by PyTorch MergeBot
parent 02990fe36b
commit f1451163ec
8 changed files with 450 additions and 528 deletions

View File

@ -3371,7 +3371,7 @@
dispatch: dispatch:
CUDA: _cslt_compress CUDA: _cslt_compress
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor - func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, bool split_k_one_kernel=True) -> Tensor
dispatch: dispatch:
CUDA: _cslt_sparse_mm CUDA: _cslt_sparse_mm

View File

@ -1,109 +1,97 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
#include <ATen/cuda/CUDADataType.h>
#include <ATen/cuda/CUDASparse.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Functions.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Half.h>
#include <cusparse.h>
#include <cstdint>
#if AT_CUSPARSELT_ENABLED() #if AT_CUSPARSELT_ENABLED()
#include <cusparseLt.h>
namespace at::native { namespace at::native {
// Ideally we would use the same DeviceThreadHandlePool mechanism as used in aten/src/ATen/cuda/CuSparseHandlePool.cpp // Ideally we would use the same DeviceThreadHandlePool mechanism as used in
// which would handle this for us. However, the cuSPARSELt handle signature is different from that of cuSPARSE/cuBLAS, // aten/src/ATen/cuda/CuSparseHandlePool.cpp which would handle this for us.
// so it's not possible to reuse the existing pooling mechanism. Instead we have to handle our handles ourselves, which // However, the cuSPARSELt handle signature is different from that of
// is why these variables are thread local. Once cuSPARSELt updates their handle signature to be consistent with the rest // cuSPARSE/cuBLAS, so it's not possible to reuse the existing pooling
// of CUDA, we can switch to using DeviceThreadHandlePool. // mechanism. Instead we have to handle our handles ourselves, which is why
// these variables are thread local. Once cuSPARSELt updates their handle
// signature to be consistent with the rest of CUDA, we can switch to using
// DeviceThreadHandlePool.
thread_local cusparseLtHandle_t handle; thread_local cusparseLtHandle_t handle;
thread_local bool handle_initialized = false; thread_local bool handle_initialized = false;
at::Tensor _cslt_compress(const Tensor& sparse_input) at::Tensor _cslt_compress(const Tensor& sparse_input) {
{ if (!handle_initialized) {
if (!handle_initialized){ TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); handle_initialized = true;
handle_initialized = true; }
} // create sparse descriptor, dtype
// create sparse descriptor, dtype cusparseLtMatDescriptor_t sparse_input_descriptor;
cusparseLtMatDescriptor_t sparse_input_descriptor; cudaDataType type;
cudaDataType type; auto compression_factor = 9;
auto compression_factor = 9;
switch( switch (sparse_input.scalar_type()) {
sparse_input.scalar_type() case at::ScalarType::Char:
) type = CUDA_R_8I;
{ compression_factor = 10;
case at::ScalarType::Char: break;
type = CUDA_R_8I; case at::ScalarType::Half:
compression_factor = 10; type = CUDA_R_16F;
break; break;
case at::ScalarType::Half: case at::ScalarType::BFloat16:
type = CUDA_R_16F; type = CUDA_R_16BF;
break; break;
case at::ScalarType::BFloat16: case at::ScalarType::Float:
type = CUDA_R_16BF; type = CUDA_R_32F;
break; break;
case at::ScalarType::Float:
type = CUDA_R_32F;
break;
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
case at::ScalarType::Float8_e4m3fn: case at::ScalarType::Float8_e4m3fn:
type = CUDA_R_8F_E4M3; type = CUDA_R_8F_E4M3;
break; compression_factor = 10;
break;
#endif #endif
default: default:
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix"); TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix");
break; break;
} }
// create a new compressed tensor with the same dtype as // create a new compressed tensor with the same dtype as
auto compressed_tensor = sparse_input.new_empty(sparse_input.numel() * compression_factor / 16); auto compressed_tensor =
sparse_input.new_empty(sparse_input.numel() * compression_factor / 16);
TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit( TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit(
&handle, &handle,
&sparse_input_descriptor, &sparse_input_descriptor,
sparse_input.size(0), sparse_input.size(0),
sparse_input.size(1), sparse_input.size(1),
sparse_input.size(1), sparse_input.size(1),
16, 16,
type, type,
CUSPARSE_ORDER_ROW, CUSPARSE_ORDER_ROW,
CUSPARSELT_SPARSITY_50_PERCENT)); CUSPARSELT_SPARSITY_50_PERCENT));
// compress input // compress input
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
size_t compressed_size, compressed_buffer_size; size_t compressed_size, compressed_buffer_size;
TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompressedSize2( TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompressedSize2(
&handle, &handle,
&sparse_input_descriptor, &sparse_input_descriptor,
&compressed_size, &compressed_size,
&compressed_buffer_size)); &compressed_buffer_size));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto compressedBufferPtr = allocator.allocate(compressed_buffer_size); auto compressedBufferPtr = allocator.allocate(compressed_buffer_size);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompress2( TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompress2(
&handle, &handle,
&sparse_input_descriptor, &sparse_input_descriptor,
true, true,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE,
sparse_input.data_ptr(), sparse_input.data_ptr(),
compressed_tensor.data_ptr(), compressed_tensor.data_ptr(),
compressedBufferPtr.get(), compressedBufferPtr.get(),
stream)); stream));
return compressed_tensor; return compressed_tensor;
} }
std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl( std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
const Tensor& compressed_A, const Tensor& compressed_A,
const Tensor& dense_B, const Tensor& dense_B,
const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& bias_opt,
@ -111,12 +99,12 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
const std::optional<c10::ScalarType> out_dtype_opt, const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result, bool transpose_result,
int alg_id, int alg_id,
bool search_alg_id int split_k,
) bool split_k_one_kernel,
{ bool search_alg_id) {
if (!handle_initialized){ if (!handle_initialized) {
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
handle_initialized = true; handle_initialized = true;
} }
// cupsarselt constructs // cupsarselt constructs
cusparseLtMatmulDescriptor_t matmul; cusparseLtMatmulDescriptor_t matmul;
@ -132,134 +120,138 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
cusparseComputeType compute_type; cusparseComputeType compute_type;
auto compression_factor = 9; auto compression_factor = 9;
switch(compressed_A.scalar_type()) switch (compressed_A.scalar_type()) {
{
case at::ScalarType::Char: case at::ScalarType::Char:
input_type = CUDA_R_8I; input_type = CUDA_R_8I;
output_type = CUDA_R_8I; output_type = CUDA_R_8I;
C_type = CUDA_R_8I; C_type = CUDA_R_8I;
compute_type = CUSPARSE_COMPUTE_32I; compute_type = CUSPARSE_COMPUTE_32I;
compression_factor = 10; compression_factor = 10;
break; break;
// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F to CUSPARSE_COMPUTE_32F // cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F
// to CUSPARSE_COMPUTE_32F
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502
case at::ScalarType::Half: case at::ScalarType::Half:
input_type = CUDA_R_16F; input_type = CUDA_R_16F;
output_type = CUDA_R_16F; output_type = CUDA_R_16F;
C_type = CUDA_R_16F; C_type = CUDA_R_16F;
compute_type = CUSPARSE_COMPUTE_32F; compute_type = CUSPARSE_COMPUTE_32F;
break; break;
case at::ScalarType::BFloat16: case at::ScalarType::BFloat16:
input_type = CUDA_R_16BF; input_type = CUDA_R_16BF;
output_type = CUDA_R_16BF; output_type = CUDA_R_16BF;
C_type = CUDA_R_16BF; C_type = CUDA_R_16BF;
compute_type = CUSPARSE_COMPUTE_32F; compute_type = CUSPARSE_COMPUTE_32F;
break; break;
case at::ScalarType::Float: case at::ScalarType::Float:
input_type = CUDA_R_32F; input_type = CUDA_R_32F;
output_type = CUDA_R_32F; output_type = CUDA_R_32F;
C_type = CUDA_R_32F; C_type = CUDA_R_32F;
compute_type = CUSPARSE_COMPUTE_32F; compute_type = CUSPARSE_COMPUTE_32F;
break; break;
// if cuSPARSELt >= 6.2.3, we can add Float8 support // if cuSPARSELt >= 6.2.3, we can add Float8 support
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
case at::ScalarType::Float8_e4m3fn: case at::ScalarType::Float8_e4m3fn:
input_type = CUDA_R_8F_E4M3; input_type = CUDA_R_8F_E4M3;
output_type = CUDA_R_8F_E4M3; output_type = CUDA_R_8F_E4M3;
C_type = CUDA_R_16F; C_type = CUDA_R_16F;
compute_type = CUSPARSE_COMPUTE_32F; compute_type = CUSPARSE_COMPUTE_32F;
break; compression_factor = 10;
break;
#endif #endif
// cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F // cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F
#else #else
case at::ScalarType::Half: case at::ScalarType::Half:
input_type = CUDA_R_16F; input_type = CUDA_R_16F;
output_type = CUDA_R_16F; output_type = CUDA_R_16F;
C_type = CUDA_R_16F; C_type = CUDA_R_16F;
compute_type = CUSPARSE_COMPUTE_16F; compute_type = CUSPARSE_COMPUTE_16F;
break; break;
case at::ScalarType::BFloat16: case at::ScalarType::BFloat16:
input_type = CUDA_R_16BF; input_type = CUDA_R_16BF;
output_type = CUDA_R_16BF; output_type = CUDA_R_16BF;
C_type = CUDA_R_16BF; C_type = CUDA_R_16BF;
compute_type = CUSPARSE_COMPUTE_16F; compute_type = CUSPARSE_COMPUTE_16F;
break; break;
case at::ScalarType::Float: case at::ScalarType::Float:
input_type = CUDA_R_32F; input_type = CUDA_R_32F;
output_type = CUDA_R_32F; output_type = CUDA_R_32F;
C_type = CUDA_R_32F; C_type = CUDA_R_32F;
compute_type = CUSPARSE_COMPUTE_TF32; compute_type = CUSPARSE_COMPUTE_TF32;
break; break;
#endif #endif
default: default:
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix multiplication."); TORCH_CHECK(
break; false,
"Unsupported dtype for cuSPARSELt compressed matrix multiplication.");
break;
} }
ScalarType out_dtype = dense_B.scalar_type(); ScalarType out_dtype = dense_B.scalar_type();
// special check for mixed dtype support for 8 bit dtypes // special check for mixed dtype support for 8 bit dtypes
// cslt 0.5.2+: int8 int8 -> {fp16, bf16, int32} support // cslt 0.5.2+: int8 int8 -> {fp16, bf16, int32} support
if (out_dtype_opt.has_value()) { if (out_dtype_opt.has_value()) {
out_dtype = out_dtype_opt.value(); out_dtype = out_dtype_opt.value();
if (input_type == CUDA_R_8I) if (input_type == CUDA_R_8I) {
{ switch (out_dtype) {
switch (out_dtype) case at::ScalarType::Half:
{ C_type = CUDA_R_16F;
case at::ScalarType::Half: output_type = CUDA_R_16F;
C_type = CUDA_R_16F; break;
output_type = CUDA_R_16F; case at::ScalarType::BFloat16:
break; C_type = CUDA_R_16BF;
case at::ScalarType::BFloat16: output_type = CUDA_R_16BF;
C_type = CUDA_R_16BF; break;
output_type = CUDA_R_16BF; case at::ScalarType::Int:
break; C_type = CUDA_R_32I;
case at::ScalarType::Int: output_type = CUDA_R_32I;
C_type = CUDA_R_32I; break;
output_type = CUDA_R_32I; default:
break; TORCH_CHECK(
default: false,
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, int32} for int8 inputs"); "Unsupported out_dtype passed, must be one of {fp16, bf16, int32} for int8 inputs");
break; break;
} }
} }
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support // cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
else if (input_type == CUDA_R_8F_E4M3) else if (input_type == CUDA_R_8F_E4M3) {
{ switch (out_dtype) {
switch (out_dtype) case at::ScalarType::Float8_e4m3fn:
{ output_type = CUDA_R_8F_E4M3;
case at::ScalarType::Float8_e4m3fn: C_type = CUDA_R_16F;
output_type = CUDA_R_8F_E4M3; break;
C_type = CUDA_R_16F; case at::ScalarType::Half:
break; output_type = CUDA_R_16F;
case at::ScalarType::Half: C_type = CUDA_R_16F;
output_type = CUDA_R_16F; break;
C_type = CUDA_R_16F; case at::ScalarType::BFloat16:
break; output_type = CUDA_R_16BF;
case at::ScalarType::BFloat16: C_type = CUDA_R_16BF;
output_type = CUDA_R_16BF; break;
C_type = CUDA_R_16BF; case at::ScalarType::Float:
break; output_type = CUDA_R_32F;
case at::ScalarType::Float: C_type = CUDA_R_32F;
output_type = CUDA_R_32F; break;
C_type = CUDA_R_32F; default:
break; TORCH_CHECK(
default: false,
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, float32} for fp8 inputs"); "Unsupported out_dtype passed, must be one of {fp16, bf16, float32} for fp8 inputs");
break; break;
} }
} }
#endif #endif
else { else {
TORCH_CHECK(false, "out_dtype support only available for int8/fp8 inputs"); TORCH_CHECK(
false, "out_dtype support only available for int8/fp8 inputs");
} }
} }
int64_t k = dense_B.size(0); int64_t k = dense_B.size(0);
int64_t n = dense_B.size(1); int64_t n = dense_B.size(1);
int64_t m = (compressed_A.numel() * 16 / compression_factor ) / k; int64_t m = (compressed_A.numel() * 16 / compression_factor) / k;
//initialize sparse descriptor // initialize sparse descriptor
cusparseLtMatDescriptor_t sparse_input_descriptor; cusparseLtMatDescriptor_t sparse_input_descriptor;
TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit( TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit(
&handle, &handle,
@ -285,7 +277,8 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
CUSPARSE_ORDER_ROW)); CUSPARSE_ORDER_ROW));
// create result tensor // create result tensor
auto res_tensor_options = c10::TensorOptions().dtype(out_dtype).device(dense_B.device()); auto res_tensor_options =
c10::TensorOptions().dtype(out_dtype).device(dense_B.device());
at::Tensor res = (transpose_result) ? at::empty({n, m}, res_tensor_options) at::Tensor res = (transpose_result) ? at::empty({n, m}, res_tensor_options)
: at::empty({m, n}, res_tensor_options); : at::empty({m, n}, res_tensor_options);
@ -295,7 +288,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
&res_descriptor, &res_descriptor,
m, m,
n, n,
(transpose_result) ? m: n, (transpose_result) ? m : n,
16, 16,
output_type, output_type,
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW)); (transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
@ -307,7 +300,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
&C_descriptor, &C_descriptor,
m, m,
n, n,
(transpose_result) ? m: n, (transpose_result) ? m : n,
16, 16,
C_type, C_type,
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW)); (transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
@ -317,7 +310,8 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
&handle, &handle,
&matmul, &matmul,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE,
(dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE, (dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE
: CUSPARSE_OPERATION_TRANSPOSE,
&sparse_input_descriptor, &sparse_input_descriptor,
&dense_input_descriptor, &dense_input_descriptor,
&C_descriptor, &C_descriptor,
@ -329,28 +323,59 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
auto& bias = bias_opt.value(); auto& bias = bias_opt.value();
void* dBias = bias.data_ptr(); void* dBias = bias.data_ptr();
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute( TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute(
&handle, &matmul, CUSPARSELT_MATMUL_BIAS_POINTER, &dBias, sizeof(dBias))); &handle,
&matmul,
CUSPARSELT_MATMUL_BIAS_POINTER,
&dBias,
sizeof(dBias)));
} }
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit( TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit(
&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)); &handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT));
// set alg_id // set matmul search params
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute( TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id))); &handle,
&alg_sel,
CUSPARSELT_MATMUL_ALG_CONFIG_ID,
&alg_id,
sizeof(alg_id)));
cusparseLtSplitKMode_t splitKMode;
int max_alg_id;
if (split_k != 1) {
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
&handle,
&alg_sel,
CUSPARSELT_MATMUL_SPLIT_K,
&split_k,
sizeof(split_k)));
splitKMode = split_k_one_kernel ? CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL
: CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS;
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
&handle,
&alg_sel,
CUSPARSELT_MATMUL_SPLIT_K_MODE,
&splitKMode,
sizeof(splitKMode)));
}
// set tensor_alpha_mode and alpha pointer for matmul // set tensor_alpha_mode and alpha pointer for matmul
const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{}; const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt : Tensor{};
auto alpha_ptr = &alpha; auto alpha_ptr = &alpha;
if (alpha_opt.has_value()) { if (alpha_opt.has_value()) {
if (alpha_tensor.numel() == 1) { if (alpha_tensor.numel() == 1) {
alpha = alpha_tensor.item<float>(); alpha = alpha_tensor.item<float>();
} } else {
else { tensor_alpha_mode = 1;
tensor_alpha_mode = 1; TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute(
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute( &handle,
&handle, &matmul, CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, &tensor_alpha_mode, sizeof(tensor_alpha_mode))); &matmul,
alpha_ptr = static_cast<float*>(alpha_tensor.data_ptr()); CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING,
&tensor_alpha_mode,
sizeof(tensor_alpha_mode)));
alpha_ptr = static_cast<float*>(alpha_tensor.data_ptr());
} }
} }
@ -365,7 +390,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
auto workspacePtr = allocator.allocate(workspace_size); auto workspacePtr = allocator.allocate(workspace_size);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if(search_alg_id){ if (search_alg_id) {
// run matmul search // run matmul search
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulSearch( TORCH_CUDASPARSE_CHECK(cusparseLtMatmulSearch(
&handle, &handle,
@ -381,11 +406,36 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
&stream, &stream,
1)); 1));
// get alg_id used // get matmul params used
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute( TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id))); &handle,
} &alg_sel,
else { CUSPARSELT_MATMUL_ALG_CONFIG_ID,
&alg_id,
sizeof(alg_id)));
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
&handle,
&alg_sel,
CUSPARSELT_MATMUL_SPLIT_K,
&split_k,
sizeof(split_k)));
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
&handle,
&alg_sel,
CUSPARSELT_MATMUL_SPLIT_K_MODE,
&splitKMode,
sizeof(splitKMode)));
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
&handle,
&alg_sel,
CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID,
&max_alg_id,
sizeof(max_alg_id)));
} else {
// do normal matmul // do normal matmul
TORCH_CUDASPARSE_CHECK(cusparseLtMatmul( TORCH_CUDASPARSE_CHECK(cusparseLtMatmul(
&handle, &handle,
@ -402,7 +452,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
1)); 1));
} }
//destroy descriptors // destroy descriptors
TORCH_CUDASPARSE_CHECK( TORCH_CUDASPARSE_CHECK(
cusparseLtMatDescriptorDestroy(&sparse_input_descriptor)); cusparseLtMatDescriptorDestroy(&sparse_input_descriptor));
TORCH_CUDASPARSE_CHECK( TORCH_CUDASPARSE_CHECK(
@ -411,7 +461,12 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
// destroy plan // destroy plan
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanDestroy(&plan)); TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanDestroy(&plan));
return {alg_id, res}; return {
res,
alg_id,
split_k,
splitKMode == CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL,
max_alg_id};
} }
at::Tensor _cslt_sparse_mm( at::Tensor _cslt_sparse_mm(
@ -421,19 +476,21 @@ at::Tensor _cslt_sparse_mm(
const std::optional<Tensor>& alpha_opt, const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt, const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result, bool transpose_result,
int64_t alg_id int64_t alg_id,
) int64_t split_k,
{ bool split_k_one_kernel) {
auto result = _cslt_sparse_mm_impl( auto result = _cslt_sparse_mm_impl(
compressed_A, compressed_A,
dense_B, dense_B,
bias_opt, bias_opt,
alpha_opt, alpha_opt,
out_dtype_opt, out_dtype_opt,
transpose_result, transpose_result,
(int) alg_id, (int)alg_id,
false); (int)split_k,
return std::get<1>(result); split_k_one_kernel,
false);
return std::get<0>(result);
} }
int64_t _cslt_sparse_mm_search( int64_t _cslt_sparse_mm_search(
@ -442,31 +499,34 @@ int64_t _cslt_sparse_mm_search(
const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& alpha_opt, const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt, const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result bool transpose_result) {
) TORCH_WARN_ONCE(
{ "torch._cslt_sparse_mm_search is deprecated and will be removed in a future PyTorch release. Please use torch._C._cusparselt.mm_search instead.");
int alg_id_int = 0; int alg_id_int = 0;
auto result = _cslt_sparse_mm_impl( int split_k = 1;
compressed_A, bool split_k_one_kernel = true;
dense_B, auto result = _cslt_sparse_mm_impl(
bias_opt, compressed_A,
alpha_opt, dense_B,
out_dtype_opt, bias_opt,
transpose_result, alpha_opt,
alg_id_int, out_dtype_opt,
true); transpose_result,
return (int64_t) std::get<0>(result); alg_id_int,
split_k,
split_k_one_kernel,
true);
return (int64_t)std::get<1>(result);
} }
} // namespace at::native } // namespace at::native
#else // No cuSPARSELt support, throw error if these functions are called. #else // No cuSPARSELt support, throw error if these functions are called.
namespace at::native { namespace at::native {
at::Tensor _cslt_compress(const Tensor& sparse_input){ at::Tensor _cslt_compress(const Tensor& sparse_input) {
TORCH_CHECK(false, "cuSPARSELt not supported on your machine."); TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
} }
at::Tensor _cslt_sparse_mm( at::Tensor _cslt_sparse_mm(
@ -476,9 +536,10 @@ at::Tensor _cslt_sparse_mm(
const std::optional<Tensor>& alpha_opt, const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype, const std::optional<c10::ScalarType> out_dtype,
bool transpose_result, bool transpose_result,
int64_t alg_id) int64_t alg_id,
{ int64_t split_k,
TORCH_CHECK(false, "cuSPARSELt not supported on your machine."); bool split_k_one_kernel) {
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
} }
int64_t _cslt_sparse_mm_search( int64_t _cslt_sparse_mm_search(
@ -487,10 +548,8 @@ int64_t _cslt_sparse_mm_search(
const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& alpha_opt, const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype, const std::optional<c10::ScalarType> out_dtype,
bool transpose_result bool transpose_result) {
) TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
{
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
} }
} // namespace at::native } // namespace at::native

View File

@ -0,0 +1,58 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <ATen/cuda/CUDASparse.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Functions.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Half.h>
#include <cusparse.h>
#include <cstdint>
#if AT_CUSPARSELT_ENABLED()
#include <cusparseLt.h>
#endif
namespace at::native {
at::Tensor _cslt_compress(const Tensor& sparse_input);
TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
const Tensor& compressed_A,
const Tensor& dense_B,
const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result,
int alg_id,
int split_k,
bool split_k_one_kernel,
bool search_alg_id
);
at::Tensor _cslt_sparse_mm(
const Tensor& compressed_A,
const Tensor& dense_B,
const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result,
int64_t alg_id,
int64_t split_k,
bool split_k_one_kernel
);
int64_t _cslt_sparse_mm_search(
const Tensor& compressed_A,
const Tensor& dense_B,
const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result
);
} // namespace at::native

View File

@ -1,253 +0,0 @@
import argparse
import random
import pandas as pd
from tqdm import tqdm
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
torch.set_printoptions(
precision=2,
threshold=None,
edgeitems=16,
linewidth=480,
profile=None,
sci_mode=False,
)
# helper model definition for pruner
class Model(nn.Module):
def __init__(self, m, k, dtype=None):
super().__init__()
# transposed so reversed
self.linear = nn.Linear(k, m)
def forward(self, x):
return self.linear(x)
def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""
choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)
def test_linear(m, k, n, dtype, contiguous, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
input_tensor = torch.zeros(n, k).to(dtype).cuda()
model = Model(m, k).to(dtype).cuda().eval()
dense_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
dense_output = model(input_tensor)
print(dense_output.shape)
# sparsify weights
model.linear.weight = nn.Parameter(
to_sparse_semi_structured(
sparse_weight,
)
)
sparse_output = model(input_tensor)
print(sparse_output.shape)
sparse_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "linear",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
def test_tensor(m, k, n, dtype, contiguous, backend):
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
B = torch.zeros(k, n).to(dtype).cuda()
bias = torch.rand(n).to(dtype).cuda()
sA = to_sparse_semi_structured(A)
# torch.mm calculation
if dtype is not torch.int8:
dense_output = torch.mm(A, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(A, B)",
globals=locals(),
).blocked_autorange()
else:
print("int8 baseline not supported")
dense_output = torch.mm(sA, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
sparse_output = torch.mm(sA, B)
sparse_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "tensor",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
if __name__ == "__main__":
dtype_lookup = {
"int8": torch.int8,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
parser.add_argument(
"--mode",
type=str,
choices=[
"nvidia-bert",
"nvidia-fixed-k",
"nvidia-fixed-mn",
],
)
parser.add_argument(
"--dtype",
type=str,
choices=dtype_lookup.keys(),
default="fp16",
)
parser.add_argument(
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
)
parser.add_argument("-contiguous", action="store_true")
parser.add_argument("-e2e", action="store_true")
parser.add_argument("-save", action="store_true")
args = parser.parse_args()
if args.e2e:
eval_fn = test_linear
else:
eval_fn = test_tensor
print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
dtype = dtype_lookup[args.dtype]
if args.mode == "nvidia-bert":
bert_shapes = [
(3072, 1024, 16384),
(4096, 1024, 16384),
(1024, 1024, 16384),
(1024, 4096, 16384),
]
results = (
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
for (m, k, n) in tqdm(bert_shapes)
)
elif args.mode == "nvidia-fixed-k":
mn_vals = [
3072,
4096,
5120,
6144,
7168,
8192,
9216,
10240,
11264,
12288,
13312,
14336,
15360,
16384,
17408,
18432,
19456,
20480,
]
results = (
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
for mn in tqdm(mn_vals)
)
elif args.mode == "nvidia-fixed-mn":
k_vals = [
2560,
3840,
5120,
6400,
7680,
8960,
10240,
11520,
12800,
14080,
15360,
16640,
17920,
19200,
20480,
]
results = (
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
for k in tqdm(k_vals)
)
df = pd.DataFrame.from_records(results)
if args.save:
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
df.to_csv(save_file)
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
print(df)

View File

@ -244,18 +244,17 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
def test_sp24_compile(self) -> None: def test_sp24_compile(self) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
def fn(x, e): def fn(x):
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x) y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
y = y.t() y = y.t()
return x @ y return x @ y
# Eager # Eager
output = fn(x, e) output = fn(x)
output.backward(output) output.backward(output)
# Torch compile # Torch compile
output = torch.compile(fn)(x, e) output = torch.compile(fn)(x)
output.backward(output) output.backward(output)
class TestSparseSemiStructured(TestCase): class TestSparseSemiStructured(TestCase):
@ -1133,6 +1132,21 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
def test_cslt_sparse_mm_alpha_compile_autotune(self, device):
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).cuda()
B = torch.ones((128, 256), device=device).to(torch.int8).t()
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
A_compressed = torch._cslt_compress(A)
compiled_sparse_mm = torch.compile(torch._cslt_sparse_mm, mode="max-autotune")
sparse_result = compiled_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.int32)
alpha_scaled = torch.stack([alpha] * 128).t().cpu().float()
dense_result = alpha_scaled * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
dense_result = dense_result.to(torch.int32)
torch.testing.assert_close(sparse_result.cpu(), dense_result, rtol=1e-3, atol=1e-3)
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32]) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda() A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
@ -1149,21 +1163,6 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@inference_dtypes
def test_cslt_sparse_mm_alg_id(self, device, dtype):
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@inference_dtypes @inference_dtypes
def test_cslt_sparse_mm_search(self, device, dtype): def test_cslt_sparse_mm_search(self, device, dtype):
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
@ -1172,7 +1171,26 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
A_compressed = torch._cslt_compress(A) A_compressed = torch._cslt_compress(A)
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t()) alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id()) sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@inference_dtypes
def test_csrc_cslt_sparse_mm_search(self, device, dtype):
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
alg_id=alg_id,
split_k=split_k,
split_k_one_kernel=split_k_one_kernel)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
def test_cusparselt_backend(self): def test_cusparselt_backend(self):
version = _get_torch_cuda_version() version = _get_torch_cuda_version()

View File

@ -520,18 +520,22 @@ def meta__cslt_sparse_mm(
alpha: Optional[Tensor] = None, alpha: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None, out_dtype: Optional[torch.dtype] = None,
transpose_result: bool = False, transpose_result: bool = False,
alg_id: int = 0,
split_k: int = 1,
split_k_one_kernel: bool = False,
): ):
assert dense_B.dtype in { assert dense_B.dtype in {
torch.float32, torch.float32,
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
torch.int8, torch.int8,
}, "_cslt_sparse_mm only supports fp16, bf16, and int8" torch.float8_e4m3fn,
}, "_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3"
assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype" assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs" assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
is_int8_input_type = compressed_A.dtype == torch.int8 is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn]
compression_factor = 10 if is_int8_input_type else 9 compression_factor = 10 if is_8bit_input_type else 9
k = dense_B.size(0) k = dense_B.size(0)
n = dense_B.size(1) n = dense_B.size(1)
m = (compressed_A.numel() * 16) // (compression_factor * k) m = (compressed_A.numel() * 16) // (compression_factor * k)
@ -539,11 +543,16 @@ def meta__cslt_sparse_mm(
assert m == bias.size(0) assert m == bias.size(0)
if out_dtype is not None: if out_dtype is not None:
assert is_int8_input_type and out_dtype in { assert (
torch.float16, is_8bit_input_type
torch.bfloat16, and out_dtype
torch.int32, in {
}, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul" torch.float16,
torch.bfloat16,
torch.int32,
torch.float8_e4m3fn,
}
), "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
output_shape = (n, m) if transpose_result else (m, n) output_shape = (n, m) if transpose_result else (m, n)
result = dense_B.new_empty(output_shape, dtype=out_dtype) result = dense_B.new_empty(output_shape, dtype=out_dtype)
return result return result

View File

@ -1,7 +1,7 @@
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
#ifdef USE_CUSPARSELT #ifdef USE_CUSPARSELT
#include <cusparseLt.h> #include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
namespace { namespace {
@ -9,6 +9,34 @@ size_t getVersionInt() {
return CUSPARSELT_VERSION; return CUSPARSELT_VERSION;
} }
std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
const at::Tensor& compressed_A,
const at::Tensor& dense_B,
const std::optional<at::Tensor>& bias_opt,
const std::optional<at::Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result) {
int alg_id_int = 0;
int split_k = 1;
bool split_k_one_kernel = true;
auto result = at::native::_cslt_sparse_mm_impl(
compressed_A,
dense_B,
bias_opt,
alpha_opt,
out_dtype_opt,
transpose_result,
alg_id_int,
split_k,
split_k_one_kernel,
true);
return {
(int64_t)std::get<1>(result),
(int64_t)std::get<2>(result),
(bool)std::get<3>(result),
(int64_t)std::get<4>(result)};
}
} // namespace } // namespace
namespace torch::cuda::shared { namespace torch::cuda::shared {
@ -17,6 +45,7 @@ void initCusparseltBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>(); auto m = py::handle(module).cast<py::module>();
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings"); auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
cusparselt.def("getVersionInt", getVersionInt); cusparselt.def("getVersionInt", getVersionInt);
cusparselt.def("mm_search", mmSearch);
} }
} // namespace torch::cuda::shared } // namespace torch::cuda::shared

View File

@ -103,6 +103,8 @@ def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
packed_t=self.packed_t, packed_t=self.packed_t,
meta_t=self.meta_t, meta_t=self.meta_t,
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask, compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
fuse_transpose_cusparselt=self.fuse_transpose_cusparselt,
alg_id_cusparselt=self.alg_id_cusparselt,
requires_grad=False, requires_grad=False,
) )