Files
vllm-dev/vllm/v1/engine/processor.py
Xiaodong Wang 81eea3d348 vllm fix check on max vocab size (#22471)
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Roger Wang <hey@rogerw.me>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.me>
2025-08-31 20:57:05 +08:00

524 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Mapping
from typing import Any, Literal, Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines)
from vllm.v1.structured_output.backend_xgrammar import (
validate_xgrammar_grammar)
class Processor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: TokenizerGroup,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config
self.tokenizer = tokenizer
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.mm_registry = mm_registry
self.mm_processor_cache = processor_cache_from_config(
vllm_config, mm_registry)
self.input_preprocessor = InputPreprocessor(
self.model_config,
self.tokenizer,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
def _validate_logprobs(
self,
params: SamplingParams,
) -> None:
max_logprobs = self.model_config.max_logprobs
if max_logprobs == -1:
return
# Validate sample logprobs.
if params.logprobs and (params.logprobs == -1
or params.logprobs > max_logprobs):
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# Validate prompt logprobs.
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {params.prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
def _validate_sampling_params(
self,
params: SamplingParams,
lora_request: Optional[LoRARequest],
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!")
if self.tokenizer is None:
# When skip_tokenizer_init=True, we can't validate token IDs
# Skip validation and let the model handle invalid tokens
return
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
vocab_size = len(tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of is not None and params.best_of > 1:
raise ValueError("vLLM V1 does not yet support best_of.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError("vLLM V1 does not support per request "
"user provided logits processors.")
def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest],
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""
if isinstance(params, PoolingParams):
return
self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request)
self._validate_supported_sampling_params(params)
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
Only checks lengths; `None` entries are allowed and will be
auto-hashed downstream.
"""
def _validate_single_prompt(single_prompt: Union[dict, str]) -> None:
if not isinstance(single_prompt, dict):
return
mm_data = single_prompt.get("multi_modal_data")
mm_uuids = single_prompt.get("multi_modal_uuids")
if not mm_data or not mm_uuids:
return
for modality, items in mm_data.items():
if modality in mm_uuids:
data_len = len(items) if isinstance(items, list) else 1
uuid_len = len(mm_uuids[modality]) if isinstance(
mm_uuids[modality], list) else 1
if uuid_len != data_len:
raise ValueError(
f"multi_modal_uuids for modality '{modality}' "
"must have same length as data: got "
f"{uuid_len} uuids vs "
f"{data_len} items.")
else:
raise ValueError(
f"multi_modal_uuids for modality '{modality}' must "
"be provided if multi_modal_data is provided.")
# Handle explicit encoder/decoder prompts or singleton prompt
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt")
if enc is not None:
_validate_single_prompt(enc)
if dec is not None:
_validate_single_prompt(dec)
else:
_validate_single_prompt(prompt) # type: ignore[arg-type]
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
if self.model_config.skip_tokenizer_init and params.guided_decoding:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params.
if (params.guided_decoding.backend != engine_level_backend
and not (engine_level_backend == "auto"
and params.guided_decoding.backend_was_auto)):
raise ValueError(
"Request-level structured output backend selection is no "
"longer supported. The request specified "
f"'{params.guided_decoding.backend}', but vLLM was "
f"initialised with '{engine_level_backend}'. This error "
"can be resolved by removing backend selection from the "
"request.")
else:
params.guided_decoding.backend = engine_level_backend
# Request content validation
if (isinstance(params.guided_decoding.choice, list)
and not params.guided_decoding.choice):
# It is invalid for choice to be an empty list
raise ValueError(f"Choice '{params.guided_decoding.choice}' "
"cannot be an empty list")
if engine_level_backend.startswith("xgrammar"):
# xgrammar with no fallback
validate_xgrammar_grammar(params)
elif engine_level_backend.startswith("guidance"):
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(params, tokenizer=None)
elif engine_level_backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
elif engine_level_backend == "lm-format-enforcer":
# lm format enforcer backend
validate_structured_output_request_lm_format_enforcer(params)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above.
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try:
validate_xgrammar_grammar(params)
params.guided_decoding.backend = "xgrammar"
except ValueError:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
validate_guidance_grammar(params, tokenizer=None)
params.guided_decoding.backend = "guidance"
# Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True
def _maybe_build_mm_hash_overrides(
self,
request_id: str,
prompt: PromptType,
) -> Optional[dict[str, list[str]]]:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
index rather than their content.
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
def _extract_mm_data(p: PromptType):
if isinstance(p, dict) and "encoder_prompt" in p:
enc = p.get("encoder_prompt")
if isinstance(enc, dict):
return enc.get("multi_modal_data")
return None
if isinstance(p, dict):
return p.get("multi_modal_data")
return None
mm_data = _extract_mm_data(prompt)
if not mm_data:
return None
overrides: dict[str, list[str]] = {}
for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1
overrides[modality] = [
f"{request_id}-{modality}-{i}" for i in range(n)
]
return overrides
def process_inputs(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
data_parallel_size):
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size}).")
if arrival_time is None:
arrival_time = time.time()
# Optionally generate multimodal hash overrides to avoid hashing
# multimodal data items by their content as their identifiers.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# request id-modality-index as multimodal hash overrides.
if (self.model_config.multimodal_config and
self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching):
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
request_id, prompt)
else:
# Otherwise, use user-provided uuids as multimodal hash overrides
# if provided.
self._validate_multi_modal_uuids(prompt)
if isinstance(prompt, dict):
mm_hash_overrides = prompt.get("multi_modal_uuids")
else:
mm_hash_overrides = None
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
)
from vllm.platforms import current_platform
current_platform.validate_request(
prompt=prompt,
params=params,
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
self._validate_model_inputs(processed_inputs, lora_request)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
raise NotImplementedError
sampling_params = None
pooling_params = None
if isinstance(params, SamplingParams):
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
if self.tokenizer is not None:
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
else:
pooling_params = params.clone()
# Multimodal related.
mm_features: Optional[list[MultiModalFeatureSpec]] = None
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
decoder_mm_positions = decoder_inputs["mm_placeholders"]
decoder_mm_hashes = decoder_inputs["mm_hashes"]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
mm_features = []
for modality, idx in sorted_mm_idxs:
mm_features.append(
MultiModalFeatureSpec(
data=decoder_mm_inputs[modality][idx],
modality=modality,
identifier=decoder_mm_hashes[modality][idx],
mm_position=decoder_mm_positions[modality][idx]))
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_features=mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank,
)
def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
if encoder_inputs is not None:
self._validate_model_input(encoder_inputs,
lora_request,
prompt_type="encoder")
self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
def _validate_model_input(
self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
max_input_id = max(prompt_ids, default=0)
# NOTE: tokenizer.max_token_id is the tokenizers vocab size while
# self.model_config.get_vocab_size() is the models vocab size.
# For Qwen3 models, the language model has extra tokens that do
# not exist in the tokenizer, and vice versa for multimodal
# placeholder tokens in some multimodal models.
# See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
# and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501
# Here we take the max of the two to determine if a token id is
# truly out-of-vocabulary.
if max_input_id > max(tokenizer.max_token_id,
self.model_config.get_vocab_size() - 1):
raise ValueError(
f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper and Donut
if model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def clear_cache(self) -> None:
self.input_preprocessor.clear_cache()