[sparse] Add in out_dtype support (i8i8->bf16, i32) for cusparselt (#119296)

Summary:

Adds in out_dtype support for (i8i8->bf16) and (i8i8->i32) matmul with
cuSPARSELt.

Test Plan:

```
python test/test_sparse_semi_structured.py -k mixed
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119296
Approved by: https://github.com/cpuhrsch, https://github.com/alexsamardzic
This commit is contained in:
Jesse Cai
2024-02-06 14:18:44 -08:00
committed by PyTorch MergeBot
parent 5f69d95b2b
commit 1c1dc0e4e0
4 changed files with 39 additions and 22 deletions

View File

@ -34,6 +34,7 @@ from torch.testing._internal.common_utils import (
from torch.utils._triton import has_triton
CUSPARSELT_NUM_ALG_IDS = 4
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
@ -596,16 +597,16 @@ class TestCUSPARSELT(TestCase):
else:
SparseSemiStructuredTensor._FORCE_CUTLASS = False
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
@parametrize("dense_input_shape", [(128, 128)])
def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device):
def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
A_compressed = torch._cslt_compress(A)
B = torch.rand(dense_input_shape, device=device).to(torch.int8)
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=torch.float16)
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@dtypes(torch.float16, torch.bfloat16)
@ -623,17 +624,19 @@ class TestCUSPARSELT(TestCase):
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
def test_cslt_sparse_mm_alpha_int8_in_f16_out(self, device):
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
B = torch.ones((128, 256), device=device).to(torch.int8).t()
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
alpha = torch.Tensor([2**(-i) if out_dtype is not torch.int32 else 1
for i in range(128)]).cuda()
A_compressed = torch._cslt_compress(A)
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.float16).cpu()
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=out_dtype).cpu()
alpha_scaled = torch.stack([alpha] * 128).t()
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int32).cpu(), B.to(torch.int32).cpu())
dense_result = dense_result.to(torch.float16)
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
dense_result = dense_result.to(out_dtype)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)