mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] add extra options to _cslt_spare_mm (#137427)
Summary: Splitting this PR into two, one for the cuSPARSELt improvements, and one for the inductor lowering. This PR adds in the additional cuSPARSELt bindings into pytorch. * `torch._cslt_sparse_mm_search` will be deprecated in a future PR, so a warning has been added * Added a header file for cuSPARSELtOps.cpp * max_id is now available in `torch.backends.cusparselt` via `torch.backends.cusparselt.get_max_alg_id()` * fixed meta registrations for float8 Test Plan: python test/test_sparse_semi_structured.py Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427 Approved by: https://github.com/cpuhrsch, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
9a72939042
commit
45b30a5aec
@ -3362,7 +3362,7 @@
|
||||
dispatch:
|
||||
CUDA: _cslt_compress
|
||||
|
||||
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor
|
||||
- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, bool split_k_one_kernel=True) -> Tensor
|
||||
dispatch:
|
||||
CUDA: _cslt_sparse_mm
|
||||
|
||||
|
@ -1,20 +1,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
#include <ATen/cuda/CUDASparse.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <cusparse.h>
|
||||
#include <cstdint>
|
||||
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
||||
|
||||
#if AT_CUSPARSELT_ENABLED()
|
||||
|
||||
#include <cusparseLt.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
// Ideally we would use the same DeviceThreadHandlePool mechanism as used in aten/src/ATen/cuda/CuSparseHandlePool.cpp
|
||||
@ -56,6 +43,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input)
|
||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
type = CUDA_R_8F_E4M3;
|
||||
compression_factor = 10;
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
@ -103,7 +91,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input)
|
||||
return compressed_tensor;
|
||||
}
|
||||
|
||||
std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
||||
std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
@ -111,6 +99,8 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result,
|
||||
int alg_id,
|
||||
int split_k,
|
||||
bool split_k_one_kernel,
|
||||
bool search_alg_id
|
||||
)
|
||||
{
|
||||
@ -169,6 +159,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
||||
output_type = CUDA_R_8F_E4M3;
|
||||
C_type = CUDA_R_16F;
|
||||
compute_type = CUSPARSE_COMPUTE_32F;
|
||||
compression_factor = 10;
|
||||
break;
|
||||
#endif
|
||||
// cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F
|
||||
@ -335,16 +326,29 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSelectionInit(
|
||||
&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT));
|
||||
|
||||
// set alg_id
|
||||
// set matmul search params
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id)));
|
||||
|
||||
cusparseLtSplitKMode_t splitKMode;
|
||||
int max_alg_id;
|
||||
if (split_k != 1) {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||
&handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K, &split_k, sizeof(split_k)));
|
||||
|
||||
splitKMode = split_k_one_kernel ? CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL : CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS;
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgSetAttribute(
|
||||
&handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode)));
|
||||
}
|
||||
|
||||
// set tensor_alpha_mode and alpha pointer for matmul
|
||||
const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{};
|
||||
auto alpha_ptr = α
|
||||
if (alpha_opt.has_value()) {
|
||||
if (alpha_tensor.numel() == 1) {
|
||||
alpha = alpha_tensor.item<float>();
|
||||
// * static_cast<float*>(alpha_tensor.data_ptr());
|
||||
// std::cout<<alpha<<std::endl;
|
||||
}
|
||||
else {
|
||||
tensor_alpha_mode = 1;
|
||||
@ -381,9 +385,23 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
||||
&stream,
|
||||
1));
|
||||
|
||||
// get alg_id used
|
||||
// get matmul params used
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute(
|
||||
&handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg_id, sizeof(alg_id)));
|
||||
|
||||
TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel,
|
||||
CUSPARSELT_MATMUL_SPLIT_K,
|
||||
&split_k, sizeof(split_k)));
|
||||
|
||||
TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel,
|
||||
CUSPARSELT_MATMUL_SPLIT_K_MODE,
|
||||
&splitKMode, sizeof(splitKMode)));
|
||||
|
||||
TORCH_CUDASPARSE_CHECK( cusparseLtMatmulAlgGetAttribute(&handle, &alg_sel,
|
||||
CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID,
|
||||
&max_alg_id, sizeof(max_alg_id)));
|
||||
|
||||
|
||||
}
|
||||
else {
|
||||
// do normal matmul
|
||||
@ -411,7 +429,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
||||
// destroy plan
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulPlanDestroy(&plan));
|
||||
|
||||
return {alg_id, res};
|
||||
return {res, alg_id, split_k, splitKMode == CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL, max_alg_id};
|
||||
}
|
||||
|
||||
at::Tensor _cslt_sparse_mm(
|
||||
@ -421,7 +439,9 @@ at::Tensor _cslt_sparse_mm(
|
||||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result,
|
||||
int64_t alg_id
|
||||
int64_t alg_id,
|
||||
int64_t split_k,
|
||||
bool split_k_one_kernel
|
||||
)
|
||||
{
|
||||
auto result = _cslt_sparse_mm_impl(
|
||||
@ -432,8 +452,10 @@ at::Tensor _cslt_sparse_mm(
|
||||
out_dtype_opt,
|
||||
transpose_result,
|
||||
(int) alg_id,
|
||||
(int) split_k,
|
||||
split_k_one_kernel,
|
||||
false);
|
||||
return std::get<1>(result);
|
||||
return std::get<0>(result);
|
||||
}
|
||||
|
||||
int64_t _cslt_sparse_mm_search(
|
||||
@ -445,7 +467,10 @@ int64_t _cslt_sparse_mm_search(
|
||||
bool transpose_result
|
||||
)
|
||||
{
|
||||
TORCH_WARN_ONCE("torch._cslt_sparse_mm_search is deprecated and will be removed in a future PyTorch release. Please use torch._C._cusparselt.mm_search instead.");
|
||||
int alg_id_int = 0;
|
||||
int split_k = 1;
|
||||
bool split_k_one_kernel= true;
|
||||
auto result = _cslt_sparse_mm_impl(
|
||||
compressed_A,
|
||||
dense_B,
|
||||
@ -454,11 +479,12 @@ int64_t _cslt_sparse_mm_search(
|
||||
out_dtype_opt,
|
||||
transpose_result,
|
||||
alg_id_int,
|
||||
split_k,
|
||||
split_k_one_kernel,
|
||||
true);
|
||||
return (int64_t) std::get<0>(result);
|
||||
return (int64_t) std::get<1>(result);
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
#else // No cuSPARSELt support, throw error if these functions are called.
|
||||
@ -476,7 +502,9 @@ at::Tensor _cslt_sparse_mm(
|
||||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
bool transpose_result,
|
||||
int64_t alg_id)
|
||||
int64_t alg_id,
|
||||
int64_t split_k,
|
||||
bool split_k_one_kernel)
|
||||
{
|
||||
TORCH_CHECK(false, "cuSPARSELt not supported on your machine.");
|
||||
}
|
||||
|
58
aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h
Normal file
58
aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.h
Normal file
@ -0,0 +1,58 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
#include <ATen/cuda/CUDASparse.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <cusparse.h>
|
||||
#include <cstdint>
|
||||
|
||||
#if AT_CUSPARSELT_ENABLED()
|
||||
#include <cusparseLt.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
at::Tensor _cslt_compress(const Tensor& sparse_input);
|
||||
|
||||
TORCH_CUDA_CPP_API std::tuple<at::Tensor, int64_t, int64_t, bool, int64_t> _cslt_sparse_mm_impl(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result,
|
||||
int alg_id,
|
||||
int split_k,
|
||||
bool split_k_one_kernel,
|
||||
bool search_alg_id
|
||||
);
|
||||
|
||||
at::Tensor _cslt_sparse_mm(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result,
|
||||
int64_t alg_id,
|
||||
int64_t split_k,
|
||||
bool split_k_one_kernel
|
||||
);
|
||||
|
||||
int64_t _cslt_sparse_mm_search(
|
||||
const Tensor& compressed_A,
|
||||
const Tensor& dense_B,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const std::optional<Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result
|
||||
);
|
||||
|
||||
} // namespace at::native
|
@ -1,253 +0,0 @@
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from torch import nn
|
||||
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
|
||||
|
||||
|
||||
torch.set_printoptions(
|
||||
precision=2,
|
||||
threshold=None,
|
||||
edgeitems=16,
|
||||
linewidth=480,
|
||||
profile=None,
|
||||
sci_mode=False,
|
||||
)
|
||||
|
||||
|
||||
# helper model definition for pruner
|
||||
class Model(nn.Module):
|
||||
def __init__(self, m, k, dtype=None):
|
||||
super().__init__()
|
||||
# transposed so reversed
|
||||
self.linear = nn.Linear(k, m)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
def rand_sparse_semi_structured_mask(
|
||||
r, c, dtype=torch.float16, device="cuda", choice=None
|
||||
):
|
||||
"""
|
||||
This function returns a 1:2 sparse matrix of size (r, c).
|
||||
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
|
||||
"""
|
||||
|
||||
choices = [[0, 1], [1, 0]]
|
||||
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
||||
|
||||
return (
|
||||
torch.tensor(mask_entries, dtype=dtype, device=device)
|
||||
.reshape(r, c)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
|
||||
def test_linear(m, k, n, dtype, contiguous, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
|
||||
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
|
||||
input_tensor = torch.zeros(n, k).to(dtype).cuda()
|
||||
model = Model(m, k).to(dtype).cuda().eval()
|
||||
|
||||
dense_measurement = benchmark.Timer(
|
||||
stmt="model(input_tensor)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
dense_output = model(input_tensor)
|
||||
print(dense_output.shape)
|
||||
|
||||
# sparsify weights
|
||||
model.linear.weight = nn.Parameter(
|
||||
to_sparse_semi_structured(
|
||||
sparse_weight,
|
||||
)
|
||||
)
|
||||
|
||||
sparse_output = model(input_tensor)
|
||||
print(sparse_output.shape)
|
||||
|
||||
sparse_measurement = benchmark.Timer(
|
||||
stmt="model(input_tensor)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
|
||||
|
||||
return {
|
||||
"test_function": "linear",
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"dtype": str(dtype),
|
||||
"backend": backend,
|
||||
"sparse_latency (ms)": sparse_measurement.median * 1000,
|
||||
"dense_latency (ms)": dense_measurement.median * 1000,
|
||||
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
|
||||
"correct": correct,
|
||||
"contiguous": sparse_output.is_contiguous(),
|
||||
}
|
||||
|
||||
|
||||
def test_tensor(m, k, n, dtype, contiguous, backend):
|
||||
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
|
||||
B = torch.zeros(k, n).to(dtype).cuda()
|
||||
bias = torch.rand(n).to(dtype).cuda()
|
||||
|
||||
sA = to_sparse_semi_structured(A)
|
||||
|
||||
# torch.mm calculation
|
||||
if dtype is not torch.int8:
|
||||
dense_output = torch.mm(A, B)
|
||||
|
||||
dense_measurement = benchmark.Timer(
|
||||
stmt="torch.mm(A, B)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
else:
|
||||
print("int8 baseline not supported")
|
||||
dense_output = torch.mm(sA, B)
|
||||
|
||||
dense_measurement = benchmark.Timer(
|
||||
stmt="torch.mm(sA, B)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
sparse_output = torch.mm(sA, B)
|
||||
sparse_measurement = benchmark.Timer(
|
||||
stmt="torch.mm(sA, B)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
|
||||
|
||||
return {
|
||||
"test_function": "tensor",
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"dtype": str(dtype),
|
||||
"backend": backend,
|
||||
"sparse_latency (ms)": sparse_measurement.median * 1000,
|
||||
"dense_latency (ms)": dense_measurement.median * 1000,
|
||||
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
|
||||
"correct": correct,
|
||||
"contiguous": sparse_output.is_contiguous(),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype_lookup = {
|
||||
"int8": torch.int8,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"nvidia-bert",
|
||||
"nvidia-fixed-k",
|
||||
"nvidia-fixed-mn",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=dtype_lookup.keys(),
|
||||
default="fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
|
||||
)
|
||||
parser.add_argument("-contiguous", action="store_true")
|
||||
parser.add_argument("-e2e", action="store_true")
|
||||
parser.add_argument("-save", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.e2e:
|
||||
eval_fn = test_linear
|
||||
else:
|
||||
eval_fn = test_tensor
|
||||
|
||||
print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
|
||||
dtype = dtype_lookup[args.dtype]
|
||||
|
||||
if args.mode == "nvidia-bert":
|
||||
bert_shapes = [
|
||||
(3072, 1024, 16384),
|
||||
(4096, 1024, 16384),
|
||||
(1024, 1024, 16384),
|
||||
(1024, 4096, 16384),
|
||||
]
|
||||
results = (
|
||||
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
|
||||
for (m, k, n) in tqdm(bert_shapes)
|
||||
)
|
||||
|
||||
elif args.mode == "nvidia-fixed-k":
|
||||
mn_vals = [
|
||||
3072,
|
||||
4096,
|
||||
5120,
|
||||
6144,
|
||||
7168,
|
||||
8192,
|
||||
9216,
|
||||
10240,
|
||||
11264,
|
||||
12288,
|
||||
13312,
|
||||
14336,
|
||||
15360,
|
||||
16384,
|
||||
17408,
|
||||
18432,
|
||||
19456,
|
||||
20480,
|
||||
]
|
||||
results = (
|
||||
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
|
||||
for mn in tqdm(mn_vals)
|
||||
)
|
||||
|
||||
elif args.mode == "nvidia-fixed-mn":
|
||||
k_vals = [
|
||||
2560,
|
||||
3840,
|
||||
5120,
|
||||
6400,
|
||||
7680,
|
||||
8960,
|
||||
10240,
|
||||
11520,
|
||||
12800,
|
||||
14080,
|
||||
15360,
|
||||
16640,
|
||||
17920,
|
||||
19200,
|
||||
20480,
|
||||
]
|
||||
results = (
|
||||
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
|
||||
for k in tqdm(k_vals)
|
||||
)
|
||||
|
||||
df = pd.DataFrame.from_records(results)
|
||||
if args.save:
|
||||
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
|
||||
df.to_csv(save_file)
|
||||
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
|
||||
print(df)
|
@ -244,18 +244,17 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
|
||||
|
||||
def fn(x, e):
|
||||
def fn(x):
|
||||
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
||||
y = y.t()
|
||||
return x @ y
|
||||
|
||||
# Eager
|
||||
output = fn(x, e)
|
||||
output = fn(x)
|
||||
output.backward(output)
|
||||
# Torch compile
|
||||
output = torch.compile(fn)(x, e)
|
||||
output = torch.compile(fn)(x)
|
||||
output.backward(output)
|
||||
|
||||
class TestSparseSemiStructured(TestCase):
|
||||
@ -1133,6 +1132,21 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_cslt_sparse_mm_alpha_compile_autotune(self, device):
|
||||
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).cuda()
|
||||
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
||||
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
compiled_sparse_mm = torch.compile(torch._cslt_sparse_mm, mode="max-autotune")
|
||||
sparse_result = compiled_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.int32)
|
||||
|
||||
alpha_scaled = torch.stack([alpha] * 128).t().cpu().float()
|
||||
dense_result = alpha_scaled * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
|
||||
dense_result = dense_result.to(torch.int32)
|
||||
|
||||
torch.testing.assert_close(sparse_result.cpu(), dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
||||
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
||||
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
||||
@ -1156,8 +1170,9 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
B = torch.ones((128, 128), device=device).to(dtype)
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
||||
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
|
||||
alg_id=alg_id, split_k=split_k, split_k_one_kernel=split_k_one_kernel)
|
||||
|
||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||
dense_result = dense_result.to(dtype)
|
||||
@ -1174,6 +1189,16 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
|
||||
|
||||
@inference_dtypes
|
||||
def test_csrc_cslt_sparse_mm_search(self, device, dtype):
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
B = torch.ones((128, 128), device=device).to(dtype)
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id, _, _, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
|
||||
|
||||
def test_cusparselt_backend(self):
|
||||
version = _get_torch_cuda_version()
|
||||
assert torch.backends.cusparselt.is_available()
|
||||
|
@ -520,30 +520,42 @@ def meta__cslt_sparse_mm(
|
||||
alpha: Optional[Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
transpose_result: bool = False,
|
||||
alg_id: int = 0,
|
||||
split_k: int = 1,
|
||||
split_k_one_kernel: bool = False,
|
||||
):
|
||||
assert dense_B.dtype in {
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int8,
|
||||
}, "_cslt_sparse_mm only supports fp16, bf16, and int8"
|
||||
torch.float8_e4m3fn,
|
||||
}, "_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3"
|
||||
assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
|
||||
assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
|
||||
|
||||
is_int8_input_type = compressed_A.dtype == torch.int8
|
||||
compression_factor = 10 if is_int8_input_type else 9
|
||||
is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn]
|
||||
compression_factor = 10 if is_8bit_input_type else 9
|
||||
k = dense_B.size(0)
|
||||
n = dense_B.size(1)
|
||||
m = (compressed_A.numel() * 16) // (compression_factor * k)
|
||||
if bias is not None:
|
||||
assert m == bias.size(0)
|
||||
|
||||
if is_8bit_input_type:
|
||||
assert not dense_B.is_contiguous()
|
||||
|
||||
if out_dtype is not None:
|
||||
assert is_int8_input_type and out_dtype in {
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int32,
|
||||
}, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul"
|
||||
assert (
|
||||
is_8bit_input_type
|
||||
and out_dtype
|
||||
in {
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.int32,
|
||||
torch.float8_e4m3fn,
|
||||
}
|
||||
), "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
|
||||
output_shape = (n, m) if transpose_result else (m, n)
|
||||
result = dense_B.new_empty(output_shape, dtype=out_dtype)
|
||||
return result
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#ifdef USE_CUSPARSELT
|
||||
#include <cusparseLt.h>
|
||||
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
||||
|
||||
namespace {
|
||||
|
||||
@ -9,6 +9,34 @@ size_t getVersionInt() {
|
||||
return CUSPARSELT_VERSION;
|
||||
}
|
||||
|
||||
std::tuple<int64_t, int64_t, bool, int64_t> mmSearch(
|
||||
const at::Tensor& compressed_A,
|
||||
const at::Tensor& dense_B,
|
||||
const std::optional<at::Tensor>& bias_opt,
|
||||
const std::optional<at::Tensor>& alpha_opt,
|
||||
const std::optional<c10::ScalarType> out_dtype_opt,
|
||||
bool transpose_result) {
|
||||
int alg_id_int = 0;
|
||||
int split_k = 1;
|
||||
bool split_k_one_kernel = true;
|
||||
auto result = at::native::_cslt_sparse_mm_impl(
|
||||
compressed_A,
|
||||
dense_B,
|
||||
bias_opt,
|
||||
alpha_opt,
|
||||
out_dtype_opt,
|
||||
transpose_result,
|
||||
alg_id_int,
|
||||
split_k,
|
||||
split_k_one_kernel,
|
||||
true);
|
||||
return {
|
||||
(int64_t)std::get<1>(result),
|
||||
(int64_t)std::get<2>(result),
|
||||
(bool)std::get<3>(result),
|
||||
(int64_t)std::get<4>(result)};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace torch::cuda::shared {
|
||||
@ -17,6 +45,7 @@ void initCusparseltBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
|
||||
cusparselt.def("getVersionInt", getVersionInt);
|
||||
cusparselt.def("mm_search", mmSearch);
|
||||
}
|
||||
|
||||
} // namespace torch::cuda::shared
|
||||
|
@ -103,6 +103,8 @@ def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
|
||||
packed_t=self.packed_t,
|
||||
meta_t=self.meta_t,
|
||||
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
|
||||
fuse_transpose_cusparselt=self.fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=self.alg_id_cusparselt,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user