mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Simplify ATen sparse semi-structured operators based on CUTLASS (#123473)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123473 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
4f244cfaa0
commit
b2a0b8c446
@ -47,7 +47,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
|
||||
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
|
||||
- `def from_dense()` - backend specific compression routines
|
||||
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_linear)
|
||||
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
|
||||
"""
|
||||
|
||||
_DEFAULT_ALG_ID: int = 0
|
||||
@ -371,11 +371,12 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
"""
|
||||
This class implements semi-structured sparsity for the CUTLASS backend.
|
||||
|
||||
|
||||
In this implementation, the specified elements and metadata are stored seprately,
|
||||
in packed and meta respectively.
|
||||
|
||||
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
|
||||
and sparse_semi_structured_from_dense for conversion to the compressed format.
|
||||
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
|
||||
sparse_semi_structured_from_dense for conversion to the compressed format.
|
||||
"""
|
||||
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
@ -436,9 +437,14 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
f"`{cls_name}` matmul: operation is not supported"
|
||||
)
|
||||
else:
|
||||
res = torch._sparse_semi_structured_linear(
|
||||
B.t(), self.packed, self.meta, bias=bias
|
||||
).t()
|
||||
if bias is None:
|
||||
res = torch._sparse_semi_structured_mm(
|
||||
self.packed, self.meta, B
|
||||
)
|
||||
else:
|
||||
res = torch._sparse_semi_structured_addmm(
|
||||
bias, self.packed, self.meta, B
|
||||
)
|
||||
return res[: self.shape[0]]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user