Simplify (and fix) passing of guided decoding backend options (#17008)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-29 20:02:23 +01:00
committed by GitHub
parent 2fa2a50bf9
commit a6977dbd15
17 changed files with 309 additions and 217 deletions

View File

@ -112,8 +112,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
"alan.turing@enigma.com\n")
try:
# The no-fallback option forces vLLM to use xgrammar, so when it fails
# you get a 400 with the reason why
# The guided_decoding_disable_fallback option forces vLLM to use
# xgrammar, so when it fails you get a 400 with the reason why
completion = client.chat.completions.create(
model=model,
messages=[{
@ -123,7 +123,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
extra_body={
"guided_regex": r"\w+@\w+\.com\n",
"stop": ["\n"],
"guided_decoding_backend": "xgrammar:no-fallback"
"guided_decoding_backend": "xgrammar",
"guided_decoding_disable_fallback": True,
},
)
return completion.choices[0].message.content

View File

@ -16,10 +16,11 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = [
"outlines",
"lm-format-enforcer",
"xgrammar:disable-any-whitespace",
"guidance:disable-any-whitespace",
# (backend, disable_any_whitespace),
("outlines", False),
("lm-format-enforcer", False),
("xgrammar", True),
("guidance", True),
]
@ -36,13 +37,17 @@ def llm():
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend))
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
@ -62,14 +67,18 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
@ -92,14 +101,18 @@ def test_guided_json_completion(sample_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_complex_json_schema,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_complex_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example JSON for an assignment grade "
f"that fits this schema: {sample_complex_json_schema}"
@ -123,14 +136,18 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_definition_json_schema,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_definition_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example JSON for solving 8x + 7 = -23 "
f"that fits this schema: {sample_definition_json_schema}"
@ -154,14 +171,18 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_enum_json_schema,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_enum_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
"Create a bug report JSON that fits this schema: "
f"{sample_enum_json_schema}. Make it for a high priority critical bug."
@ -195,14 +216,18 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
@ -221,15 +246,19 @@ def test_guided_choice_completion(sample_guided_choice, llm,
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_statements,
backend=guided_decoding_backend))
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_statements,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
@ -300,7 +329,8 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
top_p=0.95,
guided_decoding=GuidedDecodingParams(
json=unsupported_json,
backend="xgrammar:no-fallback"))
backend="xgrammar",
disable_fallback=True))
with pytest.raises(
ValueError,
@ -312,14 +342,18 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str):
sampling_params = SamplingParams(temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend))
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
@ -337,7 +371,7 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
print(generated_text)
assert generated_text is not None
if 'disable-any-whitespace' in guided_decoding_backend:
if disable_any_whitespace:
assert "\n" not in generated_text
# Parse to verify it is valid JSON
@ -359,14 +393,18 @@ class CarDescription(BaseModel):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend))
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
@ -387,9 +425,10 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm,
guided_decoding_backend: str):
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sample_output_schema = {
"type": "object",
"properties": {
@ -413,8 +452,10 @@ def test_guided_number_range_json_completion(llm,
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_output_schema,
backend=guided_decoding_backend),
guided_decoding=GuidedDecodingParams(
json=sample_output_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace),
)
outputs = llm.generate(
prompts=[
@ -466,8 +507,12 @@ def test_guidance_no_additional_properties(llm):
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>\n<|im_start|>assistant\n")
def generate_with_backend(backend):
guided_params = GuidedDecodingParams(json=schema, backend=backend)
def generate_with_backend(backend, disable_additional_properties):
guided_params = GuidedDecodingParams(
json=schema,
backend=backend,
disable_any_whitespace=True,
disable_additional_properties=disable_additional_properties)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
@ -481,7 +526,7 @@ def test_guidance_no_additional_properties(llm):
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json
base_generated = generate_with_backend('guidance:disable-any-whitespace')
base_generated = generate_with_backend("guidance", False)
assert "a1" in base_generated
assert "a2" in base_generated
assert "a3" in base_generated
@ -490,8 +535,7 @@ def test_guidance_no_additional_properties(llm):
assert "a5" in base_generated
assert "a6" in base_generated
generated = generate_with_backend(
'guidance:no-additional-properties,disable-any-whitespace')
generated = generate_with_backend("guidance", True)
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated

View File

@ -202,12 +202,15 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
def test_guided_decoding_backend_options():
"""Test backend-specific options"""
params = GuidedDecodingParams(
backend="xgrammar:option-1,option-2,option-3")
assert params.backend_options() == ["option-1", "option-2", "option-3"]
no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
assert no_fallback.no_fallback()
with pytest.warns(DeprecationWarning):
guided_decoding_params = GuidedDecodingParams(
backend=
"xgrammar:no-fallback,disable-any-whitespace,no-additional-properties"
)
assert guided_decoding_params.backend == "xgrammar"
assert guided_decoding_params.disable_fallback
assert guided_decoding_params.disable_any_whitespace
assert guided_decoding_params.disable_additional_properties
def test_pickle_xgrammar_tokenizer_data():

View File

@ -17,15 +17,12 @@ from vllm.platforms import current_platform
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
"auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance:disable-any-whitespace",
"auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
"mistral"),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar:disable-any-whitespace", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance:disable-any-whitespace", "auto"),
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
]
PARAMS_MODELS_TOKENIZER_MODE = [
@ -73,6 +70,7 @@ def test_structured_output(
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True,
tokenizer_mode=tokenizer_mode)
#
@ -98,8 +96,7 @@ def test_structured_output(
generated_text = output.outputs[0].text
assert generated_text is not None
if 'disable-any-whitespace' in guided_decoding_backend:
assert "\n" not in generated_text
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@ -520,10 +517,11 @@ def test_structured_output_auto_mode(
def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "1")
backend = 'guidance:no-additional-properties,disable-any-whitespace'
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
max_model_len=1024,
guided_decoding_backend=backend)
guided_decoding_backend="guidance",
guided_decoding_disable_any_whitespace=True,
guided_decoding_disable_additional_properties=True)
schema = {
'type': 'object',
@ -548,7 +546,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
"<|im_end|>\n<|im_start|>assistant\n")
def generate_with_backend(backend):
guided_params = GuidedDecodingParams(json=schema, backend=backend)
guided_params = GuidedDecodingParams(
json=schema,
backend=backend,
disable_any_whitespace=True,
disable_additional_properties=True)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
@ -562,8 +564,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json
generated = generate_with_backend(
'guidance:no-additional-properties,disable-any-whitespace')
generated = generate_with_backend("guidance")
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated

View File

@ -57,7 +57,8 @@ def test_unsupported_configs(monkeypatch):
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
guided_decoding_backend="lm-format-enforcer:no-fallback",
guided_decoding_backend="lm-format-enforcer",
guided_decoding_disable_fallback=True,
).create_engine_config()
with pytest.raises(NotImplementedError):

View File

@ -17,12 +17,14 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union, get_args, get_origin)
Optional, Protocol, TypeVar, Union, cast, get_args,
get_origin)
import torch
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
@ -32,7 +34,6 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum, current_platform
from vllm.sampling_params import GuidedDecodingParams
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
@ -344,7 +345,7 @@ class ModelConfig:
def __init__(
self,
model: str,
task: Union[TaskOption, Literal["draft"]],
task: Literal[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
@ -701,7 +702,7 @@ class ModelConfig:
def _resolve_task(
self,
task_option: Union[TaskOption, Literal["draft"]],
task_option: Literal[TaskOption, Literal["draft"]],
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
if task_option == "draft":
return {"draft"}, "draft"
@ -3185,13 +3186,36 @@ GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine."""
guided_decoding_backend: GuidedDecodingBackend = \
"auto" if envs.VLLM_USE_V1 else "xgrammar"
@property
@deprecated(
"`guided_decoding_backend` is deprecated and has been renamed to "
"`backend`. This will be removed in v0.10.0. Please use the "
"`backend` argument instead.")
def guided_decoding_backend(self) -> GuidedDecodingBackend:
return self.backend
@guided_decoding_backend.setter
def guided_decoding_backend(self, value: GuidedDecodingBackend):
self.backend = value
backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar"
"""Which engine will be used for guided decoding (JSON schema / regex etc)
by default. With "auto", we will make opinionated choices based on request
contents and what the backend libraries currently support, so the behavior
is subject to change in each release."""
disable_fallback: bool = False
"""If `True`, vLLM will not fallback to a different backend on error."""
disable_any_whitespace: bool = False
"""If `True`, the model will not generate any whitespace during guided
decoding. This is only supported for xgrammar and guidance backends."""
disable_additional_properties: bool = False
"""If `True`, the `guidance` backend will not use `additionalProperties`
in the JSON schema. This is only supported for the `guidance` backend and
is used to better align its behaviour with `outlines` and `xgrammar`."""
reasoning_backend: Optional[str] = None
"""Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format.
@ -3217,15 +3241,41 @@ class DecodingConfig:
return hash_str
def __post_init__(self):
backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name
if ":" in self.backend:
self._extract_backend_options()
if envs.VLLM_USE_V1:
valid_guided_backends = get_args(GuidedDecodingBackendV1)
else:
valid_guided_backends = get_args(GuidedDecodingBackendV0)
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
if self.backend not in valid_guided_backends:
raise ValueError(f"Invalid backend '{self.backend}',"
f" must be one of {valid_guided_backends}")
if (self.disable_any_whitespace
and self.backend not in ("xgrammar", "guidance")):
raise ValueError("disable_any_whitespace is only supported for "
"xgrammar and guidance backends.")
if (self.disable_additional_properties and self.backend != "guidance"):
raise ValueError("disable_additional_properties is only supported "
"for the guidance backend.")
@deprecated(
"Passing guided decoding backend options inside backend in the format "
"'backend:...' is deprecated. This will be removed in v0.10.0. Please "
"use the dedicated arguments '--disable-fallback', "
"'--disable-any-whitespace' and '--disable-additional-properties' "
"instead.")
def _extract_backend_options(self):
"""Extract backend options from the backend string."""
backend, options = self.backend.split(":")
self.backend = cast(GuidedDecodingBackend, backend)
options_set = set(options.strip().split(","))
if "no-fallback" in options_set:
self.disable_fallback = True
if "disable-any-whitespace" in options_set:
self.disable_any_whitespace = True
if "no-additional-properties" in options_set:
self.disable_additional_properties = True
@dataclass

View File

@ -18,9 +18,9 @@ from vllm import version
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, DecodingConfig, Device,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackendV1, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, MultiModalConfig,
GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
@ -317,7 +317,12 @@ class EngineArgs:
bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
guided_decoding_disable_any_whitespace: bool = \
DecodingConfig.disable_any_whitespace
guided_decoding_disable_additional_properties: bool = \
DecodingConfig.disable_additional_properties
logits_processor_pattern: Optional[str] = None
speculative_config: Optional[Dict[str, Any]] = None
@ -498,9 +503,17 @@ class EngineArgs:
title="DecodingConfig",
description=DecodingConfig.__doc__,
)
guided_decoding_group.add_argument("--guided-decoding-backend",
**guided_decoding_kwargs["backend"])
guided_decoding_group.add_argument(
'--guided-decoding-backend',
**guided_decoding_kwargs["guided_decoding_backend"])
"--guided-decoding-disable-fallback",
**guided_decoding_kwargs["disable_fallback"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-any-whitespace",
**guided_decoding_kwargs["disable_any_whitespace"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-additional-properties",
**guided_decoding_kwargs["disable_additional_properties"])
guided_decoding_group.add_argument(
"--reasoning-parser",
# This choices is a special case because it's not static
@ -1244,7 +1257,11 @@ class EngineArgs:
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend,
backend=self.guided_decoding_backend,
disable_fallback=self.guided_decoding_disable_fallback,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
disable_additional_properties=\
self.guided_decoding_disable_additional_properties,
reasoning_backend=self.reasoning_parser
if self.enable_reasoning else None,
)
@ -1335,9 +1352,8 @@ class EngineArgs:
recommend_to_remove=True)
return False
# remove backend options when doing this check
if self.guided_decoding_backend.split(':')[0] \
not in get_args(GuidedDecodingBackendV1):
if self.guided_decoding_backend not in get_args(
GuidedDecodingBackendV1):
_raise_or_fallback(
feature_name=
f"--guided-decoding-backend={self.guided_decoding_backend}",

View File

@ -2091,7 +2091,7 @@ class LLMEngine:
tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
self.decoding_config.backend
if self.decoding_config.reasoning_backend is not None:
logger.debug("Building with reasoning backend %s",

View File

@ -615,9 +615,9 @@ class MQLLMEngineClient(EngineClient):
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=(self.decoding_config.guided_decoding_backend
default_guided_backend=(self.decoding_config.backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
else DecodingConfig.backend),
model_config=self.model_config,
reasoning_backend=self.decoding_config.reasoning_backend,
)

View File

@ -26,8 +26,8 @@ def maybe_backend_fallback(
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
fallback: str) -> None:
"""Change the backend to the specified fallback with a warning log,
or raise a ValueError if the `no-fallback` option is specified."""
if guided_params.no_fallback():
or raise a ValueError if the `disable_fallback` option is specified."""
if guided_params.disable_fallback:
raise ValueError(message)
logger.warning("%s Falling back to use %s instead.", message, fallback)
@ -40,7 +40,7 @@ def maybe_backend_fallback(
guided_params.backend = "xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend_name == "lm-format-enforcer":
if guided_params.backend == "lm-format-enforcer":
if guided_params.grammar is not None:
fallback_or_error(
guided_params,
@ -55,7 +55,7 @@ def maybe_backend_fallback(
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges.", "outlines")
if guided_params.backend_name == "xgrammar":
if guided_params.backend == "xgrammar":
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
@ -87,7 +87,7 @@ def maybe_backend_fallback(
guided_params,
"xgrammar module cannot be imported successfully.", "outlines")
if (guided_params.backend_name == "outlines"
if (guided_params.backend == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params,
@ -111,7 +111,7 @@ async def get_guided_decoding_logits_processor(
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
@ -122,12 +122,12 @@ async def get_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend_name == 'xgrammar':
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
if guided_params.backend == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
@ -152,23 +152,23 @@ def get_local_guided_decoding_logits_processor(
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer, reasoner)
if guided_params.backend_name == 'lm-format-enforcer':
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend_name == 'xgrammar':
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
if guided_params.backend == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(

View File

@ -21,13 +21,12 @@ def get_local_guidance_guided_decoding_logits_processor(
"""
grm = ""
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
any_whitespace = not guided_params.disable_any_whitespace
if (guide_json := guided_params.json) is not None:
# Optionally set additionalProperties to False at the top-level
# By default, other backends do not allow additional top-level
# properties, so this makes guidance more similar to other backends
if 'no-additional-properties' in guided_params.backend_options():
if guided_params.disable_additional_properties:
if not isinstance(guide_json, str):
guide_json = json.dumps(guide_json)
guide_json = process_for_additional_properties(guide_json)

View File

@ -175,8 +175,7 @@ class GrammarConfig:
else:
json_str = guided_params.json
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
any_whitespace = not guided_params.disable_any_whitespace
# Check and log if model with xgrammar and whitespace have history
# of runaway generation of whitespaces.
@ -191,11 +190,10 @@ class GrammarConfig:
model_with_warn = 'Qwen'
if model_with_warn is not None and any_whitespace:
msg = (f"{model_with_warn} "
f"model detected, consider set "
f"`guided_backend=xgrammar:disable-any-whitespace` "
f"to prevent runaway generation of whitespaces.")
logger.info_once(msg)
logger.info_once(
"%s model detected, consider setting "
"`disable_any_whitespace` to prevent runaway generation "
"of whitespaces.", model_with_warn)
# Validate the schema and raise ValueError here if it is invalid.
# This is to avoid exceptions in model execution, which will crash
# the engine worker process.

View File

@ -8,6 +8,7 @@ from typing import Annotated, Any, Optional, Union
import msgspec
from pydantic import BaseModel
from typing_extensions import deprecated
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
@ -37,6 +38,10 @@ class GuidedDecodingParams:
json_object: Optional[bool] = None
"""These are other options that can be set"""
backend: Optional[str] = None
backend_was_auto: bool = False
disable_fallback: bool = False
disable_any_whitespace: bool = False
disable_additional_properties: bool = False
whitespace_pattern: Optional[str] = None
structural_tag: Optional[str] = None
@ -68,36 +73,6 @@ class GuidedDecodingParams:
structural_tag=structural_tag,
)
@property
def backend_name(self) -> str:
"""Return the backend name without any options.
For example if the backend is "xgrammar:no-fallback", returns "xgrammar"
"""
return (self.backend or "").split(":")[0]
def backend_options(self) -> list[str]:
"""Return the backend options as a list of strings."""
if not self.backend or ":" not in self.backend:
return []
return self.backend.split(":")[1].split(",")
def add_option(self, opt_name: str) -> None:
"""Adds an option to the backend options."""
if not self.backend:
self.backend = f":{opt_name}"
elif ":" not in self.backend:
self.backend += f":{opt_name}"
else:
options = set(self.backend_options())
options.add(opt_name)
self.backend = f"{self.backend_name}:{','.join(sorted(options))}"
def no_fallback(self) -> bool:
"""Returns True if the "no-fallback" option is supplied for the guided
decoding backend"""
return "no-fallback" in self.backend_options()
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
@ -109,6 +84,27 @@ class GuidedDecodingParams:
"You can only use one kind of guided decoding but multiple are "
f"specified: {self.__dict__}")
if self.backend is not None and ":" in self.backend:
self._extract_backend_options()
@deprecated(
"Passing guided decoding backend options inside backend in the format "
"'backend:...' is deprecated. This will be removed in v0.10.0. Please "
"use the dedicated arguments '--disable-fallback', "
"'--disable-any-whitespace' and '--disable-additional-properties' "
"instead.")
def _extract_backend_options(self):
"""Extract backend options from the backend string."""
assert isinstance(self.backend, str)
self.backend, options = self.backend.split(":")
options_set = set(options.strip().split(","))
if "no-fallback" in options_set:
self.disable_fallback = True
if "disable-any-whitespace" in options_set:
self.disable_any_whitespace = True
if "no-additional-properties" in options_set:
self.disable_additional_properties = True
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput

View File

@ -144,7 +144,7 @@ class Processor:
if not params.guided_decoding or not self.decoding_config:
return
engine_level_backend = self.decoding_config.guided_decoding_backend
engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set
@ -152,8 +152,8 @@ class Processor:
# request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params.
if (params.guided_decoding.backend != engine_level_backend
and not (engine_level_backend == "auto" and "_auto"
in params.guided_decoding.backend_options())):
and not (engine_level_backend == "auto"
and params.guided_decoding.backend_was_auto)):
raise ValueError(
"Request-level structured output backend selection is no "
"longer supported. The request specified "
@ -189,7 +189,7 @@ class Processor:
# are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance"
# Remember that this backend was set automatically
params.guided_decoding.add_option("_auto")
params.guided_decoding.backend_was_auto = True
def process_inputs(
self,

View File

@ -45,17 +45,17 @@ class StructuredOutputManager:
# NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None:
backend_name = request.sampling_params.guided_decoding.backend_name
if backend_name == "xgrammar":
backend = request.sampling_params.guided_decoding.backend
if backend == "xgrammar":
from vllm.v1.structured_output.backend_xgrammar import (
XgrammarBackend)
self.backend = XgrammarBackend(self.vllm_config)
elif backend_name == "guidance":
elif backend == "guidance":
self.backend = GuidanceBackend(self.vllm_config)
else:
raise ValueError(
f"Unsupported structured output backend: {backend_name}")
f"Unsupported structured output backend: {backend}")
grammar = self.executor.submit(self._async_create_grammar, request)
request.structured_output_request.grammar = grammar # type: ignore[assignment]

View File

@ -10,7 +10,7 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
@ -65,19 +65,10 @@ class GuidanceBackend(StructuredOutputBackend):
self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.disable_any_whitespace = False
self.no_additional_properties = False
backend_options = GuidedDecodingParams(
backend=vllm_config.decoding_config.guided_decoding_backend
).backend_options()
for option in backend_options:
if option == "disable-any-whitespace":
self.disable_any_whitespace = True
elif option == "no-additional-properties":
self.no_additional_properties = True
else:
raise ValueError(
f"Unsupported option for the guidance backend: {option}")
self.disable_any_whitespace = \
vllm_config.decoding_config.disable_any_whitespace
self.disable_additional_properties = \
vllm_config.decoding_config.disable_additional_properties
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(
@ -87,7 +78,7 @@ class GuidanceBackend(StructuredOutputBackend):
grammar_spec: str) -> StructuredOutputGrammar:
self.serialized_grammar = serialize_guidance_grammar(
request_type, grammar_spec, self.disable_any_whitespace,
self.no_additional_properties)
self.disable_additional_properties)
ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
@ -171,11 +162,11 @@ def serialize_guidance_grammar(
request_type: StructuredOutputOptions,
grammar_spec: Union[str, dict[str, Any]],
disable_any_whitespace: bool = False,
no_additional_properties: bool = False,
disable_additional_properties: bool = False,
) -> str:
def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str:
if no_additional_properties:
if disable_additional_properties:
grammar_spec = process_for_additional_properties(grammar_spec)
return llguidance.LLMatcher.grammar_from_json_schema(
grammar_spec,

View File

@ -9,7 +9,7 @@ import torch
import vllm.envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
@ -37,16 +37,8 @@ class XgrammarBackend(StructuredOutputBackend):
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
self.disable_any_whitespace = False
backend_options = GuidedDecodingParams(
backend=vllm_config.decoding_config.guided_decoding_backend
).backend_options()
for option in backend_options:
if option == "disable-any-whitespace":
self.disable_any_whitespace = True
else:
raise ValueError(
f"Unsupported option for the xgrammar backend: {option}")
self.disable_any_whitespace = \
vllm_config.decoding_config.disable_any_whitespace
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size()