mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Enable conversion of multimodal models to pooling tasks (#24451)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
committed by
simon-mo
parent
89da8d9d09
commit
bbb70036cb
114
tests/models/language/pooling/test_mm_classifier_conversion.py
Normal file
114
tests/models/language/pooling/test_mm_classifier_conversion.py
Normal file
@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def test_idefics_multimodal(
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
with vllm_runner(model_name="HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
runner="pooling",
|
||||
task="classify",
|
||||
convert="classify",
|
||||
load_format="dummy",
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
disable_log_stats=True,
|
||||
dtype="bfloat16") as vllm_model:
|
||||
llm = vllm_model.get_llm()
|
||||
outputs = llm.classify(prompts)
|
||||
for output in outputs:
|
||||
assert len(output.outputs.probs) == 2
|
||||
|
||||
|
||||
def update_config(config):
|
||||
config.text_config.update({
|
||||
"architectures": ["Gemma3ForSequenceClassification"],
|
||||
"classifier_from_token": ["A", "B", "C", "D", "E"],
|
||||
"method":
|
||||
"no_post_processing",
|
||||
"id2label": {
|
||||
"A": "Chair",
|
||||
"B": "Couch",
|
||||
"C": "Table",
|
||||
"D": "Bed",
|
||||
"E": "Cupboard"
|
||||
},
|
||||
})
|
||||
return config
|
||||
|
||||
|
||||
def test_gemma_multimodal(
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"system",
|
||||
"content":
|
||||
"""
|
||||
You are a helpful assistant. You will be given a product description
|
||||
which may also include an image. Classify the following product into
|
||||
one of the categories:
|
||||
|
||||
A = chair
|
||||
B = couch
|
||||
C = table
|
||||
D = bed
|
||||
E = cupboard
|
||||
|
||||
You'll answer with exactly one letter (A, B, C, D, or E)."""
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url":
|
||||
"https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg"
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "A fine 19th century piece of furniture."
|
||||
}]
|
||||
}]
|
||||
|
||||
with vllm_runner(model_name="google/gemma-3-4b-it",
|
||||
runner="pooling",
|
||||
task="classify",
|
||||
convert="classify",
|
||||
load_format="auto",
|
||||
hf_overrides=update_config,
|
||||
override_pooler_config={"pooling_type": "LAST"},
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
disable_log_stats=True,
|
||||
dtype="bfloat16") as vllm_model:
|
||||
|
||||
llm = vllm_model.get_llm()
|
||||
prompts = llm.preprocess_chat(messages)
|
||||
|
||||
result = llm.classify(prompts)
|
||||
assert result[0].outputs.probs[0] > 0.95
|
||||
assert all(c < 0.05 for c in result[0].outputs.probs[1:])
|
@ -703,6 +703,106 @@ class LLM:
|
||||
|
||||
return outputs
|
||||
|
||||
def preprocess_chat(
|
||||
self,
|
||||
messages: Union[list[ChatCompletionMessageParam],
|
||||
list[list[ChatCompletionMessageParam]]],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[TokensPrompt]:
|
||||
"""
|
||||
Generate prompt for a chat conversation. The pre-processed
|
||||
prompt can then be used as input for the other LLM methods.
|
||||
|
||||
Refer to `chat` for a complete description of the arguments.
|
||||
Returns:
|
||||
A list of `TokensPrompts` objects containing the tokenized
|
||||
prompt after chat template interpolation, and the
|
||||
pre-processed multi-modal inputs.
|
||||
"""
|
||||
list_of_messages: list[list[ChatCompletionMessageParam]]
|
||||
|
||||
# Handle multi and single conversations
|
||||
if is_list_of(messages, list):
|
||||
# messages is list[list[...]]
|
||||
list_of_messages = cast(list[list[ChatCompletionMessageParam]],
|
||||
messages)
|
||||
else:
|
||||
# messages is list[...]
|
||||
list_of_messages = [
|
||||
cast(list[ChatCompletionMessageParam], messages)
|
||||
]
|
||||
|
||||
tokenizer = self.get_tokenizer(lora_request)
|
||||
model_config = self.llm_engine.get_model_config()
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
_chat_template_kwargs: dict[str, Any] = dict(
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
)
|
||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||
|
||||
prompts: list[TokensPrompt] = []
|
||||
|
||||
for msgs in list_of_messages:
|
||||
# NOTE: _parse_chat_message_content_parts() currently doesn't
|
||||
# handle mm_processor_kwargs, since there is no implementation in
|
||||
# the chat message parsing for it.
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
msgs,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt_token_ids = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=msgs,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
else:
|
||||
prompt_str = apply_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
conversation=conversation,
|
||||
model_config=model_config,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
# Special tokens are already included in chat templates so
|
||||
# should not be added by the tokenizer in this case.
|
||||
prompt_token_ids = tokenizer.encode(prompt_str,
|
||||
add_special_tokens=False)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
if mm_processor_kwargs is not None:
|
||||
prompt["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
return prompts
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: Union[list[ChatCompletionMessageParam],
|
||||
@ -769,80 +869,18 @@ class LLM:
|
||||
A list of `RequestOutput` objects containing the generated
|
||||
responses in the same order as the input messages.
|
||||
"""
|
||||
list_of_messages: list[list[ChatCompletionMessageParam]]
|
||||
|
||||
# Handle multi and single conversations
|
||||
if is_list_of(messages, list):
|
||||
# messages is list[list[...]]
|
||||
list_of_messages = cast(list[list[ChatCompletionMessageParam]],
|
||||
messages)
|
||||
else:
|
||||
# messages is list[...]
|
||||
list_of_messages = [
|
||||
cast(list[ChatCompletionMessageParam], messages)
|
||||
]
|
||||
|
||||
tokenizer = self.get_tokenizer(lora_request)
|
||||
model_config = self.llm_engine.get_model_config()
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
_chat_template_kwargs: dict[str, Any] = dict(
|
||||
prompts = self.preprocess_chat(
|
||||
messages=messages,
|
||||
lora_request=lora_request,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||
|
||||
prompts: list[Union[TokensPrompt, TextPrompt]] = []
|
||||
|
||||
for msgs in list_of_messages:
|
||||
# NOTE: _parse_chat_message_content_parts() currently doesn't
|
||||
# handle mm_processor_kwargs, since there is no implementation in
|
||||
# the chat message parsing for it.
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
msgs,
|
||||
model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt_token_ids = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=msgs,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
else:
|
||||
prompt_str = apply_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
conversation=conversation,
|
||||
model_config=model_config,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
# Special tokens are already included in chat templates so
|
||||
# should not be added by the tokenizer in this case.
|
||||
prompt_token_ids = tokenizer.encode(prompt_str,
|
||||
add_special_tokens=False)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
if mm_processor_kwargs is not None:
|
||||
prompt["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
return self.generate(
|
||||
prompts,
|
||||
|
@ -19,10 +19,11 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.models.adapters import (as_embedding_model,
|
||||
as_reward_model,
|
||||
as_seq_cls_model)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.model_executor.models.adapters import (
|
||||
as_embedding_model, as_reward_model, as_seq_cls_model,
|
||||
try_create_mm_pooling_model_cls)
|
||||
from vllm.model_executor.models.interfaces import (SupportsQuant,
|
||||
supports_multimodal)
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -183,6 +184,15 @@ def get_model_architecture(
|
||||
"performance may not be optimal.", arch)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type != "none" and supports_multimodal(model_cls):
|
||||
logger.debug_once("Detected conversion of Multi Modal model.")
|
||||
converted = try_create_mm_pooling_model_cls(model_cls)
|
||||
if converted is not None:
|
||||
logger.debug_once("Creating wrapper class to forward pooler.")
|
||||
return converted, arch
|
||||
else:
|
||||
logger.debug_once("Attempting direct conversion.")
|
||||
|
||||
if convert_type == "none":
|
||||
pass
|
||||
elif convert_type == "embed":
|
||||
|
@ -1,12 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.models.config import VerifyAndUpdateConfig
|
||||
@ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
return model_name + pooling_suffix
|
||||
|
||||
|
||||
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
|
||||
class CallVisitor(ast.NodeVisitor):
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Name):
|
||||
self.calls.append(node.func.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
visitor = CallVisitor()
|
||||
visitor.visit(ast.parse(inspect.getsource(orig_cls)))
|
||||
if "init_vllm_registered_model" not in visitor.calls:
|
||||
return None
|
||||
|
||||
class ModelForPooling(orig_cls, VllmModelForPooling):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
self.pooler = self.get_language_model().pooler
|
||||
|
||||
return ModelForPooling # type: ignore
|
||||
|
||||
|
||||
def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
@ -399,6 +437,7 @@ def load_weights_using_from_2_way_softmax(
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
|
||||
model_config = model.vllm_config.model_config
|
||||
|
||||
tokens = getattr(model.config, "classifier_from_token", [])
|
||||
tokens = cast(list[int], tokens)
|
||||
assert len(tokens) == 2
|
||||
@ -406,9 +445,10 @@ def load_weights_using_from_2_way_softmax(
|
||||
if model.config.tie_word_embeddings:
|
||||
model.lm_head = model.model.embed_tokens
|
||||
else:
|
||||
quant_config = model.vllm_config.quant_config
|
||||
model.lm_head = ParallelLMHead(model.config.vocab_size,
|
||||
model.config.hidden_size,
|
||||
quant_config=model.quant_config)
|
||||
quant_config=quant_config)
|
||||
|
||||
loader = AutoWeightsLoader(model)
|
||||
loaded_weights = loader.load_weights(weights)
|
||||
@ -452,9 +492,10 @@ def load_weights_no_post_processing(model,
|
||||
if model.config.tie_word_embeddings:
|
||||
model.lm_head = model.model.embed_tokens
|
||||
else:
|
||||
quant_config = model.vllm_config.quant_config
|
||||
model.lm_head = ParallelLMHead(model.config.vocab_size,
|
||||
model.config.hidden_size,
|
||||
quant_config=model.quant_config)
|
||||
quant_config=quant_config)
|
||||
|
||||
loader = AutoWeightsLoader(model)
|
||||
loaded_weights = loader.load_weights(weights)
|
||||
|
@ -512,7 +512,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
architectures=["Gemma3ForCausalLM"],
|
||||
)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
if hasattr(self.language_model, "logits_processor"):
|
||||
# The logits processor can be unset if we're using
|
||||
# automatic conversion to pooling model.
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
Reference in New Issue
Block a user