mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] support tie_word_embeddings
for all models (#5724)
This commit is contained in:
@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.num_experts = config.num_local_experts
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
|
@ -331,6 +331,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -821,6 +821,8 @@ class BartForConditionalGeneration(nn.Module):
|
||||
lora_config: Optional[LoRAConfig] = None):
|
||||
|
||||
super().__init__()
|
||||
# currently all existing BART models have `tie_word_embeddings` enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.config = config
|
||||
self.model = BartModel(config,
|
||||
cache_config,
|
||||
|
@ -494,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
super().__init__()
|
||||
|
||||
# currently all existing BLIP-2 models have `tie_word_embeddings`
|
||||
# enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
|
@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = BloomModel(config, cache_config, quant_config)
|
||||
self.lm_head = self.transformer.word_embeddings
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.word_embeddings
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -356,6 +356,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||
8192)
|
||||
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.transformer.output_layer.weight = (
|
||||
self.transformer.embedding.weight)
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# currently all existing command R models have `tie_word_embeddings`
|
||||
# enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if config.tie_word_embeddings:
|
||||
raise ValueError(
|
||||
"tie_word_embeddings is not supported for Dbrx models.")
|
||||
self.quant_config = quant_config
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.transformer = DbrxModel(config, cache_config, quant_config)
|
||||
|
@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -331,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
# currently all existing Gemma models have `tie_word_embeddings` enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
@ -323,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# currently all existing Gemma models have `tie_word_embeddings` enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.quant_config = quant_config
|
||||
self.model = Gemma2Model(config, cache_config, quant_config)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
|
@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module):
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="transformer")
|
||||
self.lm_head = self.transformer.wte
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
|
||||
lora_config)
|
||||
self.lm_head = self.transformer.wte
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.transformer.vocab_size,
|
||||
self.transformer.embed_dim,
|
||||
org_num_embeddings=self.config.vocab_size)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: GPTNeoXConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.embed_out.weight = self.gpt_neox.embed_in.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -264,6 +264,8 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
self.output = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.output.weight = self.model.tok_embeddings.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
@ -291,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = JAISModel(config, cache_config, quant_config)
|
||||
self.lm_head = self.transformer.wte
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
if hasattr(config, "width_scale"):
|
||||
self.output_logits_scale = config.width_scale
|
||||
else:
|
||||
|
@ -313,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
|
||||
|
||||
To reserve space in KV cache, we have to insert placeholder tokens
|
||||
before they are inputted to the model, so the input processor prepends
|
||||
before they are inputted to the model, so the input processor prepends
|
||||
additional image tokens (denoted as `32000`), resulting in:
|
||||
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
|
||||
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
|
||||
@ -331,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values: The pixels in each input image.
|
||||
|
||||
|
||||
See also:
|
||||
:class:`LlavaImageInputs`
|
||||
"""
|
||||
|
@ -545,7 +545,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
9047, 13566, 29901]`.
|
||||
|
||||
To reserve space in KV cache, we have to insert placeholder tokens
|
||||
before they are inputted to the model, so the input processor prepends
|
||||
before they are inputted to the model, so the input processor prepends
|
||||
additional image tokens (denoted as `32000`), resulting in:
|
||||
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
|
||||
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
|
||||
@ -566,7 +566,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
batch.
|
||||
pixel_values: The pixels in each grid patch for each input image.
|
||||
image_sizes: The original `(height, width)` for each input image.
|
||||
|
||||
|
||||
See also:
|
||||
:class:`LlavaNextImageInputs`
|
||||
"""
|
||||
|
@ -496,6 +496,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# All MiniCPM-V models disable `tie_word_embeddings` but
|
||||
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
|
||||
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
|
||||
# and config class
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
|
@ -359,6 +359,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
@ -347,6 +347,8 @@ class MixtralForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
@ -307,7 +307,11 @@ class OPTForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = OPTModel(config, cache_config, quant_config)
|
||||
self.lm_head = self.model.decoder.embed_tokens
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.decoder.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.word_embed_proj_dim)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -262,6 +262,8 @@ class OrionForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -260,6 +260,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
# lm_head use bias, cannot share word embeddings
|
||||
assert not config.tie_word_embeddings
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
@ -368,6 +368,8 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -449,4 +451,3 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)
|
||||
|
@ -477,6 +477,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -252,6 +252,8 @@ class QWenLMHeadModel(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.transformer.wte.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -385,6 +385,8 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -243,6 +243,8 @@ class StablelmForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
@ -313,6 +313,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
Reference in New Issue
Block a user