mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
[fake_impls] fix max_seqlen return values in efficient_attention_forward (#120842)
To match the actual implementation, we should return the max_seqlen_q/k, not M, N, when in the sparse case7e185277cd/aten/src/ATen/native/transformers/cuda/attention.cu (L981-L996)Note that although the .cu file sets max_seqlen_k = 0 in the sparse case, it actually returns max_seqlen_k or N:7e185277cd/aten/src/ATen/native/transformers/cuda/attention.cu (L1224-L1231)Tests - added in the next PR (#102839, which also fixes other parts of the test_fake tests so that we can un-xfail them and actually run the tests) Pull Request resolved: https://github.com/pytorch/pytorch/pull/120842 Approved by: https://github.com/YuqingJ ghstack dependencies: #120682
This commit is contained in:
committed by
PyTorch MergeBot
parent
d1d50d2e4c
commit
df1e855313
@ -2162,6 +2162,7 @@ class TestFakeTensor(TestCase):
|
||||
):
|
||||
if not isinstance(fake_out, torch.Tensor):
|
||||
self.assertTrue(not isinstance(real_out, torch.Tensor))
|
||||
self.assertEqual(fake_out, real_out)
|
||||
continue
|
||||
|
||||
self.assertTrue(isinstance(fake_out, FakeTensor))
|
||||
|
||||
@ -803,8 +803,9 @@ def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
|
||||
value = kwargs["value"]
|
||||
cu_seqlens_q = kwargs["cu_seqlens_q"]
|
||||
max_seqlen_q = kwargs["max_seqlen_q"]
|
||||
max_seqlen_k = kwargs["max_seqlen_k"]
|
||||
compute_log_sumexp = kwargs["compute_log_sumexp"]
|
||||
# unused: bias, cu_seqlens_k, max_seqlen_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k
|
||||
# unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k
|
||||
|
||||
def convert_tensor(t, device):
|
||||
return FakeTensor(fake_mode, t, device)
|
||||
@ -826,6 +827,7 @@ def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
|
||||
if cu_seqlens_q is not None:
|
||||
assert max_seqlen_q is not None
|
||||
actual_max_seqlen_q = max_seqlen_q
|
||||
actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
|
||||
logsumexp_dim = (
|
||||
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
|
||||
)
|
||||
@ -846,7 +848,7 @@ def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
|
||||
torch.empty((), dtype=torch.long, device="meta"), query.device
|
||||
)
|
||||
|
||||
return res, logsum_exp, seed, offset, M, N
|
||||
return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
|
||||
|
||||
|
||||
FAST_OP_IMPLEMENTATIONS = {}
|
||||
|
||||
Reference in New Issue
Block a user