[Frontend][Core] Move guided decoding params into sampling params (#8252)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Joe Runde
2024-09-30 19:34:25 -06:00
committed by GitHub
parent bce324487a
commit 062c89e7c9
16 changed files with 441 additions and 281 deletions

View File

@ -7,7 +7,7 @@ import pytest
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...conftest import cleanup
@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
guided_decoding=GuidedDecodingParams(regex=sample_regex))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
)
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_json=sample_json_schema))
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_choice=sample_guided_choice))
use_tqdm=True)
assert outputs is not None
for output in outputs:
@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm):
temperature=0.8,
top_p=0.95,
max_tokens=1000,
)
guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_grammar=sample_sql_statements))
)
assert outputs is not None
for output in outputs:
@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm):
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
def test_guided_options_request_deprecation_warning(sample_regex, llm):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
with pytest.warns(DeprecationWarning, match="guided_options_request"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
@pytest.mark.skip_global_cleanup
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))
with pytest.raises(ValueError, match="Cannot set both"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))

View File

@ -0,0 +1,49 @@
import pytest
@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}

View File

@ -1,14 +1,12 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
import pytest
import torch
from transformers import AutoTokenizer
from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
def test_guided_logits_processors(sample_regex, sample_json_schema):
@ -44,11 +42,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=sample_regex)
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer)
regex_request, tokenizer)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
@ -59,14 +55,31 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=sample_json_schema)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer)
json_request, tokenizer)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, json_object=True)
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")

View File

@ -20,6 +20,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
@ -477,6 +479,18 @@ class _AsyncLLMEngine(LLMEngine):
)
processed_inputs = self.input_processor(preprocessed_inputs)
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
# Guided decoding has an async implementation for building logits
# processors in a separate threadpool.
# We want to invoke that here instead of using the blocking
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
@ -494,6 +508,36 @@ class _AsyncLLMEngine(LLMEngine):
self.model_executor.check_health()
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
return sampling_params
logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(processor)
# Unset guided decoding params after constructing the lp from them
sampling_params.guided_decoding = None
return sampling_params
class AsyncLLMEngine:
"""An asynchronous wrapper for :class:`LLMEngine`.

View File

@ -25,6 +25,7 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
@ -33,6 +34,8 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
@ -843,6 +846,9 @@ class LLMEngine:
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
sampling_params = self._build_logits_processors(
sampling_params, lora_request)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
@ -1895,3 +1901,51 @@ class LLMEngine:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""
logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None:
logger.debug(
"Building guided decoding logits processor in "
"LLMEngine. Params: %s", guided_decoding)
tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
logits_processors.append(processor)
# Unset so this doesn't get passed down to the model
sampling_params.guided_decoding = None
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)
processors = get_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
logits_processors.extend(processors)
# Unset so these don't get passed down to the model
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None
if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
else:
sampling_params.logits_processors.extend(logits_processors)
return sampling_params

View File

@ -16,6 +16,8 @@ from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
@ -512,6 +514,18 @@ class MQLLMEngineClient:
if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with)
# Constructing guided decoding logits processors is expensive, so we do
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
params = await \
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.guided_decoding_backend
)
# 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()

View File

@ -1,4 +1,5 @@
import itertools
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
@ -16,13 +17,13 @@ from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -798,6 +799,14 @@ class LLM:
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None,
) -> None:
if guided_options is not None:
warnings.warn(
"guided_options_request is deprecated, use "
"SamplingParams.guided_decoding instead",
DeprecationWarning,
stacklevel=2,
)
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
@ -813,7 +822,7 @@ class LLM:
for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams):
self._add_guided_processor(sp, guided_options)
self._add_guided_params(sp, guided_options)
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
@ -847,22 +856,25 @@ class LLM:
priority=priority,
)
def _add_guided_processor(
def _add_guided_params(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options:
if guided_options.guided_decoding_backend is None:
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = (
decoding_config.guided_decoding_backend)
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer())
if guided_logits_processor:
if params.logits_processors is None:
params.logits_processors = []
params.logits_processors.append(guided_logits_processor)
if guided_options is None:
return params
if params.guided_decoding is not None:
raise ValueError("Cannot set both guided_options_request and"
"params.guided_decoding.")
params.guided_decoding = GuidedDecodingParams(
json=guided_options.guided_json,
regex=guided_options.guided_regex,
choice=guided_options.guided_choice,
grammar=guided_options.guided_grammar,
json_object=guided_options.guided_json_object,
backend=guided_options.guided_decoding_backend,
whitespace_pattern=guided_options.guided_whitespace_pattern)
return params
def _run_engine(

View File

@ -10,12 +10,10 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
# torch is mocked during docs generation,
@ -284,10 +282,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def to_sampling_params(
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
@ -296,14 +291,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=None,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
guided_json_object = None
if (self.response_format is not None
and self.response_format.type == "json_object"):
guided_json_object = True
guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional(
n=self.n,
@ -329,11 +329,29 @@ class ChatCompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)
guided_decoding=guided_decoding,
logit_bias=self.logit_bias)
def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]:
# user has chosen to not use any tool
if self.tool_choice == "none" or self.tools is None:
return None
# user has chosen to use a named tool
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = self.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in self.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
return tool.parameters
return None
@model_validator(mode="before")
@classmethod
@ -537,10 +555,7 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def to_sampling_params(
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
@ -551,13 +566,19 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
guided_json_object = None
if (self.response_format is not None
and self.response_format.type == "json_object"):
guided_json_object = True
guided_decoding = GuidedDecodingParams.from_optional(
json=self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional(
n=self.n,
@ -583,11 +604,12 @@ class CompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
@model_validator(mode="before")
@classmethod

View File

@ -187,9 +187,6 @@ class OpenAIServingChat(OpenAIServing):
raw_request.state.request_metadata = request_metadata
try:
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
@ -208,8 +205,6 @@ class OpenAIServingChat(OpenAIServing):
assert prompt_inputs is not None
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

View File

@ -110,8 +110,6 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
@ -123,8 +121,6 @@ class OpenAIServingCompletion(OpenAIServing):
for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

View File

@ -27,11 +27,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter
@ -168,15 +166,6 @@ class OpenAIServing:
})
return json_str
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
guided_decoding_backend, request, tokenizer)
async def _check_model(
self,
request: AnyRequest,

View File

@ -1,77 +1,45 @@
from typing import Optional, Union
from typing import Optional
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.sampling_params import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest],
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
guided_params, tokenizer)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer)
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
if type(request) is CompletionRequest:
return request
# user has chosen to not use any tool,
# OR is allowing the model to choose a tool.
if request.tool_choice == "none" or request.tool_choice == "auto":
return request
# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = request.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in request.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
request.guided_json = tool.parameters
return request

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TypedDict, Union
from pydantic import BaseModel
# These classes are deprecated, see SamplingParams
class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str]
guided_regex: str

View File

@ -7,66 +7,13 @@ from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
TokenEnforcerTokenizerData, UnionParser)
from lmformatenforcer.integrations.vllm import (
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.sampling_params import LogitsProcessor
async def get_lm_format_enforcer_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if request.guided_json:
schema = _normalize_json_schema_object(request.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif request.guided_choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in request.guided_choice])
elif request.guided_regex:
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
elif (request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
schema = _normalize_json_schema_object(
request.response_format.json_schema.json_schema)
character_level_parser = JsonSchemaParser(schema)
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest,
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
@ -78,23 +25,20 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if guided_options.guided_json:
schema = _normalize_json_schema_object(guided_options.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif guided_options.guided_choice:
if guided_params.json:
schema_dict = _normalize_json_schema_object(guided_params.json)
character_level_parser = JsonSchemaParser(schema_dict)
elif guided_params.choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in guided_options.guided_choice])
elif guided_options.guided_regex:
character_level_parser = RegexParser(guided_options.guided_regex)
elif guided_options.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
elif guided_options.guided_json_object:
[StringParser(choice) for choice in guided_params.choice])
elif guided_params.regex:
character_level_parser = RegexParser(guided_params.regex)
elif guided_params.grammar:
# CFG grammar not supported by LMFE
raise ValueError("Cannot construct a guided decoding logits processor"
" using the grammar option with the"
" lm_format_enforcer backend.")
elif guided_params.json_object:
# None means any json object
character_level_parser = JsonSchemaParser(None)
else:
@ -105,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
def _normalize_json_schema_object(schema: Union[str, dict]) -> dict:
if isinstance(schema, str):
return json_loads(schema)
if isinstance(schema, dict):
return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}")

View File

@ -5,16 +5,11 @@ from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Tuple, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
class GuidedDecodingMode(Enum):
@ -55,8 +50,7 @@ global_thread_pool = None # used for generating logits processor fsm
async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest,
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@ -66,7 +60,7 @@ async def get_outlines_guided_decoding_logits_processor(
we make a shallow copy to reuse the same underlying FSM.
"""
global global_thread_pool
guide, mode = _get_guide_and_mode(request)
guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode:
return None
@ -77,11 +71,11 @@ async def get_outlines_guided_decoding_logits_processor(
return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
mode, request.guided_whitespace_pattern)
mode, guided_params.whitespace_pattern)
def get_local_outlines_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@ -90,65 +84,37 @@ def get_local_outlines_guided_decoding_logits_processor(
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide, mode = _get_guide_and_mode(guided_options)
guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode:
return None
return _get_logits_processor(guide, tokenizer, mode,
guided_options.guided_whitespace_pattern)
guided_params.whitespace_pattern)
def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest,
GuidedDecodingRequest]
guided_params: GuidedDecodingParams
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
# if the request is a chat completion request, AND the tool choice is a
# named tool choice, do guided decoding
# using that tool as the JSON schema
if isinstance(request, ChatCompletionRequest) and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam):
# Guided generation for tools/functions parameters
if request.tool_choice.type == "function":
for tool in request.tools:
if (tool.type == "function" and tool.function.name
== request.tool_choice.function.name):
json = json_dumps(tool.function.parameters, sort_keys=True)
return json, GuidedDecodingMode.JSON
return None, None
elif request.guided_json:
if isinstance(request.guided_json, dict):
if guided_params.json:
if isinstance(guided_params.json, dict):
# turn dict into hashable string
json = json_dumps(request.guided_json)
elif isinstance(request.guided_json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(request.guided_json.__signature__)
json = json_dumps(guided_params.json)
else:
json = request.guided_json
json = guided_params.json
return json, GuidedDecodingMode.JSON
elif request.guided_regex:
return request.guided_regex, GuidedDecodingMode.REGEX
elif request.guided_choice:
elif guided_params.regex:
return guided_params.regex, GuidedDecodingMode.REGEX
elif guided_params.choice:
# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in request.guided_choice
regex_escape(str(choice)) for choice in guided_params.choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
elif request.guided_grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_object"):
elif guided_params.grammar:
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
elif guided_params.json_object:
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
json = json_dumps(request.response_format.json_schema.json_schema)
return json, GuidedDecodingMode.JSON
else:
return None, None

View File

@ -1,11 +1,13 @@
"""Sampling parameters for text generation."""
import copy
from dataclasses import dataclass
from enum import Enum, IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union
import msgspec
import torch
from pydantic import BaseModel
from typing_extensions import Annotated
import vllm.envs as envs
@ -34,6 +36,54 @@ first argument, and returns a modified tensor of logits
to sample from."""
# maybe make msgspec?
@dataclass
class GuidedDecodingParams:
"""One of these fields will be used to build a logit processor."""
json: Optional[Union[str, Dict]] = None
regex: Optional[str] = None
choice: Optional[List[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
"""These are other options that can be set"""
backend: Optional[str] = None
whitespace_pattern: Optional[str] = None
@staticmethod
def from_optional(
json: Optional[Union[Dict, BaseModel, str]],
regex: Optional[str] = None,
choice: Optional[List[str]] = None,
grammar: Optional[str] = None,
json_object: Optional[bool] = None,
backend: Optional[str] = None,
whitespace_pattern: Optional[str] = None,
) -> "GuidedDecodingParams":
# Extract json schemas from pydantic models
if isinstance(json, (BaseModel, type(BaseModel))):
json = json.model_json_schema()
return GuidedDecodingParams(
json=json,
regex=regex,
choice=choice,
grammar=grammar,
json_object=json_object,
backend=backend,
whitespace_pattern=whitespace_pattern,
)
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
self.json is not None, self.regex is not None, self.choice
is not None, self.grammar is not None, self.json_object is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding but multiple are "
f"specified: {self.__dict__}")
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
@ -124,6 +174,13 @@ class SamplingParams(
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
guided_decoding: If provided, the engine will construct a guided
decoding logits processor from these parameters. Defaults to None.
logit_bias: If provided, the engine will construct a logits processor
that applies these logit biases. Defaults to None.
allowed_token_ids: If provided, the engine will construct a logits
processor which only retains scores for the given token ids.
Defaults to None.
"""
n: int = 1
@ -164,6 +221,11 @@ class SamplingParams(
output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors
guided_decoding: Optional[GuidedDecodingParams] = None
logit_bias: Optional[Dict[int, float]] = None
allowed_token_ids: Optional[List[int]] = None
@staticmethod
def from_optional(
n: Optional[int] = 1,
@ -194,7 +256,16 @@ class SamplingParams(
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None,
allowed_token_ids: Optional[List[int]] = None,
) -> "SamplingParams":
if logit_bias is not None:
logit_bias = {
int(token): bias
for token, bias in logit_bias.items()
}
return SamplingParams(
n=1 if n is None else n,
best_of=best_of,
@ -226,6 +297,9 @@ class SamplingParams(
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind,
guided_decoding=guided_decoding,
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
)
def __post_init__(self) -> None:
@ -454,4 +528,5 @@ class SamplingParams(
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
f"guided_decoding={self.guided_decoding}")