[Bugfix][Speculative Decoding] Extend Eagle quantization config fix to llama_eagle.py (#26590)

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
Rahul Tuli
2025-10-13 22:47:13 +05:30
committed by GitHub
parent 134f70b3ed
commit e3b90c1ba2

View File

@ -12,6 +12,7 @@ from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
@ -37,6 +38,17 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
del self.input_layernorm
self.input_layernorm = nn.Identity()
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
"""Use drafter's quantization config instead of verifier's."""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config
return (
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
if draft_model_config
else None
)
@support_torch_compile
class LlamaModel(nn.Module):