[sparse] semi-structured sparse refactor (#117302)

Summary:

This PR is a refactor of semi-structured sparsity support.

**deprecation**:

Before `torch.sparse.to_sparse_semi_structured` had a kwarg param
`transposed=False`, which has been removed. This kwarg was unused and
now thros a deprecation warning.

Namely, I've taken the subclassing implementation that xFormers has
created and brought it over to PyTorch, as part of our plan to upstream
runtime 2:4 sparsity.

I've also copied over all the op support that Daniel implemenented that
did not depend on the fast sparsification routines, into
`_sparse_semi_structured_ops.py`

With this subclass, all of our internal tests pass, as well as those in
xFormers.

The main change is that we now define a base subclass,
`SparseSemiStructuredTensor` that is inherited from for each of the
specific backends.

We also now can arbitrarily override the sparse dispatch table with
`_load_dispatch_table()`, idea being this is still general enough
where users don't need to modify pytorch source code to get their model
working.

This also adds in padding support and stores alg_id and fuse_transpose
as flags on the tensor, instead of hardcoding them.

There still remains two components in xFormers that will need to be
ported over eventually:
- the autograd functions  (`Sparsify24`, `Sparsify24_like`)
- fast sparsification routines that they rely on

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117302
Approved by: https://github.com/alexsamardzic, https://github.com/HDCharles
This commit is contained in:
Jesse Cai
2024-02-12 17:12:20 -08:00
committed by PyTorch MergeBot
parent 2536c5186e
commit 16369816a2
4 changed files with 583 additions and 427 deletions

View File

@ -6,9 +6,10 @@ import unittest
import torch
from torch import nn
from torch.sparse.semi_structured import (
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
from torch.sparse import (
SparseSemiStructuredTensor,
SparseSemiStructuredTensorCUSPARSELT,
SparseSemiStructuredTensorCUTLASS,
to_sparse_semi_structured,
)
@ -36,7 +37,7 @@ from torch.utils._triton import has_triton
CUSPARSELT_NUM_ALG_IDS = 4
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8]
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
_IS_SM8X = False
@ -315,7 +316,7 @@ class TestSparseSemiStructured(TestCase):
with self.assertRaisesRegex(
NotImplementedError,
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
):
torch.mm(A_sparse.t(), B)
@ -357,7 +358,7 @@ class TestSparseSemiStructured(TestCase):
with self.assertRaisesRegex(
NotImplementedError,
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
):
sparse_result = torch.mm(A, B_sparse)
@ -438,7 +439,10 @@ class TestSparseSemiStructured(TestCase):
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_min_sparse_shape(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
config = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[dtype]
if backend == "cutlass":
config = SparseSemiStructuredTensorCUTLASS._DTYPE_SHAPE_CONSTRAINTS[dtype]
elif backend == "cusparselt":
config = SparseSemiStructuredTensorCUSPARSELT._DTYPE_SHAPE_CONSTRAINTS[dtype]
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)