[V0 deprecation] Guided decoding (#21347)

Signed-off-by: Reza Barazesh <rezabarazesh@meta.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Reza Barazesh
2025-07-29 03:15:30 -07:00
committed by GitHub
parent a4528f0cac
commit 37efc63b64
29 changed files with 103 additions and 2809 deletions

View File

@ -128,11 +128,10 @@ steps:
- tests/entrypoints/offline_mode
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_collective_rpc.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
- label: Entrypoints Test (API Server) # 40min

3
.github/CODEOWNERS vendored
View File

@ -10,7 +10,6 @@
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
/vllm/model_executor/guided_decoding @mgoin @russellb @aarnphm
/vllm/multimodal @DarkLight1337 @ywang96
/vllm/vllm_flash_attn @LucasWilkinson
/vllm/lora @jeejeelee
@ -35,9 +34,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/distributed/test_pipeline_parallel.py @youkaichao
/tests/distributed/test_same_node.py @youkaichao
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm
/tests/entrypoints/llm/test_guided_generate.py @mgoin @russellb @aarnphm
/tests/kernels @tlrmchlsmth @WoosukKwon
/tests/model_executor/test_guided_processors.py @mgoin @russellb
/tests/models @DarkLight1337 @ywang96
/tests/multi_step @alexm-redhat @comaniac
/tests/multimodal @DarkLight1337 @ywang96

3
.github/mergify.yml vendored
View File

@ -149,9 +149,6 @@ pull_request_rules:
- files=examples/offline_inference/structured_outputs.py
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
- files~=^vllm/model_executor/guided_decoding/
- files=tests/model_executor/test_guided_processors.py
- files=tests/entrypoints/llm/test_guided_generate.py
- files~=^tests/v1/structured_output/
- files=tests/v1/entrypoints/llm/test_guided_generate.py
- files~=^vllm/v1/structured_output/

View File

@ -1,552 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import weakref
from enum import Enum
import jsonschema
import pytest
import regex as re
from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
# Separate backends which support grammars vs ones
# which only support regex based constraints in tests.
GRAMMAR_DECODING_BACKENDS = [
# (backend, disable_any_whitespace),
("lm-format-enforcer", False),
("xgrammar", True),
("guidance", True),
]
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
ALL_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
ALL_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
ALL_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_complex_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example JSON for an assignment grade "
f"that fits this schema: {sample_complex_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_complex_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
ALL_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_definition_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
f"Give an example JSON for solving 8x + 7 = -23 "
f"that fits this schema: {sample_definition_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_definition_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
ALL_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_enum_json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(prompts=[
"Create a bug report JSON that fits this schema: "
f"{sample_enum_json_schema}. Make it for a high priority critical bug."
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_enum_json_schema)
# Additional assertions to verify enum values
assert output_json["status"] in ["active", "inactive", "pending"]
assert output_json["priority"] in ["low", "medium", "high", "critical"]
assert output_json["category"]["type"] in [
"bug", "feature", "improvement"
]
assert output_json["category"]["severity"] in [1, 2, 3, 4, 5]
for flag in output_json["flags"]:
assert flag in ["urgent", "blocked", "needs_review", "approved"]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
ALL_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GRAMMAR_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_statements,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(generated_text)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
def test_guided_options_request_deprecation_warning(sample_regex, llm):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
with pytest.warns(DeprecationWarning, match="guided_options_request"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
@pytest.mark.skip_global_cleanup
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))
with pytest.raises(ValueError, match="Cannot set both"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
@pytest.mark.skip_global_cleanup
def test_disable_guided_decoding_fallback(sample_regex, llm):
# see has_xgrammar_unsupported_json_features()
unsupported_json = {
"type": "object",
"properties": {
"example": {
"type": "string",
"minLength": 5 # unsupported by xgrammar
}
}
}
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
json=unsupported_json,
backend="xgrammar",
disable_fallback=True))
with pytest.raises(
ValueError,
match="xgrammar does not support advanced JSON schema features "
"like string length, item limits, or property bounds."):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GRAMMAR_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
for i in range(2):
generated_text = output.outputs[i].text
print(generated_text)
assert generated_text is not None
if disable_any_whitespace:
assert "\n" not in generated_text
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
# A list is not what was intended, but is still valid
# json.
assert isinstance(parsed_json, (dict, list))
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
ALL_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
ALL_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sample_output_schema = {
"type": "object",
"properties": {
"age": {
"type": "integer",
"minimum": 18,
"maximum": 99
},
"score": {
"type": "number",
"minimum": 0.0,
"maximum": 100.0
},
"zipcode": {
"type": "string",
"pattern": r"^\d{5}(-\d{4})?$"
},
},
"required": ["age", "score", "zipcode"],
}
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_output_schema,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace),
)
outputs = llm.generate(
prompts=[
"Create a JSON object for a user with age, score, and zipcode."
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_output_schema)
assert 18 <= output_json["age"] <= 99
assert 0.0 <= output_json["score"] <= 100.0
assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"])
is not None)
@pytest.mark.skip_global_cleanup
def test_guidance_no_additional_properties(llm):
schema = {
'type': 'object',
'properties': {
'a1': {
'type': 'string'
},
'a2': {
'type': 'string'
},
'a3': {
'type': 'string'
}
},
'required': ['a1', 'a2', 'a3'],
}
prompt = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"<|im_end|>\n<|im_start|>assistant\n")
def generate_with_backend(backend, disable_additional_properties):
guided_params = GuidedDecodingParams(
json=schema,
backend=backend,
disable_any_whitespace=True,
disable_additional_properties=disable_additional_properties)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
assert outputs is not None
generated_text = outputs[0].outputs[0].text
assert generated_text is not None
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json
base_generated = generate_with_backend("guidance", False)
assert "a1" in base_generated
assert "a2" in base_generated
assert "a3" in base_generated
# by default additional keys are generated
assert "a4" in base_generated
assert "a5" in base_generated
assert "a6" in base_generated
generated = generate_with_backend("guidance", True)
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated
assert "a4" not in generated
assert "a5" not in generated
assert "a6" not in generated

View File

@ -4,43 +4,11 @@
import sys
from contextlib import nullcontext
import pytest
from vllm_test_utils import BlameResult, blame
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
V1 only supports xgrammar so this is irrelevant.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def run_normal_opt125m():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM without guided decoding as a baseline.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
gpu_memory_utilization=0.3)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Destroy the LLM object and free up the GPU memory.
del llm
cleanup_dist_env_and_memory()
from vllm.sampling_params import GuidedDecodingParams
def run_normal():
@ -67,20 +35,22 @@ def run_normal():
cleanup_dist_env_and_memory()
def run_lmfe(sample_regex):
def run_xgrammar(sample_regex):
# Create an LLM with guided decoding enabled.
llm = LLM(model="distilbert/distilgpt2",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
guided_decoding_backend="xgrammar",
gpu_memory_utilization=0.3)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
prompt = f"Give an example IPv4 address with this regex: {sample_regex}"
guided_decoding = GuidedDecodingParams(regex=sample_regex)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=guided_decoding)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
prompts=[prompt] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
)
for output in outputs:
prompt = output.prompt
@ -103,7 +73,7 @@ def test_lazy_outlines(sample_regex):
lambda: module_name in sys.modules) if use_blame else nullcontext()
with context as result:
run_normal()
run_lmfe(sample_regex)
run_xgrammar(sample_regex)
if use_blame:
assert isinstance(result, BlameResult)
print(f"the first import location is:\n{result.trace_stack}")

View File

@ -488,7 +488,9 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
sample_guided_choice):
sample_guided_choice, is_v1_server: bool):
if not is_v1_server:
pytest.skip("Guided decoding is only supported in v1 engine")
messages = [{
"role": "system",
"content": "you are a helpful assistant"
@ -524,8 +526,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
async def test_guided_json_chat(client: openai.AsyncOpenAI,
sample_json_schema):
async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema,
is_v1_server: bool):
if not is_v1_server:
pytest.skip("Guided decoding is only supported in v1 engine")
messages = [{
"role": "system",
@ -568,7 +572,10 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex):
async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex,
is_v1_server: bool):
if not is_v1_server:
pytest.skip("Guided decoding is only supported in v1 engine")
messages = [{
"role": "system",
@ -653,7 +660,10 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema):
async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema,
is_v1_server: bool):
if not is_v1_server:
pytest.skip("Tool use is only supported in v1 engine")
messages = [{
"role": "system",
"content": "you are a helpful assistant"
@ -741,131 +751,6 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema):
assert json1["age"] != json2["age"]
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_required_tool_use(client: openai.AsyncOpenAI,
is_v1_server: bool, model_name: str):
if is_v1_server:
pytest.skip(
"tool_choice='required' requires features unsupported on V1")
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to find the weather for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "unit"],
},
},
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to get the forecast for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "days", "unit"],
},
},
},
]
messages = [
{
"role": "user",
"content": "Hi! How are you doing today?"
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Berlin and the "\
"forecast for the next 5 days, in fahrenheit?",
},
]
# Non-streaming test
chat_completion = await client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice="required",
)
assert chat_completion.choices[0].message.tool_calls is not None
assert len(chat_completion.choices[0].message.tool_calls) > 0
# Streaming test
stream = await client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice="required",
stream=True,
)
output = []
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.tool_calls:
output.extend(chunk.choices[0].delta.tool_calls)
assert len(output) > 0
@pytest.mark.asyncio
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
sample_json_schema):
@ -948,7 +833,11 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
async def test_response_format_json_schema(client: openai.AsyncOpenAI,
is_v1_server: bool):
if not is_v1_server:
pytest.skip(
"JSON schema response format is only supported in v1 engine")
prompt = 'what is 1+1? The format is "result": 2'
# Check that this prompt cannot lead to a valid JSON without json_schema
for _ in range(2):

View File

@ -28,7 +28,7 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"]
@pytest.fixture(scope="module")
@ -95,6 +95,14 @@ def server(default_server_args, request):
os.environ['VLLM_USE_V1'] = original_value
@pytest.fixture
def is_v1_server(server):
import os
# For completion tests, we assume v0 since there's no explicit v1 setup
return os.environ.get('VLLM_USE_V1', '0') == '1'
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
@ -631,7 +639,10 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI):
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_json_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema):
sample_json_schema, is_v1_server: bool):
if not is_v1_server:
pytest.skip("Guided decoding is only supported in v1 engine")
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile "
@ -653,7 +664,10 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_regex):
sample_regex, is_v1_server: bool):
if not is_v1_server:
pytest.skip("Guided decoding is only supported in v1 engine")
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
@ -674,7 +688,11 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_guided_choice):
sample_guided_choice,
is_v1_server: bool):
if not is_v1_server:
pytest.skip("Guided decoding is only supported in v1 engine")
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
@ -692,7 +710,9 @@ async def test_guided_choice_completion(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
async def test_guided_grammar(client: openai.AsyncOpenAI,
sample_sql_statements):
sample_sql_statements, is_v1_server: bool):
if not is_v1_server:
pytest.skip("Guided grammar is only supported in v1 engine")
completion = await client.completions.create(
model=MODEL_NAME,
@ -754,7 +774,11 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema, sample_regex):
sample_json_schema, sample_regex,
is_v1_server: bool):
if not is_v1_server:
pytest.skip("Guided decoding is only supported in v1 engine")
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,

View File

@ -9,6 +9,11 @@ import regex as re
from ...utils import RemoteOpenAIServer
@pytest.fixture(scope="function", autouse=True)
def use_v1_only(monkeypatch):
monkeypatch.setenv('VLLM_USE_V1', '1')
@pytest.mark.asyncio
async def test_empty_prompt():
model_name = "gpt2"
@ -37,24 +42,3 @@ async def test_out_of_vocab_token_ids():
prompt=[999999],
max_tokens=5,
temperature=0.0)
@pytest.mark.asyncio
async def test_reject_multistep_with_guided_decoding():
model_name = "gpt2"
server_args = ["--enforce-eager", "--num-scheduler-steps", "8"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
with pytest.raises(
openai.BadRequestError,
match=re.compile(
'.*Guided decoding .* multi-step decoding.*').pattern):
await client.completions.create(
model=model_name,
prompt="Hello",
max_tokens=5,
temperature=0.0,
extra_body={"response_format": {
"type": "json_object"
}})

View File

@ -1,207 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pickle
import pytest
import torch
from transformers import AutoTokenizer
from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor,
get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS = [
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
]
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
# Initialize the tokenizer for the model here to avoid repeated loading
@pytest.fixture(scope="module")
def zephyr_7B_tokenzer():
return AutoTokenizer.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
sample_json_schema):
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
regex_LP = RegexLogitsProcessor(sample_regex,
zephyr_7B_tokenzer,
reasoner=None)
json_LP = JSONLogitsProcessor(sample_json_schema,
zephyr_7B_tokenzer,
whitespace_pattern=None,
reasoner=None)
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_LP([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_LP([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
@pytest.mark.asyncio
@pytest.mark.parametrize("backend", GUIDED_DECODING_BACKENDS)
@pytest.mark.parametrize("is_local", [True, False])
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
sample_regex,
sample_json_schema,
zephyr_7B_tokenzer):
config = ModelConfig(
MODEL_NAME,
runner="generate",
seed=0,
dtype="bfloat16",
)
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor(
regex_request, zephyr_7B_tokenzer, config) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, zephyr_7B_tokenzer, config)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
# allowed tokens at state 0
tensor = regex_lp([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, zephyr_7B_tokenzer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
@pytest.mark.asyncio
@pytest.mark.parametrize("backend",
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
@pytest.mark.parametrize("is_local", [True, False])
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
async def test_guided_logits_processor_with_reasoning(
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
sample_json_schema, deepseek_r1_qwen_tokenizer):
config = ModelConfig(
REASONING_MODEL_NAME,
runner="generate",
seed=0,
dtype="bfloat16",
)
token_ids = deepseek_r1_qwen_tokenizer.encode(
"<think>here is the thinking process")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = get_local_guided_decoding_logits_processor(regex_request,
deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
regex_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend)
assert regex_lp is not None
tensor = torch.rand(151664)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)
token_ids = deepseek_r1_qwen_tokenizer.encode(
"<think>here is the thinking process")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(151664)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)
# Thinking is over, so the tensor should change.
token_ids = deepseek_r1_qwen_tokenizer.encode(
"<think>here is the thinking process</think>")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config,
reasoning_backend) if is_local else \
await get_guided_decoding_logits_processor(
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
assert json_lp is not None
tensor = torch.rand(151664)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, json_object=True)
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
def test_pickle_xgrammar_tokenizer_data():
try:
import xgrammar as xgr
except ImportError:
pytest.skip("Could not import xgrammar to run test")
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
TokenizerData)
tokenizer_data = TokenizerData(
metadata=
'{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}',
encoded_vocab=['!', '"', '#', '$', '%'],
)
pickled = pickle.dumps(tokenizer_data)
assert pickled is not None
depickled: TokenizerData = pickle.loads(pickled)
assert depickled is not None
assert json.loads(
depickled.metadata)['vocab_type'] == xgr.VocabType.BYTE_LEVEL.value

View File

@ -3,13 +3,11 @@
import copy
import json
import jsonschema
import jsonschema.exceptions
import pytest
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall, MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import MistralTokenizer
from ...utils import check_logprobs_close
@ -274,53 +272,6 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
assert parsed_message.content is None
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
monkeypatch: pytest.MonkeyPatch,
vllm_runner,
model: str,
guided_backend: str,
) -> None:
with monkeypatch.context() as m:
# Guided JSON not supported in xgrammar + V1 yet
m.setenv("VLLM_USE_V1", "0")
with vllm_runner(
model,
dtype='bfloat16',
tokenizer_mode="mistral",
guided_decoding_backend=guided_backend,
) as vllm_model:
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA)
params = SamplingParams(max_tokens=512,
temperature=0.7,
guided_decoding=guided_decoding)
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
}]
outputs = vllm_model.llm.chat(messages, sampling_params=params)
generated_text = outputs[0].outputs[0].text
json_response = json.loads(generated_text)
assert outputs is not None
try:
jsonschema.validate(instance=json_response,
schema=SAMPLE_JSON_SCHEMA)
except jsonschema.exceptions.ValidationError:
pytest.fail("Generated response is not valid with JSON schema")
def test_mistral_function_call_nested_json():
"""Ensure that the function-name regex captures the entire outer-most
JSON block, including nested braces."""

View File

@ -14,9 +14,9 @@ from vllm import LLM, SamplingParams
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
"""We can run both engines for this test."""
pass
def v1(monkeypatch):
"""Only run on vLLM v1."""
monkeypatch.setenv('VLLM_USE_V1', '1')
def _generate(

View File

@ -56,8 +56,7 @@ def test_sampling_params_from_request_with_no_guided_decoding_backend(
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
[("xgrammar", "xgrammar"),
("lm-format-enforcer", "lm-format-enforcer"),
[("xgrammar", "xgrammar"), ("guidance", "guidance"),
("outlines", "outlines")])
def test_sampling_params_from_request_with_guided_decoding_backend(
request_level_guided_decoding_backend: str, expected: str,

View File

@ -47,13 +47,6 @@ def test_unsupported_configs(monkeypatch):
},
).create_engine_config()
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
guided_decoding_backend="lm-format-enforcer",
guided_decoding_disable_fallback=True,
).create_engine_config()
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,

View File

@ -34,7 +34,6 @@ ALLOWED_FILES = set([
'vllm/model_executor/models/registry.py',
'tests/test_utils.py',
'tests/tokenization/test_cached_tokenizer.py',
'tests/model_executor/test_guided_processors.py',
'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py',
'vllm/engine/multiprocessing/client.py',

View File

@ -3721,12 +3721,7 @@ def get_served_model_name(model: str,
return served_model_name
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
GuidedDecodingBackendV1]
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"]
@config
@ -3734,7 +3729,7 @@ GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine."""
backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar"
backend: GuidedDecodingBackend = "auto"
"""Which engine will be used for guided decoding (JSON schema / regex etc)
by default. With "auto", we will make opinionated choices based on request
contents and what the backend libraries currently support, so the behavior
@ -3776,13 +3771,6 @@ class DecodingConfig:
return hash_str
def __post_init__(self):
if envs.VLLM_USE_V1:
valid_guided_backends = get_args(GuidedDecodingBackendV1)
else:
valid_guided_backends = get_args(GuidedDecodingBackendV0)
if self.backend not in valid_guided_backends:
raise ValueError(f"Invalid backend '{self.backend}',"
f" must be one of {valid_guided_backends}")
if (self.disable_any_whitespace
and self.backend not in ("xgrammar", "guidance")):
raise ValueError("disable_any_whitespace is only supported for "

View File

@ -25,14 +25,14 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, ConvertOption,
DecodingConfig, DetailedTraceModules, Device,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, KVEventsConfig, KVTransferConfig,
LoadConfig, LogprobsMode, LoRAConfig, ModelConfig,
ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
SchedulerPolicy, SpeculativeConfig, TaskOption,
TokenizerMode, VllmConfig, get_attr_docs, get_field)
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@ -1343,14 +1343,6 @@ class EngineArgs:
recommend_to_remove=True)
return False
if self.guided_decoding_backend not in get_args(
GuidedDecodingBackendV1):
_raise_or_fallback(
feature_name=
f"--guided-decoding-backend={self.guided_decoding_backend}",
recommend_to_remove=False)
return False
# Need at least Ampere for now (FA support required).
# Skip this check if we are running on a non-GPU platform,
# or if the device capability is not available

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import time
import weakref
from functools import partial
@ -24,8 +23,6 @@ from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
@ -469,19 +466,6 @@ class _AsyncLLMEngine(LLMEngine):
tokenization_kwargs=tokenization_kwargs,
)
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
# Guided decoding has an async implementation for building logits
# processors in a separate threadpool.
# We want to invoke that here instead of using the blocking
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.backend,
reasoning_backend=self.decoding_config.reasoning_backend,
model_config=self.model_config)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
@ -503,48 +487,6 @@ class _AsyncLLMEngine(LLMEngine):
raise NotImplementedError
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str, reasoning_backend: Optional[str],
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if sampling_params.guided_decoding is None:
return sampling_params
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
logger.debug(
"Building guided decoding logits processor. "
"guided_decoding: %s%s", guided_decoding,
f", reasoning_backend: {reasoning_backend}"
if reasoning_backend is not None else "")
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding,
tokenizer=tokenizer,
reasoning_backend=reasoning_backend,
model_config=model_config)
if processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(processor)
# Unset guided decoding params after constructing the lp from them
sampling_params.guided_decoding = None
return sampling_params
class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import time
from collections import Counter as collectionsCounter
from collections import deque
@ -36,8 +35,6 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.processing import EncDecMultiModalProcessor
@ -686,11 +683,10 @@ class LLMEngine:
"Priority scheduling is not enabled.")
if isinstance(params, SamplingParams) \
and (params.guided_decoding or params.logits_processors) \
and params.logits_processors \
and self.scheduler_config.num_scheduler_steps > 1:
raise ValueError(
"Guided decoding and logits processors are not supported "
"in multi-step decoding")
"Logits processors are not supported in multi-step decoding")
if arrival_time is None:
arrival_time = time.time()
@ -1983,43 +1979,13 @@ class LLMEngine:
def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""
"""Constructs logits processors based on the logits_bias, and
allowed_token_ids fields in sampling_params. Deletes those fields and
adds the constructed logits processors to the logits_processors field.
Returns the modified sampling params."""
logits_processors = []
if sampling_params.guided_decoding is not None:
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
logger.debug(
"Building guided decoding logits processor in "
"LLMEngine. Params: %s", guided_decoding)
tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.backend
if self.decoding_config.reasoning_backend:
logger.debug("Building with reasoning backend %s",
self.decoding_config.reasoning_backend)
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=self.model_config,
reasoning_backend=self.decoding_config.reasoning_backend,
)
if processor:
logits_processors.append(processor)
# Unset so this doesn't get passed down to the model
sampling_params.guided_decoding = None
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)

View File

@ -20,8 +20,6 @@ from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
@ -537,22 +535,6 @@ class MQLLMEngineClient(EngineClient):
if request_id in self.output_queues:
raise ValueError(f"Request {request_id} already exists")
# Constructing guided decoding logits processors is expensive, so we do
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
params = await \
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=(self.decoding_config.backend
if self.decoding_config
else DecodingConfig.backend),
model_config=self.model_config,
reasoning_backend=self.decoding_config.reasoning_backend,
)
# 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
@ -40,15 +39,13 @@ from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
SamplingParams)
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
@ -330,8 +327,6 @@ class LLM:
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
...
@ -345,8 +340,6 @@ class LLM:
prompt_token_ids: Optional[list[int]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
...
@ -360,8 +353,6 @@ class LLM:
prompt_token_ids: Optional[list[list[int]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
...
@ -376,8 +367,6 @@ class LLM:
prompt_token_ids: list[int],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
...
@ -392,8 +381,6 @@ class LLM:
prompt_token_ids: list[list[int]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
...
@ -406,8 +393,6 @@ class LLM:
prompt_token_ids: Union[list[int], list[list[int]]],
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
) -> list[RequestOutput]:
...
@ -425,8 +410,6 @@ class LLM:
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
priority: Optional[list[int]] = None,
) -> list[RequestOutput]:
"""Generates the completions for the input prompts.
@ -478,14 +461,6 @@ class LLM:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
raise ValueError(
"You can only use one guided decoding but multiple is "
f"specified: {guided_options_request}")
guided_options_request = GuidedDecodingRequest(
**guided_options_request)
if sampling_params is None:
# Use default sampling params.
sampling_params = self.get_default_sampling_params()
@ -507,7 +482,6 @@ class LLM:
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
guided_options=guided_options_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
@ -1582,17 +1556,8 @@ class LLM:
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[list[int]] = None,
) -> None:
if guided_options is not None:
warnings.warn(
"guided_options_request is deprecated, use "
"SamplingParams.guided_decoding instead",
DeprecationWarning,
stacklevel=2,
)
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
@ -1608,8 +1573,6 @@ class LLM:
for sp in params if isinstance(params, Sequence) else (params, ):
if isinstance(sp, SamplingParams):
self._add_guided_params(sp, guided_options)
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
@ -1647,29 +1610,6 @@ class LLM:
priority=priority,
)
def _add_guided_params(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options is None:
return params
if params.guided_decoding is not None:
raise ValueError("Cannot set both guided_options_request and "
"params.guided_decoding.")
params.guided_decoding = GuidedDecodingParams(
json=guided_options.guided_json,
regex=guided_options.guided_regex,
choice=guided_options.guided_choice,
grammar=guided_options.guided_grammar,
json_object=guided_options.guided_json_object,
backend=guided_options.guided_decoding_backend,
whitespace_pattern=guided_options.guided_whitespace_pattern,
structural_tag=guided_options.structural_tag,
)
return params
def _run_engine(
self,
*,

View File

@ -1,192 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
fallback: str) -> None:
"""Change the backend to the specified fallback with a warning log,
or raise a ValueError if the `disable_fallback` option is specified."""
if guided_params.disable_fallback:
raise ValueError(message)
logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if guided_params.backend == "auto":
guided_params.backend = "xgrammar"
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend == "lm-format-enforcer":
if guided_params.grammar is not None:
fallback_or_error(
guided_params,
"lm-format-enforcer does not support grammar guided decoding.",
"xgrammar")
# lm-format-enforcer doesn't support some JSON schema features
elif (guided_params.json is not None
and has_lmf_unsupported_json_features(guided_params.json)):
fallback_or_error(
guided_params,
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges.", "outlines")
if guided_params.backend == "xgrammar":
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
# xgrammar doesn't support some JSON schema features
if (guided_params.json is not None and
has_xgrammar_unsupported_json_features(guided_params.json)):
fallback_or_error(
guided_params,
"xgrammar does not support advanced JSON schema features like "
"string length, item limits, or property bounds.", "outlines")
# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
# grammar is convertible to GBNF
elif (guided_params.grammar is not None
and grammar_is_likely_lark(guided_params.grammar)):
try:
convert_lark_to_gbnf(guided_params.grammar)
except Exception:
fallback_or_error(
guided_params,
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "guidance")
# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
fallback_or_error(
guided_params,
"xgrammar module cannot be imported successfully.", "guidance")
if guided_params.backend == "outlines":
if guided_params.json_object is not None:
# outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params,
"outlines does not support json_object.",
"guidance")
elif guided_params.grammar is not None:
# outlines grammar support has been removed, fallback to guidance
# if it is a lark-based grammar and xgrammar otherwise
if grammar_is_likely_lark(guided_params.grammar):
fallback_or_error(guided_params,
"outlines no longer supports grammars.",
"guidance")
else:
# The grammar is likely already GBNF format.
fallback_or_error(guided_params,
"outlines no longer supports grammars.",
"xgrammar")
return guided_params
async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = None
if reasoning_backend:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params)
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
guided_params, tokenizer, reasoner)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
reasoner = None
if reasoning_backend:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer, reasoner)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)

View File

@ -1,63 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import llguidance
from regex import escape as regex_escape
from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
GuidanceLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
from vllm.v1.structured_output.backend_guidance import (
process_for_additional_properties)
def get_local_guidance_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
"""
grm = ""
any_whitespace = not guided_params.disable_any_whitespace
if (guide_json := guided_params.json) is not None:
# Optionally set additionalProperties to False at the top-level
# By default, other backends do not allow additional top-level
# properties, so this makes guidance more similar to other backends
if guided_params.disable_additional_properties:
if not isinstance(guide_json, str):
guide_json = json.dumps(guide_json)
guide_json = process_for_additional_properties(guide_json)
grm = llguidance.LLMatcher.grammar_from_json_schema(
guide_json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice:
# choice just uses regex
choices = (regex_escape(str(choice))
for choice in guided_params.choice)
choices_regex = "(" + "|".join(choices) + ")"
grm = llguidance.grammar_from("regex", choices_regex)
elif guided_params.grammar:
# this supports Lark and GBNF
grm = llguidance.grammar_from("grammar", guided_params.grammar)
if grm:
return GuidanceLogitsProcessor(grm, tokenizer)
raise ValueError("Unknown guided decoding mode")

View File

@ -1,104 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import os
from typing import Any
import llguidance
import llguidance.hf
import llguidance.torch
import torch
from transformers import PreTrainedTokenizerBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class GuidanceLogitsProcessor:
"""Base Guidance Logits Processor"""
cached_tokenizers: dict[str, Any] = {}
def __init__(
self,
grammar: str,
tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Base Guidance Logits Processor
Args:
grammar (str)
grammar to guide the generation
tokenizer (PreTrainedTokenizerBase)
model's tokenizer
"""
self.grammar = grammar
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer.name_or_path
self.ll_tokenizer = None
self.ll_matcher = None
self.bitmask = None
self.new_sampling = False
self.initialized = False
def clone(self) -> "GuidanceLogitsProcessor":
cloned = copy.copy(self)
if self.initialized:
cloned.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer, # type: ignore[assignment]
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
return cloned
def _initialize(self):
if self.initialized:
return
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
None)
if ll_tokenizer is None:
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer
self.ll_tokenizer = ll_tokenizer
self.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
# create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
self.initialized = True
def __call__(
self,
input_ids: list[int],
scores: torch.Tensor,
) -> torch.Tensor:
# we initialize the guidance model here
# to avoid pickling ll_tokenizer and ll_interpreter
self._initialize()
if self.new_sampling and len(input_ids) > 0:
self.ll_matcher.consume_token( # type: ignore[attr-defined]
input_ids[-1])
err = self.ll_matcher.get_error() # type: ignore[attr-defined]
if err:
logger.warning("Error in LLMatcher: %s", err)
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
0)
llguidance.torch.apply_token_bitmask_inplace(
scores,
self.bitmask.to(scores.device)) # type: ignore[attr-defined]
self.new_sampling = True
return scores

View File

@ -1,41 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, TypedDict, Union
# These classes are deprecated, see SamplingParams
class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[dict, str]
guided_regex: str
guided_choice: list[str]
guided_grammar: str
guided_decoding_backend: str
guided_whitespace_pattern: str
guided_json_object: bool
@dataclass
class GuidedDecodingRequest:
"""One of the fields will be used to retrieve the logit processor."""
guided_json: Optional[Union[dict, str]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[list[str]] = None
guided_grammar: Optional[str] = None
guided_decoding_backend: Optional[str] = None
guided_whitespace_pattern: Optional[str] = None
guided_json_object: Optional[bool] = None
structural_tag: Optional[str] = None
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum(x is not None
for x in (self.guided_json, self.guided_regex,
self.guided_choice, self.guided_grammar,
self.guided_json_object,
self.structural_tag))
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding but multiple are "
f"specified: {self.__dict__}")

View File

@ -1,67 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import lru_cache
from json import loads as json_loads
from typing import Optional, Union
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
RegexParser, StringParser,
TokenEnforcerTokenizerData, UnionParser)
from lmformatenforcer.integrations.vllm import (
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from transformers import PreTrainedTokenizerBase
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if guided_params.json:
schema_dict = _normalize_json_schema_object(guided_params.json)
character_level_parser = JsonSchemaParser(schema_dict)
elif guided_params.choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in guided_params.choice])
elif guided_params.regex:
character_level_parser = RegexParser(guided_params.regex)
elif guided_params.grammar:
# CFG grammar not supported by LMFE
raise ValueError("Cannot construct a guided decoding logits processor"
" using the grammar option with the"
" lm_format_enforcer backend.")
elif guided_params.json_object:
# None means any json object
character_level_parser = JsonSchemaParser(None)
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict]) -> dict:
if isinstance(schema, str):
return json_loads(schema)
if isinstance(schema, dict):
return schema
raise AssertionError(f"Unsupported schema type {schema}")
@lru_cache
def _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
return build_vllm_token_enforcer_tokenizer_data(tokenizer)

View File

@ -1,117 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import concurrent.futures
import os
from enum import Enum
from json import dumps as json_dumps
from typing import Optional, Union
from regex import escape as regex_escape
from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
class GuidedDecodingMode(Enum):
JSON = "json"
REGEX = "regex"
CHOICE = "choice"
global_thread_pool = None # used for generating logits processor fsm
# It's not yet clear that using more provides a benefit, and it could
# potentially starve other processes on the machine. We'll cap this for now and
# adjust later if testing proves it to help overcome a bottleneck.
_MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser]
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
"""
global global_thread_pool
guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode:
return None
if global_thread_pool is None:
max_workers = os.cpu_count() or 2
if max_workers > _MAX_THREADPOOL_WORKERS:
max_workers = _MAX_THREADPOOL_WORKERS
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
mode, guided_params.whitespace_pattern,
reasoner)
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser]
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
"""
guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode:
return None
return _get_logits_processor(guide, tokenizer, mode,
guided_params.whitespace_pattern, reasoner)
def _get_guide_and_mode(
guided_params: GuidedDecodingParams
) -> Union[tuple[str, GuidedDecodingMode], tuple[None, None]]:
if guided_params.json:
if isinstance(guided_params.json, dict):
# turn dict into hashable string
json = json_dumps(guided_params.json)
else:
json = guided_params.json
return json, GuidedDecodingMode.JSON
elif guided_params.regex:
return guided_params.regex, GuidedDecodingMode.REGEX
elif guided_params.choice:
# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in guided_params.choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
elif guided_params.grammar:
raise ValueError(
"The `outlines` guided decoding backend no longer supports grammar "
"guided generation. Please use either the `xgrammar` or `guidance` "
"backend")
else:
return None, None
def _get_logits_processor(
guide: str,
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
reasoner)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer, reasoner)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")

View File

@ -1,307 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright 2024-present the Outlines developers
from __future__ import annotations
import copy
import hashlib
import importlib.metadata
import json
import os
from typing import Optional, Union
import regex as re
import torch
from cachetools import LRUCache
from diskcache import Cache
from outlines_core import Guide, Index, Vocabulary
from outlines_core.json_schema import build_regex_from_schema
from outlines_core.kernels.torch import (_apply_token_bitmask_inplace_kernel,
allocate_token_bitmask)
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from transformers.file_utils import SPIECE_UNDERLINE
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
CACHE = None
class BaseLogitsProcessor:
def __init__(self, guide: Guide, eos_token_id: int,
reasoner: Optional[ReasoningParser]) -> None:
self._guide: Guide = guide
self._eos_token_id: int = eos_token_id
self._reasoner: Optional[ReasoningParser] = reasoner
self._mask: Optional[torch.Tensor] = None
def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
if self._mask is None:
self._mask = allocate_token_bitmask(scores.size(-1))
# Skip the structured logits processing if reasoning is not finished.
# reasoner is not None only when `--reasoning-parser` is set.
if self._reasoner is not None and not self._reasoner.is_reasoning_end(
input_ids):
return scores
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# input_ids sequence to store the FSM state.
input_ids = (self._reasoner.extract_content_ids(input_ids)
if self._reasoner is not None else input_ids)
# Vllm V0 engine has a weird bug where we have to repeat
# the eos token id twice for generation to stop, or at least
# that is what we have to do from here in any case.
# This is a patch until a better solution can be pushed
# to outlines_core
if input_ids and input_ids[-1] != self._eos_token_id:
self._guide.advance(token_id=input_ids[-1], return_tokens=False)
self._guide.write_mask_into(
data_ptr=self._mask.data_ptr(),
numel=self._mask.numel(),
element_size=self._mask.element_size(),
)
# Any allowed tokens beyond the length of the scores will
# be ignored by the kernel, taking care of the issue with
# models such as Llama 3.2 Vision with an `<|image|>` token
# with id 128256, but scores.shape == torch.Size([128256])
_apply_token_bitmask_inplace_kernel(
logits=scores.unsqueeze(dim=0),
# mask must be on same device
mask=self._mask.to(scores.device, non_blocking=True))
self._mask.to("cpu", non_blocking=True)
return scores
def clone(self) -> BaseLogitsProcessor:
guide = copy.deepcopy(self._guide)
guide.reset()
return BaseLogitsProcessor(guide=guide,
eos_token_id=self._eos_token_id,
reasoner=self._reasoner)
class RegexLogitsProcessor(BaseLogitsProcessor):
@classmethod
def _get_guide(cls, regex_string: str,
tokenizer: PreTrainedTokenizerBase) -> Guide:
global CACHE
if CACHE is None:
CACHE = get_cache()
vocabulary = get_vocabulary(tokenizer) # type: ignore[arg-type]
cache_key = f"{vocabulary._hash}_{regex_string}"
if CACHE is not None and cache_key in CACHE:
return Guide(CACHE[cache_key])
index = Index(regex_string, vocabulary.inner)
if CACHE is not None:
CACHE[cache_key] = index
return Guide(index)
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[ReasoningParser]) -> None:
super().__init__(
guide=RegexLogitsProcessor._get_guide(regex_string, tokenizer),
eos_token_id=tokenizer.eos_token_id, # type: ignore
reasoner=reasoner)
class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
reasoner: Optional[ReasoningParser]) -> None:
if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema())
elif isinstance(schema, dict):
schema_str = json.dumps(schema)
elif isinstance(schema, str):
schema_str = schema
else:
raise ValueError(
f"Cannot parse schema {schema}. The schema must be either "
f"a Pydantic object, a dictionary or a string that contains "
f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer, reasoner)
class OutlinesVocabulary:
"""
Wrapper class for `outlines_core.Vocabulary`,
which allows us to store a hash with the vocabulary
"""
def __init__(self, vocabulary: Vocabulary) -> None:
# Actual vocabulary object
self.inner = vocabulary
# Have to do abs(hash()) because python hashes can
# be negative, and we are using hash as a cache key.
hex_str = hashlib.sha256(
vocabulary.__repr__().encode('utf-8')).hexdigest()
hash_int = int(hex_str, 16)
self._hash = hash_int
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
def _reduced_vocabulary(tokenizer: AnyTokenizer,
eos_token_id: int) -> dict[bytes, list[int]]:
"""Create a map from vocabulary tokens to lists of equivalent token ids.
Returns:
A Dict of token string -> equivalent token ids
"""
unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
def convert_token_to_string(token: str) -> str:
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
or token == "<0x20>"):
return " " + string
return string
vocabulary: dict[bytes, list[int]] = {}
empty_token_ids: list[int] = []
for token, token_idx in tokenizer.get_vocab().items():
if token in tokenizer.all_special_tokens: # type: ignore
continue
token_str = convert_token_to_string(token)
if token_str:
if isinstance(token, (bytes, bytearray)):
# For BPE tokenizers where tokens are stored as bytes.
# safe to ignore since token_str is of type (bytearray, bytes)
# by this point.
token_bytes = bytes(token_str) # type: ignore[arg-type]
elif "\ufffd" in token_str and not re_replacement_seq.match(
token_str):
# Handle tokens with invalid UTF-8 sequences.
if re_llama_byte_token.match(token):
# Llama-like tokenizers use <0xXX> for incomplete sequences.
token_bytes = bytes([int(token[3:5], 16)])
else:
# GPT2 tokenizers: map each byte back using unicode_to_bytes
byte_vals = [unicode_to_bytes.get(c) for c in token]
if None in byte_vals:
raise RuntimeError(
f"Cannot convert token `{token}`"
f" ({token_idx}) to bytes: {token_str}")
# safe to ignore, since if None in byte_vals,
# an error is thrown.
token_bytes = bytes(byte_vals) # type: ignore[arg-type]
else:
token_bytes = token_str.encode('utf-8')
if token_idx != eos_token_id:
vocabulary.setdefault(token_bytes, []).append(token_idx)
else:
empty_token_ids.append(token_idx)
return vocabulary
def get_vocabulary(tokenizer: AnyTokenizer) -> Vocabulary:
"""Get the `Vocabulary` object for a given tokenizer.
"""
if hasattr(tokenizer, "_outlines_vocabulary"):
return tokenizer._outlines_vocabulary # type: ignore
try:
if hasattr(
tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
eos_token_id = tokenizer.eos_token_id
else:
raise ValueError(
f"Error during guided decoding setup: Tokenizer"
f" ({type(tokenizer)}) has no `eos_token_id` property, "
"but `eos_token_id` is required for guided decoding"
" to work properly.")
reduced_vocab = _reduced_vocabulary(
tokenizer,
eos_token_id #type: ignore
)
vocabulary = OutlinesVocabulary(Vocabulary(eos_token_id,
reduced_vocab))
tokenizer._outlines_vocabulary = vocabulary # type: ignore
return vocabulary
except AttributeError as e:
raise ValueError(f"Cannot get the vocabulary of the tokenizer "
f"({type(tokenizer)}). The tokenizer should have a "
"get_vocab method.") from e
def get_cache_path() -> str:
"""Get the context object that contains previously-computed return values"""
outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
xdg_cache_home = os.getenv("XDG_CACHE_HOME")
home_dir = os.path.expanduser("~")
if outlines_cache_dir:
# OUTLINES_CACHE_DIR takes precedence
return outlines_cache_dir
elif xdg_cache_home:
return os.path.join(xdg_cache_home, ".cache", "outlines")
# If homedir is "/", we may be inside a container, and thus writing to
# root would be problematic, so we fallback to using a tempfile.
# Also validate the path exists, since os.path.expanduser does
# not garuntee existence.
elif os.path.isdir(home_dir) and home_dir != "/":
# Default Unix fallback: ~/.cache/outlines
return os.path.join(home_dir, ".cache", "outlines")
else:
import tempfile
# home_dir may be / inside a docker container without existing user
tempdir = tempfile.gettempdir()
return os.path.join(tempdir, ".cache", "outlines")
def get_cache():
"""Get the Cache instance to be used for index caching"""
cache_dir = get_cache_path()
if envs.VLLM_V0_USE_OUTLINES_CACHE:
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
"cache. It may consume a lot of disk space and should "
"not be used with untrusted clients.")
cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
outlines_version = importlib.metadata.version("outlines_core")
cached_version = cache.get('__version__', None)
if cached_version != outlines_version:
cache.clear()
cache.set('__version__', outlines_version)
return cache
else:
return LRUCache(maxsize=128)

View File

@ -1,242 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import regex as re
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""
def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
return True
# Check for array unsupported keywords
if obj.get("type") == "array" and any(key in obj for key in [
"uniqueItems", "contains", "minContains", "maxContains",
"minItems", "maxItems"
]):
return True
# Unsupported keywords for strings
if obj.get("type") == "string" and any(
key in obj for key in ["minLength", "maxLength", "format"]):
return True
# Unsupported keywords for objects
if obj.get("type") == "object" and any(key in obj for key in [
"minProperties", "maxProperties", "propertyNames",
"patternProperties"
]):
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def has_lmf_unsupported_json_features(schema: dict) -> bool:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.
Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""
def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def grammar_is_likely_lark(grammar_str: str) -> bool:
"""
Check if grammar appears to use Lark syntax.
Args:
grammar_str: Input grammar string
Returns:
bool: True if grammar appears to be in Lark format, False otherwise
Examples:
>>> grammar_is_likely_lark("rule: 'abc'")
True
>>> grammar_is_likely_lark("rule ::= 'abc'")
False
"""
if not grammar_str or not isinstance(grammar_str, str):
return False
for line in grammar_str.split('\n'):
# Remove both comment styles
line = re.sub(r'(#|//).*$', '', line).strip()
if not line:
continue
# Look for GBNF rule definition
if '::=' in line:
return False
return True
def convert_lark_to_gbnf(grammar_str: str) -> str:
"""
Convert a Lark grammar string to GBNF format.
GBNF reference:
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Lark grammar reference:
https://lark-parser.readthedocs.io/en/latest/grammar.html
Args:
grammar_str: Input grammar in Lark format
Returns:
str: Converted grammar in GBNF format
Examples:
>>> print(convert_lark_to_gbnf("rule: 'hello'"))
root ::= rule
rule ::= "hello"
"""
if not isinstance(grammar_str, str):
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
if not grammar_str.strip():
raise ValueError("Grammar string cannot be empty")
defined_rules = set()
referenced_rules = set()
output_lines = []
def clean_line(line: str) -> str:
"""Remove comments and whitespace from line."""
return re.sub(r'(#|//).*$', '', line).strip()
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
"""Validate quote matching in text."""
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
raise ValueError(
f"Mismatched quotes in {rule_name} on line {line_num}")
def extract_references(text: str) -> set:
"""Extract rule references from text."""
# Remove quoted strings and special characters
text = re.sub(r'"[^"]*"', '', text)
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
# First pass: Find root rule and validate rule definitions
lines = [clean_line(line) for line in grammar_str.split('\n')]
first_rule = None
for line_num, line in enumerate(lines, 1):
if not line or line.startswith('|'):
continue
if ':' in line:
try:
name = line.split(':', 1)[0].strip().strip('?')
defined_rules.add(name)
if first_rule is None:
first_rule = name
if name == 'start':
first_rule = 'start'
except IndexError as e:
raise ValueError(f"Invalid rule format on line {line_num}. "
"Expected 'rule_name: definition'") from e
if not defined_rules:
raise ValueError("No valid rules found in grammar")
# Add root rule
output_lines.append(f"root ::= {first_rule}")
# Second pass: Process rule definitions and alternatives
current_rule = None
current_definition = []
for line_num, line in enumerate(lines, 1):
if not line:
continue
try:
if ':' in line and not line.startswith('|'):
# Save previous rule if exists
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")
# Process new rule
name, definition = line.split(':', 1)
current_rule = name.strip().strip('?')
check_quotes(definition, f"rule '{current_rule}'", line_num)
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
referenced_rules.update(extract_references(definition))
current_definition = [definition.strip()]
elif line.startswith('|'):
if not current_rule:
raise ValueError(f"Alternative '|' on line {line_num} "
"without a preceding rule definition")
alt_def = line[1:].strip()
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
line_num)
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
referenced_rules.update(extract_references(alt_def))
current_definition.append(alt_def)
except ValueError as e:
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
# Add final rule if exists
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")
# Validate all rules are defined
undefined_rules = referenced_rules - defined_rules - {'root'}
if undefined_rules:
raise ValueError("Referenced rules are not defined: "
f"{', '.join(sorted(undefined_rules))}")
return '\n'.join(output_lines)

View File

@ -1,426 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# noqa: UP007
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import regex as re
import torch
import vllm.envs
from vllm.logger import init_logger
try:
import xgrammar as xgr
xgr_installed = True
except ImportError:
xgr_installed = False
pass
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
grammar_is_likely_lark)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoner: ReasoningParser | None,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
tokenizer=tokenizer,
max_threads=max_threads)
return XGrammarLogitsProcessor(config, reasoner)
@dataclass(frozen=True)
class TokenizerData:
"""Immutable container for cached tokenizer data."""
metadata: str
encoded_vocab: list[str] = field(default_factory=list)
class TokenizerDataCache:
"""Cache manager for tokenizer data to avoid repeated processing."""
_cache: dict[int, TokenizerData] = {}
@classmethod
def get_tokenizer_data(
cls,
tokenizer: PreTrainedTokenizer,
/,
*,
tokenizer_hash: int,
vocab_size: int,
) -> TokenizerData:
if tokenizer_hash not in cls._cache:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
# NOTE: We will need to use lm_head's vocab_size
# to determine correct special_token_ids for this tokenizer.
# See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501
vocab_size=vocab_size,
)
metadata = json.loads(tokenizer_info.dump_metadata())
# Vendored from xgrammar logic to get encoded_vocab
# https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
try:
vocab_dict = tokenizer.get_vocab()
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
# maintain tokenizer's indexing
encoded_vocab = [""] * tokenizer_info.vocab_size
for token, idx in vocab_dict.items():
if idx < tokenizer_info.vocab_size:
encoded_vocab[idx] = token
if isinstance(tokenizer, MistralTokenizer):
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
metadata.update({
"vocab_type": xgr.VocabType.BYTE_FALLBACK,
"add_prefix_space": True
})
cls._cache[tokenizer_hash] = TokenizerData(
encoded_vocab=encoded_vocab,
metadata=json.dumps(metadata),
)
return cls._cache[tokenizer_hash]
class GrammarCompilerCache:
"""
Cache for GrammarCompiler instances based on tokenizer.
This cache reduces the overhead of creating new compiler instances when
using the same tokenizer configuration.
"""
_cache: dict[str, xgr.GrammarCompiler] = {}
@classmethod
def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
cache_key = str(config.tokenizer_hash)
if cache_key not in cls._cache:
config_data = config.tokenizer_data
# In TokenizerDataCache.get_tokenizer_data, a serializable
# tokenizer_data is created and cached. This data is used to build
# a tokenizer_info and create an xgrammar compiler.
tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata(
encoded_vocab=config_data.encoded_vocab,
metadata=config_data.metadata,
)
cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024
cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info,
max_threads=config.max_threads,
cache_enabled=True,
cache_limit_bytes=cache_size,
)
return cls._cache[cache_key]
@dataclass
class GrammarConfig:
"""Serializable configuration for grammar compilation"""
tokenizer_hash: int
tokenizer_data: TokenizerData
json_str: str | None = None
grammar_str: str | None = None
json_object: bool | None = None
any_whitespace: bool = True
regex_str: str | None = None
max_threads: int = 8
@classmethod
def from_guided_params(cls,
guided_params: GuidedDecodingParams,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
max_threads: int = 8) -> GrammarConfig:
tokenizer_hash = hash(tokenizer)
tokenizer_data = TokenizerDataCache.get_tokenizer_data(
tokenizer,
tokenizer_hash=tokenizer_hash,
vocab_size=model_config.hf_text_config.vocab_size,
)
if guided_params.json:
if not isinstance(guided_params.json, str):
json_str = json.dumps(guided_params.json)
else:
json_str = guided_params.json
any_whitespace = not guided_params.disable_any_whitespace
# Check and log if model with xgrammar and whitespace have history
# of runaway generation of whitespaces.
# References:
# https://github.com/vllm-project/vllm/pull/12744
# https://github.com/mlc-ai/xgrammar/issues/212
model_with_warn = None
if 'Mistral' in model_config.model:
model_with_warn = 'Mistral'
elif 'Qwen' in model_config.model:
model_with_warn = 'Qwen'
if model_with_warn is not None and any_whitespace:
logger.info_once(
"%s model detected, consider setting `disable_any_whitespace` to prevent runaway generation of whitespaces.", # noqa: E501
model_with_warn,
)
# Validate the schema and raise ValueError here if it is invalid.
# This is to avoid exceptions in model execution, which will crash
# the engine worker process.
try:
xgr.Grammar.from_json_schema(json_str,
any_whitespace=any_whitespace)
except RuntimeError as err:
raise ValueError(str(err)) from err
return cls(json_str=json_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
any_whitespace=any_whitespace)
elif guided_params.grammar:
# XGrammar only supports GBNF grammars, so we must convert Lark
if grammar_is_likely_lark(guided_params.grammar):
try:
grammar_str = convert_lark_to_gbnf(guided_params.grammar)
except ValueError as e:
raise ValueError(
"Failed to convert the grammar from Lark to GBNF. "
"Please either use GBNF grammar directly or specify"
" --guided-decoding-backend=outlines.\n"
f"Conversion error: {str(e)}") from e
else:
grammar_str = guided_params.grammar
# Validate the grammar and raise ValueError here if it is invalid.
# This is to avoid exceptions in model execution, which will crash
# the engine worker process.
try:
xgr.Grammar.from_ebnf(grammar_str)
except RuntimeError as err:
raise ValueError(str(err)) from err
return cls(grammar_str=grammar_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data)
elif guided_params.json_object:
return cls(
json_object=True,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
elif guided_params.choice:
choice_str = GrammarConfig.choice_as_grammar(guided_params.choice)
try:
xgr.Grammar.from_ebnf(choice_str)
except RuntimeError as err:
raise ValueError(str(err)) from err
return cls(
grammar_str=choice_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
elif guided_params.regex:
return cls(
regex_str=guided_params.regex,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data,
)
else:
raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar"
)
@staticmethod
def escape_ebnf_string(s: str) -> str:
"""Escape special characters in a EBNF string."""
# Escape double quotes and backslashes
return re.sub(r'(["\\])', r'\\\1', s)
@staticmethod
def choice_as_grammar(choice: list[str] | None) -> str:
if choice is None:
raise ValueError("Choice is not set")
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
return grammar
@staticmethod
def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo:
return xgr.TokenizerInfo.from_vocab_and_metadata(
encoded_vocab=tokenizer_data.encoded_vocab,
metadata=tokenizer_data.metadata,
)
@dataclass
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
token_bitmask: torch.Tensor = None # type: ignore[assignment]
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
batch_size: int = field(default=1)
prefilled: bool = field(default=False)
def __post_init__(self):
if self.tokenizer_info is None:
self.tokenizer_info = self.config.tokenizer_info(
self.config.tokenizer_data)
def __getstate__(self) -> dict[str, Any]:
return {'config': self.config, 'reasoner': self.reasoner}
def __setstate__(self, state: dict[str, Any]):
self.config = state['config']
self.reasoner = state['reasoner']
self.tokenizer_info = GrammarConfig.tokenizer_info(
self.config.tokenizer_data)
self.ctx = None
self.matchers = []
self.batch_size = 1
self.token_bitmask = None # type: ignore[assignment]
self.prefilled = False
def _ensure_ctx(self):
"""Lazily initialize the processor in the worker process"""
if self.ctx is None:
compiler = GrammarCompilerCache.get_compiler(self.config)
if self.config.json_str is not None:
any_whitespace = self.config.any_whitespace
self.ctx = compiler\
.compile_json_schema(self.config.json_str,
any_whitespace=any_whitespace)
elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object:
any_whitespace = self.config.any_whitespace
self.ctx = compiler\
.compile_json_schema('{"type": "object"}',
any_whitespace=any_whitespace)
elif self.config.regex_str:
self.ctx = compiler.compile_regex(self.config.regex_str)
else:
raise ValueError(
"Invalid configuration for xgrammar logits processor")
def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
# Skip the structured logits processing if reasoning is not finished.
# reasoner is not None only when `--reasoning-parser` is set.
if self.reasoner is not None and \
not self.reasoner.is_reasoning_end(
input_ids):
return scores
if self.ctx is None:
self._ensure_ctx()
if len(self.matchers) == 0:
self.matchers = [
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
]
self.token_bitmask = xgr.allocate_token_bitmask(
self.batch_size, self.tokenizer_info.vocab_size)
if not self.prefilled:
# Have not sampled a token yet
self.prefilled = True
else:
for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
sampled_token = input_ids[-1]
assert self.matchers[i].accept_token(sampled_token)
for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
# @ubospica: ideally, fill_next_token_bitmask should be
# parallelized with model decoding
# See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303
matcher.fill_next_token_bitmask(self.token_bitmask, i)
# token_bitmask is a CPU tensor for use with accept_token and
# fill_next_token_bitmask so we move it to the device of scores
device_type = scores.device.type
dtype = scores.dtype
if device_type != "cuda":
# xgrammar on cpu only supports float32 scores
# see: https://github.com/mlc-ai/xgrammar/blob/c1b64920cad24f44f235778c1c00bb52d57da01a/python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py#L22
scores = scores.to("cpu").float().unsqueeze(0)
# Note: In this method, if the tensors have different dimensions
# on CPU device fails, but on GPU it runs without error. Hence the
# unsqueeze above for scores, to match the token bitmask shape
xgr.apply_token_bitmask_inplace(
scores, self.token_bitmask.to(scores.device, non_blocking=True))
if device_type != "cuda":
scores = scores.to(dtype).to(device_type).squeeze()
return scores
def clone(self) -> XGrammarLogitsProcessor:
"""Create a new instance with shared compiled grammar
but separate state"""
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner,
None, self.tokenizer_info)
# Share the compiled grammar context (immutable after compilation)
new_processor.ctx = self.ctx
# Create fresh matchers for the new sequence
if self.ctx is not None:
new_processor.matchers = [
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
]
# Create a new token bitmask with the same size
if hasattr(self, 'token_bitmask') and self.token_bitmask is not None:
new_processor.token_bitmask = self.token_bitmask
# Copy simple attributes
new_processor.batch_size = self.batch_size
# Reset prefilled state for new sequence
new_processor.prefilled = False
return new_processor