mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
FlexDecode not guarding on GQA groups correctly (#160904)
Addressing #151359 Updates flex_decode dispatch to use flex attention rather than flex decode if number of groups is not a power of 2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160904 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
e631557518
commit
8e17709055
@ -1739,6 +1739,16 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
rtol=tolerance.rtol,
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
|
||||
def test_not_pw_of_two(self):
|
||||
query = torch.randn(1, 12, 1, 16, device="cuda")
|
||||
key = torch.randn(1, 2, 128, 16, device="cuda")
|
||||
value = torch.randn(1, 2, 128, 16, device="cuda")
|
||||
|
||||
flex_compiled = torch.compile(flex_attention)
|
||||
flex_compiled(query, key, value, enable_gqa=True)
|
||||
|
||||
@supported_platform
|
||||
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
|
||||
def test_logsumexp_only_return(self):
|
||||
|
@ -164,7 +164,7 @@ def flex_attention(
|
||||
enable_gqa = V.graph.sizevars.evaluate_expr(
|
||||
sympy.Ne(query.get_size()[1], key.get_size()[1]),
|
||||
)
|
||||
if _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa):
|
||||
if _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa):
|
||||
return create_flex_decoding_kernel(
|
||||
query,
|
||||
key,
|
||||
|
@ -31,7 +31,7 @@ aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
|
||||
def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa) -> bool:
|
||||
def _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa) -> bool:
|
||||
"""Decide which kernel to use, return true if use flex decoding kernel.
|
||||
Note:
|
||||
Since the number of splits is calculated based of the the number of batch and head dims
|
||||
@ -60,6 +60,15 @@ def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa) -> bool:
|
||||
sympy.Eq(kv_indices.get_size()[1], query.get_size()[1]),
|
||||
)
|
||||
)
|
||||
|
||||
Hq = query.get_size()[1]
|
||||
Hkv = value.get_size()[1]
|
||||
ratio = Hq // Hkv
|
||||
|
||||
pw_of_two = V.graph.sizevars.guard_or_false(
|
||||
sympy.And(sympy.Gt(ratio, 0), sympy.Eq(ratio & (ratio - 1), 0))
|
||||
)
|
||||
|
||||
return (
|
||||
not force_flex
|
||||
and short_query_length
|
||||
@ -67,6 +76,7 @@ def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa) -> bool:
|
||||
and static_num_heads
|
||||
and non_zero_length
|
||||
and valid_block_mask_num_heads
|
||||
and pw_of_two
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user