mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Export] Fix SDPA decomposition (#135297)
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary. Test Plan: CI Differential Revision: D62278378 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135297 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
118d7e1480
commit
5d964a5eb7
@ -1179,24 +1179,6 @@ class DecompOneOffTests(TestCase):
|
||||
def test_sdpa(self, device, dtype, op):
|
||||
# SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we
|
||||
# add support for float16 over there we should update this test as well.
|
||||
|
||||
class ScaledDotProductAttention(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, query_layer, key_layer, value_layer, mask=None, is_causal=True
|
||||
):
|
||||
attn_output = op(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
return attn_output
|
||||
|
||||
query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
|
||||
key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
|
||||
value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
|
||||
@ -1206,12 +1188,17 @@ class DecompOneOffTests(TestCase):
|
||||
|
||||
for mask in masks:
|
||||
is_causal = mask is None
|
||||
attention = ScaledDotProductAttention()
|
||||
decomposed_res = (
|
||||
torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
|
||||
query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
|
||||
)
|
||||
)
|
||||
actual_res = decomposed_res[0]
|
||||
# Output has form (N, H, L, E), but should be continuous on (L, N, H, E)
|
||||
# in order for subsequent view(L * N, H * E) to be valid.
|
||||
# So permute(2, 0, 1, 3) before checking that tensor is contiguous
|
||||
self.assertTrue(actual_res.permute(2, 0, 1, 3).is_contiguous())
|
||||
|
||||
eager_res = op(
|
||||
query_layer,
|
||||
key_layer,
|
||||
@ -1221,9 +1208,7 @@ class DecompOneOffTests(TestCase):
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol)
|
||||
)
|
||||
self.assertTrue(torch.allclose(actual_res, eager_res, atol=atol, rtol=rtol))
|
||||
|
||||
|
||||
instantiate_device_type_tests(DecompOneOffTests, globals())
|
||||
|
Reference in New Issue
Block a user