[sparse] semi-structured sparse + torch.compile support (#111049)

Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111049
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Jesse Cai
2023-10-23 13:46:11 -07:00
committed by PyTorch MergeBot
parent 5eac44bc72
commit 702aaf8aea
3 changed files with 89 additions and 5 deletions

View File

@ -20,13 +20,15 @@ from torch.testing._internal.common_device_type import (
)
from torch.testing._internal.common_dtype import all_types_and_complex
import torch._dynamo.test_case
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
subtest,
TestCase,
TEST_WITH_ROCM
TEST_WITH_ROCM,
IS_WINDOWS,
)
from torch.utils._triton import has_triton
@ -115,13 +117,64 @@ def rand_dense_2by4_all_patterns(r, c, dtype, device):
return dense_inv, dense_val
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
super().setUp()
def tearDown(self):
super().tearDown()
@unittest.skipIf(IS_WINDOWS, "torch.compile not support on windows")
def test_mlp_contiguous_relu_compile(self):
"""
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
We expect:
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()`
(2) Inductor should fuse the .contiguous() call into the relu
"""
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(128, 128)
def forward(self, x):
x = self.linear(x)
x = x.contiguous()
return torch.nn.functional.relu(x)
def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
input = torch.rand(dense_input_shape, device="cuda").half()
model = Model().eval().cuda().half()
mod_linear = model.linear
m, n = mod_linear.weight.shape
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
# set masked weight
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
dense_result = model(input)
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
model = torch.compile(model)
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
for backend in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
_test_mlp_contiguous_relu_compile(backend, dense_input_shape)
class TestSparseSemiStructured(TestCase):
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_to_sparse_semi_structured(self, dtype, backend):
@ -274,7 +327,7 @@ class TestSparseSemiStructured(TestCase):
"""
if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
SparseSemiStructuredTensor._FORCE_CUTLASS = False
A = rand_sparse_semi_structured_mask(128, 256, dtype=torch.int8)
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(torch.int8)

View File

@ -322,7 +322,7 @@ def _sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def sparse_semi_structured_from_dense_cutlass(dense, compile=True):
def sparse_semi_structured_from_dense_cutlass(dense, compile=False):
if compile:
from torch._dynamo.utils import is_compile_supported
if is_compile_supported(dense.device.type):
@ -336,7 +336,7 @@ def sparse_semi_structured_from_dense_cutlass(dense, compile=True):
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered, compile=True):
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered, compile=False):
if compile:
from torch._dynamo.utils import is_compile_supported
if is_compile_supported(sparse.device.type):

View File

@ -213,6 +213,37 @@ class SparseSemiStructuredTensor(torch.Tensor):
self.sparse_tensor_cutlass = sparse_tensor_cutlass
self.meta_tensor_cutlass = meta_tensor_cutlass
self.transposed = transposed
self.original_shape = original_shape
def __tensor_flatten__(self):
if self.compressed_tensor_cusparselt is not None:
return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed)
else:
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
@staticmethod
def __tensor_unflatten__(inner_tensors, meta):
original_shape, transposed = meta
if len(inner_tensors) == 2:
sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass']
meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass']
compressed_tensor_cusparselt = None
elif len(inner_tensors) == 1:
sparse_tensor_cutlass = None
meta_tensor_cutlass = None
compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt']
else:
raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}")
return SparseSemiStructuredTensor(
None,
original_shape=original_shape,
compressed_tensor_cusparselt=compressed_tensor_cusparselt,
sparse_tensor_cutlass=sparse_tensor_cutlass,
meta_tensor_cutlass=meta_tensor_cutlass,
transposed=transposed,
)
def __repr__(self) -> str: # type: ignore[override]
"""Return string representation of SparseSemiStructuredTensor