427 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			427 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # SPDX-License-Identifier: Apache-2.0
 | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | |
| 
 | |
| # noqa: UP007
 | |
| from __future__ import annotations
 | |
| 
 | |
| import json
 | |
| from dataclasses import dataclass, field
 | |
| from typing import TYPE_CHECKING, Any
 | |
| 
 | |
| import regex as re
 | |
| import torch
 | |
| 
 | |
| import vllm.envs
 | |
| from vllm.logger import init_logger
 | |
| 
 | |
| try:
 | |
|     import xgrammar as xgr
 | |
|     xgr_installed = True
 | |
| except ImportError:
 | |
|     xgr_installed = False
 | |
|     pass
 | |
| 
 | |
| from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
 | |
|                                                        grammar_is_likely_lark)
 | |
| from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from transformers import PreTrainedTokenizer
 | |
| 
 | |
|     from vllm.config import ModelConfig
 | |
|     from vllm.reasoning import ReasoningParser
 | |
|     from vllm.sampling_params import GuidedDecodingParams
 | |
| 
 | |
| logger = init_logger(__name__)
 | |
| 
 | |
| 
 | |
| def get_local_xgrammar_guided_decoding_logits_processor(
 | |
|         guided_params: GuidedDecodingParams,
 | |
|         tokenizer: PreTrainedTokenizer,
 | |
|         model_config: ModelConfig,
 | |
|         reasoner: ReasoningParser | None,
 | |
|         max_threads: int = 8):
 | |
|     config = GrammarConfig.from_guided_params(guided_params=guided_params,
 | |
|                                               model_config=model_config,
 | |
|                                               tokenizer=tokenizer,
 | |
|                                               max_threads=max_threads)
 | |
|     return XGrammarLogitsProcessor(config, reasoner)
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True)
 | |
| class TokenizerData:
 | |
|     """Immutable container for cached tokenizer data."""
 | |
|     metadata: str
 | |
|     encoded_vocab: list[str] = field(default_factory=list)
 | |
| 
 | |
| 
 | |
| class TokenizerDataCache:
 | |
|     """Cache manager for tokenizer data to avoid repeated processing."""
 | |
|     _cache: dict[int, TokenizerData] = {}
 | |
| 
 | |
|     @classmethod
 | |
|     def get_tokenizer_data(
 | |
|         cls,
 | |
|         tokenizer: PreTrainedTokenizer,
 | |
|         /,
 | |
|         *,
 | |
|         tokenizer_hash: int,
 | |
|         vocab_size: int,
 | |
|     ) -> TokenizerData:
 | |
| 
 | |
|         if tokenizer_hash not in cls._cache:
 | |
|             tokenizer_info = xgr.TokenizerInfo.from_huggingface(
 | |
|                 tokenizer,
 | |
|                 # NOTE: We will need to use lm_head's vocab_size
 | |
|                 # to determine correct special_token_ids for this tokenizer.
 | |
|                 # See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92  # noqa: E501
 | |
|                 vocab_size=vocab_size,
 | |
|             )
 | |
|             metadata = json.loads(tokenizer_info.dump_metadata())
 | |
| 
 | |
|             # Vendored from xgrammar logic to get encoded_vocab
 | |
|             # https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
 | |
|             try:
 | |
|                 vocab_dict = tokenizer.get_vocab()
 | |
|             except AttributeError as e:
 | |
|                 raise ValueError(
 | |
|                     f"Cannot get the vocabulary of the tokenizer "
 | |
|                     f"{type(tokenizer)}. The tokenizer should have a "
 | |
|                     "get_vocab method.") from e
 | |
| 
 | |
|             # maintain tokenizer's indexing
 | |
|             encoded_vocab = [""] * tokenizer_info.vocab_size
 | |
|             for token, idx in vocab_dict.items():
 | |
|                 if idx < tokenizer_info.vocab_size:
 | |
|                     encoded_vocab[idx] = token
 | |
| 
 | |
|             if isinstance(tokenizer, MistralTokenizer):
 | |
|                 # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
 | |
|                 metadata.update({
 | |
|                     "vocab_type": xgr.VocabType.BYTE_FALLBACK,
 | |
|                     "add_prefix_space": True
 | |
|                 })
 | |
| 
 | |
|             cls._cache[tokenizer_hash] = TokenizerData(
 | |
|                 encoded_vocab=encoded_vocab,
 | |
|                 metadata=json.dumps(metadata),
 | |
|             )
 | |
| 
 | |
|         return cls._cache[tokenizer_hash]
 | |
| 
 | |
| 
 | |
| class GrammarCompilerCache:
 | |
|     """
 | |
|     Cache for GrammarCompiler instances based on tokenizer.
 | |
| 
 | |
|     This cache reduces the overhead of creating new compiler instances when
 | |
|     using the same tokenizer configuration.
 | |
|     """
 | |
|     _cache: dict[str, xgr.GrammarCompiler] = {}
 | |
| 
 | |
|     @classmethod
 | |
|     def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
 | |
|         cache_key = str(config.tokenizer_hash)
 | |
| 
 | |
|         if cache_key not in cls._cache:
 | |
|             config_data = config.tokenizer_data
 | |
| 
 | |
|             # In TokenizerDataCache.get_tokenizer_data, a serializable
 | |
|             # tokenizer_data is created and cached. This data is used to build
 | |
|             # a tokenizer_info and create an xgrammar compiler.
 | |
|             tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata(
 | |
|                 encoded_vocab=config_data.encoded_vocab,
 | |
|                 metadata=config_data.metadata,
 | |
|             )
 | |
|             cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024
 | |
|             cls._cache[cache_key] = xgr.GrammarCompiler(
 | |
|                 tokenizer_info,
 | |
|                 max_threads=config.max_threads,
 | |
|                 cache_enabled=True,
 | |
|                 cache_limit_bytes=cache_size,
 | |
|             )
 | |
| 
 | |
|         return cls._cache[cache_key]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class GrammarConfig:
 | |
|     """Serializable configuration for grammar compilation"""
 | |
|     tokenizer_hash: int
 | |
|     tokenizer_data: TokenizerData
 | |
|     json_str: str | None = None
 | |
|     grammar_str: str | None = None
 | |
|     json_object: bool | None = None
 | |
|     any_whitespace: bool = True
 | |
|     regex_str: str | None = None
 | |
|     max_threads: int = 8
 | |
| 
 | |
|     @classmethod
 | |
|     def from_guided_params(cls,
 | |
|                            guided_params: GuidedDecodingParams,
 | |
|                            model_config: ModelConfig,
 | |
|                            tokenizer: PreTrainedTokenizer,
 | |
|                            max_threads: int = 8) -> GrammarConfig:
 | |
| 
 | |
|         tokenizer_hash = hash(tokenizer)
 | |
|         tokenizer_data = TokenizerDataCache.get_tokenizer_data(
 | |
|             tokenizer,
 | |
|             tokenizer_hash=tokenizer_hash,
 | |
|             vocab_size=model_config.hf_text_config.vocab_size,
 | |
|         )
 | |
| 
 | |
|         if guided_params.json:
 | |
|             if not isinstance(guided_params.json, str):
 | |
|                 json_str = json.dumps(guided_params.json)
 | |
|             else:
 | |
|                 json_str = guided_params.json
 | |
| 
 | |
|             any_whitespace = not guided_params.disable_any_whitespace
 | |
| 
 | |
|             # Check and log if model with xgrammar and whitespace have history
 | |
|             # of runaway generation of whitespaces.
 | |
|             # References:
 | |
|             # https://github.com/vllm-project/vllm/pull/12744
 | |
|             # https://github.com/mlc-ai/xgrammar/issues/212
 | |
|             model_with_warn = None
 | |
| 
 | |
|             if 'Mistral' in model_config.model:
 | |
|                 model_with_warn = 'Mistral'
 | |
|             elif 'Qwen' in model_config.model:
 | |
|                 model_with_warn = 'Qwen'
 | |
| 
 | |
|             if model_with_warn is not None and any_whitespace:
 | |
|                 logger.info_once(
 | |
|                     "%s model detected, consider setting `disable_any_whitespace` to prevent runaway generation of whitespaces.",  # noqa: E501
 | |
|                     model_with_warn,
 | |
|                 )
 | |
|             # Validate the schema and raise ValueError here if it is invalid.
 | |
|             # This is to avoid exceptions in model execution, which will crash
 | |
|             # the engine worker process.
 | |
|             try:
 | |
|                 xgr.Grammar.from_json_schema(json_str,
 | |
|                                              any_whitespace=any_whitespace)
 | |
|             except RuntimeError as err:
 | |
|                 raise ValueError(str(err)) from err
 | |
| 
 | |
|             return cls(json_str=json_str,
 | |
|                        tokenizer_hash=tokenizer_hash,
 | |
|                        max_threads=max_threads,
 | |
|                        tokenizer_data=tokenizer_data,
 | |
|                        any_whitespace=any_whitespace)
 | |
|         elif guided_params.grammar:
 | |
|             # XGrammar only supports GBNF grammars, so we must convert Lark
 | |
|             if grammar_is_likely_lark(guided_params.grammar):
 | |
|                 try:
 | |
|                     grammar_str = convert_lark_to_gbnf(guided_params.grammar)
 | |
|                 except ValueError as e:
 | |
|                     raise ValueError(
 | |
|                         "Failed to convert the grammar from Lark to GBNF. "
 | |
|                         "Please either use GBNF grammar directly or specify"
 | |
|                         " --guided-decoding-backend=outlines.\n"
 | |
|                         f"Conversion error: {str(e)}") from e
 | |
|             else:
 | |
|                 grammar_str = guided_params.grammar
 | |
| 
 | |
|             # Validate the grammar and raise ValueError here if it is invalid.
 | |
|             # This is to avoid exceptions in model execution, which will crash
 | |
|             # the engine worker process.
 | |
|             try:
 | |
|                 xgr.Grammar.from_ebnf(grammar_str)
 | |
|             except RuntimeError as err:
 | |
|                 raise ValueError(str(err)) from err
 | |
| 
 | |
|             return cls(grammar_str=grammar_str,
 | |
|                        tokenizer_hash=tokenizer_hash,
 | |
|                        max_threads=max_threads,
 | |
|                        tokenizer_data=tokenizer_data)
 | |
|         elif guided_params.json_object:
 | |
|             return cls(
 | |
|                 json_object=True,
 | |
|                 tokenizer_hash=tokenizer_hash,
 | |
|                 max_threads=max_threads,
 | |
|                 tokenizer_data=tokenizer_data,
 | |
|             )
 | |
|         elif guided_params.choice:
 | |
|             choice_str = GrammarConfig.choice_as_grammar(guided_params.choice)
 | |
|             try:
 | |
|                 xgr.Grammar.from_ebnf(choice_str)
 | |
|             except RuntimeError as err:
 | |
|                 raise ValueError(str(err)) from err
 | |
| 
 | |
|             return cls(
 | |
|                 grammar_str=choice_str,
 | |
|                 tokenizer_hash=tokenizer_hash,
 | |
|                 max_threads=max_threads,
 | |
|                 tokenizer_data=tokenizer_data,
 | |
|             )
 | |
|         elif guided_params.regex:
 | |
|             return cls(
 | |
|                 regex_str=guided_params.regex,
 | |
|                 tokenizer_hash=tokenizer_hash,
 | |
|                 max_threads=max_threads,
 | |
|                 tokenizer_data=tokenizer_data,
 | |
|             )
 | |
|         else:
 | |
|             raise ValueError(
 | |
|                 "Currently only support JSON and EBNF grammar mode for xgrammar"
 | |
|             )
 | |
| 
 | |
|     @staticmethod
 | |
|     def escape_ebnf_string(s: str) -> str:
 | |
|         """Escape special characters in a EBNF string."""
 | |
|         # Escape double quotes and backslashes
 | |
|         return re.sub(r'(["\\])', r'\\\1', s)
 | |
| 
 | |
|     @staticmethod
 | |
|     def choice_as_grammar(choice: list[str] | None) -> str:
 | |
|         if choice is None:
 | |
|             raise ValueError("Choice is not set")
 | |
|         escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
 | |
|         grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
 | |
|         return grammar
 | |
| 
 | |
|     @staticmethod
 | |
|     def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo:
 | |
|         return xgr.TokenizerInfo.from_vocab_and_metadata(
 | |
|             encoded_vocab=tokenizer_data.encoded_vocab,
 | |
|             metadata=tokenizer_data.metadata,
 | |
|         )
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class XGrammarLogitsProcessor:
 | |
|     """Wrapper class to support pickle protocol"""
 | |
|     config: GrammarConfig
 | |
|     reasoner: ReasoningParser | None = None
 | |
| 
 | |
|     ctx: xgr.CompiledGrammar | None = None
 | |
|     tokenizer_info: xgr.TokenizerInfo = None  # type: ignore[assignment]
 | |
|     token_bitmask: torch.Tensor = None  # type: ignore[assignment]
 | |
|     matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
 | |
|     batch_size: int = field(default=1)
 | |
|     prefilled: bool = field(default=False)
 | |
| 
 | |
|     def __post_init__(self):
 | |
|         if self.tokenizer_info is None:
 | |
|             self.tokenizer_info = self.config.tokenizer_info(
 | |
|                 self.config.tokenizer_data)
 | |
| 
 | |
|     def __getstate__(self) -> dict[str, Any]:
 | |
|         return {'config': self.config, 'reasoner': self.reasoner}
 | |
| 
 | |
|     def __setstate__(self, state: dict[str, Any]):
 | |
|         self.config = state['config']
 | |
|         self.reasoner = state['reasoner']
 | |
| 
 | |
|         self.tokenizer_info = GrammarConfig.tokenizer_info(
 | |
|             self.config.tokenizer_data)
 | |
|         self.ctx = None
 | |
|         self.matchers = []
 | |
|         self.batch_size = 1
 | |
|         self.token_bitmask = None  # type: ignore[assignment]
 | |
|         self.prefilled = False
 | |
| 
 | |
|     def _ensure_ctx(self):
 | |
|         """Lazily initialize the processor in the worker process"""
 | |
|         if self.ctx is None:
 | |
|             compiler = GrammarCompilerCache.get_compiler(self.config)
 | |
|             if self.config.json_str is not None:
 | |
|                 any_whitespace = self.config.any_whitespace
 | |
|                 self.ctx = compiler\
 | |
|                     .compile_json_schema(self.config.json_str,
 | |
|                                          any_whitespace=any_whitespace)
 | |
|             elif self.config.grammar_str is not None:
 | |
|                 self.ctx = compiler.compile_grammar(self.config.grammar_str)
 | |
|             elif self.config.json_object:
 | |
|                 any_whitespace = self.config.any_whitespace
 | |
|                 self.ctx = compiler\
 | |
|                     .compile_json_schema('{"type": "object"}',
 | |
|                                          any_whitespace=any_whitespace)
 | |
|             elif self.config.regex_str:
 | |
|                 self.ctx = compiler.compile_regex(self.config.regex_str)
 | |
|             else:
 | |
|                 raise ValueError(
 | |
|                     "Invalid configuration for xgrammar logits processor")
 | |
| 
 | |
|     def __call__(self, input_ids: list[int],
 | |
|                  scores: torch.Tensor) -> torch.Tensor:
 | |
| 
 | |
|         # Skip the structured logits processing if reasoning is not finished.
 | |
|         # reasoner is not None only when `--reasoning-parser` is set.
 | |
|         if self.reasoner is not None and \
 | |
|         not self.reasoner.is_reasoning_end(
 | |
|                 input_ids):
 | |
|             return scores
 | |
| 
 | |
|         if self.ctx is None:
 | |
|             self._ensure_ctx()
 | |
| 
 | |
|         if len(self.matchers) == 0:
 | |
|             self.matchers = [
 | |
|                 xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
 | |
|             ]
 | |
|             self.token_bitmask = xgr.allocate_token_bitmask(
 | |
|                 self.batch_size, self.tokenizer_info.vocab_size)
 | |
| 
 | |
|         if not self.prefilled:
 | |
|             # Have not sampled a token yet
 | |
|             self.prefilled = True
 | |
|         else:
 | |
|             for i, matcher in enumerate(self.matchers):
 | |
|                 if not matcher.is_terminated():
 | |
|                     sampled_token = input_ids[-1]
 | |
|                     assert self.matchers[i].accept_token(sampled_token)
 | |
| 
 | |
|         for i, matcher in enumerate(self.matchers):
 | |
|             if not matcher.is_terminated():
 | |
|                 # @ubospica: ideally, fill_next_token_bitmask should be
 | |
|                 # parallelized with model decoding
 | |
|                 # See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303
 | |
|                 matcher.fill_next_token_bitmask(self.token_bitmask, i)
 | |
| 
 | |
|         # token_bitmask is a CPU tensor for use with accept_token and
 | |
|         # fill_next_token_bitmask so we move it to the device of scores
 | |
|         device_type = scores.device.type
 | |
|         dtype = scores.dtype
 | |
|         if device_type != "cuda":
 | |
|             # xgrammar on cpu only supports float32 scores
 | |
|             # see: https://github.com/mlc-ai/xgrammar/blob/c1b64920cad24f44f235778c1c00bb52d57da01a/python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py#L22
 | |
|             scores = scores.to("cpu").float().unsqueeze(0)
 | |
| 
 | |
|         # Note: In this method, if the tensors have different dimensions
 | |
|         # on CPU device fails, but on GPU it runs without error. Hence the
 | |
|         # unsqueeze above for scores, to match the token bitmask shape
 | |
|         xgr.apply_token_bitmask_inplace(
 | |
|             scores, self.token_bitmask.to(scores.device, non_blocking=True))
 | |
|         if device_type != "cuda":
 | |
|             scores = scores.to(dtype).to(device_type).squeeze()
 | |
| 
 | |
|         return scores
 | |
| 
 | |
|     def clone(self) -> XGrammarLogitsProcessor:
 | |
|         """Create a new instance with shared compiled grammar
 | |
|           but separate state"""
 | |
|         new_processor = XGrammarLogitsProcessor(self.config, self.reasoner,
 | |
|                                                 None, self.tokenizer_info)
 | |
| 
 | |
|         # Share the compiled grammar context (immutable after compilation)
 | |
|         new_processor.ctx = self.ctx
 | |
| 
 | |
|         # Create fresh matchers for the new sequence
 | |
|         if self.ctx is not None:
 | |
|             new_processor.matchers = [
 | |
|                 xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
 | |
|             ]
 | |
| 
 | |
|         # Create a new token bitmask with the same size
 | |
|         if hasattr(self, 'token_bitmask') and self.token_bitmask is not None:
 | |
|             new_processor.token_bitmask = self.token_bitmask
 | |
| 
 | |
|         # Copy simple attributes
 | |
|         new_processor.batch_size = self.batch_size
 | |
|         # Reset prefilled state for new sequence
 | |
|         new_processor.prefilled = False
 | |
| 
 | |
|         return new_processor
 |