Simplify weight loading in Transformers backend (#21382)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-07-23 04:29:43 +01:00
committed by GitHub
parent 3ec7170ff1
commit f154bb9ff0
7 changed files with 53 additions and 76 deletions

View File

@ -177,7 +177,7 @@ TEXT_GENERATION_MODELS = {
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersForCausalLM
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
"hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
# Uses Llama
@ -249,7 +249,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Llama-3.2-1B-Instruct",
"ArthurZ/Ilama-3.2-1B",
"hmellor/Ilama-3.2-1B",
"ibm/PowerLM-3b",
"deepseek-ai/DeepSeek-V2-Lite-Chat",
# [LANGUAGE EMBEDDING]

View File

@ -9,7 +9,7 @@ from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test, multi_gpu_test
MODEL_PATH = "ArthurZ/ilama-3.2-1B"
MODEL_PATH = "hmellor/Ilama-3.2-1B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501

View File

@ -500,7 +500,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
}
_TRANSFORMERS_MODELS = {
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
}

View File

@ -56,7 +56,7 @@ def check_implementation(
"model,model_impl",
[
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
]) # trust_remote_code=True by default
def test_models(
hf_runner: type[HfRunner],

View File

@ -624,13 +624,9 @@ class SupportsQuant:
instance.quant_config = quant_config
# apply model mappings to config for proper config-model matching
# NOTE: `TransformersForCausalLM` is not supported due to how this
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
instance.quant_config.apply_vllm_mapper(
instance.hf_to_vllm_mapper)
if getattr(instance, "packed_modules_mapping", None) is not None:
if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None:
instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if instance.packed_modules_mapping is not None:
instance.quant_config.packed_modules_mapping.update(
instance.packed_modules_mapping)

View File

@ -414,7 +414,7 @@ class ConfigOverride:
setattr(self.config, key, value)
class TransformersModel(nn.Module):
class TransformersModel:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -454,9 +454,6 @@ class TransformersModel(nn.Module):
# method after v4.54.0 is released
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
self.model: PreTrainedModel = AutoModel.from_config(
config,
torch_dtype=model_config.dtype,
@ -620,9 +617,6 @@ class TransformersModel(nn.Module):
for child in module.children():
self.init_parameters(child)
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def forward(
self,
input_ids: Optional[torch.Tensor],
@ -694,7 +688,9 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.config = config
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
@ -716,22 +712,7 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
# NOTE: `SupportsQuant` can be updated after property decorator is removed
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {
name: "model." + name
for name, _ in self.model.model.named_children()
}
return WeightsMapper(
orig_to_new_substr={"model.": "model.model."},
orig_to_new_prefix=prefix_mapper,
)
self.transformers_model.make_empty_intermediate_tensors)
def forward(
self,
@ -740,8 +721,9 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
@ -755,12 +737,10 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
skip_prefixes = ["lm_head."
] if self.config.tie_word_embeddings else None
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
@MULTIMODAL_REGISTRY.register_processor(
@ -772,6 +752,29 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"]
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"visual": "model.visual",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
# Qwen models used "model" as the name for the language model.
# Therefore, we must map each of submodule explicitly to avoid
# conflicts with newer models that use "model.language_model".
"model.embed_tokens": "model.language_model.embed_tokens",
"model.layers": "model.language_model.layers",
"model.norm": "model.language_model.norm",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
@ -780,7 +783,9 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.config = config
self.dtype = vllm_config.model_config.dtype
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
self.transformers_model = TransformersModel(vllm_config=vllm_config,
prefix=prefix)
self.model = self.transformers_model.model
text_config = config.get_text_config()
if get_pp_group().is_last_rank:
@ -803,32 +808,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@property
def hf_to_vllm_mapper(self):
# Backwards compatibility for prev released models
# State dicts back then had different formats
# and cannot be loaded with `AutoModel` mapping
# as is
prefix_mapper = {
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
}
# Don't change the order for QwenVL
if 'Qwen2' in self.config.__class__.__name__:
prefix_mapper["model"] = "model.language_model"
prefix_mapper["visual"] = "model.visual"
return WeightsMapper(orig_to_new_prefix=prefix_mapper, )
self.transformers_model.make_empty_intermediate_tensors)
def forward(
self,
@ -848,8 +828,9 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
input_ids, multimodal_embeds)
input_ids = None
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
model_output = self.transformers_model.forward(input_ids, positions,
intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
@ -898,7 +879,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
if isinstance(num_image_patches, list):
num_image_patches = torch.cat(num_image_patches)
vision_embeddings = self.model.model.get_image_features(
vision_embeddings = self.model.get_image_features(
pixel_values,
**{
k: v.flatten(0, 1)
@ -928,7 +909,7 @@ class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
input_ids: torch.Tensor,
multimodal_embeddings=None,
) -> torch.Tensor:
inputs_embeds = self.model.model.get_input_embeddings()(input_ids)
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0):
mask = (input_ids == self.config.image_token_id)

View File

@ -10,7 +10,7 @@ MODELS_ON_S3 = [
"allenai/OLMoE-1B-7B-0924-Instruct",
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test",
"AMead10/Llama-3.2-1B-Instruct-AWQ",
"ArthurZ/Ilama-3.2-1B",
"hmellor/Ilama-3.2-1B",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
"BAAI/bge-reranker-v2-m3",