mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] semi-structured sparse + torch.compile support (#111049)
Summary: This PR adds in torch.compile support for semi-structured sparsity, using the subclass tracing @bdhirsh added. Based on wether we are using cuSPARSELt or CUTLASS, we return a different representation of the inner tensors. Test Plan: ``` python test/test_sparse_semi_structured.py -k compile ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/111049 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
5eac44bc72
commit
702aaf8aea
@ -20,13 +20,15 @@ from torch.testing._internal.common_device_type import (
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex
|
||||
import torch._dynamo.test_case
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
TestCase,
|
||||
TEST_WITH_ROCM
|
||||
TEST_WITH_ROCM,
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
from torch.utils._triton import has_triton
|
||||
@ -115,13 +117,64 @@ def rand_dense_2by4_all_patterns(r, c, dtype, device):
|
||||
return dense_inv, dense_val
|
||||
|
||||
|
||||
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile not support on windows")
|
||||
def test_mlp_contiguous_relu_compile(self):
|
||||
"""
|
||||
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
|
||||
We expect:
|
||||
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()`
|
||||
(2) Inductor should fuse the .contiguous() call into the relu
|
||||
"""
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(128, 128)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = x.contiguous()
|
||||
return torch.nn.functional.relu(x)
|
||||
|
||||
def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
input = torch.rand(dense_input_shape, device="cuda").half()
|
||||
|
||||
model = Model().eval().cuda().half()
|
||||
mod_linear = model.linear
|
||||
m, n = mod_linear.weight.shape
|
||||
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
|
||||
# set masked weight
|
||||
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
|
||||
|
||||
dense_result = model(input)
|
||||
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
|
||||
|
||||
model = torch.compile(model)
|
||||
sparse_result = model(input)
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
for backend in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
|
||||
_test_mlp_contiguous_relu_compile(backend, dense_input_shape)
|
||||
|
||||
class TestSparseSemiStructured(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_to_sparse_semi_structured(self, dtype, backend):
|
||||
@ -274,7 +327,7 @@ class TestSparseSemiStructured(TestCase):
|
||||
"""
|
||||
if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
||||
A = rand_sparse_semi_structured_mask(128, 256, dtype=torch.int8)
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
|
||||
A_sparse = to_sparse_semi_structured(A)
|
||||
|
||||
B = torch.rand(dense_input_shape, device=A_sparse.device).to(torch.int8)
|
||||
|
@ -322,7 +322,7 @@ def _sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
||||
# This function converts dense matrix into sparse semi-structured
|
||||
# representation, producing "compressed" matrix, in the layout used by
|
||||
# CUTLASS backend, and corresponding metadata matrix.
|
||||
def sparse_semi_structured_from_dense_cutlass(dense, compile=True):
|
||||
def sparse_semi_structured_from_dense_cutlass(dense, compile=False):
|
||||
if compile:
|
||||
from torch._dynamo.utils import is_compile_supported
|
||||
if is_compile_supported(dense.device.type):
|
||||
@ -336,7 +336,7 @@ def sparse_semi_structured_from_dense_cutlass(dense, compile=True):
|
||||
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
||||
# in the layout used by CUTLASS backend, and accompanying metadata
|
||||
# matrix.
|
||||
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered, compile=True):
|
||||
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered, compile=False):
|
||||
if compile:
|
||||
from torch._dynamo.utils import is_compile_supported
|
||||
if is_compile_supported(sparse.device.type):
|
||||
|
@ -213,6 +213,37 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
self.sparse_tensor_cutlass = sparse_tensor_cutlass
|
||||
self.meta_tensor_cutlass = meta_tensor_cutlass
|
||||
self.transposed = transposed
|
||||
self.original_shape = original_shape
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
if self.compressed_tensor_cusparselt is not None:
|
||||
return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed)
|
||||
else:
|
||||
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, meta):
|
||||
original_shape, transposed = meta
|
||||
|
||||
if len(inner_tensors) == 2:
|
||||
sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass']
|
||||
meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass']
|
||||
compressed_tensor_cusparselt = None
|
||||
elif len(inner_tensors) == 1:
|
||||
sparse_tensor_cutlass = None
|
||||
meta_tensor_cutlass = None
|
||||
compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt']
|
||||
else:
|
||||
raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}")
|
||||
|
||||
return SparseSemiStructuredTensor(
|
||||
None,
|
||||
original_shape=original_shape,
|
||||
compressed_tensor_cusparselt=compressed_tensor_cusparselt,
|
||||
sparse_tensor_cutlass=sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass=meta_tensor_cutlass,
|
||||
transposed=transposed,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
"""Return string representation of SparseSemiStructuredTensor
|
||||
|
Reference in New Issue
Block a user