mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Simplify (and fix) passing of guided decoding backend options (#17008)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -112,8 +112,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
|
||||
"alan.turing@enigma.com\n")
|
||||
|
||||
try:
|
||||
# The no-fallback option forces vLLM to use xgrammar, so when it fails
|
||||
# you get a 400 with the reason why
|
||||
# The guided_decoding_disable_fallback option forces vLLM to use
|
||||
# xgrammar, so when it fails you get a 400 with the reason why
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
@ -123,7 +123,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
|
||||
extra_body={
|
||||
"guided_regex": r"\w+@\w+\.com\n",
|
||||
"stop": ["\n"],
|
||||
"guided_decoding_backend": "xgrammar:no-fallback"
|
||||
"guided_decoding_backend": "xgrammar",
|
||||
"guided_decoding_disable_fallback": True,
|
||||
},
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
|
@ -16,10 +16,11 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
GUIDED_DECODING_BACKENDS = [
|
||||
"outlines",
|
||||
"lm-format-enforcer",
|
||||
"xgrammar:disable-any-whitespace",
|
||||
"guidance:disable-any-whitespace",
|
||||
# (backend, disable_any_whitespace),
|
||||
("outlines", False),
|
||||
("lm-format-enforcer", False),
|
||||
("xgrammar", True),
|
||||
("guidance", True),
|
||||
]
|
||||
|
||||
|
||||
@ -36,13 +37,17 @@ def llm():
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
regex=sample_regex,
|
||||
backend=guided_decoding_backend))
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
regex=sample_regex,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
@ -62,14 +67,18 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_json_completion(sample_json_schema, llm,
|
||||
guided_decoding_backend: str):
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_json_schema,
|
||||
backend=guided_decoding_backend))
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}"
|
||||
@ -92,14 +101,18 @@ def test_guided_json_completion(sample_json_schema, llm,
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
guided_decoding_backend: str):
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_complex_json_schema,
|
||||
backend=guided_decoding_backend))
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_complex_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an assignment grade "
|
||||
f"that fits this schema: {sample_complex_json_schema}"
|
||||
@ -123,14 +136,18 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
guided_decoding_backend: str):
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_definition_json_schema,
|
||||
backend=guided_decoding_backend))
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_definition_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for solving 8x + 7 = -23 "
|
||||
f"that fits this schema: {sample_definition_json_schema}"
|
||||
@ -154,14 +171,18 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
guided_decoding_backend: str):
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_enum_json_schema,
|
||||
backend=guided_decoding_backend))
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_enum_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
"Create a bug report JSON that fits this schema: "
|
||||
f"{sample_enum_json_schema}. Make it for a high priority critical bug."
|
||||
@ -195,14 +216,18 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
guided_decoding_backend: str):
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
choice=sample_guided_choice,
|
||||
backend=guided_decoding_backend))
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
choice=sample_guided_choice,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(
|
||||
prompts="The best language for type-safe systems programming is ",
|
||||
sampling_params=sampling_params,
|
||||
@ -221,15 +246,19 @@ def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_grammar(sample_sql_statements, llm,
|
||||
guided_decoding_backend: str):
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
grammar=sample_sql_statements,
|
||||
backend=guided_decoding_backend))
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
grammar=sample_sql_statements,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a sql state that select col_1 from "
|
||||
"table_1 where it is equals to 1"),
|
||||
@ -300,7 +329,8 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=unsupported_json,
|
||||
backend="xgrammar:no-fallback"))
|
||||
backend="xgrammar",
|
||||
disable_fallback=True))
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
@ -312,14 +342,18 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_json_object(llm, guided_decoding_backend: str):
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=100,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json_object=True,
|
||||
backend=guided_decoding_backend))
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=100,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json_object=True,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a JSON object with curly braces for a person with "
|
||||
@ -337,7 +371,7 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
|
||||
if 'disable-any-whitespace' in guided_decoding_backend:
|
||||
if disable_any_whitespace:
|
||||
assert "\n" not in generated_text
|
||||
|
||||
# Parse to verify it is valid JSON
|
||||
@ -359,14 +393,18 @@ class CarDescription(BaseModel):
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=json_schema,
|
||||
backend=guided_decoding_backend))
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(
|
||||
prompts="Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's",
|
||||
@ -387,9 +425,10 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_number_range_json_completion(llm,
|
||||
guided_decoding_backend: str):
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GUIDED_DECODING_BACKENDS)
|
||||
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sample_output_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -413,8 +452,10 @@ def test_guided_number_range_json_completion(llm,
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_output_schema,
|
||||
backend=guided_decoding_backend),
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_output_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace),
|
||||
)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
@ -466,8 +507,12 @@ def test_guidance_no_additional_properties(llm):
|
||||
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
def generate_with_backend(backend):
|
||||
guided_params = GuidedDecodingParams(json=schema, backend=backend)
|
||||
def generate_with_backend(backend, disable_additional_properties):
|
||||
guided_params = GuidedDecodingParams(
|
||||
json=schema,
|
||||
backend=backend,
|
||||
disable_any_whitespace=True,
|
||||
disable_additional_properties=disable_additional_properties)
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=256,
|
||||
guided_decoding=guided_params)
|
||||
@ -481,7 +526,7 @@ def test_guidance_no_additional_properties(llm):
|
||||
jsonschema.validate(instance=parsed_json, schema=schema)
|
||||
return parsed_json
|
||||
|
||||
base_generated = generate_with_backend('guidance:disable-any-whitespace')
|
||||
base_generated = generate_with_backend("guidance", False)
|
||||
assert "a1" in base_generated
|
||||
assert "a2" in base_generated
|
||||
assert "a3" in base_generated
|
||||
@ -490,8 +535,7 @@ def test_guidance_no_additional_properties(llm):
|
||||
assert "a5" in base_generated
|
||||
assert "a6" in base_generated
|
||||
|
||||
generated = generate_with_backend(
|
||||
'guidance:no-additional-properties,disable-any-whitespace')
|
||||
generated = generate_with_backend("guidance", True)
|
||||
assert "a1" in generated
|
||||
assert "a2" in generated
|
||||
assert "a3" in generated
|
||||
|
@ -202,12 +202,15 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
|
||||
|
||||
def test_guided_decoding_backend_options():
|
||||
"""Test backend-specific options"""
|
||||
params = GuidedDecodingParams(
|
||||
backend="xgrammar:option-1,option-2,option-3")
|
||||
assert params.backend_options() == ["option-1", "option-2", "option-3"]
|
||||
|
||||
no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
|
||||
assert no_fallback.no_fallback()
|
||||
with pytest.warns(DeprecationWarning):
|
||||
guided_decoding_params = GuidedDecodingParams(
|
||||
backend=
|
||||
"xgrammar:no-fallback,disable-any-whitespace,no-additional-properties"
|
||||
)
|
||||
assert guided_decoding_params.backend == "xgrammar"
|
||||
assert guided_decoding_params.disable_fallback
|
||||
assert guided_decoding_params.disable_any_whitespace
|
||||
assert guided_decoding_params.disable_additional_properties
|
||||
|
||||
|
||||
def test_pickle_xgrammar_tokenizer_data():
|
||||
|
@ -17,15 +17,12 @@ from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
|
||||
"auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance:disable-any-whitespace",
|
||||
"auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
|
||||
"mistral"),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar:disable-any-whitespace", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"),
|
||||
#FIXME: This test is flaky on CI thus disabled
|
||||
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance:disable-any-whitespace", "auto"),
|
||||
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
|
||||
]
|
||||
|
||||
PARAMS_MODELS_TOKENIZER_MODE = [
|
||||
@ -73,6 +70,7 @@ def test_structured_output(
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
tokenizer_mode=tokenizer_mode)
|
||||
|
||||
#
|
||||
@ -98,8 +96,7 @@ def test_structured_output(
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
if 'disable-any-whitespace' in guided_decoding_backend:
|
||||
assert "\n" not in generated_text
|
||||
assert "\n" not in generated_text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
@ -520,10 +517,11 @@ def test_structured_output_auto_mode(
|
||||
def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
backend = 'guidance:no-additional-properties,disable-any-whitespace'
|
||||
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=backend)
|
||||
guided_decoding_backend="guidance",
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
guided_decoding_disable_additional_properties=True)
|
||||
|
||||
schema = {
|
||||
'type': 'object',
|
||||
@ -548,7 +546,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
def generate_with_backend(backend):
|
||||
guided_params = GuidedDecodingParams(json=schema, backend=backend)
|
||||
guided_params = GuidedDecodingParams(
|
||||
json=schema,
|
||||
backend=backend,
|
||||
disable_any_whitespace=True,
|
||||
disable_additional_properties=True)
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=256,
|
||||
guided_decoding=guided_params)
|
||||
@ -562,8 +564,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
jsonschema.validate(instance=parsed_json, schema=schema)
|
||||
return parsed_json
|
||||
|
||||
generated = generate_with_backend(
|
||||
'guidance:no-additional-properties,disable-any-whitespace')
|
||||
generated = generate_with_backend("guidance")
|
||||
assert "a1" in generated
|
||||
assert "a2" in generated
|
||||
assert "a3" in generated
|
||||
|
@ -57,7 +57,8 @@ def test_unsupported_configs(monkeypatch):
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
guided_decoding_backend="lm-format-enforcer:no-fallback",
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
guided_decoding_disable_fallback=True,
|
||||
).create_engine_config()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
|
@ -17,12 +17,14 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, TypeVar, Union, get_args, get_origin)
|
||||
Optional, Protocol, TypeVar, Union, cast, get_args,
|
||||
get_origin)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
@ -32,7 +34,6 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
@ -344,7 +345,7 @@ class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
task: Union[TaskOption, Literal["draft"]],
|
||||
task: Literal[TaskOption, Literal["draft"]],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
@ -701,7 +702,7 @@ class ModelConfig:
|
||||
|
||||
def _resolve_task(
|
||||
self,
|
||||
task_option: Union[TaskOption, Literal["draft"]],
|
||||
task_option: Literal[TaskOption, Literal["draft"]],
|
||||
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
|
||||
if task_option == "draft":
|
||||
return {"draft"}, "draft"
|
||||
@ -3185,13 +3186,36 @@ GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
|
||||
class DecodingConfig:
|
||||
"""Dataclass which contains the decoding strategy of the engine."""
|
||||
|
||||
guided_decoding_backend: GuidedDecodingBackend = \
|
||||
"auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||
@property
|
||||
@deprecated(
|
||||
"`guided_decoding_backend` is deprecated and has been renamed to "
|
||||
"`backend`. This will be removed in v0.10.0. Please use the "
|
||||
"`backend` argument instead.")
|
||||
def guided_decoding_backend(self) -> GuidedDecodingBackend:
|
||||
return self.backend
|
||||
|
||||
@guided_decoding_backend.setter
|
||||
def guided_decoding_backend(self, value: GuidedDecodingBackend):
|
||||
self.backend = value
|
||||
|
||||
backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||
"""Which engine will be used for guided decoding (JSON schema / regex etc)
|
||||
by default. With "auto", we will make opinionated choices based on request
|
||||
contents and what the backend libraries currently support, so the behavior
|
||||
is subject to change in each release."""
|
||||
|
||||
disable_fallback: bool = False
|
||||
"""If `True`, vLLM will not fallback to a different backend on error."""
|
||||
|
||||
disable_any_whitespace: bool = False
|
||||
"""If `True`, the model will not generate any whitespace during guided
|
||||
decoding. This is only supported for xgrammar and guidance backends."""
|
||||
|
||||
disable_additional_properties: bool = False
|
||||
"""If `True`, the `guidance` backend will not use `additionalProperties`
|
||||
in the JSON schema. This is only supported for the `guidance` backend and
|
||||
is used to better align its behaviour with `outlines` and `xgrammar`."""
|
||||
|
||||
reasoning_backend: Optional[str] = None
|
||||
"""Select the reasoning parser depending on the model that you're using.
|
||||
This is used to parse the reasoning content into OpenAI API format.
|
||||
@ -3217,15 +3241,41 @@ class DecodingConfig:
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
backend = GuidedDecodingParams(
|
||||
backend=self.guided_decoding_backend).backend_name
|
||||
if ":" in self.backend:
|
||||
self._extract_backend_options()
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV1)
|
||||
else:
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV0)
|
||||
if backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
|
||||
if self.backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid backend '{self.backend}',"
|
||||
f" must be one of {valid_guided_backends}")
|
||||
if (self.disable_any_whitespace
|
||||
and self.backend not in ("xgrammar", "guidance")):
|
||||
raise ValueError("disable_any_whitespace is only supported for "
|
||||
"xgrammar and guidance backends.")
|
||||
if (self.disable_additional_properties and self.backend != "guidance"):
|
||||
raise ValueError("disable_additional_properties is only supported "
|
||||
"for the guidance backend.")
|
||||
|
||||
@deprecated(
|
||||
"Passing guided decoding backend options inside backend in the format "
|
||||
"'backend:...' is deprecated. This will be removed in v0.10.0. Please "
|
||||
"use the dedicated arguments '--disable-fallback', "
|
||||
"'--disable-any-whitespace' and '--disable-additional-properties' "
|
||||
"instead.")
|
||||
def _extract_backend_options(self):
|
||||
"""Extract backend options from the backend string."""
|
||||
backend, options = self.backend.split(":")
|
||||
self.backend = cast(GuidedDecodingBackend, backend)
|
||||
options_set = set(options.strip().split(","))
|
||||
if "no-fallback" in options_set:
|
||||
self.disable_fallback = True
|
||||
if "disable-any-whitespace" in options_set:
|
||||
self.disable_any_whitespace = True
|
||||
if "no-additional-properties" in options_set:
|
||||
self.disable_additional_properties = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -18,9 +18,9 @@ from vllm import version
|
||||
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
ConfigFormat, ConfigType, DecodingConfig, Device,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackendV1, HfOverrides,
|
||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ModelImpl, MultiModalConfig,
|
||||
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
||||
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PrefixCachingHashAlgo, PromptAdapterConfig,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
@ -317,7 +317,12 @@ class EngineArgs:
|
||||
bool] = SchedulerConfig.enable_chunked_prefill
|
||||
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
|
||||
|
||||
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
|
||||
guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
|
||||
guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
|
||||
guided_decoding_disable_any_whitespace: bool = \
|
||||
DecodingConfig.disable_any_whitespace
|
||||
guided_decoding_disable_additional_properties: bool = \
|
||||
DecodingConfig.disable_additional_properties
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
|
||||
speculative_config: Optional[Dict[str, Any]] = None
|
||||
@ -498,9 +503,17 @@ class EngineArgs:
|
||||
title="DecodingConfig",
|
||||
description=DecodingConfig.__doc__,
|
||||
)
|
||||
guided_decoding_group.add_argument("--guided-decoding-backend",
|
||||
**guided_decoding_kwargs["backend"])
|
||||
guided_decoding_group.add_argument(
|
||||
'--guided-decoding-backend',
|
||||
**guided_decoding_kwargs["guided_decoding_backend"])
|
||||
"--guided-decoding-disable-fallback",
|
||||
**guided_decoding_kwargs["disable_fallback"])
|
||||
guided_decoding_group.add_argument(
|
||||
"--guided-decoding-disable-any-whitespace",
|
||||
**guided_decoding_kwargs["disable_any_whitespace"])
|
||||
guided_decoding_group.add_argument(
|
||||
"--guided-decoding-disable-additional-properties",
|
||||
**guided_decoding_kwargs["disable_additional_properties"])
|
||||
guided_decoding_group.add_argument(
|
||||
"--reasoning-parser",
|
||||
# This choices is a special case because it's not static
|
||||
@ -1244,7 +1257,11 @@ class EngineArgs:
|
||||
if self.enable_prompt_adapter else None
|
||||
|
||||
decoding_config = DecodingConfig(
|
||||
guided_decoding_backend=self.guided_decoding_backend,
|
||||
backend=self.guided_decoding_backend,
|
||||
disable_fallback=self.guided_decoding_disable_fallback,
|
||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||
disable_additional_properties=\
|
||||
self.guided_decoding_disable_additional_properties,
|
||||
reasoning_backend=self.reasoning_parser
|
||||
if self.enable_reasoning else None,
|
||||
)
|
||||
@ -1335,9 +1352,8 @@ class EngineArgs:
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
# remove backend options when doing this check
|
||||
if self.guided_decoding_backend.split(':')[0] \
|
||||
not in get_args(GuidedDecodingBackendV1):
|
||||
if self.guided_decoding_backend not in get_args(
|
||||
GuidedDecodingBackendV1):
|
||||
_raise_or_fallback(
|
||||
feature_name=
|
||||
f"--guided-decoding-backend={self.guided_decoding_backend}",
|
||||
|
@ -2091,7 +2091,7 @@ class LLMEngine:
|
||||
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
guided_decoding.backend = guided_decoding.backend or \
|
||||
self.decoding_config.guided_decoding_backend
|
||||
self.decoding_config.backend
|
||||
|
||||
if self.decoding_config.reasoning_backend is not None:
|
||||
logger.debug("Building with reasoning backend %s",
|
||||
|
@ -615,9 +615,9 @@ class MQLLMEngineClient(EngineClient):
|
||||
build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=await self.get_tokenizer(lora_request),
|
||||
default_guided_backend=(self.decoding_config.guided_decoding_backend
|
||||
default_guided_backend=(self.decoding_config.backend
|
||||
if self.decoding_config
|
||||
else DecodingConfig.guided_decoding_backend),
|
||||
else DecodingConfig.backend),
|
||||
model_config=self.model_config,
|
||||
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||
)
|
||||
|
@ -26,8 +26,8 @@ def maybe_backend_fallback(
|
||||
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
|
||||
fallback: str) -> None:
|
||||
"""Change the backend to the specified fallback with a warning log,
|
||||
or raise a ValueError if the `no-fallback` option is specified."""
|
||||
if guided_params.no_fallback():
|
||||
or raise a ValueError if the `disable_fallback` option is specified."""
|
||||
if guided_params.disable_fallback:
|
||||
raise ValueError(message)
|
||||
|
||||
logger.warning("%s Falling back to use %s instead.", message, fallback)
|
||||
@ -40,7 +40,7 @@ def maybe_backend_fallback(
|
||||
guided_params.backend = "xgrammar"
|
||||
|
||||
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
||||
if guided_params.backend_name == "lm-format-enforcer":
|
||||
if guided_params.backend == "lm-format-enforcer":
|
||||
if guided_params.grammar is not None:
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
@ -55,7 +55,7 @@ def maybe_backend_fallback(
|
||||
"lm-format-enforcer does not support advanced JSON schema "
|
||||
"features like patterns or numeric ranges.", "outlines")
|
||||
|
||||
if guided_params.backend_name == "xgrammar":
|
||||
if guided_params.backend == "xgrammar":
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||
xgr_installed)
|
||||
|
||||
@ -87,7 +87,7 @@ def maybe_backend_fallback(
|
||||
guided_params,
|
||||
"xgrammar module cannot be imported successfully.", "outlines")
|
||||
|
||||
if (guided_params.backend_name == "outlines"
|
||||
if (guided_params.backend == "outlines"
|
||||
and guided_params.json_object is not None):
|
||||
# outlines doesn't support json_object, fallback to guidance
|
||||
fallback_or_error(guided_params,
|
||||
@ -111,7 +111,7 @@ async def get_guided_decoding_logits_processor(
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend_name == 'outlines':
|
||||
if guided_params.backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
@ -122,12 +122,12 @@ async def get_guided_decoding_logits_processor(
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend_name == 'xgrammar':
|
||||
if guided_params.backend == 'xgrammar':
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, model_config, reasoner)
|
||||
if guided_params.backend_name == 'guidance':
|
||||
if guided_params.backend == 'guidance':
|
||||
from vllm.model_executor.guided_decoding.guidance_decoding import (
|
||||
get_local_guidance_guided_decoding_logits_processor)
|
||||
return get_local_guidance_guided_decoding_logits_processor(
|
||||
@ -152,23 +152,23 @@ def get_local_guided_decoding_logits_processor(
|
||||
reasoner = reasoner_class(tokenizer)
|
||||
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend_name == 'outlines':
|
||||
if guided_params.backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_local_outlines_guided_decoding_logits_processor)
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, reasoner)
|
||||
if guided_params.backend_name == 'lm-format-enforcer':
|
||||
if guided_params.backend == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend_name == 'xgrammar':
|
||||
if guided_params.backend == 'xgrammar':
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, model_config, reasoner)
|
||||
if guided_params.backend_name == 'guidance':
|
||||
if guided_params.backend == 'guidance':
|
||||
from vllm.model_executor.guided_decoding.guidance_decoding import (
|
||||
get_local_guidance_guided_decoding_logits_processor)
|
||||
return get_local_guidance_guided_decoding_logits_processor(
|
||||
|
@ -21,13 +21,12 @@ def get_local_guidance_guided_decoding_logits_processor(
|
||||
"""
|
||||
|
||||
grm = ""
|
||||
any_whitespace = 'disable-any-whitespace' not in \
|
||||
guided_params.backend_options()
|
||||
any_whitespace = not guided_params.disable_any_whitespace
|
||||
if (guide_json := guided_params.json) is not None:
|
||||
# Optionally set additionalProperties to False at the top-level
|
||||
# By default, other backends do not allow additional top-level
|
||||
# properties, so this makes guidance more similar to other backends
|
||||
if 'no-additional-properties' in guided_params.backend_options():
|
||||
if guided_params.disable_additional_properties:
|
||||
if not isinstance(guide_json, str):
|
||||
guide_json = json.dumps(guide_json)
|
||||
guide_json = process_for_additional_properties(guide_json)
|
||||
|
@ -175,8 +175,7 @@ class GrammarConfig:
|
||||
else:
|
||||
json_str = guided_params.json
|
||||
|
||||
any_whitespace = 'disable-any-whitespace' not in \
|
||||
guided_params.backend_options()
|
||||
any_whitespace = not guided_params.disable_any_whitespace
|
||||
|
||||
# Check and log if model with xgrammar and whitespace have history
|
||||
# of runaway generation of whitespaces.
|
||||
@ -191,11 +190,10 @@ class GrammarConfig:
|
||||
model_with_warn = 'Qwen'
|
||||
|
||||
if model_with_warn is not None and any_whitespace:
|
||||
msg = (f"{model_with_warn} "
|
||||
f"model detected, consider set "
|
||||
f"`guided_backend=xgrammar:disable-any-whitespace` "
|
||||
f"to prevent runaway generation of whitespaces.")
|
||||
logger.info_once(msg)
|
||||
logger.info_once(
|
||||
"%s model detected, consider setting "
|
||||
"`disable_any_whitespace` to prevent runaway generation "
|
||||
"of whitespaces.", model_with_warn)
|
||||
# Validate the schema and raise ValueError here if it is invalid.
|
||||
# This is to avoid exceptions in model execution, which will crash
|
||||
# the engine worker process.
|
||||
|
@ -8,6 +8,7 @@ from typing import Annotated, Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
@ -37,6 +38,10 @@ class GuidedDecodingParams:
|
||||
json_object: Optional[bool] = None
|
||||
"""These are other options that can be set"""
|
||||
backend: Optional[str] = None
|
||||
backend_was_auto: bool = False
|
||||
disable_fallback: bool = False
|
||||
disable_any_whitespace: bool = False
|
||||
disable_additional_properties: bool = False
|
||||
whitespace_pattern: Optional[str] = None
|
||||
structural_tag: Optional[str] = None
|
||||
|
||||
@ -68,36 +73,6 @@ class GuidedDecodingParams:
|
||||
structural_tag=structural_tag,
|
||||
)
|
||||
|
||||
@property
|
||||
def backend_name(self) -> str:
|
||||
"""Return the backend name without any options.
|
||||
|
||||
For example if the backend is "xgrammar:no-fallback", returns "xgrammar"
|
||||
"""
|
||||
return (self.backend or "").split(":")[0]
|
||||
|
||||
def backend_options(self) -> list[str]:
|
||||
"""Return the backend options as a list of strings."""
|
||||
if not self.backend or ":" not in self.backend:
|
||||
return []
|
||||
return self.backend.split(":")[1].split(",")
|
||||
|
||||
def add_option(self, opt_name: str) -> None:
|
||||
"""Adds an option to the backend options."""
|
||||
if not self.backend:
|
||||
self.backend = f":{opt_name}"
|
||||
elif ":" not in self.backend:
|
||||
self.backend += f":{opt_name}"
|
||||
else:
|
||||
options = set(self.backend_options())
|
||||
options.add(opt_name)
|
||||
self.backend = f"{self.backend_name}:{','.join(sorted(options))}"
|
||||
|
||||
def no_fallback(self) -> bool:
|
||||
"""Returns True if the "no-fallback" option is supplied for the guided
|
||||
decoding backend"""
|
||||
return "no-fallback" in self.backend_options()
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that some fields are mutually exclusive."""
|
||||
guide_count = sum([
|
||||
@ -109,6 +84,27 @@ class GuidedDecodingParams:
|
||||
"You can only use one kind of guided decoding but multiple are "
|
||||
f"specified: {self.__dict__}")
|
||||
|
||||
if self.backend is not None and ":" in self.backend:
|
||||
self._extract_backend_options()
|
||||
|
||||
@deprecated(
|
||||
"Passing guided decoding backend options inside backend in the format "
|
||||
"'backend:...' is deprecated. This will be removed in v0.10.0. Please "
|
||||
"use the dedicated arguments '--disable-fallback', "
|
||||
"'--disable-any-whitespace' and '--disable-additional-properties' "
|
||||
"instead.")
|
||||
def _extract_backend_options(self):
|
||||
"""Extract backend options from the backend string."""
|
||||
assert isinstance(self.backend, str)
|
||||
self.backend, options = self.backend.split(":")
|
||||
options_set = set(options.strip().split(","))
|
||||
if "no-fallback" in options_set:
|
||||
self.disable_fallback = True
|
||||
if "disable-any-whitespace" in options_set:
|
||||
self.disable_any_whitespace = True
|
||||
if "no-additional-properties" in options_set:
|
||||
self.disable_additional_properties = True
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
|
@ -144,7 +144,7 @@ class Processor:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
|
||||
engine_level_backend = self.decoding_config.guided_decoding_backend
|
||||
engine_level_backend = self.decoding_config.backend
|
||||
if params.guided_decoding.backend:
|
||||
# Request-level backend selection is not supported in V1.
|
||||
# The values may differ if `params` is reused and was set
|
||||
@ -152,8 +152,8 @@ class Processor:
|
||||
# request. We remember that it was set as a result of `auto`
|
||||
# using the `_auto` option set on the backend in the params.
|
||||
if (params.guided_decoding.backend != engine_level_backend
|
||||
and not (engine_level_backend == "auto" and "_auto"
|
||||
in params.guided_decoding.backend_options())):
|
||||
and not (engine_level_backend == "auto"
|
||||
and params.guided_decoding.backend_was_auto)):
|
||||
raise ValueError(
|
||||
"Request-level structured output backend selection is no "
|
||||
"longer supported. The request specified "
|
||||
@ -189,7 +189,7 @@ class Processor:
|
||||
# are not supported in xgrammar. Fall back to guidance.
|
||||
params.guided_decoding.backend = "guidance"
|
||||
# Remember that this backend was set automatically
|
||||
params.guided_decoding.add_option("_auto")
|
||||
params.guided_decoding.backend_was_auto = True
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
|
@ -45,17 +45,17 @@ class StructuredOutputManager:
|
||||
# NOTE: We only support a single backend. We do NOT support different
|
||||
# backends on a per-request basis in V1 (for now, anyway...).
|
||||
if self.backend is None:
|
||||
backend_name = request.sampling_params.guided_decoding.backend_name
|
||||
if backend_name == "xgrammar":
|
||||
backend = request.sampling_params.guided_decoding.backend
|
||||
if backend == "xgrammar":
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
XgrammarBackend)
|
||||
|
||||
self.backend = XgrammarBackend(self.vllm_config)
|
||||
elif backend_name == "guidance":
|
||||
elif backend == "guidance":
|
||||
self.backend = GuidanceBackend(self.vllm_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported structured output backend: {backend_name}")
|
||||
f"Unsupported structured output backend: {backend}")
|
||||
|
||||
grammar = self.executor.submit(self._async_create_grammar, request)
|
||||
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
@ -65,19 +65,10 @@ class GuidanceBackend(StructuredOutputBackend):
|
||||
self.vllm_config = vllm_config
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
|
||||
self.disable_any_whitespace = False
|
||||
self.no_additional_properties = False
|
||||
backend_options = GuidedDecodingParams(
|
||||
backend=vllm_config.decoding_config.guided_decoding_backend
|
||||
).backend_options()
|
||||
for option in backend_options:
|
||||
if option == "disable-any-whitespace":
|
||||
self.disable_any_whitespace = True
|
||||
elif option == "no-additional-properties":
|
||||
self.no_additional_properties = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported option for the guidance backend: {option}")
|
||||
self.disable_any_whitespace = \
|
||||
vllm_config.decoding_config.disable_any_whitespace
|
||||
self.disable_additional_properties = \
|
||||
vllm_config.decoding_config.disable_additional_properties
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
||||
@ -87,7 +78,7 @@ class GuidanceBackend(StructuredOutputBackend):
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
self.serialized_grammar = serialize_guidance_grammar(
|
||||
request_type, grammar_spec, self.disable_any_whitespace,
|
||||
self.no_additional_properties)
|
||||
self.disable_additional_properties)
|
||||
|
||||
ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer,
|
||||
@ -171,11 +162,11 @@ def serialize_guidance_grammar(
|
||||
request_type: StructuredOutputOptions,
|
||||
grammar_spec: Union[str, dict[str, Any]],
|
||||
disable_any_whitespace: bool = False,
|
||||
no_additional_properties: bool = False,
|
||||
disable_additional_properties: bool = False,
|
||||
) -> str:
|
||||
|
||||
def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str:
|
||||
if no_additional_properties:
|
||||
if disable_additional_properties:
|
||||
grammar_spec = process_for_additional_properties(grammar_spec)
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
grammar_spec,
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
import vllm.envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import LazyLoader
|
||||
@ -37,16 +37,8 @@ class XgrammarBackend(StructuredOutputBackend):
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||
|
||||
self.disable_any_whitespace = False
|
||||
backend_options = GuidedDecodingParams(
|
||||
backend=vllm_config.decoding_config.guided_decoding_backend
|
||||
).backend_options()
|
||||
for option in backend_options:
|
||||
if option == "disable-any-whitespace":
|
||||
self.disable_any_whitespace = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported option for the xgrammar backend: {option}")
|
||||
self.disable_any_whitespace = \
|
||||
vllm_config.decoding_config.disable_any_whitespace
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
|
Reference in New Issue
Block a user