mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Frontend][Core] Update Outlines Integration from FSM
to Guide
(#4109)
Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Breno Faria <breno.faria@intrafind.com>
This commit is contained in:
@ -17,6 +17,6 @@ prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer == 0.10.1
|
||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
||||
outlines >= 0.0.43 # Requires torch >= 2.1.0
|
||||
typing_extensions
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
|
@ -63,7 +63,6 @@ def test_guided_logits_processors():
|
||||
tokenizer,
|
||||
whitespace_pattern=None)
|
||||
|
||||
regex_LP.init_state()
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
|
||||
tensor = torch.rand(32000)
|
||||
@ -72,7 +71,6 @@ def test_guided_logits_processors():
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
json_LP.init_state()
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
|
||||
tensor = torch.rand(32000)
|
||||
|
@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from json import dumps as json_dumps
|
||||
from re import escape as regex_escape
|
||||
from typing import Tuple, Union
|
||||
@ -54,8 +52,10 @@ global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||
request: Union[CompletionRequest,
|
||||
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
@ -64,7 +64,7 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
"""
|
||||
global global_thread_pool
|
||||
guide, mode = _get_guide_and_mode(request)
|
||||
if not guide:
|
||||
if not guide or not mode:
|
||||
return None
|
||||
|
||||
if global_thread_pool is None:
|
||||
@ -72,15 +72,9 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
max_workers=2)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
result = await loop.run_in_executor(global_thread_pool,
|
||||
_get_cached_logits_processor, guide,
|
||||
tokenizer, mode,
|
||||
request.guided_whitespace_pattern)
|
||||
|
||||
logits_processor = copy(result)
|
||||
# reset logits processor's internal state
|
||||
logits_processor.init_state()
|
||||
return logits_processor
|
||||
return await loop.run_in_executor(global_thread_pool,
|
||||
_get_logits_processor, guide, tokenizer,
|
||||
mode, request.guided_whitespace_pattern)
|
||||
|
||||
|
||||
def _get_guide_and_mode(
|
||||
@ -115,11 +109,10 @@ def _get_guide_and_mode(
|
||||
return None, None
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_cached_logits_processor(guide: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None]):
|
||||
def _get_logits_processor(
|
||||
guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None]
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
|
||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||
|
@ -21,7 +21,7 @@ from functools import lru_cache
|
||||
from typing import Callable, DefaultDict, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
|
||||
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@ -29,28 +29,32 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self):
|
||||
# Child class should use initialize in their init.
|
||||
self.fsm: FSM
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize the FSM states."""
|
||||
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||
def __init__(self, guide: Guide):
|
||||
self._guide: Guide = guide
|
||||
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||
|
||||
def __call__(self, input_ids: List[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Use the FSM to bias the logits before sampling the next token."""
|
||||
seq_id = hash(tuple(input_ids))
|
||||
|
||||
if len(input_ids) == 0:
|
||||
self.init_state()
|
||||
else:
|
||||
if len(input_ids) > 0:
|
||||
last_token = input_ids[-1]
|
||||
last_seq_id = hash(tuple(input_ids[:-1]))
|
||||
self.fsm_state[seq_id] = self.fsm.next_state(
|
||||
self.fsm_state[last_seq_id], last_token)
|
||||
self._fsm_state[seq_id] = self._guide.get_next_state(
|
||||
state=self._fsm_state[last_seq_id], token_id=last_token)
|
||||
|
||||
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
||||
instruction = self._guide.get_next_instruction(
|
||||
state=self._fsm_state[seq_id])
|
||||
|
||||
if type(instruction) == Generate:
|
||||
allowed_tokens = instruction.tokens
|
||||
elif type(instruction) == Write:
|
||||
# TODO: support fast forward tokens
|
||||
allowed_tokens = [instruction.tokens[0]]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported instruction type {type(instruction)}")
|
||||
|
||||
mask = torch.full((scores.shape[-1], ),
|
||||
-math.inf,
|
||||
@ -62,6 +66,13 @@ class BaseLogitsProcessor:
|
||||
|
||||
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_guide(cls, regex_string: str,
|
||||
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
return RegexGuide(regex_string, tokenizer)
|
||||
|
||||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
|
||||
"""Compile the FSM that drives the regex-structured generation.
|
||||
|
||||
@ -73,9 +84,8 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
self.fsm = fsm
|
||||
super().__init__(
|
||||
RegexLogitsProcessor._get_guide(regex_string, tokenizer))
|
||||
|
||||
|
||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
@ -115,6 +125,12 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
|
||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
return CFGGuide(cfg, tokenizer)
|
||||
|
||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
|
||||
"""Compile the FSM that drives the context free grammar generation.
|
||||
|
||||
@ -126,17 +142,11 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
fsm = CFGFSM(cfg, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize state with a CFGFSM copy."""
|
||||
super().init_state()
|
||||
self.fsm = self.fsm.copy()
|
||||
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
|
||||
self._guide = self._guide.copy()
|
||||
|
||||
|
||||
@lru_cache
|
||||
@lru_cache(maxsize=32)
|
||||
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
|
Reference in New Issue
Block a user