mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable UFMT on all of torch/sparse (#130545)
Partially addresses #123062 Ran lintrunner on: - torch/sparse Detail: ``` $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/130545 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7d4f50de19
commit
535016967a
@ -1,23 +1,23 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import Any, Optional, Tuple, List, Callable, Dict
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.sparse._semi_structured_conversions import (
|
||||
sparse_semi_structured_from_dense_cutlass,
|
||||
sparse_semi_structured_to_dense_cutlass
|
||||
sparse_semi_structured_to_dense_cutlass,
|
||||
)
|
||||
from torch.sparse._semi_structured_ops import (
|
||||
fallback_dispatcher,
|
||||
semi_sparse_values,
|
||||
semi_sparse_indices,
|
||||
semi_sparse_detach,
|
||||
semi_sparse_t,
|
||||
semi_sparse_view,
|
||||
semi_sparse_mm,
|
||||
semi_sparse_addmm,
|
||||
semi_sparse_detach,
|
||||
semi_sparse_indices,
|
||||
semi_sparse_linear,
|
||||
semi_sparse_mm,
|
||||
semi_sparse_t,
|
||||
semi_sparse_values,
|
||||
semi_sparse_view,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -175,7 +175,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
def __tensor_unflatten__(
|
||||
cls,
|
||||
inner_tensors,
|
||||
tensor_meta : Tuple[torch.Size, bool, int, bool],
|
||||
tensor_meta: Tuple[torch.Size, bool, int, bool],
|
||||
outer_size,
|
||||
outer_stride,
|
||||
) -> torch.Tensor:
|
||||
@ -186,7 +186,9 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
meta=inner_tensors.get("meta", None),
|
||||
packed_t=inner_tensors.get("packed_t", None),
|
||||
meta_t=inner_tensors.get("meta_t", None),
|
||||
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
|
||||
compressed_swizzled_bitmask=inner_tensors.get(
|
||||
"compressed_swizzled_bitmask", None
|
||||
),
|
||||
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
|
||||
alg_id_cusparselt=alg_id_cusparselt,
|
||||
requires_grad=requires_grad,
|
||||
@ -227,7 +229,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
|
||||
|
||||
@classmethod
|
||||
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
|
||||
def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None:
|
||||
"""
|
||||
Assert that the given tensor is valid for semi-structured sparse compression.
|
||||
"""
|
||||
@ -297,7 +299,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
|
||||
|
||||
@classmethod
|
||||
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
|
||||
def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor":
|
||||
raise NotImplementedError
|
||||
|
||||
def _mm(
|
||||
@ -377,6 +379,7 @@ def to_sparse_semi_structured(
|
||||
|
||||
return SPARSE_SUBCLASS.from_dense(original_tensor)
|
||||
|
||||
|
||||
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
"""
|
||||
This class implements semi-structured sparsity for the CUTLASS backend.
|
||||
@ -388,6 +391,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
|
||||
sparse_semi_structured_from_dense for conversion to the compressed format.
|
||||
"""
|
||||
|
||||
BACKEND = "cutlass"
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
|
||||
@ -417,13 +421,19 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
|
||||
def to_dense(self):
|
||||
assert self.meta is not None and self.packed is not None
|
||||
return sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
) if self.meta.ndim == 2 else super().to_dense()
|
||||
return (
|
||||
sparse_semi_structured_to_dense_cutlass(
|
||||
self.packed,
|
||||
self.meta,
|
||||
)
|
||||
if self.meta.ndim == 2
|
||||
else super().to_dense()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
def prune_dense_static_sort(
|
||||
cls, original_tensor: torch.Tensor, algorithm=""
|
||||
) -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
|
||||
|
||||
@ -463,10 +473,15 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
```
|
||||
"""
|
||||
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=True)
|
||||
(
|
||||
packed,
|
||||
meta,
|
||||
packed_t,
|
||||
meta_t,
|
||||
compressed_swizzled_bitmask,
|
||||
) = torch._sparse_semi_structured_tile(
|
||||
original_tensor, algorithm=algorithm, use_cutlass=True
|
||||
)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
@ -479,11 +494,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
)
|
||||
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
*,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
if isinstance(B, SparseSemiStructuredTensor):
|
||||
raise ValueError(
|
||||
@ -500,9 +511,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
)
|
||||
else:
|
||||
if bias is None:
|
||||
res = torch._sparse_semi_structured_mm(
|
||||
self.packed, self.meta, B
|
||||
)
|
||||
res = torch._sparse_semi_structured_mm(self.packed, self.meta, B)
|
||||
else:
|
||||
res = torch._sparse_semi_structured_addmm(
|
||||
bias, self.packed, self.meta, B
|
||||
@ -521,6 +530,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
|
||||
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
|
||||
"""
|
||||
|
||||
BACKEND = "cusparselt"
|
||||
_DTYPE_SHAPE_CONSTRAINTS = {
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
|
||||
@ -530,7 +540,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
|
||||
def from_dense(
|
||||
cls, original_tensor: torch.Tensor
|
||||
) -> "SparseSemiStructuredTensorCUSPARSELT":
|
||||
cls._validate_device_dim_dtype_shape(original_tensor)
|
||||
return cls(
|
||||
shape=original_tensor.shape,
|
||||
@ -545,7 +557,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
|
||||
def prune_dense_static_sort(
|
||||
cls, original_tensor: torch.Tensor, algorithm=""
|
||||
) -> "SparseSemiStructuredTensor":
|
||||
"""
|
||||
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
|
||||
layout and sparse matmul.
|
||||
@ -576,10 +590,15 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
|
||||
```
|
||||
"""
|
||||
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
|
||||
original_tensor,
|
||||
algorithm=algorithm,
|
||||
use_cutlass=False)
|
||||
(
|
||||
packed,
|
||||
meta,
|
||||
packed_t,
|
||||
meta_t,
|
||||
compressed_swizzled_bitmask,
|
||||
) = torch._sparse_semi_structured_tile(
|
||||
original_tensor, algorithm=algorithm, use_cutlass=False
|
||||
)
|
||||
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
@ -592,11 +611,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||
)
|
||||
|
||||
def _mm(
|
||||
self,
|
||||
B: torch.Tensor,
|
||||
*,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
if isinstance(B, SparseSemiStructuredTensor):
|
||||
raise ValueError(
|
||||
|
Reference in New Issue
Block a user