[sparse][quant] Add support for vector alpha in cusparselt mm (#112056)

Summary:

This PR adds in support for passing in a alpha Tensor, which represents
a tensor of alpha values to fuse into the matmul.

```
cusparselt_sparse_mm = alpha A @ B + bias
```

This operation is necessary for quantization, where we would like to
fuse one of the dequant matmuls into the sparse op.

Test Plan:

```
python test/test_sparse_semi_structured -k alpha
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112056
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Jesse Cai
2023-11-29 18:44:49 -08:00
committed by PyTorch MergeBot
parent f101426790
commit 4cb7dd0fc9
5 changed files with 96 additions and 37 deletions

View File

@ -350,7 +350,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
)
else:
res = torch._cslt_sparse_mm(
input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias # type: ignore[arg-type]
input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias=bias # type: ignore[arg-type]
).t()
return res[:row, :]
@ -369,7 +369,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
).t()
else:
res = torch._cslt_sparse_mm(
input_A.compressed_tensor_cusparselt, input_B_padded, None # type: ignore[arg-type]
input_A.compressed_tensor_cusparselt, input_B_padded, bias=None # type: ignore[arg-type]
)
return res[:, :col]
@ -384,7 +384,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass
)
else:
res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A_padded.t(), None).t() # type: ignore[arg-type]
res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias=None).t() # type: ignore[arg-type]
return res[:row, :]
@ -413,7 +413,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
res = torch._cslt_sparse_mm(
weight.compressed_tensor_cusparselt, # type: ignore[arg-type]
input_tensor_2d_padded.t(),
bias
bias=bias
).t()
return res[:row, :].view(*shape[:-1], -1)