mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Add missing sink tensor into flash attn cascade attn implementation (#26325)
This commit is contained in:
@ -607,6 +607,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
q_descale=layer._q_scale,
|
||||
k_descale=layer._k_scale,
|
||||
v_descale=layer._v_scale,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
return output
|
||||
|
||||
@ -767,6 +768,7 @@ def cascade_attention(
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
s_aux: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert alibi_slopes is None, "Cascade attention does not support ALiBi."
|
||||
# TODO: Support sliding window.
|
||||
@ -801,6 +803,9 @@ def cascade_attention(
|
||||
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||
# enabling its effect during the final attention merge.
|
||||
s_aux=s_aux,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
Reference in New Issue
Block a user