mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] semi-structured sparse refactor (#117302)
Summary: This PR is a refactor of semi-structured sparsity support. **deprecation**: Before `torch.sparse.to_sparse_semi_structured` had a kwarg param `transposed=False`, which has been removed. This kwarg was unused and now thros a deprecation warning. Namely, I've taken the subclassing implementation that xFormers has created and brought it over to PyTorch, as part of our plan to upstream runtime 2:4 sparsity. I've also copied over all the op support that Daniel implemenented that did not depend on the fast sparsification routines, into `_sparse_semi_structured_ops.py` With this subclass, all of our internal tests pass, as well as those in xFormers. The main change is that we now define a base subclass, `SparseSemiStructuredTensor` that is inherited from for each of the specific backends. We also now can arbitrarily override the sparse dispatch table with `_load_dispatch_table()`, idea being this is still general enough where users don't need to modify pytorch source code to get their model working. This also adds in padding support and stores alg_id and fuse_transpose as flags on the tensor, instead of hardcoding them. There still remains two components in xFormers that will need to be ported over eventually: - the autograd functions (`Sparsify24`, `Sparsify24_like`) - fast sparsification routines that they rely on Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/117302 Approved by: https://github.com/alexsamardzic, https://github.com/HDCharles
This commit is contained in:
committed by
PyTorch MergeBot
parent
2536c5186e
commit
16369816a2
@ -6,9 +6,10 @@ import unittest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.sparse.semi_structured import (
|
||||
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
|
||||
from torch.sparse import (
|
||||
SparseSemiStructuredTensor,
|
||||
SparseSemiStructuredTensorCUSPARSELT,
|
||||
SparseSemiStructuredTensorCUTLASS,
|
||||
to_sparse_semi_structured,
|
||||
)
|
||||
|
||||
@ -36,7 +37,7 @@ from torch.utils._triton import has_triton
|
||||
CUSPARSELT_NUM_ALG_IDS = 4
|
||||
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
||||
|
||||
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
|
||||
SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8]
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
|
||||
|
||||
_IS_SM8X = False
|
||||
@ -315,7 +316,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
|
||||
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
|
||||
):
|
||||
torch.mm(A_sparse.t(), B)
|
||||
|
||||
@ -357,7 +358,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
|
||||
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
|
||||
):
|
||||
sparse_result = torch.mm(A, B_sparse)
|
||||
|
||||
@ -438,7 +439,10 @@ class TestSparseSemiStructured(TestCase):
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_min_sparse_shape(self, dtype, device, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
config = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[dtype]
|
||||
if backend == "cutlass":
|
||||
config = SparseSemiStructuredTensorCUTLASS._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
elif backend == "cusparselt":
|
||||
config = SparseSemiStructuredTensorCUSPARSELT._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
||||
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
|
||||
|
Reference in New Issue
Block a user