Support llama3 eagle3 head with llama4 verifier (#25961)

Signed-off-by: rahul-tuli <rtuli@redhat.com>
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
Rahul Tuli
2025-10-06 23:26:08 +05:30
committed by GitHub
parent 20db99cc69
commit 05f6846ede
5 changed files with 83 additions and 8 deletions

View File

@ -604,6 +604,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Override to return default layers for Llama
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

View File

@ -21,6 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix
@ -241,7 +242,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
requires_grad=False,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
is_multimodal: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(

View File

@ -64,7 +64,12 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
)
from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .vision import run_dp_sharded_vision_model
@ -717,7 +722,9 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@ -767,6 +774,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
self.language_model.set_aux_hidden_state_layers(layers)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Get the layer indices for auxiliary hidden state outputs.
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
return self.language_model.get_eagle3_aux_hidden_state_layers()
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[Llama4ImagePatchInputs]:

View File

@ -21,6 +21,10 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
- draft_vocab_size: Size of the draft model's vocabulary
- target_hidden_size: Hidden size of the target model
- norm_before_residual: Whether to apply norm before residual connection
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
model to use as auxiliary inputs for the Eagle3 drafter. These layers
provide intermediate hidden states that help the drafter make better
predictions. This is the standard field used in Eagle3 checkpoints.
"""
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
@ -28,3 +32,7 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
"eagle_aux_hidden_state_layer_ids"
]

View File

@ -2943,15 +2943,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
if supports_eagle3(self.model):
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers()
)
else:
if not supports_eagle3(self.model):
raise RuntimeError(
"Model does not support EAGLE3 interface but "
"aux_hidden_state_outputs was requested"
)
# Try to get auxiliary layers from speculative config,
# otherwise use model's default layers
aux_layers = self._get_eagle3_aux_layers_from_config()
if aux_layers:
logger.info(
"Using auxiliary layers from speculative config: %s",
aux_layers,
)
else:
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
self.model.set_aux_hidden_state_layers(aux_layers)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info(
@ -3006,6 +3015,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
)
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
"""Extract Eagle3 auxiliary layer indices from speculative config.
These indices specify which hidden states from the base model should
be used as auxiliary inputs for the Eagle3 drafter model during
speculative decoding.
Returns:
Tuple of layer indices if found in draft model config,
None otherwise.
"""
if not (self.speculative_config and self.speculative_config.draft_model_config):
return None
hf_config = self.speculative_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
return None
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, (
"Cannot reload weights before model is loaded."