[sparse][semi-structured] Add float8 dtype support to 24 sparsity (#136397)

Summary:

This PR adds `torch.float8e4m3fn` support to cuSPARSELt and `to_sparse_semi_structured`.

This will let users to run fp8 + 2:4 sparse matmuls on Hopper GPUs with
cusparselt >= 0.6.2, via to `scaled_mm` API.

```
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()

A_fp8, A_scale = to_float8(A)
B_fp8, B_scale = to_float8(B)

dense_result = torch._scaled_mm(
    A_fp8, B_fp8,
    scale_a=A_scale, scale_b=B_scale,
    out_dtype=out_dtype
)
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
sparse_result = torch._scaled_mm(
    A_fp8_sparse, B_fp8,
    scale_a=A_scale, scale_b=B_scale,
    out_dtype=out_dtype
)
```

Note that to keep this consistent with normal torch behavior, calling
`torch.mm(A_fp8_sparse, B_fp8)` will raise a NotImplementedError.

I also turned on cuSPARSELt by default and added CUSPARSELT_MAX_ID to the
backend to make the tests a bit cleaner

Test Plan:
```
python test/test_sparse_semi_structured -k scaled_mm
python test/test_sparse_semi_structured -k fp8
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136397
Approved by: https://github.com/drisspg
This commit is contained in:
Jesse Cai
2024-09-27 12:03:44 -07:00
committed by PyTorch MergeBot
parent a28b40fa74
commit bc21689136
5 changed files with 231 additions and 46 deletions

View File

@ -15,6 +15,7 @@ from torch.sparse._semi_structured_ops import (
semi_sparse_indices,
semi_sparse_linear,
semi_sparse_mm,
semi_sparse_scaled_mm,
semi_sparse_t,
semi_sparse_values,
semi_sparse_view,
@ -54,7 +55,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
_DEFAULT_ALG_ID: int = 0
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
_FORCE_CUTLASS: bool = True
_FORCE_CUTLASS: bool = False
_FUSE_TRANSPOSE: bool = False
_PROTOTYPE_WARNING_SHOWN: bool = False
@ -225,6 +226,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
torch.ops.aten.addmm: semi_sparse_addmm,
torch.ops.aten.linear: semi_sparse_linear,
torch.ops.aten._to_copy: fallback_dispatcher,
torch.ops.aten._scaled_mm: semi_sparse_scaled_mm,
}
if custom_dispatch_table is not None:
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
@ -258,8 +260,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
# check dtype
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
raise RuntimeError(
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype for {cls}!"
)
# check shape
@ -534,6 +535,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
BACKEND = "cusparselt"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.float8_e4m3fn: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
@ -630,9 +632,16 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
if bias is not None and bias.dtype != self.dtype:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
f"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
"This operation is only supported when A, B and C have the same data type."
)
# Force fp8 mm to error to be consistent with torch
if self.dtype == torch.float8_e4m3fn:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
f"with A.dtype=B.dtype={self.dtype}. "
"mm is not supported for float8_e4m3fn, please use `torch._scaled_mm` instead."
)
if self.packed is None:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: operation is not supported"