mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix PP for ChatGLM and Molmo (#9422)
This commit is contained in:
@ -425,7 +425,7 @@ Text Generation
|
||||
-
|
||||
* - :code:`MolmoForCausalLM`
|
||||
- Molmo
|
||||
- Image
|
||||
- T + I
|
||||
- :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
|
@ -118,11 +118,8 @@ class PPTestSettings:
|
||||
# The values displayed here are only a rough indicator of the size of the model
|
||||
|
||||
# yapf: disable
|
||||
GENERATION_MODEL_SETTINGS = {
|
||||
# [DETAILED TESTS]
|
||||
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
||||
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501
|
||||
# [FAST TESTS]
|
||||
TEXT_GENERATION_MODELS = {
|
||||
# [Decoder-only]
|
||||
# Uses Llama
|
||||
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
|
||||
"Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501
|
||||
@ -151,6 +148,7 @@ GENERATION_MODEL_SETTINGS = {
|
||||
"core42/jais-13b-chat": PPTestSettings.fast(),
|
||||
# TODO: Implement PP
|
||||
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
|
||||
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
||||
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
|
||||
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
|
||||
# Uses Llama
|
||||
@ -163,6 +161,7 @@ GENERATION_MODEL_SETTINGS = {
|
||||
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
|
||||
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"microsoft/phi-2": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501
|
||||
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"adept/persimmon-8b-chat": PPTestSettings.fast(),
|
||||
@ -174,40 +173,40 @@ GENERATION_MODEL_SETTINGS = {
|
||||
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
|
||||
# FIXME: Cannot load tokenizer in latest transformers version
|
||||
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
# [Encoder-only]
|
||||
# TODO: Implement PP
|
||||
# "facebook/bart-base": PPTestSettings.fast(),
|
||||
}
|
||||
|
||||
EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated]
|
||||
# [FAST TESTS]
|
||||
EMBEDDING_MODELS = { # type: ignore[var-annotated]
|
||||
# [Text-only]
|
||||
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
|
||||
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501
|
||||
}
|
||||
|
||||
MULTIMODAL_MODEL_SETTINGS = {
|
||||
# [FAST TESTS]
|
||||
MULTIMODAL_MODELS = {
|
||||
# [Decoder-only]
|
||||
"Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
|
||||
"facebook/chameleon-7b": PPTestSettings.fast(),
|
||||
"adept/fuyu-8b": PPTestSettings.fast(),
|
||||
"THUDM/glm-4v-9b": PPTestSettings.fast(trust_remote_code=True),
|
||||
"OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True),
|
||||
"llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True),
|
||||
# TODO: Implement PP
|
||||
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
|
||||
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(trust_remote_code=True),
|
||||
"microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
|
||||
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
|
||||
}
|
||||
|
||||
CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated]
|
||||
# [FAST TESTS]
|
||||
# [Encoder-decoder]
|
||||
# TODO: Implement PP
|
||||
# "facebook/bart-base": PPTestSettings.fast(),
|
||||
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
@ -323,7 +322,7 @@ def _compare_tp(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"test_options"),
|
||||
[
|
||||
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
|
||||
params for model_name, settings in TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
],
|
||||
@ -350,7 +349,7 @@ def test_tp_language_generation(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"test_options"),
|
||||
[
|
||||
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
|
||||
params for model_name, settings in EMBEDDING_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
],
|
||||
@ -377,7 +376,7 @@ def test_tp_language_embedding(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"test_options"),
|
||||
[
|
||||
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
|
||||
params for model_name, settings in MULTIMODAL_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
],
|
||||
|
@ -13,8 +13,9 @@ from torch.nn import LayerNorm
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -22,8 +23,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -39,7 +39,9 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -150,6 +152,10 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
|
||||
|
||||
|
||||
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
|
||||
@ -161,8 +167,8 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
input_ids = inputs.get("prompt_token_ids")
|
||||
position_ids = inputs.get("position_ids")
|
||||
input_ids = inputs["prompt_token_ids"]
|
||||
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.model,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code)
|
||||
@ -171,20 +177,19 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
raw_batch_data = tokenizer.apply_chat_template(
|
||||
conversation=[{
|
||||
"role": "user",
|
||||
"image": inputs['multi_modal_data']["image"],
|
||||
"content": inputs['prompt']
|
||||
"image": multi_modal_data["image"],
|
||||
"content": inputs['prompt'],
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True).data
|
||||
return_dict=True,
|
||||
).data
|
||||
except Exception:
|
||||
logger.error("Failed to process content (%s)", inputs['prompt'])
|
||||
raise
|
||||
input_ids = raw_batch_data['input_ids'][0].tolist()
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = list(range(len(input_ids)))
|
||||
boi_token_id = hf_config.boi_token_id
|
||||
eoi_token_id = hf_config.eoi_token_id
|
||||
boi_positions = find_all_positions(input_ids, boi_token_id)
|
||||
@ -193,7 +198,6 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
assert len(boi_positions) == len(eoi_positions)
|
||||
|
||||
new_input_ids = []
|
||||
new_position_ids = []
|
||||
final_processed_position = 0
|
||||
final_processed_position = 0
|
||||
|
||||
@ -201,29 +205,28 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
assert boi_position < eoi_position
|
||||
new_input_ids.extend(input_ids[final_processed_position:boi_position +
|
||||
1])
|
||||
new_position_ids.extend(
|
||||
list(range(final_processed_position, boi_position + 1)))
|
||||
new_input_ids.extend([input_ids[boi_position + 1]] *
|
||||
image_placeholder_length)
|
||||
new_position_ids.extend([boi_position + 1] * image_placeholder_length)
|
||||
final_processed_position = eoi_position
|
||||
|
||||
new_input_ids.extend(input_ids[final_processed_position:])
|
||||
new_position_ids.extend(
|
||||
list(range(final_processed_position, len(input_ids))))
|
||||
|
||||
assert len(new_input_ids) == len(new_position_ids)
|
||||
prompt = inputs.get("prompt")
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(new_input_ids)
|
||||
|
||||
inputs["prompt_token_ids"] = new_input_ids
|
||||
inputs["position_ids"] = new_position_ids
|
||||
return inputs
|
||||
return token_inputs(
|
||||
prompt_token_ids=new_input_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@ -314,7 +317,7 @@ class GLMMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -357,7 +360,7 @@ class GLMBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@ -428,9 +431,10 @@ class GLMTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.post_layer_norm = config.post_layer_norm
|
||||
@ -439,10 +443,11 @@ class GLMTransformer(nn.Module):
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList([
|
||||
GLMBlock(config, cache_config, quant_config)
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
self.num_layers,
|
||||
lambda prefix: GLMBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
if self.post_layer_norm:
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
@ -450,6 +455,10 @@ class GLMTransformer(nn.Module):
|
||||
self.final_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -457,16 +466,16 @@ class GLMTransformer(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.num_layers):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_caches[i],
|
||||
kv_cache=kv_caches[i - self.start_layer],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
if get_pp_group().is_last_rank and self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@ -476,7 +485,7 @@ class ChatGLMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@ -504,6 +513,9 @@ class ChatGLMModel(nn.Module):
|
||||
else:
|
||||
self.vision = None
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.encoder.make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> GLMImagePixelInputs:
|
||||
|
||||
@ -529,24 +541,26 @@ class ChatGLMModel(nn.Module):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input["pixel_values"] is not None:
|
||||
pixel_values = image_input["pixel_values"].to(
|
||||
dtype=inputs_embeds.dtype)
|
||||
image_embeds = self.vision(pixel_values)
|
||||
|
||||
if image_input["pixel_values"] is not None:
|
||||
pixel_values = image_input["pixel_values"].to(
|
||||
dtype=inputs_embeds.dtype)
|
||||
image_embeds = self.vision(pixel_values)
|
||||
boi_token_id = self.config.boi_token_id
|
||||
eoi_token_id = self.config.eoi_token_id
|
||||
|
||||
boi_token_id = self.config.boi_token_id
|
||||
eoi_token_id = self.config.eoi_token_id
|
||||
|
||||
inputs_embeds = merge_glm_vision_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
vision_embeddings=image_embeds,
|
||||
boi_token_id=boi_token_id,
|
||||
eoi_token_id=eoi_token_id)
|
||||
inputs_embeds = merge_glm_vision_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
vision_embeddings=image_embeds,
|
||||
boi_token_id=boi_token_id,
|
||||
eoi_token_id=eoi_token_id)
|
||||
else:
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
@ -555,6 +569,9 @@ class ChatGLMModel(nn.Module):
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -562,7 +579,8 @@ class ChatGLMModel(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
@ -610,7 +628,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, **kwargs)
|
||||
attn_metadata, intermediate_tensors,
|
||||
**kwargs)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@ -656,6 +675,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
@ -30,21 +30,21 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.utils import make_layers
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.transformers_utils.processor import get_processor
|
||||
|
||||
from .utils import get_vit_attn_backend
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (get_vit_attn_backend,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@ -744,6 +744,10 @@ class MolmoModel(nn.Module):
|
||||
assert config.layer_norm_type == "rms"
|
||||
self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -925,16 +929,19 @@ def pad_images(
|
||||
|
||||
|
||||
def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
prompt = inputs.get("prompt", None)
|
||||
multi_modal_data = inputs.get("multi_modal_data", None)
|
||||
if multi_modal_data is not None:
|
||||
image = multi_modal_data.get("image", None)
|
||||
else:
|
||||
image = None
|
||||
prompt = inputs.get("prompt")
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
image = None if multi_modal_data is None else multi_modal_data.get("image")
|
||||
|
||||
processor = cached_get_processor(ctx.model_config.model,
|
||||
trust_remote_code=True,
|
||||
revision=ctx.model_config.code_revision)
|
||||
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
# NOTE: message formatting for raw text prompt is only applied for
|
||||
# offline inference; for online inference, the prompt is always in
|
||||
# instruction format and tokenized.
|
||||
@ -997,9 +1004,13 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
|
||||
multi_modal_data = dict(image=image_data)
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(out["input_ids"])
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=out["input_ids"],
|
||||
prompt=inputs["prompt"],
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
@ -1008,7 +1019,7 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
|
||||
class MolmoForCausalLM(nn.Module, SupportsMultiModal):
|
||||
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1040,6 +1051,9 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal):
|
||||
or config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self,
|
||||
**kwargs: object,
|
||||
@ -1123,31 +1137,36 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal):
|
||||
positions: torch.LongTensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
inputs_embeds = self._merge_multimodal_embeddings(
|
||||
inputs_embeds,
|
||||
image_features,
|
||||
image_input["image_input_idx"],
|
||||
image_input["seq_len"],
|
||||
)
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
inputs_embeds = self._merge_multimodal_embeddings(
|
||||
inputs_embeds,
|
||||
image_features,
|
||||
image_input["image_input_idx"],
|
||||
image_input["seq_len"],
|
||||
)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
|
@ -119,5 +119,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["lm_head."])
|
||||
loader.load_weights(weights)
|
||||
|
@ -61,6 +61,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalInputs)
|
||||
from vllm.multimodal.base import MultiModalData
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
@ -817,7 +818,7 @@ def input_processor_for_qwen2_vl(
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
multi_modal_data = inputs.get("multi_modal_data", None)
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None:
|
||||
return inputs
|
||||
|
||||
@ -830,6 +831,7 @@ def input_processor_for_qwen2_vl(
|
||||
min_pixels = min_pixels if min_pixels else image_processor.min_pixels
|
||||
max_pixels = max_pixels if max_pixels else image_processor.max_pixels
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(Qwen2VLConfig)
|
||||
|
||||
# To avoid redundant processing of vision objects (resize, rescale, etc.),
|
||||
@ -845,14 +847,11 @@ def input_processor_for_qwen2_vl(
|
||||
# return_tensors="pt")
|
||||
# prompt_token_ids = inputs["input_ids"][0].tolist()
|
||||
|
||||
prompt_token_ids = inputs.get("prompt_token_ids", None)
|
||||
if prompt_token_ids is None:
|
||||
prompt = inputs["prompt"]
|
||||
prompt_token_ids = processor.tokenizer(
|
||||
prompt,
|
||||
padding=True,
|
||||
return_tensors=None,
|
||||
)["input_ids"]
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
# Expand image pad tokens.
|
||||
|
||||
@ -894,9 +893,13 @@ def input_processor_for_qwen2_vl(
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels)
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs["prompt"],
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
@ -79,6 +79,9 @@ class AutoWeightsLoader:
|
||||
|
||||
Similarly, the weight loading logic for individual parameters can be
|
||||
overridden by defining a ``weight_loader`` method.
|
||||
|
||||
Detailed weight loading information can be viewed by setting the
|
||||
environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -136,20 +139,27 @@ class AutoWeightsLoader:
|
||||
weight_qualname = self._get_qualname(base_prefix, weight_name)
|
||||
|
||||
if self._can_skip(weight_qualname):
|
||||
logger.debug("Skipping weight %s", weight_qualname)
|
||||
|
||||
continue
|
||||
|
||||
if weight_name != "":
|
||||
if not self._can_ignore_unexpected(weight_qualname):
|
||||
raise ValueError(
|
||||
f"Attempted to load nested weight '{weight_qualname}' "
|
||||
f"into a single parameter '{base_prefix}'")
|
||||
if self._can_ignore_unexpected(weight_qualname):
|
||||
logger.debug("Ignoring weight %s", weight_qualname)
|
||||
|
||||
continue
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Attempted to load nested weight '{weight_qualname}' "
|
||||
f"into a single parameter '{base_prefix}'")
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight_data)
|
||||
|
||||
logger.debug("Loaded weight %s with shape %s", weight_qualname,
|
||||
param.shape)
|
||||
|
||||
yield weight_qualname
|
||||
|
||||
def _load_module(
|
||||
@ -175,21 +185,41 @@ class AutoWeightsLoader:
|
||||
for child_prefix, child_weights in self._groupby_prefix(weights):
|
||||
prefix = self._get_qualname(base_prefix, child_prefix)
|
||||
|
||||
if self._can_skip(prefix):
|
||||
continue
|
||||
|
||||
if child_prefix in child_modules:
|
||||
if self._can_skip(prefix + "."):
|
||||
logger.debug("Skipping module %s", prefix)
|
||||
|
||||
continue
|
||||
|
||||
yield from self._load_module(prefix,
|
||||
child_modules[child_prefix],
|
||||
child_weights)
|
||||
elif child_prefix in child_params:
|
||||
if self._can_skip(prefix):
|
||||
logger.debug("Skipping param %s", prefix)
|
||||
|
||||
continue
|
||||
|
||||
yield from self._load_param(prefix, child_params[child_prefix],
|
||||
child_weights)
|
||||
else:
|
||||
if not self._can_ignore_unexpected(prefix):
|
||||
msg = (f"There is no module or parameter named '{prefix}' "
|
||||
f"in {type(self.module).__name__}")
|
||||
raise ValueError(msg)
|
||||
can_skip_module = self._can_skip(prefix + ".")
|
||||
can_skip_param = self._can_skip(prefix)
|
||||
if can_skip_module or can_skip_param:
|
||||
logger.debug("Skipping missing %s", prefix)
|
||||
|
||||
continue
|
||||
|
||||
can_ignore_module = self._can_ignore_unexpected(prefix + ".")
|
||||
can_ignore_param = self._can_ignore_unexpected(prefix)
|
||||
if can_ignore_module or can_ignore_param:
|
||||
logger.debug("Ignoring missing %s", prefix)
|
||||
|
||||
continue
|
||||
|
||||
msg = (f"There is no module or parameter named '{prefix}' "
|
||||
f"in {type(self.module).__name__}")
|
||||
raise ValueError(msg)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user