Enable UFMT on all of torch/sparse (#130545)

Partially addresses #123062
Ran lintrunner on:
- torch/sparse

Detail:
```
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130545
Approved by: https://github.com/ezyang
This commit is contained in:
WeiChunyu-star
2024-07-15 22:35:52 +00:00
committed by PyTorch MergeBot
parent 7d4f50de19
commit 535016967a
5 changed files with 884 additions and 376 deletions

View File

@ -1,23 +1,23 @@
# mypy: allow-untyped-defs
import warnings
from collections import namedtuple
from typing import Any, Optional, Tuple, List, Callable, Dict
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass
sparse_semi_structured_to_dense_cutlass,
)
from torch.sparse._semi_structured_ops import (
fallback_dispatcher,
semi_sparse_values,
semi_sparse_indices,
semi_sparse_detach,
semi_sparse_t,
semi_sparse_view,
semi_sparse_mm,
semi_sparse_addmm,
semi_sparse_detach,
semi_sparse_indices,
semi_sparse_linear,
semi_sparse_mm,
semi_sparse_t,
semi_sparse_values,
semi_sparse_view,
)
__all__ = [
@ -175,7 +175,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta : Tuple[torch.Size, bool, int, bool],
tensor_meta: Tuple[torch.Size, bool, int, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
@ -186,7 +186,9 @@ class SparseSemiStructuredTensor(torch.Tensor):
meta=inner_tensors.get("meta", None),
packed_t=inner_tensors.get("packed_t", None),
meta_t=inner_tensors.get("meta_t", None),
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
compressed_swizzled_bitmask=inner_tensors.get(
"compressed_swizzled_bitmask", None
),
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
alg_id_cusparselt=alg_id_cusparselt,
requires_grad=requires_grad,
@ -227,7 +229,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
@classmethod
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None:
"""
Assert that the given tensor is valid for semi-structured sparse compression.
"""
@ -297,7 +299,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor":
raise NotImplementedError
def _mm(
@ -377,6 +379,7 @@ def to_sparse_semi_structured(
return SPARSE_SUBCLASS.from_dense(original_tensor)
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
"""
This class implements semi-structured sparsity for the CUTLASS backend.
@ -388,6 +391,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
sparse_semi_structured_from_dense for conversion to the compressed format.
"""
BACKEND = "cutlass"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
@ -417,13 +421,19 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
def to_dense(self):
assert self.meta is not None and self.packed is not None
return sparse_semi_structured_to_dense_cutlass(
self.packed,
self.meta,
) if self.meta.ndim == 2 else super().to_dense()
return (
sparse_semi_structured_to_dense_cutlass(
self.packed,
self.meta,
)
if self.meta.ndim == 2
else super().to_dense()
)
@classmethod
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
def prune_dense_static_sort(
cls, original_tensor: torch.Tensor, algorithm=""
) -> "SparseSemiStructuredTensor":
"""
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
@ -463,10 +473,15 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
```
"""
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
original_tensor,
algorithm=algorithm,
use_cutlass=True)
(
packed,
meta,
packed_t,
meta_t,
compressed_swizzled_bitmask,
) = torch._sparse_semi_structured_tile(
original_tensor, algorithm=algorithm, use_cutlass=True
)
return cls(
original_tensor.shape,
@ -479,11 +494,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
@ -500,9 +511,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
)
else:
if bias is None:
res = torch._sparse_semi_structured_mm(
self.packed, self.meta, B
)
res = torch._sparse_semi_structured_mm(self.packed, self.meta, B)
else:
res = torch._sparse_semi_structured_addmm(
bias, self.packed, self.meta, B
@ -521,6 +530,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
"""
BACKEND = "cusparselt"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
@ -530,7 +540,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
}
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
def from_dense(
cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor)
return cls(
shape=original_tensor.shape,
@ -545,7 +557,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
)
@classmethod
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
def prune_dense_static_sort(
cls, original_tensor: torch.Tensor, algorithm=""
) -> "SparseSemiStructuredTensor":
"""
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
layout and sparse matmul.
@ -576,10 +590,15 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
```
"""
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
original_tensor,
algorithm=algorithm,
use_cutlass=False)
(
packed,
meta,
packed_t,
meta_t,
compressed_swizzled_bitmask,
) = torch._sparse_semi_structured_tile(
original_tensor, algorithm=algorithm, use_cutlass=False
)
return cls(
original_tensor.shape,
@ -592,11 +611,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(