Support for guided decoding for offline LLM (#6878)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Yihuan Bu
2024-08-03 23:12:09 -04:00
committed by GitHub
parent 825b044863
commit 654bc5ca49
9 changed files with 352 additions and 12 deletions

View File

@ -0,0 +1,38 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, TypedDict, Union
from pydantic import BaseModel
class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str]
guided_regex: str
guided_choice: List[str]
guided_grammar: str
guided_decoding_backend: str
guided_whitespace_pattern: str
guided_json_object: bool
@dataclass
class GuidedDecodingRequest:
"""One of the fields will be used to retrieve the logit processor."""
guided_json: Optional[Union[Dict, BaseModel, str]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
guided_decoding_backend: Optional[str] = None
guided_whitespace_pattern: Optional[str] = None
guided_json_object: Optional[bool] = None
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
self.guided_json is not None, self.guided_regex is not None,
self.guided_choice is not None, self.guided_grammar is not None,
self.guided_json_object is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding but multiple are "
f"specified: {self.__dict__}")