mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Frontend][Core] Move guided decoding params into sampling params (#8252)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
from ...conftest import cleanup
|
||||
|
||||
@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_json=sample_json_schema))
|
||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
)
|
||||
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
|
||||
outputs = llm.generate(
|
||||
prompts="The best language for type-safe systems programming is ",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_choice=sample_guided_choice))
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm):
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
)
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements))
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a sql state that select col_1 from "
|
||||
"table_1 where it is equals to 1"),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_grammar=sample_sql_statements))
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm):
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guided_options_request_deprecation_warning(sample_regex, llm):
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="guided_options_request"):
|
||||
llm.generate(prompts="This should fail",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot set both"):
|
||||
llm.generate(prompts="This should fail",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
49
tests/model_executor/conftest.py
Normal file
49
tests/model_executor/conftest.py
Normal file
@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_regex():
|
||||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"age": {
|
||||
"type": "integer"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"maxLength": 10
|
||||
},
|
||||
"minItems": 3
|
||||
},
|
||||
"work_history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "string"
|
||||
},
|
||||
"duration": {
|
||||
"type": "number"
|
||||
},
|
||||
"position": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["company", "position"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["name", "age", "skills", "work_history"]
|
||||
}
|
@ -1,14 +1,12 @@
|
||||
# This unit test should be moved to a new
|
||||
# tests/test_guided_decoding directory.
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||
@ -44,11 +42,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||
regex_request = CompletionRequest(model='test',
|
||||
prompt=token_ids,
|
||||
guided_regex=sample_regex)
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
regex_lp = await get_guided_decoding_logits_processor(
|
||||
backend, regex_request, tokenizer)
|
||||
regex_request, tokenizer)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
@ -59,14 +55,31 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||
)
|
||||
json_request = CompletionRequest(model='test',
|
||||
prompt=token_ids,
|
||||
guided_json=sample_json_schema)
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
backend, json_request, tokenizer)
|
||||
json_request, tokenizer)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
|
||||
def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
|
||||
with pytest.raises(ValueError,
|
||||
match="You can only use one kind of guided"):
|
||||
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="You can only use one kind of guided"):
|
||||
GuidedDecodingParams(json=sample_json_schema, json_object=True)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="You can only use one kind of guided"):
|
||||
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="You can only use one kind of guided"):
|
||||
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
|
@ -20,6 +20,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -477,6 +479,18 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
if isinstance(params, SamplingParams) and \
|
||||
params.guided_decoding is not None:
|
||||
# Guided decoding has an async implementation for building logits
|
||||
# processors in a separate threadpool.
|
||||
# We want to invoke that here instead of using the blocking
|
||||
# implementation in the LLMEngine
|
||||
params = await build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=self.get_tokenizer(lora_request),
|
||||
default_guided_backend=self.decoding_config.
|
||||
guided_decoding_backend)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
@ -494,6 +508,36 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self.model_executor.check_health()
|
||||
|
||||
|
||||
async def build_guided_decoding_logits_processor_async(
|
||||
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
||||
default_guided_backend: str) -> SamplingParams:
|
||||
"""Constructs logits processors based on the guided_decoding,
|
||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||
those fields and adds the constructed logits processors to the
|
||||
logits_processors field. Modifies sampling params in-place and returns
|
||||
the modified sampling params."""
|
||||
if (guided_decoding := sampling_params.guided_decoding) is None:
|
||||
return sampling_params
|
||||
|
||||
logger.debug("Building guided decoding logits processor. "
|
||||
"Params: %s", guided_decoding)
|
||||
|
||||
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
||||
|
||||
processor = await get_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding, tokenizer=tokenizer)
|
||||
|
||||
if processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(processor)
|
||||
|
||||
# Unset guided decoding params after constructing the lp from them
|
||||
sampling_params.guided_decoding = None
|
||||
|
||||
return sampling_params
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
"""An asynchronous wrapper for :class:`LLMEngine`.
|
||||
|
||||
|
@ -25,6 +25,7 @@ from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
@ -33,6 +34,8 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
@ -843,6 +846,9 @@ class LLMEngine:
|
||||
raise ValueError(f"Cannot request more than "
|
||||
f"{max_logprobs} logprobs.")
|
||||
|
||||
sampling_params = self._build_logits_processors(
|
||||
sampling_params, lora_request)
|
||||
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
sampling_params = sampling_params.clone()
|
||||
@ -1895,3 +1901,51 @@ class LLMEngine:
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
|
||||
def _build_logits_processors(
|
||||
self, sampling_params: SamplingParams,
|
||||
lora_request: Optional[LoRARequest]) -> SamplingParams:
|
||||
"""Constructs logits processors based on the guided_decoding,
|
||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||
those fields and adds the constructed logits processors to the
|
||||
logits_processors field. Returns the modified sampling params."""
|
||||
|
||||
logits_processors = []
|
||||
if (guided_decoding := sampling_params.guided_decoding) is not None:
|
||||
|
||||
logger.debug(
|
||||
"Building guided decoding logits processor in "
|
||||
"LLMEngine. Params: %s", guided_decoding)
|
||||
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
guided_decoding.backend = guided_decoding.backend or \
|
||||
self.decoding_config.guided_decoding_backend
|
||||
|
||||
processor = get_local_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding, tokenizer=tokenizer)
|
||||
if processor:
|
||||
logits_processors.append(processor)
|
||||
|
||||
# Unset so this doesn't get passed down to the model
|
||||
sampling_params.guided_decoding = None
|
||||
|
||||
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
|
||||
processors = get_logits_processors(
|
||||
logit_bias=sampling_params.logit_bias,
|
||||
allowed_token_ids=sampling_params.allowed_token_ids,
|
||||
tokenizer=tokenizer)
|
||||
logits_processors.extend(processors)
|
||||
|
||||
# Unset so these don't get passed down to the model
|
||||
sampling_params.logit_bias = None
|
||||
sampling_params.allowed_token_ids = None
|
||||
|
||||
if logits_processors:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = logits_processors
|
||||
else:
|
||||
sampling_params.logits_processors.extend(logits_processors)
|
||||
|
||||
return sampling_params
|
||||
|
@ -16,6 +16,8 @@ from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.async_llm_engine import (
|
||||
build_guided_decoding_logits_processor_async)
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||
@ -512,6 +514,18 @@ class MQLLMEngineClient:
|
||||
if self._errored_with is not None:
|
||||
raise ENGINE_DEAD_ERROR(self._errored_with)
|
||||
|
||||
# Constructing guided decoding logits processors is expensive, so we do
|
||||
# it here to avoid contending with cpu resources and the GIL on the
|
||||
# backend process.
|
||||
if isinstance(params, SamplingParams) and \
|
||||
params.guided_decoding is not None:
|
||||
params = await \
|
||||
build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=await self.get_tokenizer(lora_request),
|
||||
default_guided_backend=self.decoding_config.guided_decoding_backend
|
||||
)
|
||||
|
||||
# 1) Create output queue for this requests.
|
||||
queue: asyncio.Queue[Union[RequestOutput,
|
||||
BaseException]] = asyncio.Queue()
|
||||
|
@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||
@ -16,13 +17,13 @@ from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest, LLMGuidedOptions)
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
@ -798,6 +799,14 @@ class LLM:
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
priority: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
if guided_options is not None:
|
||||
warnings.warn(
|
||||
"guided_options_request is deprecated, use "
|
||||
"SamplingParams.guided_decoding instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(prompts, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
@ -813,7 +822,7 @@ class LLM:
|
||||
|
||||
for sp in params if isinstance(params, list) else (params, ):
|
||||
if isinstance(sp, SamplingParams):
|
||||
self._add_guided_processor(sp, guided_options)
|
||||
self._add_guided_params(sp, guided_options)
|
||||
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
@ -847,22 +856,25 @@ class LLM:
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
def _add_guided_processor(
|
||||
def _add_guided_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
guided_options: Optional[GuidedDecodingRequest] = None):
|
||||
if guided_options:
|
||||
if guided_options.guided_decoding_backend is None:
|
||||
decoding_config = self.llm_engine.get_decoding_config()
|
||||
guided_options.guided_decoding_backend = (
|
||||
decoding_config.guided_decoding_backend)
|
||||
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
|
||||
guided_options.guided_decoding_backend, guided_options,
|
||||
self.get_tokenizer())
|
||||
if guided_logits_processor:
|
||||
if params.logits_processors is None:
|
||||
params.logits_processors = []
|
||||
params.logits_processors.append(guided_logits_processor)
|
||||
if guided_options is None:
|
||||
return params
|
||||
|
||||
if params.guided_decoding is not None:
|
||||
raise ValueError("Cannot set both guided_options_request and"
|
||||
"params.guided_decoding.")
|
||||
|
||||
params.guided_decoding = GuidedDecodingParams(
|
||||
json=guided_options.guided_json,
|
||||
regex=guided_options.guided_regex,
|
||||
choice=guided_options.guided_choice,
|
||||
grammar=guided_options.guided_grammar,
|
||||
json_object=guided_options.guided_json_object,
|
||||
backend=guided_options.guided_decoding_backend,
|
||||
whitespace_pattern=guided_options.guided_whitespace_pattern)
|
||||
return params
|
||||
|
||||
def _run_engine(
|
||||
|
@ -10,12 +10,10 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated, Required, TypedDict
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
|
||||
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# torch is mocked during docs generation,
|
||||
@ -284,10 +282,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_sampling_params(
|
||||
self, tokenizer: AnyTokenizer,
|
||||
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||
default_max_tokens: int) -> SamplingParams:
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
@ -296,14 +291,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.top_logprobs
|
||||
|
||||
# We now allow logprobs being true without top_logrobs.
|
||||
logits_processors = get_logits_processors(
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=None,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
if guided_decode_logits_processor:
|
||||
logits_processors.append(guided_decode_logits_processor)
|
||||
guided_json_object = None
|
||||
if (self.response_format is not None
|
||||
and self.response_format.type == "json_object"):
|
||||
guided_json_object = True
|
||||
|
||||
guided_decoding = GuidedDecodingParams.from_optional(
|
||||
json=self._get_guided_json_from_tool() or self.guided_json,
|
||||
regex=self.guided_regex,
|
||||
choice=self.guided_choice,
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
@ -329,11 +329,29 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
)
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=self.logit_bias)
|
||||
|
||||
def _get_guided_json_from_tool(
|
||||
self) -> Optional[Union[str, dict, BaseModel]]:
|
||||
# user has chosen to not use any tool
|
||||
if self.tool_choice == "none" or self.tools is None:
|
||||
return None
|
||||
|
||||
# user has chosen to use a named tool
|
||||
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
tool_name = self.tool_choice.function.name
|
||||
tools = {tool.function.name: tool.function for tool in self.tools}
|
||||
if tool_name not in tools:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
tool = tools[tool_name]
|
||||
return tool.parameters
|
||||
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -537,10 +555,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_sampling_params(
|
||||
self, tokenizer: AnyTokenizer,
|
||||
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||
default_max_tokens: int) -> SamplingParams:
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
@ -551,13 +566,19 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
logits_processors = get_logits_processors(
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
if guided_decode_logits_processor:
|
||||
logits_processors.append(guided_decode_logits_processor)
|
||||
guided_json_object = None
|
||||
if (self.response_format is not None
|
||||
and self.response_format.type == "json_object"):
|
||||
guided_json_object = True
|
||||
|
||||
guided_decoding = GuidedDecodingParams.from_optional(
|
||||
json=self.guided_json,
|
||||
regex=self.guided_regex,
|
||||
choice=self.guided_choice,
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
@ -583,11 +604,12 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
)
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=self.allowed_token_ids)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
@ -187,9 +187,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
guided_decode_logits_processor = (
|
||||
await self._guided_decode_logits_processor(request, tokenizer))
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
@ -208,8 +205,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
assert prompt_inputs is not None
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
tokenizer,
|
||||
guided_decode_logits_processor,
|
||||
default_max_tokens=self.max_model_len -
|
||||
len(prompt_inputs["prompt_token_ids"]))
|
||||
|
||||
|
@ -110,8 +110,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
guided_decode_logits_processor = (
|
||||
await self._guided_decode_logits_processor(request, tokenizer))
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(
|
||||
request,
|
||||
@ -123,8 +121,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
sampling_params = request.to_sampling_params(
|
||||
tokenizer,
|
||||
guided_decode_logits_processor,
|
||||
default_max_tokens=self.max_model_len -
|
||||
len(prompt_inputs["prompt_token_ids"]))
|
||||
|
||||
|
@ -27,11 +27,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AtomicCounter
|
||||
@ -168,15 +166,6 @@ class OpenAIServing:
|
||||
})
|
||||
return json_str
|
||||
|
||||
async def _guided_decode_logits_processor(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
|
||||
decoding_config = await self.engine_client.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
return await get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend, request, tokenizer)
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
|
@ -1,77 +1,45 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend: str, request: Union[CompletionRequest,
|
||||
ChatCompletionRequest],
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
request = _adapt_request_for_tool_use(request)
|
||||
|
||||
if guided_decoding_backend == 'outlines':
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend == 'outlines' or guided_params.grammar:
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
if guided_decoding_backend == 'lm-format-enforcer':
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return await get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||
|
||||
|
||||
def get_local_guided_decoding_logits_processor(
|
||||
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
# request = _adapt_request_for_tool_use(request)
|
||||
|
||||
if guided_decoding_backend == 'outlines':
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend == 'outlines' or guided_params.grammar:
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_local_outlines_guided_decoding_logits_processor)
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_options, tokenizer)
|
||||
if guided_decoding_backend == 'lm-format-enforcer':
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_options, tokenizer)
|
||||
guided_params, tokenizer)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||
|
||||
|
||||
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
|
||||
ChatCompletionRequest]):
|
||||
# the legacy completion API does not support tool use
|
||||
if type(request) is CompletionRequest:
|
||||
return request
|
||||
|
||||
# user has chosen to not use any tool,
|
||||
# OR is allowing the model to choose a tool.
|
||||
if request.tool_choice == "none" or request.tool_choice == "auto":
|
||||
return request
|
||||
|
||||
# user has chosen to use a named tool
|
||||
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
tool_name = request.tool_choice.function.name
|
||||
tools = {tool.function.name: tool.function for tool in request.tools}
|
||||
if tool_name not in tools:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
tool = tools[tool_name]
|
||||
request.guided_json = tool.parameters
|
||||
|
||||
return request
|
||||
|
@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TypedDict, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# These classes are deprecated, see SamplingParams
|
||||
class LLMGuidedOptions(TypedDict, total=False):
|
||||
guided_json: Union[Dict, BaseModel, str]
|
||||
guided_regex: str
|
||||
|
@ -7,66 +7,13 @@ from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
|
||||
TokenEnforcerTokenizerData, UnionParser)
|
||||
from lmformatenforcer.integrations.vllm import (
|
||||
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
|
||||
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer)
|
||||
character_level_parser: CharacterLevelParser
|
||||
if request.guided_json:
|
||||
schema = _normalize_json_schema_object(request.guided_json)
|
||||
character_level_parser = JsonSchemaParser(schema)
|
||||
elif request.guided_choice:
|
||||
character_level_parser = UnionParser(
|
||||
[StringParser(choice) for choice in request.guided_choice])
|
||||
elif request.guided_regex:
|
||||
character_level_parser = RegexParser(request.guided_regex)
|
||||
elif request.guided_grammar:
|
||||
# CFG grammar not supported by LMFE, revert to outlines
|
||||
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
elif (request.response_format is not None
|
||||
and request.response_format.type == "json_object"):
|
||||
character_level_parser = JsonSchemaParser(
|
||||
None) # None means any json object
|
||||
elif (request.response_format is not None
|
||||
and request.response_format.type == "json_schema"
|
||||
and request.response_format.json_schema is not None
|
||||
and request.response_format.json_schema.json_schema is not None):
|
||||
schema = _normalize_json_schema_object(
|
||||
request.response_format.json_schema.json_schema)
|
||||
character_level_parser = JsonSchemaParser(schema)
|
||||
else:
|
||||
return None
|
||||
|
||||
logits_processor = build_vllm_logits_processor(tokenizer_data,
|
||||
character_level_parser)
|
||||
return logits_processor
|
||||
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
|
||||
|
||||
|
||||
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_options: GuidedDecodingRequest,
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
@ -78,23 +25,20 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer)
|
||||
character_level_parser: CharacterLevelParser
|
||||
if guided_options.guided_json:
|
||||
schema = _normalize_json_schema_object(guided_options.guided_json)
|
||||
character_level_parser = JsonSchemaParser(schema)
|
||||
elif guided_options.guided_choice:
|
||||
if guided_params.json:
|
||||
schema_dict = _normalize_json_schema_object(guided_params.json)
|
||||
character_level_parser = JsonSchemaParser(schema_dict)
|
||||
elif guided_params.choice:
|
||||
character_level_parser = UnionParser(
|
||||
[StringParser(choice) for choice in guided_options.guided_choice])
|
||||
elif guided_options.guided_regex:
|
||||
character_level_parser = RegexParser(guided_options.guided_regex)
|
||||
elif guided_options.guided_grammar:
|
||||
# CFG grammar not supported by LMFE, revert to outlines
|
||||
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_local_outlines_guided_decoding_logits_processor)
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_options, tokenizer)
|
||||
elif guided_options.guided_json_object:
|
||||
[StringParser(choice) for choice in guided_params.choice])
|
||||
elif guided_params.regex:
|
||||
character_level_parser = RegexParser(guided_params.regex)
|
||||
elif guided_params.grammar:
|
||||
# CFG grammar not supported by LMFE
|
||||
raise ValueError("Cannot construct a guided decoding logits processor"
|
||||
" using the grammar option with the"
|
||||
" lm_format_enforcer backend.")
|
||||
elif guided_params.json_object:
|
||||
# None means any json object
|
||||
character_level_parser = JsonSchemaParser(None)
|
||||
else:
|
||||
@ -105,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
return logits_processor
|
||||
|
||||
|
||||
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
||||
def _normalize_json_schema_object(schema: Union[str, dict]) -> dict:
|
||||
if isinstance(schema, str):
|
||||
return json_loads(schema)
|
||||
if isinstance(schema, dict):
|
||||
return schema
|
||||
if isinstance(schema, BaseModel):
|
||||
return schema.model_json_schema()
|
||||
raise AssertionError(f"Unsupported schema type {schema}")
|
||||
|
||||
|
||||
|
@ -5,16 +5,11 @@ from json import dumps as json_dumps
|
||||
from re import escape as regex_escape
|
||||
from typing import Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
class GuidedDecodingMode(Enum):
|
||||
@ -55,8 +50,7 @@ global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest,
|
||||
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
"""
|
||||
@ -66,7 +60,7 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
global global_thread_pool
|
||||
guide, mode = _get_guide_and_mode(request)
|
||||
guide, mode = _get_guide_and_mode(guided_params)
|
||||
if not guide or not mode:
|
||||
return None
|
||||
|
||||
@ -77,11 +71,11 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
|
||||
return await loop.run_in_executor(global_thread_pool,
|
||||
_get_logits_processor, guide, tokenizer,
|
||||
mode, request.guided_whitespace_pattern)
|
||||
mode, guided_params.whitespace_pattern)
|
||||
|
||||
|
||||
def get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
"""
|
||||
@ -90,65 +84,37 @@ def get_local_outlines_guided_decoding_logits_processor(
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
guide, mode = _get_guide_and_mode(guided_options)
|
||||
guide, mode = _get_guide_and_mode(guided_params)
|
||||
if not guide or not mode:
|
||||
return None
|
||||
|
||||
return _get_logits_processor(guide, tokenizer, mode,
|
||||
guided_options.guided_whitespace_pattern)
|
||||
guided_params.whitespace_pattern)
|
||||
|
||||
|
||||
def _get_guide_and_mode(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest,
|
||||
GuidedDecodingRequest]
|
||||
guided_params: GuidedDecodingParams
|
||||
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
||||
# if the request is a chat completion request, AND the tool choice is a
|
||||
# named tool choice, do guided decoding
|
||||
# using that tool as the JSON schema
|
||||
if isinstance(request, ChatCompletionRequest) and isinstance(
|
||||
request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
# Guided generation for tools/functions parameters
|
||||
if request.tool_choice.type == "function":
|
||||
for tool in request.tools:
|
||||
if (tool.type == "function" and tool.function.name
|
||||
== request.tool_choice.function.name):
|
||||
json = json_dumps(tool.function.parameters, sort_keys=True)
|
||||
return json, GuidedDecodingMode.JSON
|
||||
return None, None
|
||||
|
||||
elif request.guided_json:
|
||||
if isinstance(request.guided_json, dict):
|
||||
if guided_params.json:
|
||||
if isinstance(guided_params.json, dict):
|
||||
# turn dict into hashable string
|
||||
json = json_dumps(request.guided_json)
|
||||
elif isinstance(request.guided_json, BaseModel):
|
||||
# use pydantic signature so that different model classes
|
||||
# with the same fields will get hashed the same
|
||||
json = str(request.guided_json.__signature__)
|
||||
json = json_dumps(guided_params.json)
|
||||
else:
|
||||
json = request.guided_json
|
||||
json = guided_params.json
|
||||
return json, GuidedDecodingMode.JSON
|
||||
elif request.guided_regex:
|
||||
return request.guided_regex, GuidedDecodingMode.REGEX
|
||||
elif request.guided_choice:
|
||||
elif guided_params.regex:
|
||||
return guided_params.regex, GuidedDecodingMode.REGEX
|
||||
elif guided_params.choice:
|
||||
# choice just uses regex
|
||||
choices = [
|
||||
regex_escape(str(choice)) for choice in request.guided_choice
|
||||
regex_escape(str(choice)) for choice in guided_params.choice
|
||||
]
|
||||
choices_regex = "(" + "|".join(choices) + ")"
|
||||
return choices_regex, GuidedDecodingMode.CHOICE
|
||||
elif request.guided_grammar:
|
||||
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
|
||||
elif (not isinstance(request, GuidedDecodingRequest)
|
||||
and request.response_format is not None
|
||||
and request.response_format.type == "json_object"):
|
||||
elif guided_params.grammar:
|
||||
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
|
||||
elif guided_params.json_object:
|
||||
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
||||
elif (not isinstance(request, GuidedDecodingRequest)
|
||||
and request.response_format is not None
|
||||
and request.response_format.type == "json_schema"
|
||||
and request.response_format.json_schema is not None
|
||||
and request.response_format.json_schema.json_schema is not None):
|
||||
json = json_dumps(request.response_format.json_schema.json_schema)
|
||||
return json, GuidedDecodingMode.JSON
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -34,6 +36,54 @@ first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
# maybe make msgspec?
|
||||
@dataclass
|
||||
class GuidedDecodingParams:
|
||||
"""One of these fields will be used to build a logit processor."""
|
||||
json: Optional[Union[str, Dict]] = None
|
||||
regex: Optional[str] = None
|
||||
choice: Optional[List[str]] = None
|
||||
grammar: Optional[str] = None
|
||||
json_object: Optional[bool] = None
|
||||
"""These are other options that can be set"""
|
||||
backend: Optional[str] = None
|
||||
whitespace_pattern: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
json: Optional[Union[Dict, BaseModel, str]],
|
||||
regex: Optional[str] = None,
|
||||
choice: Optional[List[str]] = None,
|
||||
grammar: Optional[str] = None,
|
||||
json_object: Optional[bool] = None,
|
||||
backend: Optional[str] = None,
|
||||
whitespace_pattern: Optional[str] = None,
|
||||
) -> "GuidedDecodingParams":
|
||||
# Extract json schemas from pydantic models
|
||||
if isinstance(json, (BaseModel, type(BaseModel))):
|
||||
json = json.model_json_schema()
|
||||
return GuidedDecodingParams(
|
||||
json=json,
|
||||
regex=regex,
|
||||
choice=choice,
|
||||
grammar=grammar,
|
||||
json_object=json_object,
|
||||
backend=backend,
|
||||
whitespace_pattern=whitespace_pattern,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that some fields are mutually exclusive."""
|
||||
guide_count = sum([
|
||||
self.json is not None, self.regex is not None, self.choice
|
||||
is not None, self.grammar is not None, self.json_object is not None
|
||||
])
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding but multiple are "
|
||||
f"specified: {self.__dict__}")
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
CUMULATIVE = 0
|
||||
@ -124,6 +174,13 @@ class SamplingParams(
|
||||
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
||||
tokens from the prompt (i.e., left truncation). Defaults to None
|
||||
(i.e., no truncation).
|
||||
guided_decoding: If provided, the engine will construct a guided
|
||||
decoding logits processor from these parameters. Defaults to None.
|
||||
logit_bias: If provided, the engine will construct a logits processor
|
||||
that applies these logit biases. Defaults to None.
|
||||
allowed_token_ids: If provided, the engine will construct a logits
|
||||
processor which only retains scores for the given token ids.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
n: int = 1
|
||||
@ -164,6 +221,11 @@ class SamplingParams(
|
||||
output_text_buffer_length: int = 0
|
||||
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
# Fields used to construct logits processors
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
allowed_token_ids: Optional[List[int]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: Optional[int] = 1,
|
||||
@ -194,7 +256,16 @@ class SamplingParams(
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=1)]] = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None,
|
||||
allowed_token_ids: Optional[List[int]] = None,
|
||||
) -> "SamplingParams":
|
||||
if logit_bias is not None:
|
||||
logit_bias = {
|
||||
int(token): bias
|
||||
for token, bias in logit_bias.items()
|
||||
}
|
||||
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
@ -226,6 +297,9 @@ class SamplingParams(
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
output_kind=output_kind,
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@ -454,4 +528,5 @@ class SamplingParams(
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens}, "
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
|
||||
f"guided_decoding={self.guided_decoding}")
|
||||
|
Reference in New Issue
Block a user