mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[core][pruning][sparse][feature] SparseSemiStructured tensor subclass (#102135)"
This reverts commit aea771de30427998e83010459b69da1ab66f0879.
Reverted https://github.com/pytorch/pytorch/pull/102135 on behalf of https://github.com/huydhn due to test_sparse_semi_structured.py::TestSparseSemiStructuredCUDA::test_mm_sparse_first_NT_cuda_int8 is still failing CUDA trunk jobs aea771de30
([comment](https://github.com/pytorch/pytorch/pull/102135#issuecomment-1608744110))
This commit is contained in:
@ -1,245 +0,0 @@
|
||||
import random
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import argparse
|
||||
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
|
||||
|
||||
|
||||
torch.set_printoptions(
|
||||
precision=2,
|
||||
threshold=None,
|
||||
edgeitems=16,
|
||||
linewidth=480,
|
||||
profile=None,
|
||||
sci_mode=False,
|
||||
)
|
||||
|
||||
|
||||
# helper model definition for pruner
|
||||
class Model(nn.Module):
|
||||
def __init__(self, m, k, dtype=None):
|
||||
super().__init__()
|
||||
# transposed so reversed
|
||||
self.linear = nn.Linear(k, m)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
def rand_sparse_semi_structured_mask(
|
||||
r, c, dtype=torch.float16, device="cuda", choice=None
|
||||
):
|
||||
"""
|
||||
This function returns a 1:2 sparse matrix of size (r, c).
|
||||
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
|
||||
"""
|
||||
|
||||
choices = [[0, 1], [1, 0]]
|
||||
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
||||
|
||||
return (
|
||||
torch.tensor(mask_entries, dtype=dtype, device=device)
|
||||
.reshape(r, c)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
|
||||
def test_linear(m, k, n, dtype, contiguous, backend):
|
||||
SparseSemiStructuredTensor.fuse_transpose = contiguous
|
||||
mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
|
||||
sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
|
||||
input_tensor = torch.zeros(n, k).to(dtype).cuda()
|
||||
model = Model(m, k).to(dtype).cuda().eval()
|
||||
|
||||
dense_measurement = benchmark.Timer(
|
||||
stmt="model(input_tensor)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
dense_output = model(input_tensor)
|
||||
|
||||
# sparsify weights
|
||||
model.linear.weight = nn.Parameter(to_sparse_semi_structured(sparse_weight, mask=mask.bool()))
|
||||
|
||||
sparse_output = model(input_tensor)
|
||||
|
||||
sparse_measurement = benchmark.Timer(
|
||||
stmt="model(input_tensor)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
|
||||
|
||||
return {
|
||||
"test_function": "linear",
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"dtype": str(dtype),
|
||||
"backend": backend,
|
||||
"sparse_latency (ms)": sparse_measurement.median * 1000,
|
||||
"dense_latency (ms)": dense_measurement.median * 1000,
|
||||
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
|
||||
"correct": correct,
|
||||
"contiguous": sparse_output.is_contiguous(),
|
||||
}
|
||||
|
||||
|
||||
def test_tensor(m, k, n, dtype, contiguous, backend):
|
||||
A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
|
||||
B = torch.zeros(k, n).to(dtype).cuda()
|
||||
bias = torch.rand(n).to(dtype).cuda()
|
||||
|
||||
sA = to_sparse_semi_structured(A, mask=A.bool())
|
||||
|
||||
# torch.mm calculation
|
||||
if dtype is not torch.int8:
|
||||
dense_output = torch.mm(A, B)
|
||||
|
||||
dense_measurement = benchmark.Timer(
|
||||
stmt="torch.mm(A, B)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
else:
|
||||
print("int8 baseline not supported")
|
||||
dense_output = torch.mm(sA, B)
|
||||
|
||||
dense_measurement = benchmark.Timer(
|
||||
stmt="torch.mm(sA, B)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
sparse_output = torch.mm(sA, B)
|
||||
sparse_measurement = benchmark.Timer(
|
||||
stmt="torch.mm(sA, B)",
|
||||
globals=locals(),
|
||||
).blocked_autorange()
|
||||
|
||||
correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
|
||||
|
||||
return {
|
||||
"test_function": "tensor",
|
||||
"m": m,
|
||||
"k": k,
|
||||
"n": n,
|
||||
"dtype": str(dtype),
|
||||
"backend": backend,
|
||||
"sparse_latency (ms)": sparse_measurement.median * 1000,
|
||||
"dense_latency (ms)": dense_measurement.median * 1000,
|
||||
"speedup (d/s)": dense_measurement.median / sparse_measurement.median,
|
||||
"correct": correct,
|
||||
"contiguous": sparse_output.is_contiguous(),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype_lookup = {
|
||||
"int8": torch.int8,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"nvidia-bert",
|
||||
"nvidia-fixed-k",
|
||||
"nvidia-fixed-mn",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=dtype_lookup.keys(),
|
||||
default="fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
|
||||
)
|
||||
parser.add_argument("-contiguous", action="store_true")
|
||||
parser.add_argument("-e2e", action="store_true")
|
||||
parser.add_argument("-save", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.e2e:
|
||||
eval_fn = test_linear
|
||||
else:
|
||||
eval_fn = test_tensor
|
||||
|
||||
print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
|
||||
dtype = dtype_lookup[args.dtype]
|
||||
|
||||
if args.mode == "nvidia-bert":
|
||||
bert_shapes = [
|
||||
(3072, 1024, 16384),
|
||||
(4096, 1024, 16384),
|
||||
(1024, 1024, 16384),
|
||||
(1024, 4096, 16384),
|
||||
]
|
||||
results = (
|
||||
eval_fn(m, k, n, dtype, args.contiguous, args.backend)
|
||||
for (m, k, n) in tqdm(bert_shapes)
|
||||
)
|
||||
|
||||
elif args.mode == "nvidia-fixed-k":
|
||||
mn_vals = [
|
||||
3072,
|
||||
4096,
|
||||
5120,
|
||||
6144,
|
||||
7168,
|
||||
8192,
|
||||
9216,
|
||||
10240,
|
||||
11264,
|
||||
12288,
|
||||
13312,
|
||||
14336,
|
||||
15360,
|
||||
16384,
|
||||
17408,
|
||||
18432,
|
||||
19456,
|
||||
20480,
|
||||
]
|
||||
results = (
|
||||
eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
|
||||
for mn in tqdm(mn_vals)
|
||||
)
|
||||
|
||||
elif args.mode == "nvidia-fixed-mn":
|
||||
k_vals = [
|
||||
2560,
|
||||
3840,
|
||||
5120,
|
||||
6400,
|
||||
7680,
|
||||
8960,
|
||||
10240,
|
||||
11520,
|
||||
12800,
|
||||
14080,
|
||||
15360,
|
||||
16640,
|
||||
17920,
|
||||
19200,
|
||||
20480,
|
||||
]
|
||||
results = (
|
||||
eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
|
||||
for k in tqdm(k_vals)
|
||||
)
|
||||
|
||||
df = pd.DataFrame.from_records(results)
|
||||
if args.save:
|
||||
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
|
||||
df.to_csv(save_file)
|
||||
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
|
||||
print(df)
|
@ -24,7 +24,7 @@ matrices, pruned weights or points clouds by Tensors whose *elements are
|
||||
mostly zero valued*. We recognize these are important applications and aim
|
||||
to provide performance optimizations for these use cases via sparse storage formats.
|
||||
|
||||
Various sparse storage formats such as COO, CSR/CSC, semi-structured, LIL, etc. have been
|
||||
Various sparse storage formats such as COO, CSR/CSC, LIL, etc. have been
|
||||
developed over the years. While they differ in exact layouts, they all
|
||||
compress data through efficient representation of zero valued elements.
|
||||
We call the uncompressed values *specified* in contrast to *unspecified*,
|
||||
@ -67,8 +67,6 @@ indices of non-zero elements are stored in this case.
|
||||
|
||||
PyTorch currently supports :ref:`COO<sparse-coo-docs>`, :ref:`CSR<sparse-csr-docs>`,
|
||||
:ref:`CSC<sparse-csc-docs>`, :ref:`BSR<sparse-bsr-docs>`, and :ref:`BSC<sparse-bsc-docs>`.
|
||||
|
||||
We also have a prototype implementation to support :ref: `semi-structured sparsity<sparse-semi-structured-docs>`.
|
||||
Please see the references for more details.
|
||||
|
||||
Note that we provide slight generalizations of these formats.
|
||||
@ -169,147 +167,6 @@ receiving a particular layout. We are working on an API to control the result la
|
||||
and recognize it is an important feature to plan a more optimal path of execution for
|
||||
any given model.
|
||||
|
||||
.. _sparse-semi-structured-docs:
|
||||
|
||||
Sparse Semi-Structured Tensors
|
||||
++++++++++++++++++++++++++++++
|
||||
|
||||
.. warning::
|
||||
|
||||
Sparse semi-sturctured tensors are currently a prototype feature and subject to change. Please feel free to open an issue to report a bug or if you have feedback to share.
|
||||
|
||||
Semi-Structured sparsity is a sparse data layout that was first introduced in NVIDIA's Ampere architecture. It is also referred to as **fine-grained structured sparsity** or **2:4 structured sparsity**.
|
||||
|
||||
This sparse layout stores `n` elements out of every `2n` elements, with `n` being determined by the width of the Tensor's data type (dtype). The most frequently used dtype is float16, where `n=2`, thus the term "2:4 structured sparsity."
|
||||
|
||||
Semi-structured sparsity is explained in greater detail in `this NVIDIA blog post <https://developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt>`_.
|
||||
|
||||
In PyTorch, semi-structured sparsity is implemented via a Tensor subclass.
|
||||
By subclassing, we can override ``__torch_dispatch__`` , allowing us to use faster sparse kernels when performing matrix multiplication.
|
||||
We can also store the tensor in it's compressed form inside the subclass to reduce memory overhead.
|
||||
|
||||
In this compressed form, the sparse tensor is stored by retaining only the *specified* elements and some metadata, which encodes the mask.
|
||||
|
||||
.. note::
|
||||
The specified elements and metadata mask of a semi-structured sparse tensor are stored together in a single
|
||||
flat compressed tensor. They are appended to each other to form a contiguous chunk of memory.
|
||||
|
||||
compressed tensor = [ specified elements of original tensor | metadata_mask ]
|
||||
|
||||
For an original tensor of size `(r, c)` we expect the first `m * k // 2` elements to be the kept elements
|
||||
and the rest of the tensor is metadata.
|
||||
|
||||
In order to make it easier for the user to view the specified elements
|
||||
and mask, one can use ``.indices()`` and ``.values()`` to access the mask and specified elements respectively.
|
||||
|
||||
|
||||
- ``.values()`` returns the specified elements in a tensor of size `(r, c//2)` and with the same dtype as the dense matrix.
|
||||
|
||||
- ``.indices()`` returns the metadata_mask in a tensor of size `(r, c//2 )` and with element type ``torch.int16`` if dtype is torch.float16 and element type ``torch.int32`` if dtype is torch.int8.
|
||||
|
||||
|
||||
For 2:4 sparse tensors, the metadata overhead is minor - just 2 bits per specified element.
|
||||
|
||||
.. note::
|
||||
It's important to note that ``torch.float32`` is only supported for 1:2 sparsity. Therefore, it does not follow the same formula as above.
|
||||
|
||||
Here, we break down how to calculate the compression ratio ( size dense / size sparse) of a 2:4 sparse tensor.
|
||||
|
||||
Let `(r, c) = tensor.shape` and `e = bitwidth(tensor.dtype)`, so `e = 16` for ``torch.float16`` and ``torch.bfloat16`` and `e = 8` for ``torch.int8``.
|
||||
|
||||
.. math::
|
||||
M_{dense} = r \times c \times e \\
|
||||
M_{sparse} = M_{specified} + M_{metadata} = r \times \frac{c}{2} \times e + r \times \frac{c}{2} \times 2 = \frac{rce}{2} + rc =rce(\frac{1}{2} +\frac{1}{e})
|
||||
|
||||
Using these calculations, we can determine the total memory footprint for both the original dense and the new sparse representation.
|
||||
|
||||
This gives us a simple formula for the compression ratio, which is dependent only on the bitwidth of the tensor datatype.
|
||||
|
||||
.. math::
|
||||
C = \frac{M_{sparse}}{M_{dense}} = \frac{1}{2} + \frac{1}{e}
|
||||
|
||||
By using this formula, we find that the compression ratio is 56.25% for ``torch.float16`` and 62.5% for ``torch.int8``.
|
||||
|
||||
Constructing Sparse Semi-Structured Tensors
|
||||
-------------------------------------------
|
||||
|
||||
You can transform a dense tensor into a sparse semi-structured tensor by using the ``torch.sparse.to_sparse_semi_structured`` function.
|
||||
|
||||
Please also note that we only support CUDA tensors since hardware compatibility for semi-structured sparsity is limited to NVIDIA GPUs.
|
||||
|
||||
|
||||
The following datatypes are supported for semi-structured sparsity. Note that each datatype has its own shape constraints and compression factor.
|
||||
|
||||
.. csv-table::
|
||||
:header: "PyTorch dtype", "Shape Constraints", "Compression Factor", "Sparsity Pattern"
|
||||
:widths: 15, 45, 10, 10
|
||||
:delim: ;
|
||||
|
||||
``torch.float16``; Tensor must be 2D and (r, c) must both be a positive multiple of 64;9/16;2:4
|
||||
``torch.int8``; Tensor must be 2D and (r, c) must both be a positive multiple of 128;10/16;2:4
|
||||
|
||||
|
||||
To construct a semi-structured sparse tensor, start by creating a regular dense tensor that adheres to a 2:4 (or semi-structured) sparse format.
|
||||
To do this we tile a small 1x4 strip to create a 16x16 dense float16 tensor.
|
||||
Afterwards, we can call ``to_sparse_semi_structured`` on this matrix to compress it for accelerated inference.
|
||||
|
||||
>>> from torch.sparse import to_sparse_semi_structured
|
||||
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
|
||||
tensor([[0., 0., 1., ..., 0., 1., 1.],
|
||||
[0., 0., 1., ..., 0., 1., 1.],
|
||||
[0., 0., 1., ..., 0., 1., 1.],
|
||||
...,
|
||||
[0., 0., 1., ..., 0., 1., 1.],
|
||||
[0., 0., 1., ..., 0., 1., 1.],
|
||||
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
|
||||
>>> A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
...,
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
...,
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
|
||||
dtype=torch.int16))
|
||||
|
||||
Sparse Semi-Structured Tensor Operations
|
||||
----------------------------------------
|
||||
|
||||
Currently, the following operations are supported for semi-structured sparse tensors:
|
||||
|
||||
- torch.addmm(bias, dense, sparse.t())
|
||||
- torch.mm(dense, sparse)
|
||||
- torch.mm(sparse, dense)
|
||||
- aten.linear.default(dense, sparse, bias)
|
||||
- aten.t.default(sparse)
|
||||
- aten.t.detach(sparse)
|
||||
|
||||
To use these ops, simply pass the output of ``to_sparse_semi_structured(tensor)`` instead of using ``tensor`` once your tensor has 0s in a semi-structured sparse format, like this:
|
||||
|
||||
>>> a = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).half().cuda()
|
||||
>>> b = torch.rand(64, 64).half().cuda()
|
||||
>>> c = torch.mm(a, b)
|
||||
>>> a_sparse = to_sparse_semi_structured(a, mask=a.bool())
|
||||
>>> torch.allclose(c, torch.mm(a_sparse, b))
|
||||
True
|
||||
|
||||
Under the hood, SparseSemiStructuredTensor will call ``torch._structured_sparse_linear`` for accelerated inference using CUTLASS sparse kernels.
|
||||
|
||||
Accelerating nn.Linear with semi-structured sparsity
|
||||
----------------------------------------------------
|
||||
You can accelerate the linear layers in your model if the weights are already semi-structured sparse with just a few lines of code:
|
||||
|
||||
>>> input = torch.rand(64, 64).half().cuda()
|
||||
>>> mask = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).cuda().bool()
|
||||
>>> linear = nn.Linear(64, 64).half().cuda()
|
||||
>>> linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=mask))
|
||||
|
||||
|
||||
.. _sparse-coo-docs:
|
||||
|
||||
@ -1135,18 +992,12 @@ multiplication, and ``@`` is matrix multiplication.
|
||||
:func:`torch.mv`;no; ``M[sparse_csr] @ V[strided] -> V[strided]``
|
||||
:func:`torch.matmul`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
|
||||
:func:`torch.matmul`; no; ``M[sparse_csr] @ M[strided] -> M[strided]``
|
||||
:func:`torch.matmul`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]``
|
||||
:func:`torch.matmul`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]``
|
||||
:func:`torch.mm`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
|
||||
:func:`torch.mm`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]``
|
||||
:func:`torch.mm`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]``
|
||||
:func:`torch.sparse.mm`; yes; ``M[sparse_coo] @ M[strided] -> M[strided]``
|
||||
:func:`torch.smm`; no; ``M[sparse_coo] @ M[strided] -> M[sparse_coo]``
|
||||
:func:`torch.hspmm`; no; ``M[sparse_coo] @ M[strided] -> M[hybrid sparse_coo]``
|
||||
:func:`torch.bmm`; no; ``T[sparse_coo] @ T[strided] -> T[strided]``
|
||||
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]``
|
||||
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[SparseSemiStructured] @ M[strided]) -> M[strided]``
|
||||
:func:`torch.addmm`; no; ``f * M[strided] + f * (M[strided] @ M[SparseSemiStructured]) -> M[strided]``
|
||||
:func:`torch.sparse.addmm`; yes; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]``
|
||||
:func:`torch.sspaddmm`; no; ``f * M[sparse_coo] + f * (M[sparse_coo] @ M[strided]) -> M[sparse_coo]``
|
||||
:func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]``
|
||||
|
@ -1,227 +0,0 @@
|
||||
# Owner(s): ["module: sparse"]
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.sparse.semi_structured import (
|
||||
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
|
||||
SparseSemiStructuredTensor,
|
||||
to_sparse_semi_structured,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
|
||||
|
||||
_IS_SM8X = False
|
||||
if torch.cuda.is_available():
|
||||
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
||||
|
||||
def rand_sparse_semi_structured_mask(
|
||||
r, c, dtype=torch.float16, device="cuda", choice=None
|
||||
):
|
||||
"""
|
||||
This function returns a 1:2 sparse matrix of size (r, c).
|
||||
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
|
||||
"""
|
||||
|
||||
choices = [[0, 1], [1, 0]]
|
||||
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
||||
|
||||
return (
|
||||
torch.tensor(mask_entries, dtype=dtype, device=device)
|
||||
.reshape(r, c)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
|
||||
class TestSparseSemiStructured(TestCase):
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_to_sparse_semi_structured(self, dtype):
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
|
||||
assert A.shape == A_sparse.shape
|
||||
assert A.device == A_sparse.device
|
||||
assert A.dtype == A_sparse.dtype
|
||||
|
||||
assert isinstance(A, torch.Tensor)
|
||||
assert isinstance(A_sparse, SparseSemiStructuredTensor)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"You must pass in a mask to to_sparse_semi_structured, currently mask=None.",
|
||||
):
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_mm_sparse_first_NT(self, dtype, device):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
|
||||
Ensure torch.mm(A_sparse, B.t()) is correct
|
||||
"""
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
|
||||
B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
|
||||
|
||||
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
||||
if dtype is torch.int8:
|
||||
# This should fail
|
||||
with self.assertRaisesRegex(RuntimeError, "_structured_sparse_linear"):
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
|
||||
# test transpose
|
||||
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
|
||||
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32)
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
dense_result = torch.mm(A, B)
|
||||
sparse_result = torch.mm(A_sparse, B)
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
# test transpose
|
||||
dense_result = torch.mm(A, B.t())
|
||||
sparse_result = torch.mm(A_sparse, B.t())
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_mm_sparse_first_T(self, dtype, device):
|
||||
"""
|
||||
Ensure torch.mm(A_sparse.t(), B) throws error
|
||||
"""
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
|
||||
B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
|
||||
):
|
||||
torch.mm(A_sparse.t(), B)
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_mm_sparse_second_T(self, dtype, device):
|
||||
"""
|
||||
Ensure torch.mm(A, B_sparse.t()) is correct
|
||||
"""
|
||||
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
B_sparse = to_sparse_semi_structured(B, mask=B.bool())
|
||||
|
||||
A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
|
||||
|
||||
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
||||
if dtype is torch.int8:
|
||||
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32)
|
||||
sparse_result = torch.mm(A, B_sparse.t())
|
||||
else:
|
||||
dense_result = torch.mm(A, B.t())
|
||||
sparse_result = torch.mm(A, B_sparse.t())
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_mm_sparse_second_NT(self, dtype, device):
|
||||
"""
|
||||
Ensure torch.mm(A, B_sparse) throws error
|
||||
"""
|
||||
B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
B_sparse = to_sparse_semi_structured(B, mask=B.bool())
|
||||
|
||||
A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
|
||||
):
|
||||
sparse_result = torch.mm(A, B_sparse)
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
@parametrize("inference_mode", [subtest(False), subtest(True)])
|
||||
def test_linear(self, inference_mode, device):
|
||||
"""
|
||||
Test nn.Linear has the same numerics
|
||||
"""
|
||||
input = torch.rand(128, 128, device=device).half()
|
||||
model = nn.Linear(128, 128).to(device).half()
|
||||
m, n = model.weight.shape
|
||||
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
|
||||
# set masked weight
|
||||
model.weight = nn.Parameter(model.weight * mask)
|
||||
|
||||
dense_result = model(input)
|
||||
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight, mask=mask))
|
||||
|
||||
if inference_mode:
|
||||
with torch.inference_mode():
|
||||
sparse_result = model(input)
|
||||
else:
|
||||
sparse_result = model(input)
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
def test_values(self):
|
||||
A = rand_sparse_semi_structured_mask(128, 128)
|
||||
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
assert A_sparse.values().shape == (128, 64)
|
||||
assert (A_sparse.values() == 1).all()
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
def test_indices(self):
|
||||
A = rand_sparse_semi_structured_mask(128, 128)
|
||||
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
assert A_sparse.indices().shape == (128, 8)
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
def test_unsupported_shape(self, dtype, device):
|
||||
A = rand_sparse_semi_structured_mask(4, 4, dtype=dtype, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
|
||||
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
@dtypes(*all_types_and_complex())
|
||||
def test_unsupported_dtype(self, dtype, device):
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
|
||||
|
||||
if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES:
|
||||
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
|
||||
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
else:
|
||||
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
|
||||
@unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
|
||||
def test_unsupported_dim(self, device):
|
||||
A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
|
||||
A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -5,9 +5,6 @@ import torch
|
||||
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
|
||||
from torch import Tensor
|
||||
|
||||
# Semi structured sparsity support
|
||||
from .semi_structured import SparseSemiStructuredTensor, to_sparse_semi_structured
|
||||
|
||||
# A workaround to support both TorchScript and MyPy:
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@ -26,10 +23,9 @@ __all__ = [
|
||||
'sum',
|
||||
'softmax',
|
||||
'log_softmax',
|
||||
'SparseSemiStructuredTensor',
|
||||
'to_sparse_semi_structured',
|
||||
]
|
||||
|
||||
|
||||
addmm = _add_docstr(_sparse._sparse_addmm, r"""
|
||||
sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
|
||||
|
||||
|
@ -1,389 +0,0 @@
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = [
|
||||
"to_sparse_semi_structured",
|
||||
"SparseSemiStructuredTensor",
|
||||
]
|
||||
|
||||
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
|
||||
"_SEMI_STRUCTURED_SPARSE_CONFIG", "compression_factor min_size"
|
||||
)
|
||||
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = {
|
||||
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(9, 64),
|
||||
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(10, 128),
|
||||
}
|
||||
|
||||
_WARNING_SHOWN = False
|
||||
|
||||
class SparseSemiStructuredTensor(torch.Tensor):
|
||||
"""This class implementes semi-structured sparsity as a Tensor subclass.
|
||||
|
||||
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
|
||||
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
|
||||
structured sparsity.
|
||||
|
||||
Currently, this class supports 2:4 sparsity for int8 and float16 dtypes.
|
||||
|
||||
This subclass stores the dense tensor in a compressed form by only storing the specified elemenets and a metadata mask.
|
||||
These two are stored next to each other in one contiguous tensor.
|
||||
|
||||
We choose to store the specified elements and the metadata in a single tensor for future compatibilty with cuSPARSELt.
|
||||
|
||||
compressed tensor = [ specified elements of original tensor | mask_metadata ]
|
||||
|
||||
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
|
||||
The rest of the tensor is metadata.
|
||||
|
||||
This subclass also overrides __torch_dispatch__ to use _structured_sparse_linear for faster matrix multiplications
|
||||
via sparse CUTLASS kernels. In the future we will also call into cuSPARSELt kernels for more performance gains.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(
|
||||
cls,
|
||||
original_tensor: Optional[torch.Tensor],
|
||||
original_shape: Optional[torch.Size] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
compressed_tensor: Optional[torch.Tensor] = None,
|
||||
transposed: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a new instance of the class.
|
||||
|
||||
When original_tensor is passed in, we compress it and store the compresed representation.
|
||||
We can also create new instance of the class from the compressed representation without the original tensor.
|
||||
|
||||
Args:
|
||||
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
|
||||
original_shape: The shape of the original dense tensor
|
||||
mask: Mask to be applied to the original tensor.
|
||||
compressed_tensor: A flattened tensor to store the specified elements and mask metadata.
|
||||
transposed: Whether the tensor is transposed or not.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A torch.Tensor wrapper subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If both original_tensor and compressed_tensor are None.
|
||||
|
||||
"""
|
||||
if original_tensor is not None:
|
||||
previous_tensor = original_tensor
|
||||
original_shape = original_tensor.shape
|
||||
elif compressed_tensor is not None:
|
||||
previous_tensor = compressed_tensor
|
||||
else:
|
||||
raise ValueError("Both compressed_tensor and original_tensor are None!")
|
||||
|
||||
kwargs = {}
|
||||
kwargs["device"] = previous_tensor.device # type: ignore[assignment]
|
||||
kwargs["dtype"] = previous_tensor.dtype # type: ignore[assignment]
|
||||
kwargs["layout"] = previous_tensor.layout # type: ignore[assignment]
|
||||
kwargs["requires_grad"] = False # type: ignore[assignment]
|
||||
|
||||
return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
original_tensor: Optional[torch.Tensor],
|
||||
original_shape: Optional[torch.Size] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
compressed_tensor: Optional[torch.Tensor] = None,
|
||||
transposed: bool = False,
|
||||
) -> None:
|
||||
"""SparseSemiStructuredTensor constructor.
|
||||
|
||||
Args:
|
||||
original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
|
||||
original_shape: The shape of the original dense tensor
|
||||
mask: Mask to be applied to the original tensor.
|
||||
compressed_tensor: A flattened tensor to store the specified elements and mask metadata.
|
||||
transposed: Whether the tensor is transposed or not.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If ``mask=None``, as we currently do not support inferring a mask from the dense tensor.
|
||||
RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device.
|
||||
"""
|
||||
global _WARNING_SHOWN
|
||||
if not _WARNING_SHOWN:
|
||||
warnings.warn(
|
||||
(
|
||||
"The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
|
||||
"and will change in the near future. Please open a Github issue "
|
||||
"for features requests and see our documentation on the torch.sparse "
|
||||
"module for further information about the project."
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
_WARNING_SHOWN = True
|
||||
|
||||
# if original tensor is passed in, we need to compress it and store the compressed representation.
|
||||
if original_tensor is not None:
|
||||
# check if mask passed in
|
||||
if mask is None:
|
||||
raise NotImplementedError("You must pass in a mask to to_sparse_semi_structured, currently mask=None.")
|
||||
|
||||
# check device
|
||||
if not original_tensor.is_cuda:
|
||||
raise RuntimeError(
|
||||
(
|
||||
f"Error original_tensor.device= {original_tensor.device} is not supported! "
|
||||
"Only CUDA tensors are currently supported."
|
||||
)
|
||||
)
|
||||
|
||||
# check dim
|
||||
if original_tensor.dim() != 2:
|
||||
raise RuntimeError(
|
||||
(
|
||||
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
|
||||
"Only 2d tensors are currently supported."
|
||||
)
|
||||
)
|
||||
|
||||
# check dtype
|
||||
if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG:
|
||||
raise RuntimeError(
|
||||
(
|
||||
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
|
||||
"dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}"
|
||||
)
|
||||
)
|
||||
|
||||
# check shape
|
||||
m, n = original_tensor.shape
|
||||
min_size = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_size
|
||||
if m < min_size or m % min_size or n < min_size or n % min_size:
|
||||
# TODO in the future we can add in padding to support dimensions that aren't perfect multiples
|
||||
raise RuntimeError(
|
||||
(
|
||||
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
|
||||
"Both dimensions must be larger than and a multiple of {min_size}"
|
||||
)
|
||||
)
|
||||
|
||||
# This code calculates the size of the compressed tensor.
|
||||
# compression factor is different based on dtype it's given by the formula below for 2:4 sparsity:
|
||||
# compression_factor = 1/2 + 1/bitwidth(dtype)
|
||||
original_size = original_tensor.nelement()
|
||||
compression_factor = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
|
||||
original_tensor.dtype
|
||||
].compression_factor
|
||||
compressed_size = original_size * compression_factor // 16
|
||||
|
||||
compressed_tensor = torch.empty(
|
||||
(compressed_size,),
|
||||
dtype=original_tensor.dtype,
|
||||
device=original_tensor.device,
|
||||
)
|
||||
|
||||
# TODO This is a temporoary hack to get the mask in compressed form so we can store the compressed tensor.
|
||||
# In the future, we will add in a conversion function from the mask to the meta that we can use instead.
|
||||
placeholder = torch.ones(
|
||||
(128, n), dtype=original_tensor.dtype, device=original_tensor.device
|
||||
)
|
||||
specified = original_tensor.masked_select(mask).view(m, n // 2)
|
||||
_, meta = torch._structured_sparse_linear(placeholder, specified, mask)
|
||||
# set the specified elements
|
||||
compressed_tensor[: m * n // 2] = specified.view(-1)
|
||||
# set the metadata
|
||||
compressed_tensor[m * n // 2 :] = meta.view(original_tensor.dtype).view(-1)
|
||||
|
||||
# set values
|
||||
self.original_tensor = None
|
||||
self.compressed_tensor = compressed_tensor
|
||||
self.transposed = transposed
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of SparseSemiStructuredTensor
|
||||
|
||||
Returns:
|
||||
str: String representation
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
return (
|
||||
f"SparseSemiStructuredTensor(shape={self.shape}, "
|
||||
f"transposed={self.transposed}"
|
||||
f"values={self.values()}"
|
||||
f"metadata={self.indices()})"
|
||||
)
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
|
||||
"""Overload __torch_dispatch__ to use torch._structured_sparse_linear.
|
||||
|
||||
`torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels.
|
||||
In the future we plan to also add in support for cuSPARSELt kernels.
|
||||
|
||||
Args:
|
||||
func: The function being dispatched.
|
||||
types: The types of the arguments.
|
||||
args: The arguments passed to the function.
|
||||
kwargs: The keyword arguments passed to the function.
|
||||
|
||||
Returns:
|
||||
Any: The result of the dispatched operation.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the dispatched operation is not implemented.
|
||||
"""
|
||||
# Since this code runs below autograd, a detach corresponds to only returning a new object
|
||||
if func is torch.ops.aten.detach.default:
|
||||
return SparseSemiStructuredTensor(
|
||||
args[0].original_tensor,
|
||||
original_shape=args[0].shape,
|
||||
mask=None,
|
||||
compressed_tensor=args[0].compressed_tensor,
|
||||
transposed=args[0].transposed,
|
||||
)
|
||||
|
||||
# Because we cannot go from the compressed representation back to the dense representation currently,
|
||||
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
||||
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
||||
if func is torch.ops.aten.t.default:
|
||||
return SparseSemiStructuredTensor(
|
||||
args[0].original_tensor,
|
||||
original_shape=args[0].shape,
|
||||
mask=None,
|
||||
compressed_tensor=args[0].compressed_tensor,
|
||||
transposed=not args[0].transposed,
|
||||
)
|
||||
|
||||
# handle addmm
|
||||
if func is torch.ops.aten.addmm.default:
|
||||
bias, input_A, input_B = args
|
||||
|
||||
# Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS.
|
||||
# CUTLASS only supports the first input to be sparse for a given matmul.
|
||||
# cuSPARSELt does not have this limitation, although our implementation is only for sparse first.
|
||||
|
||||
# We support second matrix sparse matmul by taking advantage of some transpose properties:
|
||||
# This is also why we want an odd number of transposed for second matrix sparse vs an even number
|
||||
# of transpose calss for first matrix sparse.
|
||||
# F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')''
|
||||
# = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T
|
||||
if isinstance(input_B, cls) and input_B.transposed:
|
||||
result, _ = torch._structured_sparse_linear(
|
||||
input_A, input_B.values(), input_B.indices(), bias=bias
|
||||
)
|
||||
return result
|
||||
|
||||
# handle mm
|
||||
if func is torch.ops.aten.mm.default:
|
||||
input_A, input_B = args
|
||||
|
||||
if isinstance(input_A, cls) and not input_A.transposed:
|
||||
transposed_result, _ = torch._structured_sparse_linear(
|
||||
input_B.t(), input_A.values(), input_A.indices()
|
||||
)
|
||||
return transposed_result.t()
|
||||
|
||||
elif isinstance(input_B, cls) and input_B.transposed:
|
||||
result, _ = torch._structured_sparse_linear(
|
||||
input_A, input_B.values(), input_B.indices()
|
||||
)
|
||||
return result
|
||||
|
||||
# When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(),
|
||||
# so we must match the aten.linear op.
|
||||
# TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here.
|
||||
if func is torch.ops.aten.linear.default:
|
||||
input_tensor, weight, bias = args
|
||||
if isinstance(weight, cls):
|
||||
result, _ = torch._structured_sparse_linear(
|
||||
input_tensor, weight.values(), weight.indices(), bias=bias
|
||||
)
|
||||
return result
|
||||
|
||||
# handle values
|
||||
if func is torch.ops.aten.values.default:
|
||||
m, k = args[0].shape
|
||||
num_kept_elements = m * k // 2
|
||||
return args[0].compressed_tensor[:num_kept_elements].view(m, k // 2)
|
||||
|
||||
# handle indices
|
||||
if func is torch.ops.aten.indices.default:
|
||||
m, k = args[0].shape
|
||||
num_kept_elements = m * k // 2
|
||||
metadata = args[0].compressed_tensor[num_kept_elements:].view(m, -1)
|
||||
|
||||
# the metadata is expected to be in different datatypes for fp16/int8 respectively for CUTLASS.
|
||||
if args[0].dtype is torch.int8:
|
||||
return metadata.view(torch.int32)
|
||||
elif args[0].dtype is torch.float16:
|
||||
return metadata.view(torch.int16)
|
||||
|
||||
error_string = "\n".join(
|
||||
[f"func {func} with args: "]
|
||||
+ [f"arg{i}: {arg}" for i, arg in enumerate(args)]
|
||||
)
|
||||
raise NotImplementedError(error_string)
|
||||
|
||||
|
||||
def to_sparse_semi_structured(
|
||||
original_tensor: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
transposed: bool = False,
|
||||
) -> SparseSemiStructuredTensor:
|
||||
"""
|
||||
This function converts a dense tensor into a sparse semi-structured tensor.
|
||||
It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
|
||||
|
||||
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
|
||||
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
|
||||
Additionally, your tensor must be a positive multiple of a block size given the dtype
|
||||
|
||||
- torch.float16 (r, c) must be >= and a multiple of 64
|
||||
- torch.int8 (r, c) must be >= and a multiple of 128
|
||||
|
||||
Args:
|
||||
original_tensor (Tensor): the dense tensor to convert
|
||||
mask (Optional BoolTensor): boolean mask to apply to the original tensor
|
||||
transposed (bool, optional): whether the dense tensor is transposed
|
||||
|
||||
Returns:
|
||||
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor and mask
|
||||
|
||||
Raises:
|
||||
None
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
|
||||
tensor([[0., 0., 1., ..., 0., 1., 1.],
|
||||
[0., 0., 1., ..., 0., 1., 1.],
|
||||
[0., 0., 1., ..., 0., 1., 1.],
|
||||
...,
|
||||
[0., 0., 1., ..., 0., 1., 1.],
|
||||
[0., 0., 1., ..., 0., 1., 1.],
|
||||
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
|
||||
>>> A_sparse = to_sparse_semi_structured(A, mask=A.bool())
|
||||
SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
...,
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
|
||||
metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
...,
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
|
||||
dtype=torch.int16))
|
||||
"""
|
||||
return SparseSemiStructuredTensor(original_tensor, mask=mask, transposed=transposed)
|
Reference in New Issue
Block a user