Add guided decoding for OpenAI API server (#2819)

Co-authored-by: br3no <breno@veltefaria.de>
Co-authored-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
felixzhu555
2024-02-29 14:13:08 -08:00
committed by GitHub
parent 29a8d6a554
commit 703e42ee4b
9 changed files with 597 additions and 1 deletions

View File

@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
outlines >= 0.0.27
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.

View File

@ -0,0 +1,75 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
from transformers import AutoTokenizer
import torch
from vllm.model_executor.guided_logits_processors import (RegexLogitsProcessor,
JSONLogitsProcessor)
TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)
regex_LP.init_state()
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
json_LP.init_state()
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
json_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

View File

@ -9,12 +9,64 @@ import ray # using Ray for overall ease of process management, parallel request
import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests
# imports for guided decoding tests
import json
import jsonschema
import re
from vllm.transformers_utils.tokenizer import get_tokenizer
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
"Swift", "Kotlin"
]
pytestmark = pytest.mark.asyncio
@ -325,6 +377,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token_id): 100},
seed=42,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
@ -358,5 +411,189 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=
f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}",
n=3,
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
for i in range(3):
assert completion.choices[i].text is not None
output_json = json.loads(completion.choices[i].text)
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "Give an example JSON for an employee profile that " + \
f"fits this schema: {TEST_SCHEMA}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
messages.append({"role": "assistant", "content": message.content})
messages.append({
"role":
"user",
"content":
"Give me another one with a different name and age"
})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
n=3,
temperature=1.0,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
for i in range(3):
assert completion.choices[i].text is not None
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example IP address with this regex: {TEST_REGEX}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None
messages.append({"role": "assistant", "content": ip1})
messages.append({"role": "user", "content": "Give me a different one"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None
assert ip1 != ip2
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 2
for i in range(2):
assert completion.choices[i].text in TEST_CHOICE
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
"The best language for type-safe systems programming is "
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE
messages.append({"role": "assistant", "content": choice1})
messages.append({
"role": "user",
"content": "I disagree, pick another one"
})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE
assert choice1 != choice2
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(guided_json=42))
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
"The best language for type-safe systems programming is "
}]
with pytest.raises(openai.BadRequestError):
_ = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
extra_body=dict(guided_regex={
1: "Python",
2: "C++"
}))
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example string that fits this regex",
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -333,6 +333,9 @@ class AsyncLLMEngine:
return (self.background_loop is not None
and not self.background_loop.done())
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
def start_background_loop(self) -> None:
"""Start the background loop."""
if self.is_running:

View File

@ -3,7 +3,7 @@
import time
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
@ -86,6 +86,9 @@ class ChatCompletionRequest(BaseModel):
min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
@ -131,6 +134,20 @@ class ChatCompletionRequest(BaseModel):
logits_processors=logits_processors,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class CompletionRequest(BaseModel):
model: str
@ -163,6 +180,9 @@ class CompletionRequest(BaseModel):
min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
@ -207,6 +227,20 @@ class CompletionRequest(BaseModel):
logits_processors=logits_processors,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)

View File

@ -12,6 +12,7 @@ from vllm.entrypoints.openai.protocol import (
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__)
@ -62,6 +63,14 @@ class OpenAIServingChat(OpenAIServing):
prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
except ValueError as e:
return self.create_error_response(str(e))

View File

@ -16,6 +16,7 @@ from .protocol import (
)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__)
@ -286,6 +287,14 @@ class OpenAIServingCompletion(OpenAIServing):
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):

View File

@ -0,0 +1,99 @@
import asyncio
import concurrent.futures
from copy import copy
from enum import Enum
from functools import lru_cache
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Union, Tuple
from pydantic import BaseModel
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest
from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor
class GuidedDecodingMode(Enum):
JSON = "json"
REGEX = "regex"
CHOICE = "choice"
global_thread_pool = None # used for generating logits processor fsm
async def get_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
global global_thread_pool
guide, mode = _get_guide_and_mode(request)
if not guide:
return None
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide,
tokenizer, mode)
logits_processor = copy(result)
# reset logits processor's internal state
logits_processor.init_state()
return logits_processor
def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest]
) -> Tuple[str, GuidedDecodingMode]:
if request.guided_json:
if not isinstance(request.guided_json, (str, dict, BaseModel)):
raise TypeError("JSON schema must be str, dict, or BaseModel")
json = request.guided_json
if isinstance(json, dict):
# turn dict into hashable string
json = json_dumps(json, sort_keys=True)
elif isinstance(json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(json.__signature__)
return json, GuidedDecodingMode.JSON
elif request.guided_regex:
if not isinstance(request.guided_regex, str):
raise TypeError("Regex must be string")
return request.guided_regex, GuidedDecodingMode.REGEX
elif request.guided_choice:
if not isinstance(request.guided_choice, list):
raise TypeError("Choices must be a list")
# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in request.guided_choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
else:
return None, None
@lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str, tokenizer,
mode: GuidedDecodingMode):
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")

View File

@ -0,0 +1,129 @@
# Copyright 2024- the Outlines developers
# This file is adapted from
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
from collections import defaultdict
from typing import Union, DefaultDict, Dict, List, Optional
import torch
from pydantic import BaseModel
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
class RegexLogitsProcessor:
def __init__(self, regex_string: str, tokenizer):
"""Compile the FSM that drives the regex-structured generation.
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
def init_state(self):
"""Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
seq_id = hash(tuple(input_ids))
if len(input_ids) == 0:
self.init_state()
else:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[last_seq_id], last_token)
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask[allowed_tokens] = 0
scores.add_(mask)
return scores
def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer
class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self,
schema: Union[str, Dict, BaseModel],
tokenizer,
whitespace_pattern: Optional[str] = None):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
tokenizer
The model's tokenizer
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
"""
if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema())
elif isinstance(schema, Dict):
schema_str = json.dumps(schema)
elif isinstance(schema, str):
schema_str = schema
else:
raise ValueError(
f"Cannot parse schema {schema}. The schema must be either " +
"a Pydantic object, a dictionary or a string that contains the JSON "
+ "Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)