mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[FIXBUG ] Allow disabling rocm_aiter_fa backend for ROCm GPUs not compatible with AITER (#22795)
Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@ -2,7 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
from dataclasses import replace
|
||||
from typing import Optional
|
||||
from importlib.util import find_spec
|
||||
from typing import Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -20,8 +21,6 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
AiterFlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||||
TreeAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
@ -34,6 +33,17 @@ logger = init_logger(__name__)
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class EagleAttentionMetadata(Protocol):
|
||||
# Required attributes
|
||||
num_actual_tokens: int
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
|
||||
class EagleProposer:
|
||||
|
||||
def __init__(
|
||||
@ -97,6 +107,20 @@ class EagleProposer:
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
|
||||
if current_platform.is_rocm():
|
||||
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
||||
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
|
||||
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
AiterFlashAttentionMetadata)
|
||||
rocm_types.append(AiterFlashAttentionMetadata)
|
||||
self.allowed_attn_types = tuple(rocm_types)
|
||||
else:
|
||||
self.allowed_attn_types = (FlashAttentionMetadata,
|
||||
TreeAttentionMetadata)
|
||||
|
||||
# Parse the speculative token tree.
|
||||
spec_token_tree = self.speculative_config.speculative_token_tree
|
||||
self.tree_choices: list[tuple[int,
|
||||
@ -165,7 +189,7 @@ class EagleProposer:
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
@ -225,25 +249,13 @@ class EagleProposer:
|
||||
# TODO: Currently, MTP module released by deepseek only has
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
|
||||
# On ROCm, both AiterFlashAttention and TritonAttention
|
||||
# support multi-token eagle spec decode.
|
||||
if current_platform.is_rocm():
|
||||
assert isinstance(
|
||||
attn_metadata,
|
||||
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
|
||||
FlashAttentionMetadata))
|
||||
else:
|
||||
# Currently, only FlashAttention supports multi-token eagle spec
|
||||
# decode. This is because the code below makes assumptions about
|
||||
# attn_metadata attributes available.
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
assert isinstance(attn_metadata, self.allowed_attn_types)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
@ -449,7 +461,7 @@ class EagleProposer:
|
||||
num_tokens, -1)
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_tokens)
|
||||
else:
|
||||
@ -508,19 +520,19 @@ class EagleProposer:
|
||||
"""
|
||||
# E.g.
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||||
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# This function computes the intermediate values:
|
||||
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
||||
# And returns:
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
# common_attn_metadata.seq_lens{_cpu}:
|
||||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||||
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
@ -564,9 +576,9 @@ class EagleProposer:
|
||||
old_query_start_locs_expanded = np.repeat(
|
||||
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
|
||||
# Final token indices are:
|
||||
# [0, 1, // req 1
|
||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||
# [0, 1, // req 1
|
||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||
token_indices_np = token_offests + old_query_start_locs_expanded
|
||||
token_indices = torch.from_numpy(token_indices_np).to(
|
||||
device, non_blocking=True)
|
||||
@ -616,20 +628,18 @@ class EagleProposer:
|
||||
target_language_model = target_model
|
||||
# share embed_tokens with the target model if needed
|
||||
if get_pp_group().world_size == 1 \
|
||||
and self.model.model.embed_tokens.weight.shape \
|
||||
== target_language_model.model.embed_tokens.weight.shape:
|
||||
and self.model.model.embed_tokens.weight.shape \
|
||||
== target_language_model.model.embed_tokens.weight.shape:
|
||||
logger.info(
|
||||
"Assuming the EAGLE head shares the same vocab embedding" \
|
||||
" with the target model."
|
||||
)
|
||||
"Assuming the EAGLE head shares the same vocab embedding"
|
||||
" with the target model.")
|
||||
del self.model.model.embed_tokens
|
||||
self.model.model.embed_tokens = (
|
||||
target_language_model.model.embed_tokens)
|
||||
else:
|
||||
logger.info(
|
||||
"The EAGLE head's vocab embedding will be loaded separately" \
|
||||
" from the target model."
|
||||
)
|
||||
"The EAGLE head's vocab embedding will be loaded separately"
|
||||
" from the target model.")
|
||||
|
||||
# share lm_head with the target model if needed
|
||||
# some model definition do not define lm_head explicitly
|
||||
|
Reference in New Issue
Block a user