[core][pruning][sparse][feature] SparseSemiStructured tensor subclass (#102135)

This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

**SparseSemiStructuredTensor**

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

**to_sparse_semi_structured**

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

**User Details**

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
reflect these changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
This commit is contained in:
Jesse Cai
2023-06-27 09:01:27 -07:00
committed by PyTorch MergeBot
parent 39868b0578
commit 2da6cae43c
5 changed files with 1016 additions and 2 deletions

View File

@ -0,0 +1,245 @@
import random
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from tqdm import tqdm
import pandas as pd
import argparse
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
torch.set_printoptions(
precision=2,
threshold=None,
edgeitems=16,
linewidth=480,
profile=None,
sci_mode=False,
)
# helper model definition for pruner
class Model(nn.Module):
def __init__(self, m, k, dtype=None):
super().__init__()
# transposed so reversed
self.linear = nn.Linear(k, m)
def forward(self, x):
return self.linear(x)
def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""
choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)
def test_linear(m, k, n, dtype, contiguous, backend):
SparseSemiStructuredTensor.fuse_transpose = contiguous
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
input_tensor = torch.zeros(n, k).to(dtype).cuda()
model = Model(m, k).to(dtype).cuda().eval()
dense_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
dense_output = model(input_tensor)
# sparsify weights
model.linear.weight = nn.Parameter(to_sparse_semi_structured(sparse_weight, mask=mask.bool()))
sparse_output = model(input_tensor)
sparse_measurement = benchmark.Timer(
stmt="model(input_tensor)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "linear",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
def test_tensor(m, k, n, dtype, contiguous, backend):
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
B = torch.zeros(k, n).to(dtype).cuda()
bias = torch.rand(n).to(dtype).cuda()
sA = to_sparse_semi_structured(A, mask=A.bool())
# torch.mm calculation
if dtype is not torch.int8:
dense_output = torch.mm(A, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(A, B)",
globals=locals(),
).blocked_autorange()
else:
print("int8 baseline not supported")
dense_output = torch.mm(sA, B)
dense_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
sparse_output = torch.mm(sA, B)
sparse_measurement = benchmark.Timer(
stmt="torch.mm(sA, B)",
globals=locals(),
).blocked_autorange()
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
return {
"test_function": "tensor",
"m": m,
"k": k,
"n": n,
"dtype": str(dtype),
"backend": backend,
"sparse_latency (ms)": sparse_measurement.median * 1000,
"dense_latency (ms)": dense_measurement.median * 1000,
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
"correct": correct,
"contiguous": sparse_output.is_contiguous(),
}
if __name__ == "__main__":
dtype_lookup = {
"int8": torch.int8,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
parser.add_argument(
"--mode",
type=str,
choices=[
"nvidia-bert",
"nvidia-fixed-k",
"nvidia-fixed-mn",
],
)
parser.add_argument(
"--dtype",
type=str,
choices=dtype_lookup.keys(),
default="fp16",
)
parser.add_argument(
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
)
parser.add_argument("-contiguous", action="store_true")
parser.add_argument("-e2e", action="store_true")
parser.add_argument("-save", action="store_true")
args = parser.parse_args()
if args.e2e:
eval_fn = test_linear
else:
eval_fn = test_tensor
print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
dtype = dtype_lookup[args.dtype]
if args.mode == "nvidia-bert":
bert_shapes = [
(3072, 1024, 16384),
(4096, 1024, 16384),
(1024, 1024, 16384),
(1024, 4096, 16384),
]
results = (
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
for (m, k, n) in tqdm(bert_shapes)
)
elif args.mode == "nvidia-fixed-k":
mn_vals = [
3072,
4096,
5120,
6144,
7168,
8192,
9216,
10240,
11264,
12288,
13312,
14336,
15360,
16384,
17408,
18432,
19456,
20480,
]
results = (
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
for mn in tqdm(mn_vals)
)
elif args.mode == "nvidia-fixed-mn":
k_vals = [
2560,
3840,
5120,
6400,
7680,
8960,
10240,
11520,
12800,
14080,
15360,
16640,
17920,
19200,
20480,
]
results = (
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
for k in tqdm(k_vals)
)
df = pd.DataFrame.from_records(results)
if args.save:
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
df.to_csv(save_file)
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
print(df)

View File

@ -24,7 +24,7 @@ matrices, pruned weights or points clouds by Tensors whose *elements are
mostly zero valued*. We recognize these are important applications and aim
to provide performance optimizations for these use cases via sparse storage formats.
Various sparse storage formats such as COO, CSR/CSC, LIL, etc. have been
Various sparse storage formats such as COO, CSR/CSC, semi-structured, LIL, etc. have been
developed over the years. While they differ in exact layouts, they all
compress data through efficient representation of zero valued elements.
We call the uncompressed values *specified* in contrast to *unspecified*,
@ -67,6 +67,8 @@ indices of non-zero elements are stored in this case.
PyTorch currently supports :ref:`COO<sparse-coo-docs>`, :ref:`CSR<sparse-csr-docs>`,
:ref:`CSC<sparse-csc-docs>`, :ref:`BSR<sparse-bsr-docs>`, and :ref:`BSC<sparse-bsc-docs>`.
We also have a prototype implementation to support :ref: `semi-structured sparsity<sparse-semi-structured-docs>`.
Please see the references for more details.
Note that we provide slight generalizations of these formats.
@ -167,6 +169,147 @@ receiving a particular layout. We are working on an API to control the result la
and recognize it is an important feature to plan a more optimal path of execution for
any given model.
.. _sparse-semi-structured-docs:
Sparse Semi-Structured Tensors
++++++++++++++++++++++++++++++
.. warning::
Sparse semi-sturctured tensors are currently a prototype feature and subject to change. Please feel free to open an issue to report a bug or if you have feedback to share.
Semi-Structured sparsity is a sparse data layout that was first introduced in NVIDIA's Ampere architecture. It is also referred to as **fine-grained structured sparsity** or **2:4 structured sparsity**.
This sparse layout stores `n` elements out of every `2n` elements, with `n` being determined by the width of the Tensor's data type (dtype). The most frequently used dtype is float16, where `n=2`, thus the term "2:4 structured sparsity."
Semi-structured sparsity is explained in greater detail in `this NVIDIA blog post <https://developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt>`_.
In PyTorch, semi-structured sparsity is implemented via a Tensor subclass.
By subclassing, we can override ``__torch_dispatch__`` , allowing us to use faster sparse kernels when performing matrix multiplication.
We can also store the tensor in it's compressed form inside the subclass to reduce memory overhead.
In this compressed form, the sparse tensor is stored by retaining only the *specified* elements and some metadata, which encodes the mask.
.. note::
The specified elements and metadata mask of a semi-structured sparse tensor are stored together in a single
flat compressed tensor. They are appended to each other to form a contiguous chunk of memory.
compressed tensor = [ specified elements of original tensor | metadata_mask ]
For an original tensor of size `(r, c)` we expect the first `m * k // 2` elements to be the kept elements
and the rest of the tensor is metadata.
In order to make it easier for the user to view the specified elements
and mask, one can use ``.indices()`` and ``.values()`` to access the mask and specified elements respectively.
- ``.values()`` returns the specified elements in a tensor of size `(r, c//2)` and with the same dtype as the dense matrix.
- ``.indices()`` returns the metadata_mask in a tensor of size `(r, c//2 )` and with element type ``torch.int16`` if dtype is torch.float16 and element type ``torch.int32`` if dtype is torch.int8.
For 2:4 sparse tensors, the metadata overhead is minor - just 2 bits per specified element.
.. note::
It's important to note that ``torch.float32`` is only supported for 1:2 sparsity. Therefore, it does not follow the same formula as above.
Here, we break down how to calculate the compression ratio ( size dense / size sparse) of a 2:4 sparse tensor.
Let `(r, c) = tensor.shape` and `e = bitwidth(tensor.dtype)`, so `e = 16` for ``torch.float16`` and ``torch.bfloat16`` and `e = 8` for ``torch.int8``.
.. math::
M_{dense} = r \times c \times e \\
M_{sparse} = M_{specified} + M_{metadata} = r \times \frac{c}{2} \times e + r \times \frac{c}{2} \times 2 = \frac{rce}{2} + rc =rce(\frac{1}{2} +\frac{1}{e})
Using these calculations, we can determine the total memory footprint for both the original dense and the new sparse representation.
This gives us a simple formula for the compression ratio, which is dependent only on the bitwidth of the tensor datatype.
.. math::
C = \frac{M_{sparse}}{M_{dense}} = \frac{1}{2} + \frac{1}{e}
By using this formula, we find that the compression ratio is 56.25% for ``torch.float16`` and 62.5% for ``torch.int8``.
Constructing Sparse Semi-Structured Tensors
-------------------------------------------
You can transform a dense tensor into a sparse semi-structured tensor by using the ``torch.sparse.to_sparse_semi_structured`` function.
Please also note that we only support CUDA tensors since hardware compatibility for semi-structured sparsity is limited to NVIDIA GPUs.
The following datatypes are supported for semi-structured sparsity. Note that each datatype has its own shape constraints and compression factor.
.. csv-table::
:header: "PyTorch dtype", "Shape Constraints", "Compression Factor", "Sparsity Pattern"
:widths: 15, 45, 10, 10
:delim: ;
``torch.float16``; Tensor must be 2D and (r, c) must both be a positive multiple of 64;9/16;2:4
``torch.int8``; Tensor must be 2D and (r, c) must both be a positive multiple of 128;10/16;2:4
To construct a semi-structured sparse tensor, start by creating a regular dense tensor that adheres to a 2:4 (or semi-structured) sparse format.
To do this we tile a small 1x4 strip to create a 16x16 dense float16 tensor.
Afterwards, we can call ``to_sparse_semi_structured`` on this matrix to compress it for accelerated inference.
>>> from torch.sparse import to_sparse_semi_structured
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
tensor([[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
...,
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
>>> A_sparse = to_sparse_semi_structured(A, mask=A.bool())
SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
...,
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
dtype=torch.int16))
Sparse Semi-Structured Tensor Operations
----------------------------------------
Currently, the following operations are supported for semi-structured sparse tensors:
- torch.addmm(bias, dense, sparse.t())
- torch.mm(dense, sparse)
- torch.mm(sparse, dense)
- aten.linear.default(dense, sparse, bias)
- aten.t.default(sparse)
- aten.t.detach(sparse)
To use these ops, simply pass the output of ``to_sparse_semi_structured(tensor)`` instead of using ``tensor`` once your tensor has 0s in a semi-structured sparse format, like this:
>>> a = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).half().cuda()
>>> b = torch.rand(64, 64).half().cuda()
>>> c = torch.mm(a, b)
>>> a_sparse = to_sparse_semi_structured(a, mask=a.bool())
>>> torch.allclose(c, torch.mm(a_sparse, b))
True
Under the hood, SparseSemiStructuredTensor will call ``torch._structured_sparse_linear`` for accelerated inference using CUTLASS sparse kernels.
Accelerating nn.Linear with semi-structured sparsity
----------------------------------------------------
You can accelerate the linear layers in your model if the weights are already semi-structured sparse with just a few lines of code:
>>> input = torch.rand(64, 64).half().cuda()
>>> mask = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).cuda().bool()
>>> linear = nn.Linear(64, 64).half().cuda()
>>> linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=mask))
.. _sparse-coo-docs:
@ -992,12 +1135,18 @@ multiplication, and ``@`` is matrix multiplication.
:func:`torch.mv`;no; ``M[sparse_csr] @ V[strided] -> V[strided]``
:func:`torch.matmul`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
:func:`torch.matmul`; no; ``M[sparse_csr] @ M[strided] -> M[strided]``
:func:`torch.matmul`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]``
:func:`torch.matmul`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]``
:func:`torch.mm`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
:func:`torch.mm`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]``
:func:`torch.mm`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]``
:func:`torch.sparse.mm`; yes; ``M[sparse_coo] @ M[strided] -> M[strided]``
:func:`torch.smm`; no; ``M[sparse_coo] @ M[strided] -> M[sparse_coo]``
:func:`torch.hspmm`; no; ``M[sparse_coo] @ M[strided] -> M[hybrid sparse_coo]``
:func:`torch.bmm`; no; ``T[sparse_coo] @ T[strided] -> T[strided]``
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]``
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[SparseSemiStructured] @ M[strided]) -> M[strided]``
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[strided] @ M[SparseSemiStructured]) -> M[strided]``
:func:`torch.sparse.addmm`; yes; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]``
:func:`torch.sspaddmm`; no; ``f * M[sparse_coo] + f * (M[sparse_coo] @ M[strided]) -> M[sparse_coo]``
:func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]``

View File

@ -0,0 +1,227 @@
# Owner(s): ["module: sparse"]
import random
import unittest
import torch
from torch import nn
from torch.sparse.semi_structured import (
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
SparseSemiStructuredTensor,
to_sparse_semi_structured,
)
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)
from torch.testing._internal.common_dtype import all_types_and_complex
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
subtest,
TestCase,
)
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
_IS_SM8X = False
if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""
choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)
class TestSparseSemiStructured(TestCase):
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_to_sparse_semi_structured(self, dtype):
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
assert A.shape == A_sparse.shape
assert A.device == A_sparse.device
assert A.dtype == A_sparse.dtype
assert isinstance(A, torch.Tensor)
assert isinstance(A_sparse, SparseSemiStructuredTensor)
with self.assertRaisesRegex(
NotImplementedError,
"You must pass in a mask to to_sparse_semi_structured, currently mask=None.",
):
A_sparse = to_sparse_semi_structured(A)
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_mm_sparse_first_NT(self, dtype, device):
"""
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
Ensure torch.mm(A_sparse, B.t()) is correct
"""
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
# This should fail
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B)
# test transpose
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32)
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
else:
dense_result = torch.mm(A, B)
sparse_result = torch.mm(A_sparse, B)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
# test transpose
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_mm_sparse_first_T(self, dtype, device):
"""
Ensure torch.mm(A_sparse.t(), B) throws error
"""
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
with self.assertRaisesRegex(
NotImplementedError,
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
):
torch.mm(A_sparse.t(), B)
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_mm_sparse_second_T(self, dtype, device):
"""
Ensure torch.mm(A, B_sparse.t()) is correct
"""
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B, mask=B.bool())
A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32)
sparse_result = torch.mm(A, B_sparse.t())
else:
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A, B_sparse.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_mm_sparse_second_NT(self, dtype, device):
"""
Ensure torch.mm(A, B_sparse) throws error
"""
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B, mask=B.bool())
A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
with self.assertRaisesRegex(
NotImplementedError,
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
):
sparse_result = torch.mm(A, B_sparse)
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
@parametrize("inference_mode", [subtest(False), subtest(True)])
def test_linear(self, inference_mode, device):
"""
Test nn.Linear has the same numerics
"""
input = torch.rand(128, 128, device=device).half()
model = nn.Linear(128, 128).to(device).half()
m, n = model.weight.shape
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
# set masked weight
model.weight = nn.Parameter(model.weight * mask)
dense_result = model(input)
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight, mask=mask))
if inference_mode:
with torch.inference_mode():
sparse_result = model(input)
else:
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-5, atol=1e-5)
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
def test_values(self):
A = rand_sparse_semi_structured_mask(128, 128)
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
assert A_sparse.values().shape == (128, 64)
assert (A_sparse.values() == 1).all()
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
def test_indices(self):
A = rand_sparse_semi_structured_mask(128, 128)
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
assert A_sparse.indices().shape == (128, 8)
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_unsupported_shape(self, dtype, device):
A = rand_sparse_semi_structured_mask(4, 4, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
@dtypes(*all_types_and_complex())
def test_unsupported_dtype(self, dtype, device):
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES:
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
else:
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
def test_unsupported_dim(self, device):
A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
if __name__ == "__main__":
run_tests()

View File

@ -5,6 +5,9 @@ import torch
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
from torch import Tensor
# Semi structured sparsity support
from .semi_structured import SparseSemiStructuredTensor, to_sparse_semi_structured
# A workaround to support both TorchScript and MyPy:
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@ -23,9 +26,10 @@ __all__ = [
'sum',
'softmax',
'log_softmax',
'SparseSemiStructuredTensor',
'to_sparse_semi_structured',
]
addmm = _add_docstr(_sparse._sparse_addmm, r"""
sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor

View File

@ -0,0 +1,389 @@
import warnings
from collections import namedtuple
from typing import Any, Optional
import torch
__all__ = [
"to_sparse_semi_structured",
"SparseSemiStructuredTensor",
]
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
"_SEMI_STRUCTURED_SPARSE_CONFIG", "compression_factor min_size"
)
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = {
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(9, 64),
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(10, 128),
}
_WARNING_SHOWN = False
class SparseSemiStructuredTensor(torch.Tensor):
"""This class implementes semi-structured sparsity as a Tensor subclass.
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
structured sparsity.
Currently, this class supports 2:4 sparsity for int8 and float16 dtypes.
This subclass stores the dense tensor in a compressed form by only storing the specified elemenets and a metadata mask.
These two are stored next to each other in one contiguous tensor.
We choose to store the specified elements and the metadata in a single tensor for future compatibilty with cuSPARSELt.
compressed tensor = [ specified elements of original tensor | mask_metadata ]
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
The rest of the tensor is metadata.
This subclass also overrides __torch_dispatch__ to use _structured_sparse_linear for faster matrix multiplications
via sparse CUTLASS kernels. In the future we will also call into cuSPARSELt kernels for more performance gains.
"""
@staticmethod
def __new__(
cls,
original_tensor: Optional[torch.Tensor],
original_shape: Optional[torch.Size] = None,
mask: Optional[torch.Tensor] = None,
compressed_tensor: Optional[torch.Tensor] = None,
transposed: bool = False,
):
"""
Create a new instance of the class.
When original_tensor is passed in, we compress it and store the compresed representation.
We can also create new instance of the class from the compressed representation without the original tensor.
Args:
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
original_shape: The shape of the original dense tensor
mask: Mask to be applied to the original tensor.
compressed_tensor: A flattened tensor to store the specified elements and mask metadata.
transposed: Whether the tensor is transposed or not.
Returns:
torch.Tensor: A torch.Tensor wrapper subclass.
Raises:
ValueError: If both original_tensor and compressed_tensor are None.
"""
if original_tensor is not None:
previous_tensor = original_tensor
original_shape = original_tensor.shape
elif compressed_tensor is not None:
previous_tensor = compressed_tensor
else:
raise ValueError("Both compressed_tensor and original_tensor are None!")
kwargs = {}
kwargs["device"] = previous_tensor.device # type: ignore[assignment]
kwargs["dtype"] = previous_tensor.dtype # type: ignore[assignment]
kwargs["layout"] = previous_tensor.layout # type: ignore[assignment]
kwargs["requires_grad"] = False # type: ignore[assignment]
return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined]
def __init__(
self,
original_tensor: Optional[torch.Tensor],
original_shape: Optional[torch.Size] = None,
mask: Optional[torch.Tensor] = None,
compressed_tensor: Optional[torch.Tensor] = None,
transposed: bool = False,
) -> None:
"""SparseSemiStructuredTensor constructor.
Args:
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
original_shape: The shape of the original dense tensor
mask: Mask to be applied to the original tensor.
compressed_tensor: A flattened tensor to store the specified elements and mask metadata.
transposed: Whether the tensor is transposed or not.
Returns:
None
Raises:
NotImplementedError: If ``mask=None``, as we currently do not support inferring a mask from the dense tensor.
RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device.
"""
global _WARNING_SHOWN
if not _WARNING_SHOWN:
warnings.warn(
(
"The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
"and will change in the near future. Please open a Github issue "
"for features requests and see our documentation on the torch.sparse "
"module for further information about the project."
),
UserWarning,
)
_WARNING_SHOWN = True
# if original tensor is passed in, we need to compress it and store the compressed representation.
if original_tensor is not None:
# check if mask passed in
if mask is None:
raise NotImplementedError("You must pass in a mask to to_sparse_semi_structured, currently mask=None.")
# check device
if not original_tensor.is_cuda:
raise RuntimeError(
(
f"Error original_tensor.device= {original_tensor.device} is not supported! "
"Only CUDA tensors are currently supported."
)
)
# check dim
if original_tensor.dim() != 2:
raise RuntimeError(
(
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
"Only 2d tensors are currently supported."
)
)
# check dtype
if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG:
raise RuntimeError(
(
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
"dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}"
)
)
# check shape
m, n = original_tensor.shape
min_size = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_size
if m < min_size or m % min_size or n < min_size or n % min_size:
# TODO in the future we can add in padding to support dimensions that aren't perfect multiples
raise RuntimeError(
(
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
"Both dimensions must be larger than and a multiple of {min_size}"
)
)
# This code calculates the size of the compressed tensor.
# compression factor is different based on dtype it's given by the formula below for 2:4 sparsity:
# compression_factor = 1/2 + 1/bitwidth(dtype)
original_size = original_tensor.nelement()
compression_factor = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
original_tensor.dtype
].compression_factor
compressed_size = original_size * compression_factor // 16
compressed_tensor = torch.empty(
(compressed_size,),
dtype=original_tensor.dtype,
device=original_tensor.device,
)
# TODO This is a temporoary hack to get the mask in compressed form so we can store the compressed tensor.
# In the future, we will add in a conversion function from the mask to the meta that we can use instead.
placeholder = torch.ones(
(128, n), dtype=original_tensor.dtype, device=original_tensor.device
)
specified = original_tensor.masked_select(mask).view(m, n // 2)
_, meta = torch._structured_sparse_linear(placeholder, specified, mask)
# set the specified elements
compressed_tensor[: m * n // 2] = specified.view(-1)
# set the metadata
compressed_tensor[m * n // 2 :] = meta.view(original_tensor.dtype).view(-1)
# set values
self.original_tensor = None
self.compressed_tensor = compressed_tensor
self.transposed = transposed
def __repr__(self) -> str:
"""Return string representation of SparseSemiStructuredTensor
Returns:
str: String representation
Raises:
None
"""
return (
f"SparseSemiStructuredTensor(shape={self.shape}, "
f"transposed={self.transposed}"
f"values={self.values()}"
f"metadata={self.indices()})"
)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
"""Overload __torch_dispatch__ to use torch._structured_sparse_linear.
`torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels.
In the future we plan to also add in support for cuSPARSELt kernels.
Args:
func: The function being dispatched.
types: The types of the arguments.
args: The arguments passed to the function.
kwargs: The keyword arguments passed to the function.
Returns:
Any: The result of the dispatched operation.
Raises:
NotImplementedError: If the dispatched operation is not implemented.
"""
# Since this code runs below autograd, a detach corresponds to only returning a new object
if func is torch.ops.aten.detach.default:
return SparseSemiStructuredTensor(
args[0].original_tensor,
original_shape=args[0].shape,
mask=None,
compressed_tensor=args[0].compressed_tensor,
transposed=args[0].transposed,
)
# Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
if func is torch.ops.aten.t.default:
return SparseSemiStructuredTensor(
args[0].original_tensor,
original_shape=args[0].shape,
mask=None,
compressed_tensor=args[0].compressed_tensor,
transposed=not args[0].transposed,
)
# handle addmm
if func is torch.ops.aten.addmm.default:
bias, input_A, input_B = args
# Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS.
# CUTLASS only supports the first input to be sparse for a given matmul.
# cuSPARSELt does not have this limitation, although our implementation is only for sparse first.
# We support second matrix sparse matmul by taking advantage of some transpose properties:
# This is also why we want an odd number of transposed for second matrix sparse vs an even number
# of transpose calss for first matrix sparse.
# F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')''
# = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T
if isinstance(input_B, cls) and input_B.transposed:
result, _ = torch._structured_sparse_linear(
input_A, input_B.values(), input_B.indices(), bias=bias
)
return result
# handle mm
if func is torch.ops.aten.mm.default:
input_A, input_B = args
if isinstance(input_A, cls) and not input_A.transposed:
transposed_result, _ = torch._structured_sparse_linear(
input_B.t(), input_A.values(), input_A.indices()
)
return transposed_result.t()
elif isinstance(input_B, cls) and input_B.transposed:
result, _ = torch._structured_sparse_linear(
input_A, input_B.values(), input_B.indices()
)
return result
# When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(),
# so we must match the aten.linear op.
# TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here.
if func is torch.ops.aten.linear.default:
input_tensor, weight, bias = args
if isinstance(weight, cls):
result, _ = torch._structured_sparse_linear(
input_tensor, weight.values(), weight.indices(), bias=bias
)
return result
# handle values
if func is torch.ops.aten.values.default:
m, k = args[0].shape
num_kept_elements = m * k // 2
return args[0].compressed_tensor[:num_kept_elements].view(m, k // 2)
# handle indices
if func is torch.ops.aten.indices.default:
m, k = args[0].shape
num_kept_elements = m * k // 2
metadata = args[0].compressed_tensor[num_kept_elements:].view(m, -1)
# the metadata is expected to be in different datatypes for fp16/int8 respectively for CUTLASS.
if args[0].dtype is torch.int8:
return metadata.view(torch.int32)
elif args[0].dtype is torch.float16:
return metadata.view(torch.int16)
error_string = "\n".join(
[f"func {func} with args: "]
+ [f"arg{i}: {arg}" for i, arg in enumerate(args)]
)
raise NotImplementedError(error_string)
def to_sparse_semi_structured(
original_tensor: torch.Tensor,
mask: Optional[torch.Tensor] = None,
transposed: bool = False,
) -> SparseSemiStructuredTensor:
"""
This function converts a dense tensor into a sparse semi-structured tensor.
It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
Additionally, your tensor must be a positive multiple of a block size given the dtype
- torch.float16 (r, c) must be >= and a multiple of 64
- torch.int8 (r, c) must be >= and a multiple of 128
Args:
original_tensor (Tensor): the dense tensor to convert
mask (Optional BoolTensor): boolean mask to apply to the original tensor
transposed (bool, optional): whether the dense tensor is transposed
Returns:
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor and mask
Raises:
None
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
tensor([[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
...,
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
>>> A_sparse = to_sparse_semi_structured(A, mask=A.bool())
SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
...,
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
dtype=torch.int16))
"""
return SparseSemiStructuredTensor(original_tensor, mask=mask, transposed=transposed)