[BugFix] FA2 MLA Accuracy Issue (#18807)

Signed-off-by: LucasWilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-05-28 04:59:39 -04:00
committed by simon-mo
parent 5873877241
commit 5fbbfe9a4c
3 changed files with 16 additions and 8 deletions

View File

@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
"output heads must be contiguous in memory");
TORCH_CHECK(
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
"prefix_output heads must be contiguous in memory");
TORCH_CHECK(
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
"suffix_output heads must be contiguous in memory");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();

View File

@ -1093,10 +1093,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if isinstance(attn_out, tuple):
attn_out, *rest = attn_out
# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
@ -1294,6 +1290,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse=suffix_lse,
)
# unpad if necessary
if self._pad_v:
output = output[..., :v.shape[-1]]
return output.flatten(start_dim=-2)
@abstractmethod

View File

@ -653,10 +653,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if isinstance(attn_out, tuple):
attn_out, lse = attn_out[0], attn_out[1]
# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
@ -839,6 +835,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
suffix_lse=suffix_lse,
)
# unpad if necessary
if self._pad_v:
output = output[..., :v.shape[-1]]
return output.flatten(start_dim=-2)
@abstractmethod