diff --git a/benchmarks/sparse/benchmark_semi_structured_sparsity.py b/benchmarks/sparse/benchmark_semi_structured_sparsity.py new file mode 100644 index 000000000000..c6753a95e678 --- /dev/null +++ b/benchmarks/sparse/benchmark_semi_structured_sparsity.py @@ -0,0 +1,245 @@ +import random +import torch +import torch.utils.benchmark as benchmark +from torch import nn +from tqdm import tqdm +import pandas as pd +import argparse +from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor + + +torch.set_printoptions( + precision=2, + threshold=None, + edgeitems=16, + linewidth=480, + profile=None, + sci_mode=False, +) + + +# helper model definition for pruner +class Model(nn.Module): + def __init__(self, m, k, dtype=None): + super().__init__() + # transposed so reversed + self.linear = nn.Linear(k, m) + + def forward(self, x): + return self.linear(x) + + +def rand_sparse_semi_structured_mask( + r, c, dtype=torch.float16, device="cuda", choice=None +): + """ + This function returns a 1:2 sparse matrix of size (r, c). + Note that this means this matrix will also be 2:4 and 4:8 sparse as well. + """ + + choices = [[0, 1], [1, 0]] + mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] + + return ( + torch.tensor(mask_entries, dtype=dtype, device=device) + .reshape(r, c) + .contiguous() + ) + + +def test_linear(m, k, n, dtype, contiguous, backend): + SparseSemiStructuredTensor.fuse_transpose = contiguous + mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype) + sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask + input_tensor = torch.zeros(n, k).to(dtype).cuda() + model = Model(m, k).to(dtype).cuda().eval() + + dense_measurement = benchmark.Timer( + stmt="model(input_tensor)", + globals=locals(), + ).blocked_autorange() + + dense_output = model(input_tensor) + + # sparsify weights + model.linear.weight = nn.Parameter(to_sparse_semi_structured(sparse_weight, mask=mask.bool())) + + sparse_output = model(input_tensor) + + sparse_measurement = benchmark.Timer( + stmt="model(input_tensor)", + globals=locals(), + ).blocked_autorange() + + correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) + + return { + "test_function": "linear", + "m": m, + "k": k, + "n": n, + "dtype": str(dtype), + "backend": backend, + "sparse_latency (ms)": sparse_measurement.median * 1000, + "dense_latency (ms)": dense_measurement.median * 1000, + "speedup (d/s)": dense_measurement.median / sparse_measurement.median, + "correct": correct, + "contiguous": sparse_output.is_contiguous(), + } + + +def test_tensor(m, k, n, dtype, contiguous, backend): + A = rand_sparse_semi_structured_mask(m, k, dtype=dtype) + B = torch.zeros(k, n).to(dtype).cuda() + bias = torch.rand(n).to(dtype).cuda() + + sA = to_sparse_semi_structured(A, mask=A.bool()) + + # torch.mm calculation + if dtype is not torch.int8: + dense_output = torch.mm(A, B) + + dense_measurement = benchmark.Timer( + stmt="torch.mm(A, B)", + globals=locals(), + ).blocked_autorange() + + else: + print("int8 baseline not supported") + dense_output = torch.mm(sA, B) + + dense_measurement = benchmark.Timer( + stmt="torch.mm(sA, B)", + globals=locals(), + ).blocked_autorange() + + sparse_output = torch.mm(sA, B) + sparse_measurement = benchmark.Timer( + stmt="torch.mm(sA, B)", + globals=locals(), + ).blocked_autorange() + + correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) + + return { + "test_function": "tensor", + "m": m, + "k": k, + "n": n, + "dtype": str(dtype), + "backend": backend, + "sparse_latency (ms)": sparse_measurement.median * 1000, + "dense_latency (ms)": dense_measurement.median * 1000, + "speedup (d/s)": dense_measurement.median / sparse_measurement.median, + "correct": correct, + "contiguous": sparse_output.is_contiguous(), + } + + +if __name__ == "__main__": + dtype_lookup = { + "int8": torch.int8, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + + parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks") + parser.add_argument( + "--mode", + type=str, + choices=[ + "nvidia-bert", + "nvidia-fixed-k", + "nvidia-fixed-mn", + ], + ) + parser.add_argument( + "--dtype", + type=str, + choices=dtype_lookup.keys(), + default="fp16", + ) + parser.add_argument( + "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt" + ) + parser.add_argument("-contiguous", action="store_true") + parser.add_argument("-e2e", action="store_true") + parser.add_argument("-save", action="store_true") + args = parser.parse_args() + + if args.e2e: + eval_fn = test_linear + else: + eval_fn = test_tensor + + print(f"Started benchmark: {args.mode} | dtype: {args.dtype}") + dtype = dtype_lookup[args.dtype] + + if args.mode == "nvidia-bert": + bert_shapes = [ + (3072, 1024, 16384), + (4096, 1024, 16384), + (1024, 1024, 16384), + (1024, 4096, 16384), + ] + results = ( + eval_fn(m, k, n, dtype, args.contiguous, args.backend) + for (m, k, n) in tqdm(bert_shapes) + ) + + elif args.mode == "nvidia-fixed-k": + mn_vals = [ + 3072, + 4096, + 5120, + 6144, + 7168, + 8192, + 9216, + 10240, + 11264, + 12288, + 13312, + 14336, + 15360, + 16384, + 17408, + 18432, + 19456, + 20480, + ] + results = ( + eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend) + for mn in tqdm(mn_vals) + ) + + elif args.mode == "nvidia-fixed-mn": + k_vals = [ + 2560, + 3840, + 5120, + 6400, + 7680, + 8960, + 10240, + 11520, + 12800, + 14080, + 15360, + 16640, + 17920, + 19200, + 20480, + ] + results = ( + eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend) + for k in tqdm(k_vals) + ) + + df = pd.DataFrame.from_records(results) + if args.save: + save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" + df.to_csv(save_file) + print(f"Finished benchmark: {args.mode} saved results to {save_file}") + print(df) diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index c273f74b8c0b..364d457b70fe 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -24,7 +24,7 @@ matrices, pruned weights or points clouds by Tensors whose *elements are mostly zero valued*. We recognize these are important applications and aim to provide performance optimizations for these use cases via sparse storage formats. -Various sparse storage formats such as COO, CSR/CSC, LIL, etc. have been +Various sparse storage formats such as COO, CSR/CSC, semi-structured, LIL, etc. have been developed over the years. While they differ in exact layouts, they all compress data through efficient representation of zero valued elements. We call the uncompressed values *specified* in contrast to *unspecified*, @@ -67,6 +67,8 @@ indices of non-zero elements are stored in this case. PyTorch currently supports :ref:`COO`, :ref:`CSR`, :ref:`CSC`, :ref:`BSR`, and :ref:`BSC`. + +We also have a prototype implementation to support :ref: `semi-structured sparsity`. Please see the references for more details. Note that we provide slight generalizations of these formats. @@ -167,6 +169,147 @@ receiving a particular layout. We are working on an API to control the result la and recognize it is an important feature to plan a more optimal path of execution for any given model. +.. _sparse-semi-structured-docs: + +Sparse Semi-Structured Tensors +++++++++++++++++++++++++++++++ + +.. warning:: + + Sparse semi-sturctured tensors are currently a prototype feature and subject to change. Please feel free to open an issue to report a bug or if you have feedback to share. + +Semi-Structured sparsity is a sparse data layout that was first introduced in NVIDIA's Ampere architecture. It is also referred to as **fine-grained structured sparsity** or **2:4 structured sparsity**. + +This sparse layout stores `n` elements out of every `2n` elements, with `n` being determined by the width of the Tensor's data type (dtype). The most frequently used dtype is float16, where `n=2`, thus the term "2:4 structured sparsity." + +Semi-structured sparsity is explained in greater detail in `this NVIDIA blog post `_. + +In PyTorch, semi-structured sparsity is implemented via a Tensor subclass. +By subclassing, we can override ``__torch_dispatch__`` , allowing us to use faster sparse kernels when performing matrix multiplication. +We can also store the tensor in it's compressed form inside the subclass to reduce memory overhead. + +In this compressed form, the sparse tensor is stored by retaining only the *specified* elements and some metadata, which encodes the mask. + +.. note:: + The specified elements and metadata mask of a semi-structured sparse tensor are stored together in a single + flat compressed tensor. They are appended to each other to form a contiguous chunk of memory. + + compressed tensor = [ specified elements of original tensor | metadata_mask ] + + For an original tensor of size `(r, c)` we expect the first `m * k // 2` elements to be the kept elements + and the rest of the tensor is metadata. + + In order to make it easier for the user to view the specified elements + and mask, one can use ``.indices()`` and ``.values()`` to access the mask and specified elements respectively. + + + - ``.values()`` returns the specified elements in a tensor of size `(r, c//2)` and with the same dtype as the dense matrix. + + - ``.indices()`` returns the metadata_mask in a tensor of size `(r, c//2 )` and with element type ``torch.int16`` if dtype is torch.float16 and element type ``torch.int32`` if dtype is torch.int8. + + +For 2:4 sparse tensors, the metadata overhead is minor - just 2 bits per specified element. + +.. note:: + It's important to note that ``torch.float32`` is only supported for 1:2 sparsity. Therefore, it does not follow the same formula as above. + +Here, we break down how to calculate the compression ratio ( size dense / size sparse) of a 2:4 sparse tensor. + +Let `(r, c) = tensor.shape` and `e = bitwidth(tensor.dtype)`, so `e = 16` for ``torch.float16`` and ``torch.bfloat16`` and `e = 8` for ``torch.int8``. + +.. math:: + M_{dense} = r \times c \times e \\ + M_{sparse} = M_{specified} + M_{metadata} = r \times \frac{c}{2} \times e + r \times \frac{c}{2} \times 2 = \frac{rce}{2} + rc =rce(\frac{1}{2} +\frac{1}{e}) + +Using these calculations, we can determine the total memory footprint for both the original dense and the new sparse representation. + +This gives us a simple formula for the compression ratio, which is dependent only on the bitwidth of the tensor datatype. + +.. math:: + C = \frac{M_{sparse}}{M_{dense}} = \frac{1}{2} + \frac{1}{e} + +By using this formula, we find that the compression ratio is 56.25% for ``torch.float16`` and 62.5% for ``torch.int8``. + +Constructing Sparse Semi-Structured Tensors +------------------------------------------- + +You can transform a dense tensor into a sparse semi-structured tensor by using the ``torch.sparse.to_sparse_semi_structured`` function. + +Please also note that we only support CUDA tensors since hardware compatibility for semi-structured sparsity is limited to NVIDIA GPUs. + + +The following datatypes are supported for semi-structured sparsity. Note that each datatype has its own shape constraints and compression factor. + +.. csv-table:: + :header: "PyTorch dtype", "Shape Constraints", "Compression Factor", "Sparsity Pattern" + :widths: 15, 45, 10, 10 + :delim: ; + + ``torch.float16``; Tensor must be 2D and (r, c) must both be a positive multiple of 64;9/16;2:4 + ``torch.int8``; Tensor must be 2D and (r, c) must both be a positive multiple of 128;10/16;2:4 + + +To construct a semi-structured sparse tensor, start by creating a regular dense tensor that adheres to a 2:4 (or semi-structured) sparse format. +To do this we tile a small 1x4 strip to create a 16x16 dense float16 tensor. +Afterwards, we can call ``to_sparse_semi_structured`` on this matrix to compress it for accelerated inference. + + >>> from torch.sparse import to_sparse_semi_structured + >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() + tensor([[0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + ..., + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) + >>> A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + ..., + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + ..., + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', + dtype=torch.int16)) + +Sparse Semi-Structured Tensor Operations +---------------------------------------- + +Currently, the following operations are supported for semi-structured sparse tensors: + +- torch.addmm(bias, dense, sparse.t()) +- torch.mm(dense, sparse) +- torch.mm(sparse, dense) +- aten.linear.default(dense, sparse, bias) +- aten.t.default(sparse) +- aten.t.detach(sparse) + +To use these ops, simply pass the output of ``to_sparse_semi_structured(tensor)`` instead of using ``tensor`` once your tensor has 0s in a semi-structured sparse format, like this: + + >>> a = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).half().cuda() + >>> b = torch.rand(64, 64).half().cuda() + >>> c = torch.mm(a, b) + >>> a_sparse = to_sparse_semi_structured(a, mask=a.bool()) + >>> torch.allclose(c, torch.mm(a_sparse, b)) + True + +Under the hood, SparseSemiStructuredTensor will call ``torch._structured_sparse_linear`` for accelerated inference using CUTLASS sparse kernels. + +Accelerating nn.Linear with semi-structured sparsity +---------------------------------------------------- +You can accelerate the linear layers in your model if the weights are already semi-structured sparse with just a few lines of code: + + >>> input = torch.rand(64, 64).half().cuda() + >>> mask = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).cuda().bool() + >>> linear = nn.Linear(64, 64).half().cuda() + >>> linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=mask)) + .. _sparse-coo-docs: @@ -992,12 +1135,18 @@ multiplication, and ``@`` is matrix multiplication. :func:`torch.mv`;no; ``M[sparse_csr] @ V[strided] -> V[strided]`` :func:`torch.matmul`; no; ``M[sparse_coo] @ M[strided] -> M[strided]`` :func:`torch.matmul`; no; ``M[sparse_csr] @ M[strided] -> M[strided]`` + :func:`torch.matmul`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]`` + :func:`torch.matmul`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]`` :func:`torch.mm`; no; ``M[sparse_coo] @ M[strided] -> M[strided]`` + :func:`torch.mm`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]`` + :func:`torch.mm`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]`` :func:`torch.sparse.mm`; yes; ``M[sparse_coo] @ M[strided] -> M[strided]`` :func:`torch.smm`; no; ``M[sparse_coo] @ M[strided] -> M[sparse_coo]`` :func:`torch.hspmm`; no; ``M[sparse_coo] @ M[strided] -> M[hybrid sparse_coo]`` :func:`torch.bmm`; no; ``T[sparse_coo] @ T[strided] -> T[strided]`` :func:`torch.addmm`; no; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]`` + :func:`torch.addmm`; no; ``f * M[strided] + f * (M[SparseSemiStructured] @ M[strided]) -> M[strided]`` + :func:`torch.addmm`; no; ``f * M[strided] + f * (M[strided] @ M[SparseSemiStructured]) -> M[strided]`` :func:`torch.sparse.addmm`; yes; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]`` :func:`torch.sspaddmm`; no; ``f * M[sparse_coo] + f * (M[sparse_coo] @ M[strided]) -> M[sparse_coo]`` :func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]`` diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py new file mode 100644 index 000000000000..7f2c813287d9 --- /dev/null +++ b/test/test_sparse_semi_structured.py @@ -0,0 +1,227 @@ +# Owner(s): ["module: sparse"] +import random +import unittest + +import torch +from torch import nn + +from torch.sparse.semi_structured import ( + _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG, + SparseSemiStructuredTensor, + to_sparse_semi_structured, +) + +from torch.testing._internal.common_device_type import ( + dtypes, + instantiate_device_type_tests, +) + +from torch.testing._internal.common_dtype import all_types_and_complex + +from torch.testing._internal.common_utils import ( + parametrize, + run_tests, + subtest, + TestCase, +) + +SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys() + +_IS_SM8X = False +if torch.cuda.is_available(): + _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 + +def rand_sparse_semi_structured_mask( + r, c, dtype=torch.float16, device="cuda", choice=None +): + """ + This function returns a 1:2 sparse matrix of size (r, c). + Note that this means this matrix will also be 2:4 and 4:8 sparse as well. + """ + + choices = [[0, 1], [1, 0]] + mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] + + return ( + torch.tensor(mask_entries, dtype=dtype, device=device) + .reshape(r, c) + .contiguous() + ) + + +class TestSparseSemiStructured(TestCase): + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + def test_to_sparse_semi_structured(self, dtype): + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + + assert A.shape == A_sparse.shape + assert A.device == A_sparse.device + assert A.dtype == A_sparse.dtype + + assert isinstance(A, torch.Tensor) + assert isinstance(A_sparse, SparseSemiStructuredTensor) + + with self.assertRaisesRegex( + NotImplementedError, + "You must pass in a mask to to_sparse_semi_structured, currently mask=None.", + ): + A_sparse = to_sparse_semi_structured(A) + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + def test_mm_sparse_first_NT(self, dtype, device): + """ + Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8 + Ensure torch.mm(A_sparse, B.t()) is correct + """ + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + + B = torch.rand((128, 128), device=A_sparse.device).to(dtype) + + # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over + if dtype is torch.int8: + # This should fail + with self.assertRaisesRegex(RuntimeError, "_structured_sparse_linear"): + sparse_result = torch.mm(A_sparse, B) + + # test transpose + # NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior. + # CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor + dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32) + sparse_result = torch.mm(A_sparse, B.t()) + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + else: + dense_result = torch.mm(A, B) + sparse_result = torch.mm(A_sparse, B) + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + # test transpose + dense_result = torch.mm(A, B.t()) + sparse_result = torch.mm(A_sparse, B.t()) + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + def test_mm_sparse_first_T(self, dtype, device): + """ + Ensure torch.mm(A_sparse.t(), B) throws error + """ + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + + B = torch.rand((128, 128), device=A_sparse.device).to(dtype) + + with self.assertRaisesRegex( + NotImplementedError, + r"arg0: SparseSemiStructuredTensor\(.*transposed=True", + ): + torch.mm(A_sparse.t(), B) + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + def test_mm_sparse_second_T(self, dtype, device): + """ + Ensure torch.mm(A, B_sparse.t()) is correct + """ + B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + B_sparse = to_sparse_semi_structured(B, mask=B.bool()) + + A = torch.rand((128, 128), device=B_sparse.device).to(dtype) + + # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over + if dtype is torch.int8: + dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32) + sparse_result = torch.mm(A, B_sparse.t()) + else: + dense_result = torch.mm(A, B.t()) + sparse_result = torch.mm(A, B_sparse.t()) + + assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + def test_mm_sparse_second_NT(self, dtype, device): + """ + Ensure torch.mm(A, B_sparse) throws error + """ + B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) + B_sparse = to_sparse_semi_structured(B, mask=B.bool()) + + A = torch.rand((128, 128), device=B_sparse.device).to(dtype) + + with self.assertRaisesRegex( + NotImplementedError, + r"arg1: SparseSemiStructuredTensor\(.*transposed=False", + ): + sparse_result = torch.mm(A, B_sparse) + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + @parametrize("inference_mode", [subtest(False), subtest(True)]) + def test_linear(self, inference_mode, device): + """ + Test nn.Linear has the same numerics + """ + input = torch.rand(128, 128, device=device).half() + model = nn.Linear(128, 128).to(device).half() + m, n = model.weight.shape + mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool) + # set masked weight + model.weight = nn.Parameter(model.weight * mask) + + dense_result = model(input) + model.weight = nn.Parameter(to_sparse_semi_structured(model.weight, mask=mask)) + + if inference_mode: + with torch.inference_mode(): + sparse_result = model(input) + else: + sparse_result = model(input) + + assert torch.allclose(dense_result, sparse_result, rtol=1e-5, atol=1e-5) + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + def test_values(self): + A = rand_sparse_semi_structured_mask(128, 128) + A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + assert A_sparse.values().shape == (128, 64) + assert (A_sparse.values() == 1).all() + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + def test_indices(self): + A = rand_sparse_semi_structured_mask(128, 128) + A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + assert A_sparse.indices().shape == (128, 8) + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + def test_unsupported_shape(self, dtype, device): + A = rand_sparse_semi_structured_mask(4, 4, dtype=dtype, device=device) + with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"): + A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + @dtypes(*all_types_and_complex()) + def test_unsupported_dtype(self, dtype, device): + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device) + + if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES: + with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"): + A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + else: + A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + + @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version") + def test_unsupported_dim(self, device): + A = torch.rand(128, 128, 128, device=device, dtype=torch.float16) + + with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"): + A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + + +instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda") + +if __name__ == "__main__": + run_tests() diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 6f05dfbb2209..b1a91fb82c86 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -5,6 +5,9 @@ import torch from torch._C import _add_docstr, _sparse # type: ignore[attr-defined] from torch import Tensor +# Semi structured sparsity support +from .semi_structured import SparseSemiStructuredTensor, to_sparse_semi_structured + # A workaround to support both TorchScript and MyPy: from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -23,9 +26,10 @@ __all__ = [ 'sum', 'softmax', 'log_softmax', + 'SparseSemiStructuredTensor', + 'to_sparse_semi_structured', ] - addmm = _add_docstr(_sparse._sparse_addmm, r""" sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py new file mode 100644 index 000000000000..1bf13c0a7089 --- /dev/null +++ b/torch/sparse/semi_structured.py @@ -0,0 +1,389 @@ +import warnings +from collections import namedtuple +from typing import Any, Optional + +import torch + + +__all__ = [ + "to_sparse_semi_structured", + "SparseSemiStructuredTensor", +] + +_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple( + "_SEMI_STRUCTURED_SPARSE_CONFIG", "compression_factor min_size" +) +_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = { + torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(9, 64), + torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(10, 128), +} + +_WARNING_SHOWN = False + +class SparseSemiStructuredTensor(torch.Tensor): + """This class implementes semi-structured sparsity as a Tensor subclass. + + Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, + depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained + structured sparsity. + + Currently, this class supports 2:4 sparsity for int8 and float16 dtypes. + + This subclass stores the dense tensor in a compressed form by only storing the specified elemenets and a metadata mask. + These two are stored next to each other in one contiguous tensor. + + We choose to store the specified elements and the metadata in a single tensor for future compatibilty with cuSPARSELt. + + compressed tensor = [ specified elements of original tensor | mask_metadata ] + + For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements + The rest of the tensor is metadata. + + This subclass also overrides __torch_dispatch__ to use _structured_sparse_linear for faster matrix multiplications + via sparse CUTLASS kernels. In the future we will also call into cuSPARSELt kernels for more performance gains. + """ + + @staticmethod + def __new__( + cls, + original_tensor: Optional[torch.Tensor], + original_shape: Optional[torch.Size] = None, + mask: Optional[torch.Tensor] = None, + compressed_tensor: Optional[torch.Tensor] = None, + transposed: bool = False, + ): + """ + Create a new instance of the class. + + When original_tensor is passed in, we compress it and store the compresed representation. + We can also create new instance of the class from the compressed representation without the original tensor. + + Args: + original_tensor: The original dense tensor, or None, if we have already compressed the tensor. + original_shape: The shape of the original dense tensor + mask: Mask to be applied to the original tensor. + compressed_tensor: A flattened tensor to store the specified elements and mask metadata. + transposed: Whether the tensor is transposed or not. + + Returns: + torch.Tensor: A torch.Tensor wrapper subclass. + + Raises: + ValueError: If both original_tensor and compressed_tensor are None. + + """ + if original_tensor is not None: + previous_tensor = original_tensor + original_shape = original_tensor.shape + elif compressed_tensor is not None: + previous_tensor = compressed_tensor + else: + raise ValueError("Both compressed_tensor and original_tensor are None!") + + kwargs = {} + kwargs["device"] = previous_tensor.device # type: ignore[assignment] + kwargs["dtype"] = previous_tensor.dtype # type: ignore[assignment] + kwargs["layout"] = previous_tensor.layout # type: ignore[assignment] + kwargs["requires_grad"] = False # type: ignore[assignment] + + return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + original_tensor: Optional[torch.Tensor], + original_shape: Optional[torch.Size] = None, + mask: Optional[torch.Tensor] = None, + compressed_tensor: Optional[torch.Tensor] = None, + transposed: bool = False, + ) -> None: + """SparseSemiStructuredTensor constructor. + + Args: + original_tensor: The original dense tensor, or None, if we have already compressed the tensor. + original_shape: The shape of the original dense tensor + mask: Mask to be applied to the original tensor. + compressed_tensor: A flattened tensor to store the specified elements and mask metadata. + transposed: Whether the tensor is transposed or not. + + Returns: + None + + Raises: + NotImplementedError: If ``mask=None``, as we currently do not support inferring a mask from the dense tensor. + RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device. + """ + global _WARNING_SHOWN + if not _WARNING_SHOWN: + warnings.warn( + ( + "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " + "and will change in the near future. Please open a Github issue " + "for features requests and see our documentation on the torch.sparse " + "module for further information about the project." + ), + UserWarning, + ) + _WARNING_SHOWN = True + + # if original tensor is passed in, we need to compress it and store the compressed representation. + if original_tensor is not None: + # check if mask passed in + if mask is None: + raise NotImplementedError("You must pass in a mask to to_sparse_semi_structured, currently mask=None.") + + # check device + if not original_tensor.is_cuda: + raise RuntimeError( + ( + f"Error original_tensor.device= {original_tensor.device} is not supported! " + "Only CUDA tensors are currently supported." + ) + ) + + # check dim + if original_tensor.dim() != 2: + raise RuntimeError( + ( + f"Error original_tensor.dim = {original_tensor.dim()} is not supported! " + "Only 2d tensors are currently supported." + ) + ) + + # check dtype + if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG: + raise RuntimeError( + ( + f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! " + "dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}" + ) + ) + + # check shape + m, n = original_tensor.shape + min_size = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_size + if m < min_size or m % min_size or n < min_size or n % min_size: + # TODO in the future we can add in padding to support dimensions that aren't perfect multiples + raise RuntimeError( + ( + f"Error original_tensor.shape {original_tensor.shape} is not supported! " + "Both dimensions must be larger than and a multiple of {min_size}" + ) + ) + + # This code calculates the size of the compressed tensor. + # compression factor is different based on dtype it's given by the formula below for 2:4 sparsity: + # compression_factor = 1/2 + 1/bitwidth(dtype) + original_size = original_tensor.nelement() + compression_factor = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ + original_tensor.dtype + ].compression_factor + compressed_size = original_size * compression_factor // 16 + + compressed_tensor = torch.empty( + (compressed_size,), + dtype=original_tensor.dtype, + device=original_tensor.device, + ) + + # TODO This is a temporoary hack to get the mask in compressed form so we can store the compressed tensor. + # In the future, we will add in a conversion function from the mask to the meta that we can use instead. + placeholder = torch.ones( + (128, n), dtype=original_tensor.dtype, device=original_tensor.device + ) + specified = original_tensor.masked_select(mask).view(m, n // 2) + _, meta = torch._structured_sparse_linear(placeholder, specified, mask) + # set the specified elements + compressed_tensor[: m * n // 2] = specified.view(-1) + # set the metadata + compressed_tensor[m * n // 2 :] = meta.view(original_tensor.dtype).view(-1) + + # set values + self.original_tensor = None + self.compressed_tensor = compressed_tensor + self.transposed = transposed + + def __repr__(self) -> str: + """Return string representation of SparseSemiStructuredTensor + + Returns: + str: String representation + + Raises: + None + """ + return ( + f"SparseSemiStructuredTensor(shape={self.shape}, " + f"transposed={self.transposed}" + f"values={self.values()}" + f"metadata={self.indices()})" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: + """Overload __torch_dispatch__ to use torch._structured_sparse_linear. + + `torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels. + In the future we plan to also add in support for cuSPARSELt kernels. + + Args: + func: The function being dispatched. + types: The types of the arguments. + args: The arguments passed to the function. + kwargs: The keyword arguments passed to the function. + + Returns: + Any: The result of the dispatched operation. + + Raises: + NotImplementedError: If the dispatched operation is not implemented. + """ + # Since this code runs below autograd, a detach corresponds to only returning a new object + if func is torch.ops.aten.detach.default: + return SparseSemiStructuredTensor( + args[0].original_tensor, + original_shape=args[0].shape, + mask=None, + compressed_tensor=args[0].compressed_tensor, + transposed=args[0].transposed, + ) + + # Because we cannot go from the compressed representation back to the dense representation currently, + # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix + # is the first or second argument, we expect an even / odd number of calls to transpose respectively. + if func is torch.ops.aten.t.default: + return SparseSemiStructuredTensor( + args[0].original_tensor, + original_shape=args[0].shape, + mask=None, + compressed_tensor=args[0].compressed_tensor, + transposed=not args[0].transposed, + ) + + # handle addmm + if func is torch.ops.aten.addmm.default: + bias, input_A, input_B = args + + # Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS. + # CUTLASS only supports the first input to be sparse for a given matmul. + # cuSPARSELt does not have this limitation, although our implementation is only for sparse first. + + # We support second matrix sparse matmul by taking advantage of some transpose properties: + # This is also why we want an odd number of transposed for second matrix sparse vs an even number + # of transpose calss for first matrix sparse. + # F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')'' + # = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T + if isinstance(input_B, cls) and input_B.transposed: + result, _ = torch._structured_sparse_linear( + input_A, input_B.values(), input_B.indices(), bias=bias + ) + return result + + # handle mm + if func is torch.ops.aten.mm.default: + input_A, input_B = args + + if isinstance(input_A, cls) and not input_A.transposed: + transposed_result, _ = torch._structured_sparse_linear( + input_B.t(), input_A.values(), input_A.indices() + ) + return transposed_result.t() + + elif isinstance(input_B, cls) and input_B.transposed: + result, _ = torch._structured_sparse_linear( + input_A, input_B.values(), input_B.indices() + ) + return result + + # When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(), + # so we must match the aten.linear op. + # TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here. + if func is torch.ops.aten.linear.default: + input_tensor, weight, bias = args + if isinstance(weight, cls): + result, _ = torch._structured_sparse_linear( + input_tensor, weight.values(), weight.indices(), bias=bias + ) + return result + + # handle values + if func is torch.ops.aten.values.default: + m, k = args[0].shape + num_kept_elements = m * k // 2 + return args[0].compressed_tensor[:num_kept_elements].view(m, k // 2) + + # handle indices + if func is torch.ops.aten.indices.default: + m, k = args[0].shape + num_kept_elements = m * k // 2 + metadata = args[0].compressed_tensor[num_kept_elements:].view(m, -1) + + # the metadata is expected to be in different datatypes for fp16/int8 respectively for CUTLASS. + if args[0].dtype is torch.int8: + return metadata.view(torch.int32) + elif args[0].dtype is torch.float16: + return metadata.view(torch.int16) + + error_string = "\n".join( + [f"func {func} with args: "] + + [f"arg{i}: {arg}" for i, arg in enumerate(args)] + ) + raise NotImplementedError(error_string) + + +def to_sparse_semi_structured( + original_tensor: torch.Tensor, + mask: Optional[torch.Tensor] = None, + transposed: bool = False, +) -> SparseSemiStructuredTensor: + """ + This function converts a dense tensor into a sparse semi-structured tensor. + It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor. + + This function will check to ensure the dense tensor has the right dtype, size, dims, and device. + We currently only support semi-structured sparse tensors for 2d CUDA tensors. + Additionally, your tensor must be a positive multiple of a block size given the dtype + + - torch.float16 (r, c) must be >= and a multiple of 64 + - torch.int8 (r, c) must be >= and a multiple of 128 + + Args: + original_tensor (Tensor): the dense tensor to convert + mask (Optional BoolTensor): boolean mask to apply to the original tensor + transposed (bool, optional): whether the dense tensor is transposed + + Returns: + SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor and mask + + Raises: + None + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() + tensor([[0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + ..., + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.], + [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) + >>> A_sparse = to_sparse_semi_structured(A, mask=A.bool()) + SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + ..., + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), + metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + ..., + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370], + [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', + dtype=torch.int16)) + """ + return SparseSemiStructuredTensor(original_tensor, mask=mask, transposed=transposed)