[sparse][semi-structured] Add float8 dtype support to 24 sparsity (#136397)

Summary:

This PR adds `torch.float8e4m3fn` support to cuSPARSELt and `to_sparse_semi_structured`.

This will let users to run fp8 + 2:4 sparse matmuls on Hopper GPUs with
cusparselt >= 0.6.2, via to `scaled_mm` API.

```
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()

A_fp8, A_scale = to_float8(A)
B_fp8, B_scale = to_float8(B)

dense_result = torch._scaled_mm(
    A_fp8, B_fp8,
    scale_a=A_scale, scale_b=B_scale,
    out_dtype=out_dtype
)
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
sparse_result = torch._scaled_mm(
    A_fp8_sparse, B_fp8,
    scale_a=A_scale, scale_b=B_scale,
    out_dtype=out_dtype
)
```

Note that to keep this consistent with normal torch behavior, calling
`torch.mm(A_fp8_sparse, B_fp8)` will raise a NotImplementedError.

I also turned on cuSPARSELt by default and added CUSPARSELT_MAX_ID to the
backend to make the tests a bit cleaner

Test Plan:
```
python test/test_sparse_semi_structured -k scaled_mm
python test/test_sparse_semi_structured -k fp8
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136397
Approved by: https://github.com/drisspg
This commit is contained in:
Jesse Cai
2024-09-27 12:03:44 -07:00
committed by PyTorch MergeBot
parent a28b40fa74
commit bc21689136
5 changed files with 231 additions and 46 deletions

View File

@ -53,6 +53,11 @@ at::Tensor _cslt_compress(const Tensor& sparse_input)
case at::ScalarType::Float: case at::ScalarType::Float:
type = CUDA_R_32F; type = CUDA_R_32F;
break; break;
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
case at::ScalarType::Float8_e4m3fn:
type = CUDA_R_8F_E4M3;
break;
#endif
default: default:
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix"); TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix");
break; break;
@ -123,15 +128,16 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
float beta = 0.0; float beta = 0.0;
cudaDataType input_type; cudaDataType input_type;
cudaDataType output_type; cudaDataType output_type;
cudaDataType C_type;
cusparseComputeType compute_type; cusparseComputeType compute_type;
auto compression_factor = 9; auto compression_factor = 9;
switch(compressed_A.scalar_type()) switch(compressed_A.scalar_type())
{ {
case at::ScalarType::Char: case at::ScalarType::Char:
input_type = CUDA_R_8I; input_type = CUDA_R_8I;
output_type = CUDA_R_8I; output_type = CUDA_R_8I;
C_type = CUDA_R_8I;
compute_type = CUSPARSE_COMPUTE_32I; compute_type = CUSPARSE_COMPUTE_32I;
compression_factor = 10; compression_factor = 10;
break; break;
@ -141,61 +147,111 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
case at::ScalarType::Half: case at::ScalarType::Half:
input_type = CUDA_R_16F; input_type = CUDA_R_16F;
output_type = CUDA_R_16F; output_type = CUDA_R_16F;
C_type = CUDA_R_16F;
compute_type = CUSPARSE_COMPUTE_32F; compute_type = CUSPARSE_COMPUTE_32F;
break; break;
case at::ScalarType::BFloat16: case at::ScalarType::BFloat16:
input_type = CUDA_R_16BF; input_type = CUDA_R_16BF;
output_type = CUDA_R_16BF; output_type = CUDA_R_16BF;
C_type = CUDA_R_16BF;
compute_type = CUSPARSE_COMPUTE_32F; compute_type = CUSPARSE_COMPUTE_32F;
break; break;
case at::ScalarType::Float: case at::ScalarType::Float:
input_type = CUDA_R_32F; input_type = CUDA_R_32F;
output_type = CUDA_R_32F; output_type = CUDA_R_32F;
C_type = CUDA_R_32F;
compute_type = CUSPARSE_COMPUTE_32F; compute_type = CUSPARSE_COMPUTE_32F;
break; break;
// if cuSPARSELt >= 6.2.3, we can add Float8 support
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
case at::ScalarType::Float8_e4m3fn:
input_type = CUDA_R_8F_E4M3;
output_type = CUDA_R_8F_E4M3;
C_type = CUDA_R_16F;
compute_type = CUSPARSE_COMPUTE_32F;
break;
#endif
// cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F // cuSPARSELt <= v0.5.2 uses CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUTE_16F
#else #else
case at::ScalarType::Half: case at::ScalarType::Half:
input_type = CUDA_R_16F; input_type = CUDA_R_16F;
output_type = CUDA_R_16F; output_type = CUDA_R_16F;
C_type = CUDA_R_16F;
compute_type = CUSPARSE_COMPUTE_16F; compute_type = CUSPARSE_COMPUTE_16F;
break; break;
case at::ScalarType::BFloat16: case at::ScalarType::BFloat16:
input_type = CUDA_R_16BF; input_type = CUDA_R_16BF;
output_type = CUDA_R_16BF; output_type = CUDA_R_16BF;
C_type = CUDA_R_16BF;
compute_type = CUSPARSE_COMPUTE_16F; compute_type = CUSPARSE_COMPUTE_16F;
break; break;
case at::ScalarType::Float: case at::ScalarType::Float:
input_type = CUDA_R_32F; input_type = CUDA_R_32F;
output_type = CUDA_R_32F; output_type = CUDA_R_32F;
C_type = CUDA_R_32F;
compute_type = CUSPARSE_COMPUTE_TF32; compute_type = CUSPARSE_COMPUTE_TF32;
break; break;
#endif #endif
default: default:
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix multiplication."); TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix multiplication.");
break; break;
} }
ScalarType out_dtype = dense_B.scalar_type(); ScalarType out_dtype = dense_B.scalar_type();
// special check for mixed dtype int8 int8 -> {fp16, bf16, int32} support // special check for mixed dtype support for 8 bit dtypes
// cslt 0.5.2+: int8 int8 -> {fp16, bf16, int32} support
if (out_dtype_opt.has_value()) { if (out_dtype_opt.has_value()) {
out_dtype = out_dtype_opt.value(); out_dtype = out_dtype_opt.value();
TORCH_CHECK(input_type == CUDA_R_8I, "out_dtype support only available for int8 inputs"); if (input_type == CUDA_R_8I)
switch (out_dtype)
{ {
case at::ScalarType::Half: switch (out_dtype)
output_type = CUDA_R_16F; {
break; case at::ScalarType::Half:
case at::ScalarType::BFloat16: C_type = CUDA_R_16F;
output_type = CUDA_R_16BF; output_type = CUDA_R_16F;
break; break;
case at::ScalarType::Int: case at::ScalarType::BFloat16:
output_type = CUDA_R_32I; C_type = CUDA_R_16BF;
break; output_type = CUDA_R_16BF;
default: break;
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, int32}"); case at::ScalarType::Int:
break; C_type = CUDA_R_32I;
output_type = CUDA_R_32I;
break;
default:
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, int32} for int8 inputs");
break;
}
}
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
else if (input_type == CUDA_R_8F_E4M3)
{
switch (out_dtype)
{
case at::ScalarType::Float8_e4m3fn:
output_type = CUDA_R_8F_E4M3;
C_type = CUDA_R_16F;
break;
case at::ScalarType::Half:
output_type = CUDA_R_16F;
C_type = CUDA_R_16F;
break;
case at::ScalarType::BFloat16:
output_type = CUDA_R_16BF;
C_type = CUDA_R_16BF;
break;
case at::ScalarType::Float:
output_type = CUDA_R_32F;
C_type = CUDA_R_32F;
break;
default:
TORCH_CHECK(false, "Unsupported out_dtype passed, must be one of {fp16, bf16, float32} for fp8 inputs");
break;
}
}
#endif
else {
TORCH_CHECK(false, "out_dtype support only available for int8/fp8 inputs");
} }
} }
@ -244,6 +300,18 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
output_type, output_type,
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW)); (transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
// For float8, need fp16 C_descriptor, can't use FP8 for this matrix
cusparseLtMatDescriptor_t C_descriptor;
TORCH_CUDASPARSE_CHECK(cusparseLtDenseDescriptorInit(
&handle,
&C_descriptor,
m,
n,
(transpose_result) ? m: n,
16,
C_type,
(transpose_result) ? CUSPARSE_ORDER_COL : CUSPARSE_ORDER_ROW));
// initialize matmul // initialize matmul
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescriptorInit( TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescriptorInit(
&handle, &handle,
@ -252,7 +320,7 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
(dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE, (dense_B.is_contiguous()) ? CUSPARSE_OPERATION_NON_TRANSPOSE : CUSPARSE_OPERATION_TRANSPOSE,
&sparse_input_descriptor, &sparse_input_descriptor,
&dense_input_descriptor, &dense_input_descriptor,
&res_descriptor, &C_descriptor,
&res_descriptor, &res_descriptor,
compute_type)); compute_type));
@ -273,11 +341,17 @@ std::tuple<int64_t, at::Tensor> _cslt_sparse_mm_impl(
// set tensor_alpha_mode and alpha pointer for matmul // set tensor_alpha_mode and alpha pointer for matmul
const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{}; const auto alpha_tensor = alpha_opt.has_value() ? *alpha_opt: Tensor{};
const auto alpha_ptr = alpha_opt.has_value() ? alpha_tensor.data_ptr(): &alpha; auto alpha_ptr = &alpha;
if (alpha_opt.has_value()) { if (alpha_opt.has_value()) {
tensor_alpha_mode = 1; if (alpha_tensor.numel() == 1) {
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute( alpha = alpha_tensor.item<float>();
&handle, &matmul, CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, &tensor_alpha_mode, sizeof(tensor_alpha_mode))); }
else {
tensor_alpha_mode = 1;
TORCH_CUDASPARSE_CHECK(cusparseLtMatmulDescSetAttribute(
&handle, &matmul, CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, &tensor_alpha_mode, sizeof(tensor_alpha_mode)));
alpha_ptr = static_cast<float*>(alpha_tensor.data_ptr());
}
} }
TORCH_CUDASPARSE_CHECK( TORCH_CUDASPARSE_CHECK(

View File

@ -21,7 +21,7 @@ from torch.sparse._semi_structured_conversions import (
) )
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_cuda import _get_torch_cuda_version from torch.testing._internal.common_cuda import _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
dtypes, dtypes,
instantiate_device_type_tests, instantiate_device_type_tests,
@ -1022,10 +1022,19 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0) torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
CUSPARSELT_NUM_ALG_IDS = 4
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32] CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / x.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
return x_scl_sat.to(dtype), scale.float().reciprocal()
class TestSparseSemiStructuredCUSPARSELT(TestCase): class TestSparseSemiStructuredCUSPARSELT(TestCase):
""" """
@ -1034,10 +1043,68 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
torch._cslt_sparse_mm torch._cslt_sparse_mm
""" """
def setUp(self): def setUp(self):
SparseSemiStructuredTensor._FORCE_CUTLASS = False
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
self.skipTest('cuSPARSELt not enabled') self.skipTest('cuSPARSELt not enabled')
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
@parametrize("dense_input_shape", [(256, 128)])
def test_sparse_fp8fp8_mm(self, dense_input_shape, device):
if torch.backends.cusparselt.version() < 602:
self.skipTest("fp8 matmul requires cuSPARSELt v0.6.2+")
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()
A_fp8, A_scale = to_float8(A)
B_fp8, B_scale = to_float8(B)
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
with self.assertRaisesRegex(
NotImplementedError,
r"`SparseSemiStructuredTensor.*_scaled_mm",
):
dense_result = torch.mm(A_fp8_sparse, B_fp8)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None:
(k, l, m) = (32, 64, 32)
x = rand_sparse_semi_structured_mask(k, l, dtype=torch.float8_e4m3fn, device=device)
y = torch.full((m, l), .25, device=device, dtype=torch.float8_e4m3fn).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn)
x_sparse = to_sparse_semi_structured(x)
out_fp8_sparse = torch._scaled_mm(x_sparse, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn)
# this fails on ROCm currently because hipblaslt doesn't have amax op
out_fp32 = out_fp8.to(torch.float32)
out_fp32_sparse = out_fp8_sparse.to(torch.float32)
torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
@parametrize("dense_input_shape", [(256, 128)])
def test_sparse_semi_structured_scaled_mm(
self, dense_input_shape, device, out_dtype
):
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()
A_fp8, A_scale = to_float8(A)
B_fp8, B_scale = to_float8(B)
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
dense_result = torch._scaled_mm(
A_fp8, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype
)
sparse_result = torch._scaled_mm(
A_fp8_sparse, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype
)
torch.testing.assert_close(dense_result, sparse_result, rtol=7e-2, atol=7e-2)
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
@parametrize("dense_input_shape", [(128, 128)]) @parametrize("dense_input_shape", [(128, 128)])
def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device): def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8) A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
@ -1066,7 +1133,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda() A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
B = torch.ones((128, 256), device=device).to(torch.int8).t() B = torch.ones((128, 256), device=device).to(torch.int8).t()
@ -1082,17 +1149,14 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
@inference_dtypes @inference_dtypes
def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id): def test_cslt_sparse_mm_alg_id(self, device, dtype):
# alg_id=3 not supported for float32 dtype
if dtype == torch.float32 and alg_id == 3:
return
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A) A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype) B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A) A_compressed = torch._cslt_compress(A)
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id) sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32)) dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
@ -1102,17 +1166,13 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
@inference_dtypes @inference_dtypes
def test_cslt_sparse_mm_search(self, device, dtype): def test_cslt_sparse_mm_search(self, device, dtype):
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A) A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype) B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A) A_compressed = torch._cslt_compress(A)
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t()) alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
# for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
# when setting using the last one (4)
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
# TODO Move this into the cuSPARSELt backendk
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
def test_cusparselt_backend(self): def test_cusparselt_backend(self):
version = _get_torch_cuda_version() version = _get_torch_cuda_version()

View File

@ -7,6 +7,7 @@ import torch
__all__ = [ __all__ = [
"version", "version",
"is_available", "is_available",
"get_max_alg_id",
] ]
try: try:
@ -15,13 +16,21 @@ except ImportError:
_cusparselt = None # type: ignore[assignment] _cusparselt = None # type: ignore[assignment]
__cusparselt_version: Optional[int] = None __cusparselt_version: Optional[int] = None
__MAX_ALG_ID: Optional[int] = None
if _cusparselt is not None: if _cusparselt is not None:
def _init(): def _init():
global __cusparselt_version global __cusparselt_version
global __MAX_ALG_ID
if __cusparselt_version is None: if __cusparselt_version is None:
__cusparselt_version = _cusparselt.getVersionInt() __cusparselt_version = _cusparselt.getVersionInt()
if __cusparselt_version == 400:
__MAX_ALG_ID = 4
elif __cusparselt_version == 502:
__MAX_ALG_ID = 5
elif __cusparselt_version == 602:
__MAX_ALG_ID = 37
return True return True
else: else:
@ -40,3 +49,9 @@ def version() -> Optional[int]:
def is_available() -> bool: def is_available() -> bool:
r"""Return a bool indicating if cuSPARSELt is currently available.""" r"""Return a bool indicating if cuSPARSELt is currently available."""
return torch._C._has_cusparselt return torch._C._has_cusparselt
def get_max_alg_id() -> Optional[int]:
if not _init():
return None
return __MAX_ALG_ID

View File

@ -14,6 +14,7 @@ __all__ = [
"semi_sparse_mm", "semi_sparse_mm",
"semi_sparse_addmm", "semi_sparse_addmm",
"semi_sparse_linear", "semi_sparse_linear",
"semi_sparse_scaled_mm",
] ]
@ -72,9 +73,11 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
meta=self.meta_t, meta=self.meta_t,
packed_t=self.packed, packed_t=self.packed,
meta_t=self.meta, meta_t=self.meta,
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1) compressed_swizzled_bitmask=(
if self.compressed_swizzled_bitmask is not None self.compressed_swizzled_bitmask.transpose(0, 1)
else None, if self.compressed_swizzled_bitmask is not None
else None
),
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt, fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
alg_id_cusparselt=args[0].alg_id_cusparselt, alg_id_cusparselt=args[0].alg_id_cusparselt,
) )
@ -166,3 +169,27 @@ def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
) )
return res.view(*shape[:-1], -1) return res.view(*shape[:-1], -1)
def semi_sparse_scaled_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
# pull all args, excluding use_fast_accum flag if set.
A, B, A_scale, B_scale, bias, scale_result, out_dtype = args[:7]
assert A.dtype == torch.float8_e4m3fn
assert B.dtype == torch.float8_e4m3fn
# only cuSPARSELt supports float8_e4m3fn currentl
assert isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT)
assert A.packed is not None
# Currently we only support per-tensor scaling, with float32 scales
assert A_scale.numel() == 1 and B_scale.numel() == 1
assert A_scale.dtype == torch.float32 and B_scale.dtype == torch.float32
# cuSPARSELt lacks the A and B operand scaling support, so instead we use alpha to scale the result.
# Note that this limits us to per-tensor scalig only.
sparse_result = torch._cslt_sparse_mm(
A.packed,
B,
alpha=A_scale * B_scale,
out_dtype=out_dtype,
)
return sparse_result

View File

@ -15,6 +15,7 @@ from torch.sparse._semi_structured_ops import (
semi_sparse_indices, semi_sparse_indices,
semi_sparse_linear, semi_sparse_linear,
semi_sparse_mm, semi_sparse_mm,
semi_sparse_scaled_mm,
semi_sparse_t, semi_sparse_t,
semi_sparse_values, semi_sparse_values,
semi_sparse_view, semi_sparse_view,
@ -54,7 +55,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
_DEFAULT_ALG_ID: int = 0 _DEFAULT_ALG_ID: int = 0
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG] _DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
_FORCE_CUTLASS: bool = True _FORCE_CUTLASS: bool = False
_FUSE_TRANSPOSE: bool = False _FUSE_TRANSPOSE: bool = False
_PROTOTYPE_WARNING_SHOWN: bool = False _PROTOTYPE_WARNING_SHOWN: bool = False
@ -225,6 +226,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
torch.ops.aten.addmm: semi_sparse_addmm, torch.ops.aten.addmm: semi_sparse_addmm,
torch.ops.aten.linear: semi_sparse_linear, torch.ops.aten.linear: semi_sparse_linear,
torch.ops.aten._to_copy: fallback_dispatcher, torch.ops.aten._to_copy: fallback_dispatcher,
torch.ops.aten._scaled_mm: semi_sparse_scaled_mm,
} }
if custom_dispatch_table is not None: if custom_dispatch_table is not None:
cls.SPARSE_DISPATCH.update(custom_dispatch_table) cls.SPARSE_DISPATCH.update(custom_dispatch_table)
@ -258,8 +260,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
# check dtype # check dtype
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS: if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
raise RuntimeError( raise RuntimeError(
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! " f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype for {cls}!"
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
) )
# check shape # check shape
@ -534,6 +535,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
BACKEND = "cusparselt" BACKEND = "cusparselt"
_DTYPE_SHAPE_CONSTRAINTS = { _DTYPE_SHAPE_CONSTRAINTS = {
torch.float8_e4m3fn: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16), torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
@ -630,9 +632,16 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
if bias is not None and bias.dtype != self.dtype: if bias is not None and bias.dtype != self.dtype:
raise NotImplementedError( raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, " f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. " f"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
"This operation is only supported when A, B and C have the same data type." "This operation is only supported when A, B and C have the same data type."
) )
# Force fp8 mm to error to be consistent with torch
if self.dtype == torch.float8_e4m3fn:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
f"with A.dtype=B.dtype={self.dtype}. "
"mm is not supported for float8_e4m3fn, please use `torch._scaled_mm` instead."
)
if self.packed is None: if self.packed is None:
raise NotImplementedError( raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: operation is not supported" f"`{self.__class__.__name__}` matmul: operation is not supported"