mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 07:13:52 +08:00
Signed-off-by: yangxurui <yangxurui@meituan.com> Co-authored-by: yangxurui <yangxurui@meituan.com>
484 lines
17 KiB
Python
484 lines
17 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import warnings
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.config import ModelConfig, ModelDType, RunnerOption
|
|
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
|
from vllm.multimodal.processing import InputProcessingContext
|
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
|
|
|
from .registry import HF_EXAMPLE_MODELS
|
|
|
|
TokensText = tuple[list[int], str]
|
|
|
|
|
|
def check_outputs_equal(
|
|
*,
|
|
outputs_0_lst: Sequence[TokensText],
|
|
outputs_1_lst: Sequence[TokensText],
|
|
name_0: str,
|
|
name_1: str,
|
|
):
|
|
"""
|
|
Compare the two sequences generated by different models,
|
|
which should be equal.
|
|
"""
|
|
assert len(outputs_0_lst) == len(outputs_1_lst)
|
|
|
|
for prompt_idx, (outputs_0,
|
|
outputs_1) in enumerate(zip(outputs_0_lst,
|
|
outputs_1_lst)):
|
|
output_ids_0, output_str_0 = outputs_0
|
|
output_ids_1, output_str_1 = outputs_1
|
|
|
|
# The text and token outputs should exactly match
|
|
fail_msg = (f"Test{prompt_idx}:"
|
|
f"\n{name_0}:\t{output_str_0!r}"
|
|
f"\n{name_1}:\t{output_str_1!r}")
|
|
|
|
assert output_str_0 == output_str_1, fail_msg
|
|
assert output_ids_0 == output_ids_1, fail_msg
|
|
|
|
|
|
# Representation of generated sequence as a tuple of
|
|
# * Token ID list
|
|
# * String
|
|
# * List of top sample logprobs for each sampled token
|
|
#
|
|
# Assumes prompt logprobs were not requested.
|
|
TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int,
|
|
float]],
|
|
SampleLogprobs]]]
|
|
|
|
# Allow for tokens to be represented as str's rather than IDs;
|
|
# tuple of
|
|
# * Token string representations list
|
|
# * String
|
|
# * Optional list of top sample logprobs for each sampled token
|
|
#
|
|
# Assumes prompt logprobs were not requested.
|
|
TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]],
|
|
list[dict[str,
|
|
Logprob]]]]]
|
|
|
|
# Representation of generated sequence as a tuple of
|
|
# * Token ID list
|
|
# * String
|
|
# * Optional list of top sample logprobs for each sampled token
|
|
# * Optional list of top prompt logprobs for each prompt token
|
|
#
|
|
# Allows prompt logprobs to be requested.
|
|
TokensTextLogprobsPromptLogprobs = tuple[
|
|
list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]],
|
|
Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]]]
|
|
|
|
|
|
def check_logprobs_close(
|
|
*,
|
|
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
|
|
TokensTextLogprobsPromptLogprobs,
|
|
TextTextLogprobs]],
|
|
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
|
|
TokensTextLogprobsPromptLogprobs,
|
|
TextTextLogprobs]],
|
|
name_0: str,
|
|
name_1: str,
|
|
num_outputs_0_skip_tokens: int = 0,
|
|
warn_on_mismatch: bool = True,
|
|
always_check_logprobs: bool = False,
|
|
) -> None:
|
|
"""Compare the logprobs of two sequences generated by different models,
|
|
which should be similar but not necessarily equal.
|
|
|
|
How sample logprobs are compared:
|
|
* `always_check_logprobs == True`: set of highest-logprob token ids
|
|
must match between seq0 and seq1 at all sampled token offsets
|
|
* `always_check_logprobs == False`: highest-logprob token ids are
|
|
only compared at sampled token offsets for which generated token
|
|
ids don't match
|
|
|
|
Prompt logprobs must be provided either for both input sequences, or
|
|
for neither. If prompt logprobs are provided, then highest-logprob
|
|
prompt token ids must match between seq0 and seq1 at all prompt token
|
|
offsets.
|
|
|
|
Args:
|
|
outputs_0_lst: First sequence to compare
|
|
outputs_0_lst: Second sequence to compare
|
|
name_0: sequence #0 name
|
|
name_1: sequence #1 name
|
|
num_outputs_0_skip_tokens: If > 0, specifies the number of initial
|
|
sequence #0 tokens & logprobs to discard
|
|
before comparison, i.e. all
|
|
of sequence #1 will be compared to
|
|
sequence #0 beginning at index
|
|
num_outputs_0_skip_tokens
|
|
warn_on_mismatch: Issue a warning if there is token-wise or text-wise
|
|
mismatch between the two sequences
|
|
always_check_logprobs: If true, check logprobs even when tokens match
|
|
"""
|
|
assert len(outputs_0_lst) == len(outputs_1_lst)
|
|
|
|
# Loop through responses to each prompt.
|
|
for prompt_idx, (outputs_0,
|
|
outputs_1) in enumerate(zip(outputs_0_lst,
|
|
outputs_1_lst)):
|
|
assert len(outputs_0) == len(outputs_1)
|
|
if len(outputs_0) == 3:
|
|
assert len(outputs_1) == 3
|
|
# Break out tokens, text & sample logprobs
|
|
# (prompt logprobs were not provided)
|
|
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
|
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
|
elif len(outputs_0) == 4:
|
|
assert len(outputs_1) == 4
|
|
# Break out tokens, text, sample logprobs & prompt logprobs
|
|
(
|
|
output_ids_0,
|
|
output_str_0,
|
|
logprobs_0,
|
|
prompt_logprobs_0,
|
|
) = outputs_0
|
|
(
|
|
output_ids_1,
|
|
output_str_1,
|
|
logprobs_1,
|
|
prompt_logprobs_1,
|
|
) = outputs_1
|
|
|
|
# Test prompt logprobs closeness
|
|
if (prompt_logprobs_0 is not None
|
|
and prompt_logprobs_1 is not None):
|
|
# Both sequences' prompt logprobs lists are not `None``
|
|
# (although individual list elements may be `None`);
|
|
# for each token's logprobs:
|
|
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
|
|
zip(prompt_logprobs_0, prompt_logprobs_1)):
|
|
fail_msg = (
|
|
f"Prompt logprobs test:"
|
|
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
|
|
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
|
|
|
|
if logprobs_elem_0 is None:
|
|
# If the seq 0 token's logprobs are `None`,
|
|
# the seq 1 token's logprobs must be `None`
|
|
assert logprobs_elem_1 is None, fail_msg
|
|
else:
|
|
# If the seq 0 token's logprobs are not `None`,
|
|
# the seq 1 token's logprobs must not be `None`
|
|
assert logprobs_elem_1 is not None, fail_msg
|
|
# Logprobs check: top-k token choices must be the same
|
|
assert (set(logprobs_elem_0.keys()) == set(
|
|
logprobs_elem_1.keys())), fail_msg
|
|
else:
|
|
# Both sequence logprobs lists must be `None`
|
|
fail_msg = (f"Prompt logprobs test:"
|
|
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
|
|
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
|
|
|
|
assert (prompt_logprobs_0 is None
|
|
and prompt_logprobs_1 is None), fail_msg
|
|
else:
|
|
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
|
|
f"{len(outputs_0)} elements were provided: "
|
|
f"{outputs_0}")
|
|
|
|
if logprobs_0 is None:
|
|
logprobs_0 = [None] * len(output_ids_0)
|
|
if logprobs_1 is None:
|
|
logprobs_1 = [None] * len(output_ids_1)
|
|
|
|
# Skip specified number of initial sequence #0 tokens
|
|
# & logprobs, leaving output text as-is for simplicity
|
|
# (text mismatches may generate warnings but do not
|
|
# cause the test to fail.)
|
|
if num_outputs_0_skip_tokens < 0:
|
|
raise ValueError("num_outputs_0_skip_tokens must be non-negative")
|
|
output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
|
|
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
|
|
|
|
# Loop through generated tokens.
|
|
for idx, (output_id_0,
|
|
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
|
|
|
is_tok_mismatch = output_id_0 != output_id_1
|
|
|
|
# If generated tokens don't match
|
|
# or it is desired to always check logprobs,
|
|
# then
|
|
if is_tok_mismatch or always_check_logprobs:
|
|
logprobs_elem_0 = logprobs_0[idx]
|
|
logprobs_elem_1 = logprobs_1[idx]
|
|
|
|
# Each predicted token must be in top N logprobs of the other
|
|
fail_msg = (
|
|
f"Test{prompt_idx}:"
|
|
f"\nMatched tokens:\t{output_ids_0[:idx]}"
|
|
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
|
|
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
|
|
|
|
assert logprobs_elem_0 is not None, fail_msg
|
|
assert logprobs_elem_1 is not None, fail_msg
|
|
assert output_id_0 in logprobs_elem_1, fail_msg
|
|
assert output_id_1 in logprobs_elem_0, fail_msg
|
|
|
|
if warn_on_mismatch and is_tok_mismatch:
|
|
with warnings.catch_warnings():
|
|
# This ensures that repeated warnings are shown
|
|
# in the output, not just the first occurrence
|
|
warnings.simplefilter("always")
|
|
|
|
warnings.warn(fail_msg, stacklevel=2)
|
|
|
|
# Break out since sequences will now diverge.
|
|
break
|
|
else:
|
|
if output_str_0 != output_str_1 and warn_on_mismatch:
|
|
# The token outputs exactly match,
|
|
# so the text outputs should exactly match as well
|
|
fail_msg = (f"Test{prompt_idx}:"
|
|
f"\n{name_0}:\t{output_str_0!r}"
|
|
f"\n{name_1}:\t{output_str_1!r}")
|
|
|
|
with warnings.catch_warnings():
|
|
# This ensures that repeated warnings are shown
|
|
# in the output, not just the first occurrence
|
|
warnings.simplefilter("always")
|
|
|
|
warnings.warn(fail_msg, stacklevel=2)
|
|
|
|
|
|
def build_model_context(
|
|
model_id: str,
|
|
runner: RunnerOption = "auto",
|
|
dtype: ModelDType = "auto",
|
|
model_config_kwargs: Optional[dict[str, Any]] = None,
|
|
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
|
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
|
mm_processor_cache_gb: int = 0,
|
|
):
|
|
"""Creates an InputProcessingContext for a given model.
|
|
|
|
Args:
|
|
model_id: ID of the model being considered.
|
|
mm_processor_kwargs: optional processor kwargs for to be leveraged
|
|
in the input processor, mapper, dummy data creation, etc.
|
|
limit_mm_per_prompt: Multimodal limits.
|
|
|
|
Returns:
|
|
InputProcessingContext for the model being considered.
|
|
"""
|
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
|
model_info.check_available_online(on_fail="skip")
|
|
model_info.check_transformers_version(on_fail="skip")
|
|
|
|
model_config_kwargs = model_config_kwargs or {}
|
|
limit_mm_per_prompt = limit_mm_per_prompt or {}
|
|
model_config = ModelConfig(
|
|
model_id,
|
|
runner=runner,
|
|
tokenizer=model_info.tokenizer or model_id,
|
|
tokenizer_mode=model_info.tokenizer_mode,
|
|
revision=model_info.revision,
|
|
trust_remote_code=model_info.trust_remote_code,
|
|
dtype=dtype,
|
|
seed=0,
|
|
mm_processor_kwargs=mm_processor_kwargs,
|
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
|
mm_processor_cache_gb=mm_processor_cache_gb,
|
|
hf_overrides=model_info.hf_overrides,
|
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
|
enforce_eager=model_info.enforce_eager,
|
|
**model_config_kwargs,
|
|
)
|
|
|
|
return InputProcessingContext(
|
|
model_config,
|
|
tokenizer=cached_tokenizer_from_config(model_config),
|
|
)
|
|
|
|
|
|
def check_embeddings_close(
|
|
*,
|
|
embeddings_0_lst: Sequence[list[float]],
|
|
embeddings_1_lst: Sequence[list[float]],
|
|
name_0: str,
|
|
name_1: str,
|
|
tol: float = 1e-3,
|
|
) -> None:
|
|
assert len(embeddings_0_lst) == len(embeddings_1_lst)
|
|
|
|
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
|
|
zip(embeddings_0_lst, embeddings_1_lst)):
|
|
assert len(embeddings_0) == len(embeddings_1), (
|
|
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
|
|
|
|
sim = F.cosine_similarity(torch.tensor(embeddings_0),
|
|
torch.tensor(embeddings_1),
|
|
dim=0)
|
|
|
|
fail_msg = (f"Test{prompt_idx}:"
|
|
f"\nCosine similarity: \t{sim:.4f}"
|
|
f"\n{name_0}:\t{embeddings_0[:16]!r}"
|
|
f"\n{name_1}:\t{embeddings_1[:16]!r}")
|
|
|
|
assert sim >= 1 - tol, fail_msg
|
|
|
|
|
|
def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
|
|
tensor = torch.tensor(tensor)
|
|
tensor = tensor[..., :dimensions]
|
|
tensor = F.normalize(tensor, p=2, dim=1)
|
|
return tensor
|
|
|
|
|
|
def softmax(data):
|
|
if data.shape[-1] == 1:
|
|
return F.sigmoid(data)
|
|
else:
|
|
return F.softmax(data, dim=-1)
|
|
|
|
|
|
@dataclass
|
|
class ModelInfo:
|
|
name: str
|
|
architecture: str = ""
|
|
dtype: str = "auto"
|
|
hf_dtype: str = "float32"
|
|
hf_overrides: Optional[dict[str, Any]] = None
|
|
default_pooling_type: str = ""
|
|
enable_test: bool = True
|
|
|
|
|
|
@dataclass
|
|
class EmbedModelInfo(ModelInfo):
|
|
mteb_score: Optional[float] = None
|
|
is_matryoshka: bool = False
|
|
matryoshka_dimensions: Optional[list[int]] = None
|
|
|
|
|
|
@dataclass
|
|
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
|
|
default_pooling_type: str = "CLS"
|
|
|
|
|
|
@dataclass
|
|
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
|
|
default_pooling_type: str = "LAST"
|
|
|
|
|
|
@dataclass
|
|
class RerankModelInfo(ModelInfo):
|
|
mteb_score: Optional[float] = None
|
|
|
|
|
|
@dataclass
|
|
class CLSPoolingRerankModelInfo(RerankModelInfo):
|
|
default_pooling_type: str = "CLS"
|
|
|
|
|
|
@dataclass
|
|
class LASTPoolingRerankModelInfo(RerankModelInfo):
|
|
default_pooling_type: str = "LAST"
|
|
|
|
|
|
@dataclass
|
|
class GenerateModelInfo(ModelInfo):
|
|
hf_dtype: str = "auto"
|
|
hf_ppl: Optional[float] = None
|
|
|
|
|
|
def dummy_hf_overrides(
|
|
hf_config: PretrainedConfig,
|
|
*,
|
|
model_arch: str = "",
|
|
exist_overrides: Optional[dict[str, Any]] = None,
|
|
use_original_num_layers: bool = False,
|
|
) -> PretrainedConfig:
|
|
"""
|
|
Dummy HF overrides function used to create dummy model
|
|
with only minimum nums of layer.
|
|
"""
|
|
hf_config.update(exist_overrides or {})
|
|
|
|
text_config = hf_config.get_text_config()
|
|
|
|
# Ensure at least 2 expert per group
|
|
# Since `grouped_topk` assumes top-2
|
|
n_group = getattr(text_config, 'n_group', None)
|
|
num_experts = n_group * 2 if n_group is not None else 2
|
|
|
|
# we use three layers for Gemma-3n to check
|
|
# both normal layer and kv_shared_layer
|
|
if use_original_num_layers:
|
|
# Use the original number of layers from the config
|
|
num_layers = getattr(text_config, 'num_layers', 1)
|
|
num_hidden_layers = getattr(text_config, 'num_hidden_layers', 1)
|
|
else:
|
|
# Use minimal layers for testing
|
|
num_layers = 1
|
|
num_hidden_layers = (3 if model_arch
|
|
== "Gemma3nForConditionalGeneration" else 1)
|
|
|
|
update_dict = {
|
|
"num_layers": num_layers,
|
|
"num_experts": num_experts,
|
|
"num_experts_per_tok": 2,
|
|
"num_local_experts": num_experts,
|
|
# Otherwise there will not be any expert layers
|
|
"first_k_dense_replace": 0,
|
|
# To avoid OOM on DeepSeek-V3
|
|
"n_routed_experts": num_experts,
|
|
# For Gemma-3n
|
|
"num_kv_shared_layers": 1,
|
|
}
|
|
|
|
# Update num_hidden_layers for non-Longcat architectures
|
|
if model_arch != "LongcatFlashForCausalLM" \
|
|
and model_arch != "LongCatFlashMTPModel":
|
|
update_dict["num_hidden_layers"] = num_hidden_layers
|
|
|
|
text_config.update(update_dict)
|
|
|
|
if hasattr(hf_config, "vision_config"):
|
|
hf_config.vision_config.update({
|
|
"num_layers": 1,
|
|
"num_hidden_layers": 1,
|
|
})
|
|
|
|
# e.g.: ibm-granite/granite-speech-3.3-2b
|
|
if hasattr(hf_config, "encoder_config"):
|
|
hf_config.encoder_config.update({
|
|
"num_layers": 1,
|
|
"num_hidden_layers": 1,
|
|
})
|
|
|
|
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
|
|
if hasattr(hf_config, "audio_config"):
|
|
hf_config.audio_config.update({
|
|
"num_layers": 1,
|
|
"num_hidden_layers": 1,
|
|
"encoder_layers": 1,
|
|
})
|
|
|
|
return hf_config
|
|
|
|
|
|
def check_transformers_version(model: str,
|
|
min_transformers_version: Optional[str] = None,
|
|
max_transformers_version: Optional[str] = None):
|
|
from .registry import _HfExamplesInfo
|
|
|
|
return _HfExamplesInfo(model,
|
|
min_transformers_version=min_transformers_version,
|
|
max_transformers_version=max_transformers_version
|
|
).check_transformers_version(on_fail="skip")
|