[Export] Fix SDPA decomposition (#135297)

Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary.

Test Plan: CI

Differential Revision: D62278378

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135297
Approved by: https://github.com/drisspg
This commit is contained in:
Sidney Tsang
2024-09-11 20:21:59 +00:00
committed by PyTorch MergeBot
parent 118d7e1480
commit 5d964a5eb7
2 changed files with 14 additions and 35 deletions

View File

@ -1179,24 +1179,6 @@ class DecompOneOffTests(TestCase):
def test_sdpa(self, device, dtype, op):
# SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we
# add support for float16 over there we should update this test as well.
class ScaledDotProductAttention(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self, query_layer, key_layer, value_layer, mask=None, is_causal=True
):
attn_output = op(
query_layer,
key_layer,
value_layer,
attn_mask=mask,
dropout_p=0.0,
is_causal=is_causal,
)
return attn_output
query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
@ -1206,12 +1188,17 @@ class DecompOneOffTests(TestCase):
for mask in masks:
is_causal = mask is None
attention = ScaledDotProductAttention()
decomposed_res = (
torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
)
)
actual_res = decomposed_res[0]
# Output has form (N, H, L, E), but should be continuous on (L, N, H, E)
# in order for subsequent view(L * N, H * E) to be valid.
# So permute(2, 0, 1, 3) before checking that tensor is contiguous
self.assertTrue(actual_res.permute(2, 0, 1, 3).is_contiguous())
eager_res = op(
query_layer,
key_layer,
@ -1221,9 +1208,7 @@ class DecompOneOffTests(TestCase):
is_causal=is_causal,
)
self.assertTrue(
torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol)
)
self.assertTrue(torch.allclose(actual_res, eager_res, atol=atol, rtol=rtol))
instantiate_device_type_tests(DecompOneOffTests, globals())