Add out_dtype support for sparse semi-structured CUTLASS back-end (#116519)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116519
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Aleksandar Samardžić
2023-12-28 21:59:59 +01:00
committed by PyTorch MergeBot
parent ba06951c66
commit f081c45a34
5 changed files with 72 additions and 34 deletions

View File

@ -290,9 +290,7 @@ class TestSparseSemiStructured(TestCase):
sparse_result = torch.mm(A_sparse, B.t())
elif dtype is torch.int8:
# test transpose
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
else:
@ -335,7 +333,7 @@ class TestSparseSemiStructured(TestCase):
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
sparse_result = torch.mm(A, B_sparse.t())
else:
dense_result = torch.mm(A, B.t())
@ -444,7 +442,7 @@ class TestSparseSemiStructured(TestCase):
A_sparse = to_sparse_semi_structured(A)
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
if dtype == torch.int8:
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8)
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
B_t = B.t().contiguous()
sparse_res = torch.mm(A_sparse, B_t.t())
@ -509,7 +507,8 @@ class TestSparseSemiStructured(TestCase):
weight_sparse = compressed.values()
meta = compressed.indices()
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation)
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation,
out_dtype=dtype_out if dtype == torch.int8 else None)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
if dtype == torch.float32: