[sparse] add extra options to _cslt_spare_mm (#137427)

Summary:

Splitting this PR into two, one for the cuSPARSELt improvements, and one
for the inductor lowering.

This PR adds in the additional cuSPARSELt bindings into pytorch.

* `torch._cslt_sparse_mm_search` will be deprecated in a future PR,
  so a warning has been added

* Added a header file for cuSPARSELtOps.cpp

* max_id is now available in `torch.backends.cusparselt` via
  `torch.backends.cusparselt.get_max_alg_id()`

* fixed meta registrations for float8

Test Plan:

python test/test_sparse_semi_structured.py

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch, https://github.com/eqy
This commit is contained in:
Jesse Cai
2024-11-21 14:34:59 -05:00
committed by PyTorch MergeBot
parent 9a72939042
commit 45b30a5aec
8 changed files with 193 additions and 292 deletions

View File

@ -3362,7 +3362,7 @@
dispatch:
CUDA: _cslt_compress
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, bool split_k_one_kernel=True) -> Tensor
dispatch:
CUDA: _cslt_sparse_mm

View File

@ -1,20 +1,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <ATen/cuda/CUDASparse.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Functions.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Half.h>
#include <cusparse.h>
#include <cstdint>
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
#if AT_CUSPARSELT_ENABLED()
#include <cusparseLt.h>
namespace at::native {
// Ideally we would use the same DeviceThreadHandlePool mechanism as used in aten/src/ATen/cuda/CuSparseHandlePool.cpp
@ -56,6 +43,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input)
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
case at::ScalarType::Float8_e4m3fn:
type = CUDA_R_8F_E4M3;
compression_factor = 10;
break;
#endif
default:
@ -103,7 +91,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input)
return compressed_tensor;
}
std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
const Tensor& compressed_A,
const Tensor& dense_B,
const std::optional<Tensor>& bias_opt,
@ -111,6 +99,8 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result,
int alg_id,
int split_k,
bool split_k_one_kernel,
bool search_alg_id
)
{
@ -169,6 +159,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
output_type = CUDA_R_8F_E4M3;
C_type = CUDA_R_16F;
compute_type = CUSPARSE_COMPUTE_32F;
compression_factor = 10;
break;
#endif
// cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F
@ -335,16 +326,29 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit(
&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT));
// set alg_id
// set matmul search params
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id)));
cusparseLtSplitKMode_t splitKMode;
int max_alg_id;
if (split_k != 1) {
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
&handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K, &split_k, sizeof(split_k)));
splitKMode = split_k_one_kernel ? CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL : CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS;
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
&handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode)));
}
// set tensor_alpha_mode and alpha pointer for matmul
const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{};
auto alpha_ptr = &alpha;
if (alpha_opt.has_value()) {
if (alpha_tensor.numel() == 1) {
alpha = alpha_tensor.item<float>();
// * static_cast<float*>(alpha_tensor.data_ptr());
// std::cout<<alpha<<std::endl;
}
else {
tensor_alpha_mode = 1;
@ -381,9 +385,23 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
&stream,
1));
// get alg_id used
// get matmul params used
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id)));
TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel,
CUSPARSELT_MATMUL_SPLIT_K,
&split_k, sizeof(split_k)));
TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel,
CUSPARSELT_MATMUL_SPLIT_K_MODE,
&splitKMode, sizeof(splitKMode)));
TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel,
CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID,
&max_alg_id, sizeof(max_alg_id)));
}
else {
// do normal matmul
@ -411,7 +429,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
// destroy plan
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanDestroy(&plan));
return {alg_id, res};
return {res, alg_id, split_k, splitKMode == CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL, max_alg_id};
}
at::Tensor _cslt_sparse_mm(
@ -421,7 +439,9 @@ at::Tensor _cslt_sparse_mm(
const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result,
int64_t alg_id
int64_t alg_id,
int64_t split_k,
bool split_k_one_kernel
)
{
auto result = _cslt_sparse_mm_impl(
@ -432,8 +452,10 @@ at::Tensor _cslt_sparse_mm(
out_dtype_opt,
transpose_result,
(int) alg_id,
(int) split_k,
split_k_one_kernel,
false);
return std::get<1>(result);
return std::get<0>(result);
}
int64_t _cslt_sparse_mm_search(
@ -445,7 +467,10 @@ int64_t _cslt_sparse_mm_search(
bool transpose_result
)
{
TORCH_WARN_ONCE("torch._cslt_sparse_mm_search is deprecated and will be removed in a future PyTorch release. Please use torch._C._cusparselt.mm_search instead.");
int alg_id_int = 0;
int split_k = 1;
bool split_k_one_kernel= true;
auto result = _cslt_sparse_mm_impl(
compressed_A,
dense_B,
@ -454,11 +479,12 @@ int64_t _cslt_sparse_mm_search(
out_dtype_opt,
transpose_result,
alg_id_int,
split_k,
split_k_one_kernel,
true);
return (int64_t) std::get<0>(result);
return (int64_t) std::get<1>(result);
}
} // namespace at::native
#else // No cuSPARSELt support, throw error if these functions are called.
@ -476,7 +502,9 @@ at::Tensor _cslt_sparse_mm(
const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype,
bool transpose_result,
int64_t alg_id)
int64_t alg_id,
int64_t split_k,
bool split_k_one_kernel)
{
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
}

View File

@ -0,0 +1,58 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <ATen/cuda/CUDASparse.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Functions.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Half.h>
#include <cusparse.h>
#include <cstdint>
#if AT_CUSPARSELT_ENABLED()
#include <cusparseLt.h>
#endif
namespace at::native {
at::Tensor _cslt_compress(const Tensor& sparse_input);
TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
const Tensor& compressed_A,
const Tensor& dense_B,
const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result,
int alg_id,
int split_k,
bool split_k_one_kernel,
bool search_alg_id
);
at::Tensor _cslt_sparse_mm(
const Tensor& compressed_A,
const Tensor& dense_B,
const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result,
int64_t alg_id,
int64_t split_k,
bool split_k_one_kernel
);
int64_t _cslt_sparse_mm_search(
const Tensor& compressed_A,
const Tensor& dense_B,
const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result
);
} // namespace at::native

View File

@ -1,253 +0,0 @@
import argparse
import random
import pandas as pd
from tqdm import tqdm
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
torch.set_printoptions(
precision=2,
threshold=None,
edgeitems=16,
linewidth=480,
profile=None,
sci_mode=False,
)
# helper model definition for pruner
class Model(nn.Module):
def __init__(self, m, k, dtype=None):
super().__init__()
# transposed so reversed
self.linear = nn.Linear(k, m)
def forward(self, x):
return self.linear(x)
def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""
choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)
def test_linear(m, k, n, dtype, contiguous, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
input_tensor = torch.zeros(n, k).to(dtype).cuda()
model = Model(m, k).to(dtype).cuda().eval()
dense_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
dense_output = model(input_tensor)
print(dense_output.shape)
# sparsify weights
model.linear.weight = nn.Parameter(
to_sparse_semi_structured(
sparse_weight,
)
)
sparse_output = model(input_tensor)
print(sparse_output.shape)
sparse_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "linear",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
def test_tensor(m, k, n, dtype, contiguous, backend):
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
B = torch.zeros(k, n).to(dtype).cuda()
bias = torch.rand(n).to(dtype).cuda()
sA = to_sparse_semi_structured(A)
# torch.mm calculation
if dtype is not torch.int8:
dense_output = torch.mm(A, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(A, B)",
globals=locals(),
).blocked_autorange()
else:
print("int8 baseline not supported")
dense_output = torch.mm(sA, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
sparse_output = torch.mm(sA, B)
sparse_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "tensor",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
if __name__ == "__main__":
dtype_lookup = {
"int8": torch.int8,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
parser.add_argument(
"--mode",
type=str,
choices=[
"nvidia-bert",
"nvidia-fixed-k",
"nvidia-fixed-mn",
],
)
parser.add_argument(
"--dtype",
type=str,
choices=dtype_lookup.keys(),
default="fp16",
)
parser.add_argument(
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
)
parser.add_argument("-contiguous", action="store_true")
parser.add_argument("-e2e", action="store_true")
parser.add_argument("-save", action="store_true")
args = parser.parse_args()
if args.e2e:
eval_fn = test_linear
else:
eval_fn = test_tensor
print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
dtype = dtype_lookup[args.dtype]
if args.mode == "nvidia-bert":
bert_shapes = [
(3072, 1024, 16384),
(4096, 1024, 16384),
(1024, 1024, 16384),
(1024, 4096, 16384),
]
results = (
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
for (m, k, n) in tqdm(bert_shapes)
)
elif args.mode == "nvidia-fixed-k":
mn_vals = [
3072,
4096,
5120,
6144,
7168,
8192,
9216,
10240,
11264,
12288,
13312,
14336,
15360,
16384,
17408,
18432,
19456,
20480,
]
results = (
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
for mn in tqdm(mn_vals)
)
elif args.mode == "nvidia-fixed-mn":
k_vals = [
2560,
3840,
5120,
6400,
7680,
8960,
10240,
11520,
12800,
14080,
15360,
16640,
17920,
19200,
20480,
]
results = (
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
for k in tqdm(k_vals)
)
df = pd.DataFrame.from_records(results)
if args.save:
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
df.to_csv(save_file)
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
print(df)

View File

@ -244,18 +244,17 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
def test_sp24_compile(self) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
def fn(x, e):
def fn(x):
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
y = y.t()
return x @ y
# Eager
output = fn(x, e)
output = fn(x)
output.backward(output)
# Torch compile
output = torch.compile(fn)(x, e)
output = torch.compile(fn)(x)
output.backward(output)
class TestSparseSemiStructured(TestCase):
@ -1133,6 +1132,21 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
def test_cslt_sparse_mm_alpha_compile_autotune(self, device):
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).cuda()
B = torch.ones((128, 256), device=device).to(torch.int8).t()
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
A_compressed = torch._cslt_compress(A)
compiled_sparse_mm = torch.compile(torch._cslt_sparse_mm, mode="max-autotune")
sparse_result = compiled_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.int32)
alpha_scaled = torch.stack([alpha] * 128).t().cpu().float()
dense_result = alpha_scaled * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
dense_result = dense_result.to(torch.int32)
torch.testing.assert_close(sparse_result.cpu(), dense_result, rtol=1e-3, atol=1e-3)
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
@ -1156,8 +1170,9 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
alg_id=alg_id, split_k=split_k, split_k_one_kernel=split_k_one_kernel)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
@ -1174,6 +1189,16 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
@inference_dtypes
def test_csrc_cslt_sparse_mm_search(self, device, dtype):
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id, _, _, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
def test_cusparselt_backend(self):
version = _get_torch_cuda_version()
assert torch.backends.cusparselt.is_available()

View File

@ -520,30 +520,42 @@ def meta__cslt_sparse_mm(
alpha: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
transpose_result: bool = False,
alg_id: int = 0,
split_k: int = 1,
split_k_one_kernel: bool = False,
):
assert dense_B.dtype in {
torch.float32,
torch.float16,
torch.bfloat16,
torch.int8,
}, "_cslt_sparse_mm only supports fp16, bf16, and int8"
torch.float8_e4m3fn,
}, "_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3"
assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
is_int8_input_type = compressed_A.dtype == torch.int8
compression_factor = 10 if is_int8_input_type else 9
is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn]
compression_factor = 10 if is_8bit_input_type else 9
k = dense_B.size(0)
n = dense_B.size(1)
m = (compressed_A.numel() * 16) // (compression_factor * k)
if bias is not None:
assert m == bias.size(0)
if is_8bit_input_type:
assert not dense_B.is_contiguous()
if out_dtype is not None:
assert is_int8_input_type and out_dtype in {
torch.float16,
torch.bfloat16,
torch.int32,
}, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul"
assert (
is_8bit_input_type
and out_dtype
in {
torch.float16,
torch.bfloat16,
torch.int32,
torch.float8_e4m3fn,
}
), "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
output_shape = (n, m) if transpose_result else (m, n)
result = dense_B.new_empty(output_shape, dtype=out_dtype)
return result

View File

@ -1,7 +1,7 @@
#include <torch/csrc/utils/pybind.h>
#ifdef USE_CUSPARSELT
#include <cusparseLt.h>
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
namespace {
@ -9,6 +9,34 @@ size_t getVersionInt() {
return CUSPARSELT_VERSION;
}
std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
const at::Tensor& compressed_A,
const at::Tensor& dense_B,
const std::optional<at::Tensor>& bias_opt,
const std::optional<at::Tensor>& alpha_opt,
const std::optional<c10::ScalarType> out_dtype_opt,
bool transpose_result) {
int alg_id_int = 0;
int split_k = 1;
bool split_k_one_kernel = true;
auto result = at::native::_cslt_sparse_mm_impl(
compressed_A,
dense_B,
bias_opt,
alpha_opt,
out_dtype_opt,
transpose_result,
alg_id_int,
split_k,
split_k_one_kernel,
true);
return {
(int64_t)std::get<1>(result),
(int64_t)std::get<2>(result),
(bool)std::get<3>(result),
(int64_t)std::get<4>(result)};
}
} // namespace
namespace torch::cuda::shared {
@ -17,6 +45,7 @@ void initCusparseltBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
cusparselt.def("getVersionInt", getVersionInt);
cusparselt.def("mm_search", mmSearch);
}
} // namespace torch::cuda::shared

View File

@ -103,6 +103,8 @@ def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
packed_t=self.packed_t,
meta_t=self.meta_t,
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
fuse_transpose_cusparselt=self.fuse_transpose_cusparselt,
alg_id_cusparselt=self.alg_id_cusparselt,
requires_grad=False,
)