mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix][Speculative Decoding] Extend Eagle quantization config fix to llama_eagle.py (#26590)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
@ -12,6 +12,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||
@ -37,6 +38,17 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
del self.input_layernorm
|
||||
self.input_layernorm = nn.Identity()
|
||||
|
||||
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
|
||||
"""Use drafter's quantization config instead of verifier's."""
|
||||
draft_model_config = vllm_config.speculative_config.draft_model_config
|
||||
draft_load_config = vllm_config.load_config
|
||||
|
||||
return (
|
||||
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
|
||||
if draft_model_config
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
Reference in New Issue
Block a user