mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[bugfix] fix early import of flash attention (#12959)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -14,8 +14,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import (
|
||||
PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState,
|
||||
compute_slot_mapping, compute_slot_mapping_start_idx,
|
||||
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx, get_flash_attn_version,
|
||||
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
||||
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
|
||||
is_block_tables_empty)
|
||||
@ -640,6 +640,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -759,7 +760,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
out=prefill_output,
|
||||
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
@ -782,7 +783,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=logits_soft_cap,
|
||||
out=prefill_output,
|
||||
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
@ -811,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softcap=logits_soft_cap,
|
||||
block_table=decode_meta.block_tables,
|
||||
out=decode_output,
|
||||
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
@ -832,7 +833,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
out=decode_output.unsqueeze(1),
|
||||
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
return output
|
||||
|
||||
|
@ -12,7 +12,7 @@ from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl, T)
|
||||
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -181,6 +181,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
@ -515,7 +516,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
max_seqlen_k=max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
attn_output = attn_output\
|
||||
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
|
@ -587,11 +587,11 @@ def get_num_prefill_decode_query_kv_tokens(
|
||||
num_decode_query_tokens)
|
||||
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason, is_fa_version_supported)
|
||||
def get_flash_attn_version():
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason, is_fa_version_supported)
|
||||
|
||||
def flash_attn_version():
|
||||
# if hopper default to FA3, otherwise stick to FA2 for now
|
||||
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
||||
# use FA3 as default for both
|
||||
@ -610,7 +610,5 @@ try:
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
|
||||
VLLM_FLASH_ATTN_VERSION = flash_attn_version()
|
||||
except (ImportError, AssertionError):
|
||||
VLLM_FLASH_ATTN_VERSION = None
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
|
@ -10,7 +10,7 @@ import triton.language as tl
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
@ -132,6 +132,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -205,7 +206,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
window_size=self.sliding_window,
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
return output
|
||||
|
||||
@ -227,7 +228,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
return output
|
||||
|
||||
|
Reference in New Issue
Block a user