mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[sparse] add extra options to _cslt_spare_mm (#137427)
Summary: Splitting this PR into two, one for the cuSPARSELt improvements, and one for the inductor lowering. This PR adds in the additional cuSPARSELt bindings into pytorch. * `torch._cslt_sparse_mm_search` will be deprecated in a future PR, so a warning has been added * Added a header file for cuSPARSELtOps.cpp * max_id is now available in `torch.backends.cusparselt` via `torch.backends.cusparselt.get_max_alg_id()` * fixed meta registrations for float8 Test Plan: python test/test_sparse_semi_structured.py Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427 Approved by: https://github.com/cpuhrsch, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
02990fe36b
commit
f1451163ec
@ -3371,7 +3371,7 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
CUDA: _cslt_compress
|
CUDA: _cslt_compress
|
||||||
|
|
||||||
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor
|
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, bool split_k_one_kernel=True) -> Tensor
|
||||||
dispatch:
|
dispatch:
|
||||||
CUDA: _cslt_sparse_mm
|
CUDA: _cslt_sparse_mm
|
||||||
|
|
||||||
|
@ -1,109 +1,97 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
||||||
#include <ATen/cuda/CUDADataType.h>
|
|
||||||
#include <ATen/cuda/CUDASparse.h>
|
|
||||||
#include <ATen/cuda/CUDAConfig.h>
|
|
||||||
#include <ATen/core/Tensor.h>
|
|
||||||
#include <ATen/Dispatch.h>
|
|
||||||
#include <ATen/Functions.h>
|
|
||||||
#include <c10/core/ScalarType.h>
|
|
||||||
#include <c10/cuda/CUDACachingAllocator.h>
|
|
||||||
#include <c10/util/Half.h>
|
|
||||||
#include <cusparse.h>
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
#if AT_CUSPARSELT_ENABLED()
|
#if AT_CUSPARSELT_ENABLED()
|
||||||
|
|
||||||
#include <cusparseLt.h>
|
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
// Ideally we would use the same DeviceThreadHandlePool mechanism as used in aten/src/ATen/cuda/CuSparseHandlePool.cpp
|
// Ideally we would use the same DeviceThreadHandlePool mechanism as used in
|
||||||
// which would handle this for us. However, the cuSPARSELt handle signature is different from that of cuSPARSE/cuBLAS,
|
// aten/src/ATen/cuda/CuSparseHandlePool.cpp which would handle this for us.
|
||||||
// so it's not possible to reuse the existing pooling mechanism. Instead we have to handle our handles ourselves, which
|
// However, the cuSPARSELt handle signature is different from that of
|
||||||
// is why these variables are thread local. Once cuSPARSELt updates their handle signature to be consistent with the rest
|
// cuSPARSE/cuBLAS, so it's not possible to reuse the existing pooling
|
||||||
// of CUDA, we can switch to using DeviceThreadHandlePool.
|
// mechanism. Instead we have to handle our handles ourselves, which is why
|
||||||
|
// these variables are thread local. Once cuSPARSELt updates their handle
|
||||||
|
// signature to be consistent with the rest of CUDA, we can switch to using
|
||||||
|
// DeviceThreadHandlePool.
|
||||||
thread_local cusparseLtHandle_t handle;
|
thread_local cusparseLtHandle_t handle;
|
||||||
thread_local bool handle_initialized = false;
|
thread_local bool handle_initialized = false;
|
||||||
|
|
||||||
at::Tensor _cslt_compress(const Tensor& sparse_input)
|
at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
||||||
{
|
if (!handle_initialized) {
|
||||||
if (!handle_initialized){
|
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
|
handle_initialized = true;
|
||||||
handle_initialized = true;
|
}
|
||||||
}
|
// create sparse descriptor, dtype
|
||||||
// create sparse descriptor, dtype
|
cusparseLtMatDescriptor_t sparse_input_descriptor;
|
||||||
cusparseLtMatDescriptor_t sparse_input_descriptor;
|
cudaDataType type;
|
||||||
cudaDataType type;
|
auto compression_factor = 9;
|
||||||
auto compression_factor = 9;
|
|
||||||
|
|
||||||
switch(
|
switch (sparse_input.scalar_type()) {
|
||||||
sparse_input.scalar_type()
|
case at::ScalarType::Char:
|
||||||
)
|
type = CUDA_R_8I;
|
||||||
{
|
compression_factor = 10;
|
||||||
case at::ScalarType::Char:
|
break;
|
||||||
type = CUDA_R_8I;
|
case at::ScalarType::Half:
|
||||||
compression_factor = 10;
|
type = CUDA_R_16F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::BFloat16:
|
||||||
type = CUDA_R_16F;
|
type = CUDA_R_16BF;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::BFloat16:
|
case at::ScalarType::Float:
|
||||||
type = CUDA_R_16BF;
|
type = CUDA_R_32F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::Float:
|
|
||||||
type = CUDA_R_32F;
|
|
||||||
break;
|
|
||||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||||
case at::ScalarType::Float8_e4m3fn:
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
type = CUDA_R_8F_E4M3;
|
type = CUDA_R_8F_E4M3;
|
||||||
break;
|
compression_factor = 10;
|
||||||
|
break;
|
||||||
#endif
|
#endif
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix");
|
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a new compressed tensor with the same dtype as
|
// create a new compressed tensor with the same dtype as
|
||||||
auto compressed_tensor = sparse_input.new_empty(sparse_input.numel() * compression_factor / 16);
|
auto compressed_tensor =
|
||||||
|
sparse_input.new_empty(sparse_input.numel() * compression_factor / 16);
|
||||||
|
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit(
|
TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit(
|
||||||
&handle,
|
&handle,
|
||||||
&sparse_input_descriptor,
|
&sparse_input_descriptor,
|
||||||
sparse_input.size(0),
|
sparse_input.size(0),
|
||||||
sparse_input.size(1),
|
sparse_input.size(1),
|
||||||
sparse_input.size(1),
|
sparse_input.size(1),
|
||||||
16,
|
16,
|
||||||
type,
|
type,
|
||||||
CUSPARSE_ORDER_ROW,
|
CUSPARSE_ORDER_ROW,
|
||||||
CUSPARSELT_SPARSITY_50_PERCENT));
|
CUSPARSELT_SPARSITY_50_PERCENT));
|
||||||
|
|
||||||
// compress input
|
// compress input
|
||||||
//--------------------------------------------------------------------------
|
//--------------------------------------------------------------------------
|
||||||
size_t compressed_size, compressed_buffer_size;
|
size_t compressed_size, compressed_buffer_size;
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompressedSize2(
|
TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompressedSize2(
|
||||||
&handle,
|
&handle,
|
||||||
&sparse_input_descriptor,
|
&sparse_input_descriptor,
|
||||||
&compressed_size,
|
&compressed_size,
|
||||||
&compressed_buffer_size));
|
&compressed_buffer_size));
|
||||||
|
|
||||||
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
|
||||||
auto compressedBufferPtr = allocator.allocate(compressed_buffer_size);
|
auto compressedBufferPtr = allocator.allocate(compressed_buffer_size);
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompress2(
|
TORCH_CUDASPARSE_CHECK(cusparseLtSpMMACompress2(
|
||||||
&handle,
|
&handle,
|
||||||
&sparse_input_descriptor,
|
&sparse_input_descriptor,
|
||||||
true,
|
true,
|
||||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||||
sparse_input.data_ptr(),
|
sparse_input.data_ptr(),
|
||||||
compressed_tensor.data_ptr(),
|
compressed_tensor.data_ptr(),
|
||||||
compressedBufferPtr.get(),
|
compressedBufferPtr.get(),
|
||||||
stream));
|
stream));
|
||||||
|
|
||||||
return compressed_tensor;
|
return compressed_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||||
const Tensor& compressed_A,
|
const Tensor& compressed_A,
|
||||||
const Tensor& dense_B,
|
const Tensor& dense_B,
|
||||||
const std::optional<Tensor>& bias_opt,
|
const std::optional<Tensor>& bias_opt,
|
||||||
@ -111,12 +99,12 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||||
bool transpose_result,
|
bool transpose_result,
|
||||||
int alg_id,
|
int alg_id,
|
||||||
bool search_alg_id
|
int split_k,
|
||||||
)
|
bool split_k_one_kernel,
|
||||||
{
|
bool search_alg_id) {
|
||||||
if (!handle_initialized){
|
if (!handle_initialized) {
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
|
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
|
||||||
handle_initialized = true;
|
handle_initialized = true;
|
||||||
}
|
}
|
||||||
// cupsarselt constructs
|
// cupsarselt constructs
|
||||||
cusparseLtMatmulDescriptor_t matmul;
|
cusparseLtMatmulDescriptor_t matmul;
|
||||||
@ -132,134 +120,138 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
cusparseComputeType compute_type;
|
cusparseComputeType compute_type;
|
||||||
auto compression_factor = 9;
|
auto compression_factor = 9;
|
||||||
|
|
||||||
switch(compressed_A.scalar_type())
|
switch (compressed_A.scalar_type()) {
|
||||||
{
|
|
||||||
case at::ScalarType::Char:
|
case at::ScalarType::Char:
|
||||||
input_type = CUDA_R_8I;
|
input_type = CUDA_R_8I;
|
||||||
output_type = CUDA_R_8I;
|
output_type = CUDA_R_8I;
|
||||||
C_type = CUDA_R_8I;
|
C_type = CUDA_R_8I;
|
||||||
compute_type = CUSPARSE_COMPUTE_32I;
|
compute_type = CUSPARSE_COMPUTE_32I;
|
||||||
compression_factor = 10;
|
compression_factor = 10;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F to CUSPARSE_COMPUTE_32F
|
// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F
|
||||||
|
// to CUSPARSE_COMPUTE_32F
|
||||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
input_type = CUDA_R_16F;
|
input_type = CUDA_R_16F;
|
||||||
output_type = CUDA_R_16F;
|
output_type = CUDA_R_16F;
|
||||||
C_type = CUDA_R_16F;
|
C_type = CUDA_R_16F;
|
||||||
compute_type = CUSPARSE_COMPUTE_32F;
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::BFloat16:
|
case at::ScalarType::BFloat16:
|
||||||
input_type = CUDA_R_16BF;
|
input_type = CUDA_R_16BF;
|
||||||
output_type = CUDA_R_16BF;
|
output_type = CUDA_R_16BF;
|
||||||
C_type = CUDA_R_16BF;
|
C_type = CUDA_R_16BF;
|
||||||
compute_type = CUSPARSE_COMPUTE_32F;
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::Float:
|
case at::ScalarType::Float:
|
||||||
input_type = CUDA_R_32F;
|
input_type = CUDA_R_32F;
|
||||||
output_type = CUDA_R_32F;
|
output_type = CUDA_R_32F;
|
||||||
C_type = CUDA_R_32F;
|
C_type = CUDA_R_32F;
|
||||||
compute_type = CUSPARSE_COMPUTE_32F;
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
break;
|
break;
|
||||||
// if cuSPARSELt >= 6.2.3, we can add Float8 support
|
// if cuSPARSELt >= 6.2.3, we can add Float8 support
|
||||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||||
case at::ScalarType::Float8_e4m3fn:
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
input_type = CUDA_R_8F_E4M3;
|
input_type = CUDA_R_8F_E4M3;
|
||||||
output_type = CUDA_R_8F_E4M3;
|
output_type = CUDA_R_8F_E4M3;
|
||||||
C_type = CUDA_R_16F;
|
C_type = CUDA_R_16F;
|
||||||
compute_type = CUSPARSE_COMPUTE_32F;
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
break;
|
compression_factor = 10;
|
||||||
|
break;
|
||||||
#endif
|
#endif
|
||||||
// cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F
|
// cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F
|
||||||
#else
|
#else
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
input_type = CUDA_R_16F;
|
input_type = CUDA_R_16F;
|
||||||
output_type = CUDA_R_16F;
|
output_type = CUDA_R_16F;
|
||||||
C_type = CUDA_R_16F;
|
C_type = CUDA_R_16F;
|
||||||
compute_type = CUSPARSE_COMPUTE_16F;
|
compute_type = CUSPARSE_COMPUTE_16F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::BFloat16:
|
case at::ScalarType::BFloat16:
|
||||||
input_type = CUDA_R_16BF;
|
input_type = CUDA_R_16BF;
|
||||||
output_type = CUDA_R_16BF;
|
output_type = CUDA_R_16BF;
|
||||||
C_type = CUDA_R_16BF;
|
C_type = CUDA_R_16BF;
|
||||||
compute_type = CUSPARSE_COMPUTE_16F;
|
compute_type = CUSPARSE_COMPUTE_16F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::Float:
|
case at::ScalarType::Float:
|
||||||
input_type = CUDA_R_32F;
|
input_type = CUDA_R_32F;
|
||||||
output_type = CUDA_R_32F;
|
output_type = CUDA_R_32F;
|
||||||
C_type = CUDA_R_32F;
|
C_type = CUDA_R_32F;
|
||||||
compute_type = CUSPARSE_COMPUTE_TF32;
|
compute_type = CUSPARSE_COMPUTE_TF32;
|
||||||
break;
|
break;
|
||||||
#endif
|
#endif
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix multiplication.");
|
TORCH_CHECK(
|
||||||
break;
|
false,
|
||||||
|
"Unsupported dtype for cuSPARSELt compressed matrix multiplication.");
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
ScalarType out_dtype = dense_B.scalar_type();
|
ScalarType out_dtype = dense_B.scalar_type();
|
||||||
// special check for mixed dtype support for 8 bit dtypes
|
// special check for mixed dtype support for 8 bit dtypes
|
||||||
// cslt 0.5.2+: int8 int8 -> {fp16, bf16, int32} support
|
// cslt 0.5.2+: int8 int8 -> {fp16, bf16, int32} support
|
||||||
if (out_dtype_opt.has_value()) {
|
if (out_dtype_opt.has_value()) {
|
||||||
out_dtype = out_dtype_opt.value();
|
out_dtype = out_dtype_opt.value();
|
||||||
if (input_type == CUDA_R_8I)
|
if (input_type == CUDA_R_8I) {
|
||||||
{
|
switch (out_dtype) {
|
||||||
switch (out_dtype)
|
case at::ScalarType::Half:
|
||||||
{
|
C_type = CUDA_R_16F;
|
||||||
case at::ScalarType::Half:
|
output_type = CUDA_R_16F;
|
||||||
C_type = CUDA_R_16F;
|
break;
|
||||||
output_type = CUDA_R_16F;
|
case at::ScalarType::BFloat16:
|
||||||
break;
|
C_type = CUDA_R_16BF;
|
||||||
case at::ScalarType::BFloat16:
|
output_type = CUDA_R_16BF;
|
||||||
C_type = CUDA_R_16BF;
|
break;
|
||||||
output_type = CUDA_R_16BF;
|
case at::ScalarType::Int:
|
||||||
break;
|
C_type = CUDA_R_32I;
|
||||||
case at::ScalarType::Int:
|
output_type = CUDA_R_32I;
|
||||||
C_type = CUDA_R_32I;
|
break;
|
||||||
output_type = CUDA_R_32I;
|
default:
|
||||||
break;
|
TORCH_CHECK(
|
||||||
default:
|
false,
|
||||||
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, int32} for int8 inputs");
|
"Unsupported out_dtype passed, must be one of {fp16, bf16, int32} for int8 inputs");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
|
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
|
||||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||||
else if (input_type == CUDA_R_8F_E4M3)
|
else if (input_type == CUDA_R_8F_E4M3) {
|
||||||
{
|
switch (out_dtype) {
|
||||||
switch (out_dtype)
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
{
|
output_type = CUDA_R_8F_E4M3;
|
||||||
case at::ScalarType::Float8_e4m3fn:
|
C_type = CUDA_R_16F;
|
||||||
output_type = CUDA_R_8F_E4M3;
|
break;
|
||||||
C_type = CUDA_R_16F;
|
case at::ScalarType::Half:
|
||||||
break;
|
output_type = CUDA_R_16F;
|
||||||
case at::ScalarType::Half:
|
C_type = CUDA_R_16F;
|
||||||
output_type = CUDA_R_16F;
|
break;
|
||||||
C_type = CUDA_R_16F;
|
case at::ScalarType::BFloat16:
|
||||||
break;
|
output_type = CUDA_R_16BF;
|
||||||
case at::ScalarType::BFloat16:
|
C_type = CUDA_R_16BF;
|
||||||
output_type = CUDA_R_16BF;
|
break;
|
||||||
C_type = CUDA_R_16BF;
|
case at::ScalarType::Float:
|
||||||
break;
|
output_type = CUDA_R_32F;
|
||||||
case at::ScalarType::Float:
|
C_type = CUDA_R_32F;
|
||||||
output_type = CUDA_R_32F;
|
break;
|
||||||
C_type = CUDA_R_32F;
|
default:
|
||||||
break;
|
TORCH_CHECK(
|
||||||
default:
|
false,
|
||||||
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, float32} for fp8 inputs");
|
"Unsupported out_dtype passed, must be one of {fp16, bf16, float32} for fp8 inputs");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "out_dtype support only available for int8/fp8 inputs");
|
TORCH_CHECK(
|
||||||
|
false, "out_dtype support only available for int8/fp8 inputs");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t k = dense_B.size(0);
|
int64_t k = dense_B.size(0);
|
||||||
int64_t n = dense_B.size(1);
|
int64_t n = dense_B.size(1);
|
||||||
int64_t m = (compressed_A.numel() * 16 / compression_factor ) / k;
|
int64_t m = (compressed_A.numel() * 16 / compression_factor) / k;
|
||||||
|
|
||||||
//initialize sparse descriptor
|
// initialize sparse descriptor
|
||||||
cusparseLtMatDescriptor_t sparse_input_descriptor;
|
cusparseLtMatDescriptor_t sparse_input_descriptor;
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit(
|
TORCH_CUDASPARSE_CHECK(cusparseLtStructuredDescriptorInit(
|
||||||
&handle,
|
&handle,
|
||||||
@ -285,7 +277,8 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
CUSPARSE_ORDER_ROW));
|
CUSPARSE_ORDER_ROW));
|
||||||
|
|
||||||
// create result tensor
|
// create result tensor
|
||||||
auto res_tensor_options = c10::TensorOptions().dtype(out_dtype).device(dense_B.device());
|
auto res_tensor_options =
|
||||||
|
c10::TensorOptions().dtype(out_dtype).device(dense_B.device());
|
||||||
at::Tensor res = (transpose_result) ? at::empty({n, m}, res_tensor_options)
|
at::Tensor res = (transpose_result) ? at::empty({n, m}, res_tensor_options)
|
||||||
: at::empty({m, n}, res_tensor_options);
|
: at::empty({m, n}, res_tensor_options);
|
||||||
|
|
||||||
@ -295,7 +288,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
&res_descriptor,
|
&res_descriptor,
|
||||||
m,
|
m,
|
||||||
n,
|
n,
|
||||||
(transpose_result) ? m: n,
|
(transpose_result) ? m : n,
|
||||||
16,
|
16,
|
||||||
output_type,
|
output_type,
|
||||||
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
|
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
|
||||||
@ -307,7 +300,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
&C_descriptor,
|
&C_descriptor,
|
||||||
m,
|
m,
|
||||||
n,
|
n,
|
||||||
(transpose_result) ? m: n,
|
(transpose_result) ? m : n,
|
||||||
16,
|
16,
|
||||||
C_type,
|
C_type,
|
||||||
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
|
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
|
||||||
@ -317,7 +310,8 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
&handle,
|
&handle,
|
||||||
&matmul,
|
&matmul,
|
||||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||||
(dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE,
|
(dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE
|
||||||
|
: CUSPARSE_OPERATION_TRANSPOSE,
|
||||||
&sparse_input_descriptor,
|
&sparse_input_descriptor,
|
||||||
&dense_input_descriptor,
|
&dense_input_descriptor,
|
||||||
&C_descriptor,
|
&C_descriptor,
|
||||||
@ -329,28 +323,59 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
auto& bias = bias_opt.value();
|
auto& bias = bias_opt.value();
|
||||||
void* dBias = bias.data_ptr();
|
void* dBias = bias.data_ptr();
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute(
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute(
|
||||||
&handle, &matmul, CUSPARSELT_MATMUL_BIAS_POINTER, &dBias, sizeof(dBias)));
|
&handle,
|
||||||
|
&matmul,
|
||||||
|
CUSPARSELT_MATMUL_BIAS_POINTER,
|
||||||
|
&dBias,
|
||||||
|
sizeof(dBias)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit(
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit(
|
||||||
&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT));
|
&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT));
|
||||||
|
|
||||||
// set alg_id
|
// set matmul search params
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||||
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id)));
|
&handle,
|
||||||
|
&alg_sel,
|
||||||
|
CUSPARSELT_MATMUL_ALG_CONFIG_ID,
|
||||||
|
&alg_id,
|
||||||
|
sizeof(alg_id)));
|
||||||
|
|
||||||
|
cusparseLtSplitKMode_t splitKMode;
|
||||||
|
int max_alg_id;
|
||||||
|
if (split_k != 1) {
|
||||||
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||||
|
&handle,
|
||||||
|
&alg_sel,
|
||||||
|
CUSPARSELT_MATMUL_SPLIT_K,
|
||||||
|
&split_k,
|
||||||
|
sizeof(split_k)));
|
||||||
|
|
||||||
|
splitKMode = split_k_one_kernel ? CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL
|
||||||
|
: CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS;
|
||||||
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||||
|
&handle,
|
||||||
|
&alg_sel,
|
||||||
|
CUSPARSELT_MATMUL_SPLIT_K_MODE,
|
||||||
|
&splitKMode,
|
||||||
|
sizeof(splitKMode)));
|
||||||
|
}
|
||||||
|
|
||||||
// set tensor_alpha_mode and alpha pointer for matmul
|
// set tensor_alpha_mode and alpha pointer for matmul
|
||||||
const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{};
|
const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt : Tensor{};
|
||||||
auto alpha_ptr = α
|
auto alpha_ptr = α
|
||||||
if (alpha_opt.has_value()) {
|
if (alpha_opt.has_value()) {
|
||||||
if (alpha_tensor.numel() == 1) {
|
if (alpha_tensor.numel() == 1) {
|
||||||
alpha = alpha_tensor.item<float>();
|
alpha = alpha_tensor.item<float>();
|
||||||
}
|
} else {
|
||||||
else {
|
tensor_alpha_mode = 1;
|
||||||
tensor_alpha_mode = 1;
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute(
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute(
|
&handle,
|
||||||
&handle, &matmul, CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, &tensor_alpha_mode, sizeof(tensor_alpha_mode)));
|
&matmul,
|
||||||
alpha_ptr = static_cast<float*>(alpha_tensor.data_ptr());
|
CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING,
|
||||||
|
&tensor_alpha_mode,
|
||||||
|
sizeof(tensor_alpha_mode)));
|
||||||
|
alpha_ptr = static_cast<float*>(alpha_tensor.data_ptr());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -365,7 +390,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
auto workspacePtr = allocator.allocate(workspace_size);
|
auto workspacePtr = allocator.allocate(workspace_size);
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
if(search_alg_id){
|
if (search_alg_id) {
|
||||||
// run matmul search
|
// run matmul search
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulSearch(
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulSearch(
|
||||||
&handle,
|
&handle,
|
||||||
@ -381,11 +406,36 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
&stream,
|
&stream,
|
||||||
1));
|
1));
|
||||||
|
|
||||||
// get alg_id used
|
// get matmul params used
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
|
||||||
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id)));
|
&handle,
|
||||||
}
|
&alg_sel,
|
||||||
else {
|
CUSPARSELT_MATMUL_ALG_CONFIG_ID,
|
||||||
|
&alg_id,
|
||||||
|
sizeof(alg_id)));
|
||||||
|
|
||||||
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
|
||||||
|
&handle,
|
||||||
|
&alg_sel,
|
||||||
|
CUSPARSELT_MATMUL_SPLIT_K,
|
||||||
|
&split_k,
|
||||||
|
sizeof(split_k)));
|
||||||
|
|
||||||
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
|
||||||
|
&handle,
|
||||||
|
&alg_sel,
|
||||||
|
CUSPARSELT_MATMUL_SPLIT_K_MODE,
|
||||||
|
&splitKMode,
|
||||||
|
sizeof(splitKMode)));
|
||||||
|
|
||||||
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
|
||||||
|
&handle,
|
||||||
|
&alg_sel,
|
||||||
|
CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID,
|
||||||
|
&max_alg_id,
|
||||||
|
sizeof(max_alg_id)));
|
||||||
|
|
||||||
|
} else {
|
||||||
// do normal matmul
|
// do normal matmul
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmul(
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmul(
|
||||||
&handle,
|
&handle,
|
||||||
@ -402,7 +452,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
1));
|
1));
|
||||||
}
|
}
|
||||||
|
|
||||||
//destroy descriptors
|
// destroy descriptors
|
||||||
TORCH_CUDASPARSE_CHECK(
|
TORCH_CUDASPARSE_CHECK(
|
||||||
cusparseLtMatDescriptorDestroy(&sparse_input_descriptor));
|
cusparseLtMatDescriptorDestroy(&sparse_input_descriptor));
|
||||||
TORCH_CUDASPARSE_CHECK(
|
TORCH_CUDASPARSE_CHECK(
|
||||||
@ -411,7 +461,12 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
// destroy plan
|
// destroy plan
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanDestroy(&plan));
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanDestroy(&plan));
|
||||||
|
|
||||||
return {alg_id, res};
|
return {
|
||||||
|
res,
|
||||||
|
alg_id,
|
||||||
|
split_k,
|
||||||
|
splitKMode == CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL,
|
||||||
|
max_alg_id};
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor _cslt_sparse_mm(
|
at::Tensor _cslt_sparse_mm(
|
||||||
@ -421,19 +476,21 @@ at::Tensor _cslt_sparse_mm(
|
|||||||
const std::optional<Tensor>& alpha_opt,
|
const std::optional<Tensor>& alpha_opt,
|
||||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||||
bool transpose_result,
|
bool transpose_result,
|
||||||
int64_t alg_id
|
int64_t alg_id,
|
||||||
)
|
int64_t split_k,
|
||||||
{
|
bool split_k_one_kernel) {
|
||||||
auto result = _cslt_sparse_mm_impl(
|
auto result = _cslt_sparse_mm_impl(
|
||||||
compressed_A,
|
compressed_A,
|
||||||
dense_B,
|
dense_B,
|
||||||
bias_opt,
|
bias_opt,
|
||||||
alpha_opt,
|
alpha_opt,
|
||||||
out_dtype_opt,
|
out_dtype_opt,
|
||||||
transpose_result,
|
transpose_result,
|
||||||
(int) alg_id,
|
(int)alg_id,
|
||||||
false);
|
(int)split_k,
|
||||||
return std::get<1>(result);
|
split_k_one_kernel,
|
||||||
|
false);
|
||||||
|
return std::get<0>(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t _cslt_sparse_mm_search(
|
int64_t _cslt_sparse_mm_search(
|
||||||
@ -442,31 +499,34 @@ int64_t _cslt_sparse_mm_search(
|
|||||||
const std::optional<Tensor>& bias_opt,
|
const std::optional<Tensor>& bias_opt,
|
||||||
const std::optional<Tensor>& alpha_opt,
|
const std::optional<Tensor>& alpha_opt,
|
||||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||||
bool transpose_result
|
bool transpose_result) {
|
||||||
)
|
TORCH_WARN_ONCE(
|
||||||
{
|
"torch._cslt_sparse_mm_search is deprecated and will be removed in a future PyTorch release. Please use torch._C._cusparselt.mm_search instead.");
|
||||||
int alg_id_int = 0;
|
int alg_id_int = 0;
|
||||||
auto result = _cslt_sparse_mm_impl(
|
int split_k = 1;
|
||||||
compressed_A,
|
bool split_k_one_kernel = true;
|
||||||
dense_B,
|
auto result = _cslt_sparse_mm_impl(
|
||||||
bias_opt,
|
compressed_A,
|
||||||
alpha_opt,
|
dense_B,
|
||||||
out_dtype_opt,
|
bias_opt,
|
||||||
transpose_result,
|
alpha_opt,
|
||||||
alg_id_int,
|
out_dtype_opt,
|
||||||
true);
|
transpose_result,
|
||||||
return (int64_t) std::get<0>(result);
|
alg_id_int,
|
||||||
|
split_k,
|
||||||
|
split_k_one_kernel,
|
||||||
|
true);
|
||||||
|
return (int64_t)std::get<1>(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|
||||||
#else // No cuSPARSELt support, throw error if these functions are called.
|
#else // No cuSPARSELt support, throw error if these functions are called.
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
at::Tensor _cslt_compress(const Tensor& sparse_input){
|
at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
||||||
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
|
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor _cslt_sparse_mm(
|
at::Tensor _cslt_sparse_mm(
|
||||||
@ -476,9 +536,10 @@ at::Tensor _cslt_sparse_mm(
|
|||||||
const std::optional<Tensor>& alpha_opt,
|
const std::optional<Tensor>& alpha_opt,
|
||||||
const std::optional<c10::ScalarType> out_dtype,
|
const std::optional<c10::ScalarType> out_dtype,
|
||||||
bool transpose_result,
|
bool transpose_result,
|
||||||
int64_t alg_id)
|
int64_t alg_id,
|
||||||
{
|
int64_t split_k,
|
||||||
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
|
bool split_k_one_kernel) {
|
||||||
|
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t _cslt_sparse_mm_search(
|
int64_t _cslt_sparse_mm_search(
|
||||||
@ -487,10 +548,8 @@ int64_t _cslt_sparse_mm_search(
|
|||||||
const std::optional<Tensor>& bias_opt,
|
const std::optional<Tensor>& bias_opt,
|
||||||
const std::optional<Tensor>& alpha_opt,
|
const std::optional<Tensor>& alpha_opt,
|
||||||
const std::optional<c10::ScalarType> out_dtype,
|
const std::optional<c10::ScalarType> out_dtype,
|
||||||
bool transpose_result
|
bool transpose_result) {
|
||||||
)
|
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
|
||||||
{
|
|
||||||
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
58
aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h
Normal file
58
aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <ATen/cuda/CUDADataType.h>
|
||||||
|
#include <ATen/cuda/CUDASparse.h>
|
||||||
|
#include <ATen/cuda/CUDAConfig.h>
|
||||||
|
#include <ATen/core/Tensor.h>
|
||||||
|
#include <ATen/Dispatch.h>
|
||||||
|
#include <ATen/Functions.h>
|
||||||
|
#include <c10/core/ScalarType.h>
|
||||||
|
#include <c10/cuda/CUDACachingAllocator.h>
|
||||||
|
#include <c10/util/Half.h>
|
||||||
|
#include <cusparse.h>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#if AT_CUSPARSELT_ENABLED()
|
||||||
|
#include <cusparseLt.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
|
|
||||||
|
at::Tensor _cslt_compress(const Tensor& sparse_input);
|
||||||
|
|
||||||
|
TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||||
|
const Tensor& compressed_A,
|
||||||
|
const Tensor& dense_B,
|
||||||
|
const std::optional<Tensor>& bias_opt,
|
||||||
|
const std::optional<Tensor>& alpha_opt,
|
||||||
|
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||||
|
bool transpose_result,
|
||||||
|
int alg_id,
|
||||||
|
int split_k,
|
||||||
|
bool split_k_one_kernel,
|
||||||
|
bool search_alg_id
|
||||||
|
);
|
||||||
|
|
||||||
|
at::Tensor _cslt_sparse_mm(
|
||||||
|
const Tensor& compressed_A,
|
||||||
|
const Tensor& dense_B,
|
||||||
|
const std::optional<Tensor>& bias_opt,
|
||||||
|
const std::optional<Tensor>& alpha_opt,
|
||||||
|
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||||
|
bool transpose_result,
|
||||||
|
int64_t alg_id,
|
||||||
|
int64_t split_k,
|
||||||
|
bool split_k_one_kernel
|
||||||
|
);
|
||||||
|
|
||||||
|
int64_t _cslt_sparse_mm_search(
|
||||||
|
const Tensor& compressed_A,
|
||||||
|
const Tensor& dense_B,
|
||||||
|
const std::optional<Tensor>& bias_opt,
|
||||||
|
const std::optional<Tensor>& alpha_opt,
|
||||||
|
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||||
|
bool transpose_result
|
||||||
|
);
|
||||||
|
|
||||||
|
} // namespace at::native
|
@ -1,253 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import random
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.utils.benchmark as benchmark
|
|
||||||
from torch import nn
|
|
||||||
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
|
|
||||||
|
|
||||||
|
|
||||||
torch.set_printoptions(
|
|
||||||
precision=2,
|
|
||||||
threshold=None,
|
|
||||||
edgeitems=16,
|
|
||||||
linewidth=480,
|
|
||||||
profile=None,
|
|
||||||
sci_mode=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# helper model definition for pruner
|
|
||||||
class Model(nn.Module):
|
|
||||||
def __init__(self, m, k, dtype=None):
|
|
||||||
super().__init__()
|
|
||||||
# transposed so reversed
|
|
||||||
self.linear = nn.Linear(k, m)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.linear(x)
|
|
||||||
|
|
||||||
|
|
||||||
def rand_sparse_semi_structured_mask(
|
|
||||||
r, c, dtype=torch.float16, device="cuda", choice=None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
This function returns a 1:2 sparse matrix of size (r, c).
|
|
||||||
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
|
|
||||||
"""
|
|
||||||
|
|
||||||
choices = [[0, 1], [1, 0]]
|
|
||||||
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
|
||||||
|
|
||||||
return (
|
|
||||||
torch.tensor(mask_entries, dtype=dtype, device=device)
|
|
||||||
.reshape(r, c)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_linear(m, k, n, dtype, contiguous, backend):
|
|
||||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
|
||||||
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
|
|
||||||
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
|
|
||||||
input_tensor = torch.zeros(n, k).to(dtype).cuda()
|
|
||||||
model = Model(m, k).to(dtype).cuda().eval()
|
|
||||||
|
|
||||||
dense_measurement = benchmark.Timer(
|
|
||||||
stmt="model(input_tensor)",
|
|
||||||
globals=locals(),
|
|
||||||
).blocked_autorange()
|
|
||||||
|
|
||||||
dense_output = model(input_tensor)
|
|
||||||
print(dense_output.shape)
|
|
||||||
|
|
||||||
# sparsify weights
|
|
||||||
model.linear.weight = nn.Parameter(
|
|
||||||
to_sparse_semi_structured(
|
|
||||||
sparse_weight,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
sparse_output = model(input_tensor)
|
|
||||||
print(sparse_output.shape)
|
|
||||||
|
|
||||||
sparse_measurement = benchmark.Timer(
|
|
||||||
stmt="model(input_tensor)",
|
|
||||||
globals=locals(),
|
|
||||||
).blocked_autorange()
|
|
||||||
|
|
||||||
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"test_function": "linear",
|
|
||||||
"m": m,
|
|
||||||
"k": k,
|
|
||||||
"n": n,
|
|
||||||
"dtype": str(dtype),
|
|
||||||
"backend": backend,
|
|
||||||
"sparse_latency (ms)": sparse_measurement.median * 1000,
|
|
||||||
"dense_latency (ms)": dense_measurement.median * 1000,
|
|
||||||
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
|
|
||||||
"correct": correct,
|
|
||||||
"contiguous": sparse_output.is_contiguous(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_tensor(m, k, n, dtype, contiguous, backend):
|
|
||||||
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
|
|
||||||
B = torch.zeros(k, n).to(dtype).cuda()
|
|
||||||
bias = torch.rand(n).to(dtype).cuda()
|
|
||||||
|
|
||||||
sA = to_sparse_semi_structured(A)
|
|
||||||
|
|
||||||
# torch.mm calculation
|
|
||||||
if dtype is not torch.int8:
|
|
||||||
dense_output = torch.mm(A, B)
|
|
||||||
|
|
||||||
dense_measurement = benchmark.Timer(
|
|
||||||
stmt="torch.mm(A, B)",
|
|
||||||
globals=locals(),
|
|
||||||
).blocked_autorange()
|
|
||||||
|
|
||||||
else:
|
|
||||||
print("int8 baseline not supported")
|
|
||||||
dense_output = torch.mm(sA, B)
|
|
||||||
|
|
||||||
dense_measurement = benchmark.Timer(
|
|
||||||
stmt="torch.mm(sA, B)",
|
|
||||||
globals=locals(),
|
|
||||||
).blocked_autorange()
|
|
||||||
|
|
||||||
sparse_output = torch.mm(sA, B)
|
|
||||||
sparse_measurement = benchmark.Timer(
|
|
||||||
stmt="torch.mm(sA, B)",
|
|
||||||
globals=locals(),
|
|
||||||
).blocked_autorange()
|
|
||||||
|
|
||||||
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"test_function": "tensor",
|
|
||||||
"m": m,
|
|
||||||
"k": k,
|
|
||||||
"n": n,
|
|
||||||
"dtype": str(dtype),
|
|
||||||
"backend": backend,
|
|
||||||
"sparse_latency (ms)": sparse_measurement.median * 1000,
|
|
||||||
"dense_latency (ms)": dense_measurement.median * 1000,
|
|
||||||
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
|
|
||||||
"correct": correct,
|
|
||||||
"contiguous": sparse_output.is_contiguous(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
dtype_lookup = {
|
|
||||||
"int8": torch.int8,
|
|
||||||
"fp16": torch.float16,
|
|
||||||
"bf16": torch.bfloat16,
|
|
||||||
"fp32": torch.float32,
|
|
||||||
}
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
|
|
||||||
parser.add_argument(
|
|
||||||
"--mode",
|
|
||||||
type=str,
|
|
||||||
choices=[
|
|
||||||
"nvidia-bert",
|
|
||||||
"nvidia-fixed-k",
|
|
||||||
"nvidia-fixed-mn",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dtype",
|
|
||||||
type=str,
|
|
||||||
choices=dtype_lookup.keys(),
|
|
||||||
default="fp16",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
|
|
||||||
)
|
|
||||||
parser.add_argument("-contiguous", action="store_true")
|
|
||||||
parser.add_argument("-e2e", action="store_true")
|
|
||||||
parser.add_argument("-save", action="store_true")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.e2e:
|
|
||||||
eval_fn = test_linear
|
|
||||||
else:
|
|
||||||
eval_fn = test_tensor
|
|
||||||
|
|
||||||
print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
|
|
||||||
dtype = dtype_lookup[args.dtype]
|
|
||||||
|
|
||||||
if args.mode == "nvidia-bert":
|
|
||||||
bert_shapes = [
|
|
||||||
(3072, 1024, 16384),
|
|
||||||
(4096, 1024, 16384),
|
|
||||||
(1024, 1024, 16384),
|
|
||||||
(1024, 4096, 16384),
|
|
||||||
]
|
|
||||||
results = (
|
|
||||||
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
|
|
||||||
for (m, k, n) in tqdm(bert_shapes)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif args.mode == "nvidia-fixed-k":
|
|
||||||
mn_vals = [
|
|
||||||
3072,
|
|
||||||
4096,
|
|
||||||
5120,
|
|
||||||
6144,
|
|
||||||
7168,
|
|
||||||
8192,
|
|
||||||
9216,
|
|
||||||
10240,
|
|
||||||
11264,
|
|
||||||
12288,
|
|
||||||
13312,
|
|
||||||
14336,
|
|
||||||
15360,
|
|
||||||
16384,
|
|
||||||
17408,
|
|
||||||
18432,
|
|
||||||
19456,
|
|
||||||
20480,
|
|
||||||
]
|
|
||||||
results = (
|
|
||||||
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
|
|
||||||
for mn in tqdm(mn_vals)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif args.mode == "nvidia-fixed-mn":
|
|
||||||
k_vals = [
|
|
||||||
2560,
|
|
||||||
3840,
|
|
||||||
5120,
|
|
||||||
6400,
|
|
||||||
7680,
|
|
||||||
8960,
|
|
||||||
10240,
|
|
||||||
11520,
|
|
||||||
12800,
|
|
||||||
14080,
|
|
||||||
15360,
|
|
||||||
16640,
|
|
||||||
17920,
|
|
||||||
19200,
|
|
||||||
20480,
|
|
||||||
]
|
|
||||||
results = (
|
|
||||||
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
|
|
||||||
for k in tqdm(k_vals)
|
|
||||||
)
|
|
||||||
|
|
||||||
df = pd.DataFrame.from_records(results)
|
|
||||||
if args.save:
|
|
||||||
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
|
|
||||||
df.to_csv(save_file)
|
|
||||||
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
|
|
||||||
print(df)
|
|
@ -244,18 +244,17 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
|||||||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||||
def test_sp24_compile(self) -> None:
|
def test_sp24_compile(self) -> None:
|
||||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||||
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
|
|
||||||
|
|
||||||
def fn(x, e):
|
def fn(x):
|
||||||
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
||||||
y = y.t()
|
y = y.t()
|
||||||
return x @ y
|
return x @ y
|
||||||
|
|
||||||
# Eager
|
# Eager
|
||||||
output = fn(x, e)
|
output = fn(x)
|
||||||
output.backward(output)
|
output.backward(output)
|
||||||
# Torch compile
|
# Torch compile
|
||||||
output = torch.compile(fn)(x, e)
|
output = torch.compile(fn)(x)
|
||||||
output.backward(output)
|
output.backward(output)
|
||||||
|
|
||||||
class TestSparseSemiStructured(TestCase):
|
class TestSparseSemiStructured(TestCase):
|
||||||
@ -1133,6 +1132,21 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||||||
|
|
||||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
def test_cslt_sparse_mm_alpha_compile_autotune(self, device):
|
||||||
|
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).cuda()
|
||||||
|
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
||||||
|
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
|
||||||
|
|
||||||
|
A_compressed = torch._cslt_compress(A)
|
||||||
|
compiled_sparse_mm = torch.compile(torch._cslt_sparse_mm, mode="max-autotune")
|
||||||
|
sparse_result = compiled_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.int32)
|
||||||
|
|
||||||
|
alpha_scaled = torch.stack([alpha] * 128).t().cpu().float()
|
||||||
|
dense_result = alpha_scaled * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
|
||||||
|
dense_result = dense_result.to(torch.int32)
|
||||||
|
|
||||||
|
torch.testing.assert_close(sparse_result.cpu(), dense_result, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
||||||
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
||||||
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
||||||
@ -1149,21 +1163,6 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||||||
|
|
||||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
@inference_dtypes
|
|
||||||
def test_cslt_sparse_mm_alg_id(self, device, dtype):
|
|
||||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
||||||
A_compressed = torch._cslt_compress(A)
|
|
||||||
B = torch.ones((128, 128), device=device).to(dtype)
|
|
||||||
|
|
||||||
A_compressed = torch._cslt_compress(A)
|
|
||||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
|
||||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
|
||||||
|
|
||||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
|
||||||
dense_result = dense_result.to(dtype)
|
|
||||||
|
|
||||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
@inference_dtypes
|
@inference_dtypes
|
||||||
def test_cslt_sparse_mm_search(self, device, dtype):
|
def test_cslt_sparse_mm_search(self, device, dtype):
|
||||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||||
@ -1172,7 +1171,26 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||||||
|
|
||||||
A_compressed = torch._cslt_compress(A)
|
A_compressed = torch._cslt_compress(A)
|
||||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||||
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
||||||
|
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||||
|
dense_result = dense_result.to(dtype)
|
||||||
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
@inference_dtypes
|
||||||
|
def test_csrc_cslt_sparse_mm_search(self, device, dtype):
|
||||||
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||||
|
A_compressed = torch._cslt_compress(A)
|
||||||
|
B = torch.ones((128, 128), device=device).to(dtype)
|
||||||
|
|
||||||
|
A_compressed = torch._cslt_compress(A)
|
||||||
|
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||||
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
|
||||||
|
alg_id=alg_id,
|
||||||
|
split_k=split_k,
|
||||||
|
split_k_one_kernel=split_k_one_kernel)
|
||||||
|
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||||
|
dense_result = dense_result.to(dtype)
|
||||||
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
def test_cusparselt_backend(self):
|
def test_cusparselt_backend(self):
|
||||||
version = _get_torch_cuda_version()
|
version = _get_torch_cuda_version()
|
||||||
|
@ -520,18 +520,22 @@ def meta__cslt_sparse_mm(
|
|||||||
alpha: Optional[Tensor] = None,
|
alpha: Optional[Tensor] = None,
|
||||||
out_dtype: Optional[torch.dtype] = None,
|
out_dtype: Optional[torch.dtype] = None,
|
||||||
transpose_result: bool = False,
|
transpose_result: bool = False,
|
||||||
|
alg_id: int = 0,
|
||||||
|
split_k: int = 1,
|
||||||
|
split_k_one_kernel: bool = False,
|
||||||
):
|
):
|
||||||
assert dense_B.dtype in {
|
assert dense_B.dtype in {
|
||||||
torch.float32,
|
torch.float32,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
torch.int8,
|
torch.int8,
|
||||||
}, "_cslt_sparse_mm only supports fp16, bf16, and int8"
|
torch.float8_e4m3fn,
|
||||||
|
}, "_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3"
|
||||||
assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
|
assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
|
||||||
assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
|
assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
|
||||||
|
|
||||||
is_int8_input_type = compressed_A.dtype == torch.int8
|
is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn]
|
||||||
compression_factor = 10 if is_int8_input_type else 9
|
compression_factor = 10 if is_8bit_input_type else 9
|
||||||
k = dense_B.size(0)
|
k = dense_B.size(0)
|
||||||
n = dense_B.size(1)
|
n = dense_B.size(1)
|
||||||
m = (compressed_A.numel() * 16) // (compression_factor * k)
|
m = (compressed_A.numel() * 16) // (compression_factor * k)
|
||||||
@ -539,11 +543,16 @@ def meta__cslt_sparse_mm(
|
|||||||
assert m == bias.size(0)
|
assert m == bias.size(0)
|
||||||
|
|
||||||
if out_dtype is not None:
|
if out_dtype is not None:
|
||||||
assert is_int8_input_type and out_dtype in {
|
assert (
|
||||||
torch.float16,
|
is_8bit_input_type
|
||||||
torch.bfloat16,
|
and out_dtype
|
||||||
torch.int32,
|
in {
|
||||||
}, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul"
|
torch.float16,
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.int32,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
}
|
||||||
|
), "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
|
||||||
output_shape = (n, m) if transpose_result else (m, n)
|
output_shape = (n, m) if transpose_result else (m, n)
|
||||||
result = dense_B.new_empty(output_shape, dtype=out_dtype)
|
result = dense_B.new_empty(output_shape, dtype=out_dtype)
|
||||||
return result
|
return result
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
#ifdef USE_CUSPARSELT
|
#ifdef USE_CUSPARSELT
|
||||||
#include <cusparseLt.h>
|
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -9,6 +9,34 @@ size_t getVersionInt() {
|
|||||||
return CUSPARSELT_VERSION;
|
return CUSPARSELT_VERSION;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
|
||||||
|
const at::Tensor& compressed_A,
|
||||||
|
const at::Tensor& dense_B,
|
||||||
|
const std::optional<at::Tensor>& bias_opt,
|
||||||
|
const std::optional<at::Tensor>& alpha_opt,
|
||||||
|
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||||
|
bool transpose_result) {
|
||||||
|
int alg_id_int = 0;
|
||||||
|
int split_k = 1;
|
||||||
|
bool split_k_one_kernel = true;
|
||||||
|
auto result = at::native::_cslt_sparse_mm_impl(
|
||||||
|
compressed_A,
|
||||||
|
dense_B,
|
||||||
|
bias_opt,
|
||||||
|
alpha_opt,
|
||||||
|
out_dtype_opt,
|
||||||
|
transpose_result,
|
||||||
|
alg_id_int,
|
||||||
|
split_k,
|
||||||
|
split_k_one_kernel,
|
||||||
|
true);
|
||||||
|
return {
|
||||||
|
(int64_t)std::get<1>(result),
|
||||||
|
(int64_t)std::get<2>(result),
|
||||||
|
(bool)std::get<3>(result),
|
||||||
|
(int64_t)std::get<4>(result)};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace torch::cuda::shared {
|
namespace torch::cuda::shared {
|
||||||
@ -17,6 +45,7 @@ void initCusparseltBindings(PyObject* module) {
|
|||||||
auto m = py::handle(module).cast<py::module>();
|
auto m = py::handle(module).cast<py::module>();
|
||||||
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
|
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
|
||||||
cusparselt.def("getVersionInt", getVersionInt);
|
cusparselt.def("getVersionInt", getVersionInt);
|
||||||
|
cusparselt.def("mm_search", mmSearch);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace torch::cuda::shared
|
} // namespace torch::cuda::shared
|
||||||
|
@ -103,6 +103,8 @@ def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
|
|||||||
packed_t=self.packed_t,
|
packed_t=self.packed_t,
|
||||||
meta_t=self.meta_t,
|
meta_t=self.meta_t,
|
||||||
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
|
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
|
||||||
|
fuse_transpose_cusparselt=self.fuse_transpose_cusparselt,
|
||||||
|
alg_id_cusparselt=self.alg_id_cusparselt,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user