diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5e48bfd6b9a8..9749dad663f9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3371,7 +3371,7 @@ dispatch: CUDA: _cslt_compress -- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, bool split_k_one_kernel=True) -> Tensor +- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor dispatch: CUDA: _cslt_sparse_mm diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp index e40b68ee0bfc..ca3996f00e7a 100644 --- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp +++ b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp @@ -1,97 +1,109 @@ -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #if AT_CUSPARSELT_ENABLED() +#include + namespace at::native { -// Ideally we would use the same DeviceThreadHandlePool mechanism as used in -// aten/src/ATen/cuda/CuSparseHandlePool.cpp which would handle this for us. -// However, the cuSPARSELt handle signature is different from that of -// cuSPARSE/cuBLAS, so it's not possible to reuse the existing pooling -// mechanism. Instead we have to handle our handles ourselves, which is why -// these variables are thread local. Once cuSPARSELt updates their handle -// signature to be consistent with the rest of CUDA, we can switch to using -// DeviceThreadHandlePool. +// Ideally we would use the same DeviceThreadHandlePool mechanism as used in aten/src/ATen/cuda/CuSparseHandlePool.cpp +// which would handle this for us. However, the cuSPARSELt handle signature is different from that of cuSPARSE/cuBLAS, +// so it's not possible to reuse the existing pooling mechanism. Instead we have to handle our handles ourselves, which +// is why these variables are thread local. Once cuSPARSELt updates their handle signature to be consistent with the rest +// of CUDA, we can switch to using DeviceThreadHandlePool. thread_local cusparseLtHandle_t handle; thread_local bool handle_initialized = false; -at::Tensor _cslt_compress(const Tensor& sparse_input) { - if (!handle_initialized) { - TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); - handle_initialized = true; - } - // create sparse descriptor, dtype - cusparseLtMatDescriptor_t sparse_input_descriptor; - cudaDataType type; - auto compression_factor = 9; +at::Tensor _cslt_compress(const Tensor& sparse_input) +{ + if (!handle_initialized){ + TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); + handle_initialized = true; + } + // create sparse descriptor, dtype + cusparseLtMatDescriptor_t sparse_input_descriptor; + cudaDataType type; + auto compression_factor = 9; - switch (sparse_input.scalar_type()) { - case at::ScalarType::Char: - type = CUDA_R_8I; - compression_factor = 10; - break; - case at::ScalarType::Half: - type = CUDA_R_16F; - break; - case at::ScalarType::BFloat16: - type = CUDA_R_16BF; - break; - case at::ScalarType::Float: - type = CUDA_R_32F; - break; + switch( + sparse_input.scalar_type() + ) + { + case at::ScalarType::Char: + type = CUDA_R_8I; + compression_factor = 10; + break; + case at::ScalarType::Half: + type = CUDA_R_16F; + break; + case at::ScalarType::BFloat16: + type = CUDA_R_16BF; + break; + case at::ScalarType::Float: + type = CUDA_R_32F; + break; #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 - case at::ScalarType::Float8_e4m3fn: - type = CUDA_R_8F_E4M3; - compression_factor = 10; - break; + case at::ScalarType::Float8_e4m3fn: + type = CUDA_R_8F_E4M3; + break; #endif - default: - TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix"); - break; - } + default: + TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix"); + break; + } - // create a new compressed tensor with the same dtype as - auto compressed_tensor = - sparse_input.new_empty(sparse_input.numel() * compression_factor / 16); + // create a new compressed tensor with the same dtype as + auto compressed_tensor = sparse_input.new_empty(sparse_input.numel() * compression_factor / 16); - TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit( - &handle, - &sparse_input_descriptor, - sparse_input.size(0), - sparse_input.size(1), - sparse_input.size(1), - 16, - type, - CUSPARSE_ORDER_ROW, - CUSPARSELT_SPARSITY_50_PERCENT)); + TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit( + &handle, + &sparse_input_descriptor, + sparse_input.size(0), + sparse_input.size(1), + sparse_input.size(1), + 16, + type, + CUSPARSE_ORDER_ROW, + CUSPARSELT_SPARSITY_50_PERCENT)); - // compress input - //-------------------------------------------------------------------------- - size_t compressed_size, compressed_buffer_size; - TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompressedSize2( - &handle, - &sparse_input_descriptor, - &compressed_size, - &compressed_buffer_size)); + // compress input + //-------------------------------------------------------------------------- + size_t compressed_size, compressed_buffer_size; + TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompressedSize2( + &handle, + &sparse_input_descriptor, + &compressed_size, + &compressed_buffer_size)); - auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - auto compressedBufferPtr = allocator.allocate(compressed_buffer_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto compressedBufferPtr = allocator.allocate(compressed_buffer_size); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompress2( - &handle, - &sparse_input_descriptor, - true, - CUSPARSE_OPERATION_NON_TRANSPOSE, - sparse_input.data_ptr(), - compressed_tensor.data_ptr(), - compressedBufferPtr.get(), - stream)); + TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompress2( + &handle, + &sparse_input_descriptor, + true, + CUSPARSE_OPERATION_NON_TRANSPOSE, + sparse_input.data_ptr(), + compressed_tensor.data_ptr(), + compressedBufferPtr.get(), + stream)); - return compressed_tensor; + return compressed_tensor; } -std::tuple _cslt_sparse_mm_impl( +std::tuple _cslt_sparse_mm_impl( const Tensor& compressed_A, const Tensor& dense_B, const std::optional& bias_opt, @@ -99,12 +111,12 @@ std::tuple _cslt_sparse_mm_impl( const std::optional out_dtype_opt, bool transpose_result, int alg_id, - int split_k, - bool split_k_one_kernel, - bool search_alg_id) { - if (!handle_initialized) { - TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); - handle_initialized = true; + bool search_alg_id +) +{ + if (!handle_initialized){ + TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); + handle_initialized = true; } // cupsarselt constructs cusparseLtMatmulDescriptor_t matmul; @@ -120,138 +132,134 @@ std::tuple _cslt_sparse_mm_impl( cusparseComputeType compute_type; auto compression_factor = 9; - switch (compressed_A.scalar_type()) { + switch(compressed_A.scalar_type()) + { case at::ScalarType::Char: - input_type = CUDA_R_8I; - output_type = CUDA_R_8I; - C_type = CUDA_R_8I; - compute_type = CUSPARSE_COMPUTE_32I; - compression_factor = 10; - break; + input_type = CUDA_R_8I; + output_type = CUDA_R_8I; + C_type = CUDA_R_8I; + compute_type = CUSPARSE_COMPUTE_32I; + compression_factor = 10; + break; -// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F -// to CUSPARSE_COMPUTE_32F +// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F to CUSPARSE_COMPUTE_32F #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 case at::ScalarType::Half: - input_type = CUDA_R_16F; - output_type = CUDA_R_16F; - C_type = CUDA_R_16F; - compute_type = CUSPARSE_COMPUTE_32F; - break; + input_type = CUDA_R_16F; + output_type = CUDA_R_16F; + C_type = CUDA_R_16F; + compute_type = CUSPARSE_COMPUTE_32F; + break; case at::ScalarType::BFloat16: - input_type = CUDA_R_16BF; - output_type = CUDA_R_16BF; - C_type = CUDA_R_16BF; - compute_type = CUSPARSE_COMPUTE_32F; - break; + input_type = CUDA_R_16BF; + output_type = CUDA_R_16BF; + C_type = CUDA_R_16BF; + compute_type = CUSPARSE_COMPUTE_32F; + break; case at::ScalarType::Float: - input_type = CUDA_R_32F; - output_type = CUDA_R_32F; - C_type = CUDA_R_32F; - compute_type = CUSPARSE_COMPUTE_32F; - break; + input_type = CUDA_R_32F; + output_type = CUDA_R_32F; + C_type = CUDA_R_32F; + compute_type = CUSPARSE_COMPUTE_32F; + break; // if cuSPARSELt >= 6.2.3, we can add Float8 support #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 case at::ScalarType::Float8_e4m3fn: - input_type = CUDA_R_8F_E4M3; - output_type = CUDA_R_8F_E4M3; - C_type = CUDA_R_16F; - compute_type = CUSPARSE_COMPUTE_32F; - compression_factor = 10; - break; + input_type = CUDA_R_8F_E4M3; + output_type = CUDA_R_8F_E4M3; + C_type = CUDA_R_16F; + compute_type = CUSPARSE_COMPUTE_32F; + break; #endif // cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F #else case at::ScalarType::Half: - input_type = CUDA_R_16F; - output_type = CUDA_R_16F; - C_type = CUDA_R_16F; - compute_type = CUSPARSE_COMPUTE_16F; - break; + input_type = CUDA_R_16F; + output_type = CUDA_R_16F; + C_type = CUDA_R_16F; + compute_type = CUSPARSE_COMPUTE_16F; + break; case at::ScalarType::BFloat16: - input_type = CUDA_R_16BF; - output_type = CUDA_R_16BF; - C_type = CUDA_R_16BF; - compute_type = CUSPARSE_COMPUTE_16F; - break; + input_type = CUDA_R_16BF; + output_type = CUDA_R_16BF; + C_type = CUDA_R_16BF; + compute_type = CUSPARSE_COMPUTE_16F; + break; case at::ScalarType::Float: - input_type = CUDA_R_32F; - output_type = CUDA_R_32F; - C_type = CUDA_R_32F; - compute_type = CUSPARSE_COMPUTE_TF32; - break; + input_type = CUDA_R_32F; + output_type = CUDA_R_32F; + C_type = CUDA_R_32F; + compute_type = CUSPARSE_COMPUTE_TF32; + break; #endif default: - TORCH_CHECK( - false, - "Unsupported dtype for cuSPARSELt compressed matrix multiplication."); - break; + TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix multiplication."); + break; } ScalarType out_dtype = dense_B.scalar_type(); // special check for mixed dtype support for 8 bit dtypes // cslt 0.5.2+: int8 int8 -> {fp16, bf16, int32} support if (out_dtype_opt.has_value()) { out_dtype = out_dtype_opt.value(); - if (input_type == CUDA_R_8I) { - switch (out_dtype) { - case at::ScalarType::Half: - C_type = CUDA_R_16F; - output_type = CUDA_R_16F; - break; - case at::ScalarType::BFloat16: - C_type = CUDA_R_16BF; - output_type = CUDA_R_16BF; - break; - case at::ScalarType::Int: - C_type = CUDA_R_32I; - output_type = CUDA_R_32I; - break; - default: - TORCH_CHECK( - false, - "Unsupported out_dtype passed, must be one of {fp16, bf16, int32} for int8 inputs"); - break; - } + if (input_type == CUDA_R_8I) + { + switch (out_dtype) + { + case at::ScalarType::Half: + C_type = CUDA_R_16F; + output_type = CUDA_R_16F; + break; + case at::ScalarType::BFloat16: + C_type = CUDA_R_16BF; + output_type = CUDA_R_16BF; + break; + case at::ScalarType::Int: + C_type = CUDA_R_32I; + output_type = CUDA_R_32I; + break; + default: + TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, int32} for int8 inputs"); + break; + } } // cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 - else if (input_type == CUDA_R_8F_E4M3) { - switch (out_dtype) { - case at::ScalarType::Float8_e4m3fn: - output_type = CUDA_R_8F_E4M3; - C_type = CUDA_R_16F; - break; - case at::ScalarType::Half: - output_type = CUDA_R_16F; - C_type = CUDA_R_16F; - break; - case at::ScalarType::BFloat16: - output_type = CUDA_R_16BF; - C_type = CUDA_R_16BF; - break; - case at::ScalarType::Float: - output_type = CUDA_R_32F; - C_type = CUDA_R_32F; - break; - default: - TORCH_CHECK( - false, - "Unsupported out_dtype passed, must be one of {fp16, bf16, float32} for fp8 inputs"); - break; - } + else if (input_type == CUDA_R_8F_E4M3) + { + switch (out_dtype) + { + case at::ScalarType::Float8_e4m3fn: + output_type = CUDA_R_8F_E4M3; + C_type = CUDA_R_16F; + break; + case at::ScalarType::Half: + output_type = CUDA_R_16F; + C_type = CUDA_R_16F; + break; + case at::ScalarType::BFloat16: + output_type = CUDA_R_16BF; + C_type = CUDA_R_16BF; + break; + case at::ScalarType::Float: + output_type = CUDA_R_32F; + C_type = CUDA_R_32F; + break; + default: + TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, float32} for fp8 inputs"); + break; + } } #endif else { - TORCH_CHECK( - false, "out_dtype support only available for int8/fp8 inputs"); + TORCH_CHECK(false, "out_dtype support only available for int8/fp8 inputs"); } } int64_t k = dense_B.size(0); int64_t n = dense_B.size(1); - int64_t m = (compressed_A.numel() * 16 / compression_factor) / k; + int64_t m = (compressed_A.numel() * 16 / compression_factor ) / k; - // initialize sparse descriptor + //initialize sparse descriptor cusparseLtMatDescriptor_t sparse_input_descriptor; TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit( &handle, @@ -277,8 +285,7 @@ std::tuple _cslt_sparse_mm_impl( CUSPARSE_ORDER_ROW)); // create result tensor - auto res_tensor_options = - c10::TensorOptions().dtype(out_dtype).device(dense_B.device()); + auto res_tensor_options = c10::TensorOptions().dtype(out_dtype).device(dense_B.device()); at::Tensor res = (transpose_result) ? at::empty({n, m}, res_tensor_options) : at::empty({m, n}, res_tensor_options); @@ -288,7 +295,7 @@ std::tuple _cslt_sparse_mm_impl( &res_descriptor, m, n, - (transpose_result) ? m : n, + (transpose_result) ? m: n, 16, output_type, (transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW)); @@ -300,7 +307,7 @@ std::tuple _cslt_sparse_mm_impl( &C_descriptor, m, n, - (transpose_result) ? m : n, + (transpose_result) ? m: n, 16, C_type, (transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW)); @@ -310,8 +317,7 @@ std::tuple _cslt_sparse_mm_impl( &handle, &matmul, CUSPARSE_OPERATION_NON_TRANSPOSE, - (dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE - : CUSPARSE_OPERATION_TRANSPOSE, + (dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE, &sparse_input_descriptor, &dense_input_descriptor, &C_descriptor, @@ -323,59 +329,28 @@ std::tuple _cslt_sparse_mm_impl( auto& bias = bias_opt.value(); void* dBias = bias.data_ptr(); TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute( - &handle, - &matmul, - CUSPARSELT_MATMUL_BIAS_POINTER, - &dBias, - sizeof(dBias))); + &handle, &matmul, CUSPARSELT_MATMUL_BIAS_POINTER, &dBias, sizeof(dBias))); } TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit( &handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT)); - // set matmul search params + // set alg_id TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute( - &handle, - &alg_sel, - CUSPARSELT_MATMUL_ALG_CONFIG_ID, - &alg_id, - sizeof(alg_id))); - - cusparseLtSplitKMode_t splitKMode; - int max_alg_id; - if (split_k != 1) { - TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute( - &handle, - &alg_sel, - CUSPARSELT_MATMUL_SPLIT_K, - &split_k, - sizeof(split_k))); - - splitKMode = split_k_one_kernel ? CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL - : CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS; - TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute( - &handle, - &alg_sel, - CUSPARSELT_MATMUL_SPLIT_K_MODE, - &splitKMode, - sizeof(splitKMode))); - } + &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id))); // set tensor_alpha_mode and alpha pointer for matmul - const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt : Tensor{}; + const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{}; auto alpha_ptr = α if (alpha_opt.has_value()) { if (alpha_tensor.numel() == 1) { - alpha = alpha_tensor.item(); - } else { - tensor_alpha_mode = 1; - TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute( - &handle, - &matmul, - CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, - &tensor_alpha_mode, - sizeof(tensor_alpha_mode))); - alpha_ptr = static_cast(alpha_tensor.data_ptr()); + alpha = alpha_tensor.item(); + } + else { + tensor_alpha_mode = 1; + TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, &tensor_alpha_mode, sizeof(tensor_alpha_mode))); + alpha_ptr = static_cast(alpha_tensor.data_ptr()); } } @@ -390,7 +365,7 @@ std::tuple _cslt_sparse_mm_impl( auto workspacePtr = allocator.allocate(workspace_size); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (search_alg_id) { + if(search_alg_id){ // run matmul search TORCH_CUDASPARSE_CHECK(cusparseLtMatmulSearch( &handle, @@ -406,36 +381,11 @@ std::tuple _cslt_sparse_mm_impl( &stream, 1)); - // get matmul params used + // get alg_id used TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute( - &handle, - &alg_sel, - CUSPARSELT_MATMUL_ALG_CONFIG_ID, - &alg_id, - sizeof(alg_id))); - - TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute( - &handle, - &alg_sel, - CUSPARSELT_MATMUL_SPLIT_K, - &split_k, - sizeof(split_k))); - - TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute( - &handle, - &alg_sel, - CUSPARSELT_MATMUL_SPLIT_K_MODE, - &splitKMode, - sizeof(splitKMode))); - - TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute( - &handle, - &alg_sel, - CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID, - &max_alg_id, - sizeof(max_alg_id))); - - } else { + &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id))); + } + else { // do normal matmul TORCH_CUDASPARSE_CHECK(cusparseLtMatmul( &handle, @@ -452,7 +402,7 @@ std::tuple _cslt_sparse_mm_impl( 1)); } - // destroy descriptors + //destroy descriptors TORCH_CUDASPARSE_CHECK( cusparseLtMatDescriptorDestroy(&sparse_input_descriptor)); TORCH_CUDASPARSE_CHECK( @@ -461,12 +411,7 @@ std::tuple _cslt_sparse_mm_impl( // destroy plan TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanDestroy(&plan)); - return { - res, - alg_id, - split_k, - splitKMode == CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL, - max_alg_id}; + return {alg_id, res}; } at::Tensor _cslt_sparse_mm( @@ -476,21 +421,19 @@ at::Tensor _cslt_sparse_mm( const std::optional& alpha_opt, const std::optional out_dtype_opt, bool transpose_result, - int64_t alg_id, - int64_t split_k, - bool split_k_one_kernel) { - auto result = _cslt_sparse_mm_impl( - compressed_A, - dense_B, - bias_opt, - alpha_opt, - out_dtype_opt, - transpose_result, - (int)alg_id, - (int)split_k, - split_k_one_kernel, - false); - return std::get<0>(result); + int64_t alg_id +) +{ + auto result = _cslt_sparse_mm_impl( + compressed_A, + dense_B, + bias_opt, + alpha_opt, + out_dtype_opt, + transpose_result, + (int) alg_id, + false); + return std::get<1>(result); } int64_t _cslt_sparse_mm_search( @@ -499,34 +442,31 @@ int64_t _cslt_sparse_mm_search( const std::optional& bias_opt, const std::optional& alpha_opt, const std::optional out_dtype_opt, - bool transpose_result) { - TORCH_WARN_ONCE( - "torch._cslt_sparse_mm_search is deprecated and will be removed in a future PyTorch release. Please use torch._C._cusparselt.mm_search instead."); - int alg_id_int = 0; - int split_k = 1; - bool split_k_one_kernel = true; - auto result = _cslt_sparse_mm_impl( - compressed_A, - dense_B, - bias_opt, - alpha_opt, - out_dtype_opt, - transpose_result, - alg_id_int, - split_k, - split_k_one_kernel, - true); - return (int64_t)std::get<1>(result); + bool transpose_result +) +{ + int alg_id_int = 0; + auto result = _cslt_sparse_mm_impl( + compressed_A, + dense_B, + bias_opt, + alpha_opt, + out_dtype_opt, + transpose_result, + alg_id_int, + true); + return (int64_t) std::get<0>(result); } + } // namespace at::native #else // No cuSPARSELt support, throw error if these functions are called. namespace at::native { -at::Tensor _cslt_compress(const Tensor& sparse_input) { - TORCH_CHECK(false, "cuSPARSELt not supported on your machine."); +at::Tensor _cslt_compress(const Tensor& sparse_input){ + TORCH_CHECK(false, "cuSPARSELt not supported on your machine."); } at::Tensor _cslt_sparse_mm( @@ -536,10 +476,9 @@ at::Tensor _cslt_sparse_mm( const std::optional& alpha_opt, const std::optional out_dtype, bool transpose_result, - int64_t alg_id, - int64_t split_k, - bool split_k_one_kernel) { - TORCH_CHECK(false, "cuSPARSELt not supported on your machine."); + int64_t alg_id) +{ + TORCH_CHECK(false, "cuSPARSELt not supported on your machine."); } int64_t _cslt_sparse_mm_search( @@ -548,8 +487,10 @@ int64_t _cslt_sparse_mm_search( const std::optional& bias_opt, const std::optional& alpha_opt, const std::optional out_dtype, - bool transpose_result) { - TORCH_CHECK(false, "cuSPARSELt not supported on your machine."); + bool transpose_result +) +{ + TORCH_CHECK(false, "cuSPARSELt not supported on your machine."); } } // namespace at::native diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h deleted file mode 100644 index 00e7a8e1477d..000000000000 --- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h +++ /dev/null @@ -1,58 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if AT_CUSPARSELT_ENABLED() -#include -#endif - -namespace at::native { - -at::Tensor _cslt_compress(const Tensor& sparse_input); - -TORCH_CUDA_CPP_API std::tuple _cslt_sparse_mm_impl( - const Tensor& compressed_A, - const Tensor& dense_B, - const std::optional& bias_opt, - const std::optional& alpha_opt, - const std::optional out_dtype_opt, - bool transpose_result, - int alg_id, - int split_k, - bool split_k_one_kernel, - bool search_alg_id -); - -at::Tensor _cslt_sparse_mm( - const Tensor& compressed_A, - const Tensor& dense_B, - const std::optional& bias_opt, - const std::optional& alpha_opt, - const std::optional out_dtype_opt, - bool transpose_result, - int64_t alg_id, - int64_t split_k, - bool split_k_one_kernel -); - -int64_t _cslt_sparse_mm_search( - const Tensor& compressed_A, - const Tensor& dense_B, - const std::optional& bias_opt, - const std::optional& alpha_opt, - const std::optional out_dtype_opt, - bool transpose_result -); - -} // namespace at::native diff --git a/benchmarks/sparse/benchmark_semi_structured_sparsity.py b/benchmarks/sparse/benchmark_semi_structured_sparsity.py new file mode 100644 index 000000000000..66311c40428f --- /dev/null +++ b/benchmarks/sparse/benchmark_semi_structured_sparsity.py @@ -0,0 +1,253 @@ +import argparse +import random + +import pandas as pd +from tqdm import tqdm + +import torch +import torch.utils.benchmark as benchmark +from torch import nn +from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured + + +torch.set_printoptions( + precision=2, + threshold=None, + edgeitems=16, + linewidth=480, + profile=None, + sci_mode=False, +) + + +# helper model definition for pruner +class Model(nn.Module): + def __init__(self, m, k, dtype=None): + super().__init__() + # transposed so reversed + self.linear = nn.Linear(k, m) + + def forward(self, x): + return self.linear(x) + + +def rand_sparse_semi_structured_mask( + r, c, dtype=torch.float16, device="cuda", choice=None +): + """ + This function returns a 1:2 sparse matrix of size (r, c). + Note that this means this matrix will also be 2:4 and 4:8 sparse as well. + """ + + choices = [[0, 1], [1, 0]] + mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] + + return ( + torch.tensor(mask_entries, dtype=dtype, device=device) + .reshape(r, c) + .contiguous() + ) + + +def test_linear(m, k, n, dtype, contiguous, backend): + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" + mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype) + sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask + input_tensor = torch.zeros(n, k).to(dtype).cuda() + model = Model(m, k).to(dtype).cuda().eval() + + dense_measurement = benchmark.Timer( + stmt="model(input_tensor)", + globals=locals(), + ).blocked_autorange() + + dense_output = model(input_tensor) + print(dense_output.shape) + + # sparsify weights + model.linear.weight = nn.Parameter( + to_sparse_semi_structured( + sparse_weight, + ) + ) + + sparse_output = model(input_tensor) + print(sparse_output.shape) + + sparse_measurement = benchmark.Timer( + stmt="model(input_tensor)", + globals=locals(), + ).blocked_autorange() + + correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) + + return { + "test_function": "linear", + "m": m, + "k": k, + "n": n, + "dtype": str(dtype), + "backend": backend, + "sparse_latency (ms)": sparse_measurement.median * 1000, + "dense_latency (ms)": dense_measurement.median * 1000, + "speedup (d/s)": dense_measurement.median / sparse_measurement.median, + "correct": correct, + "contiguous": sparse_output.is_contiguous(), + } + + +def test_tensor(m, k, n, dtype, contiguous, backend): + A = rand_sparse_semi_structured_mask(m, k, dtype=dtype) + B = torch.zeros(k, n).to(dtype).cuda() + bias = torch.rand(n).to(dtype).cuda() + + sA = to_sparse_semi_structured(A) + + # torch.mm calculation + if dtype is not torch.int8: + dense_output = torch.mm(A, B) + + dense_measurement = benchmark.Timer( + stmt="torch.mm(A, B)", + globals=locals(), + ).blocked_autorange() + + else: + print("int8 baseline not supported") + dense_output = torch.mm(sA, B) + + dense_measurement = benchmark.Timer( + stmt="torch.mm(sA, B)", + globals=locals(), + ).blocked_autorange() + + sparse_output = torch.mm(sA, B) + sparse_measurement = benchmark.Timer( + stmt="torch.mm(sA, B)", + globals=locals(), + ).blocked_autorange() + + correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) + + return { + "test_function": "tensor", + "m": m, + "k": k, + "n": n, + "dtype": str(dtype), + "backend": backend, + "sparse_latency (ms)": sparse_measurement.median * 1000, + "dense_latency (ms)": dense_measurement.median * 1000, + "speedup (d/s)": dense_measurement.median / sparse_measurement.median, + "correct": correct, + "contiguous": sparse_output.is_contiguous(), + } + + +if __name__ == "__main__": + dtype_lookup = { + "int8": torch.int8, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + + parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks") + parser.add_argument( + "--mode", + type=str, + choices=[ + "nvidia-bert", + "nvidia-fixed-k", + "nvidia-fixed-mn", + ], + ) + parser.add_argument( + "--dtype", + type=str, + choices=dtype_lookup.keys(), + default="fp16", + ) + parser.add_argument( + "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt" + ) + parser.add_argument("-contiguous", action="store_true") + parser.add_argument("-e2e", action="store_true") + parser.add_argument("-save", action="store_true") + args = parser.parse_args() + + if args.e2e: + eval_fn = test_linear + else: + eval_fn = test_tensor + + print(f"Started benchmark: {args.mode} | dtype: {args.dtype}") + dtype = dtype_lookup[args.dtype] + + if args.mode == "nvidia-bert": + bert_shapes = [ + (3072, 1024, 16384), + (4096, 1024, 16384), + (1024, 1024, 16384), + (1024, 4096, 16384), + ] + results = ( + eval_fn(m, k, n, dtype, args.contiguous, args.backend) + for (m, k, n) in tqdm(bert_shapes) + ) + + elif args.mode == "nvidia-fixed-k": + mn_vals = [ + 3072, + 4096, + 5120, + 6144, + 7168, + 8192, + 9216, + 10240, + 11264, + 12288, + 13312, + 14336, + 15360, + 16384, + 17408, + 18432, + 19456, + 20480, + ] + results = ( + eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend) + for mn in tqdm(mn_vals) + ) + + elif args.mode == "nvidia-fixed-mn": + k_vals = [ + 2560, + 3840, + 5120, + 6400, + 7680, + 8960, + 10240, + 11520, + 12800, + 14080, + 15360, + 16640, + 17920, + 19200, + 20480, + ] + results = ( + eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend) + for k in tqdm(k_vals) + ) + + df = pd.DataFrame.from_records(results) + if args.save: + save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" + df.to_csv(save_file) + print(f"Finished benchmark: {args.mode} saved results to {save_file}") + print(df) diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index e90015a5db4c..2292dca8c971 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -244,17 +244,18 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") def test_sp24_compile(self) -> None: x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) + e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16) - def fn(x): + def fn(x, e): y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x) y = y.t() return x @ y # Eager - output = fn(x) + output = fn(x, e) output.backward(output) # Torch compile - output = torch.compile(fn)(x) + output = torch.compile(fn)(x, e) output.backward(output) class TestSparseSemiStructured(TestCase): @@ -1132,21 +1133,6 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) - def test_cslt_sparse_mm_alpha_compile_autotune(self, device): - A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).cuda() - B = torch.ones((128, 256), device=device).to(torch.int8).t() - alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda() - - A_compressed = torch._cslt_compress(A) - compiled_sparse_mm = torch.compile(torch._cslt_sparse_mm, mode="max-autotune") - sparse_result = compiled_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.int32) - - alpha_scaled = torch.stack([alpha] * 128).t().cpu().float() - dense_result = alpha_scaled * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu()) - dense_result = dense_result.to(torch.int32) - - torch.testing.assert_close(sparse_result.cpu(), dense_result, rtol=1e-3, atol=1e-3) - @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32]) def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda() @@ -1163,6 +1149,21 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + @inference_dtypes + def test_cslt_sparse_mm_alg_id(self, device, dtype): + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + A_compressed = torch._cslt_compress(A) + B = torch.ones((128, 128), device=device).to(dtype) + + A_compressed = torch._cslt_compress(A) + alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t()) + sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id) + + dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32)) + dense_result = dense_result.to(dtype) + + torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + @inference_dtypes def test_cslt_sparse_mm_search(self, device, dtype): A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) @@ -1171,26 +1172,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): A_compressed = torch._cslt_compress(A) alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t()) - sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id) - dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32)) - dense_result = dense_result.to(dtype) - torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) - - @inference_dtypes - def test_csrc_cslt_sparse_mm_search(self, device, dtype): - A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) - A_compressed = torch._cslt_compress(A) - B = torch.ones((128, 128), device=device).to(dtype) - - A_compressed = torch._cslt_compress(A) - alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False) - sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), - alg_id=alg_id, - split_k=split_k, - split_k_one_kernel=split_k_one_kernel) - dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32)) - dense_result = dense_result.to(dtype) - torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + assert alg_id in range(torch.backends.cusparselt.get_max_alg_id()) def test_cusparselt_backend(self): version = _get_torch_cuda_version() diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 7644a2a12568..9a15a6bd71aa 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -520,22 +520,18 @@ def meta__cslt_sparse_mm( alpha: Optional[Tensor] = None, out_dtype: Optional[torch.dtype] = None, transpose_result: bool = False, - alg_id: int = 0, - split_k: int = 1, - split_k_one_kernel: bool = False, ): assert dense_B.dtype in { torch.float32, torch.float16, torch.bfloat16, torch.int8, - torch.float8_e4m3fn, - }, "_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3" + }, "_cslt_sparse_mm only supports fp16, bf16, and int8" assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype" assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs" - is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn] - compression_factor = 10 if is_8bit_input_type else 9 + is_int8_input_type = compressed_A.dtype == torch.int8 + compression_factor = 10 if is_int8_input_type else 9 k = dense_B.size(0) n = dense_B.size(1) m = (compressed_A.numel() * 16) // (compression_factor * k) @@ -543,16 +539,11 @@ def meta__cslt_sparse_mm( assert m == bias.size(0) if out_dtype is not None: - assert ( - is_8bit_input_type - and out_dtype - in { - torch.float16, - torch.bfloat16, - torch.int32, - torch.float8_e4m3fn, - } - ), "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!" + assert is_int8_input_type and out_dtype in { + torch.float16, + torch.bfloat16, + torch.int32, + }, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul" output_shape = (n, m) if transpose_result else (m, n) result = dense_B.new_empty(output_shape, dtype=out_dtype) return result diff --git a/torch/csrc/cuda/shared/cusparselt.cpp b/torch/csrc/cuda/shared/cusparselt.cpp index 02be708e9139..ca020b75a706 100644 --- a/torch/csrc/cuda/shared/cusparselt.cpp +++ b/torch/csrc/cuda/shared/cusparselt.cpp @@ -1,7 +1,7 @@ #include #ifdef USE_CUSPARSELT -#include +#include namespace { @@ -9,34 +9,6 @@ size_t getVersionInt() { return CUSPARSELT_VERSION; } -std::tuple mmSearch( - const at::Tensor& compressed_A, - const at::Tensor& dense_B, - const std::optional& bias_opt, - const std::optional& alpha_opt, - const std::optional out_dtype_opt, - bool transpose_result) { - int alg_id_int = 0; - int split_k = 1; - bool split_k_one_kernel = true; - auto result = at::native::_cslt_sparse_mm_impl( - compressed_A, - dense_B, - bias_opt, - alpha_opt, - out_dtype_opt, - transpose_result, - alg_id_int, - split_k, - split_k_one_kernel, - true); - return { - (int64_t)std::get<1>(result), - (int64_t)std::get<2>(result), - (bool)std::get<3>(result), - (int64_t)std::get<4>(result)}; -} - } // namespace namespace torch::cuda::shared { @@ -45,7 +17,6 @@ void initCusparseltBindings(PyObject* module) { auto m = py::handle(module).cast(); auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings"); cusparselt.def("getVersionInt", getVersionInt); - cusparselt.def("mm_search", mmSearch); } } // namespace torch::cuda::shared diff --git a/torch/sparse/_semi_structured_ops.py b/torch/sparse/_semi_structured_ops.py index 11a55d9d523c..eb5557bf8b0d 100644 --- a/torch/sparse/_semi_structured_ops.py +++ b/torch/sparse/_semi_structured_ops.py @@ -103,8 +103,6 @@ def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor: packed_t=self.packed_t, meta_t=self.meta_t, compressed_swizzled_bitmask=self.compressed_swizzled_bitmask, - fuse_transpose_cusparselt=self.fuse_transpose_cusparselt, - alg_id_cusparselt=self.alg_id_cusparselt, requires_grad=False, )