Enable UFMT on all of torch/sparse (#130545)

Partially addresses #123062
Ran lintrunner on:
- torch/sparse

Detail:
```
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130545
Approved by: https://github.com/ezyang
This commit is contained in:
WeiChunyu-star
2024-07-15 22:35:52 +00:00
committed by PyTorch MergeBot
parent 7d4f50de19
commit 535016967a
5 changed files with 884 additions and 376 deletions

View File

@ -1531,10 +1531,6 @@ exclude_patterns = [
'torch/signal/__init__.py',
'torch/signal/windows/__init__.py',
'torch/signal/windows/windows.py',
'torch/sparse/__init__.py',
'torch/sparse/_semi_structured_conversions.py',
'torch/sparse/_triton_ops.py',
'torch/sparse/semi_structured.py',
'torch/special/__init__.py',
'torch/testing/_internal/__init__.py',
'torch/testing/_internal/autocast_test_lists.py',

View File

@ -1,23 +1,23 @@
# mypy: allow-untyped-defs
# The Tensor classes are added to this module by python_tensor.cpp
from typing import Optional, Tuple, List, Union, Any
# A workaround to support both TorchScript and MyPy:
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
from torch import Tensor
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
# Semi structured sparsity support
from .semi_structured import (
SparseSemiStructuredTensor,
SparseSemiStructuredTensorCUSPARSELT,
SparseSemiStructuredTensorCUTLASS,
to_sparse_semi_structured
to_sparse_semi_structured,
)
# A workaround to support both TorchScript and MyPy:
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch.types import _dtype as DType
DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]]
else:
# The JIT doesn't understand Union, nor torch.dtype here
@ -26,20 +26,22 @@ else:
__all__ = [
'addmm',
'check_sparse_tensor_invariants',
'mm',
'sum',
'softmax',
'log_softmax',
'SparseSemiStructuredTensor',
'SparseSemiStructuredTensorCUTLASS',
'SparseSemiStructuredTensorCUSPARSELT',
'to_sparse_semi_structured',
'as_sparse_gradcheck',
"addmm",
"check_sparse_tensor_invariants",
"mm",
"sum",
"softmax",
"log_softmax",
"SparseSemiStructuredTensor",
"SparseSemiStructuredTensorCUTLASS",
"SparseSemiStructuredTensorCUSPARSELT",
"to_sparse_semi_structured",
"as_sparse_gradcheck",
]
addmm = _add_docstr(_sparse._sparse_addmm, r"""
addmm = _add_docstr(
_sparse._sparse_addmm,
r"""
sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
This function does exact same thing as :func:`torch.addmm` in the forward,
@ -58,10 +60,13 @@ Args:
mat2 (Tensor): a dense matrix to be multiplied
beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
""")
""",
)
mm = _add_docstr(_sparse._sparse_mm, r"""
mm = _add_docstr(
_sparse._sparse_mm,
r"""
Performs a matrix multiplication of the sparse matrix :attr:`mat1`
and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a
:math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
@ -132,10 +137,13 @@ Example::
>>> y2
tensor([[0., 1.],
[6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
""")
""",
)
sampled_addmm = _add_docstr(_sparse.sparse_sampled_addmm, r"""
sampled_addmm = _add_docstr(
_sparse.sparse_sampled_addmm,
r"""
sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor
Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations
@ -184,10 +192,11 @@ Examples::
col_indices=tensor([0, 1, 2]),
values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0',
size=(3, 3), nnz=3, layout=torch.sparse_csr)
""")
""",
)
def sum(input: Tensor, dim: DimOrDims = None,
dtype: Optional[DType] = None) -> Tensor:
def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor:
r"""Return the sum of each row of the given sparse tensor.
Returns the sum of each row of the sparse tensor :attr:`input` in the given
@ -256,7 +265,9 @@ def sum(input: Tensor, dim: DimOrDims = None,
return torch._sparse_sum(input, dtype=dtype)
softmax = _add_docstr(_sparse._sparse_softmax, r"""
softmax = _add_docstr(
_sparse._sparse_softmax,
r"""
sparse.softmax(input, dim, *, dtype=None) -> Tensor
Applies a softmax function.
@ -281,10 +292,13 @@ Args:
casted to :attr:`dtype` before the operation is
performed. This is useful for preventing data type
overflows. Default: None
""")
""",
)
log_softmax = _add_docstr(_sparse._sparse_log_softmax, r"""
log_softmax = _add_docstr(
_sparse._sparse_log_softmax,
r"""
sparse.log_softmax(input, dim, *, dtype=None) -> Tensor
Applies a softmax function followed by logarithm.
@ -299,7 +313,8 @@ Args:
casted to :attr:`dtype` before the operation is
performed. This is useful for preventing data type
overflows. Default: None
""")
""",
)
spdiags = _add_docstr(
@ -393,7 +408,8 @@ Specifying a positive offset::
[0, 0, 3, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
""")
""",
)
class check_sparse_tensor_invariants:
@ -483,12 +499,14 @@ class check_sparse_tensor_invariants:
# context manager support
def __init__(self, enable=True):
self.state = enable
self.saved_state : Optional[bool] = None
self.saved_state: Optional[bool] = None
def __enter__(self):
if self.saved_state is not None:
raise RuntimeError('This context manager instance is already activated.'
' Use a different context manager instance for context nesting.')
raise RuntimeError(
"This context manager instance is already activated."
" Use a different context manager instance for context nesting."
)
self.saved_state = self.is_enabled()
torch._C._set_check_sparse_tensor_invariants(self.state)
@ -499,7 +517,6 @@ class check_sparse_tensor_invariants:
# decorator support
def __call__(self, mth):
def test_mth(*args, **kwargs):
with type(self)(self.state):
return mth(*args, **kwargs)
@ -531,37 +548,71 @@ def as_sparse_gradcheck(gradcheck):
Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support.
"""
masked = kwargs.pop('masked', False)
sparse_layouts = {torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
sparse_compressed_layouts = {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
masked = kwargs.pop("masked", False)
sparse_layouts = {
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
}
sparse_compressed_layouts = {
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
}
sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc}
STRIDED_REPRESENTATION = '__STRIDED_REPRESENTATION__'
STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__"
def convert_to_strided_representation(args):
"""Convert differentiable non-strided tensors to a representation containing differentiable strided tensors."""
if not isinstance(args, (list, tuple)):
args = args,
args = (args,)
new_args: List[Any] = []
for obj in args:
if isinstance(obj, torch.Tensor) and obj.requires_grad and obj.layout in sparse_layouts:
if (
isinstance(obj, torch.Tensor)
and obj.requires_grad
and obj.layout in sparse_layouts
):
d = dict(layout=obj.layout, shape=obj.shape)
if not masked:
# Materialize unspecified elements with zero values
batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
blocksize = obj.values().shape[batch_dim + 1:batch_dim + 3] if obj.layout in sparse_block_layouts else None
full_mask = torch.ones(obj.shape, device=obj.device, dtype=torch.bool).to_sparse(
layout=obj.layout, blocksize=blocksize, dense_dim=obj.dense_dim())
blocksize = (
obj.values().shape[batch_dim + 1 : batch_dim + 3]
if obj.layout in sparse_block_layouts
else None
)
full_mask = torch.ones(
obj.shape, device=obj.device, dtype=torch.bool
).to_sparse(
layout=obj.layout,
blocksize=blocksize,
dense_dim=obj.dense_dim(),
)
obj = obj.to_dense().sparse_mask(full_mask)
if obj.layout is torch.sparse_coo:
d.update(indices=obj._indices(), is_coalesced=obj.is_coalesced())
d.update(
indices=obj._indices(), is_coalesced=obj.is_coalesced()
)
values = obj._values()
elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
d.update(compressed_indices=obj.crow_indices(), plain_indices=obj.col_indices())
d.update(
compressed_indices=obj.crow_indices(),
plain_indices=obj.col_indices(),
)
values = obj.values()
else:
d.update(compressed_indices=obj.ccol_indices(), plain_indices=obj.row_indices())
d.update(
compressed_indices=obj.ccol_indices(),
plain_indices=obj.row_indices(),
)
values = obj.values()
new_args.extend((STRIDED_REPRESENTATION, d, values.requires_grad_(True)))
new_args.extend(
(STRIDED_REPRESENTATION, d, values.requires_grad_(True))
)
else:
new_args.append(obj)
return tuple(new_args)
@ -574,13 +625,25 @@ def as_sparse_gradcheck(gradcheck):
a = args.pop(0)
if a == STRIDED_REPRESENTATION:
d, values = args.pop(0), args.pop(0)
if d['layout'] is torch.sparse_coo:
a = torch.sparse_coo_tensor(d['indices'], values, size=d['shape'], is_coalesced=d['is_coalesced'])
elif d['layout'] in sparse_compressed_layouts:
a = torch.sparse_compressed_tensor(d['compressed_indices'], d['plain_indices'], values,
size=d['shape'], layout=d['layout'])
if d["layout"] is torch.sparse_coo:
a = torch.sparse_coo_tensor(
d["indices"],
values,
size=d["shape"],
is_coalesced=d["is_coalesced"],
)
elif d["layout"] in sparse_compressed_layouts:
a = torch.sparse_compressed_tensor(
d["compressed_indices"],
d["plain_indices"],
values,
size=d["shape"],
layout=d["layout"],
)
else:
raise NotImplementedError(f'conversion of {d["layout"]} strided representation to tensor')
raise NotImplementedError(
f'conversion of {d["layout"]} strided representation to tensor'
)
new_args.append(a)
return tuple(new_args)
@ -591,12 +654,25 @@ def as_sparse_gradcheck(gradcheck):
# tensors:
outputs = func(*restored_args, **kwargs)
strided_outputs = tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
strided_outputs = tuple((o.to_dense(masked_grad=masked)
if isinstance(o, torch.Tensor) and o.requires_grad and o.layout in sparse_layouts else o)
for o in strided_outputs)
strided_outputs = (
tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
)
strided_outputs = tuple(
(
o.to_dense(masked_grad=masked)
if isinstance(o, torch.Tensor)
and o.requires_grad
and o.layout in sparse_layouts
else o
)
for o in strided_outputs
)
return strided_outputs if isinstance(outputs, (list, tuple)) else strided_outputs[0]
return (
strided_outputs
if isinstance(outputs, (list, tuple))
else strided_outputs[0]
)
args = (func_wrapper, convert_to_strided_representation(inputs))

View File

@ -342,11 +342,15 @@ def _compute_compressed_swizzled_bitmask(dense):
# [0 0 1 1]
# reshape tensor to expand tiles into 8-bit vectors
bitmask_binary_representation = bitmask_4x4_chunks.reshape(*bitmask_4x4_chunks.shape[:2], 4, 2, 8)
bitmask_binary_representation = bitmask_4x4_chunks.reshape(
*bitmask_4x4_chunks.shape[:2], 4, 2, 8
)
# to convert from binary representaiton, we can do a matmul with powers of two
powers_of_two = 2**torch.arange(8, dtype=torch.float, device="cuda")
powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda")
# To run on GPU: cast to float to do matmul and then cast back
compressed_swizzled_bitmask = (bitmask_binary_representation.to(torch.float) @ powers_of_two).to(torch.uint8)
compressed_swizzled_bitmask = (
bitmask_binary_representation.to(torch.float) @ powers_of_two
).to(torch.uint8)
return compressed_swizzled_bitmask

File diff suppressed because it is too large Load Diff

View File

@ -1,23 +1,23 @@
# mypy: allow-untyped-defs
import warnings
from collections import namedtuple
from typing import Any, Optional, Tuple, List, Callable, Dict
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass
sparse_semi_structured_to_dense_cutlass,
)
from torch.sparse._semi_structured_ops import (
fallback_dispatcher,
semi_sparse_values,
semi_sparse_indices,
semi_sparse_detach,
semi_sparse_t,
semi_sparse_view,
semi_sparse_mm,
semi_sparse_addmm,
semi_sparse_detach,
semi_sparse_indices,
semi_sparse_linear,
semi_sparse_mm,
semi_sparse_t,
semi_sparse_values,
semi_sparse_view,
)
__all__ = [
@ -175,7 +175,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta : Tuple[torch.Size, bool, int, bool],
tensor_meta: Tuple[torch.Size, bool, int, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
@ -186,7 +186,9 @@ class SparseSemiStructuredTensor(torch.Tensor):
meta=inner_tensors.get("meta", None),
packed_t=inner_tensors.get("packed_t", None),
meta_t=inner_tensors.get("meta_t", None),
compressed_swizzled_bitmask=inner_tensors.get("compressed_swizzled_bitmask", None),
compressed_swizzled_bitmask=inner_tensors.get(
"compressed_swizzled_bitmask", None
),
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
alg_id_cusparselt=alg_id_cusparselt,
requires_grad=requires_grad,
@ -227,7 +229,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
@classmethod
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None:
"""
Assert that the given tensor is valid for semi-structured sparse compression.
"""
@ -297,7 +299,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor":
raise NotImplementedError
def _mm(
@ -377,6 +379,7 @@ def to_sparse_semi_structured(
return SPARSE_SUBCLASS.from_dense(original_tensor)
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
"""
This class implements semi-structured sparsity for the CUTLASS backend.
@ -388,6 +391,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
sparse_semi_structured_from_dense for conversion to the compressed format.
"""
BACKEND = "cutlass"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
@ -417,13 +421,19 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
def to_dense(self):
assert self.meta is not None and self.packed is not None
return sparse_semi_structured_to_dense_cutlass(
return (
sparse_semi_structured_to_dense_cutlass(
self.packed,
self.meta,
) if self.meta.ndim == 2 else super().to_dense()
)
if self.meta.ndim == 2
else super().to_dense()
)
@classmethod
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
def prune_dense_static_sort(
cls, original_tensor: torch.Tensor, algorithm=""
) -> "SparseSemiStructuredTensor":
"""
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
@ -463,10 +473,15 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
```
"""
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
original_tensor,
algorithm=algorithm,
use_cutlass=True)
(
packed,
meta,
packed_t,
meta_t,
compressed_swizzled_bitmask,
) = torch._sparse_semi_structured_tile(
original_tensor, algorithm=algorithm, use_cutlass=True
)
return cls(
original_tensor.shape,
@ -479,11 +494,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
@ -500,9 +511,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
)
else:
if bias is None:
res = torch._sparse_semi_structured_mm(
self.packed, self.meta, B
)
res = torch._sparse_semi_structured_mm(self.packed, self.meta, B)
else:
res = torch._sparse_semi_structured_addmm(
bias, self.packed, self.meta, B
@ -521,6 +530,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
"""
BACKEND = "cusparselt"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
@ -530,7 +540,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
}
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
def from_dense(
cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor)
return cls(
shape=original_tensor.shape,
@ -545,7 +557,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
)
@classmethod
def prune_dense_static_sort(cls, original_tensor : torch.Tensor, algorithm="") -> "SparseSemiStructuredTensor":
def prune_dense_static_sort(
cls, original_tensor: torch.Tensor, algorithm=""
) -> "SparseSemiStructuredTensor":
"""
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
layout and sparse matmul.
@ -576,10 +590,15 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
```
"""
(packed, meta, packed_t, meta_t, compressed_swizzled_bitmask) = torch._sparse_semi_structured_tile(
original_tensor,
algorithm=algorithm,
use_cutlass=False)
(
packed,
meta,
packed_t,
meta_t,
compressed_swizzled_bitmask,
) = torch._sparse_semi_structured_tile(
original_tensor, algorithm=algorithm, use_cutlass=False
)
return cls(
original_tensor.shape,
@ -592,11 +611,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(