mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add out_dtype support for sparse semi-structured CUTLASS back-end (#116519)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116519 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
ba06951c66
commit
f081c45a34
@ -290,9 +290,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
elif dtype is torch.int8:
|
||||
# test transpose
|
||||
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
|
||||
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
@ -335,7 +333,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
||||
if dtype is torch.int8:
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
|
||||
sparse_result = torch.mm(A, B_sparse.t())
|
||||
else:
|
||||
dense_result = torch.mm(A, B.t())
|
||||
@ -444,7 +442,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
|
||||
if dtype == torch.int8:
|
||||
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
|
||||
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8)
|
||||
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
|
||||
B_t = B.t().contiguous()
|
||||
sparse_res = torch.mm(A_sparse, B_t.t())
|
||||
@ -509,7 +507,8 @@ class TestSparseSemiStructured(TestCase):
|
||||
weight_sparse = compressed.values()
|
||||
meta = compressed.indices()
|
||||
|
||||
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation)
|
||||
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation,
|
||||
out_dtype=dtype_out if dtype == torch.int8 else None)
|
||||
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
|
||||
|
||||
if dtype == torch.float32:
|
||||
|
||||
Reference in New Issue
Block a user