mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse][semi-structured] Add float8 dtype support to 24 sparsity (#136397)
Summary: This PR adds `torch.float8e4m3fn` support to cuSPARSELt and `to_sparse_semi_structured`. This will let users to run fp8 + 2:4 sparse matmuls on Hopper GPUs with cusparselt >= 0.6.2, via to `scaled_mm` API. ``` A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16) B = torch.rand(dense_input_shape, device=device).to(torch.float16).t() A_fp8, A_scale = to_float8(A) B_fp8, B_scale = to_float8(B) dense_result = torch._scaled_mm( A_fp8, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype ) A_fp8_sparse = to_sparse_semi_structured(A_fp8) sparse_result = torch._scaled_mm( A_fp8_sparse, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype ) ``` Note that to keep this consistent with normal torch behavior, calling `torch.mm(A_fp8_sparse, B_fp8)` will raise a NotImplementedError. I also turned on cuSPARSELt by default and added CUSPARSELT_MAX_ID to the backend to make the tests a bit cleaner Test Plan: ``` python test/test_sparse_semi_structured -k scaled_mm python test/test_sparse_semi_structured -k fp8 ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/136397 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
a28b40fa74
commit
bc21689136
@ -53,6 +53,11 @@ at::Tensor _cslt_compress(const Tensor& sparse_input)
|
|||||||
case at::ScalarType::Float:
|
case at::ScalarType::Float:
|
||||||
type = CUDA_R_32F;
|
type = CUDA_R_32F;
|
||||||
break;
|
break;
|
||||||
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||||
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
|
type = CUDA_R_8F_E4M3;
|
||||||
|
break;
|
||||||
|
#endif
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix");
|
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix");
|
||||||
break;
|
break;
|
||||||
@ -123,15 +128,16 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
float beta = 0.0;
|
float beta = 0.0;
|
||||||
cudaDataType input_type;
|
cudaDataType input_type;
|
||||||
cudaDataType output_type;
|
cudaDataType output_type;
|
||||||
|
cudaDataType C_type;
|
||||||
cusparseComputeType compute_type;
|
cusparseComputeType compute_type;
|
||||||
auto compression_factor = 9;
|
auto compression_factor = 9;
|
||||||
|
|
||||||
|
|
||||||
switch(compressed_A.scalar_type())
|
switch(compressed_A.scalar_type())
|
||||||
{
|
{
|
||||||
case at::ScalarType::Char:
|
case at::ScalarType::Char:
|
||||||
input_type = CUDA_R_8I;
|
input_type = CUDA_R_8I;
|
||||||
output_type = CUDA_R_8I;
|
output_type = CUDA_R_8I;
|
||||||
|
C_type = CUDA_R_8I;
|
||||||
compute_type = CUSPARSE_COMPUTE_32I;
|
compute_type = CUSPARSE_COMPUTE_32I;
|
||||||
compression_factor = 10;
|
compression_factor = 10;
|
||||||
break;
|
break;
|
||||||
@ -141,61 +147,111 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
input_type = CUDA_R_16F;
|
input_type = CUDA_R_16F;
|
||||||
output_type = CUDA_R_16F;
|
output_type = CUDA_R_16F;
|
||||||
|
C_type = CUDA_R_16F;
|
||||||
compute_type = CUSPARSE_COMPUTE_32F;
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::BFloat16:
|
case at::ScalarType::BFloat16:
|
||||||
input_type = CUDA_R_16BF;
|
input_type = CUDA_R_16BF;
|
||||||
output_type = CUDA_R_16BF;
|
output_type = CUDA_R_16BF;
|
||||||
|
C_type = CUDA_R_16BF;
|
||||||
compute_type = CUSPARSE_COMPUTE_32F;
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::Float:
|
case at::ScalarType::Float:
|
||||||
input_type = CUDA_R_32F;
|
input_type = CUDA_R_32F;
|
||||||
output_type = CUDA_R_32F;
|
output_type = CUDA_R_32F;
|
||||||
|
C_type = CUDA_R_32F;
|
||||||
compute_type = CUSPARSE_COMPUTE_32F;
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
break;
|
break;
|
||||||
|
// if cuSPARSELt >= 6.2.3, we can add Float8 support
|
||||||
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||||
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
|
input_type = CUDA_R_8F_E4M3;
|
||||||
|
output_type = CUDA_R_8F_E4M3;
|
||||||
|
C_type = CUDA_R_16F;
|
||||||
|
compute_type = CUSPARSE_COMPUTE_32F;
|
||||||
|
break;
|
||||||
|
#endif
|
||||||
// cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F
|
// cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F
|
||||||
#else
|
#else
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
input_type = CUDA_R_16F;
|
input_type = CUDA_R_16F;
|
||||||
output_type = CUDA_R_16F;
|
output_type = CUDA_R_16F;
|
||||||
|
C_type = CUDA_R_16F;
|
||||||
compute_type = CUSPARSE_COMPUTE_16F;
|
compute_type = CUSPARSE_COMPUTE_16F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::BFloat16:
|
case at::ScalarType::BFloat16:
|
||||||
input_type = CUDA_R_16BF;
|
input_type = CUDA_R_16BF;
|
||||||
output_type = CUDA_R_16BF;
|
output_type = CUDA_R_16BF;
|
||||||
|
C_type = CUDA_R_16BF;
|
||||||
compute_type = CUSPARSE_COMPUTE_16F;
|
compute_type = CUSPARSE_COMPUTE_16F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::Float:
|
case at::ScalarType::Float:
|
||||||
input_type = CUDA_R_32F;
|
input_type = CUDA_R_32F;
|
||||||
output_type = CUDA_R_32F;
|
output_type = CUDA_R_32F;
|
||||||
|
C_type = CUDA_R_32F;
|
||||||
compute_type = CUSPARSE_COMPUTE_TF32;
|
compute_type = CUSPARSE_COMPUTE_TF32;
|
||||||
break;
|
break;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix multiplication.");
|
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix multiplication.");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
ScalarType out_dtype = dense_B.scalar_type();
|
ScalarType out_dtype = dense_B.scalar_type();
|
||||||
// special check for mixed dtype int8 int8 -> {fp16, bf16, int32} support
|
// special check for mixed dtype support for 8 bit dtypes
|
||||||
|
// cslt 0.5.2+: int8 int8 -> {fp16, bf16, int32} support
|
||||||
if (out_dtype_opt.has_value()) {
|
if (out_dtype_opt.has_value()) {
|
||||||
out_dtype = out_dtype_opt.value();
|
out_dtype = out_dtype_opt.value();
|
||||||
TORCH_CHECK(input_type == CUDA_R_8I, "out_dtype support only available for int8 inputs");
|
if (input_type == CUDA_R_8I)
|
||||||
switch (out_dtype)
|
|
||||||
{
|
{
|
||||||
case at::ScalarType::Half:
|
switch (out_dtype)
|
||||||
output_type = CUDA_R_16F;
|
{
|
||||||
break;
|
case at::ScalarType::Half:
|
||||||
case at::ScalarType::BFloat16:
|
C_type = CUDA_R_16F;
|
||||||
output_type = CUDA_R_16BF;
|
output_type = CUDA_R_16F;
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::Int:
|
case at::ScalarType::BFloat16:
|
||||||
output_type = CUDA_R_32I;
|
C_type = CUDA_R_16BF;
|
||||||
break;
|
output_type = CUDA_R_16BF;
|
||||||
default:
|
break;
|
||||||
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, int32}");
|
case at::ScalarType::Int:
|
||||||
break;
|
C_type = CUDA_R_32I;
|
||||||
|
output_type = CUDA_R_32I;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, int32} for int8 inputs");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
|
||||||
|
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||||
|
else if (input_type == CUDA_R_8F_E4M3)
|
||||||
|
{
|
||||||
|
switch (out_dtype)
|
||||||
|
{
|
||||||
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
|
output_type = CUDA_R_8F_E4M3;
|
||||||
|
C_type = CUDA_R_16F;
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
output_type = CUDA_R_16F;
|
||||||
|
C_type = CUDA_R_16F;
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
output_type = CUDA_R_16BF;
|
||||||
|
C_type = CUDA_R_16BF;
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
output_type = CUDA_R_32F;
|
||||||
|
C_type = CUDA_R_32F;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, float32} for fp8 inputs");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
else {
|
||||||
|
TORCH_CHECK(false, "out_dtype support only available for int8/fp8 inputs");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,6 +300,18 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
output_type,
|
output_type,
|
||||||
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
|
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
|
||||||
|
|
||||||
|
// For float8, need fp16 C_descriptor, can't use FP8 for this matrix
|
||||||
|
cusparseLtMatDescriptor_t C_descriptor;
|
||||||
|
TORCH_CUDASPARSE_CHECK(cusparseLtDenseDescriptorInit(
|
||||||
|
&handle,
|
||||||
|
&C_descriptor,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
(transpose_result) ? m: n,
|
||||||
|
16,
|
||||||
|
C_type,
|
||||||
|
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
|
||||||
|
|
||||||
// initialize matmul
|
// initialize matmul
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescriptorInit(
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescriptorInit(
|
||||||
&handle,
|
&handle,
|
||||||
@ -252,7 +320,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
(dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE,
|
(dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE,
|
||||||
&sparse_input_descriptor,
|
&sparse_input_descriptor,
|
||||||
&dense_input_descriptor,
|
&dense_input_descriptor,
|
||||||
&res_descriptor,
|
&C_descriptor,
|
||||||
&res_descriptor,
|
&res_descriptor,
|
||||||
compute_type));
|
compute_type));
|
||||||
|
|
||||||
@ -273,11 +341,17 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
|
|||||||
|
|
||||||
// set tensor_alpha_mode and alpha pointer for matmul
|
// set tensor_alpha_mode and alpha pointer for matmul
|
||||||
const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{};
|
const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{};
|
||||||
const auto alpha_ptr = alpha_opt.has_value() ? alpha_tensor.data_ptr(): α
|
auto alpha_ptr = α
|
||||||
if (alpha_opt.has_value()) {
|
if (alpha_opt.has_value()) {
|
||||||
tensor_alpha_mode = 1;
|
if (alpha_tensor.numel() == 1) {
|
||||||
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute(
|
alpha = alpha_tensor.item<float>();
|
||||||
&handle, &matmul, CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, &tensor_alpha_mode, sizeof(tensor_alpha_mode)));
|
}
|
||||||
|
else {
|
||||||
|
tensor_alpha_mode = 1;
|
||||||
|
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute(
|
||||||
|
&handle, &matmul, CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, &tensor_alpha_mode, sizeof(tensor_alpha_mode)));
|
||||||
|
alpha_ptr = static_cast<float*>(alpha_tensor.data_ptr());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CUDASPARSE_CHECK(
|
TORCH_CUDASPARSE_CHECK(
|
||||||
|
@ -21,7 +21,7 @@ from torch.sparse._semi_structured_conversions import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_cuda import _get_torch_cuda_version
|
from torch.testing._internal.common_cuda import _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
@ -1022,10 +1022,19 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
|
|||||||
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
|
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CUSPARSELT_NUM_ALG_IDS = 4
|
|
||||||
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
||||||
|
|
||||||
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
# Calculate the scale as dtype max divided by absmax
|
||||||
|
scale = finfo.max / x.abs().max().clamp(min=1e-12)
|
||||||
|
# scale and clamp the tensor to bring it to
|
||||||
|
# the representative range of float8 data type
|
||||||
|
# (as default cast is unsaturated)
|
||||||
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
# Return both float8 data and the inverse scale (as float),
|
||||||
|
# as both required as inputs to torch._scaled_mm
|
||||||
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||||
|
|
||||||
class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||||
"""
|
"""
|
||||||
@ -1034,10 +1043,68 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||||||
torch._cslt_sparse_mm
|
torch._cslt_sparse_mm
|
||||||
"""
|
"""
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
||||||
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||||
self.skipTest('cuSPARSELt not enabled')
|
self.skipTest('cuSPARSELt not enabled')
|
||||||
|
|
||||||
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
|
||||||
|
@parametrize("dense_input_shape", [(256, 128)])
|
||||||
|
def test_sparse_fp8fp8_mm(self, dense_input_shape, device):
|
||||||
|
if torch.backends.cusparselt.version() < 602:
|
||||||
|
self.skipTest("fp8 matmul requires cuSPARSELt v0.6.2+")
|
||||||
|
|
||||||
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
|
||||||
|
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()
|
||||||
|
|
||||||
|
A_fp8, A_scale = to_float8(A)
|
||||||
|
B_fp8, B_scale = to_float8(B)
|
||||||
|
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
NotImplementedError,
|
||||||
|
r"`SparseSemiStructuredTensor.*_scaled_mm",
|
||||||
|
):
|
||||||
|
dense_result = torch.mm(A_fp8_sparse, B_fp8)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
|
||||||
|
def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None:
|
||||||
|
(k, l, m) = (32, 64, 32)
|
||||||
|
x = rand_sparse_semi_structured_mask(k, l, dtype=torch.float8_e4m3fn, device=device)
|
||||||
|
y = torch.full((m, l), .25, device=device, dtype=torch.float8_e4m3fn).t()
|
||||||
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
|
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
x_sparse = to_sparse_semi_structured(x)
|
||||||
|
out_fp8_sparse = torch._scaled_mm(x_sparse, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn)
|
||||||
|
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
||||||
|
out_fp32 = out_fp8.to(torch.float32)
|
||||||
|
out_fp32_sparse = out_fp8_sparse.to(torch.float32)
|
||||||
|
torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
|
||||||
|
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||||
|
@parametrize("dense_input_shape", [(256, 128)])
|
||||||
|
def test_sparse_semi_structured_scaled_mm(
|
||||||
|
self, dense_input_shape, device, out_dtype
|
||||||
|
):
|
||||||
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
|
||||||
|
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()
|
||||||
|
|
||||||
|
A_fp8, A_scale = to_float8(A)
|
||||||
|
B_fp8, B_scale = to_float8(B)
|
||||||
|
|
||||||
|
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
|
||||||
|
|
||||||
|
dense_result = torch._scaled_mm(
|
||||||
|
A_fp8, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype
|
||||||
|
)
|
||||||
|
sparse_result = torch._scaled_mm(
|
||||||
|
A_fp8_sparse, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(dense_result, sparse_result, rtol=7e-2, atol=7e-2)
|
||||||
|
|
||||||
|
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
||||||
@parametrize("dense_input_shape", [(128, 128)])
|
@parametrize("dense_input_shape", [(128, 128)])
|
||||||
def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
|
def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
|
||||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
|
||||||
@ -1066,7 +1133,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||||||
|
|
||||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
|
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
||||||
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
||||||
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
||||||
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
||||||
@ -1082,17 +1149,14 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||||||
|
|
||||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
|
|
||||||
@inference_dtypes
|
@inference_dtypes
|
||||||
def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id):
|
def test_cslt_sparse_mm_alg_id(self, device, dtype):
|
||||||
# alg_id=3 not supported for float32 dtype
|
|
||||||
if dtype == torch.float32 and alg_id == 3:
|
|
||||||
return
|
|
||||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||||
A_compressed = torch._cslt_compress(A)
|
A_compressed = torch._cslt_compress(A)
|
||||||
B = torch.ones((128, 128), device=device).to(dtype)
|
B = torch.ones((128, 128), device=device).to(dtype)
|
||||||
|
|
||||||
A_compressed = torch._cslt_compress(A)
|
A_compressed = torch._cslt_compress(A)
|
||||||
|
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
||||||
|
|
||||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||||
@ -1102,17 +1166,13 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||||||
|
|
||||||
@inference_dtypes
|
@inference_dtypes
|
||||||
def test_cslt_sparse_mm_search(self, device, dtype):
|
def test_cslt_sparse_mm_search(self, device, dtype):
|
||||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||||
A_compressed = torch._cslt_compress(A)
|
A_compressed = torch._cslt_compress(A)
|
||||||
B = torch.ones((128, 128), device=device).to(dtype)
|
B = torch.ones((128, 128), device=device).to(dtype)
|
||||||
|
|
||||||
A_compressed = torch._cslt_compress(A)
|
A_compressed = torch._cslt_compress(A)
|
||||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||||
# for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error
|
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
|
||||||
# when setting using the last one (4)
|
|
||||||
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
|
|
||||||
# TODO Move this into the cuSPARSELt backendk
|
|
||||||
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
|
|
||||||
|
|
||||||
def test_cusparselt_backend(self):
|
def test_cusparselt_backend(self):
|
||||||
version = _get_torch_cuda_version()
|
version = _get_torch_cuda_version()
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"version",
|
"version",
|
||||||
"is_available",
|
"is_available",
|
||||||
|
"get_max_alg_id",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -15,13 +16,21 @@ except ImportError:
|
|||||||
_cusparselt = None # type: ignore[assignment]
|
_cusparselt = None # type: ignore[assignment]
|
||||||
|
|
||||||
__cusparselt_version: Optional[int] = None
|
__cusparselt_version: Optional[int] = None
|
||||||
|
__MAX_ALG_ID: Optional[int] = None
|
||||||
|
|
||||||
if _cusparselt is not None:
|
if _cusparselt is not None:
|
||||||
|
|
||||||
def _init():
|
def _init():
|
||||||
global __cusparselt_version
|
global __cusparselt_version
|
||||||
|
global __MAX_ALG_ID
|
||||||
if __cusparselt_version is None:
|
if __cusparselt_version is None:
|
||||||
__cusparselt_version = _cusparselt.getVersionInt()
|
__cusparselt_version = _cusparselt.getVersionInt()
|
||||||
|
if __cusparselt_version == 400:
|
||||||
|
__MAX_ALG_ID = 4
|
||||||
|
elif __cusparselt_version == 502:
|
||||||
|
__MAX_ALG_ID = 5
|
||||||
|
elif __cusparselt_version == 602:
|
||||||
|
__MAX_ALG_ID = 37
|
||||||
return True
|
return True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -40,3 +49,9 @@ def version() -> Optional[int]:
|
|||||||
def is_available() -> bool:
|
def is_available() -> bool:
|
||||||
r"""Return a bool indicating if cuSPARSELt is currently available."""
|
r"""Return a bool indicating if cuSPARSELt is currently available."""
|
||||||
return torch._C._has_cusparselt
|
return torch._C._has_cusparselt
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_alg_id() -> Optional[int]:
|
||||||
|
if not _init():
|
||||||
|
return None
|
||||||
|
return __MAX_ALG_ID
|
||||||
|
@ -14,6 +14,7 @@ __all__ = [
|
|||||||
"semi_sparse_mm",
|
"semi_sparse_mm",
|
||||||
"semi_sparse_addmm",
|
"semi_sparse_addmm",
|
||||||
"semi_sparse_linear",
|
"semi_sparse_linear",
|
||||||
|
"semi_sparse_scaled_mm",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -72,9 +73,11 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|||||||
meta=self.meta_t,
|
meta=self.meta_t,
|
||||||
packed_t=self.packed,
|
packed_t=self.packed,
|
||||||
meta_t=self.meta,
|
meta_t=self.meta,
|
||||||
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1)
|
compressed_swizzled_bitmask=(
|
||||||
if self.compressed_swizzled_bitmask is not None
|
self.compressed_swizzled_bitmask.transpose(0, 1)
|
||||||
else None,
|
if self.compressed_swizzled_bitmask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
|
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
|
||||||
alg_id_cusparselt=args[0].alg_id_cusparselt,
|
alg_id_cusparselt=args[0].alg_id_cusparselt,
|
||||||
)
|
)
|
||||||
@ -166,3 +169,27 @@ def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return res.view(*shape[:-1], -1)
|
return res.view(*shape[:-1], -1)
|
||||||
|
|
||||||
|
|
||||||
|
def semi_sparse_scaled_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||||
|
# pull all args, excluding use_fast_accum flag if set.
|
||||||
|
A, B, A_scale, B_scale, bias, scale_result, out_dtype = args[:7]
|
||||||
|
|
||||||
|
assert A.dtype == torch.float8_e4m3fn
|
||||||
|
assert B.dtype == torch.float8_e4m3fn
|
||||||
|
# only cuSPARSELt supports float8_e4m3fn currentl
|
||||||
|
assert isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT)
|
||||||
|
assert A.packed is not None
|
||||||
|
# Currently we only support per-tensor scaling, with float32 scales
|
||||||
|
assert A_scale.numel() == 1 and B_scale.numel() == 1
|
||||||
|
assert A_scale.dtype == torch.float32 and B_scale.dtype == torch.float32
|
||||||
|
|
||||||
|
# cuSPARSELt lacks the A and B operand scaling support, so instead we use alpha to scale the result.
|
||||||
|
# Note that this limits us to per-tensor scalig only.
|
||||||
|
sparse_result = torch._cslt_sparse_mm(
|
||||||
|
A.packed,
|
||||||
|
B,
|
||||||
|
alpha=A_scale * B_scale,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
)
|
||||||
|
return sparse_result
|
||||||
|
@ -15,6 +15,7 @@ from torch.sparse._semi_structured_ops import (
|
|||||||
semi_sparse_indices,
|
semi_sparse_indices,
|
||||||
semi_sparse_linear,
|
semi_sparse_linear,
|
||||||
semi_sparse_mm,
|
semi_sparse_mm,
|
||||||
|
semi_sparse_scaled_mm,
|
||||||
semi_sparse_t,
|
semi_sparse_t,
|
||||||
semi_sparse_values,
|
semi_sparse_values,
|
||||||
semi_sparse_view,
|
semi_sparse_view,
|
||||||
@ -54,7 +55,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||||||
|
|
||||||
_DEFAULT_ALG_ID: int = 0
|
_DEFAULT_ALG_ID: int = 0
|
||||||
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
|
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
|
||||||
_FORCE_CUTLASS: bool = True
|
_FORCE_CUTLASS: bool = False
|
||||||
_FUSE_TRANSPOSE: bool = False
|
_FUSE_TRANSPOSE: bool = False
|
||||||
_PROTOTYPE_WARNING_SHOWN: bool = False
|
_PROTOTYPE_WARNING_SHOWN: bool = False
|
||||||
|
|
||||||
@ -225,6 +226,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||||||
torch.ops.aten.addmm: semi_sparse_addmm,
|
torch.ops.aten.addmm: semi_sparse_addmm,
|
||||||
torch.ops.aten.linear: semi_sparse_linear,
|
torch.ops.aten.linear: semi_sparse_linear,
|
||||||
torch.ops.aten._to_copy: fallback_dispatcher,
|
torch.ops.aten._to_copy: fallback_dispatcher,
|
||||||
|
torch.ops.aten._scaled_mm: semi_sparse_scaled_mm,
|
||||||
}
|
}
|
||||||
if custom_dispatch_table is not None:
|
if custom_dispatch_table is not None:
|
||||||
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
||||||
@ -258,8 +260,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||||||
# check dtype
|
# check dtype
|
||||||
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
|
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
|
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype for {cls}!"
|
||||||
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# check shape
|
# check shape
|
||||||
@ -534,6 +535,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
|||||||
|
|
||||||
BACKEND = "cusparselt"
|
BACKEND = "cusparselt"
|
||||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||||
|
torch.float8_e4m3fn: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
||||||
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
||||||
@ -630,9 +632,16 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
|||||||
if bias is not None and bias.dtype != self.dtype:
|
if bias is not None and bias.dtype != self.dtype:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
|
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
|
||||||
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
|
f"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
|
||||||
"This operation is only supported when A, B and C have the same data type."
|
"This operation is only supported when A, B and C have the same data type."
|
||||||
)
|
)
|
||||||
|
# Force fp8 mm to error to be consistent with torch
|
||||||
|
if self.dtype == torch.float8_e4m3fn:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
|
||||||
|
f"with A.dtype=B.dtype={self.dtype}. "
|
||||||
|
"mm is not supported for float8_e4m3fn, please use `torch._scaled_mm` instead."
|
||||||
|
)
|
||||||
if self.packed is None:
|
if self.packed is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"`{self.__class__.__name__}` matmul: operation is not supported"
|
f"`{self.__class__.__name__}` matmul: operation is not supported"
|
||||||
|
Reference in New Issue
Block a user