[V0 deprecation] Guided decoding (#21347)
Signed-off-by: Reza Barazesh <rezabarazesh@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -128,11 +128,10 @@ steps:
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Test (API Server) # 40min
|
||||
|
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -10,7 +10,6 @@
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin @russellb @aarnphm
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
/vllm/lora @jeejeelee
|
||||
@ -35,9 +34,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm
|
||||
/tests/entrypoints/llm/test_guided_generate.py @mgoin @russellb @aarnphm
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/model_executor/test_guided_processors.py @mgoin @russellb
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multi_step @alexm-redhat @comaniac
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
|
3
.github/mergify.yml
vendored
3
.github/mergify.yml
vendored
@ -149,9 +149,6 @@ pull_request_rules:
|
||||
- files=examples/offline_inference/structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
|
||||
- files~=^vllm/model_executor/guided_decoding/
|
||||
- files=tests/model_executor/test_guided_processors.py
|
||||
- files=tests/entrypoints/llm/test_guided_generate.py
|
||||
- files~=^tests/v1/structured_output/
|
||||
- files=tests/v1/entrypoints/llm/test_guided_generate.py
|
||||
- files~=^vllm/v1/structured_output/
|
||||
|
@ -1,552 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import weakref
|
||||
from enum import Enum
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
import regex as re
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
# Separate backends which support grammars vs ones
|
||||
# which only support regex based constraints in tests.
|
||||
GRAMMAR_DECODING_BACKENDS = [
|
||||
# (backend, disable_any_whitespace),
|
||||
("lm-format-enforcer", False),
|
||||
("xgrammar", True),
|
||||
("guidance", True),
|
||||
]
|
||||
|
||||
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
regex=sample_regex,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
assert re.fullmatch(sample_regex, generated_text) is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion(sample_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_complex_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an assignment grade "
|
||||
f"that fits this schema: {sample_complex_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_complex_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_definition_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for solving 8x + 7 = -23 "
|
||||
f"that fits this schema: {sample_definition_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_definition_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_enum_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
"Create a bug report JSON that fits this schema: "
|
||||
f"{sample_enum_json_schema}. Make it for a high priority critical bug."
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_enum_json_schema)
|
||||
|
||||
# Additional assertions to verify enum values
|
||||
assert output_json["status"] in ["active", "inactive", "pending"]
|
||||
assert output_json["priority"] in ["low", "medium", "high", "critical"]
|
||||
assert output_json["category"]["type"] in [
|
||||
"bug", "feature", "improvement"
|
||||
]
|
||||
assert output_json["category"]["severity"] in [1, 2, 3, 4, 5]
|
||||
for flag in output_json["flags"]:
|
||||
assert flag in ["urgent", "blocked", "needs_review", "approved"]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
choice=sample_guided_choice,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(
|
||||
prompts="The best language for type-safe systems programming is ",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
assert generated_text in sample_guided_choice
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_grammar(sample_sql_statements, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
grammar=sample_sql_statements,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a sql state that select col_1 from "
|
||||
"table_1 where it is equals to 1"),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_statements)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guided_options_request_deprecation_warning(sample_regex, llm):
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="guided_options_request"):
|
||||
llm.generate(prompts="This should fail",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot set both"):
|
||||
llm.generate(prompts="This should fail",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_disable_guided_decoding_fallback(sample_regex, llm):
|
||||
# see has_xgrammar_unsupported_json_features()
|
||||
unsupported_json = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"example": {
|
||||
"type": "string",
|
||||
"minLength": 5 # unsupported by xgrammar
|
||||
}
|
||||
}
|
||||
}
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=unsupported_json,
|
||||
backend="xgrammar",
|
||||
disable_fallback=True))
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="xgrammar does not support advanced JSON schema features "
|
||||
"like string length, item limits, or property bounds."):
|
||||
llm.generate(prompts="This should fail",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=100,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json_object=True,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
|
||||
for i in range(2):
|
||||
generated_text = output.outputs[i].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
|
||||
if disable_any_whitespace:
|
||||
assert "\n" not in generated_text
|
||||
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
# A list is not what was intended, but is still valid
|
||||
# json.
|
||||
assert isinstance(parsed_json, (dict, list))
|
||||
|
||||
|
||||
class CarType(str, Enum):
|
||||
sedan = "sedan"
|
||||
suv = "SUV"
|
||||
truck = "Truck"
|
||||
coupe = "Coupe"
|
||||
|
||||
|
||||
class CarDescription(BaseModel):
|
||||
brand: str
|
||||
model: str
|
||||
car_type: CarType
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(
|
||||
prompts="Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sample_output_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {
|
||||
"type": "integer",
|
||||
"minimum": 18,
|
||||
"maximum": 99
|
||||
},
|
||||
"score": {
|
||||
"type": "number",
|
||||
"minimum": 0.0,
|
||||
"maximum": 100.0
|
||||
},
|
||||
"zipcode": {
|
||||
"type": "string",
|
||||
"pattern": r"^\d{5}(-\d{4})?$"
|
||||
},
|
||||
},
|
||||
"required": ["age", "score", "zipcode"],
|
||||
}
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_output_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace),
|
||||
)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
"Create a JSON object for a user with age, score, and zipcode."
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_output_schema)
|
||||
assert 18 <= output_json["age"] <= 99
|
||||
assert 0.0 <= output_json["score"] <= 100.0
|
||||
assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"])
|
||||
is not None)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guidance_no_additional_properties(llm):
|
||||
schema = {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'a1': {
|
||||
'type': 'string'
|
||||
},
|
||||
'a2': {
|
||||
'type': 'string'
|
||||
},
|
||||
'a3': {
|
||||
'type': 'string'
|
||||
}
|
||||
},
|
||||
'required': ['a1', 'a2', 'a3'],
|
||||
}
|
||||
|
||||
prompt = (
|
||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
|
||||
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
|
||||
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
def generate_with_backend(backend, disable_additional_properties):
|
||||
guided_params = GuidedDecodingParams(
|
||||
json=schema,
|
||||
backend=backend,
|
||||
disable_any_whitespace=True,
|
||||
disable_additional_properties=disable_additional_properties)
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=256,
|
||||
guided_decoding=guided_params)
|
||||
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
assert outputs is not None
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
assert generated_text is not None
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
jsonschema.validate(instance=parsed_json, schema=schema)
|
||||
return parsed_json
|
||||
|
||||
base_generated = generate_with_backend("guidance", False)
|
||||
assert "a1" in base_generated
|
||||
assert "a2" in base_generated
|
||||
assert "a3" in base_generated
|
||||
# by default additional keys are generated
|
||||
assert "a4" in base_generated
|
||||
assert "a5" in base_generated
|
||||
assert "a6" in base_generated
|
||||
|
||||
generated = generate_with_backend("guidance", True)
|
||||
assert "a1" in generated
|
||||
assert "a2" in generated
|
||||
assert "a3" in generated
|
||||
assert "a4" not in generated
|
||||
assert "a5" not in generated
|
||||
assert "a6" not in generated
|
@ -4,43 +4,11 @@
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
from vllm_test_utils import BlameResult, blame
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
V1 only supports xgrammar so this is irrelevant.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
def run_normal_opt125m():
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM without guided decoding as a baseline.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.3)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Destroy the LLM object and free up the GPU memory.
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
def run_normal():
|
||||
@ -67,20 +35,22 @@ def run_normal():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def run_lmfe(sample_regex):
|
||||
def run_xgrammar(sample_regex):
|
||||
# Create an LLM with guided decoding enabled.
|
||||
llm = LLM(model="distilbert/distilgpt2",
|
||||
enforce_eager=True,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
guided_decoding_backend="xgrammar",
|
||||
gpu_memory_utilization=0.3)
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
prompt = f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
guided_decoding = GuidedDecodingParams(regex=sample_regex)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=guided_decoding)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
prompts=[prompt] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
@ -103,7 +73,7 @@ def test_lazy_outlines(sample_regex):
|
||||
lambda: module_name in sys.modules) if use_blame else nullcontext()
|
||||
with context as result:
|
||||
run_normal()
|
||||
run_lmfe(sample_regex)
|
||||
run_xgrammar(sample_regex)
|
||||
if use_blame:
|
||||
assert isinstance(result, BlameResult)
|
||||
print(f"the first import location is:\n{result.trace_stack}")
|
||||
|
@ -488,7 +488,9 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
sample_guided_choice):
|
||||
sample_guided_choice, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -524,8 +526,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_json_chat(client: openai.AsyncOpenAI,
|
||||
sample_json_schema):
|
||||
async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@ -568,7 +572,10 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex):
|
||||
async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@ -653,7 +660,10 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema):
|
||||
async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Tool use is only supported in v1 engine")
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -741,131 +751,6 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema):
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_required_tool_use(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool, model_name: str):
|
||||
if is_v1_server:
|
||||
pytest.skip(
|
||||
"tool_choice='required' requires features unsupported on V1")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Berlin and the "\
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
# Non-streaming test
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
)
|
||||
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
|
||||
# Streaming test
|
||||
stream = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
|
||||
sample_json_schema):
|
||||
@ -948,7 +833,11 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip(
|
||||
"JSON schema response format is only supported in v1 engine")
|
||||
prompt = 'what is 1+1? The format is "result": 2'
|
||||
# Check that this prompt cannot lead to a valid JSON without json_schema
|
||||
for _ in range(2):
|
||||
|
@ -28,7 +28,7 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# but we're not testing generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -95,6 +95,14 @@ def server(default_server_args, request):
|
||||
os.environ['VLLM_USE_V1'] = original_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_v1_server(server):
|
||||
import os
|
||||
|
||||
# For completion tests, we assume v0 since there's no explicit v1 setup
|
||||
return os.environ.get('VLLM_USE_V1', '0') == '1'
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
@ -631,7 +639,10 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema):
|
||||
sample_json_schema, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example JSON for an employee profile "
|
||||
@ -653,7 +664,10 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_regex):
|
||||
sample_regex, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
|
||||
@ -674,7 +688,11 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_guided_choice):
|
||||
sample_guided_choice,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="The best language for type-safe systems programming is ",
|
||||
@ -692,7 +710,9 @@ async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_grammar(client: openai.AsyncOpenAI,
|
||||
sample_sql_statements):
|
||||
sample_sql_statements, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided grammar is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
@ -754,7 +774,11 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema, sample_regex):
|
||||
sample_json_schema, sample_regex,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
|
@ -9,6 +9,11 @@ import regex as re
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v1_only(monkeypatch):
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt():
|
||||
model_name = "gpt2"
|
||||
@ -37,24 +42,3 @@ async def test_out_of_vocab_token_ids():
|
||||
prompt=[999999],
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_multistep_with_guided_decoding():
|
||||
model_name = "gpt2"
|
||||
server_args = ["--enforce-eager", "--num-scheduler-steps", "8"]
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
with pytest.raises(
|
||||
openai.BadRequestError,
|
||||
match=re.compile(
|
||||
'.*Guided decoding .* multi-step decoding.*').pattern):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"response_format": {
|
||||
"type": "json_object"
|
||||
}})
|
||||
|
@ -1,207 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor,
|
||||
get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
|
||||
GUIDED_DECODING_BACKENDS = [
|
||||
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
|
||||
]
|
||||
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
||||
# Initialize the tokenizer for the model here to avoid repeated loading
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_7B_tokenzer():
|
||||
return AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def deepseek_r1_qwen_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
|
||||
sample_json_schema):
|
||||
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||
regex_LP = RegexLogitsProcessor(sample_regex,
|
||||
zephyr_7B_tokenzer,
|
||||
reasoner=None)
|
||||
json_LP = JSONLogitsProcessor(sample_json_schema,
|
||||
zephyr_7B_tokenzer,
|
||||
whitespace_pattern=None,
|
||||
reasoner=None)
|
||||
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_LP([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_LP([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend", GUIDED_DECODING_BACKENDS)
|
||||
@pytest.mark.parametrize("is_local", [True, False])
|
||||
async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
|
||||
sample_regex,
|
||||
sample_json_schema,
|
||||
zephyr_7B_tokenzer):
|
||||
|
||||
config = ModelConfig(
|
||||
MODEL_NAME,
|
||||
runner="generate",
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
regex_lp = get_local_guided_decoding_logits_processor(
|
||||
regex_request, zephyr_7B_tokenzer, config) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
regex_request, zephyr_7B_tokenzer, config)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
# allowed tokens at state 0
|
||||
tensor = regex_lp([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
json_request, zephyr_7B_tokenzer, config)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp([], tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend",
|
||||
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT)
|
||||
@pytest.mark.parametrize("is_local", [True, False])
|
||||
@pytest.mark.parametrize("reasoning_backend", ["deepseek_r1"])
|
||||
async def test_guided_logits_processor_with_reasoning(
|
||||
backend: str, is_local: bool, reasoning_backend: str, sample_regex,
|
||||
sample_json_schema, deepseek_r1_qwen_tokenizer):
|
||||
|
||||
config = ModelConfig(
|
||||
REASONING_MODEL_NAME,
|
||||
runner="generate",
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
"<think>here is the thinking process")
|
||||
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||
|
||||
regex_lp = get_local_guided_decoding_logits_processor(regex_request,
|
||||
deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
regex_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(151664)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
"<think>here is the thinking process")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = get_local_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(151664)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert torch.allclose(tensor, original_tensor)
|
||||
|
||||
# Thinking is over, so the tensor should change.
|
||||
token_ids = deepseek_r1_qwen_tokenizer.encode(
|
||||
"<think>here is the thinking process</think>")
|
||||
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||
backend=backend)
|
||||
json_lp = get_local_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config,
|
||||
reasoning_backend) if is_local else \
|
||||
await get_guided_decoding_logits_processor(
|
||||
json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(151664)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
|
||||
def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
|
||||
with pytest.raises(ValueError,
|
||||
match="You can only use one kind of guided"):
|
||||
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="You can only use one kind of guided"):
|
||||
GuidedDecodingParams(json=sample_json_schema, json_object=True)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="You can only use one kind of guided"):
|
||||
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="You can only use one kind of guided"):
|
||||
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
|
||||
|
||||
|
||||
def test_pickle_xgrammar_tokenizer_data():
|
||||
try:
|
||||
import xgrammar as xgr
|
||||
except ImportError:
|
||||
pytest.skip("Could not import xgrammar to run test")
|
||||
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||
TokenizerData)
|
||||
tokenizer_data = TokenizerData(
|
||||
metadata=
|
||||
'{"vocab_type":2,"vocab_size":151665,"add_prefix_space":false,"stop_token_ids":[151645]}',
|
||||
encoded_vocab=['!', '"', '#', '$', '%'],
|
||||
)
|
||||
pickled = pickle.dumps(tokenizer_data)
|
||||
|
||||
assert pickled is not None
|
||||
|
||||
depickled: TokenizerData = pickle.loads(pickled)
|
||||
|
||||
assert depickled is not None
|
||||
assert json.loads(
|
||||
depickled.metadata)['vocab_type'] == xgr.VocabType.BYTE_LEVEL.value
|
@ -3,13 +3,11 @@
|
||||
import copy
|
||||
import json
|
||||
|
||||
import jsonschema
|
||||
import jsonschema.exceptions
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall, MistralToolParser)
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
@ -274,53 +272,6 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
|
||||
assert parsed_message.content is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("guided_backend",
|
||||
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||
def test_mistral_guided_decoding(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
vllm_runner,
|
||||
model: str,
|
||||
guided_backend: str,
|
||||
) -> None:
|
||||
with monkeypatch.context() as m:
|
||||
# Guided JSON not supported in xgrammar + V1 yet
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype='bfloat16',
|
||||
tokenizer_mode="mistral",
|
||||
guided_decoding_backend=guided_backend,
|
||||
) as vllm_model:
|
||||
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA)
|
||||
params = SamplingParams(max_tokens=512,
|
||||
temperature=0.7,
|
||||
guided_decoding=guided_decoding)
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
|
||||
}]
|
||||
outputs = vllm_model.llm.chat(messages, sampling_params=params)
|
||||
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
json_response = json.loads(generated_text)
|
||||
assert outputs is not None
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=json_response,
|
||||
schema=SAMPLE_JSON_SCHEMA)
|
||||
except jsonschema.exceptions.ValidationError:
|
||||
pytest.fail("Generated response is not valid with JSON schema")
|
||||
|
||||
|
||||
def test_mistral_function_call_nested_json():
|
||||
"""Ensure that the function-name regex captures the entire outer-most
|
||||
JSON block, including nested braces."""
|
||||
|
@ -14,9 +14,9 @@ from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
"""We can run both engines for this test."""
|
||||
pass
|
||||
def v1(monkeypatch):
|
||||
"""Only run on vLLM v1."""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
|
||||
def _generate(
|
||||
|
@ -56,8 +56,7 @@ def test_sampling_params_from_request_with_no_guided_decoding_backend(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
|
||||
[("xgrammar", "xgrammar"),
|
||||
("lm-format-enforcer", "lm-format-enforcer"),
|
||||
[("xgrammar", "xgrammar"), ("guidance", "guidance"),
|
||||
("outlines", "outlines")])
|
||||
def test_sampling_params_from_request_with_guided_decoding_backend(
|
||||
request_level_guided_decoding_backend: str, expected: str,
|
||||
|
@ -47,13 +47,6 @@ def test_unsupported_configs(monkeypatch):
|
||||
},
|
||||
).create_engine_config()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
guided_decoding_disable_fallback=True,
|
||||
).create_engine_config()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
AsyncEngineArgs(
|
||||
model=MODEL,
|
||||
|
@ -34,7 +34,6 @@ ALLOWED_FILES = set([
|
||||
'vllm/model_executor/models/registry.py',
|
||||
'tests/test_utils.py',
|
||||
'tests/tokenization/test_cached_tokenizer.py',
|
||||
'tests/model_executor/test_guided_processors.py',
|
||||
'vllm/distributed/utils.py',
|
||||
'vllm/distributed/parallel_state.py',
|
||||
'vllm/engine/multiprocessing/client.py',
|
||||
|
@ -1774,8 +1774,8 @@ class CacheConfig:
|
||||
- "builtin" is Python's built-in hash.\n
|
||||
- "sha256" is collision resistant but with certain overheads.
|
||||
This option uses Pickle for object serialization before hashing.\n
|
||||
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
|
||||
hash. It serializes objects using canonical CBOR and hashes them with
|
||||
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
|
||||
hash. It serializes objects using canonical CBOR and hashes them with
|
||||
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
|
||||
digest."""
|
||||
cpu_offload_gb: float = 0
|
||||
@ -3721,12 +3721,7 @@ def get_served_model_name(model: str,
|
||||
return served_model_name
|
||||
|
||||
|
||||
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
||||
"xgrammar", "guidance"]
|
||||
|
||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
|
||||
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
|
||||
GuidedDecodingBackendV1]
|
||||
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"]
|
||||
|
||||
|
||||
@config
|
||||
@ -3734,7 +3729,7 @@ GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
|
||||
class DecodingConfig:
|
||||
"""Dataclass which contains the decoding strategy of the engine."""
|
||||
|
||||
backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||
backend: GuidedDecodingBackend = "auto"
|
||||
"""Which engine will be used for guided decoding (JSON schema / regex etc)
|
||||
by default. With "auto", we will make opinionated choices based on request
|
||||
contents and what the backend libraries currently support, so the behavior
|
||||
@ -3776,13 +3771,6 @@ class DecodingConfig:
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if envs.VLLM_USE_V1:
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV1)
|
||||
else:
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV0)
|
||||
if self.backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid backend '{self.backend}',"
|
||||
f" must be one of {valid_guided_backends}")
|
||||
if (self.disable_any_whitespace
|
||||
and self.backend not in ("xgrammar", "guidance")):
|
||||
raise ValueError("disable_any_whitespace is only supported for "
|
||||
|
@ -25,14 +25,14 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
ConfigFormat, ConfigType, ConvertOption,
|
||||
DecodingConfig, DetailedTraceModules, Device,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
||||
HfOverrides, KVEventsConfig, KVTransferConfig,
|
||||
LoadConfig, LogprobsMode, LoRAConfig, ModelConfig,
|
||||
ModelDType, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
|
||||
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||
TokenizerMode, VllmConfig, get_attr_docs, get_field)
|
||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
@ -1343,14 +1343,6 @@ class EngineArgs:
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
if self.guided_decoding_backend not in get_args(
|
||||
GuidedDecodingBackendV1):
|
||||
_raise_or_fallback(
|
||||
feature_name=
|
||||
f"--guided-decoding-backend={self.guided_decoding_backend}",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# Need at least Ampere for now (FA support required).
|
||||
# Skip this check if we are running on a non-GPU platform,
|
||||
# or if the device capability is not available
|
||||
|
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import time
|
||||
import weakref
|
||||
from functools import partial
|
||||
@ -24,8 +23,6 @@ from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -469,19 +466,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(params, SamplingParams) and \
|
||||
params.guided_decoding is not None:
|
||||
# Guided decoding has an async implementation for building logits
|
||||
# processors in a separate threadpool.
|
||||
# We want to invoke that here instead of using the blocking
|
||||
# implementation in the LLMEngine
|
||||
params = await build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=await self.get_tokenizer_async(lora_request),
|
||||
default_guided_backend=self.decoding_config.backend,
|
||||
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||
model_config=self.model_config)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
@ -503,48 +487,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def build_guided_decoding_logits_processor_async(
|
||||
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
||||
default_guided_backend: str, reasoning_backend: Optional[str],
|
||||
model_config: ModelConfig) -> SamplingParams:
|
||||
"""Constructs logits processors based on the guided_decoding,
|
||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||
those fields and adds the constructed logits processors to the
|
||||
logits_processors field. Modifies sampling params in-place and returns
|
||||
the modified sampling params."""
|
||||
if sampling_params.guided_decoding is None:
|
||||
return sampling_params
|
||||
|
||||
# Defensively copy sampling params since guided decoding logits
|
||||
# processors can have different state for each request
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
guided_decoding = sampling_params.guided_decoding
|
||||
|
||||
logger.debug(
|
||||
"Building guided decoding logits processor. "
|
||||
"guided_decoding: %s%s", guided_decoding,
|
||||
f", reasoning_backend: {reasoning_backend}"
|
||||
if reasoning_backend is not None else "")
|
||||
|
||||
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
||||
|
||||
processor = await get_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding,
|
||||
tokenizer=tokenizer,
|
||||
reasoning_backend=reasoning_backend,
|
||||
model_config=model_config)
|
||||
|
||||
if processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(processor)
|
||||
|
||||
# Unset guided decoding params after constructing the lp from them
|
||||
sampling_params.guided_decoding = None
|
||||
|
||||
return sampling_params
|
||||
|
||||
|
||||
class AsyncLLMEngine(EngineClient):
|
||||
"""An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
|
||||
|
||||
@ -1028,7 +970,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
```
|
||||
# Please refer to entrypoints/api_server.py for
|
||||
# the complete example.
|
||||
|
||||
|
||||
# initialize the engine and the example input
|
||||
# note that engine_args here is AsyncEngineArgs instance
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
@ -1036,13 +978,13 @@ class AsyncLLMEngine(EngineClient):
|
||||
"input": "What is LLM?",
|
||||
"request_id": 0,
|
||||
}
|
||||
|
||||
|
||||
# start the generation
|
||||
results_generator = engine.encode(
|
||||
example_input["input"],
|
||||
PoolingParams(),
|
||||
example_input["request_id"])
|
||||
|
||||
|
||||
# get the results
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
@ -1052,7 +994,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
# Return or raise an error
|
||||
...
|
||||
final_output = request_output
|
||||
|
||||
|
||||
# Process and return the final output
|
||||
...
|
||||
```
|
||||
|
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import time
|
||||
from collections import Counter as collectionsCounter
|
||||
from collections import deque
|
||||
@ -36,8 +35,6 @@ from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import get_bad_words_logits_processors
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
@ -686,11 +683,10 @@ class LLMEngine:
|
||||
"Priority scheduling is not enabled.")
|
||||
|
||||
if isinstance(params, SamplingParams) \
|
||||
and (params.guided_decoding or params.logits_processors) \
|
||||
and params.logits_processors \
|
||||
and self.scheduler_config.num_scheduler_steps > 1:
|
||||
raise ValueError(
|
||||
"Guided decoding and logits processors are not supported "
|
||||
"in multi-step decoding")
|
||||
"Logits processors are not supported in multi-step decoding")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
@ -1226,7 +1222,7 @@ class LLMEngine:
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
example_inputs = [(0, "What is LLM?",
|
||||
SamplingParams(temperature=0.0))]
|
||||
|
||||
|
||||
# Start the engine with an event loop
|
||||
while True:
|
||||
if example_inputs:
|
||||
@ -1983,43 +1979,13 @@ class LLMEngine:
|
||||
def _build_logits_processors(
|
||||
self, sampling_params: SamplingParams,
|
||||
lora_request: Optional[LoRARequest]) -> SamplingParams:
|
||||
"""Constructs logits processors based on the guided_decoding,
|
||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||
those fields and adds the constructed logits processors to the
|
||||
logits_processors field. Returns the modified sampling params."""
|
||||
"""Constructs logits processors based on the logits_bias, and
|
||||
allowed_token_ids fields in sampling_params. Deletes those fields and
|
||||
adds the constructed logits processors to the logits_processors field.
|
||||
Returns the modified sampling params."""
|
||||
|
||||
logits_processors = []
|
||||
|
||||
if sampling_params.guided_decoding is not None:
|
||||
# Defensively copy sampling params since guided decoding logits
|
||||
# processors can have different state for each request
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
guided_decoding = sampling_params.guided_decoding
|
||||
|
||||
logger.debug(
|
||||
"Building guided decoding logits processor in "
|
||||
"LLMEngine. Params: %s", guided_decoding)
|
||||
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
guided_decoding.backend = guided_decoding.backend or \
|
||||
self.decoding_config.backend
|
||||
|
||||
if self.decoding_config.reasoning_backend:
|
||||
logger.debug("Building with reasoning backend %s",
|
||||
self.decoding_config.reasoning_backend)
|
||||
|
||||
processor = get_local_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding,
|
||||
tokenizer=tokenizer,
|
||||
model_config=self.model_config,
|
||||
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||
)
|
||||
if processor:
|
||||
logits_processors.append(processor)
|
||||
|
||||
# Unset so this doesn't get passed down to the model
|
||||
sampling_params.guided_decoding = None
|
||||
|
||||
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
|
||||
|
@ -20,8 +20,6 @@ from vllm.config import DecodingConfig, ModelConfig, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.async_llm_engine import (
|
||||
build_guided_decoding_logits_processor_async)
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||
@ -537,22 +535,6 @@ class MQLLMEngineClient(EngineClient):
|
||||
if request_id in self.output_queues:
|
||||
raise ValueError(f"Request {request_id} already exists")
|
||||
|
||||
# Constructing guided decoding logits processors is expensive, so we do
|
||||
# it here to avoid contending with cpu resources and the GIL on the
|
||||
# backend process.
|
||||
if isinstance(params, SamplingParams) and \
|
||||
params.guided_decoding is not None:
|
||||
params = await \
|
||||
build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=await self.get_tokenizer(lora_request),
|
||||
default_guided_backend=(self.decoding_config.backend
|
||||
if self.decoding_config
|
||||
else DecodingConfig.backend),
|
||||
model_config=self.model_config,
|
||||
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||
)
|
||||
|
||||
# 1) Create output queue for this requests.
|
||||
queue: asyncio.Queue[Union[RequestOutput,
|
||||
BaseException]] = asyncio.Queue()
|
||||
|
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
|
||||
@ -40,15 +39,13 @@ from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest, LLMGuidedOptions)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
@ -330,8 +327,6 @@ class LLM:
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -345,8 +340,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[list[int]] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -360,8 +353,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[list[list[int]]] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -376,8 +367,6 @@ class LLM:
|
||||
prompt_token_ids: list[int],
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -392,8 +381,6 @@ class LLM:
|
||||
prompt_token_ids: list[list[int]],
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -406,8 +393,6 @@ class LLM:
|
||||
prompt_token_ids: Union[list[int], list[list[int]]],
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -425,8 +410,6 @@ class LLM:
|
||||
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
priority: Optional[list[int]] = None,
|
||||
) -> list[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
@ -478,14 +461,6 @@ class LLM:
|
||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||
prompts)
|
||||
|
||||
if isinstance(guided_options_request, dict):
|
||||
if len(guided_options_request) > 1:
|
||||
raise ValueError(
|
||||
"You can only use one guided decoding but multiple is "
|
||||
f"specified: {guided_options_request}")
|
||||
guided_options_request = GuidedDecodingRequest(
|
||||
**guided_options_request)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
@ -507,7 +482,6 @@ class LLM:
|
||||
params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
guided_options=guided_options_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
@ -1361,17 +1335,17 @@ class LLM:
|
||||
of your inputs into a single list and pass it to this method.
|
||||
|
||||
Supports both text and multi-modal data (images, etc.) when used with
|
||||
appropriate multi-modal models. For multi-modal inputs, ensure the
|
||||
appropriate multi-modal models. For multi-modal inputs, ensure the
|
||||
prompt structure matches the model's expected input format.
|
||||
|
||||
Args:
|
||||
data_1: Can be a single prompt, a list of prompts or
|
||||
`ScoreMultiModalParam`, which can contain either text or
|
||||
multi-modal data. When a list, it must have the same length as
|
||||
data_1: Can be a single prompt, a list of prompts or
|
||||
`ScoreMultiModalParam`, which can contain either text or
|
||||
multi-modal data. When a list, it must have the same length as
|
||||
the `data_2` list.
|
||||
data_2: The data to pair with the query to form the input to
|
||||
data_2: The data to pair with the query to form the input to
|
||||
the LLM. Can be text or multi-modal data. See [PromptType]
|
||||
[vllm.inputs.PromptType] for more details about the format of
|
||||
[vllm.inputs.PromptType] for more details about the format of
|
||||
each prompt.
|
||||
use_tqdm: If `True`, shows a tqdm progress bar.
|
||||
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
|
||||
@ -1582,17 +1556,8 @@ class LLM:
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
priority: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
if guided_options is not None:
|
||||
warnings.warn(
|
||||
"guided_options_request is deprecated, use "
|
||||
"SamplingParams.guided_decoding instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(prompts, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
@ -1608,8 +1573,6 @@ class LLM:
|
||||
|
||||
for sp in params if isinstance(params, Sequence) else (params, ):
|
||||
if isinstance(sp, SamplingParams):
|
||||
self._add_guided_params(sp, guided_options)
|
||||
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
@ -1647,29 +1610,6 @@ class LLM:
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
def _add_guided_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
guided_options: Optional[GuidedDecodingRequest] = None):
|
||||
if guided_options is None:
|
||||
return params
|
||||
|
||||
if params.guided_decoding is not None:
|
||||
raise ValueError("Cannot set both guided_options_request and "
|
||||
"params.guided_decoding.")
|
||||
|
||||
params.guided_decoding = GuidedDecodingParams(
|
||||
json=guided_options.guided_json,
|
||||
regex=guided_options.guided_regex,
|
||||
choice=guided_options.guided_choice,
|
||||
grammar=guided_options.guided_grammar,
|
||||
json_object=guided_options.guided_json_object,
|
||||
backend=guided_options.guided_decoding_backend,
|
||||
whitespace_pattern=guided_options.guided_whitespace_pattern,
|
||||
structural_tag=guided_options.structural_tag,
|
||||
)
|
||||
return params
|
||||
|
||||
def _run_engine(
|
||||
self,
|
||||
*,
|
||||
|
@ -1,192 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.utils import (
|
||||
convert_lark_to_gbnf, grammar_is_likely_lark,
|
||||
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def maybe_backend_fallback(
|
||||
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
||||
|
||||
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
|
||||
fallback: str) -> None:
|
||||
"""Change the backend to the specified fallback with a warning log,
|
||||
or raise a ValueError if the `disable_fallback` option is specified."""
|
||||
if guided_params.disable_fallback:
|
||||
raise ValueError(message)
|
||||
|
||||
logger.warning("%s Falling back to use %s instead.", message, fallback)
|
||||
guided_params.backend = fallback
|
||||
|
||||
# `auto` was added for V1 to explicitly declare a mode that has fallbacks
|
||||
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
|
||||
# fallbacks enabled for that and it is the V0 default.
|
||||
if guided_params.backend == "auto":
|
||||
guided_params.backend = "xgrammar"
|
||||
|
||||
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
||||
if guided_params.backend == "lm-format-enforcer":
|
||||
if guided_params.grammar is not None:
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"lm-format-enforcer does not support grammar guided decoding.",
|
||||
"xgrammar")
|
||||
|
||||
# lm-format-enforcer doesn't support some JSON schema features
|
||||
elif (guided_params.json is not None
|
||||
and has_lmf_unsupported_json_features(guided_params.json)):
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"lm-format-enforcer does not support advanced JSON schema "
|
||||
"features like patterns or numeric ranges.", "outlines")
|
||||
|
||||
if guided_params.backend == "xgrammar":
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||
xgr_installed)
|
||||
|
||||
# xgrammar doesn't support some JSON schema features
|
||||
if (guided_params.json is not None and
|
||||
has_xgrammar_unsupported_json_features(guided_params.json)):
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar does not support advanced JSON schema features like "
|
||||
"string length, item limits, or property bounds.", "outlines")
|
||||
|
||||
# xgrammar only supports GBNF grammars, so we must convert Lark.
|
||||
# We must check if the grammar is likely Lark and if that
|
||||
# grammar is convertible to GBNF
|
||||
elif (guided_params.grammar is not None
|
||||
and grammar_is_likely_lark(guided_params.grammar)):
|
||||
try:
|
||||
convert_lark_to_gbnf(guided_params.grammar)
|
||||
except Exception:
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar does not support Lark grammars and the "
|
||||
"grammar failed to convert to GBNF.", "guidance")
|
||||
|
||||
# If the xgrammar module cannot be imported successfully,
|
||||
# we should still allow users to use guided decoding with a fallback.
|
||||
elif not xgr_installed:
|
||||
fallback_or_error(
|
||||
guided_params,
|
||||
"xgrammar module cannot be imported successfully.", "guidance")
|
||||
|
||||
if guided_params.backend == "outlines":
|
||||
if guided_params.json_object is not None:
|
||||
# outlines doesn't support json_object, fallback to guidance
|
||||
fallback_or_error(guided_params,
|
||||
"outlines does not support json_object.",
|
||||
"guidance")
|
||||
elif guided_params.grammar is not None:
|
||||
# outlines grammar support has been removed, fallback to guidance
|
||||
# if it is a lark-based grammar and xgrammar otherwise
|
||||
if grammar_is_likely_lark(guided_params.grammar):
|
||||
fallback_or_error(guided_params,
|
||||
"outlines no longer supports grammars.",
|
||||
"guidance")
|
||||
else:
|
||||
# The grammar is likely already GBNF format.
|
||||
fallback_or_error(guided_params,
|
||||
"outlines no longer supports grammars.",
|
||||
"xgrammar")
|
||||
|
||||
return guided_params
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig,
|
||||
reasoning_backend: str | None = None) -> LogitsProcessor | None:
|
||||
|
||||
reasoner = None
|
||||
if reasoning_backend:
|
||||
reasoner_class = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_backend)
|
||||
reasoner = reasoner_class(tokenizer)
|
||||
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
|
||||
if guided_params.backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, reasoner)
|
||||
if guided_params.backend == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend == 'xgrammar':
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, model_config, reasoner)
|
||||
if guided_params.backend == 'guidance':
|
||||
from vllm.model_executor.guided_decoding.guidance_decoding import (
|
||||
get_local_guidance_guided_decoding_logits_processor)
|
||||
return get_local_guidance_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
|
||||
)
|
||||
|
||||
|
||||
def get_local_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig,
|
||||
reasoning_backend: str | None = None) -> LogitsProcessor | None:
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
|
||||
reasoner = None
|
||||
if reasoning_backend:
|
||||
reasoner_class = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_backend)
|
||||
reasoner = reasoner_class(tokenizer)
|
||||
|
||||
if guided_params.backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_local_outlines_guided_decoding_logits_processor)
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, reasoner)
|
||||
if guided_params.backend == 'lm-format-enforcer':
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
if guided_params.backend == 'xgrammar':
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, model_config, reasoner)
|
||||
if guided_params.backend == 'guidance':
|
||||
from vllm.model_executor.guided_decoding.guidance_decoding import (
|
||||
get_local_guidance_guided_decoding_logits_processor)
|
||||
return get_local_guidance_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
|
||||
)
|
@ -1,63 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
|
||||
import llguidance
|
||||
from regex import escape as regex_escape
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
|
||||
GuidanceLogitsProcessor)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
process_for_additional_properties)
|
||||
|
||||
|
||||
def get_local_guidance_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
"""
|
||||
|
||||
grm = ""
|
||||
any_whitespace = not guided_params.disable_any_whitespace
|
||||
if (guide_json := guided_params.json) is not None:
|
||||
# Optionally set additionalProperties to False at the top-level
|
||||
# By default, other backends do not allow additional top-level
|
||||
# properties, so this makes guidance more similar to other backends
|
||||
if guided_params.disable_additional_properties:
|
||||
if not isinstance(guide_json, str):
|
||||
guide_json = json.dumps(guide_json)
|
||||
guide_json = process_for_additional_properties(guide_json)
|
||||
|
||||
grm = llguidance.LLMatcher.grammar_from_json_schema(
|
||||
guide_json,
|
||||
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
|
||||
defaults={
|
||||
"whitespace_flexible": any_whitespace,
|
||||
})
|
||||
elif guided_params.json_object:
|
||||
grm = llguidance.LLMatcher.grammar_from_json_schema(
|
||||
'{"type": "object"}',
|
||||
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
|
||||
defaults={
|
||||
"whitespace_flexible": any_whitespace,
|
||||
})
|
||||
elif guided_params.regex:
|
||||
grm = llguidance.grammar_from("regex", guided_params.regex)
|
||||
elif guided_params.choice:
|
||||
# choice just uses regex
|
||||
choices = (regex_escape(str(choice))
|
||||
for choice in guided_params.choice)
|
||||
choices_regex = "(" + "|".join(choices) + ")"
|
||||
grm = llguidance.grammar_from("regex", choices_regex)
|
||||
elif guided_params.grammar:
|
||||
# this supports Lark and GBNF
|
||||
grm = llguidance.grammar_from("grammar", guided_params.grammar)
|
||||
|
||||
if grm:
|
||||
return GuidanceLogitsProcessor(grm, tokenizer)
|
||||
|
||||
raise ValueError("Unknown guided decoding mode")
|
@ -1,104 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import llguidance
|
||||
import llguidance.hf
|
||||
import llguidance.torch
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GuidanceLogitsProcessor:
|
||||
"""Base Guidance Logits Processor"""
|
||||
|
||||
cached_tokenizers: dict[str, Any] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grammar: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
"""Base Guidance Logits Processor
|
||||
|
||||
Args:
|
||||
grammar (str)
|
||||
grammar to guide the generation
|
||||
tokenizer (PreTrainedTokenizerBase)
|
||||
model's tokenizer
|
||||
"""
|
||||
self.grammar = grammar
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_name = tokenizer.name_or_path
|
||||
self.ll_tokenizer = None
|
||||
self.ll_matcher = None
|
||||
self.bitmask = None
|
||||
self.new_sampling = False
|
||||
self.initialized = False
|
||||
|
||||
def clone(self) -> "GuidanceLogitsProcessor":
|
||||
cloned = copy.copy(self)
|
||||
if self.initialized:
|
||||
cloned.ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer, # type: ignore[assignment]
|
||||
self.grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
self.bitmask = llguidance.torch.allocate_token_bitmask(
|
||||
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
|
||||
return cloned
|
||||
|
||||
def _initialize(self):
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
|
||||
None)
|
||||
if ll_tokenizer is None:
|
||||
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
|
||||
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer
|
||||
|
||||
self.ll_tokenizer = ll_tokenizer
|
||||
self.ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer,
|
||||
self.grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
|
||||
# create reusable bitmask
|
||||
self.bitmask = llguidance.torch.allocate_token_bitmask(
|
||||
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: list[int],
|
||||
scores: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# we initialize the guidance model here
|
||||
# to avoid pickling ll_tokenizer and ll_interpreter
|
||||
self._initialize()
|
||||
|
||||
if self.new_sampling and len(input_ids) > 0:
|
||||
self.ll_matcher.consume_token( # type: ignore[attr-defined]
|
||||
input_ids[-1])
|
||||
err = self.ll_matcher.get_error() # type: ignore[attr-defined]
|
||||
if err:
|
||||
logger.warning("Error in LLMatcher: %s", err)
|
||||
|
||||
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
|
||||
0)
|
||||
llguidance.torch.apply_token_bitmask_inplace(
|
||||
scores,
|
||||
self.bitmask.to(scores.device)) # type: ignore[attr-defined]
|
||||
|
||||
self.new_sampling = True
|
||||
|
||||
return scores
|
@ -1,41 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, TypedDict, Union
|
||||
|
||||
|
||||
# These classes are deprecated, see SamplingParams
|
||||
class LLMGuidedOptions(TypedDict, total=False):
|
||||
guided_json: Union[dict, str]
|
||||
guided_regex: str
|
||||
guided_choice: list[str]
|
||||
guided_grammar: str
|
||||
guided_decoding_backend: str
|
||||
guided_whitespace_pattern: str
|
||||
guided_json_object: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidedDecodingRequest:
|
||||
"""One of the fields will be used to retrieve the logit processor."""
|
||||
guided_json: Optional[Union[dict, str]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
guided_choice: Optional[list[str]] = None
|
||||
guided_grammar: Optional[str] = None
|
||||
guided_decoding_backend: Optional[str] = None
|
||||
guided_whitespace_pattern: Optional[str] = None
|
||||
guided_json_object: Optional[bool] = None
|
||||
structural_tag: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that some fields are mutually exclusive."""
|
||||
guide_count = sum(x is not None
|
||||
for x in (self.guided_json, self.guided_regex,
|
||||
self.guided_choice, self.guided_grammar,
|
||||
self.guided_json_object,
|
||||
self.structural_tag))
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding but multiple are "
|
||||
f"specified: {self.__dict__}")
|
@ -1,67 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import lru_cache
|
||||
from json import loads as json_loads
|
||||
from typing import Optional, Union
|
||||
|
||||
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
|
||||
RegexParser, StringParser,
|
||||
TokenEnforcerTokenizerData, UnionParser)
|
||||
from lmformatenforcer.integrations.vllm import (
|
||||
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
|
||||
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer)
|
||||
character_level_parser: CharacterLevelParser
|
||||
if guided_params.json:
|
||||
schema_dict = _normalize_json_schema_object(guided_params.json)
|
||||
character_level_parser = JsonSchemaParser(schema_dict)
|
||||
elif guided_params.choice:
|
||||
character_level_parser = UnionParser(
|
||||
[StringParser(choice) for choice in guided_params.choice])
|
||||
elif guided_params.regex:
|
||||
character_level_parser = RegexParser(guided_params.regex)
|
||||
elif guided_params.grammar:
|
||||
# CFG grammar not supported by LMFE
|
||||
raise ValueError("Cannot construct a guided decoding logits processor"
|
||||
" using the grammar option with the"
|
||||
" lm_format_enforcer backend.")
|
||||
elif guided_params.json_object:
|
||||
# None means any json object
|
||||
character_level_parser = JsonSchemaParser(None)
|
||||
else:
|
||||
return None
|
||||
|
||||
logits_processor = build_vllm_logits_processor(tokenizer_data,
|
||||
character_level_parser)
|
||||
return logits_processor
|
||||
|
||||
|
||||
def _normalize_json_schema_object(schema: Union[str, dict]) -> dict:
|
||||
if isinstance(schema, str):
|
||||
return json_loads(schema)
|
||||
if isinstance(schema, dict):
|
||||
return schema
|
||||
raise AssertionError(f"Unsupported schema type {schema}")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
|
||||
return build_vllm_token_enforcer_tokenizer_data(tokenizer)
|
@ -1,117 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from enum import Enum
|
||||
from json import dumps as json_dumps
|
||||
from typing import Optional, Union
|
||||
|
||||
from regex import escape as regex_escape
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
class GuidedDecodingMode(Enum):
|
||||
JSON = "json"
|
||||
REGEX = "regex"
|
||||
CHOICE = "choice"
|
||||
|
||||
|
||||
global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
# It's not yet clear that using more provides a benefit, and it could
|
||||
# potentially starve other processes on the machine. We'll cap this for now and
|
||||
# adjust later if testing proves it to help overcome a bottleneck.
|
||||
_MAX_THREADPOOL_WORKERS = 16
|
||||
|
||||
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
"""
|
||||
global global_thread_pool
|
||||
guide, mode = _get_guide_and_mode(guided_params)
|
||||
if not guide or not mode:
|
||||
return None
|
||||
|
||||
if global_thread_pool is None:
|
||||
max_workers = os.cpu_count() or 2
|
||||
if max_workers > _MAX_THREADPOOL_WORKERS:
|
||||
max_workers = _MAX_THREADPOOL_WORKERS
|
||||
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers)
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(global_thread_pool,
|
||||
_get_logits_processor, guide, tokenizer,
|
||||
mode, guided_params.whitespace_pattern,
|
||||
reasoner)
|
||||
|
||||
|
||||
def get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
"""
|
||||
guide, mode = _get_guide_and_mode(guided_params)
|
||||
if not guide or not mode:
|
||||
return None
|
||||
|
||||
return _get_logits_processor(guide, tokenizer, mode,
|
||||
guided_params.whitespace_pattern, reasoner)
|
||||
|
||||
|
||||
def _get_guide_and_mode(
|
||||
guided_params: GuidedDecodingParams
|
||||
) -> Union[tuple[str, GuidedDecodingMode], tuple[None, None]]:
|
||||
if guided_params.json:
|
||||
if isinstance(guided_params.json, dict):
|
||||
# turn dict into hashable string
|
||||
json = json_dumps(guided_params.json)
|
||||
else:
|
||||
json = guided_params.json
|
||||
return json, GuidedDecodingMode.JSON
|
||||
elif guided_params.regex:
|
||||
return guided_params.regex, GuidedDecodingMode.REGEX
|
||||
elif guided_params.choice:
|
||||
# choice just uses regex
|
||||
choices = [
|
||||
regex_escape(str(choice)) for choice in guided_params.choice
|
||||
]
|
||||
choices_regex = "(" + "|".join(choices) + ")"
|
||||
return choices_regex, GuidedDecodingMode.CHOICE
|
||||
elif guided_params.grammar:
|
||||
raise ValueError(
|
||||
"The `outlines` guided decoding backend no longer supports grammar "
|
||||
"guided generation. Please use either the `xgrammar` or `guidance` "
|
||||
"backend")
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def _get_logits_processor(
|
||||
guide: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
|
||||
reasoner)
|
||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||
return RegexLogitsProcessor(guide, tokenizer, reasoner)
|
||||
else:
|
||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
@ -1,307 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Copyright 2024-present the Outlines developers
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from cachetools import LRUCache
|
||||
from diskcache import Cache
|
||||
from outlines_core import Guide, Index, Vocabulary
|
||||
from outlines_core.json_schema import build_regex_from_schema
|
||||
from outlines_core.kernels.torch import (_apply_token_bitmask_inplace_kernel,
|
||||
allocate_token_bitmask)
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CACHE = None
|
||||
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self, guide: Guide, eos_token_id: int,
|
||||
reasoner: Optional[ReasoningParser]) -> None:
|
||||
self._guide: Guide = guide
|
||||
self._eos_token_id: int = eos_token_id
|
||||
self._reasoner: Optional[ReasoningParser] = reasoner
|
||||
self._mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, input_ids: list[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
if self._mask is None:
|
||||
self._mask = allocate_token_bitmask(scores.size(-1))
|
||||
|
||||
# Skip the structured logits processing if reasoning is not finished.
|
||||
# reasoner is not None only when `--reasoning-parser` is set.
|
||||
if self._reasoner is not None and not self._reasoner.is_reasoning_end(
|
||||
input_ids):
|
||||
return scores
|
||||
|
||||
# Remove the reasoning tokens from the input_ids
|
||||
# We need this because our implementation relies on the
|
||||
# input_ids sequence to store the FSM state.
|
||||
input_ids = (self._reasoner.extract_content_ids(input_ids)
|
||||
if self._reasoner is not None else input_ids)
|
||||
|
||||
# Vllm V0 engine has a weird bug where we have to repeat
|
||||
# the eos token id twice for generation to stop, or at least
|
||||
# that is what we have to do from here in any case.
|
||||
# This is a patch until a better solution can be pushed
|
||||
# to outlines_core
|
||||
if input_ids and input_ids[-1] != self._eos_token_id:
|
||||
self._guide.advance(token_id=input_ids[-1], return_tokens=False)
|
||||
|
||||
self._guide.write_mask_into(
|
||||
data_ptr=self._mask.data_ptr(),
|
||||
numel=self._mask.numel(),
|
||||
element_size=self._mask.element_size(),
|
||||
)
|
||||
|
||||
# Any allowed tokens beyond the length of the scores will
|
||||
# be ignored by the kernel, taking care of the issue with
|
||||
# models such as Llama 3.2 Vision with an `<|image|>` token
|
||||
# with id 128256, but scores.shape == torch.Size([128256])
|
||||
_apply_token_bitmask_inplace_kernel(
|
||||
logits=scores.unsqueeze(dim=0),
|
||||
# mask must be on same device
|
||||
mask=self._mask.to(scores.device, non_blocking=True))
|
||||
self._mask.to("cpu", non_blocking=True)
|
||||
|
||||
return scores
|
||||
|
||||
def clone(self) -> BaseLogitsProcessor:
|
||||
guide = copy.deepcopy(self._guide)
|
||||
guide.reset()
|
||||
return BaseLogitsProcessor(guide=guide,
|
||||
eos_token_id=self._eos_token_id,
|
||||
reasoner=self._reasoner)
|
||||
|
||||
|
||||
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
@classmethod
|
||||
def _get_guide(cls, regex_string: str,
|
||||
tokenizer: PreTrainedTokenizerBase) -> Guide:
|
||||
global CACHE
|
||||
if CACHE is None:
|
||||
CACHE = get_cache()
|
||||
vocabulary = get_vocabulary(tokenizer) # type: ignore[arg-type]
|
||||
cache_key = f"{vocabulary._hash}_{regex_string}"
|
||||
if CACHE is not None and cache_key in CACHE:
|
||||
return Guide(CACHE[cache_key])
|
||||
|
||||
index = Index(regex_string, vocabulary.inner)
|
||||
|
||||
if CACHE is not None:
|
||||
CACHE[cache_key] = index
|
||||
|
||||
return Guide(index)
|
||||
|
||||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[ReasoningParser]) -> None:
|
||||
super().__init__(
|
||||
guide=RegexLogitsProcessor._get_guide(regex_string, tokenizer),
|
||||
eos_token_id=tokenizer.eos_token_id, # type: ignore
|
||||
reasoner=reasoner)
|
||||
|
||||
|
||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
|
||||
def __init__(self, schema: Union[str, dict, BaseModel],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[ReasoningParser]) -> None:
|
||||
|
||||
if isinstance(schema, type(BaseModel)):
|
||||
schema_str = json.dumps(schema.model_json_schema())
|
||||
elif isinstance(schema, dict):
|
||||
schema_str = json.dumps(schema)
|
||||
elif isinstance(schema, str):
|
||||
schema_str = schema
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot parse schema {schema}. The schema must be either "
|
||||
f"a Pydantic object, a dictionary or a string that contains "
|
||||
f"the JSON Schema specification")
|
||||
|
||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||
super().__init__(regex_string, tokenizer, reasoner)
|
||||
|
||||
|
||||
class OutlinesVocabulary:
|
||||
"""
|
||||
Wrapper class for `outlines_core.Vocabulary`,
|
||||
which allows us to store a hash with the vocabulary
|
||||
"""
|
||||
|
||||
def __init__(self, vocabulary: Vocabulary) -> None:
|
||||
# Actual vocabulary object
|
||||
self.inner = vocabulary
|
||||
# Have to do abs(hash()) because python hashes can
|
||||
# be negative, and we are using hash as a cache key.
|
||||
hex_str = hashlib.sha256(
|
||||
vocabulary.__repr__().encode('utf-8')).hexdigest()
|
||||
hash_int = int(hex_str, 16)
|
||||
self._hash = hash_int
|
||||
|
||||
|
||||
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
|
||||
re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
|
||||
|
||||
|
||||
def _reduced_vocabulary(tokenizer: AnyTokenizer,
|
||||
eos_token_id: int) -> dict[bytes, list[int]]:
|
||||
"""Create a map from vocabulary tokens to lists of equivalent token ids.
|
||||
|
||||
Returns:
|
||||
A Dict of token string -> equivalent token ids
|
||||
"""
|
||||
unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
|
||||
or token == "<0x20>"):
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
vocabulary: dict[bytes, list[int]] = {}
|
||||
empty_token_ids: list[int] = []
|
||||
for token, token_idx in tokenizer.get_vocab().items():
|
||||
if token in tokenizer.all_special_tokens: # type: ignore
|
||||
continue
|
||||
|
||||
token_str = convert_token_to_string(token)
|
||||
if token_str:
|
||||
if isinstance(token, (bytes, bytearray)):
|
||||
# For BPE tokenizers where tokens are stored as bytes.
|
||||
|
||||
# safe to ignore since token_str is of type (bytearray, bytes)
|
||||
# by this point.
|
||||
token_bytes = bytes(token_str) # type: ignore[arg-type]
|
||||
|
||||
elif "\ufffd" in token_str and not re_replacement_seq.match(
|
||||
token_str):
|
||||
# Handle tokens with invalid UTF-8 sequences.
|
||||
if re_llama_byte_token.match(token):
|
||||
# Llama-like tokenizers use <0xXX> for incomplete sequences.
|
||||
token_bytes = bytes([int(token[3:5], 16)])
|
||||
else:
|
||||
# GPT2 tokenizers: map each byte back using unicode_to_bytes
|
||||
byte_vals = [unicode_to_bytes.get(c) for c in token]
|
||||
if None in byte_vals:
|
||||
raise RuntimeError(
|
||||
f"Cannot convert token `{token}`"
|
||||
f" ({token_idx}) to bytes: {token_str}")
|
||||
# safe to ignore, since if None in byte_vals,
|
||||
# an error is thrown.
|
||||
token_bytes = bytes(byte_vals) # type: ignore[arg-type]
|
||||
else:
|
||||
token_bytes = token_str.encode('utf-8')
|
||||
|
||||
if token_idx != eos_token_id:
|
||||
vocabulary.setdefault(token_bytes, []).append(token_idx)
|
||||
else:
|
||||
empty_token_ids.append(token_idx)
|
||||
|
||||
return vocabulary
|
||||
|
||||
|
||||
def get_vocabulary(tokenizer: AnyTokenizer) -> Vocabulary:
|
||||
"""Get the `Vocabulary` object for a given tokenizer.
|
||||
"""
|
||||
if hasattr(tokenizer, "_outlines_vocabulary"):
|
||||
return tokenizer._outlines_vocabulary # type: ignore
|
||||
|
||||
try:
|
||||
if hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id",
|
||||
) and tokenizer.eos_token_id is not None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Error during guided decoding setup: Tokenizer"
|
||||
f" ({type(tokenizer)}) has no `eos_token_id` property, "
|
||||
"but `eos_token_id` is required for guided decoding"
|
||||
" to work properly.")
|
||||
|
||||
reduced_vocab = _reduced_vocabulary(
|
||||
tokenizer,
|
||||
eos_token_id #type: ignore
|
||||
)
|
||||
vocabulary = OutlinesVocabulary(Vocabulary(eos_token_id,
|
||||
reduced_vocab))
|
||||
tokenizer._outlines_vocabulary = vocabulary # type: ignore
|
||||
|
||||
return vocabulary
|
||||
except AttributeError as e:
|
||||
raise ValueError(f"Cannot get the vocabulary of the tokenizer "
|
||||
f"({type(tokenizer)}). The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
|
||||
|
||||
def get_cache_path() -> str:
|
||||
"""Get the context object that contains previously-computed return values"""
|
||||
outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
|
||||
xdg_cache_home = os.getenv("XDG_CACHE_HOME")
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
if outlines_cache_dir:
|
||||
# OUTLINES_CACHE_DIR takes precedence
|
||||
return outlines_cache_dir
|
||||
elif xdg_cache_home:
|
||||
return os.path.join(xdg_cache_home, ".cache", "outlines")
|
||||
# If homedir is "/", we may be inside a container, and thus writing to
|
||||
# root would be problematic, so we fallback to using a tempfile.
|
||||
# Also validate the path exists, since os.path.expanduser does
|
||||
# not garuntee existence.
|
||||
elif os.path.isdir(home_dir) and home_dir != "/":
|
||||
# Default Unix fallback: ~/.cache/outlines
|
||||
return os.path.join(home_dir, ".cache", "outlines")
|
||||
else:
|
||||
import tempfile
|
||||
|
||||
# home_dir may be / inside a docker container without existing user
|
||||
tempdir = tempfile.gettempdir()
|
||||
return os.path.join(tempdir, ".cache", "outlines")
|
||||
|
||||
|
||||
def get_cache():
|
||||
"""Get the Cache instance to be used for index caching"""
|
||||
|
||||
cache_dir = get_cache_path()
|
||||
if envs.VLLM_V0_USE_OUTLINES_CACHE:
|
||||
logger.warning("Enabling outlines cache. This is an unbounded on-disk "
|
||||
"cache. It may consume a lot of disk space and should "
|
||||
"not be used with untrusted clients.")
|
||||
cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
|
||||
outlines_version = importlib.metadata.version("outlines_core")
|
||||
|
||||
cached_version = cache.get('__version__', None)
|
||||
if cached_version != outlines_version:
|
||||
cache.clear()
|
||||
cache.set('__version__', outlines_version)
|
||||
return cache
|
||||
else:
|
||||
return LRUCache(maxsize=128)
|
@ -1,242 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||
|
||||
def check_object(obj: dict) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
||||
return True
|
||||
|
||||
# Check for array unsupported keywords
|
||||
if obj.get("type") == "array" and any(key in obj for key in [
|
||||
"uniqueItems", "contains", "minContains", "maxContains",
|
||||
"minItems", "maxItems"
|
||||
]):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for strings
|
||||
if obj.get("type") == "string" and any(
|
||||
key in obj for key in ["minLength", "maxLength", "format"]):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for objects
|
||||
if obj.get("type") == "object" and any(key in obj for key in [
|
||||
"minProperties", "maxProperties", "propertyNames",
|
||||
"patternProperties"
|
||||
]):
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def has_lmf_unsupported_json_features(schema: dict) -> bool:
|
||||
"""
|
||||
Check if JSON schema contains features unsupported
|
||||
by lm_format_enforcer.
|
||||
|
||||
Known issues:
|
||||
- Regex patterns:
|
||||
"grade": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-D]$" # Regex pattern
|
||||
},
|
||||
"""
|
||||
|
||||
def check_object(obj: dict) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for pattern restrictions
|
||||
if "pattern" in obj:
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||
"""
|
||||
Check if grammar appears to use Lark syntax.
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar string
|
||||
|
||||
Returns:
|
||||
bool: True if grammar appears to be in Lark format, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> grammar_is_likely_lark("rule: 'abc'")
|
||||
True
|
||||
>>> grammar_is_likely_lark("rule ::= 'abc'")
|
||||
False
|
||||
"""
|
||||
if not grammar_str or not isinstance(grammar_str, str):
|
||||
return False
|
||||
|
||||
for line in grammar_str.split('\n'):
|
||||
# Remove both comment styles
|
||||
line = re.sub(r'(#|//).*$', '', line).strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Look for GBNF rule definition
|
||||
if '::=' in line:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def convert_lark_to_gbnf(grammar_str: str) -> str:
|
||||
"""
|
||||
Convert a Lark grammar string to GBNF format.
|
||||
|
||||
GBNF reference:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
Lark grammar reference:
|
||||
https://lark-parser.readthedocs.io/en/latest/grammar.html
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar in Lark format
|
||||
|
||||
Returns:
|
||||
str: Converted grammar in GBNF format
|
||||
|
||||
Examples:
|
||||
>>> print(convert_lark_to_gbnf("rule: 'hello'"))
|
||||
root ::= rule
|
||||
rule ::= "hello"
|
||||
"""
|
||||
if not isinstance(grammar_str, str):
|
||||
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
|
||||
if not grammar_str.strip():
|
||||
raise ValueError("Grammar string cannot be empty")
|
||||
|
||||
defined_rules = set()
|
||||
referenced_rules = set()
|
||||
output_lines = []
|
||||
|
||||
def clean_line(line: str) -> str:
|
||||
"""Remove comments and whitespace from line."""
|
||||
return re.sub(r'(#|//).*$', '', line).strip()
|
||||
|
||||
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
|
||||
"""Validate quote matching in text."""
|
||||
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Mismatched quotes in {rule_name} on line {line_num}")
|
||||
|
||||
def extract_references(text: str) -> set:
|
||||
"""Extract rule references from text."""
|
||||
# Remove quoted strings and special characters
|
||||
text = re.sub(r'"[^"]*"', '', text)
|
||||
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
|
||||
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
|
||||
|
||||
# First pass: Find root rule and validate rule definitions
|
||||
lines = [clean_line(line) for line in grammar_str.split('\n')]
|
||||
first_rule = None
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line or line.startswith('|'):
|
||||
continue
|
||||
|
||||
if ':' in line:
|
||||
try:
|
||||
name = line.split(':', 1)[0].strip().strip('?')
|
||||
defined_rules.add(name)
|
||||
if first_rule is None:
|
||||
first_rule = name
|
||||
if name == 'start':
|
||||
first_rule = 'start'
|
||||
except IndexError as e:
|
||||
raise ValueError(f"Invalid rule format on line {line_num}. "
|
||||
"Expected 'rule_name: definition'") from e
|
||||
|
||||
if not defined_rules:
|
||||
raise ValueError("No valid rules found in grammar")
|
||||
|
||||
# Add root rule
|
||||
output_lines.append(f"root ::= {first_rule}")
|
||||
|
||||
# Second pass: Process rule definitions and alternatives
|
||||
current_rule = None
|
||||
current_definition = []
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
if ':' in line and not line.startswith('|'):
|
||||
# Save previous rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Process new rule
|
||||
name, definition = line.split(':', 1)
|
||||
current_rule = name.strip().strip('?')
|
||||
|
||||
check_quotes(definition, f"rule '{current_rule}'", line_num)
|
||||
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
|
||||
referenced_rules.update(extract_references(definition))
|
||||
current_definition = [definition.strip()]
|
||||
|
||||
elif line.startswith('|'):
|
||||
if not current_rule:
|
||||
raise ValueError(f"Alternative '|' on line {line_num} "
|
||||
"without a preceding rule definition")
|
||||
|
||||
alt_def = line[1:].strip()
|
||||
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
|
||||
line_num)
|
||||
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
|
||||
referenced_rules.update(extract_references(alt_def))
|
||||
current_definition.append(alt_def)
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
|
||||
|
||||
# Add final rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Validate all rules are defined
|
||||
undefined_rules = referenced_rules - defined_rules - {'root'}
|
||||
if undefined_rules:
|
||||
raise ValueError("Referenced rules are not defined: "
|
||||
f"{', '.join(sorted(undefined_rules))}")
|
||||
|
||||
return '\n'.join(output_lines)
|
@ -1,426 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# noqa: UP007
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
import vllm.envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
import xgrammar as xgr
|
||||
xgr_installed = True
|
||||
except ImportError:
|
||||
xgr_installed = False
|
||||
pass
|
||||
|
||||
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||
grammar_is_likely_lark)
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig,
|
||||
reasoner: ReasoningParser | None,
|
||||
max_threads: int = 8):
|
||||
config = GrammarConfig.from_guided_params(guided_params=guided_params,
|
||||
model_config=model_config,
|
||||
tokenizer=tokenizer,
|
||||
max_threads=max_threads)
|
||||
return XGrammarLogitsProcessor(config, reasoner)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenizerData:
|
||||
"""Immutable container for cached tokenizer data."""
|
||||
metadata: str
|
||||
encoded_vocab: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class TokenizerDataCache:
|
||||
"""Cache manager for tokenizer data to avoid repeated processing."""
|
||||
_cache: dict[int, TokenizerData] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer_data(
|
||||
cls,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
/,
|
||||
*,
|
||||
tokenizer_hash: int,
|
||||
vocab_size: int,
|
||||
) -> TokenizerData:
|
||||
|
||||
if tokenizer_hash not in cls._cache:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
tokenizer,
|
||||
# NOTE: We will need to use lm_head's vocab_size
|
||||
# to determine correct special_token_ids for this tokenizer.
|
||||
# See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
metadata = json.loads(tokenizer_info.dump_metadata())
|
||||
|
||||
# Vendored from xgrammar logic to get encoded_vocab
|
||||
# https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
|
||||
try:
|
||||
vocab_dict = tokenizer.get_vocab()
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
|
||||
# maintain tokenizer's indexing
|
||||
encoded_vocab = [""] * tokenizer_info.vocab_size
|
||||
for token, idx in vocab_dict.items():
|
||||
if idx < tokenizer_info.vocab_size:
|
||||
encoded_vocab[idx] = token
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
metadata.update({
|
||||
"vocab_type": xgr.VocabType.BYTE_FALLBACK,
|
||||
"add_prefix_space": True
|
||||
})
|
||||
|
||||
cls._cache[tokenizer_hash] = TokenizerData(
|
||||
encoded_vocab=encoded_vocab,
|
||||
metadata=json.dumps(metadata),
|
||||
)
|
||||
|
||||
return cls._cache[tokenizer_hash]
|
||||
|
||||
|
||||
class GrammarCompilerCache:
|
||||
"""
|
||||
Cache for GrammarCompiler instances based on tokenizer.
|
||||
|
||||
This cache reduces the overhead of creating new compiler instances when
|
||||
using the same tokenizer configuration.
|
||||
"""
|
||||
_cache: dict[str, xgr.GrammarCompiler] = {}
|
||||
|
||||
@classmethod
|
||||
def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
|
||||
cache_key = str(config.tokenizer_hash)
|
||||
|
||||
if cache_key not in cls._cache:
|
||||
config_data = config.tokenizer_data
|
||||
|
||||
# In TokenizerDataCache.get_tokenizer_data, a serializable
|
||||
# tokenizer_data is created and cached. This data is used to build
|
||||
# a tokenizer_info and create an xgrammar compiler.
|
||||
tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata(
|
||||
encoded_vocab=config_data.encoded_vocab,
|
||||
metadata=config_data.metadata,
|
||||
)
|
||||
cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024
|
||||
cls._cache[cache_key] = xgr.GrammarCompiler(
|
||||
tokenizer_info,
|
||||
max_threads=config.max_threads,
|
||||
cache_enabled=True,
|
||||
cache_limit_bytes=cache_size,
|
||||
)
|
||||
|
||||
return cls._cache[cache_key]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarConfig:
|
||||
"""Serializable configuration for grammar compilation"""
|
||||
tokenizer_hash: int
|
||||
tokenizer_data: TokenizerData
|
||||
json_str: str | None = None
|
||||
grammar_str: str | None = None
|
||||
json_object: bool | None = None
|
||||
any_whitespace: bool = True
|
||||
regex_str: str | None = None
|
||||
max_threads: int = 8
|
||||
|
||||
@classmethod
|
||||
def from_guided_params(cls,
|
||||
guided_params: GuidedDecodingParams,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_threads: int = 8) -> GrammarConfig:
|
||||
|
||||
tokenizer_hash = hash(tokenizer)
|
||||
tokenizer_data = TokenizerDataCache.get_tokenizer_data(
|
||||
tokenizer,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
)
|
||||
|
||||
if guided_params.json:
|
||||
if not isinstance(guided_params.json, str):
|
||||
json_str = json.dumps(guided_params.json)
|
||||
else:
|
||||
json_str = guided_params.json
|
||||
|
||||
any_whitespace = not guided_params.disable_any_whitespace
|
||||
|
||||
# Check and log if model with xgrammar and whitespace have history
|
||||
# of runaway generation of whitespaces.
|
||||
# References:
|
||||
# https://github.com/vllm-project/vllm/pull/12744
|
||||
# https://github.com/mlc-ai/xgrammar/issues/212
|
||||
model_with_warn = None
|
||||
|
||||
if 'Mistral' in model_config.model:
|
||||
model_with_warn = 'Mistral'
|
||||
elif 'Qwen' in model_config.model:
|
||||
model_with_warn = 'Qwen'
|
||||
|
||||
if model_with_warn is not None and any_whitespace:
|
||||
logger.info_once(
|
||||
"%s model detected, consider setting `disable_any_whitespace` to prevent runaway generation of whitespaces.", # noqa: E501
|
||||
model_with_warn,
|
||||
)
|
||||
# Validate the schema and raise ValueError here if it is invalid.
|
||||
# This is to avoid exceptions in model execution, which will crash
|
||||
# the engine worker process.
|
||||
try:
|
||||
xgr.Grammar.from_json_schema(json_str,
|
||||
any_whitespace=any_whitespace)
|
||||
except RuntimeError as err:
|
||||
raise ValueError(str(err)) from err
|
||||
|
||||
return cls(json_str=json_str,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
any_whitespace=any_whitespace)
|
||||
elif guided_params.grammar:
|
||||
# XGrammar only supports GBNF grammars, so we must convert Lark
|
||||
if grammar_is_likely_lark(guided_params.grammar):
|
||||
try:
|
||||
grammar_str = convert_lark_to_gbnf(guided_params.grammar)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Failed to convert the grammar from Lark to GBNF. "
|
||||
"Please either use GBNF grammar directly or specify"
|
||||
" --guided-decoding-backend=outlines.\n"
|
||||
f"Conversion error: {str(e)}") from e
|
||||
else:
|
||||
grammar_str = guided_params.grammar
|
||||
|
||||
# Validate the grammar and raise ValueError here if it is invalid.
|
||||
# This is to avoid exceptions in model execution, which will crash
|
||||
# the engine worker process.
|
||||
try:
|
||||
xgr.Grammar.from_ebnf(grammar_str)
|
||||
except RuntimeError as err:
|
||||
raise ValueError(str(err)) from err
|
||||
|
||||
return cls(grammar_str=grammar_str,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data)
|
||||
elif guided_params.json_object:
|
||||
return cls(
|
||||
json_object=True,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
elif guided_params.choice:
|
||||
choice_str = GrammarConfig.choice_as_grammar(guided_params.choice)
|
||||
try:
|
||||
xgr.Grammar.from_ebnf(choice_str)
|
||||
except RuntimeError as err:
|
||||
raise ValueError(str(err)) from err
|
||||
|
||||
return cls(
|
||||
grammar_str=choice_str,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
elif guided_params.regex:
|
||||
return cls(
|
||||
regex_str=guided_params.regex,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently only support JSON and EBNF grammar mode for xgrammar"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def escape_ebnf_string(s: str) -> str:
|
||||
"""Escape special characters in a EBNF string."""
|
||||
# Escape double quotes and backslashes
|
||||
return re.sub(r'(["\\])', r'\\\1', s)
|
||||
|
||||
@staticmethod
|
||||
def choice_as_grammar(choice: list[str] | None) -> str:
|
||||
if choice is None:
|
||||
raise ValueError("Choice is not set")
|
||||
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
|
||||
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
return grammar
|
||||
|
||||
@staticmethod
|
||||
def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo:
|
||||
return xgr.TokenizerInfo.from_vocab_and_metadata(
|
||||
encoded_vocab=tokenizer_data.encoded_vocab,
|
||||
metadata=tokenizer_data.metadata,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XGrammarLogitsProcessor:
|
||||
"""Wrapper class to support pickle protocol"""
|
||||
config: GrammarConfig
|
||||
reasoner: ReasoningParser | None = None
|
||||
|
||||
ctx: xgr.CompiledGrammar | None = None
|
||||
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
|
||||
token_bitmask: torch.Tensor = None # type: ignore[assignment]
|
||||
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
|
||||
batch_size: int = field(default=1)
|
||||
prefilled: bool = field(default=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_info is None:
|
||||
self.tokenizer_info = self.config.tokenizer_info(
|
||||
self.config.tokenizer_data)
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
return {'config': self.config, 'reasoner': self.reasoner}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]):
|
||||
self.config = state['config']
|
||||
self.reasoner = state['reasoner']
|
||||
|
||||
self.tokenizer_info = GrammarConfig.tokenizer_info(
|
||||
self.config.tokenizer_data)
|
||||
self.ctx = None
|
||||
self.matchers = []
|
||||
self.batch_size = 1
|
||||
self.token_bitmask = None # type: ignore[assignment]
|
||||
self.prefilled = False
|
||||
|
||||
def _ensure_ctx(self):
|
||||
"""Lazily initialize the processor in the worker process"""
|
||||
if self.ctx is None:
|
||||
compiler = GrammarCompilerCache.get_compiler(self.config)
|
||||
if self.config.json_str is not None:
|
||||
any_whitespace = self.config.any_whitespace
|
||||
self.ctx = compiler\
|
||||
.compile_json_schema(self.config.json_str,
|
||||
any_whitespace=any_whitespace)
|
||||
elif self.config.grammar_str is not None:
|
||||
self.ctx = compiler.compile_grammar(self.config.grammar_str)
|
||||
elif self.config.json_object:
|
||||
any_whitespace = self.config.any_whitespace
|
||||
self.ctx = compiler\
|
||||
.compile_json_schema('{"type": "object"}',
|
||||
any_whitespace=any_whitespace)
|
||||
elif self.config.regex_str:
|
||||
self.ctx = compiler.compile_regex(self.config.regex_str)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid configuration for xgrammar logits processor")
|
||||
|
||||
def __call__(self, input_ids: list[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# Skip the structured logits processing if reasoning is not finished.
|
||||
# reasoner is not None only when `--reasoning-parser` is set.
|
||||
if self.reasoner is not None and \
|
||||
not self.reasoner.is_reasoning_end(
|
||||
input_ids):
|
||||
return scores
|
||||
|
||||
if self.ctx is None:
|
||||
self._ensure_ctx()
|
||||
|
||||
if len(self.matchers) == 0:
|
||||
self.matchers = [
|
||||
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
|
||||
]
|
||||
self.token_bitmask = xgr.allocate_token_bitmask(
|
||||
self.batch_size, self.tokenizer_info.vocab_size)
|
||||
|
||||
if not self.prefilled:
|
||||
# Have not sampled a token yet
|
||||
self.prefilled = True
|
||||
else:
|
||||
for i, matcher in enumerate(self.matchers):
|
||||
if not matcher.is_terminated():
|
||||
sampled_token = input_ids[-1]
|
||||
assert self.matchers[i].accept_token(sampled_token)
|
||||
|
||||
for i, matcher in enumerate(self.matchers):
|
||||
if not matcher.is_terminated():
|
||||
# @ubospica: ideally, fill_next_token_bitmask should be
|
||||
# parallelized with model decoding
|
||||
# See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303
|
||||
matcher.fill_next_token_bitmask(self.token_bitmask, i)
|
||||
|
||||
# token_bitmask is a CPU tensor for use with accept_token and
|
||||
# fill_next_token_bitmask so we move it to the device of scores
|
||||
device_type = scores.device.type
|
||||
dtype = scores.dtype
|
||||
if device_type != "cuda":
|
||||
# xgrammar on cpu only supports float32 scores
|
||||
# see: https://github.com/mlc-ai/xgrammar/blob/c1b64920cad24f44f235778c1c00bb52d57da01a/python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py#L22
|
||||
scores = scores.to("cpu").float().unsqueeze(0)
|
||||
|
||||
# Note: In this method, if the tensors have different dimensions
|
||||
# on CPU device fails, but on GPU it runs without error. Hence the
|
||||
# unsqueeze above for scores, to match the token bitmask shape
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
scores, self.token_bitmask.to(scores.device, non_blocking=True))
|
||||
if device_type != "cuda":
|
||||
scores = scores.to(dtype).to(device_type).squeeze()
|
||||
|
||||
return scores
|
||||
|
||||
def clone(self) -> XGrammarLogitsProcessor:
|
||||
"""Create a new instance with shared compiled grammar
|
||||
but separate state"""
|
||||
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner,
|
||||
None, self.tokenizer_info)
|
||||
|
||||
# Share the compiled grammar context (immutable after compilation)
|
||||
new_processor.ctx = self.ctx
|
||||
|
||||
# Create fresh matchers for the new sequence
|
||||
if self.ctx is not None:
|
||||
new_processor.matchers = [
|
||||
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
|
||||
]
|
||||
|
||||
# Create a new token bitmask with the same size
|
||||
if hasattr(self, 'token_bitmask') and self.token_bitmask is not None:
|
||||
new_processor.token_bitmask = self.token_bitmask
|
||||
|
||||
# Copy simple attributes
|
||||
new_processor.batch_size = self.batch_size
|
||||
# Reset prefilled state for new sequence
|
||||
new_processor.prefilled = False
|
||||
|
||||
return new_processor
|
Reference in New Issue
Block a user