mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Speculators][Speculative Decoding] Add Qwen Eagle3 Support (#21835)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
This commit is contained in:
@ -6,11 +6,21 @@ import torch
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"),
|
||||
("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
|
||||
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
|
||||
def test_llama(vllm_runner, example_prompts, model_path):
|
||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=20)
|
||||
print(vllm_outputs)
|
||||
assert vllm_outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
|
||||
def test_qwen(vllm_runner, example_prompts, model_path):
|
||||
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=20)
|
||||
print(vllm_outputs)
|
||||
assert vllm_outputs
|
||||
|
@ -3175,10 +3175,19 @@ class SpeculativeConfig:
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{self.disable_by_batch_size=}")
|
||||
|
||||
if self.method == "eagle3" and self.target_model_config and \
|
||||
"llama" not in self.target_model_config.hf_text_config.model_type:
|
||||
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||
|
||||
eagle3_target_supported = ["llama"]
|
||||
if self.draft_model_config and isinstance(
|
||||
self.draft_model_config.hf_config, SpeculatorsConfig):
|
||||
eagle3_target_supported.append("qwen")
|
||||
|
||||
if self.method == "eagle3" and self.target_model_config and not any(
|
||||
supported_model in
|
||||
self.target_model_config.hf_text_config.model_type
|
||||
for supported_model in eagle3_target_supported):
|
||||
raise ValueError(
|
||||
"Eagle3 is only supported for Llama models. "
|
||||
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||
f"Got {self.target_model_config.hf_text_config.model_type=}")
|
||||
|
||||
return self
|
||||
|
@ -330,6 +330,8 @@ class Qwen2Model(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers: tuple[int] = tuple()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
@ -350,18 +352,25 @@ class Qwen2Model(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
aux_hidden_states = []
|
||||
for idx, layer in enumerate(
|
||||
self.layers[self.start_layer:self.end_layer]):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if len(aux_hidden_states) > 0:
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
@ -288,6 +288,13 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
|
Reference in New Issue
Block a user