mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] Fix semi-structured sparse shape mismatch bug (#110420)
Summary: Currently, PyTorch incorrectly calculates the size of the returned matrix when we pass a non-contiguous batched (>2d) input to the semi-structured sparse subclass. This is most common in MLP layers, where we have 2 linear layers back to back. This will lead to an error like the following: ``` RuntimeError: shape '[20, 64, 64, 3072]' is invalid for input of size 62914560 ``` Where the size of the sparse matmul result is off because we infer the output shape with the wrong tensor shape. This happens because of a bug where we did not update the subclass tensor shape when doing transpose. For semi-structured sparsity, transposing is a no-op where we just set the boolean flag, but we forgot to also update the tensor shape. Note that this error goes away in inference mode, since we avoid decomposing the aten.linear op and handle shape folding ourselves, which changes the execution path. An alternative way to fix this issue is to set TORCH_FLATTEN_LINEAR_3D=True, which will also fix this error. Test Plan: ``` python test/test_sparse_semi_structured.py -k test_mlp ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/110420 Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
468a73f0e3
commit
f10aab03c4
@ -52,7 +52,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
"""
|
||||
|
||||
_FUSE_TRANSPOSE = False
|
||||
_FORCE_CUTLASS = False
|
||||
_FORCE_CUTLASS = True
|
||||
_WARNING_SHOWN = False
|
||||
|
||||
@staticmethod
|
||||
@ -268,7 +268,8 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
if func is torch.ops.aten.t.default:
|
||||
return SparseSemiStructuredTensor(
|
||||
args[0].original_tensor,
|
||||
original_shape=args[0].shape,
|
||||
# transpose shape
|
||||
original_shape=torch.Size([args[0].shape[1], args[0].shape[0]]),
|
||||
compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt,
|
||||
sparse_tensor_cutlass=args[0].sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass=args[0].meta_tensor_cutlass,
|
||||
@ -438,4 +439,6 @@ def to_sparse_semi_structured(
|
||||
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
|
||||
dtype=torch.int16))
|
||||
"""
|
||||
return SparseSemiStructuredTensor(original_tensor, original_shape=original_tensor.shape, transposed=transposed)
|
||||
return SparseSemiStructuredTensor(
|
||||
original_tensor, original_shape=original_tensor.shape, transposed=transposed
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user