Frontend: Adding LM Format Enforcer support to V1 engine (#22564)

Signed-off-by: Noam Gat <noamgat@gmail.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Noam Gat
2025-08-25 05:31:22 +03:00
committed by GitHub
parent 504d914314
commit 39971db3aa
6 changed files with 190 additions and 5 deletions

View File

@ -18,7 +18,7 @@ prometheus_client >= 0.18.0
pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
lm-format-enforcer == 0.11.3
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines_core == 0.2.10 ; platform_machine != "s390x"
outlines == 0.1.11 ; platform_machine == "s390x"

View File

@ -41,8 +41,11 @@ EAGLE_SPEC_CONFIG = {
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
None),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
@ -148,7 +151,8 @@ def test_structured_output(
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
if guided_decoding_backend != 'lm-format-enforcer':
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@ -225,7 +229,7 @@ def test_structured_output(
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
if guided_decoding_backend != "outlines":
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
#
# Test 4: Generate SQL statement using EBNF grammar
#
@ -439,7 +443,7 @@ def test_structured_output(
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
if guided_decoding_backend != "outlines":
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
#
# Test 11: Generate structured output using structural_tag format
#

View File

@ -3057,7 +3057,8 @@ def get_served_model_name(model: str,
return served_model_name
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines",
"lm-format-enforcer"]
@config

View File

@ -21,6 +21,8 @@ from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines)
from vllm.v1.structured_output.backend_xgrammar import (
@ -200,6 +202,9 @@ class Processor:
elif engine_level_backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
elif engine_level_backend == "lm-format-enforcer":
# lm format enforcer backend
validate_structured_output_request_lm_format_enforcer(params)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above.

View File

@ -108,6 +108,14 @@ class StructuredOutputManager:
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
elif backend == "lm-format-enforcer":
from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501
LMFormatEnforcerBackend)
self.backend = LMFormatEnforcerBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
else:
raise ValueError(
f"Unsupported structured output backend: {backend}")

View File

@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import ast
import json
from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING
import torch
from transformers import PreTrainedTokenizerBase
from vllm.sampling_params import SamplingParams
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
if TYPE_CHECKING:
import lmformatenforcer
import lmformatenforcer.integrations.vllm as lmfe_vllm
else:
lmformatenforcer = LazyLoader("lmformatenforcer", globals(),
"lmformatenforcer")
lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(),
"lmformatenforcer.integrations.vllm")
@lru_cache
def _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer: PreTrainedTokenizerBase,
vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData:
return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data(
tokenizer, use_bitmask=True, vocab_size=vocab_size)
@dataclass
class LMFormatEnforcerGrammar(StructuredOutputGrammar):
token_enforcer: lmformatenforcer.TokenEnforcer
current_tokens_prefix: list[int] = field(default_factory=list)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
original_len = len(self.current_tokens_prefix)
for token in tokens:
if not self.token_enforcer.get_allowed_tokens(
self.current_tokens_prefix).is_token_allowed(token):
# Rollback partial updates to ensure atomicity.
del self.current_tokens_prefix[original_len:]
return False
self.current_tokens_prefix.append(token)
return True
def validate_tokens(self, tokens: list[int]) -> list[int]:
for prefix_length in range(len(tokens)):
prefix = tokens[:prefix_length]
next_token = tokens[prefix_length]
if not self.token_enforcer.get_allowed_tokens(
self.current_tokens_prefix +
prefix).is_token_allowed(next_token):
break
else:
return tokens
return tokens[:prefix_length]
def rollback(self, num_tokens: int) -> None:
self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens]
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
allowed_tokens = self.token_enforcer.get_allowed_tokens(
self.current_tokens_prefix)
bitmask[batch_index] = allowed_tokens.allowed_tokens
def is_terminated(self) -> bool:
# We are considered terminated if the prefix ends with eos_token_id
return_value = len(
self.current_tokens_prefix) > 0 and self.current_tokens_prefix[
-1] == self.token_enforcer.eos_token_id
return return_value
def reset(self):
self.current_tokens_prefix = []
@dataclass
class LMFormatEnforcerBackend(StructuredOutputBackend):
def __post_init__(self):
self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
self.tokenizer, self.vocab_size)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
character_level_parser: lmformatenforcer.CharacterLevelParser
if request_type == StructuredOutputOptions.JSON:
spec_dict = json.loads(grammar_spec)
character_level_parser = lmformatenforcer.JsonSchemaParser(
spec_dict)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
character_level_parser = lmformatenforcer.JsonSchemaParser(None)
elif request_type == StructuredOutputOptions.REGEX:
character_level_parser = lmformatenforcer.RegexParser(grammar_spec)
elif request_type == StructuredOutputOptions.CHOICE:
choices = ast.literal_eval(grammar_spec)
character_level_parser = lmformatenforcer.UnionParser(
[lmformatenforcer.StringParser(choice) for choice in choices])
else:
raise ValueError(
"Invalid request type for LM Format Enforcer backend"
f"({request_type!s})")
max_rollback_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config is not None else 0)
if max_rollback_tokens > 0:
raise ValueError(
"LM Format Enforcer backend does not support speculative tokens"
)
token_enforcer = lmformatenforcer.TokenEnforcer(
tokenizer_data=self.tokenizer_data,
parser=character_level_parser,
)
return LMFormatEnforcerGrammar(token_enforcer)
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
return torch.full(
(max_num_seqs, (self.vocab_size + 31) // 32),
-1,
dtype=torch.int32,
pin_memory=torch.cuda.is_available(),
)
def destroy(self):
pass
def validate_structured_output_request_lm_format_enforcer(
params: SamplingParams):
if params.guided_decoding is None:
return
gd_params = params.guided_decoding
if gd_params.regex:
return
elif gd_params.json:
if isinstance(gd_params.json, str):
try:
# make sure schema is valid json
json.loads(gd_params.json)
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
try:
json.dumps(gd_params.json)
except Exception as e:
raise ValueError(
f"Error serializing guided decoding jsonschema: {e}"
) from e
return
elif gd_params.choice:
return
elif gd_params.grammar:
raise ValueError("LM Format Enforcer guided decoding backend "
"does not support grammar specifications")