[Bugfix] Add missing sink tensor into flash attn cascade attn implementation (#26325)

This commit is contained in:
Pei-Lun Liao
2025-10-07 11:56:39 -07:00
committed by GitHub
parent 8f36850f73
commit eb577e4655

View File

@ -607,6 +607,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=layer._q_scale, q_descale=layer._q_scale,
k_descale=layer._k_scale, k_descale=layer._k_scale,
v_descale=layer._v_scale, v_descale=layer._v_scale,
s_aux=self.sinks,
) )
return output return output
@ -767,6 +768,7 @@ def cascade_attention(
q_descale: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None,
s_aux: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert alibi_slopes is None, "Cascade attention does not support ALiBi." assert alibi_slopes is None, "Cascade attention does not support ALiBi."
# TODO: Support sliding window. # TODO: Support sliding window.
@ -801,6 +803,9 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
) )
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])