mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] Add in out_dtype support (i8i8->bf16, i32) for cusparselt (#119296)
Summary: Adds in out_dtype support for (i8i8->bf16) and (i8i8->i32) matmul with cuSPARSELt. Test Plan: ``` python test/test_sparse_semi_structured.py -k mixed ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/119296 Approved by: https://github.com/cpuhrsch, https://github.com/alexsamardzic
This commit is contained in:
committed by
PyTorch MergeBot
parent
5f69d95b2b
commit
1c1dc0e4e0
@ -34,6 +34,7 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
CUSPARSELT_NUM_ALG_IDS = 4
|
||||
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
||||
|
||||
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
|
||||
@ -596,16 +597,16 @@ class TestCUSPARSELT(TestCase):
|
||||
else:
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
||||
|
||||
|
||||
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
|
||||
@parametrize("dense_input_shape", [(128, 128)])
|
||||
def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device):
|
||||
def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
|
||||
B = torch.rand(dense_input_shape, device=device).to(torch.int8)
|
||||
|
||||
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16)
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=torch.float16)
|
||||
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype)
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@dtypes(torch.float16, torch.bfloat16)
|
||||
@ -623,17 +624,19 @@ class TestCUSPARSELT(TestCase):
|
||||
|
||||
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_cslt_sparse_mm_alpha_int8_in_f16_out(self, device):
|
||||
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
|
||||
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
||||
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
||||
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
||||
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
|
||||
alpha = torch.Tensor([2**(-i) if out_dtype is not torch.int32 else 1
|
||||
for i in range(128)]).cuda()
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.float16).cpu()
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=out_dtype).cpu()
|
||||
|
||||
alpha_scaled = torch.stack([alpha] * 128).t()
|
||||
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int32).cpu(), B.to(torch.int32).cpu())
|
||||
dense_result = dense_result.to(torch.float16)
|
||||
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
|
||||
dense_result = dense_result.to(out_dtype)
|
||||
|
||||
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
Reference in New Issue
Block a user