mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR adds in fast semi-structured sparsification kernels to PyTorch. These kernels allow for accelerated semi-structured sparsification kernels in PyTorch. The kernels have been added as aten native functions In particular, three new functions have been added: * `torch._sparse_semi_structured_tile` This function will return the packed representation and metadata for both X and X', as well as the thread masks. Note that this applies 2:4 sparsity in a 4x4 tile instead of a 1x4 strip as usual. * `torch._sparse_semi_structured_apply` This function takes in an input tensor and thread masks from the above function and returns a packed representation and metadata from applying thread masks to the input tensor. * `torch._sparse_semi_structured_apply_dense` This function does the same thing as above but instead of returning the tensor in the sparse representation it returns it in the dense representation The subclasses have also been updated to add a new `prune_dense_static_sort` classmethod to create sparse tensors with this format. I've added some additional documentatino on how to calculate the compressed tensors needed to create a SparseSemiStructuredTensor oneself. To this end, there are two new helper functions added: `sparse_semi_structured_tile` `compute_compressed_swizzled_bitmask` Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350 Approved by: https://github.com/cpuhrsch
1123 lines
46 KiB
Python
1123 lines
46 KiB
Python
# Owner(s): ["module: sparse"]
|
|
import itertools
|
|
import random
|
|
import unittest
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
from torch.sparse import (
|
|
SparseSemiStructuredTensor,
|
|
SparseSemiStructuredTensorCUSPARSELT,
|
|
SparseSemiStructuredTensorCUTLASS,
|
|
to_sparse_semi_structured,
|
|
)
|
|
|
|
from torch.sparse._semi_structured_conversions import (
|
|
sparse_semi_structured_from_dense_cutlass,
|
|
_sparse_semi_structured_tile,
|
|
_compute_compressed_swizzled_bitmask,
|
|
)
|
|
|
|
from torch.testing import make_tensor
|
|
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
)
|
|
|
|
from torch.testing._internal.common_dtype import all_types_and_complex
|
|
import torch._dynamo.test_case
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
parametrize,
|
|
run_tests,
|
|
subtest,
|
|
TestCase,
|
|
TEST_WITH_ROCM,
|
|
IS_WINDOWS,
|
|
)
|
|
|
|
import pytest
|
|
|
|
from torch.utils._triton import has_triton
|
|
|
|
SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8]
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS = {}
|
|
|
|
_IS_SM8X = False
|
|
|
|
if torch.cuda.is_available():
|
|
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
|
|
|
|
# check if cslt is available for now using this:
|
|
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
|
|
try:
|
|
torch._cslt_compress(torch.ones(128, 256).cuda())
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
|
|
except Exception:
|
|
pass
|
|
|
|
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8)
|
|
training_dtypes = dtypes(torch.float16, torch.bfloat16)
|
|
parametrize_backends = parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
|
|
atol_rtol_kw = {
|
|
torch.float16: {
|
|
"rtol": 1e-3,
|
|
"atol": 1e-3,
|
|
},
|
|
torch.bfloat16: {
|
|
"rtol": 1e-1,
|
|
"atol": 1e-1,
|
|
},
|
|
}
|
|
|
|
def sparse24_largest_mask_2d(original):
|
|
sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original)
|
|
return sparse.to_dense().bool()
|
|
|
|
def sparsify24_dense(original):
|
|
return sparse24_largest_mask_2d(original) * original
|
|
|
|
def rand_sparse_semi_structured_mask(
|
|
r, c, dtype=torch.float16, device="cuda", choice=None
|
|
):
|
|
"""
|
|
This function returns a 1:2 sparse matrix of size (r, c).
|
|
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
|
|
"""
|
|
|
|
choices = [[0, 1], [1, 0]]
|
|
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
|
|
|
return (
|
|
torch.tensor(mask_entries, dtype=dtype, device=device)
|
|
.reshape(r, c)
|
|
.contiguous()
|
|
)
|
|
|
|
def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
|
|
pattern = '2by4' if dtype != torch.float32 else '1by2'
|
|
if pattern == '1by2':
|
|
ksparse = 2
|
|
choices = [
|
|
[0, 1],
|
|
[1, 0]
|
|
]
|
|
elif pattern == '2by4':
|
|
ksparse = 4
|
|
choices = [
|
|
[1, 1, 0, 0],
|
|
[1, 0, 1, 0],
|
|
[1, 0, 0, 1],
|
|
[0, 1, 1, 0],
|
|
[0, 1, 0, 1],
|
|
[0, 0, 1, 1]
|
|
]
|
|
mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)]
|
|
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
|
|
dense = make_tensor(r, c, dtype=dtype, device=device)
|
|
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
|
|
dense = dense.masked_fill(~mask, 0)
|
|
return dense
|
|
|
|
|
|
def rand_sparse_semi_structured_all_patterns(r, c, dtype, device):
|
|
pattern = '2by4' if dtype != torch.float32 else '1by2'
|
|
if pattern == '1by2':
|
|
ksparse = 2
|
|
choices = [
|
|
[[0, 0], [0, 1]],
|
|
[[0, 1], [0, 1]],
|
|
[[1, 0], [1, 0]],
|
|
[[1, 1], [1, 0]]
|
|
]
|
|
elif pattern == '2by4':
|
|
ksparse = 4
|
|
choices = [
|
|
[[0, 0, 0, 0], [0, 0, 1, 1]],
|
|
[[0, 0, 0, 1], [0, 0, 1, 1]],
|
|
[[0, 0, 1, 0], [0, 0, 1, 1]],
|
|
[[0, 0, 1, 1], [0, 0, 1, 1]],
|
|
[[0, 1, 0, 0], [0, 1, 1, 0]],
|
|
[[0, 1, 0, 1], [0, 1, 0, 1]],
|
|
[[0, 1, 1, 0], [0, 1, 1, 0]],
|
|
[[0, 1, 1, 1], [0, 1, 0, 1]],
|
|
[[1, 0, 0, 0], [1, 0, 1, 0]],
|
|
[[1, 0, 0, 1], [1, 0, 0, 1]],
|
|
[[1, 0, 1, 0], [1, 0, 1, 0]],
|
|
[[1, 0, 1, 1], [1, 0, 0, 1]],
|
|
[[1, 1, 0, 0], [1, 1, 0, 0]],
|
|
[[1, 1, 0, 1], [1, 1, 0, 0]],
|
|
[[1, 1, 1, 0], [1, 1, 0, 0]],
|
|
[[1, 1, 1, 1], [1, 1, 0, 0]],
|
|
]
|
|
mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // ksparse)]
|
|
|
|
COL_INV, COL_VAL = 0, 1
|
|
mask_entries_inv = [choices[i][COL_INV] for i in mask_rows]
|
|
mask_entries_val = [choices[i][COL_VAL] for i in mask_rows]
|
|
mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device)
|
|
mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device)
|
|
dense = make_tensor(r, c, dtype=dtype, device=device)
|
|
dense[dense == 0] = 1 # To prevent zeros except where mask below applied.
|
|
dense_inv = dense.masked_fill(~mask_inv, 0)
|
|
dense_val = dense_inv.masked_fill(~mask_val, 0)
|
|
|
|
return dense_inv, dense_val
|
|
|
|
|
|
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
|
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest('Only runs on SM80')
|
|
super().setUp()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
|
|
@staticmethod
|
|
def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
|
|
"""
|
|
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
|
|
We expect:
|
|
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()`
|
|
(2) Inductor should fuse the .contiguous() call into the relu
|
|
"""
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(128, 128)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = x.contiguous()
|
|
return torch.nn.functional.relu(x)
|
|
|
|
input = torch.rand(dense_input_shape, device="cuda").half()
|
|
model = Model().eval().cuda().half()
|
|
mod_linear = model.linear
|
|
m, n = mod_linear.weight.shape
|
|
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
|
|
# set masked weight
|
|
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
|
|
|
|
dense_result = model(input)
|
|
mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight))
|
|
sparse_result = model(input)
|
|
|
|
model = torch.compile(model, backend="inductor", fullgraph=True)
|
|
sparse_compile_result = model(input)
|
|
|
|
# test that sparse_compile_result and dense_result are numerically close
|
|
assert torch.allclose(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
|
|
# assert sparse and sparse_compile have the same strides,
|
|
# as meta registrations may return contiguous tensors when the output is transposed
|
|
# https://github.com/pytorch/pytorch/pull/114477
|
|
assert sparse_result.stride() == sparse_compile_result.stride()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
|
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
|
def test_mlp_contiguous_relu_compile_cusparselt(self):
|
|
"""
|
|
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
|
|
"""
|
|
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
|
|
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
|
def test_mlp_contiguous_relu_compile_cutlass(self):
|
|
"""
|
|
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
|
|
"""
|
|
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
|
|
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
|
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
|
def test_sp24_compile(self) -> None:
|
|
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
|
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
|
|
|
|
def fn(x, e):
|
|
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
|
y = y.t()
|
|
return x @ y
|
|
|
|
# Eager
|
|
output = fn(x, e)
|
|
output.backward(output)
|
|
# Torch compile
|
|
output = torch.compile(fn)(x, e)
|
|
output.backward(output)
|
|
|
|
class TestSparseSemiStructured(TestCase):
|
|
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest('Only runs on SM80')
|
|
if IS_WINDOWS:
|
|
self.skipTest("torch.compile not supported on windows")
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
def test_to_sparse_semi_structured(self, dtype, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
assert A.shape == A_sparse.shape
|
|
assert A.device == A_sparse.device
|
|
assert A.dtype == A_sparse.dtype
|
|
|
|
assert isinstance(A, torch.Tensor)
|
|
assert isinstance(A_sparse, SparseSemiStructuredTensor)
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
|
|
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8:
|
|
if backend == "cutlass":
|
|
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
else:
|
|
dense_result = torch.mm(A, B)
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
|
|
and will throw an error for int8 + padding
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8 and dense_input_shape in {(1, 128)}:
|
|
# padding with int8 throws an error because transposing B yields a contiguous output
|
|
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
|
|
if backend == "cutlass":
|
|
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
elif dtype is torch.int8:
|
|
# test transpose
|
|
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
# test transpose
|
|
dense_result = torch.mm(A, B.t())
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize_backends
|
|
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse.t(), B) throws error
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
|
|
):
|
|
torch.mm(A_sparse.t(), B)
|
|
|
|
@inference_dtypes
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize_backends
|
|
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A, B_sparse.t()) is correct
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
B_sparse = to_sparse_semi_structured(B)
|
|
|
|
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8:
|
|
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
|
|
sparse_result = torch.mm(A, B_sparse.t())
|
|
else:
|
|
dense_result = torch.mm(A, B.t())
|
|
sparse_result = torch.mm(A, B_sparse.t())
|
|
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize_backends
|
|
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A, B_sparse) throws error
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
B_sparse = to_sparse_semi_structured(B)
|
|
|
|
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
|
|
):
|
|
sparse_result = torch.mm(A, B_sparse)
|
|
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
|
@parametrize("inference_mode", [subtest(True), subtest(False)])
|
|
@parametrize_backends
|
|
def test_linear(self, dense_input_shape, inference_mode, device, backend):
|
|
"""
|
|
Test nn.Linear has the same numerics
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
input = torch.rand((dense_input_shape), device=device).half()
|
|
model = nn.Linear(128, 256).to(device).half()
|
|
m, n = model.weight.shape
|
|
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
|
|
# set masked weight
|
|
model.weight = nn.Parameter(model.weight * mask)
|
|
|
|
dense_result = model(input)
|
|
|
|
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))
|
|
|
|
if inference_mode:
|
|
with torch.inference_mode():
|
|
sparse_result = model(input)
|
|
else:
|
|
sparse_result = model(input)
|
|
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
|
@parametrize_backends
|
|
def test_mlp(self, device, dense_input_shape, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
input = torch.rand(dense_input_shape, device=device).half()
|
|
model = (
|
|
nn.Sequential(
|
|
nn.Linear(128, 256),
|
|
nn.Linear(256, 128),
|
|
)
|
|
.half()
|
|
.to(device)
|
|
)
|
|
|
|
for i in range(2):
|
|
m, n = model[i].weight.shape
|
|
mask = rand_sparse_semi_structured_mask(
|
|
m, n, device=device, dtype=torch.bool
|
|
)
|
|
# set masked weight
|
|
model[i].weight = nn.Parameter(model[i].weight * mask)
|
|
|
|
dense_result = model(input)
|
|
|
|
for i in range(2):
|
|
model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight))
|
|
|
|
sparse_result = model(input)
|
|
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize_backends
|
|
def test_values(self, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(128, 128)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
assert A_sparse.values().shape == (128, 64)
|
|
assert (A_sparse.values() == 1).all()
|
|
|
|
@parametrize_backends
|
|
def test_indices(self, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(128, 128)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
assert A_sparse.indices().shape == (128, 8)
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
def test_min_sparse_shape(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
|
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
|
|
if dtype == torch.int8:
|
|
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8)
|
|
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
|
|
B_t = B.t().contiguous()
|
|
sparse_res = torch.mm(A_sparse, B_t.t())
|
|
else:
|
|
dense_res = torch.mm(A, B)
|
|
sparse_res = torch.mm(A_sparse, B)
|
|
assert torch.allclose(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
def test_unsupported_shape(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device)
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@dtypes(*all_types_and_complex())
|
|
@parametrize_backends
|
|
def test_unsupported_dtype(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
|
|
|
|
if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES:
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
else:
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@parametrize_backends
|
|
def test_unsupported_dim(self, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
|
|
def create_random_mask(shape) -> torch.Tensor:
|
|
r = random.Random(0)
|
|
mask = torch.zeros(shape, dtype=torch.bool)
|
|
for line in range(mask.shape[0]):
|
|
for col in range(0, mask.shape[1], 4):
|
|
sparsity = r.choice(
|
|
[
|
|
[False, False, True, True],
|
|
[False, True, False, True],
|
|
[True, False, False, True],
|
|
[False, True, True, False],
|
|
[True, False, True, False],
|
|
[True, True, False, False],
|
|
]
|
|
)
|
|
mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool)
|
|
return mask
|
|
|
|
class TestSparseSemiStructuredTraining(TestCase):
|
|
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest('Only runs on SM80')
|
|
if IS_WINDOWS:
|
|
self.skipTest('CUTLASS not supported on windows')
|
|
|
|
|
|
@training_dtypes
|
|
def test_prune_dense_static_sort(self, dtype) -> None:
|
|
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
|
|
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
|
|
dense = torch.randn(128, 128, device="cuda", dtype=dtype)
|
|
pruned = _sparse_semi_structured_tile(dense)
|
|
|
|
# CUTLASS
|
|
reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy")
|
|
assert torch.allclose(pruned, reference_cutlass.to_dense())
|
|
|
|
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
|
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
|
meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride())
|
|
meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride())
|
|
compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
|
compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape,
|
|
reference_cutlass.compressed_swizzled_bitmask.stride())
|
|
cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape,
|
|
packed_cutlass,
|
|
meta_cutlass,
|
|
packed_t_cutlass,
|
|
meta_t_cutlass,
|
|
compressed_swizzled_bitmask)
|
|
assert torch.allclose(reference_cutlass.to_dense(), cutlass.to_dense())
|
|
|
|
# CUSPARSELT
|
|
reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned,
|
|
algorithm="largest_abs_values_greedy")
|
|
assert torch.allclose(pruned, reference_cusparselt.to_dense())
|
|
|
|
packed_cusparselt = torch._cslt_compress(pruned)
|
|
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
|
cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape,
|
|
packed_cusparselt,
|
|
None,
|
|
packed_t_cusparselt,
|
|
None,
|
|
compressed_swizzled_bitmask)
|
|
assert torch.allclose(reference_cusparselt.to_dense(), cusparselt.to_dense())
|
|
|
|
|
|
|
|
@training_dtypes
|
|
@parametrize_backends
|
|
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
|
|
inp = torch.tensor(
|
|
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
|
|
device="cuda",
|
|
dtype=dtype,
|
|
)
|
|
inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1)
|
|
sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy")
|
|
|
|
mask = sInp.to_dense() / inp
|
|
assert mask[:4, :4].int().tolist() == [
|
|
[1, 1, 0, 0],
|
|
[0, 1, 1, 0],
|
|
[0, 0, 1, 1],
|
|
[1, 0, 0, 1],
|
|
]
|
|
|
|
@training_dtypes
|
|
def test_gemm(self, dtype) -> None:
|
|
M, N, K = 32, 32, 64
|
|
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
|
b = torch.randn([K, N], device="cuda", dtype=dtype)
|
|
mask = rand_sparse_semi_structured_mask(M, K, dtype=torch.bool)
|
|
|
|
a.masked_fill_(~mask, 0)
|
|
|
|
a_sparse = to_sparse_semi_structured(a)
|
|
|
|
masked_a = a * mask
|
|
ref_out = masked_a @ b
|
|
sp24_out = a_sparse @ b
|
|
assert torch.allclose(ref_out, sp24_out, **atol_rtol_kw[dtype])
|
|
|
|
|
|
@training_dtypes
|
|
@parametrize_backends
|
|
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
|
|
M, N = 128, 256
|
|
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
|
|
a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :]
|
|
a = a.repeat(M // 8, N // 8)
|
|
assert a.shape == (M, N)
|
|
a = a.cuda().to(dtype)
|
|
b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype)
|
|
|
|
a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a)
|
|
|
|
mask_dense = sparse24_largest_mask_2d(a).to(dtype)
|
|
|
|
if backend == "cutlass":
|
|
assert isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS)
|
|
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(
|
|
mask_dense, use_cutlass=True)
|
|
|
|
sparse_mask = SparseSemiStructuredTensorCUTLASS(
|
|
mask_dense.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
)
|
|
assert torch.allclose(a_sparse.meta.view(torch.short), sparse_mask.meta)
|
|
|
|
ref_gemm = (mask_dense * a) @ b
|
|
pack_gemm = a_sparse @ b
|
|
assert torch.allclose(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
|
|
|
|
@training_dtypes
|
|
def test_pack_both_ways_id(self, dtype) -> None:
|
|
N = 512
|
|
torch.manual_seed(0)
|
|
a = torch.randn([N, N], dtype=dtype, device="cuda")
|
|
b = torch.eye(N, dtype=dtype, device="cuda")
|
|
|
|
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[
|
|
:4
|
|
]
|
|
# Heuristic to ensure we pack the same values
|
|
assert torch.allclose(
|
|
packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum()
|
|
)
|
|
|
|
mask_dense = sparse24_largest_mask_2d(a.to(dtype))
|
|
|
|
ref_gemm = mask_dense * a
|
|
# Test A@B
|
|
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t()
|
|
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
|
assert torch.allclose(
|
|
ref_gemm, pack_gemm,
|
|
**atol_rtol_kw[dtype]
|
|
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
|
# Test A.t@B
|
|
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
|
|
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
|
|
|
assert torch.allclose(
|
|
ref_gemm, pack_gemm,
|
|
**atol_rtol_kw[dtype]
|
|
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
|
|
|
@training_dtypes
|
|
def test_pack_both_ways_edge_case1(self, dtype) -> None:
|
|
# In this case, the heuristic will keep 7 values out of 16
|
|
# instead of 8. let's see how the kernel handles this
|
|
quad = torch.tensor(
|
|
[
|
|
[2, -1, -2, -3], # Should be packed as `2 <null>`
|
|
[-1, 8, -1, 6],
|
|
[-1, -1, 4, 5],
|
|
[-1, 3, 7, -1],
|
|
],
|
|
dtype=dtype,
|
|
device="cuda",
|
|
)
|
|
a = torch.randn([32, 64], dtype=dtype, device="cuda")
|
|
a[:4, :4] = quad
|
|
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4]
|
|
# Check first line in A
|
|
assert packed[0, 0].item() == 2
|
|
assert packed[0, 1].item() == 0
|
|
# And first column in A.t
|
|
assert packed_t[0, 0].item() == 2
|
|
assert packed_t[0, 1].item() == 0
|
|
|
|
@training_dtypes
|
|
def test_sp24_apply(self, dtype) -> None:
|
|
M, N = 256, 1024
|
|
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
|
(
|
|
packed,
|
|
meta,
|
|
packed_t,
|
|
meta_t,
|
|
bitmask,
|
|
) = torch._sparse_semi_structured_tile(x)
|
|
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
|
|
assert torch.allclose(packed, packed2)
|
|
assert torch.allclose(packed_t, packed_t2)
|
|
|
|
@training_dtypes
|
|
def test_sp24_apply_dense(self, dtype) -> None:
|
|
M, N = 256, 1024
|
|
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
|
(
|
|
packed,
|
|
meta,
|
|
packed_t,
|
|
meta_t,
|
|
bitmask,
|
|
) = torch._sparse_semi_structured_tile(x)
|
|
|
|
expected = SparseSemiStructuredTensorCUTLASS(
|
|
x.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
).to_dense()
|
|
|
|
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
|
|
sparse = SparseSemiStructuredTensorCUTLASS(
|
|
x.shape,
|
|
packed=packed2,
|
|
meta=meta,
|
|
packed_t=packed_t2,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
)
|
|
|
|
dense = torch._sparse_semi_structured_apply_dense(x, bitmask)
|
|
|
|
assert torch.allclose(dense, expected)
|
|
assert torch.allclose(sparse.to_dense(), expected)
|
|
|
|
|
|
@training_dtypes
|
|
def test_sp24_matmuls(self, dtype) -> None:
|
|
M, N, K = 64, 256, 1024
|
|
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
|
b = torch.randn([K, N], device="cuda", dtype=dtype)
|
|
a_m = sparse24_largest_mask_2d(a)
|
|
b_m = sparse24_largest_mask_2d(b)
|
|
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a)
|
|
a_s = SparseSemiStructuredTensorCUTLASS(
|
|
a.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
)
|
|
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b)
|
|
b_s = SparseSemiStructuredTensorCUTLASS(
|
|
b.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
)
|
|
|
|
assert torch.allclose(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1e-1)
|
|
assert torch.allclose(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1e-1)
|
|
assert torch.allclose(
|
|
a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1e-1
|
|
)
|
|
assert torch.allclose(
|
|
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
|
|
)
|
|
|
|
def test_sp24_matmuls_mat_vec(self) -> None:
|
|
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
|
b = torch.randn([128], device="cuda", dtype=torch.float16)
|
|
a_m = sparse24_largest_mask_2d(a)
|
|
a_s = to_sparse_semi_structured(a)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
|
|
|
|
|
def test_sp24_matmuls_bmm(self) -> None:
|
|
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
|
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
|
|
a_m = sparse24_largest_mask_2d(a)
|
|
a_s = to_sparse_semi_structured(a)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
assert torch.allclose(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
|
|
|
class TestSparseSemiStructuredCUTLASS(TestCase):
|
|
"""
|
|
This contains CUTLASS specific tests for
|
|
- torch._sparse_semi_structured_linear
|
|
"""
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest('Only runs on SM80')
|
|
if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
|
self.skipTest('CUTLASS not enabled')
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
|
|
@inference_dtypes
|
|
def test_linear_cutlass(self, device, dtype):
|
|
|
|
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
|
|
weight = rand_sparse_semi_structured(m, k, dtype, device)
|
|
input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
|
|
bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None
|
|
|
|
dtype_dense = torch.float32
|
|
input_dense = input.to(dtype_dense)
|
|
weight_dense = weight.to(dtype_dense)
|
|
bias_dense = bias.to(dtype_dense) if add_bias else None
|
|
output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense)
|
|
if activation == "relu":
|
|
relu = torch.nn.ReLU()
|
|
output0 = relu(output0)
|
|
elif activation == "silu":
|
|
silu = torch.nn.SiLU()
|
|
output0 = silu(output0)
|
|
|
|
compressed = to_sparse_semi_structured(weight)
|
|
|
|
weight_sparse = compressed.values()
|
|
meta = compressed.indices()
|
|
|
|
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation,
|
|
out_dtype=dtype_out if dtype == torch.int8 else None)
|
|
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
|
|
|
|
if dtype == torch.float32:
|
|
# Inputs are converted to TF32 internally for sparse GEMM,
|
|
# so make dense GEMM to do the same for matching results.
|
|
orig = torch.backends.cuda.matmul.allow_tf32
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
batch_shapes = [[], [3], [3, 1]]
|
|
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
|
|
activations = [None, "relu", "silu"]
|
|
rtol, atol = 1e-3, 1e-3
|
|
if dtype == torch.bfloat16:
|
|
rtol, atol = 5e-3, 5e-3
|
|
elif dtype == torch.float32:
|
|
rtol, atol = 1e-3, 75e-2
|
|
for batch_shape, m, n, k, add_bias, activation in \
|
|
itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations):
|
|
if activation == "silu" and dtype == torch.int8:
|
|
continue # SiLU not supported for integer inputs
|
|
|
|
m = 2 ** m * 32
|
|
n = 2 ** n * 32
|
|
k = 2 ** k * 128
|
|
run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol)
|
|
|
|
if dtype == torch.float32:
|
|
torch.backends.cuda.matmul.allow_tf32 = orig
|
|
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
|
|
@parametrize("backend", ["cutlass"])
|
|
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
|
def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
|
|
def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
|
|
mat1 = rand_sparse_semi_structured(m, k, dtype, device)
|
|
# mat2 transposed as int8 case supports only row-major/column-major combination
|
|
mat2 = make_tensor((n, k), dtype=dtype, device=device).t()
|
|
input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None
|
|
|
|
if use_input:
|
|
if dtype.is_floating_point:
|
|
alpha = 1.3
|
|
beta = -0.7
|
|
else:
|
|
alpha = 2
|
|
beta = -3
|
|
|
|
dtype_dense = torch.float32
|
|
mat1_dense = mat1.to(dtype_dense)
|
|
mat2_dense = mat2.to(dtype_dense)
|
|
if not use_input:
|
|
output0 = torch.mm(mat1_dense, mat2_dense)
|
|
else:
|
|
input_dense = input.to(dtype_dense)[:, None]
|
|
output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta)
|
|
|
|
compressed = to_sparse_semi_structured(mat1)
|
|
|
|
mat1_sparse = compressed.values()
|
|
mat1_meta = compressed.indices()
|
|
|
|
if not use_input:
|
|
output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out)
|
|
else:
|
|
output1 = torch._sparse_semi_structured_addmm(
|
|
input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out
|
|
)
|
|
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
|
|
|
|
if dtype == torch.float32:
|
|
# Inputs are converted to TF32 internally for sparse GEMM,
|
|
# so make dense GEMM to do the same for matching results.
|
|
orig = torch.backends.cuda.matmul.allow_tf32
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
|
|
rtol, atol = 1e-3, 1e-3
|
|
if dtype == torch.bfloat16:
|
|
rtol, atol = 5e-3, 5e-3
|
|
elif dtype == torch.float32:
|
|
rtol, atol = 1e-3, 75e-2
|
|
for m, n, k, use_input in \
|
|
itertools.product(range(3), range(3), range(3), (False, True)):
|
|
m = 2 ** m * 32
|
|
n = 2 ** n * 32
|
|
k = 2 ** k * 128
|
|
run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol)
|
|
|
|
if dtype == torch.float32:
|
|
torch.backends.cuda.matmul.allow_tf32 = orig
|
|
|
|
|
|
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
|
@inference_dtypes
|
|
def test_conversions(self, device, dtype):
|
|
|
|
def run_test(r, c, device, dtype):
|
|
dense_ref = rand_sparse_semi_structured(r, c, dtype, device)
|
|
|
|
compressed = to_sparse_semi_structured(dense_ref)
|
|
|
|
# The torch.ops.aten._to_sparse_semi_structured operator
|
|
# uses CUTLASS to perform conversion from given dense
|
|
# matrix to the pair of corresponding sparse and metadata
|
|
# matrices, with the later used here as a reference to
|
|
# compare the metadata matrix produced by conversion
|
|
# performed by SparseSemiStructuredTensor class
|
|
# constructor against.
|
|
_, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref)
|
|
|
|
meta = compressed.indices()
|
|
torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0)
|
|
|
|
dense = compressed.to_dense()
|
|
torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0)
|
|
|
|
shapes = [[32, 128], [32, 256], [64, 128], [64, 256]]
|
|
for r, c in shapes:
|
|
run_test(r, c, device, dtype)
|
|
|
|
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
|
|
@inference_dtypes
|
|
def test_conversions_all_patterns(self, device, dtype):
|
|
r, c = 32, 128
|
|
|
|
dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device)
|
|
|
|
compressed = to_sparse_semi_structured(dense_inv)
|
|
dense = compressed.to_dense()
|
|
|
|
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
|
|
|
|
|
|
|
|
CUSPARSELT_NUM_ALG_IDS = 4
|
|
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
|
|
|
|
|
class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|
"""
|
|
This contains cuSPARSELt specific tests for
|
|
torch._cslt_compress
|
|
torch._cslt_sparse_mm
|
|
"""
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest('Only runs on SM80')
|
|
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
|
self.skipTest('cuSPARSELt not enabled')
|
|
|
|
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
|
|
@parametrize("dense_input_shape", [(128, 128)])
|
|
def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
|
|
A_compressed = torch._cslt_compress(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=device).to(torch.int8)
|
|
|
|
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
|
|
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@training_dtypes
|
|
def test_cslt_sparse_mm_alpha(self, dtype, device):
|
|
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
|
|
B = torch.ones((256, 128), device=device).to(dtype)
|
|
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha)
|
|
|
|
alpha_scaled = torch.stack([alpha] * 128).t()
|
|
dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32))
|
|
dense_result = dense_result.to(dtype)
|
|
|
|
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT)
|
|
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
|
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
|
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
|
alpha = torch.Tensor([2**(-i) if out_dtype is not torch.int32 else 1
|
|
for i in range(128)]).cuda()
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=out_dtype).cpu()
|
|
|
|
alpha_scaled = torch.stack([alpha] * 128).t()
|
|
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
|
|
dense_result = dense_result.to(out_dtype)
|
|
|
|
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
|
|
@inference_dtypes
|
|
def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id):
|
|
# alg_id=3 not supported for float32 dtype
|
|
if dtype == torch.float32 and alg_id == 3:
|
|
return
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
A_compressed = torch._cslt_compress(A)
|
|
B = torch.ones((128, 128), device=device).to(dtype)
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
|
|
|
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
|
dense_result = dense_result.to(dtype)
|
|
|
|
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
def test_cslt_sparse_mm_search(self, device, dtype):
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
|
A_compressed = torch._cslt_compress(A)
|
|
B = torch.ones((128, 128), device=device).to(dtype)
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
|
# for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error
|
|
# when setting using the last one (4)
|
|
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
|
|
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
|
|
|
|
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
|
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
|
|
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")
|
|
instantiate_device_type_tests(TestSparseSemiStructuredTraining, globals(), only_for="cuda")
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|