Files
pytorch/torch/sparse/semi_structured.py
Jesse Cai 14b2273b0c [sparse] Add fast semi-structured spasification kernels (#122350)
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
2024-04-16 20:31:52 +00:00

630 lines
27 KiB
Python

import warnings
from collections import namedtuple
from typing import Any, Optional, Tuple, List, Callable, Dict
import torch
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass
)
from torch.sparse._semi_structured_ops import (
fallback_dispatcher,
semi_sparse_values,
semi_sparse_indices,
semi_sparse_detach,
semi_sparse_t,
semi_sparse_view,
semi_sparse_mm,
semi_sparse_addmm,
semi_sparse_linear,
)
__all__ = [
"SparseSemiStructuredTensor",
"SparseSemiStructuredTensorCUTLASS",
"SparseSemiStructuredTensorCUSPARSELT",
"to_sparse_semi_structured",
]
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
"_SEMI_STRUCTURED_SPARSE_CONFIG",
"sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
)
class SparseSemiStructuredTensor(torch.Tensor):
"""
This class implementes semi-structured sparsity as a Tensor subclass.
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
structured sparsity.
There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
Note that as such, this class cannot be insantiated directly.
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
- `def from_dense()` - backend specific compression routines
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
"""
_DEFAULT_ALG_ID: int = 0
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
_FORCE_CUTLASS: bool = True
_FUSE_TRANSPOSE: bool = False
_PROTOTYPE_WARNING_SHOWN: bool = False
BACKEND: str
SPARSE_DISPATCH: Dict[Callable, Callable]
packed: Optional[torch.Tensor]
meta: Optional[torch.Tensor]
packed_t: Optional[torch.Tensor]
meta_t: Optional[torch.Tensor]
compressed_swizzled_bitmask: Optional[torch.Tensor]
fuse_transpose_cusparselt: bool
alg_id_cusparselt: int
__slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
@staticmethod
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
packed: Optional[torch.Tensor],
meta: Optional[torch.Tensor],
packed_t: Optional[torch.Tensor],
meta_t: Optional[torch.Tensor],
compressed_swizzled_bitmask: Optional[torch.Tensor],
fuse_transpose_cusparselt: bool = False,
alg_id_cusparselt: int = 0,
requires_grad: bool = False,
):
"""
Create a new instance of the tensor subclass from the compressed sparse representation.
We have the option to create the subclass with the compressed representations of both X and X', for training.
For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
Args:
shape: The shape of the original dense tensor
packed: The compressed representation of the original dense tensor
meta: The metadata of the original dense tensor, if it is stored separately
packed_t: The compressed representation of the transposed original dense tensor
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
participate in the computation. Used for pointwise ops.
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
with a matmul, which is useful in the case of 2:4 sparse training.
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
Returns:
torch.Tensor: A torch.Tensor wrapper subclass.
Raises:
ValueError: If all of the tensor arguments are None.
"""
if not cls._PROTOTYPE_WARNING_SHOWN:
warnings.warn(
(
"The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
"and will change in the near future. Please open a Github issue "
"for features requests and see our documentation on the torch.sparse "
"module for further information about the project."
),
UserWarning,
)
cls._PROTOTYPE_WARNING_SHOWN = True
# Because this only runs onces, we also load the dispatch table here as well.
# We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
cls._load_dispatch_table()
# we can also register the classes with dynamo when the warning is shown.
torch._dynamo.allow_in_graph(cls)
if packed is not None:
previous_tensor = packed
elif packed_t is not None:
previous_tensor = packed_t
else:
raise ValueError("At least one of packed or packed_t must be provided")
kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
tensor.packed = packed
tensor.meta = meta
tensor.packed_t = packed_t
tensor.meta_t = meta_t
tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
tensor.alg_id_cusparselt = alg_id_cusparselt
return tensor
def __repr__(self) -> str: # type: ignore[override]
assert hasattr(self, "shape")
return f"{self.__class__.__name__}(shape={self.shape})"
def __tensor_flatten__(
self,
) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (
self.shape,
self.fuse_transpose_cusparselt,
self.alg_id_cusparselt,
self.requires_grad,
)
return inner_tensors, tensor_meta
@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta : Tuple[torch.Size, bool, int, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
return cls(
shape=shape,
packed=inner_tensors.get("packed", None),
meta=inner_tensors.get("meta", None),
packed_t=inner_tensors.get("packed_t", None),
meta_t=inner_tensors.get("meta_t", None),
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
alg_id_cusparselt=alg_id_cusparselt,
requires_grad=requires_grad,
)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
if func._overloadpacket not in cls.SPARSE_DISPATCH:
raise NotImplementedError(
f"{cls.__name__} only supports a specific set of operations, "
f"can't perform requested op ({func.__name__})"
)
return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
@classmethod
def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
"""
Loads the op overload sparse dispatch table for the current class.
"""
if getattr(cls, "SPARSE_DISPATCH", None) is None:
cls.SPARSE_DISPATCH = {
torch.ops.aten.values: semi_sparse_values,
torch.ops.aten.indices: semi_sparse_indices,
torch.ops.aten.is_same_size: fallback_dispatcher,
torch.ops.aten.detach_: fallback_dispatcher,
torch.ops.aten.detach: semi_sparse_detach,
torch.ops.aten.t: semi_sparse_t,
torch.ops.aten.view: semi_sparse_view,
torch.ops.aten.mm: semi_sparse_mm,
torch.ops.aten.matmul: semi_sparse_mm,
torch.ops.aten.addmm: semi_sparse_addmm,
torch.ops.aten.linear: semi_sparse_linear,
torch.ops.aten._to_copy: fallback_dispatcher,
}
if custom_dispatch_table is not None:
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
@classmethod
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
"""
Assert that the given tensor is valid for semi-structured sparse compression.
"""
# check device
if not original_tensor.is_cuda:
raise RuntimeError(
f"Error original_tensor.device= {original_tensor.device} is not supported! "
"Only CUDA tensors are currently supported."
)
# check dim
if original_tensor.dim() != 2:
raise RuntimeError(
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
"Only 2d tensors are currently supported."
)
# check contiguous
if not original_tensor.is_contiguous():
raise RuntimeError(
"Error original_tensor is not contiguous!"
"Only contiguous tensors are currently supported."
)
# check dtype
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
raise RuntimeError(
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
)
# check shape
m, n = original_tensor.shape
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
raise RuntimeError(
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
)
@classmethod
def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
"""
Calculates padding for dense tensor and pads tensor if necessary.
If padding is not required, this function returns the original tensor.
"""
# only 2d matmul
assert dense_input.dim() == 2
# check shape
m, n = dense_input.shape
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
# calculate padding
to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
if to_pad_m or to_pad_n:
return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
else:
return dense_input
def to_dense(self):
col = self.shape[-1]
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
raise NotImplementedError
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
raise NotImplementedError
def to_sparse_semi_structured(
original_tensor: torch.Tensor,
transposed: bool = False,
) -> SparseSemiStructuredTensor:
"""
This function converts a dense tensor into a sparse semi-structured tensor.
It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
`_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
Args:
original_tensor (Tensor): the dense tensor to convert
transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
Returns:
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
Raises:
None
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
tensor([[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
...,
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
>>> A_sparse = to_sparse_semi_structured(A)
SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
>>> A_sparse.values()
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
>>> A_sparse.indices()
tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
...,
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
"""
if transposed:
raise DeprecationWarning(
"Setting transpose from to_sparse_semi_structured is deprecated and will be removed in a future release."
"SparseSemiStructuredTensor only support contiguous input tensors. "
)
# set from _FORCE_CUTLASS flag
SPARSE_SUBCLASS = (
torch.sparse.SparseSemiStructuredTensorCUTLASS
if SparseSemiStructuredTensor._FORCE_CUTLASS
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
)
return SPARSE_SUBCLASS.from_dense(original_tensor)
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
"""
This class implements semi-structured sparsity for the CUTLASS backend.
In this implementation, the specified elements and metadata are stored seprately,
in packed and meta respectively.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
sparse_semi_structured_from_dense for conversion to the compressed format.
"""
BACKEND = "cutlass"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
}
@classmethod
def from_dense(
cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUTLASS":
cls._validate_device_dim_dtype_shape(original_tensor)
(
sparse_tensor_cutlass,
meta_tensor_cutlass,
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
return cls(
original_tensor.shape,
packed=sparse_tensor_cutlass,
meta=meta_tensor_cutlass,
packed_t=None,
meta_t=None,
compressed_swizzled_bitmask=None,
requires_grad=original_tensor.requires_grad,
)
def to_dense(self):
assert self.meta is not None and self.packed is not None
return sparse_semi_structured_to_dense_cutlass(
self.packed,
self.meta,
) if self.meta.ndim == 2 else super().to_dense()
@classmethod
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
"""
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
pruned dense tensor.
Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
This can be used in the backward pass to mask the gradients.
[9 1 7 4] [9 0 7 0]
[1 2 3 0] [0 2 0 0]
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
[1 2 6 2] [0 0 6 2] -> metadata
-> pack to transposed CUTLASS -> packed_t
semi-structured representation -> metadata_t
-> compute swizzled bitmask -> compressed_swizzled_bitmask
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
```
from torch.sparse import SparseSemiStructuredTensorCUTLASS
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
pruned = _sparse_semi_structured_tile(dense)
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
```
"""
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
original_tensor,
algorithm=algorithm,
use_cutlass=True)
return cls(
original_tensor.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
requires_grad=False,
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
)
cls_name = self.__class__.__name__
if self.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
f"`{cls_name}` matmul: Broadcasting is not implemented"
)
if self.packed is None or self.meta is None:
raise NotImplementedError(
f"`{cls_name}` matmul: operation is not supported"
)
else:
if bias is None:
res = torch._sparse_semi_structured_mm(
self.packed, self.meta, B
)
else:
res = torch._sparse_semi_structured_addmm(
bias, self.packed, self.meta, B
)
return res[: self.shape[0]]
class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
"""
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
packed = [ specified elements of original tensor | metadata ]
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
attributes respectively.
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
"""
BACKEND = "cusparselt"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(8, 8, 4, 4),
}
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor)
return cls(
shape=original_tensor.shape,
packed=torch._cslt_compress(original_tensor),
meta=None,
packed_t=None,
meta_t=None,
compressed_swizzled_bitmask=None,
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
requires_grad=original_tensor.requires_grad,
)
@classmethod
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
"""
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
layout and sparse matmul.
The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
[9 1 7 4] [9 0 7 0]
[1 2 3 0] [0 2 0 0]
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
[1 2 6 2] [0 0 6 2]
-> pack to transposed cuSPARSELt -> packed_t
semi-structured representation
-> compute swizzled bitmask -> compressed_swizzled_bitmask
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
```
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
pruned = _sparse_semi_structured_tile(dense)
packed_cusparselt = torch._cslt_compress(pruned)
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
```
"""
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
original_tensor,
algorithm=algorithm,
use_cutlass=False)
return cls(
original_tensor.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
requires_grad=False,
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
)
if self.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
)
if B.dtype != self.dtype:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
"This operation is only supported when A and B have the same data type."
)
if bias is not None and bias.dtype != self.dtype:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
"This operation is only supported when A, B and C have the same data type."
)
if self.packed is None:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: operation is not supported"
)
else:
res = torch._cslt_sparse_mm(
self.packed,
B,
bias=bias,
transpose_result=self.fuse_transpose_cusparselt,
alg_id=self.alg_id_cusparselt,
)
return res.t() if self.fuse_transpose_cusparselt else res