Simplify ATen sparse semi-structured operators based on CUTLASS (#123473)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123473
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Aleksandar Samardžić
2024-04-10 22:04:28 +02:00
committed by PyTorch MergeBot
parent 4f244cfaa0
commit b2a0b8c446
9 changed files with 1141 additions and 185 deletions

View File

@ -47,7 +47,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
- `def from_dense()` - backend specific compression routines
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_linear)
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
"""
_DEFAULT_ALG_ID: int = 0
@ -371,11 +371,12 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
"""
This class implements semi-structured sparsity for the CUTLASS backend.
In this implementation, the specified elements and metadata are stored seprately,
in packed and meta respectively.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
and sparse_semi_structured_from_dense for conversion to the compressed format.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
sparse_semi_structured_from_dense for conversion to the compressed format.
"""
_DTYPE_SHAPE_CONSTRAINTS = {
@ -436,9 +437,14 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
f"`{cls_name}` matmul: operation is not supported"
)
else:
res = torch._sparse_semi_structured_linear(
B.t(), self.packed, self.meta, bias=bias
).t()
if bias is None:
res = torch._sparse_semi_structured_mm(
self.packed, self.meta, B
)
else:
res = torch._sparse_semi_structured_addmm(
bias, self.packed, self.meta, B
)
return res[: self.shape[0]]