mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
TLDR: - This pull request introduces support for hipSPARSELt in ROCm, current usage would be semi-structure sparsity. - Require **ROCm 6.4** && **gfx942/gfx950**. - The average performance uplift (compare to dense operation) is ~ 20% in ROCm 6.4 but expect further performance lift along the way. ### Dense vs. Sparse Performance Comparison #### **NT (Row-major)** **Average Uplift**: `1.20` | M | N | K | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift | |-------|--------|--------|-------------------------|-------------------------------|--------| | 14336 | 8 | 4096 | 20.05 | 25.3 | 1.26 | | 4096 | 8 | 14336 | 21.07 | 25.28 | 1.20 | | 3072 | 3072 | 10240 | 299.05 | 351.82 | 1.18 | | 3072 | 1536 | 768 | 18.56 | 20.05 | 1.08 | | 3072 | 17664 | 768 | 163.13 | 173.91 | 1.07 | | 3072 | 196608 | 768 | 1717.30 | 1949.63 | 1.14 | | 3072 | 24576 | 768 | 206.84 | 242.98 | 1.17 | | 3072 | 6144 | 768 | 53.90 | 56.88 | 1.06 | | 3072 | 98304 | 768 | 833.77 | 962.28 | 1.15 | | 768 | 1536 | 768 | 8.53 | 19.65 | 2.30 | | 768 | 17664 | 768 | 46.02 | 46.84 | 1.02 | | 768 | 196608 | 768 | 463.15 | 540.46 | 1.17 | | 768 | 24576 | 768 | 54.32 | 59.55 | 1.10 | | 768 | 6144 | 768 | 19.47 | 20.15 | 1.03 | | 768 | 98304 | 768 | 231.88 | 258.73 | 1.12 | --- #### **NN (Row-major)** **Average Uplift**: `1.13` | M | N | K | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift | |-----|--------|-------|-------------------------|-------------------------------|--------| | 768 | 1536 | 3072 | 27.50 | 28.78 | 1.05 | | 768 | 17664 | 3072 | 125.06 | 158.94 | 1.27 | | 768 | 196608 | 3072 | 1568.38 | 1767.12 | 1.13 | | 768 | 24576 | 3072 | 171.05 | 203.49 | 1.19 | | 768 | 6144 | 3072 | 58.72 | 60.39 | 1.03 | | 768 | 98304 | 3072 | 787.15 | 887.60 | 1.13 | ------------------------- This pull request introduces support for hipSPARSELt in ROCm, alongside various updates and improvements to the codebase and test suite. The changes primarily involve adding configuration flags, updating conditional checks, and ensuring compatibility with hipSPARSELt. ### ROCm and hipSPARSELt Support: * [`BUILD.bazel`](diffhunk://#diff-7fc57714ef13c3325ce2a1130202edced92fcccc0c6db34a72f7b57f60d552a3R292): Added `@AT_HIPSPARSELT_ENABLED@` substitution to enable hipSPARSELt support. * [`aten/CMakeLists.txt`](diffhunk://#diff-0604597797bb21d7c39150f9429d6b2ace10b79ab308514ad03f76153ae8249bR104-R110): Introduced a conditional flag to enable hipSPARSELt support based on ROCm version. * [`aten/src/ATen/CMakeLists.txt`](diffhunk://#diff-ce80f3115ab2f6be5142f0678a1fc92c6b2d7727766ce44f48726c99e720f777R37): Added `AT_HIPSPARSELT_ENABLED` configuration. * [`aten/src/ATen/cuda/CUDAConfig.h.in`](diffhunk://#diff-8bb82da825ca87c28233abacffa1b0566c73a54990b7a77f3f5108d3718fea15R11): Defined `AT_HIPSPARSELT_ENABLED` macro. * `caffe2/CMakeLists.txt`, `cmake/Dependencies.cmake`, `cmake/public/LoadHIP.cmake`: Included hipSPARSELt in the ROCm dependencies. [[1]](diffhunk://#diff-c5ee05f1e918772792ff6f2a3f579fc2f182e57b1709fd786ef6dc711fd68b27R1380) [[2]](diffhunk://#diff-12e8125164bbfc7556b1781a8ed516e333cc0bf058acb7197f7415be44606c72L1084-R1084) [[3]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5R153) ### Codebase Updates: * [`aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp`](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6): Added hipSPARSELt support checks and initialization functions. Updated various methods to conditionally handle hipSPARSELt. [[1]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6) [[2]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R22-R67) [[3]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R78-R85) [[4]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R97-R109) [[5]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R183-R188) [[6]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L134-R200) [[7]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R213-R222) [[8]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L217-R285) ### Test Suite Updates: * [`test/test_sparse_semi_structured.py`](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65): Added checks for hipSPARSELt availability and updated test conditions to skip tests not supported on ROCm. [[1]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65) [[2]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR228) [[3]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR239) [[4]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR250) [[5]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR579) [[6]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR624) [[7]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR661) [[8]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR695) [[9]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR730) [[10]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR755) [[11]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR771) [[12]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR809) [[13]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR844) [[14]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cL840-R854) [[15]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR1005) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150578 Approved by: https://github.com/jeffdaily
1260 lines
53 KiB
Python
1260 lines
53 KiB
Python
# Owner(s): ["module: sparse"]
|
|
# ruff: noqa: F841
|
|
import itertools
|
|
import random
|
|
import unittest
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
from torch.sparse import (
|
|
SparseSemiStructuredTensor,
|
|
SparseSemiStructuredTensorCUSPARSELT,
|
|
SparseSemiStructuredTensorCUTLASS,
|
|
to_sparse_semi_structured,
|
|
)
|
|
|
|
from torch.sparse._semi_structured_conversions import (
|
|
sparse_semi_structured_from_dense_cutlass,
|
|
_sparse_semi_structured_tile,
|
|
_compute_compressed_swizzled_bitmask,
|
|
)
|
|
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_cuda import _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8, xfailIfSM89
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
)
|
|
|
|
from torch.testing._internal.common_dtype import all_types_and_complex
|
|
import torch._dynamo.test_case
|
|
from torch.testing._internal.common_utils import (
|
|
parametrize,
|
|
run_tests,
|
|
subtest,
|
|
TestCase,
|
|
TEST_WITH_ROCM,
|
|
IS_WINDOWS,
|
|
)
|
|
|
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
|
|
|
import pytest
|
|
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict()
|
|
|
|
_IS_SM8X = False
|
|
_IS_SM9X = False
|
|
_IS_HIPSPARSELT_AVAILABLE = False
|
|
|
|
if torch.cuda.is_available():
|
|
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
|
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
|
|
_IS_HIPSPARSELT_AVAILABLE = torch.version.hip is not None and tuple(int(v) for v in torch.version.hip.split('.')[:2]) > (6, 4)
|
|
# CUTLASS kernels only work for Ampere
|
|
if _IS_SM8X:
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
|
|
|
|
# add cuSPASRELt tests if available
|
|
if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X or _IS_HIPSPARSELT_AVAILABLE):
|
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
|
|
|
|
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8)
|
|
training_dtypes = dtypes(torch.float16, torch.bfloat16)
|
|
parametrize_backends = parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
|
|
|
atol_rtol_kw = {
|
|
torch.float16: {
|
|
"rtol": 1e-3,
|
|
"atol": 1e-3,
|
|
},
|
|
torch.bfloat16: {
|
|
"rtol": 1e-1,
|
|
"atol": 1e-1,
|
|
},
|
|
}
|
|
|
|
def sparse24_largest_mask_2d(original):
|
|
sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original)
|
|
return sparse.to_dense().bool()
|
|
|
|
def sparsify24_dense(original):
|
|
return sparse24_largest_mask_2d(original) * original
|
|
|
|
def rand_sparse_semi_structured_mask(
|
|
r, c, dtype=torch.float16, device="cuda", choice=None
|
|
):
|
|
"""
|
|
This function returns a 1:2 sparse matrix of size (r, c).
|
|
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
|
|
"""
|
|
|
|
choices = [[0, 1], [1, 0]]
|
|
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
|
|
|
|
return (
|
|
torch.tensor(mask_entries, dtype=dtype, device=device)
|
|
.reshape(r, c)
|
|
.contiguous()
|
|
)
|
|
|
|
def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
|
|
pattern = '2by4' if dtype != torch.float32 else '1by2'
|
|
if pattern == '1by2':
|
|
ksparse = 2
|
|
choices = [
|
|
[0, 1],
|
|
[1, 0]
|
|
]
|
|
elif pattern == '2by4':
|
|
ksparse = 4
|
|
choices = [
|
|
[1, 1, 0, 0],
|
|
[1, 0, 1, 0],
|
|
[1, 0, 0, 1],
|
|
[0, 1, 1, 0],
|
|
[0, 1, 0, 1],
|
|
[0, 0, 1, 1]
|
|
]
|
|
mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)]
|
|
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
|
|
dense = make_tensor(r, c, dtype=dtype, device=device)
|
|
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
|
|
dense = dense.masked_fill(~mask, 0)
|
|
return dense
|
|
|
|
|
|
def rand_sparse_semi_structured_all_patterns(r, c, dtype, device):
|
|
pattern = '2by4' if dtype != torch.float32 else '1by2'
|
|
if pattern == '1by2':
|
|
ksparse = 2
|
|
choices = [
|
|
[[0, 0], [0, 1]],
|
|
[[0, 1], [0, 1]],
|
|
[[1, 0], [1, 0]],
|
|
[[1, 1], [1, 0]]
|
|
]
|
|
elif pattern == '2by4':
|
|
ksparse = 4
|
|
choices = [
|
|
[[0, 0, 0, 0], [0, 0, 1, 1]],
|
|
[[0, 0, 0, 1], [0, 0, 1, 1]],
|
|
[[0, 0, 1, 0], [0, 0, 1, 1]],
|
|
[[0, 0, 1, 1], [0, 0, 1, 1]],
|
|
[[0, 1, 0, 0], [0, 1, 1, 0]],
|
|
[[0, 1, 0, 1], [0, 1, 0, 1]],
|
|
[[0, 1, 1, 0], [0, 1, 1, 0]],
|
|
[[0, 1, 1, 1], [0, 1, 0, 1]],
|
|
[[1, 0, 0, 0], [1, 0, 1, 0]],
|
|
[[1, 0, 0, 1], [1, 0, 0, 1]],
|
|
[[1, 0, 1, 0], [1, 0, 1, 0]],
|
|
[[1, 0, 1, 1], [1, 0, 0, 1]],
|
|
[[1, 1, 0, 0], [1, 1, 0, 0]],
|
|
[[1, 1, 0, 1], [1, 1, 0, 0]],
|
|
[[1, 1, 1, 0], [1, 1, 0, 0]],
|
|
[[1, 1, 1, 1], [1, 1, 0, 0]],
|
|
]
|
|
mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // ksparse)]
|
|
|
|
COL_INV, COL_VAL = 0, 1
|
|
mask_entries_inv = [choices[i][COL_INV] for i in mask_rows]
|
|
mask_entries_val = [choices[i][COL_VAL] for i in mask_rows]
|
|
mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device)
|
|
mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device)
|
|
dense = make_tensor(r, c, dtype=dtype, device=device)
|
|
dense[dense == 0] = 1 # To prevent zeros except where mask below applied.
|
|
dense_inv = dense.masked_fill(~mask_inv, 0)
|
|
dense_val = dense_inv.masked_fill(~mask_val, 0)
|
|
|
|
return dense_inv, dense_val
|
|
|
|
|
|
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
|
|
|
def setUp(self):
|
|
if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0:
|
|
self.skipTest('semi-structured sparsity has no available backend!')
|
|
super().setUp()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
|
|
@staticmethod
|
|
def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
|
|
"""
|
|
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
|
|
We expect:
|
|
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()`
|
|
(2) Inductor should fuse the .contiguous() call into the relu
|
|
"""
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(128, 128)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = x.contiguous()
|
|
return torch.nn.functional.relu(x)
|
|
|
|
input = torch.rand(dense_input_shape, device="cuda").half()
|
|
model = Model().eval().cuda().half()
|
|
mod_linear = model.linear
|
|
m, n = mod_linear.weight.shape
|
|
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
|
|
# set masked weight
|
|
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
|
|
|
|
dense_result = model(input)
|
|
mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight))
|
|
sparse_result = model(input)
|
|
|
|
model = torch.compile(model, backend="inductor", fullgraph=True)
|
|
sparse_compile_result = model(input)
|
|
|
|
# test that sparse_compile_result and dense_result are numerically close
|
|
torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
|
|
# assert sparse and sparse_compile have the same strides,
|
|
# as meta registrations may return contiguous tensors when the output is transposed
|
|
# https://github.com/pytorch/pytorch/pull/114477
|
|
assert sparse_result.stride() == sparse_compile_result.stride()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
|
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_mlp_contiguous_relu_compile_cusparselt(self):
|
|
"""
|
|
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
|
|
"""
|
|
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
|
|
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape)
|
|
|
|
|
|
@unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine")
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_mlp_contiguous_relu_compile_cutlass(self):
|
|
"""
|
|
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
|
|
"""
|
|
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
|
|
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
|
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_sp24_compile(self) -> None:
|
|
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
|
|
|
def fn(x):
|
|
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
|
y = y.t()
|
|
return x @ y
|
|
|
|
# Eager
|
|
output = fn(x)
|
|
output.backward(output)
|
|
# Torch compile
|
|
output = torch.compile(fn)(x)
|
|
output.backward(output)
|
|
|
|
class TestSparseSemiStructured(TestCase):
|
|
|
|
def setUp(self):
|
|
if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0:
|
|
self.skipTest('semi-structured sparsity has no available backend!')
|
|
if IS_WINDOWS:
|
|
self.skipTest("torch.compile not supported on windows")
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
def test_to_sparse_semi_structured(self, dtype, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
assert A.shape == A_sparse.shape
|
|
assert A.device == A_sparse.device
|
|
assert A.dtype == A_sparse.dtype
|
|
|
|
assert isinstance(A, torch.Tensor)
|
|
assert isinstance(A_sparse, SparseSemiStructuredTensor)
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
|
|
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8:
|
|
if backend == "cutlass":
|
|
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
else:
|
|
dense_result = torch.mm(A, B)
|
|
sparse_result = torch.mm(A_sparse, B)
|
|
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
|
|
and will throw an error for int8 + padding
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8 and dense_input_shape in {(1, 128)}:
|
|
# padding with int8 throws an error because transposing B yields a contiguous output
|
|
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
|
|
if backend == "cutlass":
|
|
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
elif dtype is torch.int8:
|
|
# test transpose
|
|
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
# test transpose
|
|
dense_result = torch.mm(A, B.t())
|
|
sparse_result = torch.mm(A_sparse, B.t())
|
|
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize_backends
|
|
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
|
|
"""
|
|
Ensure torch.mm(A_sparse.t(), B) throws error
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
|
|
):
|
|
torch.mm(A_sparse.t(), B)
|
|
|
|
@inference_dtypes
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize_backends
|
|
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A, B_sparse.t()) is correct
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
B_sparse = to_sparse_semi_structured(B)
|
|
|
|
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
|
|
|
|
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
|
|
if dtype is torch.int8:
|
|
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
|
|
sparse_result = torch.mm(A, B_sparse.t())
|
|
else:
|
|
dense_result = torch.mm(A, B.t())
|
|
sparse_result = torch.mm(A, B_sparse.t())
|
|
|
|
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
|
|
@parametrize_backends
|
|
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
|
|
"""
|
|
Ensure torch.mm(A, B_sparse) throws error
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
B_sparse = to_sparse_semi_structured(B)
|
|
|
|
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
|
|
):
|
|
sparse_result = torch.mm(A, B_sparse)
|
|
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
|
@parametrize("inference_mode", [subtest(True), subtest(False)])
|
|
@parametrize_backends
|
|
def test_linear(self, dense_input_shape, inference_mode, device, backend):
|
|
"""
|
|
Test nn.Linear has the same numerics
|
|
"""
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
input = torch.rand((dense_input_shape), device=device).half()
|
|
model = nn.Linear(128, 256).to(device).half()
|
|
m, n = model.weight.shape
|
|
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
|
|
# set masked weight
|
|
model.weight = nn.Parameter(model.weight * mask)
|
|
|
|
dense_result = model(input)
|
|
|
|
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))
|
|
|
|
if inference_mode:
|
|
with torch.inference_mode():
|
|
sparse_result = model(input)
|
|
else:
|
|
sparse_result = model(input)
|
|
|
|
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
|
@parametrize_backends
|
|
def test_mlp(self, device, dense_input_shape, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
input = torch.rand(dense_input_shape, device=device).half()
|
|
model = (
|
|
nn.Sequential(
|
|
nn.Linear(128, 256),
|
|
nn.Linear(256, 128),
|
|
)
|
|
.half()
|
|
.to(device)
|
|
)
|
|
|
|
for i in range(2):
|
|
m, n = model[i].weight.shape
|
|
mask = rand_sparse_semi_structured_mask(
|
|
m, n, device=device, dtype=torch.bool
|
|
)
|
|
# set masked weight
|
|
model[i].weight = nn.Parameter(model[i].weight * mask)
|
|
|
|
dense_result = model(input)
|
|
|
|
for i in range(2):
|
|
model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight))
|
|
|
|
sparse_result = model(input)
|
|
|
|
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize_backends
|
|
def test_values(self, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(128, 128)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
assert A_sparse.values().shape == (128, 64)
|
|
assert (A_sparse.values() == 1).all()
|
|
|
|
@parametrize_backends
|
|
def test_indices(self, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(128, 128)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
assert A_sparse.indices().shape == (128, 8)
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
def test_min_sparse_shape(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype]
|
|
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
|
|
if dtype == torch.int8:
|
|
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8)
|
|
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
|
|
B_t = B.t().contiguous()
|
|
sparse_res = torch.mm(A_sparse, B_t.t())
|
|
else:
|
|
dense_res = torch.mm(A, B)
|
|
sparse_res = torch.mm(A_sparse, B)
|
|
torch.testing.assert_close(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
@parametrize_backends
|
|
def test_unsupported_shape(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device)
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@dtypes(*all_types_and_complex())
|
|
@parametrize_backends
|
|
def test_unsupported_dtype(self, dtype, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
|
|
|
|
if dtype not in SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS:
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
else:
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
@parametrize_backends
|
|
def test_unsupported_dim(self, device, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
|
|
A_sparse = to_sparse_semi_structured(A)
|
|
|
|
|
|
def create_random_mask(shape) -> torch.Tensor:
|
|
r = random.Random(0)
|
|
mask = torch.zeros(shape, dtype=torch.bool)
|
|
for line in range(mask.shape[0]):
|
|
for col in range(0, mask.shape[1], 4):
|
|
sparsity = r.choice(
|
|
[
|
|
[False, False, True, True],
|
|
[False, True, False, True],
|
|
[True, False, False, True],
|
|
[False, True, True, False],
|
|
[True, False, True, False],
|
|
[True, True, False, False],
|
|
]
|
|
)
|
|
mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool)
|
|
return mask
|
|
|
|
class TestSparseSemiStructuredTraining(TestCase):
|
|
|
|
def setUp(self):
|
|
if not _IS_SM8X:
|
|
self.skipTest("SparseSemiStructuredTensor training only supported on SM8x (Ampere)")
|
|
|
|
if IS_WINDOWS:
|
|
self.skipTest('CUTLASS not supported on windows')
|
|
|
|
|
|
@training_dtypes
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_prune_dense_static_sort(self, dtype) -> None:
|
|
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
|
|
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
|
|
dense = torch.randn(128, 128, device="cuda", dtype=dtype)
|
|
pruned = _sparse_semi_structured_tile(dense)
|
|
|
|
# CUTLASS
|
|
reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy")
|
|
torch.testing.assert_close(pruned, reference_cutlass.to_dense())
|
|
|
|
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
|
|
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
|
|
meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride())
|
|
meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride())
|
|
compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned)
|
|
compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape,
|
|
reference_cutlass.compressed_swizzled_bitmask.stride())
|
|
cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape,
|
|
packed_cutlass,
|
|
meta_cutlass,
|
|
packed_t_cutlass,
|
|
meta_t_cutlass,
|
|
compressed_swizzled_bitmask)
|
|
torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense())
|
|
|
|
# CUSPARSELT
|
|
reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned,
|
|
algorithm="largest_abs_values_greedy")
|
|
torch.testing.assert_close(pruned, reference_cusparselt.to_dense())
|
|
|
|
packed_cusparselt = torch._cslt_compress(pruned)
|
|
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
|
|
cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape,
|
|
packed_cusparselt,
|
|
None,
|
|
packed_t_cusparselt,
|
|
None,
|
|
compressed_swizzled_bitmask)
|
|
torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense())
|
|
|
|
|
|
|
|
@training_dtypes
|
|
@parametrize_backends
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
|
|
inp = torch.tensor(
|
|
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
|
|
device="cuda",
|
|
dtype=dtype,
|
|
)
|
|
inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1)
|
|
sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy")
|
|
|
|
mask = sInp.to_dense() / inp
|
|
assert mask[:4, :4].int().tolist() == [
|
|
[1, 1, 0, 0],
|
|
[0, 1, 1, 0],
|
|
[0, 0, 1, 1],
|
|
[1, 0, 0, 1],
|
|
]
|
|
|
|
@training_dtypes
|
|
def test_gemm(self, dtype) -> None:
|
|
M, N, K = 32, 32, 64
|
|
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
|
b = torch.randn([K, N], device="cuda", dtype=dtype)
|
|
mask = rand_sparse_semi_structured_mask(M, K, dtype=torch.bool)
|
|
|
|
a.masked_fill_(~mask, 0)
|
|
|
|
a_sparse = to_sparse_semi_structured(a)
|
|
|
|
masked_a = a * mask
|
|
ref_out = masked_a @ b
|
|
sp24_out = a_sparse @ b
|
|
torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype])
|
|
|
|
|
|
@training_dtypes
|
|
@parametrize_backends
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
|
|
M, N = 128, 256
|
|
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
|
|
a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :]
|
|
a = a.repeat(M // 8, N // 8)
|
|
assert a.shape == (M, N)
|
|
a = a.cuda().to(dtype)
|
|
b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype)
|
|
|
|
a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a)
|
|
|
|
mask_dense = sparse24_largest_mask_2d(a).to(dtype)
|
|
|
|
if backend == "cutlass":
|
|
assert isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS)
|
|
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(
|
|
mask_dense, use_cutlass=True)
|
|
|
|
sparse_mask = SparseSemiStructuredTensorCUTLASS(
|
|
mask_dense.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
)
|
|
torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta)
|
|
|
|
ref_gemm = (mask_dense * a) @ b
|
|
pack_gemm = a_sparse @ b
|
|
torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
|
|
|
|
@training_dtypes
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_pack_both_ways_id(self, dtype) -> None:
|
|
N = 512
|
|
torch.manual_seed(0)
|
|
a = torch.randn([N, N], dtype=dtype, device="cuda")
|
|
b = torch.eye(N, dtype=dtype, device="cuda")
|
|
|
|
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[
|
|
:4
|
|
]
|
|
# Heuristic to ensure we pack the same values
|
|
torch.testing.assert_close(
|
|
packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum()
|
|
)
|
|
|
|
mask_dense = sparse24_largest_mask_2d(a.to(dtype))
|
|
|
|
ref_gemm = mask_dense * a
|
|
# Test A@B
|
|
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t()
|
|
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
|
torch.testing.assert_close(
|
|
ref_gemm, pack_gemm,
|
|
**atol_rtol_kw[dtype]
|
|
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
|
# Test A.t@B
|
|
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
|
|
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
|
|
|
torch.testing.assert_close(
|
|
ref_gemm, pack_gemm,
|
|
**atol_rtol_kw[dtype]
|
|
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
|
|
|
@training_dtypes
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_pack_both_ways_edge_case1(self, dtype) -> None:
|
|
# In this case, the heuristic will keep 7 values out of 16
|
|
# instead of 8. let's see how the kernel handles this
|
|
quad = torch.tensor(
|
|
[
|
|
[2, -1, -2, -3], # Should be packed as `2 <null>`
|
|
[-1, 8, -1, 6],
|
|
[-1, -1, 4, 5],
|
|
[-1, 3, 7, -1],
|
|
],
|
|
dtype=dtype,
|
|
device="cuda",
|
|
)
|
|
a = torch.randn([32, 64], dtype=dtype, device="cuda")
|
|
a[:4, :4] = quad
|
|
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4]
|
|
# Check first line in A
|
|
assert packed[0, 0].item() == 2
|
|
assert packed[0, 1].item() == 0
|
|
# And first column in A.t
|
|
assert packed_t[0, 0].item() == 2
|
|
assert packed_t[0, 1].item() == 0
|
|
|
|
@training_dtypes
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_sp24_apply(self, dtype) -> None:
|
|
M, N = 256, 1024
|
|
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
|
(
|
|
packed,
|
|
meta,
|
|
packed_t,
|
|
meta_t,
|
|
bitmask,
|
|
) = torch._sparse_semi_structured_tile(x)
|
|
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
|
|
torch.testing.assert_close(packed, packed2)
|
|
torch.testing.assert_close(packed_t, packed_t2)
|
|
|
|
@training_dtypes
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_sp24_apply_dense(self, dtype) -> None:
|
|
M, N = 256, 1024
|
|
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
|
(
|
|
packed,
|
|
meta,
|
|
packed_t,
|
|
meta_t,
|
|
bitmask,
|
|
) = torch._sparse_semi_structured_tile(x)
|
|
|
|
expected = SparseSemiStructuredTensorCUTLASS(
|
|
x.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
).to_dense()
|
|
|
|
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
|
|
sparse = SparseSemiStructuredTensorCUTLASS(
|
|
x.shape,
|
|
packed=packed2,
|
|
meta=meta,
|
|
packed_t=packed_t2,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
)
|
|
|
|
dense = torch._sparse_semi_structured_apply_dense(x, bitmask)
|
|
|
|
torch.testing.assert_close(dense, expected)
|
|
torch.testing.assert_close(sparse.to_dense(), expected)
|
|
|
|
|
|
@training_dtypes
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_sp24_matmuls(self, dtype) -> None:
|
|
M, N, K = 64, 256, 1024
|
|
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
|
b = torch.randn([K, N], device="cuda", dtype=dtype)
|
|
a_m = sparse24_largest_mask_2d(a)
|
|
b_m = sparse24_largest_mask_2d(b)
|
|
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a)
|
|
a_s = SparseSemiStructuredTensorCUTLASS(
|
|
a.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
)
|
|
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b)
|
|
b_s = SparseSemiStructuredTensorCUTLASS(
|
|
b.shape,
|
|
packed=packed,
|
|
meta=meta,
|
|
packed_t=packed_t,
|
|
meta_t=meta_t,
|
|
compressed_swizzled_bitmask=bitmask,
|
|
)
|
|
|
|
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1.5e-1)
|
|
torch.testing.assert_close(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1.5e-1)
|
|
torch.testing.assert_close(
|
|
a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1
|
|
)
|
|
torch.testing.assert_close(
|
|
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
|
|
)
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_sp24_matmuls_mat_vec(self) -> None:
|
|
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
|
b = torch.randn([128], device="cuda", dtype=torch.float16)
|
|
a_m = sparse24_largest_mask_2d(a)
|
|
a_s = to_sparse_semi_structured(a)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_sp24_matmuls_bmm(self) -> None:
|
|
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
|
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
|
|
a_m = sparse24_largest_mask_2d(a)
|
|
a_s = to_sparse_semi_structured(a)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
|
|
|
class TestSparseSemiStructuredCUTLASS(TestCase):
|
|
"""
|
|
This contains CUTLASS specific tests for
|
|
- torch._sparse_semi_structured_linear
|
|
"""
|
|
def setUp(self):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = True
|
|
if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
|
self.skipTest('CUTLASS not enabled')
|
|
|
|
def tearDown(self):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
|
super().tearDown()
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
|
|
@inference_dtypes
|
|
def test_linear_cutlass(self, device, dtype):
|
|
|
|
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
|
|
weight = rand_sparse_semi_structured(m, k, dtype, device)
|
|
input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
|
|
bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None
|
|
|
|
dtype_dense = torch.float32
|
|
input_dense = input.to(dtype_dense)
|
|
weight_dense = weight.to(dtype_dense)
|
|
bias_dense = bias.to(dtype_dense) if add_bias else None
|
|
output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense)
|
|
if activation == "relu":
|
|
relu = torch.nn.ReLU()
|
|
output0 = relu(output0)
|
|
elif activation == "silu":
|
|
silu = torch.nn.SiLU()
|
|
output0 = silu(output0)
|
|
|
|
compressed = to_sparse_semi_structured(weight)
|
|
|
|
weight_sparse = compressed.values()
|
|
meta = compressed.indices()
|
|
|
|
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation,
|
|
out_dtype=dtype_out if dtype == torch.int8 else None)
|
|
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
|
|
|
|
if dtype == torch.float32:
|
|
# Inputs are converted to TF32 internally for sparse GEMM,
|
|
# so make dense GEMM to do the same for matching results.
|
|
orig = torch.backends.cuda.matmul.allow_tf32
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
batch_shapes = [[], [3], [3, 1]]
|
|
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
|
|
activations = [None, "relu", "silu"]
|
|
rtol, atol = 1e-3, 1e-3
|
|
if dtype == torch.bfloat16:
|
|
rtol, atol = 5e-3, 5e-3
|
|
elif dtype == torch.float32:
|
|
rtol, atol = 1e-3, 75e-2
|
|
for batch_shape, m, n, k, add_bias, activation in \
|
|
itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations):
|
|
if activation == "silu" and dtype == torch.int8:
|
|
continue # SiLU not supported for integer inputs
|
|
|
|
m = 2 ** m * 32
|
|
n = 2 ** n * 32
|
|
k = 2 ** k * 128
|
|
run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol)
|
|
|
|
if dtype == torch.float32:
|
|
torch.backends.cuda.matmul.allow_tf32 = orig
|
|
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
|
|
@parametrize("backend", ["cutlass"])
|
|
@inference_dtypes
|
|
def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
|
if backend == "cutlass" and IS_WINDOWS:
|
|
self.skipTest("CUTLASS not supported on Windows")
|
|
|
|
def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
|
|
mat1 = rand_sparse_semi_structured(m, k, dtype, device)
|
|
# mat2 transposed as int8 case supports only row-major/column-major combination
|
|
mat2 = make_tensor((n, k), dtype=dtype, device=device).t()
|
|
input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None
|
|
|
|
if use_input:
|
|
if dtype.is_floating_point:
|
|
alpha = 1.3
|
|
beta = -0.7
|
|
else:
|
|
alpha = 2
|
|
beta = -3
|
|
|
|
dtype_dense = torch.float32
|
|
mat1_dense = mat1.to(dtype_dense)
|
|
mat2_dense = mat2.to(dtype_dense)
|
|
if not use_input:
|
|
output0 = torch.mm(mat1_dense, mat2_dense)
|
|
else:
|
|
input_dense = input.to(dtype_dense)[:, None]
|
|
output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta)
|
|
|
|
compressed = to_sparse_semi_structured(mat1)
|
|
|
|
mat1_sparse = compressed.values()
|
|
mat1_meta = compressed.indices()
|
|
|
|
if not use_input:
|
|
output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out)
|
|
else:
|
|
output1 = torch._sparse_semi_structured_addmm(
|
|
input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out
|
|
)
|
|
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
|
|
|
|
if dtype == torch.float32:
|
|
# Inputs are converted to TF32 internally for sparse GEMM,
|
|
# so make dense GEMM to do the same for matching results.
|
|
orig = torch.backends.cuda.matmul.allow_tf32
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
|
|
rtol, atol = 1e-3, 1e-3
|
|
if dtype == torch.bfloat16:
|
|
rtol, atol = 5e-3, 5e-3
|
|
elif dtype == torch.float32:
|
|
rtol, atol = 1e-3, 75e-2
|
|
for m, n, k, use_input in \
|
|
itertools.product(range(3), range(3), range(3), (False, True)):
|
|
m = 2 ** m * 32
|
|
n = 2 ** n * 32
|
|
k = 2 ** k * 128
|
|
run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol)
|
|
|
|
if dtype == torch.float32:
|
|
torch.backends.cuda.matmul.allow_tf32 = orig
|
|
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@inference_dtypes
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_conversions(self, device, dtype):
|
|
|
|
def run_test(r, c, device, dtype):
|
|
dense_ref = rand_sparse_semi_structured(r, c, dtype, device)
|
|
|
|
compressed = to_sparse_semi_structured(dense_ref)
|
|
|
|
# The torch.ops.aten._to_sparse_semi_structured operator
|
|
# uses CUTLASS to perform conversion from given dense
|
|
# matrix to the pair of corresponding sparse and metadata
|
|
# matrices, with the later used here as a reference to
|
|
# compare the metadata matrix produced by conversion
|
|
# performed by SparseSemiStructuredTensor class
|
|
# constructor against.
|
|
_, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref)
|
|
|
|
meta = compressed.indices()
|
|
torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0)
|
|
|
|
dense = compressed.to_dense()
|
|
torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0)
|
|
|
|
shapes = [[32, 128], [32, 256], [64, 128], [64, 256]]
|
|
for r, c in shapes:
|
|
run_test(r, c, device, dtype)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@inference_dtypes
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_conversions_all_patterns(self, device, dtype):
|
|
r, c = 32, 128
|
|
|
|
dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device)
|
|
|
|
compressed = to_sparse_semi_structured(dense_inv)
|
|
dense = compressed.to_dense()
|
|
|
|
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
|
|
|
|
|
|
CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
|
|
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
|
finfo = torch.finfo(dtype)
|
|
# Calculate the scale as dtype max divided by absmax
|
|
scale = finfo.max / x.abs().max().clamp(min=1e-12)
|
|
# scale and clamp the tensor to bring it to
|
|
# the representative range of float8 data type
|
|
# (as default cast is unsaturated)
|
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
|
# Return both float8 data and the inverse scale (as float),
|
|
# as both required as inputs to torch._scaled_mm
|
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
|
|
|
class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|
"""
|
|
This contains cuSPARSELt specific tests for
|
|
torch._cslt_compress
|
|
torch._cslt_sparse_mm
|
|
"""
|
|
def setUp(self):
|
|
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
|
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
|
self.skipTest('cuSPARSELt not enabled')
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
@xfailIfSM89
|
|
@parametrize("dense_input_shape", [(256, 128)])
|
|
def test_sparse_fp8fp8_mm(self, dense_input_shape, device):
|
|
if torch.backends.cusparselt.version() < 602:
|
|
self.skipTest("fp8 matmul requires cuSPARSELt v0.6.2+")
|
|
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
|
|
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()
|
|
|
|
A_fp8, A_scale = to_float8(A)
|
|
B_fp8, B_scale = to_float8(B)
|
|
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
r"`SparseSemiStructuredTensor.*_scaled_mm",
|
|
):
|
|
dense_result = torch.mm(A_fp8_sparse, B_fp8)
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
@xfailIfSM89
|
|
def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None:
|
|
(k, l, m) = (32, 64, 32)
|
|
x = rand_sparse_semi_structured_mask(k, l, dtype=torch.float8_e4m3fn, device=device)
|
|
y = torch.full((m, l), .25, device=device, dtype=torch.float8_e4m3fn).t()
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn)
|
|
|
|
x_sparse = to_sparse_semi_structured(x)
|
|
out_fp8_sparse = torch._scaled_mm(x_sparse, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn)
|
|
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
|
out_fp32 = out_fp8.to(torch.float32)
|
|
out_fp32_sparse = out_fp8_sparse.to(torch.float32)
|
|
torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1)
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FP8,
|
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
|
)
|
|
@xfailIfSM89
|
|
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
@parametrize("dense_input_shape", [(256, 128)])
|
|
def test_sparse_semi_structured_scaled_mm(
|
|
self, dense_input_shape, device, out_dtype
|
|
):
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
|
|
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()
|
|
|
|
A_fp8, A_scale = to_float8(A)
|
|
B_fp8, B_scale = to_float8(B)
|
|
|
|
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
|
|
|
|
dense_result = torch._scaled_mm(
|
|
A_fp8, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype
|
|
)
|
|
sparse_result = torch._scaled_mm(
|
|
A_fp8_sparse, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype
|
|
)
|
|
torch.testing.assert_close(dense_result, sparse_result, rtol=7e-2, atol=7e-2)
|
|
|
|
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
|
@parametrize("dense_input_shape", [(128, 128)])
|
|
def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
|
|
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
|
|
A_compressed = torch._cslt_compress(A)
|
|
|
|
B = torch.rand(dense_input_shape, device=device).to(torch.int8)
|
|
|
|
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype)
|
|
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling")
|
|
@training_dtypes
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_cslt_sparse_mm_alpha(self, dtype, device):
|
|
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
|
|
B = torch.ones((256, 128), device=device).to(dtype)
|
|
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
|
|
bias = torch.ones(128, device=device).to(dtype)
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, bias=bias)
|
|
|
|
alpha_scaled = torch.stack([alpha] * 128).t()
|
|
dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32))
|
|
dense_result = dense_result.to(dtype)
|
|
|
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_cslt_sparse_mm_alpha_compile_autotune(self, device, out_dtype):
|
|
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).to(device)
|
|
B = torch.ones((128, 256), device=device, dtype=torch.int8).t()
|
|
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
|
|
cslt_sparse_mm_c = torch.compile(torch._cslt_sparse_mm, mode="max-autotune")
|
|
sparse_result = cslt_sparse_mm_c(A_compressed, B, alpha=alpha, out_dtype=out_dtype)
|
|
|
|
# disable this otherwise inductor will attempt to reorder strides and pass a contiguous B
|
|
@torch.compiler.disable
|
|
def get_dense_result():
|
|
alpha_scaled = torch.stack([alpha] * 128).t().cpu().float()
|
|
dense_result = alpha_scaled * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
|
|
dense_result = dense_result.to(out_dtype)
|
|
return dense_result
|
|
|
|
torch.testing.assert_close(sparse_result.cpu(), get_dense_result(), rtol=1e-3, atol=1e-3)
|
|
|
|
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
|
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
|
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
|
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
|
alpha = torch.Tensor([2**(-i) if out_dtype is not torch.int32 else 1
|
|
for i in range(128)]).cuda()
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=out_dtype).cpu()
|
|
|
|
alpha_scaled = torch.stack([alpha] * 128).t()
|
|
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
|
|
dense_result = dense_result.to(out_dtype)
|
|
|
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
def test_cslt_sparse_mm_search(self, device, dtype):
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
A_compressed = torch._cslt_compress(A)
|
|
B = torch.ones((128, 128), device=device).to(dtype)
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
|
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
|
dense_result = dense_result.to(dtype)
|
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
@inference_dtypes
|
|
def test_csrc_cslt_sparse_mm_search(self, device, dtype):
|
|
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
|
A_compressed = torch._cslt_compress(A)
|
|
B = torch.ones((128, 128), device=device).to(dtype)
|
|
|
|
A_compressed = torch._cslt_compress(A)
|
|
alg_id, split_k, split_k_mode, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
|
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
|
|
alg_id=alg_id,
|
|
split_k=split_k,
|
|
split_k_mode=split_k_mode)
|
|
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
|
dense_result = dense_result.to(dtype)
|
|
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
|
|
|
def test_cusparselt_backend(self):
|
|
version = _get_torch_cuda_version()
|
|
assert torch.backends.cusparselt.is_available()
|
|
|
|
# CUDA 11.8 has cuSPARSELt v0.4.0 support
|
|
if version == (11, 8):
|
|
assert torch.backends.cusparselt.version() == 400
|
|
# PyTorch CUDA 12.4+ using cuSPARSELt v0.6.2+
|
|
elif version >= (12, 4):
|
|
assert torch.backends.cusparselt.version() >= 602
|
|
else:
|
|
assert torch.backends.cusparselt.version() is None
|
|
|
|
if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) > 0:
|
|
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
|
if "cutlass" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
|
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
|
|
instantiate_device_type_tests(TestSparseSemiStructuredTraining, globals(), only_for="cuda")
|
|
if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
|
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|