Frontend: Adding LM Format Enforcer support to V1 engine (#22564)
Signed-off-by: Noam Gat <noamgat@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@ -18,7 +18,7 @@ prometheus_client >= 0.18.0
|
||||
pillow # Required for image processing
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.11, < 0.11
|
||||
lm-format-enforcer == 0.11.3
|
||||
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||
outlines_core == 0.2.10 ; platform_machine != "s390x"
|
||||
outlines == 0.1.11 ; platform_machine == "s390x"
|
||||
|
@ -41,8 +41,11 @@ EAGLE_SPEC_CONFIG = {
|
||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
|
||||
None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
|
||||
@ -148,7 +151,8 @@ def test_structured_output(
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
assert "\n" not in generated_text
|
||||
if guided_decoding_backend != 'lm-format-enforcer':
|
||||
assert "\n" not in generated_text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
@ -225,7 +229,7 @@ def test_structured_output(
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
if guided_decoding_backend != "outlines":
|
||||
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
|
||||
#
|
||||
# Test 4: Generate SQL statement using EBNF grammar
|
||||
#
|
||||
@ -439,7 +443,7 @@ def test_structured_output(
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
if guided_decoding_backend != "outlines":
|
||||
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
|
@ -3057,7 +3057,8 @@ def get_served_model_name(model: str,
|
||||
return served_model_name
|
||||
|
||||
|
||||
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"]
|
||||
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines",
|
||||
"lm-format-enforcer"]
|
||||
|
||||
|
||||
@config
|
||||
|
@ -21,6 +21,8 @@ from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||
validate_structured_output_request_lm_format_enforcer)
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
validate_structured_output_request_outlines)
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
@ -200,6 +202,9 @@ class Processor:
|
||||
elif engine_level_backend == "outlines":
|
||||
# outlines backend
|
||||
validate_structured_output_request_outlines(params)
|
||||
elif engine_level_backend == "lm-format-enforcer":
|
||||
# lm format enforcer backend
|
||||
validate_structured_output_request_lm_format_enforcer(params)
|
||||
else:
|
||||
# NOTE: engine_level_backend must be "auto" here, because we have
|
||||
# checked supported_backends above.
|
||||
|
@ -108,6 +108,14 @@ class StructuredOutputManager:
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "lm-format-enforcer":
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501
|
||||
LMFormatEnforcerBackend)
|
||||
self.backend = LMFormatEnforcerBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported structured output backend: {backend}")
|
||||
|
167
vllm/v1/structured_output/backend_lm_format_enforcer.py
Normal file
167
vllm/v1/structured_output/backend_lm_format_enforcer.py
Normal file
@ -0,0 +1,167 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import lmformatenforcer
|
||||
import lmformatenforcer.integrations.vllm as lmfe_vllm
|
||||
else:
|
||||
lmformatenforcer = LazyLoader("lmformatenforcer", globals(),
|
||||
"lmformatenforcer")
|
||||
lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(),
|
||||
"lmformatenforcer.integrations.vllm")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData:
|
||||
return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer, use_bitmask=True, vocab_size=vocab_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMFormatEnforcerGrammar(StructuredOutputGrammar):
|
||||
token_enforcer: lmformatenforcer.TokenEnforcer
|
||||
current_tokens_prefix: list[int] = field(default_factory=list)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
original_len = len(self.current_tokens_prefix)
|
||||
for token in tokens:
|
||||
if not self.token_enforcer.get_allowed_tokens(
|
||||
self.current_tokens_prefix).is_token_allowed(token):
|
||||
# Rollback partial updates to ensure atomicity.
|
||||
del self.current_tokens_prefix[original_len:]
|
||||
return False
|
||||
self.current_tokens_prefix.append(token)
|
||||
return True
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
for prefix_length in range(len(tokens)):
|
||||
prefix = tokens[:prefix_length]
|
||||
next_token = tokens[prefix_length]
|
||||
if not self.token_enforcer.get_allowed_tokens(
|
||||
self.current_tokens_prefix +
|
||||
prefix).is_token_allowed(next_token):
|
||||
break
|
||||
else:
|
||||
return tokens
|
||||
|
||||
return tokens[:prefix_length]
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens]
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(
|
||||
self.current_tokens_prefix)
|
||||
bitmask[batch_index] = allowed_tokens.allowed_tokens
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
# We are considered terminated if the prefix ends with eos_token_id
|
||||
return_value = len(
|
||||
self.current_tokens_prefix) > 0 and self.current_tokens_prefix[
|
||||
-1] == self.token_enforcer.eos_token_id
|
||||
return return_value
|
||||
|
||||
def reset(self):
|
||||
self.current_tokens_prefix = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMFormatEnforcerBackend(StructuredOutputBackend):
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
self.tokenizer, self.vocab_size)
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
character_level_parser: lmformatenforcer.CharacterLevelParser
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
spec_dict = json.loads(grammar_spec)
|
||||
character_level_parser = lmformatenforcer.JsonSchemaParser(
|
||||
spec_dict)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
character_level_parser = lmformatenforcer.JsonSchemaParser(None)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
character_level_parser = lmformatenforcer.RegexParser(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
choices = ast.literal_eval(grammar_spec)
|
||||
character_level_parser = lmformatenforcer.UnionParser(
|
||||
[lmformatenforcer.StringParser(choice) for choice in choices])
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid request type for LM Format Enforcer backend"
|
||||
f"({request_type!s})")
|
||||
max_rollback_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config is not None else 0)
|
||||
|
||||
if max_rollback_tokens > 0:
|
||||
raise ValueError(
|
||||
"LM Format Enforcer backend does not support speculative tokens"
|
||||
)
|
||||
|
||||
token_enforcer = lmformatenforcer.TokenEnforcer(
|
||||
tokenizer_data=self.tokenizer_data,
|
||||
parser=character_level_parser,
|
||||
)
|
||||
return LMFormatEnforcerGrammar(token_enforcer)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(max_num_seqs, (self.vocab_size + 31) // 32),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
def validate_structured_output_request_lm_format_enforcer(
|
||||
params: SamplingParams):
|
||||
if params.guided_decoding is None:
|
||||
return
|
||||
|
||||
gd_params = params.guided_decoding
|
||||
|
||||
if gd_params.regex:
|
||||
return
|
||||
elif gd_params.json:
|
||||
if isinstance(gd_params.json, str):
|
||||
try:
|
||||
# make sure schema is valid json
|
||||
json.loads(gd_params.json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
try:
|
||||
json.dumps(gd_params.json)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error serializing guided decoding jsonschema: {e}"
|
||||
) from e
|
||||
return
|
||||
elif gd_params.choice:
|
||||
return
|
||||
elif gd_params.grammar:
|
||||
raise ValueError("LM Format Enforcer guided decoding backend "
|
||||
"does not support grammar specifications")
|
Reference in New Issue
Block a user