[Bugfix] support tie_word_embeddings for all models (#5724)

This commit is contained in:
Zijian Hu
2024-08-19 20:00:04 -07:00
committed by GitHub
parent 0df7ec0b2d
commit f4fc7337bf
30 changed files with 90 additions and 16 deletions

View File

@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.unpadded_vocab_size = config.vocab_size

View File

@ -331,6 +331,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -821,6 +821,8 @@ class BartForConditionalGeneration(nn.Module):
lora_config: Optional[LoRAConfig] = None):
super().__init__()
# currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.config = config
self.model = BartModel(config,
cache_config,

View File

@ -494,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
super().__init__()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.config = config
self.multimodal_config = multimodal_config

View File

@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config)
self.lm_head = self.transformer.word_embeddings
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -356,6 +356,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
if self.config.tie_word_embeddings:
self.transformer.output_layer.weight = (
self.transformer.embedding.weight)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()

View File

@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module):
) -> None:
super().__init__()
self.config = config
# currently all existing command R models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module):
):
super().__init__()
self.config = config
if config.tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Dbrx models.")
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, cache_config, quant_config)

View File

@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -331,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
super().__init__()
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.lora_config = lora_config
self.quant_config = quant_config

View File

@ -323,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
del lora_config # Unused.
super().__init__()
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(

View File

@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module):
cache_config,
quant_config,
prefix="transformer")
self.lm_head = self.transformer.wte
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
self.lm_head = self.transformer.wte
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(
self.transformer.vocab_size,
self.transformer.embed_dim,
org_num_embeddings=self.config.vocab_size)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__(
self,
config,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -264,6 +264,8 @@ class InternLM2ForCausalLM(nn.Module):
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -291,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config)
self.lm_head = self.transformer.wte
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
else:

View File

@ -313,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
@ -331,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`LlavaImageInputs`
"""

View File

@ -545,7 +545,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
@ -566,7 +566,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
batch.
pixel_values: The pixels in each grid patch for each input image.
image_sizes: The original `(height, width)` for each input image.
See also:
:class:`LlavaNextImageInputs`
"""

View File

@ -496,6 +496,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
# and config class
self.config = config
self.multimodal_config = multimodal_config

View File

@ -359,6 +359,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()

View File

@ -347,6 +347,8 @@ class MixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -307,7 +307,11 @@ class OPTForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = OPTModel(config, cache_config, quant_config)
self.lm_head = self.model.decoder.embed_tokens
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.word_embed_proj_dim)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -262,6 +262,8 @@ class OrionForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -260,6 +260,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
super().__init__()
self.config = config
# lm_head use bias, cannot share word embeddings
assert not config.tie_word_embeddings
self.lora_config = lora_config
self.quant_config = quant_config

View File

@ -368,6 +368,8 @@ class Phi3SmallForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -449,4 +451,3 @@ class Phi3SmallForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)

View File

@ -477,6 +477,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -252,6 +252,8 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -385,6 +385,8 @@ class Qwen2MoeForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -243,6 +243,8 @@ class StablelmForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -313,6 +313,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()