mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] Add fast semi-structured spasification kernels (#122350)
This PR adds in fast semi-structured sparsification kernels to PyTorch. These kernels allow for accelerated semi-structured sparsification kernels in PyTorch. The kernels have been added as aten native functions In particular, three new functions have been added: * `torch._sparse_semi_structured_tile` This function will return the packed representation and metadata for both X and X', as well as the thread masks. Note that this applies 2:4 sparsity in a 4x4 tile instead of a 1x4 strip as usual. * `torch._sparse_semi_structured_apply` This function takes in an input tensor and thread masks from the above function and returns a packed representation and metadata from applying thread masks to the input tensor. * `torch._sparse_semi_structured_apply_dense` This function does the same thing as above but instead of returning the tensor in the sparse representation it returns it in the dense representation The subclasses have also been updated to add a new `prune_dense_static_sort` classmethod to create sparse tensors with this format. I've added some additional documentatino on how to calculate the compressed tensors needed to create a SparseSemiStructuredTensor oneself. To this end, there are two new helper functions added: `sparse_semi_structured_tile` `compute_compressed_swizzled_bitmask` Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
d8717c2d68
commit
c63a7b5691
@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, List, Callable, Dict
|
||||
import torch
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_from_dense_cutlass,
|
||||
sparse_semi_structured_to_dense_cutlass,
|
||||
sparse_semi_structured_to_dense_cutlass
|
||||
)
|
||||
from torch.sparse._semi_structured_ops import (
|
||||
fallback_dispatcher,
|
||||
@ -56,17 +56,18 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
_FUSE_TRANSPOSE: bool = False
|
||||
_PROTOTYPE_WARNING_SHOWN: bool = False
|
||||
|
||||
BACKEND: str
|
||||
SPARSE_DISPATCH: Dict[Callable, Callable]
|
||||
|
||||
packed: Optional[torch.Tensor]
|
||||
meta: Optional[torch.Tensor]
|
||||
packed_t: Optional[torch.Tensor]
|
||||
meta_t: Optional[torch.Tensor]
|
||||
threads_masks: Optional[torch.Tensor]
|
||||
compressed_swizzled_bitmask: Optional[torch.Tensor]
|
||||
fuse_transpose_cusparselt: bool
|
||||
alg_id_cusparselt: int
|
||||
|
||||
__slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"]
|
||||
__slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
|
||||
|
||||
@staticmethod
|
||||
def __new__( # noqa: PYI034
|
||||
@ -76,7 +77,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta: Optional[torch.Tensor],
|
||||
packed_t: Optional[torch.Tensor],
|
||||
meta_t: Optional[torch.Tensor],
|
||||
threads_masks: Optional[torch.Tensor],
|
||||
compressed_swizzled_bitmask: Optional[torch.Tensor],
|
||||
fuse_transpose_cusparselt: bool = False,
|
||||
alg_id_cusparselt: int = 0,
|
||||
requires_grad: bool = False,
|
||||
@ -95,8 +96,8 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta: The metadata of the original dense tensor, if it is stored separately
|
||||
packed_t: The compressed representation of the transposed original dense tensor
|
||||
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
|
||||
threads_masks: The masks used by the CUTLASS backend to determine which threads should participate in the computation.
|
||||
Used for pointwise ops.
|
||||
compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
|
||||
participate in the computation. Used for pointwise ops.
|
||||
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
|
||||
with a matmul, which is useful in the case of 2:4 sparse training.
|
||||
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
|
||||
@ -124,6 +125,9 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
|
||||
cls._load_dispatch_table()
|
||||
|
||||
# we can also register the classes with dynamo when the warning is shown.
|
||||
torch._dynamo.allow_in_graph(cls)
|
||||
|
||||
if packed is not None:
|
||||
previous_tensor = packed
|
||||
elif packed_t is not None:
|
||||
@ -143,7 +147,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
tensor.meta = meta
|
||||
tensor.packed_t = packed_t
|
||||
tensor.meta_t = meta_t
|
||||
tensor.threads_masks = threads_masks
|
||||
tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
|
||||
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
|
||||
tensor.alg_id_cusparselt = alg_id_cusparselt
|
||||
return tensor
|
||||
@ -181,7 +185,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta=inner_tensors.get("meta", None),
|
||||
packed_t=inner_tensors.get("packed_t", None),
|
||||
meta_t=inner_tensors.get("meta_t", None),
|
||||
threads_masks=inner_tensors.get("threads_masks", None),
|
||||
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
|
||||
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=alg_id_cusparselt,
|
||||
requires_grad=requires_grad,
|
||||
@ -216,6 +220,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
torch.ops.aten.matmul: semi_sparse_mm,
|
||||
torch.ops.aten.addmm: semi_sparse_addmm,
|
||||
torch.ops.aten.linear: semi_sparse_linear,
|
||||
torch.ops.aten._to_copy: fallback_dispatcher,
|
||||
}
|
||||
if custom_dispatch_table is not None:
|
||||
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
||||
@ -359,13 +364,14 @@ def to_sparse_semi_structured(
|
||||
"SparseSemiStructuredTensor only support contiguous input tensors. "
|
||||
)
|
||||
|
||||
sparse_subclass = (
|
||||
# set from _FORCE_CUTLASS flag
|
||||
SPARSE_SUBCLASS = (
|
||||
torch.sparse.SparseSemiStructuredTensorCUTLASS
|
||||
if SparseSemiStructuredTensor._FORCE_CUTLASS
|
||||
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
|
||||
)
|
||||
return sparse_subclass.from_dense(original_tensor)
|
||||
|
||||
return SPARSE_SUBCLASS.from_dense(original_tensor)
|
||||
|
||||
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
"""
|
||||
@ -377,7 +383,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
|
||||
and sparse_semi_structured_from_dense for conversion to the compressed format.
|
||||
"""
|
||||
|
||||
BACKEND = "cutlass"
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
||||
@ -400,19 +406,71 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
meta=meta_tensor_cutlass,
|
||||
packed_t=None,
|
||||
meta_t=None,
|
||||
threads_masks=None,
|
||||
compressed_swizzled_bitmask=None,
|
||||
requires_grad=original_tensor.requires_grad,
|
||||
)
|
||||
|
||||
def to_dense(self):
|
||||
assert self.meta is not None and self.packed is not None
|
||||
return (
|
||||
sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
)
|
||||
if self.meta.ndim == 2
|
||||
else super().to_dense()
|
||||
return sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
) if self.meta.ndim == 2 else super().to_dense()
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
|
||||
|
||||
It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
|
||||
The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
|
||||
|
||||
Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
|
||||
It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
|
||||
pruned dense tensor.
|
||||
Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
|
||||
|
||||
Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
|
||||
This can be used in the backward pass to mask the gradients.
|
||||
|
||||
[9 1 7 4] [9 0 7 0]
|
||||
[1 2 3 0] [0 2 0 0]
|
||||
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
|
||||
[1 2 6 2] [0 0 6 2] -> metadata
|
||||
|
||||
-> pack to transposed CUTLASS -> packed_t
|
||||
semi-structured representation -> metadata_t
|
||||
|
||||
-> compute swizzled bitmask -> compressed_swizzled_bitmask
|
||||
|
||||
|
||||
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
|
||||
```
|
||||
from torch.sparse import SparseSemiStructuredTensorCUTLASS
|
||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
||||
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
||||
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
|
||||
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
|
||||
```
|
||||
"""
|
||||
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=True)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def _mm(
|
||||
@ -453,7 +511,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
|
||||
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
|
||||
"""
|
||||
|
||||
BACKEND = "cusparselt"
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
||||
@ -470,12 +528,59 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
meta=None,
|
||||
packed_t=None,
|
||||
meta_t=None,
|
||||
threads_masks=None,
|
||||
compressed_swizzled_bitmask=None,
|
||||
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
|
||||
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
|
||||
requires_grad=original_tensor.requires_grad,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
|
||||
layout and sparse matmul.
|
||||
|
||||
The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
|
||||
|
||||
[9 1 7 4] [9 0 7 0]
|
||||
[1 2 3 0] [0 2 0 0]
|
||||
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
|
||||
[1 2 6 2] [0 0 6 2]
|
||||
|
||||
-> pack to transposed cuSPARSELt -> packed_t
|
||||
semi-structured representation
|
||||
|
||||
-> compute swizzled bitmask -> compressed_swizzled_bitmask
|
||||
|
||||
|
||||
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
|
||||
```
|
||||
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
|
||||
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
||||
|
||||
pruned = _sparse_semi_structured_tile(dense)
|
||||
packed_cusparselt = torch._cslt_compress(pruned)
|
||||
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
||||
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
||||
|
||||
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
|
||||
```
|
||||
"""
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=False)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
meta=meta,
|
||||
packed_t=packed_t,
|
||||
meta_t=meta_t,
|
||||
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
|
Reference in New Issue
Block a user