mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enable UFMT on all of torch/sparse (#130545)
Partially addresses #123062 Ran lintrunner on: - torch/sparse Detail: ``` $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/130545 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7d4f50de19
commit
535016967a
@ -1531,10 +1531,6 @@ exclude_patterns = [
|
||||
'torch/signal/__init__.py',
|
||||
'torch/signal/windows/__init__.py',
|
||||
'torch/signal/windows/windows.py',
|
||||
'torch/sparse/__init__.py',
|
||||
'torch/sparse/_semi_structured_conversions.py',
|
||||
'torch/sparse/_triton_ops.py',
|
||||
'torch/sparse/semi_structured.py',
|
||||
'torch/special/__init__.py',
|
||||
'torch/testing/_internal/__init__.py',
|
||||
'torch/testing/_internal/autocast_test_lists.py',
|
||||
|
@ -1,23 +1,23 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# The Tensor classes are added to this module by python_tensor.cpp
|
||||
from typing import Optional, Tuple, List, Union, Any
|
||||
# A workaround to support both TorchScript and MyPy:
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
|
||||
from torch import Tensor
|
||||
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
|
||||
|
||||
# Semi structured sparsity support
|
||||
from .semi_structured import (
|
||||
SparseSemiStructuredTensor,
|
||||
SparseSemiStructuredTensorCUSPARSELT,
|
||||
SparseSemiStructuredTensorCUTLASS,
|
||||
to_sparse_semi_structured
|
||||
to_sparse_semi_structured,
|
||||
)
|
||||
|
||||
# A workaround to support both TorchScript and MyPy:
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import _dtype as DType
|
||||
|
||||
DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]]
|
||||
else:
|
||||
# The JIT doesn't understand Union, nor torch.dtype here
|
||||
@ -26,20 +26,22 @@ else:
|
||||
|
||||
|
||||
__all__ = [
|
||||
'addmm',
|
||||
'check_sparse_tensor_invariants',
|
||||
'mm',
|
||||
'sum',
|
||||
'softmax',
|
||||
'log_softmax',
|
||||
'SparseSemiStructuredTensor',
|
||||
'SparseSemiStructuredTensorCUTLASS',
|
||||
'SparseSemiStructuredTensorCUSPARSELT',
|
||||
'to_sparse_semi_structured',
|
||||
'as_sparse_gradcheck',
|
||||
"addmm",
|
||||
"check_sparse_tensor_invariants",
|
||||
"mm",
|
||||
"sum",
|
||||
"softmax",
|
||||
"log_softmax",
|
||||
"SparseSemiStructuredTensor",
|
||||
"SparseSemiStructuredTensorCUTLASS",
|
||||
"SparseSemiStructuredTensorCUSPARSELT",
|
||||
"to_sparse_semi_structured",
|
||||
"as_sparse_gradcheck",
|
||||
]
|
||||
|
||||
addmm = _add_docstr(_sparse._sparse_addmm, r"""
|
||||
addmm = _add_docstr(
|
||||
_sparse._sparse_addmm,
|
||||
r"""
|
||||
sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
|
||||
|
||||
This function does exact same thing as :func:`torch.addmm` in the forward,
|
||||
@ -58,10 +60,13 @@ Args:
|
||||
mat2 (Tensor): a dense matrix to be multiplied
|
||||
beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
|
||||
alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
mm = _add_docstr(_sparse._sparse_mm, r"""
|
||||
mm = _add_docstr(
|
||||
_sparse._sparse_mm,
|
||||
r"""
|
||||
Performs a matrix multiplication of the sparse matrix :attr:`mat1`
|
||||
and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a
|
||||
:math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
|
||||
@ -132,10 +137,13 @@ Example::
|
||||
>>> y2
|
||||
tensor([[0., 1.],
|
||||
[6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
sampled_addmm = _add_docstr(_sparse.sparse_sampled_addmm, r"""
|
||||
sampled_addmm = _add_docstr(
|
||||
_sparse.sparse_sampled_addmm,
|
||||
r"""
|
||||
sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor
|
||||
|
||||
Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations
|
||||
@ -184,10 +192,11 @@ Examples::
|
||||
col_indices=tensor([0, 1, 2]),
|
||||
values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0',
|
||||
size=(3, 3), nnz=3, layout=torch.sparse_csr)
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
def sum(input: Tensor, dim: DimOrDims = None,
|
||||
dtype: Optional[DType] = None) -> Tensor:
|
||||
|
||||
def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor:
|
||||
r"""Return the sum of each row of the given sparse tensor.
|
||||
|
||||
Returns the sum of each row of the sparse tensor :attr:`input` in the given
|
||||
@ -256,7 +265,9 @@ def sum(input: Tensor, dim: DimOrDims = None,
|
||||
return torch._sparse_sum(input, dtype=dtype)
|
||||
|
||||
|
||||
softmax = _add_docstr(_sparse._sparse_softmax, r"""
|
||||
softmax = _add_docstr(
|
||||
_sparse._sparse_softmax,
|
||||
r"""
|
||||
sparse.softmax(input, dim, *, dtype=None) -> Tensor
|
||||
|
||||
Applies a softmax function.
|
||||
@ -281,10 +292,13 @@ Args:
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. This is useful for preventing data type
|
||||
overflows. Default: None
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
log_softmax = _add_docstr(_sparse._sparse_log_softmax, r"""
|
||||
log_softmax = _add_docstr(
|
||||
_sparse._sparse_log_softmax,
|
||||
r"""
|
||||
sparse.log_softmax(input, dim, *, dtype=None) -> Tensor
|
||||
|
||||
Applies a softmax function followed by logarithm.
|
||||
@ -299,7 +313,8 @@ Args:
|
||||
casted to :attr:`dtype` before the operation is
|
||||
performed. This is useful for preventing data type
|
||||
overflows. Default: None
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
spdiags = _add_docstr(
|
||||
@ -393,7 +408,8 @@ Specifying a positive offset::
|
||||
[0, 0, 3, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]])
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class check_sparse_tensor_invariants:
|
||||
@ -483,12 +499,14 @@ class check_sparse_tensor_invariants:
|
||||
# context manager support
|
||||
def __init__(self, enable=True):
|
||||
self.state = enable
|
||||
self.saved_state : Optional[bool] = None
|
||||
self.saved_state: Optional[bool] = None
|
||||
|
||||
def __enter__(self):
|
||||
if self.saved_state is not None:
|
||||
raise RuntimeError('This context manager instance is already activated.'
|
||||
' Use a different context manager instance for context nesting.')
|
||||
raise RuntimeError(
|
||||
"This context manager instance is already activated."
|
||||
" Use a different context manager instance for context nesting."
|
||||
)
|
||||
self.saved_state = self.is_enabled()
|
||||
torch._C._set_check_sparse_tensor_invariants(self.state)
|
||||
|
||||
@ -499,7 +517,6 @@ class check_sparse_tensor_invariants:
|
||||
|
||||
# decorator support
|
||||
def __call__(self, mth):
|
||||
|
||||
def test_mth(*args, **kwargs):
|
||||
with type(self)(self.state):
|
||||
return mth(*args, **kwargs)
|
||||
@ -531,37 +548,71 @@ def as_sparse_gradcheck(gradcheck):
|
||||
|
||||
Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support.
|
||||
"""
|
||||
masked = kwargs.pop('masked', False)
|
||||
sparse_layouts = {torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
|
||||
sparse_compressed_layouts = {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
|
||||
masked = kwargs.pop("masked", False)
|
||||
sparse_layouts = {
|
||||
torch.sparse_coo,
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
}
|
||||
sparse_compressed_layouts = {
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
torch.sparse_bsr,
|
||||
torch.sparse_bsc,
|
||||
}
|
||||
sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc}
|
||||
STRIDED_REPRESENTATION = '__STRIDED_REPRESENTATION__'
|
||||
STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__"
|
||||
|
||||
def convert_to_strided_representation(args):
|
||||
"""Convert differentiable non-strided tensors to a representation containing differentiable strided tensors."""
|
||||
if not isinstance(args, (list, tuple)):
|
||||
args = args,
|
||||
args = (args,)
|
||||
new_args: List[Any] = []
|
||||
for obj in args:
|
||||
if isinstance(obj, torch.Tensor) and obj.requires_grad and obj.layout in sparse_layouts:
|
||||
if (
|
||||
isinstance(obj, torch.Tensor)
|
||||
and obj.requires_grad
|
||||
and obj.layout in sparse_layouts
|
||||
):
|
||||
d = dict(layout=obj.layout, shape=obj.shape)
|
||||
if not masked:
|
||||
# Materialize unspecified elements with zero values
|
||||
batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
|
||||
blocksize = obj.values().shape[batch_dim + 1:batch_dim + 3] if obj.layout in sparse_block_layouts else None
|
||||
full_mask = torch.ones(obj.shape, device=obj.device, dtype=torch.bool).to_sparse(
|
||||
layout=obj.layout, blocksize=blocksize, dense_dim=obj.dense_dim())
|
||||
blocksize = (
|
||||
obj.values().shape[batch_dim + 1 : batch_dim + 3]
|
||||
if obj.layout in sparse_block_layouts
|
||||
else None
|
||||
)
|
||||
full_mask = torch.ones(
|
||||
obj.shape, device=obj.device, dtype=torch.bool
|
||||
).to_sparse(
|
||||
layout=obj.layout,
|
||||
blocksize=blocksize,
|
||||
dense_dim=obj.dense_dim(),
|
||||
)
|
||||
obj = obj.to_dense().sparse_mask(full_mask)
|
||||
if obj.layout is torch.sparse_coo:
|
||||
d.update(indices=obj._indices(), is_coalesced=obj.is_coalesced())
|
||||
d.update(
|
||||
indices=obj._indices(), is_coalesced=obj.is_coalesced()
|
||||
)
|
||||
values = obj._values()
|
||||
elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
d.update(compressed_indices=obj.crow_indices(), plain_indices=obj.col_indices())
|
||||
d.update(
|
||||
compressed_indices=obj.crow_indices(),
|
||||
plain_indices=obj.col_indices(),
|
||||
)
|
||||
values = obj.values()
|
||||
else:
|
||||
d.update(compressed_indices=obj.ccol_indices(), plain_indices=obj.row_indices())
|
||||
d.update(
|
||||
compressed_indices=obj.ccol_indices(),
|
||||
plain_indices=obj.row_indices(),
|
||||
)
|
||||
values = obj.values()
|
||||
new_args.extend((STRIDED_REPRESENTATION, d, values.requires_grad_(True)))
|
||||
new_args.extend(
|
||||
(STRIDED_REPRESENTATION, d, values.requires_grad_(True))
|
||||
)
|
||||
else:
|
||||
new_args.append(obj)
|
||||
return tuple(new_args)
|
||||
@ -574,13 +625,25 @@ def as_sparse_gradcheck(gradcheck):
|
||||
a = args.pop(0)
|
||||
if a == STRIDED_REPRESENTATION:
|
||||
d, values = args.pop(0), args.pop(0)
|
||||
if d['layout'] is torch.sparse_coo:
|
||||
a = torch.sparse_coo_tensor(d['indices'], values, size=d['shape'], is_coalesced=d['is_coalesced'])
|
||||
elif d['layout'] in sparse_compressed_layouts:
|
||||
a = torch.sparse_compressed_tensor(d['compressed_indices'], d['plain_indices'], values,
|
||||
size=d['shape'], layout=d['layout'])
|
||||
if d["layout"] is torch.sparse_coo:
|
||||
a = torch.sparse_coo_tensor(
|
||||
d["indices"],
|
||||
values,
|
||||
size=d["shape"],
|
||||
is_coalesced=d["is_coalesced"],
|
||||
)
|
||||
elif d["layout"] in sparse_compressed_layouts:
|
||||
a = torch.sparse_compressed_tensor(
|
||||
d["compressed_indices"],
|
||||
d["plain_indices"],
|
||||
values,
|
||||
size=d["shape"],
|
||||
layout=d["layout"],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f'conversion of {d["layout"]} strided representation to tensor')
|
||||
raise NotImplementedError(
|
||||
f'conversion of {d["layout"]} strided representation to tensor'
|
||||
)
|
||||
new_args.append(a)
|
||||
return tuple(new_args)
|
||||
|
||||
@ -591,12 +654,25 @@ def as_sparse_gradcheck(gradcheck):
|
||||
# tensors:
|
||||
outputs = func(*restored_args, **kwargs)
|
||||
|
||||
strided_outputs = tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
|
||||
strided_outputs = tuple((o.to_dense(masked_grad=masked)
|
||||
if isinstance(o, torch.Tensor) and o.requires_grad and o.layout in sparse_layouts else o)
|
||||
for o in strided_outputs)
|
||||
strided_outputs = (
|
||||
tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
|
||||
)
|
||||
strided_outputs = tuple(
|
||||
(
|
||||
o.to_dense(masked_grad=masked)
|
||||
if isinstance(o, torch.Tensor)
|
||||
and o.requires_grad
|
||||
and o.layout in sparse_layouts
|
||||
else o
|
||||
)
|
||||
for o in strided_outputs
|
||||
)
|
||||
|
||||
return strided_outputs if isinstance(outputs, (list, tuple)) else strided_outputs[0]
|
||||
return (
|
||||
strided_outputs
|
||||
if isinstance(outputs, (list, tuple))
|
||||
else strided_outputs[0]
|
||||
)
|
||||
|
||||
args = (func_wrapper, convert_to_strided_representation(inputs))
|
||||
|
||||
|
@ -342,11 +342,15 @@ def _compute_compressed_swizzled_bitmask(dense):
|
||||
# [0 0 1 1]
|
||||
|
||||
# reshape tensor to expand tiles into 8-bit vectors
|
||||
bitmask_binary_representation = bitmask_4x4_chunks.reshape(*bitmask_4x4_chunks.shape[:2], 4, 2, 8)
|
||||
bitmask_binary_representation = bitmask_4x4_chunks.reshape(
|
||||
*bitmask_4x4_chunks.shape[:2], 4, 2, 8
|
||||
)
|
||||
|
||||
# to convert from binary representaiton, we can do a matmul with powers of two
|
||||
powers_of_two = 2**torch.arange(8, dtype=torch.float, device="cuda")
|
||||
powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda")
|
||||
# To run on GPU: cast to float to do matmul and then cast back
|
||||
compressed_swizzled_bitmask = (bitmask_binary_representation.to(torch.float) @ powers_of_two).to(torch.uint8)
|
||||
compressed_swizzled_bitmask = (
|
||||
bitmask_binary_representation.to(torch.float) @ powers_of_two
|
||||
).to(torch.uint8)
|
||||
|
||||
return compressed_swizzled_bitmask
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,23 +1,23 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import Any, Optional, Tuple, List, Callable, Dict
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_from_dense_cutlass,
|
||||
sparse_semi_structured_to_dense_cutlass
|
||||
sparse_semi_structured_to_dense_cutlass,
|
||||
)
|
||||
from torch.sparse._semi_structured_ops import (
|
||||
fallback_dispatcher,
|
||||
semi_sparse_values,
|
||||
semi_sparse_indices,
|
||||
semi_sparse_detach,
|
||||
semi_sparse_t,
|
||||
semi_sparse_view,
|
||||
semi_sparse_mm,
|
||||
semi_sparse_addmm,
|
||||
semi_sparse_detach,
|
||||
semi_sparse_indices,
|
||||
semi_sparse_linear,
|
||||
semi_sparse_mm,
|
||||
semi_sparse_t,
|
||||
semi_sparse_values,
|
||||
semi_sparse_view,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -175,7 +175,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
def __tensor_unflatten__(
|
||||
cls,
|
||||
inner_tensors,
|
||||
tensor_meta : Tuple[torch.Size, bool, int, bool],
|
||||
tensor_meta: Tuple[torch.Size, bool, int, bool],
|
||||
outer_size,
|
||||
outer_stride,
|
||||
) -> torch.Tensor:
|
||||
@ -186,7 +186,9 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta=inner_tensors.get("meta", None),
|
||||
packed_t=inner_tensors.get("packed_t", None),
|
||||
meta_t=inner_tensors.get("meta_t", None),
|
||||
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
|
||||
compressed_swizzled_bitmask=inner_tensors.get(
|
||||
"compressed_swizzled_bitmask", None
|
||||
),
|
||||
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=alg_id_cusparselt,
|
||||
requires_grad=requires_grad,
|
||||
@ -227,7 +229,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
||||
|
||||
@classmethod
|
||||
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
|
||||
def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None:
|
||||
"""
|
||||
Assert that the given tensor is valid for semi-structured sparse compression.
|
||||
"""
|
||||
@ -297,7 +299,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
|
||||
|
||||
@classmethod
|
||||
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
|
||||
def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor":
|
||||
raise NotImplementedError
|
||||
|
||||
def _mm(
|
||||
@ -377,6 +379,7 @@ def to_sparse_semi_structured(
|
||||
|
||||
return SPARSE_SUBCLASS.from_dense(original_tensor)
|
||||
|
||||
|
||||
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
"""
|
||||
This class implements semi-structured sparsity for the CUTLASS backend.
|
||||
@ -388,6 +391,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
|
||||
sparse_semi_structured_from_dense for conversion to the compressed format.
|
||||
"""
|
||||
|
||||
BACKEND = "cutlass"
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
|
||||
@ -417,13 +421,19 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
|
||||
def to_dense(self):
|
||||
assert self.meta is not None and self.packed is not None
|
||||
return sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
) if self.meta.ndim == 2 else super().to_dense()
|
||||
return (
|
||||
sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
)
|
||||
if self.meta.ndim == 2
|
||||
else super().to_dense()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
def prune_dense_static_sort(
|
||||
cls, original_tensor: torch.Tensor, algorithm=""
|
||||
) -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
|
||||
|
||||
@ -463,10 +473,15 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
```
|
||||
"""
|
||||
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=True)
|
||||
(
|
||||
packed,
|
||||
meta,
|
||||
packed_t,
|
||||
meta_t,
|
||||
compressed_swizzled_bitmask,
|
||||
) = torch._sparse_semi_structured_tile(
|
||||
original_tensor, algorithm=algorithm, use_cutlass=True
|
||||
)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
@ -479,11 +494,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
)
|
||||
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
*,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
if isinstance(B, SparseSemiStructuredTensor):
|
||||
raise ValueError(
|
||||
@ -500,9 +511,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
)
|
||||
else:
|
||||
if bias is None:
|
||||
res = torch._sparse_semi_structured_mm(
|
||||
self.packed, self.meta, B
|
||||
)
|
||||
res = torch._sparse_semi_structured_mm(self.packed, self.meta, B)
|
||||
else:
|
||||
res = torch._sparse_semi_structured_addmm(
|
||||
bias, self.packed, self.meta, B
|
||||
@ -521,6 +530,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
|
||||
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
|
||||
"""
|
||||
|
||||
BACKEND = "cusparselt"
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||
@ -530,7 +540,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
|
||||
def from_dense(
|
||||
cls, original_tensor: torch.Tensor
|
||||
) -> "SparseSemiStructuredTensorCUSPARSELT":
|
||||
cls._validate_device_dim_dtype_shape(original_tensor)
|
||||
return cls(
|
||||
shape=original_tensor.shape,
|
||||
@ -545,7 +557,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
def prune_dense_static_sort(
|
||||
cls, original_tensor: torch.Tensor, algorithm=""
|
||||
) -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
|
||||
layout and sparse matmul.
|
||||
@ -576,10 +590,15 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
|
||||
```
|
||||
"""
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=False)
|
||||
(
|
||||
packed,
|
||||
meta,
|
||||
packed_t,
|
||||
meta_t,
|
||||
compressed_swizzled_bitmask,
|
||||
) = torch._sparse_semi_structured_tile(
|
||||
original_tensor, algorithm=algorithm, use_cutlass=False
|
||||
)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
@ -592,11 +611,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
)
|
||||
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
*,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
if isinstance(B, SparseSemiStructuredTensor):
|
||||
raise ValueError(
|
||||
|
Reference in New Issue
Block a user