mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[V1] Structured Outputs + Thinking compatibility (#16577)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@ -141,10 +141,10 @@ Remember to check whether the `reasoning_content` exists in the response before
|
||||
The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now.
|
||||
|
||||
```bash
|
||||
VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1
|
||||
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1
|
||||
```
|
||||
|
||||
Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine.
|
||||
The following is an example client:
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
@ -1,3 +1,4 @@
|
||||
# ruff: noqa: E501
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
@ -5,17 +6,22 @@ from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import TokenizerMode
|
||||
|
||||
NGRAM_SPEC_CONFIG = {
|
||||
"model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
@ -514,6 +520,88 @@ Make the response as short as possible.
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
|
||||
[
|
||||
("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto",
|
||||
"deepseek_r1", NGRAM_SPEC_CONFIG),
|
||||
("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None),
|
||||
],
|
||||
)
|
||||
def test_structured_output_with_reasoning_matrices(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
guided_decoding_backend: str,
|
||||
tokenizer_mode: TokenizerMode,
|
||||
reasoning_parser: str,
|
||||
model_name: str,
|
||||
speculative_config: dict[str, Any] | None,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
if current_platform.is_tpu() and speculative_config:
|
||||
pytest.skip("TPU does not support speculative decoding")
|
||||
|
||||
# Use a single LLM instance for several scenarios to
|
||||
# speed up the test suite.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
# Don't use eager execution on TPUs because we want to test for no
|
||||
# recompilation at runtime
|
||||
enforce_eager=bool(not current_platform.is_tpu()),
|
||||
max_model_len=1024,
|
||||
max_num_seqs=16,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
reasoning_parser=reasoning_parser,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
tokenizer = llm.get_tokenizer(None)
|
||||
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
|
||||
tokenizer=tokenizer)
|
||||
|
||||
reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501
|
||||
reasoning_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"required": ["result"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
if "Qwen3" in model_name:
|
||||
reasoning_prompt += "<think>\n"
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.1,
|
||||
max_tokens=8192,
|
||||
guided_decoding=GuidedDecodingParams(json=reasoning_schema),
|
||||
)
|
||||
outputs = llm.generate(
|
||||
[reasoning_prompt],
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
output = outputs[0]
|
||||
assert output is not None and isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
reasoning_content, content = run_reasoning_extraction(
|
||||
reasoner, [generated_text])
|
||||
print(
|
||||
f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}"
|
||||
)
|
||||
|
||||
assert content is not None and reasoning_content is not None
|
||||
output_json = json.loads(content)
|
||||
jsonschema.validate(instance=output_json, schema=reasoning_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("model_name, tokenizer_mode",
|
||||
PARAMS_MODELS_TOKENIZER_MODE)
|
||||
|
@ -4024,7 +4024,7 @@ class VllmConfig:
|
||||
"""LoRA configuration."""
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
"""Speculative decoding configuration."""
|
||||
decoding_config: Optional[DecodingConfig] = None
|
||||
decoding_config: DecodingConfig = field(default_factory=DecodingConfig)
|
||||
"""Decoding configuration."""
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
"""Observability configuration."""
|
||||
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
@ -33,7 +35,7 @@ class ReasoningParser:
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
@abstractmethod
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
"""
|
||||
Check if the reasoning content ends in the input_ids.
|
||||
|
||||
@ -106,7 +108,7 @@ class ReasoningParserManager:
|
||||
reasoning_parsers: dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def get_reasoning_parser(cls, name) -> type:
|
||||
def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]:
|
||||
"""
|
||||
Get reasoning parser by name which is registered by `register_module`.
|
||||
|
||||
|
@ -758,7 +758,8 @@ class Scheduler(SchedulerInterface):
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
if new_token_ids and request.use_structured_output:
|
||||
if new_token_ids and self.structured_output_manager.should_advance(
|
||||
request):
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# check above, so safe to ignore type warning
|
||||
@ -767,11 +768,10 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if spec_token_ids is not None:
|
||||
if request.use_structured_output:
|
||||
if self.structured_output_manager.should_advance(request):
|
||||
metadata = request.structured_output_request
|
||||
assert metadata is not None and metadata.grammar is not None
|
||||
# Needs to happen after new_token_ids are accepted.
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens(
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
||||
spec_token_ids[req_index])
|
||||
else:
|
||||
request.spec_token_ids = spec_token_ids[req_index]
|
||||
|
@ -7,16 +7,23 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar)
|
||||
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -26,9 +33,11 @@ class StructuredOutputManager:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.backend: Optional[StructuredOutputBackend] = None
|
||||
self.reasoner: Optional[ReasoningParser] = None
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||
|
||||
# The default max_workers if not specified is the number of CPUs * 5,
|
||||
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
||||
@ -36,24 +45,43 @@ class StructuredOutputManager:
|
||||
# compilation, so we set it to half the number of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.vllm_config.model_config,
|
||||
scheduler_config=self.vllm_config.scheduler_config,
|
||||
lora_config=self.vllm_config.lora_config,
|
||||
).get_lora_tokenizer(None)
|
||||
reasoning_backend = vllm_config.decoding_config.reasoning_backend
|
||||
if reasoning_backend:
|
||||
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_backend)
|
||||
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
||||
|
||||
def grammar_init(self, request: Request) -> None:
|
||||
if request.structured_output_request is None:
|
||||
return
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert request.sampling_params.guided_decoding is not None
|
||||
|
||||
# Initialize the backend the first time it is needed.
|
||||
#
|
||||
# NOTE: We only support a single backend. We do NOT support different
|
||||
# backends on a per-request basis in V1 (for now, anyway...).
|
||||
if self.backend is None:
|
||||
backend = request.sampling_params.guided_decoding.backend
|
||||
vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||
if backend == "xgrammar":
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
XgrammarBackend)
|
||||
|
||||
self.backend = XgrammarBackend(self.vllm_config)
|
||||
self.backend = XgrammarBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "guidance":
|
||||
self.backend = GuidanceBackend(self.vllm_config)
|
||||
self.backend = GuidanceBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported structured output backend: {backend}")
|
||||
@ -87,14 +115,14 @@ class StructuredOutputManager:
|
||||
if not structured_output_request_ids:
|
||||
return None
|
||||
|
||||
max_num_spec_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
max_num_spec_tokens = \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
if self._grammar_bitmask is None:
|
||||
assert self.backend is not None
|
||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
max_num_spec_tokens = self.vllm_config.\
|
||||
speculative_config.num_speculative_tokens
|
||||
else:
|
||||
max_num_spec_tokens = 0
|
||||
|
||||
# Allocate a bitmask for each token needing to be checked:
|
||||
# one for each speculative position, and one more for the
|
||||
@ -103,6 +131,7 @@ class StructuredOutputManager:
|
||||
self.backend.allocate_token_bitmask(
|
||||
max_batch_size * (1 + max_num_spec_tokens))
|
||||
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
# Generate a batched bitmask for all structured output requests.
|
||||
# When speculative decoding is enabled, we need to include multiple
|
||||
# masks for each request, one for each possible bonus token position.
|
||||
@ -110,16 +139,30 @@ class StructuredOutputManager:
|
||||
cumulative_index = 0
|
||||
ordered_seq = sorted(structured_output_request_ids.items(),
|
||||
key=lambda x: x[1])
|
||||
|
||||
# Note that for thinking support, we will need to
|
||||
# reset the relevant part of the bitmask for consequent
|
||||
# request here.
|
||||
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
|
||||
self._full_mask)
|
||||
|
||||
# NOTE: This outer loop can likely be parallelized to improve
|
||||
# performance of bitmask generation for large batches.
|
||||
for req_id, _ in ordered_seq:
|
||||
request = requests[req_id].structured_output_request
|
||||
assert request is not None and request.grammar is not None
|
||||
if TYPE_CHECKING:
|
||||
assert request is not None
|
||||
assert request.grammar is not None
|
||||
|
||||
apply_bitmask = (
|
||||
request.reasoning_ended if self.reasoner is not None else True
|
||||
) # noqa: E501
|
||||
|
||||
state_advancements = 0
|
||||
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
|
||||
for i, token in enumerate(req_tokens):
|
||||
if not request.grammar.is_terminated():
|
||||
request.grammar.fill_bitmask(self._grammar_bitmask,
|
||||
if apply_bitmask and not request.grammar.is_terminated():
|
||||
request.grammar.fill_bitmask(bitmask_tensor,
|
||||
cumulative_index)
|
||||
if token is not None:
|
||||
# In order to generate the correct bitmask for each
|
||||
@ -132,15 +175,41 @@ class StructuredOutputManager:
|
||||
if state_advancements > 0:
|
||||
request.grammar.rollback(state_advancements)
|
||||
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
if cumulative_index < self._grammar_bitmask.shape[0]:
|
||||
bitmask_tensor = self._grammar_bitmask[:cumulative_index]
|
||||
if cumulative_index < bitmask_tensor.shape[0]:
|
||||
bitmask_tensor = bitmask_tensor[:cumulative_index]
|
||||
|
||||
# After finishing with the xgrammar operations, we convert to
|
||||
# np.ndarray, because that is much more efficient for serialization
|
||||
# and deserialization when sending this to the GPU workers.
|
||||
return bitmask_tensor.numpy()
|
||||
|
||||
def should_advance(self, request: Request) -> bool:
|
||||
if not request.use_structured_output:
|
||||
return False
|
||||
|
||||
# To determine whether we can advance the FSM.
|
||||
# Supports thinking usage where we skip the reasoning components.
|
||||
if TYPE_CHECKING:
|
||||
assert request.structured_output_request is not None
|
||||
assert request.structured_output_request.grammar is not None
|
||||
# by default, we should always advance
|
||||
# for cases that doesn't uses thinking mode.
|
||||
if self.reasoner is not None:
|
||||
structured_req = request.structured_output_request
|
||||
|
||||
if structured_req.reasoning_ended:
|
||||
return True
|
||||
|
||||
# Check if reasoning ends in *this* step
|
||||
if self.reasoner.is_reasoning_end(request.all_token_ids):
|
||||
# Reasoning just ended, so we shouldn't advanced til
|
||||
# next pass
|
||||
structured_req.reasoning_ended = True
|
||||
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def clear_backend(self) -> None:
|
||||
if self.backend is not None:
|
||||
self.backend.destroy()
|
||||
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
@ -8,10 +10,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
@ -54,25 +54,17 @@ def process_for_additional_properties(
|
||||
return guide_json_obj
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceBackend(StructuredOutputBackend):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
tokenizer_group = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||
self.vllm_config = vllm_config
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = \
|
||||
vllm_config.decoding_config.disable_any_whitespace
|
||||
self.vllm_config.decoding_config.disable_any_whitespace
|
||||
self.disable_additional_properties = \
|
||||
vllm_config.decoding_config.disable_additional_properties
|
||||
self.vllm_config.decoding_config.disable_additional_properties
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
||||
tokenizer, self.vocab_size)
|
||||
self.tokenizer, self.vocab_size)
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
|
@ -1,10 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class StructuredOutputOptions(enum.Enum):
|
||||
JSON = enum.auto()
|
||||
@ -85,9 +93,14 @@ class StructuredOutputGrammar(ABC):
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredOutputBackend(ABC):
|
||||
"""Engine-level backend for structured output requests."""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
tokenizer: AnyTokenizer
|
||||
vocab_size: int
|
||||
|
||||
@abstractmethod
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
@ -104,7 +117,7 @@ class StructuredOutputBackend(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
"""
|
||||
Allocates a token bitmask for the specified maximum number of sequences.
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -7,10 +9,8 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
import vllm.envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
@ -28,61 +28,49 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarBackend(StructuredOutputBackend):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
tokenizer_group = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = \
|
||||
vllm_config.decoding_config.disable_any_whitespace
|
||||
self.vllm_config.decoding_config.disable_any_whitespace
|
||||
|
||||
self.num_speculative_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
self.num_speculative_tokens = \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
try:
|
||||
if tokenizer.is_tekken:
|
||||
encoded_vocab = tokenizer._vocab
|
||||
if self.tokenizer.is_tekken:
|
||||
encoded_vocab = self.tokenizer._vocab
|
||||
else:
|
||||
encoded_vocab = [
|
||||
token for token, _ in sorted(
|
||||
tokenizer.get_vocab().items(),
|
||||
self.tokenizer.get_vocab().items(),
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
]
|
||||
stop_token_ids = None
|
||||
if hasattr(
|
||||
tokenizer,
|
||||
if (hasattr(
|
||||
self.tokenizer,
|
||||
"eos_token_id",
|
||||
) and tokenizer.eos_token_id is not None:
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
) and self.tokenizer.eos_token_id is not None):
|
||||
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(tokenizer)}. The tokenizer should have a "
|
||||
f"{type(self.tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
tokenizer_info = xgr.TokenizerInfo( # type: ignore
|
||||
encoded_vocab=encoded_vocab,
|
||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type=xgr.VocabType.RAW
|
||||
if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
|
||||
if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
|
||||
vocab_size=self.vocab_size,
|
||||
stop_token_ids=stop_token_ids,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
tokenizer,
|
||||
self.tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
self.compiler = xgr.GrammarCompiler(
|
||||
@ -92,6 +80,11 @@ class XgrammarBackend(StructuredOutputBackend):
|
||||
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
|
||||
)
|
||||
|
||||
self.num_speculative_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
self.num_speculative_tokens = \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
|
@ -20,6 +20,7 @@ class StructuredOutputRequest:
|
||||
sampling_params: SamplingParams
|
||||
_grammar: Optional[Union[Future[StructuredOutputGrammar],
|
||||
StructuredOutputGrammar]] = None
|
||||
reasoning_ended: bool = False
|
||||
|
||||
def _check_grammar_completion(self) -> bool:
|
||||
# NOTE: We have to lazy import to gate circular imports
|
||||
|
Reference in New Issue
Block a user