mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR adds in fast semi-structured sparsification kernels to PyTorch. These kernels allow for accelerated semi-structured sparsification kernels in PyTorch. The kernels have been added as aten native functions In particular, three new functions have been added: * `torch._sparse_semi_structured_tile` This function will return the packed representation and metadata for both X and X', as well as the thread masks. Note that this applies 2:4 sparsity in a 4x4 tile instead of a 1x4 strip as usual. * `torch._sparse_semi_structured_apply` This function takes in an input tensor and thread masks from the above function and returns a packed representation and metadata from applying thread masks to the input tensor. * `torch._sparse_semi_structured_apply_dense` This function does the same thing as above but instead of returning the tensor in the sparse representation it returns it in the dense representation The subclasses have also been updated to add a new `prune_dense_static_sort` classmethod to create sparse tensors with this format. I've added some additional documentatino on how to calculate the compressed tensors needed to create a SparseSemiStructuredTensor oneself. To this end, there are two new helper functions added: `sparse_semi_structured_tile` `compute_compressed_swizzled_bitmask` Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350 Approved by: https://github.com/cpuhrsch
630 lines
27 KiB
Python
630 lines
27 KiB
Python
import warnings
|
|
from collections import namedtuple
|
|
from typing import Any, Optional, Tuple, List, Callable, Dict
|
|
|
|
import torch
|
|
from torch.sparse._semi_structured_conversions import (
|
|
sparse_semi_structured_from_dense_cutlass,
|
|
sparse_semi_structured_to_dense_cutlass
|
|
)
|
|
from torch.sparse._semi_structured_ops import (
|
|
fallback_dispatcher,
|
|
semi_sparse_values,
|
|
semi_sparse_indices,
|
|
semi_sparse_detach,
|
|
semi_sparse_t,
|
|
semi_sparse_view,
|
|
semi_sparse_mm,
|
|
semi_sparse_addmm,
|
|
semi_sparse_linear,
|
|
)
|
|
|
|
__all__ = [
|
|
"SparseSemiStructuredTensor",
|
|
"SparseSemiStructuredTensorCUTLASS",
|
|
"SparseSemiStructuredTensorCUSPARSELT",
|
|
"to_sparse_semi_structured",
|
|
]
|
|
|
|
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
|
|
"_SEMI_STRUCTURED_SPARSE_CONFIG",
|
|
"sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
|
|
)
|
|
|
|
|
|
class SparseSemiStructuredTensor(torch.Tensor):
|
|
"""
|
|
This class implementes semi-structured sparsity as a Tensor subclass.
|
|
|
|
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
|
|
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
|
|
structured sparsity.
|
|
|
|
There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
|
|
This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
|
|
and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
|
|
Note that as such, this class cannot be insantiated directly.
|
|
|
|
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
|
|
- `def from_dense()` - backend specific compression routines
|
|
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
|
|
"""
|
|
|
|
_DEFAULT_ALG_ID: int = 0
|
|
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
|
|
_FORCE_CUTLASS: bool = True
|
|
_FUSE_TRANSPOSE: bool = False
|
|
_PROTOTYPE_WARNING_SHOWN: bool = False
|
|
|
|
BACKEND: str
|
|
SPARSE_DISPATCH: Dict[Callable, Callable]
|
|
|
|
packed: Optional[torch.Tensor]
|
|
meta: Optional[torch.Tensor]
|
|
packed_t: Optional[torch.Tensor]
|
|
meta_t: Optional[torch.Tensor]
|
|
compressed_swizzled_bitmask: Optional[torch.Tensor]
|
|
fuse_transpose_cusparselt: bool
|
|
alg_id_cusparselt: int
|
|
|
|
__slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
|
|
|
|
@staticmethod
|
|
def __new__( # noqa: PYI034
|
|
cls,
|
|
shape: torch.Size,
|
|
packed: Optional[torch.Tensor],
|
|
meta: Optional[torch.Tensor],
|
|
packed_t: Optional[torch.Tensor],
|
|
meta_t: Optional[torch.Tensor],
|
|
compressed_swizzled_bitmask: Optional[torch.Tensor],
|
|
fuse_transpose_cusparselt: bool = False,
|
|
alg_id_cusparselt: int = 0,
|
|
requires_grad: bool = False,
|
|
):
|
|
"""
|
|
Create a new instance of the tensor subclass from the compressed sparse representation.
|
|
|
|
We have the option to create the subclass with the compressed representations of both X and X', for training.
|
|
For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
|
|
|
|
Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
|
|
|
|
Args:
|
|
shape: The shape of the original dense tensor
|
|
packed: The compressed representation of the original dense tensor
|
|
meta: The metadata of the original dense tensor, if it is stored separately
|
|
packed_t: The compressed representation of the transposed original dense tensor
|
|
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
|
|
compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
|
|
participate in the computation. Used for pointwise ops.
|
|
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
|
|
with a matmul, which is useful in the case of 2:4 sparse training.
|
|
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
|
|
|
|
Returns:
|
|
torch.Tensor: A torch.Tensor wrapper subclass.
|
|
|
|
Raises:
|
|
ValueError: If all of the tensor arguments are None.
|
|
"""
|
|
if not cls._PROTOTYPE_WARNING_SHOWN:
|
|
warnings.warn(
|
|
(
|
|
"The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
|
|
"and will change in the near future. Please open a Github issue "
|
|
"for features requests and see our documentation on the torch.sparse "
|
|
"module for further information about the project."
|
|
),
|
|
UserWarning,
|
|
)
|
|
cls._PROTOTYPE_WARNING_SHOWN = True
|
|
|
|
# Because this only runs onces, we also load the dispatch table here as well.
|
|
# We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
|
|
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
|
|
cls._load_dispatch_table()
|
|
|
|
# we can also register the classes with dynamo when the warning is shown.
|
|
torch._dynamo.allow_in_graph(cls)
|
|
|
|
if packed is not None:
|
|
previous_tensor = packed
|
|
elif packed_t is not None:
|
|
previous_tensor = packed_t
|
|
else:
|
|
raise ValueError("At least one of packed or packed_t must be provided")
|
|
|
|
kwargs = {
|
|
"device": previous_tensor.device,
|
|
"dtype": previous_tensor.dtype,
|
|
"layout": previous_tensor.layout,
|
|
"requires_grad": requires_grad,
|
|
}
|
|
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
|
|
|
|
tensor.packed = packed
|
|
tensor.meta = meta
|
|
tensor.packed_t = packed_t
|
|
tensor.meta_t = meta_t
|
|
tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
|
|
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
|
|
tensor.alg_id_cusparselt = alg_id_cusparselt
|
|
return tensor
|
|
|
|
def __repr__(self) -> str: # type: ignore[override]
|
|
assert hasattr(self, "shape")
|
|
return f"{self.__class__.__name__}(shape={self.shape})"
|
|
|
|
def __tensor_flatten__(
|
|
self,
|
|
) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
|
|
inner_tensors = list(
|
|
filter(lambda x: getattr(self, x) is not None, self.__slots__)
|
|
)
|
|
tensor_meta = (
|
|
self.shape,
|
|
self.fuse_transpose_cusparselt,
|
|
self.alg_id_cusparselt,
|
|
self.requires_grad,
|
|
)
|
|
return inner_tensors, tensor_meta
|
|
|
|
@classmethod
|
|
def __tensor_unflatten__(
|
|
cls,
|
|
inner_tensors,
|
|
tensor_meta : Tuple[torch.Size, bool, int, bool],
|
|
outer_size,
|
|
outer_stride,
|
|
) -> torch.Tensor:
|
|
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
|
|
return cls(
|
|
shape=shape,
|
|
packed=inner_tensors.get("packed", None),
|
|
meta=inner_tensors.get("meta", None),
|
|
packed_t=inner_tensors.get("packed_t", None),
|
|
meta_t=inner_tensors.get("meta_t", None),
|
|
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
|
|
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
|
|
alg_id_cusparselt=alg_id_cusparselt,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
|
|
if func._overloadpacket not in cls.SPARSE_DISPATCH:
|
|
raise NotImplementedError(
|
|
f"{cls.__name__} only supports a specific set of operations, "
|
|
f"can't perform requested op ({func.__name__})"
|
|
)
|
|
return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
|
|
|
|
@classmethod
|
|
def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
|
|
"""
|
|
Loads the op overload sparse dispatch table for the current class.
|
|
"""
|
|
if getattr(cls, "SPARSE_DISPATCH", None) is None:
|
|
cls.SPARSE_DISPATCH = {
|
|
torch.ops.aten.values: semi_sparse_values,
|
|
torch.ops.aten.indices: semi_sparse_indices,
|
|
torch.ops.aten.is_same_size: fallback_dispatcher,
|
|
torch.ops.aten.detach_: fallback_dispatcher,
|
|
torch.ops.aten.detach: semi_sparse_detach,
|
|
torch.ops.aten.t: semi_sparse_t,
|
|
torch.ops.aten.view: semi_sparse_view,
|
|
torch.ops.aten.mm: semi_sparse_mm,
|
|
torch.ops.aten.matmul: semi_sparse_mm,
|
|
torch.ops.aten.addmm: semi_sparse_addmm,
|
|
torch.ops.aten.linear: semi_sparse_linear,
|
|
torch.ops.aten._to_copy: fallback_dispatcher,
|
|
}
|
|
if custom_dispatch_table is not None:
|
|
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
|
|
|
@classmethod
|
|
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
|
|
"""
|
|
Assert that the given tensor is valid for semi-structured sparse compression.
|
|
"""
|
|
# check device
|
|
if not original_tensor.is_cuda:
|
|
raise RuntimeError(
|
|
f"Error original_tensor.device= {original_tensor.device} is not supported! "
|
|
"Only CUDA tensors are currently supported."
|
|
)
|
|
|
|
# check dim
|
|
if original_tensor.dim() != 2:
|
|
raise RuntimeError(
|
|
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
|
|
"Only 2d tensors are currently supported."
|
|
)
|
|
|
|
# check contiguous
|
|
if not original_tensor.is_contiguous():
|
|
raise RuntimeError(
|
|
"Error original_tensor is not contiguous!"
|
|
"Only contiguous tensors are currently supported."
|
|
)
|
|
|
|
# check dtype
|
|
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
|
|
raise RuntimeError(
|
|
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
|
|
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
|
|
)
|
|
|
|
# check shape
|
|
m, n = original_tensor.shape
|
|
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
|
|
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
|
|
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
|
|
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
|
|
raise RuntimeError(
|
|
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
|
|
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
|
|
)
|
|
|
|
@classmethod
|
|
def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculates padding for dense tensor and pads tensor if necessary.
|
|
If padding is not required, this function returns the original tensor.
|
|
"""
|
|
# only 2d matmul
|
|
assert dense_input.dim() == 2
|
|
|
|
# check shape
|
|
m, n = dense_input.shape
|
|
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
|
|
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
|
|
|
|
# calculate padding
|
|
to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
|
|
to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
|
|
if to_pad_m or to_pad_n:
|
|
return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
|
|
else:
|
|
return dense_input
|
|
|
|
def to_dense(self):
|
|
col = self.shape[-1]
|
|
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
|
|
|
|
@classmethod
|
|
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
|
|
raise NotImplementedError
|
|
|
|
def _mm(
|
|
self,
|
|
B: torch.Tensor,
|
|
*,
|
|
bias: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
def to_sparse_semi_structured(
|
|
original_tensor: torch.Tensor,
|
|
transposed: bool = False,
|
|
) -> SparseSemiStructuredTensor:
|
|
"""
|
|
This function converts a dense tensor into a sparse semi-structured tensor.
|
|
It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
|
|
|
|
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
|
|
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
|
|
Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
|
|
`_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
|
|
|
|
Args:
|
|
original_tensor (Tensor): the dense tensor to convert
|
|
transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
|
|
Returns:
|
|
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
|
|
Raises:
|
|
None
|
|
Example:
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
|
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
|
|
tensor([[0., 0., 1., ..., 0., 1., 1.],
|
|
[0., 0., 1., ..., 0., 1., 1.],
|
|
[0., 0., 1., ..., 0., 1., 1.],
|
|
...,
|
|
[0., 0., 1., ..., 0., 1., 1.],
|
|
[0., 0., 1., ..., 0., 1., 1.],
|
|
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
|
|
>>> A_sparse = to_sparse_semi_structured(A)
|
|
SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
|
|
>>> A_sparse.values()
|
|
tensor([[1., 1., 1., ..., 1., 1., 1.],
|
|
[1., 1., 1., ..., 1., 1., 1.],
|
|
[1., 1., 1., ..., 1., 1., 1.],
|
|
...,
|
|
[1., 1., 1., ..., 1., 1., 1.],
|
|
[1., 1., 1., ..., 1., 1., 1.],
|
|
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
|
|
>>> A_sparse.indices()
|
|
tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
...,
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
|
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
|
|
"""
|
|
if transposed:
|
|
raise DeprecationWarning(
|
|
"Setting transpose from to_sparse_semi_structured is deprecated and will be removed in a future release."
|
|
"SparseSemiStructuredTensor only support contiguous input tensors. "
|
|
)
|
|
|
|
# set from _FORCE_CUTLASS flag
|
|
SPARSE_SUBCLASS = (
|
|
torch.sparse.SparseSemiStructuredTensorCUTLASS
|
|
if SparseSemiStructuredTensor._FORCE_CUTLASS
|
|
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
|
|
)
|
|
|
|
return SPARSE_SUBCLASS.from_dense(original_tensor)
|
|
|
|
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
|
"""
|
|
This class implements semi-structured sparsity for the CUTLASS backend.
|
|
|
|
|
|
In this implementation, the specified elements and metadata are stored seprately,
|
|
in packed and meta respectively.
|
|
|
|
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
|
|
sparse_semi_structured_from_dense for conversion to the compressed format.
|
|
"""
|
|
BACKEND = "cutlass"
|
|
_DTYPE_SHAPE_CONSTRAINTS = {
|
|
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
|
|
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
|
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
|
|
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dense(
|
|
cls, original_tensor: torch.Tensor
|
|
) -> "SparseSemiStructuredTensorCUTLASS":
|
|
cls._validate_device_dim_dtype_shape(original_tensor)
|
|
(
|
|
sparse_tensor_cutlass,
|
|
meta_tensor_cutlass,
|
|
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
|
return cls(
|
|
original_tensor.shape,
|
|
packed=sparse_tensor_cutlass,
|
|
meta=meta_tensor_cutlass,
|
|
packed_t=None,
|
|
meta_t=None,
|
|
compressed_swizzled_bitmask=None,
|
|
requires_grad=original_tensor.requires_grad,
|
|
)
|
|
|
|
def to_dense(self):
|
|
assert self.meta is not None and self.packed is not None
|
|
return sparse_semi_structured_to_dense_cutlass(
|
|
self.packed,
|
|
self.meta,
|
|
) if self.meta.ndim == 2 else super().to_dense()
|
|
|
|
@classmethod
|
|
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
|
"""
|
|
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
|
|
|
|
It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
|
|
The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
|
|
|
|
Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
|
|
It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
|
|
pruned dense tensor.
|
|
Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
|
|
|
|
Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
|
|
This can be used in the backward pass to mask the gradients.
|
|
|
|
[9 1 7 4] [9 0 7 0]
|
|
[1 2 3 0] [0 2 0 0]
|
|
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
|
|
[1 2 6 2] [0 0 6 2] -> metadata
|
|
|
|
-> pack to transposed CUTLASS -> packed_t
|
|
semi-structured representation -> metadata_t
|
|
|
|
-> compute swizzled bitmask -> compressed_swizzled_bitmask
|
|
|
|
|
|
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
|
|
```
|
|
from torch.sparse import SparseSemiStructuredTensorCUTLASS
|
|
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
|
|
|
pruned = _sparse_semi_structured_tile(dense)
|
|
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
|
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
|
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
|
|
|
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
|
|
```
|
|
"""
|
|
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
|
|
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
|
original_tensor,
|
|
algorithm=algorithm,
|
|
use_cutlass=True)
|
|
|
|
return cls(
|
|
original_tensor.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
|
|
requires_grad=False,
|
|
)
|
|
|
|
def _mm(
|
|
self,
|
|
B: torch.Tensor,
|
|
*,
|
|
bias: Optional[torch.Tensor] = None,
|
|
**kwargs
|
|
) -> torch.Tensor:
|
|
if isinstance(B, SparseSemiStructuredTensor):
|
|
raise ValueError(
|
|
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
|
|
)
|
|
cls_name = self.__class__.__name__
|
|
if self.ndim != 2 or B.ndim != 2:
|
|
raise NotImplementedError(
|
|
f"`{cls_name}` matmul: Broadcasting is not implemented"
|
|
)
|
|
if self.packed is None or self.meta is None:
|
|
raise NotImplementedError(
|
|
f"`{cls_name}` matmul: operation is not supported"
|
|
)
|
|
else:
|
|
if bias is None:
|
|
res = torch._sparse_semi_structured_mm(
|
|
self.packed, self.meta, B
|
|
)
|
|
else:
|
|
res = torch._sparse_semi_structured_addmm(
|
|
bias, self.packed, self.meta, B
|
|
)
|
|
return res[: self.shape[0]]
|
|
|
|
|
|
class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
|
"""
|
|
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
|
|
packed = [ specified elements of original tensor | metadata ]
|
|
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
|
|
The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
|
|
attributes respectively.
|
|
|
|
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
|
|
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
|
|
"""
|
|
BACKEND = "cusparselt"
|
|
_DTYPE_SHAPE_CONSTRAINTS = {
|
|
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
|
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
|
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
|
|
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(8, 8, 4, 4),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
|
|
cls._validate_device_dim_dtype_shape(original_tensor)
|
|
return cls(
|
|
shape=original_tensor.shape,
|
|
packed=torch._cslt_compress(original_tensor),
|
|
meta=None,
|
|
packed_t=None,
|
|
meta_t=None,
|
|
compressed_swizzled_bitmask=None,
|
|
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
|
|
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
|
|
requires_grad=original_tensor.requires_grad,
|
|
)
|
|
|
|
@classmethod
|
|
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
|
"""
|
|
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
|
|
layout and sparse matmul.
|
|
|
|
The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
|
|
|
|
[9 1 7 4] [9 0 7 0]
|
|
[1 2 3 0] [0 2 0 0]
|
|
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
|
|
[1 2 6 2] [0 0 6 2]
|
|
|
|
-> pack to transposed cuSPARSELt -> packed_t
|
|
semi-structured representation
|
|
|
|
-> compute swizzled bitmask -> compressed_swizzled_bitmask
|
|
|
|
|
|
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
|
|
```
|
|
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
|
|
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
|
|
|
|
pruned = _sparse_semi_structured_tile(dense)
|
|
packed_cusparselt = torch._cslt_compress(pruned)
|
|
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
|
bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
|
|
|
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
|
|
```
|
|
"""
|
|
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
|
original_tensor,
|
|
algorithm=algorithm,
|
|
use_cutlass=False)
|
|
|
|
return cls(
|
|
original_tensor.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
|
|
requires_grad=False,
|
|
)
|
|
|
|
def _mm(
|
|
self,
|
|
B: torch.Tensor,
|
|
*,
|
|
bias: Optional[torch.Tensor] = None,
|
|
**kwargs
|
|
) -> torch.Tensor:
|
|
if isinstance(B, SparseSemiStructuredTensor):
|
|
raise ValueError(
|
|
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
|
|
)
|
|
if self.ndim != 2 or B.ndim != 2:
|
|
raise NotImplementedError(
|
|
f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
|
|
)
|
|
if B.dtype != self.dtype:
|
|
raise NotImplementedError(
|
|
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
|
|
f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
|
|
"This operation is only supported when A and B have the same data type."
|
|
)
|
|
if bias is not None and bias.dtype != self.dtype:
|
|
raise NotImplementedError(
|
|
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
|
|
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
|
|
"This operation is only supported when A, B and C have the same data type."
|
|
)
|
|
if self.packed is None:
|
|
raise NotImplementedError(
|
|
f"`{self.__class__.__name__}` matmul: operation is not supported"
|
|
)
|
|
else:
|
|
res = torch._cslt_sparse_mm(
|
|
self.packed,
|
|
B,
|
|
bias=bias,
|
|
transpose_result=self.fuse_transpose_cusparselt,
|
|
alg_id=self.alg_id_cusparselt,
|
|
)
|
|
return res.t() if self.fuse_transpose_cusparselt else res
|