FlexDecode not guarding on GQA groups correctly (#160904)

Addressing #151359

Updates flex_decode dispatch to use flex attention rather than flex decode if number of groups is not a power of 2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160904
Approved by: https://github.com/drisspg
This commit is contained in:
Angel Li
2025-08-20 16:32:13 +00:00
committed by PyTorch MergeBot
parent e631557518
commit 8e17709055
3 changed files with 22 additions and 2 deletions

View File

@ -1739,6 +1739,16 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
rtol=tolerance.rtol,
)
@supported_platform
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
def test_not_pw_of_two(self):
query = torch.randn(1, 12, 1, 16, device="cuda")
key = torch.randn(1, 2, 128, 16, device="cuda")
value = torch.randn(1, 2, 128, 16, device="cuda")
flex_compiled = torch.compile(flex_attention)
flex_compiled(query, key, value, enable_gqa=True)
@supported_platform
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
def test_logsumexp_only_return(self):

View File

@ -164,7 +164,7 @@ def flex_attention(
enable_gqa = V.graph.sizevars.evaluate_expr(
sympy.Ne(query.get_size()[1], key.get_size()[1]),
)
if _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa):
if _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa):
return create_flex_decoding_kernel(
query,
key,

View File

@ -31,7 +31,7 @@ aten = torch.ops.aten
prims = torch.ops.prims
def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa) -> bool:
def _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa) -> bool:
"""Decide which kernel to use, return true if use flex decoding kernel.
Note:
Since the number of splits is calculated based of the the number of batch and head dims
@ -60,6 +60,15 @@ def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa) -> bool:
sympy.Eq(kv_indices.get_size()[1], query.get_size()[1]),
)
)
Hq = query.get_size()[1]
Hkv = value.get_size()[1]
ratio = Hq // Hkv
pw_of_two = V.graph.sizevars.guard_or_false(
sympy.And(sympy.Gt(ratio, 0), sympy.Eq(ratio & (ratio - 1), 0))
)
return (
not force_flex
and short_query_length
@ -67,6 +76,7 @@ def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa) -> bool:
and static_num_heads
and non_zero_length
and valid_block_mask_num_heads
and pw_of_two
)