[fake_impls] fix max_seqlen return values in efficient_attention_forward (#120842)

To match the actual implementation, we should return the max_seqlen_q/k, not M, N, when in the sparse case

7e185277cd/aten/src/ATen/native/transformers/cuda/attention.cu (L981-L996)

Note that although the .cu file sets max_seqlen_k = 0 in the sparse case, it actually returns max_seqlen_k or N:

7e185277cd/aten/src/ATen/native/transformers/cuda/attention.cu (L1224-L1231)

Tests - added in the next PR (#102839, which also fixes other parts of the test_fake tests so that we can un-xfail them and actually run the tests)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120842
Approved by: https://github.com/YuqingJ
ghstack dependencies: #120682
This commit is contained in:
David Berard
2024-02-28 20:12:15 -08:00
committed by PyTorch MergeBot
parent d1d50d2e4c
commit df1e855313
2 changed files with 5 additions and 2 deletions

View File

@ -2162,6 +2162,7 @@ class TestFakeTensor(TestCase):
):
if not isinstance(fake_out, torch.Tensor):
self.assertTrue(not isinstance(real_out, torch.Tensor))
self.assertEqual(fake_out, real_out)
continue
self.assertTrue(isinstance(fake_out, FakeTensor))

View File

@ -803,8 +803,9 @@ def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
value = kwargs["value"]
cu_seqlens_q = kwargs["cu_seqlens_q"]
max_seqlen_q = kwargs["max_seqlen_q"]
max_seqlen_k = kwargs["max_seqlen_k"]
compute_log_sumexp = kwargs["compute_log_sumexp"]
# unused: bias, cu_seqlens_k, max_seqlen_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k
# unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)
@ -826,6 +827,7 @@ def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
if cu_seqlens_q is not None:
assert max_seqlen_q is not None
actual_max_seqlen_q = max_seqlen_q
actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
logsumexp_dim = (
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
)
@ -846,7 +848,7 @@ def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
torch.empty((), dtype=torch.long, device="meta"), query.device
)
return res, logsum_exp, seed, offset, M, N
return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
FAST_OP_IMPLEMENTATIONS = {}