mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Support llama3 eagle3 head with llama4 verifier (#25961)
Signed-off-by: rahul-tuli <rtuli@redhat.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
@ -604,6 +604,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
"""Override to return default layers for Llama
|
||||
|
||||
Note: The GPU model runner will override this with layers from
|
||||
the speculative config if available, providing dynamic configuration.
|
||||
"""
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
@ -241,7 +242,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
|
@ -64,7 +64,12 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle3,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
@ -717,7 +722,9 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
info=Mllama4ProcessingInfo,
|
||||
dummy_inputs=Mllama4DummyInputsBuilder,
|
||||
)
|
||||
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
class Llama4ForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@ -767,6 +774,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
"""Set which layers should output auxiliary hidden states for EAGLE3."""
|
||||
# Delegate to underlying language model (Llama4ForCausalLM)
|
||||
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
|
||||
self.language_model.set_aux_hidden_state_layers(layers)
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
"""Get the layer indices for auxiliary hidden state outputs.
|
||||
|
||||
Note: The GPU model runner will override this with layers from
|
||||
the speculative config if available, providing dynamic configuration.
|
||||
"""
|
||||
# Delegate to underlying language model (Llama4ForCausalLM)
|
||||
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
|
||||
return self.language_model.get_eagle3_aux_hidden_state_layers()
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Optional[Llama4ImagePatchInputs]:
|
||||
|
@ -21,6 +21,10 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
|
||||
- draft_vocab_size: Size of the draft model's vocabulary
|
||||
- target_hidden_size: Hidden size of the target model
|
||||
- norm_before_residual: Whether to apply norm before residual connection
|
||||
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
|
||||
model to use as auxiliary inputs for the Eagle3 drafter. These layers
|
||||
provide intermediate hidden states that help the drafter make better
|
||||
predictions. This is the standard field used in Eagle3 checkpoints.
|
||||
"""
|
||||
|
||||
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
|
||||
@ -28,3 +32,7 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
|
||||
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
|
||||
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
|
||||
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
|
||||
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
|
||||
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
|
||||
"eagle_aux_hidden_state_layer_ids"
|
||||
]
|
||||
|
@ -2943,15 +2943,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logger.info("Loading drafter model...")
|
||||
self.drafter.load_model(self.model)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
if supports_eagle3(self.model):
|
||||
self.model.set_aux_hidden_state_layers(
|
||||
self.model.get_eagle3_aux_hidden_state_layers()
|
||||
)
|
||||
else:
|
||||
if not supports_eagle3(self.model):
|
||||
raise RuntimeError(
|
||||
"Model does not support EAGLE3 interface but "
|
||||
"aux_hidden_state_outputs was requested"
|
||||
)
|
||||
|
||||
# Try to get auxiliary layers from speculative config,
|
||||
# otherwise use model's default layers
|
||||
aux_layers = self._get_eagle3_aux_layers_from_config()
|
||||
if aux_layers:
|
||||
logger.info(
|
||||
"Using auxiliary layers from speculative config: %s",
|
||||
aux_layers,
|
||||
)
|
||||
else:
|
||||
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
|
||||
|
||||
self.model.set_aux_hidden_state_layers(aux_layers)
|
||||
time_after_load = time.perf_counter()
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info(
|
||||
@ -3006,6 +3015,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
|
||||
)
|
||||
|
||||
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
|
||||
"""Extract Eagle3 auxiliary layer indices from speculative config.
|
||||
|
||||
These indices specify which hidden states from the base model should
|
||||
be used as auxiliary inputs for the Eagle3 drafter model during
|
||||
speculative decoding.
|
||||
|
||||
Returns:
|
||||
Tuple of layer indices if found in draft model config,
|
||||
None otherwise.
|
||||
"""
|
||||
if not (self.speculative_config and self.speculative_config.draft_model_config):
|
||||
return None
|
||||
|
||||
hf_config = self.speculative_config.draft_model_config.hf_config
|
||||
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
|
||||
return None
|
||||
|
||||
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
|
||||
if layer_ids and isinstance(layer_ids, (list, tuple)):
|
||||
return tuple(layer_ids)
|
||||
|
||||
return None
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
assert getattr(self, "model", None) is not None, (
|
||||
"Cannot reload weights before model is loaded."
|
||||
|
Reference in New Issue
Block a user