mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse][quant] Add support for vector alpha in cusparselt mm (#112056)
Summary: This PR adds in support for passing in a alpha Tensor, which represents a tensor of alpha values to fuse into the matmul. ``` cusparselt_sparse_mm = alpha A @ B + bias ``` This operation is necessary for quantization, where we would like to fuse one of the dequant matmuls into the sparse op. Test Plan: ``` python test/test_sparse_semi_structured -k alpha ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/112056 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
f101426790
commit
4cb7dd0fc9
@ -350,7 +350,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
)
|
||||
else:
|
||||
res = torch._cslt_sparse_mm(
|
||||
input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias # type: ignore[arg-type]
|
||||
input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias=bias # type: ignore[arg-type]
|
||||
).t()
|
||||
return res[:row, :]
|
||||
|
||||
@ -369,7 +369,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
).t()
|
||||
else:
|
||||
res = torch._cslt_sparse_mm(
|
||||
input_A.compressed_tensor_cusparselt, input_B_padded, None # type: ignore[arg-type]
|
||||
input_A.compressed_tensor_cusparselt, input_B_padded, bias=None # type: ignore[arg-type]
|
||||
)
|
||||
return res[:, :col]
|
||||
|
||||
@ -384,7 +384,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass
|
||||
)
|
||||
else:
|
||||
res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A_padded.t(), None).t() # type: ignore[arg-type]
|
||||
res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias=None).t() # type: ignore[arg-type]
|
||||
|
||||
return res[:row, :]
|
||||
|
||||
@ -413,7 +413,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
res = torch._cslt_sparse_mm(
|
||||
weight.compressed_tensor_cusparselt, # type: ignore[arg-type]
|
||||
input_tensor_2d_padded.t(),
|
||||
bias
|
||||
bias=bias
|
||||
).t()
|
||||
return res[:row, :].view(*shape[:-1], -1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user