Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Roger Wang <hey@rogerw.me> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.me>
524 lines
22 KiB
Python
524 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
||
import time
|
||
from collections.abc import Mapping
|
||
from typing import Any, Literal, Optional, Union
|
||
|
||
from vllm.config import VllmConfig
|
||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||
from vllm.inputs.parse import split_enc_dec_inputs
|
||
from vllm.inputs.preprocess import InputPreprocessor
|
||
from vllm.lora.request import LoRARequest
|
||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||
from vllm.multimodal.cache import processor_cache_from_config
|
||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||
from vllm.multimodal.utils import argsort_mm_positions
|
||
from vllm.pooling_params import PoolingParams
|
||
from vllm.sampling_params import SamplingParams
|
||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||
from vllm.v1.engine import EngineCoreRequest
|
||
from vllm.v1.structured_output.backend_guidance import (
|
||
validate_guidance_grammar)
|
||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||
validate_structured_output_request_lm_format_enforcer)
|
||
from vllm.v1.structured_output.backend_outlines import (
|
||
validate_structured_output_request_outlines)
|
||
from vllm.v1.structured_output.backend_xgrammar import (
|
||
validate_xgrammar_grammar)
|
||
|
||
|
||
class Processor:
|
||
|
||
def __init__(
|
||
self,
|
||
vllm_config: VllmConfig,
|
||
tokenizer: TokenizerGroup,
|
||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||
):
|
||
|
||
self.vllm_config = vllm_config
|
||
self.model_config = vllm_config.model_config
|
||
self.cache_config = vllm_config.cache_config
|
||
self.lora_config = vllm_config.lora_config
|
||
self.decoding_config = vllm_config.decoding_config
|
||
self.tokenizer = tokenizer
|
||
|
||
self.generation_config_fields = (
|
||
self.model_config.try_get_generation_config())
|
||
|
||
self.mm_registry = mm_registry
|
||
self.mm_processor_cache = processor_cache_from_config(
|
||
vllm_config, mm_registry)
|
||
|
||
self.input_preprocessor = InputPreprocessor(
|
||
self.model_config,
|
||
self.tokenizer,
|
||
mm_registry,
|
||
mm_processor_cache=self.mm_processor_cache,
|
||
)
|
||
|
||
def _validate_logprobs(
|
||
self,
|
||
params: SamplingParams,
|
||
) -> None:
|
||
max_logprobs = self.model_config.max_logprobs
|
||
if max_logprobs == -1:
|
||
return
|
||
# Validate sample logprobs.
|
||
if params.logprobs and (params.logprobs == -1
|
||
or params.logprobs > max_logprobs):
|
||
raise ValueError(
|
||
f"Requested sample logprobs of {params.logprobs}, "
|
||
f"which is greater than max allowed: {max_logprobs}")
|
||
|
||
# Validate prompt logprobs.
|
||
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
|
||
raise ValueError(
|
||
f"Requested prompt logprobs of {params.prompt_logprobs}, "
|
||
f"which is greater than max allowed: {max_logprobs}")
|
||
|
||
def _validate_sampling_params(
|
||
self,
|
||
params: SamplingParams,
|
||
lora_request: Optional[LoRARequest],
|
||
) -> None:
|
||
self._validate_structured_output(params)
|
||
self._validate_logit_bias(params)
|
||
|
||
if params.allowed_token_ids is None:
|
||
return
|
||
if not params.allowed_token_ids:
|
||
raise ValueError("allowed_token_ids is not None and empty!")
|
||
if self.tokenizer is None:
|
||
# When skip_tokenizer_init=True, we can't validate token IDs
|
||
# Skip validation and let the model handle invalid tokens
|
||
return
|
||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||
vocab_size = len(tokenizer)
|
||
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
||
raise ValueError(
|
||
"allowed_token_ids contains out-of-vocab token id!")
|
||
|
||
def _validate_logit_bias(
|
||
self,
|
||
params: SamplingParams,
|
||
) -> None:
|
||
"""Validate logit_bias token IDs are within vocabulary range."""
|
||
if not params.logit_bias:
|
||
return
|
||
|
||
vocab_size = self.model_config.get_vocab_size()
|
||
invalid_token_ids = []
|
||
|
||
for token_id in params.logit_bias:
|
||
if token_id < 0 or token_id >= vocab_size:
|
||
invalid_token_ids.append(token_id)
|
||
|
||
if invalid_token_ids:
|
||
raise ValueError(
|
||
f"token_id(s) {invalid_token_ids} in logit_bias contain "
|
||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
|
||
|
||
def _validate_supported_sampling_params(
|
||
self,
|
||
params: SamplingParams,
|
||
) -> None:
|
||
# Best of not yet supported.
|
||
if params.best_of is not None and params.best_of > 1:
|
||
raise ValueError("vLLM V1 does not yet support best_of.")
|
||
# Logits processors not supported.
|
||
if params.logits_processors:
|
||
raise ValueError("vLLM V1 does not support per request "
|
||
"user provided logits processors.")
|
||
|
||
def _validate_params(
|
||
self,
|
||
params: Union[SamplingParams, PoolingParams],
|
||
lora_request: Optional[LoRARequest],
|
||
):
|
||
"""
|
||
Validate supported SamplingParam.
|
||
Should raise ValueError if unsupported for API Server.
|
||
"""
|
||
|
||
if isinstance(params, PoolingParams):
|
||
return
|
||
|
||
self._validate_logprobs(params)
|
||
self._validate_sampling_params(params, lora_request)
|
||
self._validate_supported_sampling_params(params)
|
||
|
||
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
|
||
"""
|
||
Validate that user-provided multi_modal_uuids align with
|
||
multi_modal_data in the incoming request prompt(s).
|
||
Only checks lengths; `None` entries are allowed and will be
|
||
auto-hashed downstream.
|
||
"""
|
||
|
||
def _validate_single_prompt(single_prompt: Union[dict, str]) -> None:
|
||
if not isinstance(single_prompt, dict):
|
||
return
|
||
mm_data = single_prompt.get("multi_modal_data")
|
||
mm_uuids = single_prompt.get("multi_modal_uuids")
|
||
if not mm_data or not mm_uuids:
|
||
return
|
||
|
||
for modality, items in mm_data.items():
|
||
if modality in mm_uuids:
|
||
data_len = len(items) if isinstance(items, list) else 1
|
||
uuid_len = len(mm_uuids[modality]) if isinstance(
|
||
mm_uuids[modality], list) else 1
|
||
if uuid_len != data_len:
|
||
raise ValueError(
|
||
f"multi_modal_uuids for modality '{modality}' "
|
||
"must have same length as data: got "
|
||
f"{uuid_len} uuids vs "
|
||
f"{data_len} items.")
|
||
else:
|
||
raise ValueError(
|
||
f"multi_modal_uuids for modality '{modality}' must "
|
||
"be provided if multi_modal_data is provided.")
|
||
|
||
# Handle explicit encoder/decoder prompts or singleton prompt
|
||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||
enc = prompt.get("encoder_prompt")
|
||
dec = prompt.get("decoder_prompt")
|
||
if enc is not None:
|
||
_validate_single_prompt(enc)
|
||
if dec is not None:
|
||
_validate_single_prompt(dec)
|
||
else:
|
||
_validate_single_prompt(prompt) # type: ignore[arg-type]
|
||
|
||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||
if lora_request is not None and not self.lora_config:
|
||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||
"not enabled!")
|
||
|
||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||
if not params.guided_decoding or not self.decoding_config:
|
||
return
|
||
|
||
if self.model_config.skip_tokenizer_init and params.guided_decoding:
|
||
raise ValueError(
|
||
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
|
||
)
|
||
|
||
engine_level_backend = self.decoding_config.backend
|
||
if params.guided_decoding.backend:
|
||
# Request-level backend selection is not supported in V1.
|
||
# The values may differ if `params` is reused and was set
|
||
# to a specific backend based on `auto` behavior in a previous
|
||
# request. We remember that it was set as a result of `auto`
|
||
# using the `_auto` option set on the backend in the params.
|
||
if (params.guided_decoding.backend != engine_level_backend
|
||
and not (engine_level_backend == "auto"
|
||
and params.guided_decoding.backend_was_auto)):
|
||
raise ValueError(
|
||
"Request-level structured output backend selection is no "
|
||
"longer supported. The request specified "
|
||
f"'{params.guided_decoding.backend}', but vLLM was "
|
||
f"initialised with '{engine_level_backend}'. This error "
|
||
"can be resolved by removing backend selection from the "
|
||
"request.")
|
||
else:
|
||
params.guided_decoding.backend = engine_level_backend
|
||
|
||
# Request content validation
|
||
if (isinstance(params.guided_decoding.choice, list)
|
||
and not params.guided_decoding.choice):
|
||
# It is invalid for choice to be an empty list
|
||
raise ValueError(f"Choice '{params.guided_decoding.choice}' "
|
||
"cannot be an empty list")
|
||
|
||
if engine_level_backend.startswith("xgrammar"):
|
||
# xgrammar with no fallback
|
||
validate_xgrammar_grammar(params)
|
||
elif engine_level_backend.startswith("guidance"):
|
||
# TODO: ideally we would have the LLTokenizer here as Lark syntax
|
||
# allows <|special_token|> and similar, see
|
||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||
# Without tokenizer these are disallowed in grammars.
|
||
validate_guidance_grammar(params, tokenizer=None)
|
||
elif engine_level_backend == "outlines":
|
||
# outlines backend
|
||
validate_structured_output_request_outlines(params)
|
||
elif engine_level_backend == "lm-format-enforcer":
|
||
# lm format enforcer backend
|
||
validate_structured_output_request_lm_format_enforcer(params)
|
||
else:
|
||
# NOTE: engine_level_backend must be "auto" here, because we have
|
||
# checked supported_backends above.
|
||
# "auto" is an opt-in to opinionated behavior where we try to
|
||
# choose a backend based on request contents. This is not the
|
||
# default as it is less predictable and subject to change
|
||
# between releases as feature support changes.
|
||
try:
|
||
validate_xgrammar_grammar(params)
|
||
params.guided_decoding.backend = "xgrammar"
|
||
except ValueError:
|
||
# The request either failed validation
|
||
# or includes some jsonschema feature(s) that
|
||
# are not supported in xgrammar. Fall back to guidance.
|
||
validate_guidance_grammar(params, tokenizer=None)
|
||
params.guided_decoding.backend = "guidance"
|
||
# Remember that this backend was set automatically
|
||
params.guided_decoding.backend_was_auto = True
|
||
|
||
def _maybe_build_mm_hash_overrides(
|
||
self,
|
||
request_id: str,
|
||
prompt: PromptType,
|
||
) -> Optional[dict[str, list[str]]]:
|
||
"""Build per-item multimodal hash overrides when enabled. In this case,
|
||
multimodal data items are identified by their request id, modality and
|
||
index rather than their content.
|
||
|
||
Returns a dictionary of modality -> list[str] of overrides, or None if
|
||
disabled or no multimodal data is present.
|
||
"""
|
||
|
||
def _extract_mm_data(p: PromptType):
|
||
if isinstance(p, dict) and "encoder_prompt" in p:
|
||
enc = p.get("encoder_prompt")
|
||
if isinstance(enc, dict):
|
||
return enc.get("multi_modal_data")
|
||
return None
|
||
if isinstance(p, dict):
|
||
return p.get("multi_modal_data")
|
||
return None
|
||
|
||
mm_data = _extract_mm_data(prompt)
|
||
if not mm_data:
|
||
return None
|
||
|
||
overrides: dict[str, list[str]] = {}
|
||
for modality, data in mm_data.items():
|
||
n = len(data) if isinstance(data, list) else 1
|
||
overrides[modality] = [
|
||
f"{request_id}-{modality}-{i}" for i in range(n)
|
||
]
|
||
return overrides
|
||
|
||
def process_inputs(
|
||
self,
|
||
request_id: str,
|
||
prompt: PromptType,
|
||
params: Union[SamplingParams, PoolingParams],
|
||
arrival_time: Optional[float] = None,
|
||
lora_request: Optional[LoRARequest] = None,
|
||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||
trace_headers: Optional[Mapping[str, str]] = None,
|
||
priority: int = 0,
|
||
data_parallel_rank: Optional[int] = None,
|
||
) -> tuple[Optional[str], EngineCoreRequest]:
|
||
|
||
# TODO(woosuk): Support pooling models.
|
||
# TODO(woosuk): Support encoder-decoder models.
|
||
self._validate_lora(lora_request)
|
||
self._validate_params(params, lora_request)
|
||
if trace_headers is not None:
|
||
raise ValueError("V1 does not support tracing yet.")
|
||
|
||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
||
data_parallel_size):
|
||
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
|
||
f"is out of range [0, {data_parallel_size}).")
|
||
|
||
if arrival_time is None:
|
||
arrival_time = time.time()
|
||
|
||
# Optionally generate multimodal hash overrides to avoid hashing
|
||
# multimodal data items by their content as their identifiers.
|
||
|
||
# NOTE: when users explicitly turn off BOTH prefix caching and input
|
||
# processing caching, no multimodal features or embeddings will be
|
||
# reused across requests, therefore identifying multimodal data items
|
||
# by their content is no longer necessary, and we create uuids with
|
||
# request id-modality-index as multimodal hash overrides.
|
||
if (self.model_config.multimodal_config and
|
||
self.model_config.multimodal_config.mm_processor_cache_gb == 0
|
||
and not self.cache_config.enable_prefix_caching):
|
||
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
|
||
request_id, prompt)
|
||
else:
|
||
# Otherwise, use user-provided uuids as multimodal hash overrides
|
||
# if provided.
|
||
self._validate_multi_modal_uuids(prompt)
|
||
if isinstance(prompt, dict):
|
||
mm_hash_overrides = prompt.get("multi_modal_uuids")
|
||
else:
|
||
mm_hash_overrides = None
|
||
|
||
# Process inputs, which includes:
|
||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||
# multimodal data and expand prompt token ids accordingly.
|
||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||
prompt,
|
||
tokenization_kwargs=tokenization_kwargs,
|
||
lora_request=lora_request,
|
||
mm_hash_overrides=mm_hash_overrides,
|
||
)
|
||
from vllm.platforms import current_platform
|
||
current_platform.validate_request(
|
||
prompt=prompt,
|
||
params=params,
|
||
processed_inputs=processed_inputs,
|
||
)
|
||
|
||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||
|
||
self._validate_model_inputs(processed_inputs, lora_request)
|
||
|
||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||
|
||
# TODO: Impl encoder-decoder
|
||
if encoder_inputs is not None:
|
||
raise NotImplementedError
|
||
|
||
sampling_params = None
|
||
pooling_params = None
|
||
if isinstance(params, SamplingParams):
|
||
# TODO: can we avoid cloning here in multiproc case?
|
||
sampling_params = params.clone()
|
||
# If unset max tokens, then generate up to the max_model_len.
|
||
if sampling_params.max_tokens is None:
|
||
sampling_params.max_tokens = (
|
||
self.model_config.max_model_len -
|
||
len(decoder_inputs["prompt_token_ids"]))
|
||
sampling_params.update_from_generation_config(
|
||
self.generation_config_fields, eos_token_id)
|
||
if self.tokenizer is not None:
|
||
sampling_params.update_from_tokenizer(
|
||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||
else:
|
||
pooling_params = params.clone()
|
||
|
||
# Multimodal related.
|
||
mm_features: Optional[list[MultiModalFeatureSpec]] = None
|
||
|
||
if decoder_inputs["type"] == "multimodal":
|
||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||
decoder_mm_positions = decoder_inputs["mm_placeholders"]
|
||
decoder_mm_hashes = decoder_inputs["mm_hashes"]
|
||
|
||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||
# from dictionaries to lists, and sort them by each item's position
|
||
# in the input sequence.
|
||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||
|
||
mm_features = []
|
||
for modality, idx in sorted_mm_idxs:
|
||
mm_features.append(
|
||
MultiModalFeatureSpec(
|
||
data=decoder_mm_inputs[modality][idx],
|
||
modality=modality,
|
||
identifier=decoder_mm_hashes[modality][idx],
|
||
mm_position=decoder_mm_positions[modality][idx]))
|
||
|
||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||
request_id=request_id,
|
||
prompt_token_ids=decoder_inputs["prompt_token_ids"],
|
||
mm_features=mm_features,
|
||
sampling_params=sampling_params,
|
||
pooling_params=pooling_params,
|
||
eos_token_id=eos_token_id,
|
||
arrival_time=arrival_time,
|
||
lora_request=lora_request,
|
||
cache_salt=decoder_inputs.get("cache_salt"),
|
||
priority=priority,
|
||
data_parallel_rank=data_parallel_rank,
|
||
)
|
||
|
||
def _validate_model_inputs(self,
|
||
inputs: ProcessorInputs,
|
||
lora_request: Optional[LoRARequest] = None):
|
||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||
|
||
if encoder_inputs is not None:
|
||
self._validate_model_input(encoder_inputs,
|
||
lora_request,
|
||
prompt_type="encoder")
|
||
|
||
self._validate_model_input(decoder_inputs,
|
||
lora_request,
|
||
prompt_type="decoder")
|
||
|
||
def _validate_model_input(
|
||
self,
|
||
prompt_inputs: SingletonInputs,
|
||
lora_request: Optional[LoRARequest],
|
||
*,
|
||
prompt_type: Literal["encoder", "decoder"],
|
||
):
|
||
model_config = self.model_config
|
||
|
||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||
if not prompt_ids:
|
||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||
pass # Mllama may have empty encoder inputs for text-only data
|
||
else:
|
||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||
|
||
if self.model_config.skip_tokenizer_init:
|
||
tokenizer = None
|
||
else:
|
||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||
max_input_id = max(prompt_ids, default=0)
|
||
|
||
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
|
||
# self.model_config.get_vocab_size() is the model’s vocab size.
|
||
# For Qwen3 models, the language model has extra tokens that do
|
||
# not exist in the tokenizer, and vice versa for multimodal
|
||
# placeholder tokens in some multimodal models.
|
||
# See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501
|
||
# and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501
|
||
|
||
# Here we take the max of the two to determine if a token id is
|
||
# truly out-of-vocabulary.
|
||
if max_input_id > max(tokenizer.max_token_id,
|
||
self.model_config.get_vocab_size() - 1):
|
||
raise ValueError(
|
||
f"Token id {max_input_id} is out of vocabulary")
|
||
|
||
max_prompt_len = self.model_config.max_model_len
|
||
if len(prompt_ids) > max_prompt_len:
|
||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||
mm_registry = self.input_preprocessor.mm_registry
|
||
mm_processor = mm_registry.create_processor(
|
||
model_config,
|
||
tokenizer=tokenizer,
|
||
)
|
||
assert isinstance(mm_processor, EncDecMultiModalProcessor)
|
||
|
||
if mm_processor.pad_dummy_encoder_prompt:
|
||
return # Skip encoder length check for Whisper and Donut
|
||
|
||
if model_config.is_multimodal_model:
|
||
suggestion = (
|
||
"Make sure that `max_model_len` is no smaller than the "
|
||
"number of text tokens plus multimodal tokens. For image "
|
||
"inputs, the number of image tokens depends on the number "
|
||
"of images, and possibly their aspect ratios as well.")
|
||
else:
|
||
suggestion = (
|
||
"Make sure that `max_model_len` is no smaller than the "
|
||
"number of text tokens.")
|
||
|
||
raise ValueError(
|
||
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
|
||
f"longer than the maximum model length of {max_prompt_len}. "
|
||
f"{suggestion}")
|
||
|
||
# TODO: Find out how many placeholder tokens are there so we can
|
||
# check that chunked prefill does not truncate them
|
||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||
|
||
def clear_cache(self) -> None:
|
||
self.input_preprocessor.clear_cache()
|