[sparse] Add fast semi-structured spasification kernels (#122350)

This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Jesse Cai
2024-04-12 11:44:58 -07:00
committed by PyTorch MergeBot
parent 383d2d1f6c
commit 14b2273b0c
12 changed files with 2062 additions and 126 deletions

View File

@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, List, Callable, Dict
import torch
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
sparse_semi_structured_to_dense_cutlass
)
from torch.sparse._semi_structured_ops import (
fallback_dispatcher,
@ -56,17 +56,18 @@ class SparseSemiStructuredTensor(torch.Tensor):
_FUSE_TRANSPOSE: bool = False
_PROTOTYPE_WARNING_SHOWN: bool = False
BACKEND: str
SPARSE_DISPATCH: Dict[Callable, Callable]
packed: Optional[torch.Tensor]
meta: Optional[torch.Tensor]
packed_t: Optional[torch.Tensor]
meta_t: Optional[torch.Tensor]
threads_masks: Optional[torch.Tensor]
compressed_swizzled_bitmask: Optional[torch.Tensor]
fuse_transpose_cusparselt: bool
alg_id_cusparselt: int
__slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"]
__slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
@staticmethod
def __new__( # noqa: PYI034
@ -76,7 +77,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
meta: Optional[torch.Tensor],
packed_t: Optional[torch.Tensor],
meta_t: Optional[torch.Tensor],
threads_masks: Optional[torch.Tensor],
compressed_swizzled_bitmask: Optional[torch.Tensor],
fuse_transpose_cusparselt: bool = False,
alg_id_cusparselt: int = 0,
requires_grad: bool = False,
@ -95,8 +96,8 @@ class SparseSemiStructuredTensor(torch.Tensor):
meta: The metadata of the original dense tensor, if it is stored separately
packed_t: The compressed representation of the transposed original dense tensor
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
threads_masks: The masks used by the CUTLASS backend to determine which threads should participate in the computation.
Used for pointwise ops.
compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
participate in the computation. Used for pointwise ops.
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
with a matmul, which is useful in the case of 2:4 sparse training.
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
@ -124,6 +125,9 @@ class SparseSemiStructuredTensor(torch.Tensor):
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
cls._load_dispatch_table()
# we can also register the classes with dynamo when the warning is shown.
torch._dynamo.allow_in_graph(cls)
if packed is not None:
previous_tensor = packed
elif packed_t is not None:
@ -143,7 +147,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
tensor.meta = meta
tensor.packed_t = packed_t
tensor.meta_t = meta_t
tensor.threads_masks = threads_masks
tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
tensor.alg_id_cusparselt = alg_id_cusparselt
return tensor
@ -181,7 +185,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
meta=inner_tensors.get("meta", None),
packed_t=inner_tensors.get("packed_t", None),
meta_t=inner_tensors.get("meta_t", None),
threads_masks=inner_tensors.get("threads_masks", None),
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
alg_id_cusparselt=alg_id_cusparselt,
requires_grad=requires_grad,
@ -216,6 +220,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
torch.ops.aten.matmul: semi_sparse_mm,
torch.ops.aten.addmm: semi_sparse_addmm,
torch.ops.aten.linear: semi_sparse_linear,
torch.ops.aten._to_copy: fallback_dispatcher,
}
if custom_dispatch_table is not None:
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
@ -359,13 +364,14 @@ def to_sparse_semi_structured(
"SparseSemiStructuredTensor only support contiguous input tensors. "
)
sparse_subclass = (
# set from _FORCE_CUTLASS flag
SPARSE_SUBCLASS = (
torch.sparse.SparseSemiStructuredTensorCUTLASS
if SparseSemiStructuredTensor._FORCE_CUTLASS
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
)
return sparse_subclass.from_dense(original_tensor)
return SPARSE_SUBCLASS.from_dense(original_tensor)
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
"""
@ -378,7 +384,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
sparse_semi_structured_from_dense for conversion to the compressed format.
"""
BACKEND = "cutlass"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
@ -401,19 +407,71 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
meta=meta_tensor_cutlass,
packed_t=None,
meta_t=None,
threads_masks=None,
compressed_swizzled_bitmask=None,
requires_grad=original_tensor.requires_grad,
)
def to_dense(self):
assert self.meta is not None and self.packed is not None
return (
sparse_semi_structured_to_dense_cutlass(
self.packed,
self.meta,
)
if self.meta.ndim == 2
else super().to_dense()
return sparse_semi_structured_to_dense_cutlass(
self.packed,
self.meta,
) if self.meta.ndim == 2 else super().to_dense()
@classmethod
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
"""
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
pruned dense tensor.
Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
This can be used in the backward pass to mask the gradients.
[9 1 7 4] [9 0 7 0]
[1 2 3 0] [0 2 0 0]
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
[1 2 6 2] [0 0 6 2] -> metadata
-> pack to transposed CUTLASS -> packed_t
semi-structured representation -> metadata_t
-> compute swizzled bitmask -> compressed_swizzled_bitmask
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
```
from torch.sparse import SparseSemiStructuredTensorCUTLASS
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
pruned = _sparse_semi_structured_tile(dense)
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
```
"""
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
original_tensor,
algorithm=algorithm,
use_cutlass=True)
return cls(
original_tensor.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
requires_grad=False,
)
def _mm(
@ -459,7 +517,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
"""
BACKEND = "cusparselt"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
@ -476,12 +534,59 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
meta=None,
packed_t=None,
meta_t=None,
threads_masks=None,
compressed_swizzled_bitmask=None,
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
requires_grad=original_tensor.requires_grad,
)
@classmethod
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
"""
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
layout and sparse matmul.
The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
[9 1 7 4] [9 0 7 0]
[1 2 3 0] [0 2 0 0]
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
[1 2 6 2] [0 0 6 2]
-> pack to transposed cuSPARSELt -> packed_t
semi-structured representation
-> compute swizzled bitmask -> compressed_swizzled_bitmask
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
```
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
pruned = _sparse_semi_structured_tile(dense)
packed_cusparselt = torch._cslt_compress(pruned)
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
```
"""
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
original_tensor,
algorithm=algorithm,
use_cutlass=False)
return cls(
original_tensor.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
requires_grad=False,
)
def _mm(
self,
B: torch.Tensor,