[Speculators][Speculative Decoding] Add Qwen Eagle3 Support (#21835)

Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
This commit is contained in:
Dipika Sikka
2025-08-01 22:43:37 -04:00
committed by GitHub
parent a65f46be5e
commit 9f9c38c392
4 changed files with 46 additions and 11 deletions

View File

@ -6,11 +6,21 @@ import torch
@pytest.mark.parametrize(
"model_path",
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"),
("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
def test_llama(vllm_runner, example_prompts, model_path):
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
assert vllm_outputs
@pytest.mark.parametrize(
"model_path",
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
def test_qwen(vllm_runner, example_prompts, model_path):
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
assert vllm_outputs

View File

@ -3175,10 +3175,19 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
if self.method == "eagle3" and self.target_model_config and \
"llama" not in self.target_model_config.hf_text_config.model_type:
from vllm.transformers_utils.configs import SpeculatorsConfig
eagle3_target_supported = ["llama"]
if self.draft_model_config and isinstance(
self.draft_model_config.hf_config, SpeculatorsConfig):
eagle3_target_supported.append("qwen")
if self.method == "eagle3" and self.target_model_config and not any(
supported_model in
self.target_model_config.hf_text_config.model_type
for supported_model in eagle3_target_supported):
raise ValueError(
"Eagle3 is only supported for Llama models. "
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}")
return self

View File

@ -330,6 +330,8 @@ class Qwen2Model(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int] = tuple()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@ -350,18 +352,25 @@ class Qwen2Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,

View File

@ -288,6 +288,13 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)