mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] semi-structured sparse + torch.compile support (#111049)
Summary: This PR adds in torch.compile support for semi-structured sparsity, using the subclass tracing @bdhirsh added. Based on wether we are using cuSPARSELt or CUTLASS, we return a different representation of the inner tensors. Test Plan: ``` python test/test_sparse_semi_structured.py -k compile ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/111049 Approved by: https://github.com/cpuhrsch ghstack dependencies: #110583
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c30814417
commit
ac02531bab
@ -341,6 +341,44 @@ class TestSparseSemiStructured(TestCase):
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
|
||||
def test_mlp_contiguous_relu_compile(self, dense_input_shape, backend, device):
|
||||
"""
|
||||
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
|
||||
We expect:
|
||||
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()`
|
||||
(2) Inductor should fuse the .contiguous() call into the relu
|
||||
"""
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(128, 128)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = x.contiguous()
|
||||
return torch.nn.functional.relu(x)
|
||||
|
||||
input = torch.rand(dense_input_shape, device=device).half()
|
||||
model = Model().eval().cuda().half()
|
||||
mod_linear = model.linear
|
||||
m, n = mod_linear.weight.shape
|
||||
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
|
||||
# set masked weight
|
||||
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
|
||||
|
||||
dense_result = model(input)
|
||||
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
|
||||
|
||||
model = torch.compile(model)
|
||||
sparse_result = model(input)
|
||||
|
||||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
def test_values(self, backend):
|
||||
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
|
||||
|
@ -213,6 +213,37 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
self.sparse_tensor_cutlass = sparse_tensor_cutlass
|
||||
self.meta_tensor_cutlass = meta_tensor_cutlass
|
||||
self.transposed = transposed
|
||||
self.original_shape = original_shape
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
if self.compressed_tensor_cusparselt is not None:
|
||||
return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed)
|
||||
else:
|
||||
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, meta):
|
||||
original_shape, transposed = meta
|
||||
|
||||
if len(inner_tensors) == 2:
|
||||
sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass']
|
||||
meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass']
|
||||
compressed_tensor_cusparselt = None
|
||||
elif len(inner_tensors) == 1:
|
||||
sparse_tensor_cutlass = None
|
||||
meta_tensor_cutlass = None
|
||||
compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt']
|
||||
else:
|
||||
raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}")
|
||||
|
||||
return SparseSemiStructuredTensor(
|
||||
None,
|
||||
original_shape=original_shape,
|
||||
compressed_tensor_cusparselt=compressed_tensor_cusparselt,
|
||||
sparse_tensor_cutlass=sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass=meta_tensor_cutlass,
|
||||
transposed=transposed,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
"""Return string representation of SparseSemiStructuredTensor
|
||||
|
Reference in New Issue
Block a user