mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse][semi-structured] Add float8 dtype support to 24 sparsity (#136397)
Summary: This PR adds `torch.float8e4m3fn` support to cuSPARSELt and `to_sparse_semi_structured`. This will let users to run fp8 + 2:4 sparse matmuls on Hopper GPUs with cusparselt >= 0.6.2, via to `scaled_mm` API. ``` A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16) B = torch.rand(dense_input_shape, device=device).to(torch.float16).t() A_fp8, A_scale = to_float8(A) B_fp8, B_scale = to_float8(B) dense_result = torch._scaled_mm( A_fp8, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype ) A_fp8_sparse = to_sparse_semi_structured(A_fp8) sparse_result = torch._scaled_mm( A_fp8_sparse, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype ) ``` Note that to keep this consistent with normal torch behavior, calling `torch.mm(A_fp8_sparse, B_fp8)` will raise a NotImplementedError. I also turned on cuSPARSELt by default and added CUSPARSELT_MAX_ID to the backend to make the tests a bit cleaner Test Plan: ``` python test/test_sparse_semi_structured -k scaled_mm python test/test_sparse_semi_structured -k fp8 ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/136397 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
a28b40fa74
commit
bc21689136
@ -15,6 +15,7 @@ from torch.sparse._semi_structured_ops import (
|
||||
semi_sparse_indices,
|
||||
semi_sparse_linear,
|
||||
semi_sparse_mm,
|
||||
semi_sparse_scaled_mm,
|
||||
semi_sparse_t,
|
||||
semi_sparse_values,
|
||||
semi_sparse_view,
|
||||
@ -54,7 +55,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
|
||||
_DEFAULT_ALG_ID: int = 0
|
||||
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
|
||||
_FORCE_CUTLASS: bool = True
|
||||
_FORCE_CUTLASS: bool = False
|
||||
_FUSE_TRANSPOSE: bool = False
|
||||
_PROTOTYPE_WARNING_SHOWN: bool = False
|
||||
|
||||
@ -225,6 +226,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
torch.ops.aten.addmm: semi_sparse_addmm,
|
||||
torch.ops.aten.linear: semi_sparse_linear,
|
||||
torch.ops.aten._to_copy: fallback_dispatcher,
|
||||
torch.ops.aten._scaled_mm: semi_sparse_scaled_mm,
|
||||
}
|
||||
if custom_dispatch_table is not None:
|
||||
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
||||
@ -258,8 +260,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
# check dtype
|
||||
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
|
||||
raise RuntimeError(
|
||||
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
|
||||
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
|
||||
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype for {cls}!"
|
||||
)
|
||||
|
||||
# check shape
|
||||
@ -534,6 +535,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
|
||||
BACKEND = "cusparselt"
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.float8_e4m3fn: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
||||
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
||||
@ -630,9 +632,16 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
if bias is not None and bias.dtype != self.dtype:
|
||||
raise NotImplementedError(
|
||||
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
|
||||
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
|
||||
f"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
|
||||
"This operation is only supported when A, B and C have the same data type."
|
||||
)
|
||||
# Force fp8 mm to error to be consistent with torch
|
||||
if self.dtype == torch.float8_e4m3fn:
|
||||
raise NotImplementedError(
|
||||
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
|
||||
f"with A.dtype=B.dtype={self.dtype}. "
|
||||
"mm is not supported for float8_e4m3fn, please use `torch._scaled_mm` instead."
|
||||
)
|
||||
if self.packed is None:
|
||||
raise NotImplementedError(
|
||||
f"`{self.__class__.__name__}` matmul: operation is not supported"
|
||||
|
Reference in New Issue
Block a user