mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix CFGGuide and use outlines for grammars that can't convert to GBNF (#11389)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@ -174,11 +174,6 @@ def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_grammar(sample_sql_statements, llm,
|
||||
guided_decoding_backend: str):
|
||||
if guided_decoding_backend == "outlines":
|
||||
pytest.skip("Outlines backend fails in this test case with:\n"
|
||||
"AttributeError: Error in model execution: 'ParserConf' "
|
||||
"object has no attribute 'deterministic'")
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
|
@ -3,6 +3,9 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.utils import (
|
||||
convert_lark_to_gbnf, grammar_is_likely_lark,
|
||||
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -15,76 +18,6 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||
|
||||
def check_object(obj: dict) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for pattern restrictions
|
||||
if "pattern" in obj:
|
||||
return True
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and any(
|
||||
key in obj for key in [
|
||||
"minimum", "maximum", "exclusiveMinimum",
|
||||
"exclusiveMaximum", "multipleOf"
|
||||
]):
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def has_lmf_unsupported_json_features(schema: dict) -> bool:
|
||||
"""
|
||||
Check if JSON schema contains features unsupported
|
||||
by lm_format_enforcer.
|
||||
|
||||
Known issues:
|
||||
- Regex patterns:
|
||||
"grade": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-D]$" # Regex pattern
|
||||
},
|
||||
"""
|
||||
|
||||
def check_object(obj: dict) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for pattern restrictions
|
||||
if "pattern" in obj:
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def maybe_backend_fallback(
|
||||
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
||||
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
||||
@ -127,6 +60,20 @@ def maybe_backend_fallback(
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
# xgrammar only supports GBNF grammars, so we must convert Lark.
|
||||
# We must check if the grammar is likely Lark and if that
|
||||
# grammar is convertible to GBNF
|
||||
elif (guided_params.grammar is not None
|
||||
and grammar_is_likely_lark(guided_params.grammar)):
|
||||
try:
|
||||
convert_lark_to_gbnf(guided_params.grammar)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"xgrammar does not support Lark grammars and the "
|
||||
"grammar failed to convert to GBNF. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
if (guided_params.backend == "outlines"
|
||||
and guided_params.json_object is not None):
|
||||
# outlines doesn't support json_object, fallback to xgrammar
|
||||
|
@ -21,10 +21,11 @@ from typing import Callable, DefaultDict, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lark import Lark
|
||||
from outlines import grammars
|
||||
from outlines.caching import cache
|
||||
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
|
||||
from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
|
||||
RegexGuide, Write)
|
||||
from outlines.fsm.parsing import PartialLark
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@ -34,7 +35,9 @@ class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self, guide: Guide):
|
||||
self._guide: Guide = guide
|
||||
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||
# CFGState is used for the FSM state for CFGGuide
|
||||
self._fsm_state: DefaultDict[int, Union[int,
|
||||
CFGState]] = defaultdict(int)
|
||||
|
||||
def __call__(self, input_ids: List[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
@ -54,15 +57,13 @@ class BaseLogitsProcessor:
|
||||
# On the first time this is called, we simply re-create
|
||||
# the Lark object.
|
||||
if isinstance(self._guide, CFGGuide):
|
||||
self._guide.parser = Lark(
|
||||
self._guide.parser = PartialLark(
|
||||
self._guide.cfg_string,
|
||||
parser="lalr",
|
||||
lexer="contextual",
|
||||
propagate_positions=False,
|
||||
maybe_placeholders=False,
|
||||
regex=True,
|
||||
import_paths=[grammars.GRAMMAR_PATH],
|
||||
)
|
||||
self._fsm_state[seq_id] = CFGState(
|
||||
parser_state=self._guide.parser.parse(""), prev_token=None)
|
||||
|
||||
instruction = self._guide.get_next_instruction(
|
||||
state=self._fsm_state[seq_id])
|
||||
@ -200,7 +201,8 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
|
||||
or token == "<0x20>"):
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
@ -211,6 +213,9 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||
|
||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
|
||||
and isinstance(inp_tokens[0], list)):
|
||||
inp_tokens = inp_tokens[0]
|
||||
return [decoder(inp_tokens)]
|
||||
|
||||
return new_decoder
|
||||
|
@ -1,6 +1,76 @@
|
||||
import re
|
||||
|
||||
|
||||
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||
|
||||
def check_object(obj: dict) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for pattern restrictions
|
||||
if "pattern" in obj:
|
||||
return True
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and any(
|
||||
key in obj for key in [
|
||||
"minimum", "maximum", "exclusiveMinimum",
|
||||
"exclusiveMaximum", "multipleOf"
|
||||
]):
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def has_lmf_unsupported_json_features(schema: dict) -> bool:
|
||||
"""
|
||||
Check if JSON schema contains features unsupported
|
||||
by lm_format_enforcer.
|
||||
|
||||
Known issues:
|
||||
- Regex patterns:
|
||||
"grade": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-D]$" # Regex pattern
|
||||
},
|
||||
"""
|
||||
|
||||
def check_object(obj: dict) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for pattern restrictions
|
||||
if "pattern" in obj:
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||
"""
|
||||
Check if grammar appears to use Lark syntax.
|
@ -14,8 +14,8 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from vllm.model_executor.guided_decoding.xgrammar_utils import (
|
||||
convert_lark_to_gbnf, grammar_is_likely_lark)
|
||||
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||
grammar_is_likely_lark)
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
Reference in New Issue
Block a user