mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] semi-structured sparse + torch.compile support (#111049)
Summary: This PR adds in torch.compile support for semi-structured sparsity, using the subclass tracing @bdhirsh added. Based on wether we are using cuSPARSELt or CUTLASS, we return a different representation of the inner tensors. Test Plan: ``` python test/test_sparse_semi_structured.py -k compile ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/111049 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
5eac44bc72
commit
702aaf8aea
@ -213,6 +213,37 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
self.sparse_tensor_cutlass = sparse_tensor_cutlass
|
||||
self.meta_tensor_cutlass = meta_tensor_cutlass
|
||||
self.transposed = transposed
|
||||
self.original_shape = original_shape
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
if self.compressed_tensor_cusparselt is not None:
|
||||
return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed)
|
||||
else:
|
||||
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, meta):
|
||||
original_shape, transposed = meta
|
||||
|
||||
if len(inner_tensors) == 2:
|
||||
sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass']
|
||||
meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass']
|
||||
compressed_tensor_cusparselt = None
|
||||
elif len(inner_tensors) == 1:
|
||||
sparse_tensor_cutlass = None
|
||||
meta_tensor_cutlass = None
|
||||
compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt']
|
||||
else:
|
||||
raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}")
|
||||
|
||||
return SparseSemiStructuredTensor(
|
||||
None,
|
||||
original_shape=original_shape,
|
||||
compressed_tensor_cusparselt=compressed_tensor_cusparselt,
|
||||
sparse_tensor_cutlass=sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass=meta_tensor_cutlass,
|
||||
transposed=transposed,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
"""Return string representation of SparseSemiStructuredTensor
|
||||
|
Reference in New Issue
Block a user