From 703e42ee4b3efed3c71e7ae7d15f0f96e05722d4 Mon Sep 17 00:00:00 2001 From: felixzhu555 <79335195+felixzhu555@users.noreply.github.com> Date: Thu, 29 Feb 2024 14:13:08 -0800 Subject: [PATCH] Add guided decoding for OpenAI API server (#2819) Co-authored-by: br3no Co-authored-by: simon-mo --- requirements.txt | 1 + tests/entrypoints/test_guided_processors.py | 75 ++++++ tests/entrypoints/test_openai_server.py | 237 ++++++++++++++++++ vllm/engine/async_llm_engine.py | 3 + vllm/entrypoints/openai/protocol.py | 36 ++- vllm/entrypoints/openai/serving_chat.py | 9 + vllm/entrypoints/openai/serving_completion.py | 9 + vllm/model_executor/guided_decoding.py | 99 ++++++++ .../guided_logits_processors.py | 129 ++++++++++ 9 files changed, 597 insertions(+), 1 deletion(-) create mode 100644 tests/entrypoints/test_guided_processors.py create mode 100644 vllm/model_executor/guided_decoding.py create mode 100644 vllm/model_executor/guided_logits_processors.py diff --git a/requirements.txt b/requirements.txt index d4599ec95d9..05ec2e804e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 pynvml == 11.5.0 triton >= 2.1.0 +outlines >= 0.0.27 cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py new file mode 100644 index 00000000000..5b39269916f --- /dev/null +++ b/tests/entrypoints/test_guided_processors.py @@ -0,0 +1,75 @@ +# This unit test should be moved to a new +# tests/test_guided_decoding directory. + +from transformers import AutoTokenizer +import torch + +from vllm.model_executor.guided_logits_processors import (RegexLogitsProcessor, + JSONLogitsProcessor) + +TEST_SCHEMA = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work history"] +} + +TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + + +def test_guided_logits_processors(): + """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" + tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer) + json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) + + regex_LP.init_state() + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {TEST_REGEX}") + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + regex_LP(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + json_LP.init_state() + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {TEST_SCHEMA}") + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + json_LP(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 72e23748997..e426cf7eed7 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -9,12 +9,64 @@ import ray # using Ray for overall ease of process management, parallel request import openai # use the official client for correctness check from huggingface_hub import snapshot_download # downloading lora to test lora requests +# imports for guided decoding tests +import json +import jsonschema +import re + from vllm.transformers_utils.tokenizer import get_tokenizer MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here +TEST_SCHEMA = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work history"] +} + +TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + +TEST_CHOICE = [ + "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby", + "Swift", "Kotlin" +] + pytestmark = pytest.mark.asyncio @@ -325,6 +377,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): max_tokens=max_tokens, temperature=0.0, logit_bias={str(token_id): 100}, + seed=42, ) assert completion.choices[0].text is not None and len( completion.choices[0].text) >= 5 @@ -358,5 +411,189 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text +async def test_guided_json_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt= + f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}", + n=3, + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_json=TEST_SCHEMA)) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 3 + for i in range(3): + assert completion.choices[i].text is not None + output_json = json.loads(completion.choices[i].text) + jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) + + +async def test_guided_json_chat(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "Give an example JSON for an employee profile that " + \ + f"fits this schema: {TEST_SCHEMA}" + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=500, + extra_body=dict(guided_json=TEST_SCHEMA)) + message = chat_completion.choices[0].message + assert message.content is not None + json1 = json.loads(message.content) + jsonschema.validate(instance=json1, schema=TEST_SCHEMA) + + messages.append({"role": "assistant", "content": message.content}) + messages.append({ + "role": + "user", + "content": + "Give me another one with a different name and age" + }) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=500, + extra_body=dict(guided_json=TEST_SCHEMA)) + message = chat_completion.choices[0].message + assert message.content is not None + json2 = json.loads(message.content) + jsonschema.validate(instance=json2, schema=TEST_SCHEMA) + assert json1["name"] != json2["name"] + assert json1["age"] != json2["age"] + + +async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", + n=3, + temperature=1.0, + max_tokens=20, + extra_body=dict(guided_regex=TEST_REGEX)) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 3 + for i in range(3): + assert completion.choices[i].text is not None + assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None + + +async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example IP address with this regex: {TEST_REGEX}" + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=20, + extra_body=dict(guided_regex=TEST_REGEX)) + ip1 = chat_completion.choices[0].message.content + assert ip1 is not None + assert re.fullmatch(TEST_REGEX, ip1) is not None + + messages.append({"role": "assistant", "content": ip1}) + messages.append({"role": "user", "content": "Give me a different one"}) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=20, + extra_body=dict(guided_regex=TEST_REGEX)) + ip2 = chat_completion.choices[0].message.content + assert ip2 is not None + assert re.fullmatch(TEST_REGEX, ip2) is not None + assert ip1 != ip2 + + +async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt="The best language for type-safe systems programming is ", + n=2, + temperature=1.0, + max_tokens=10, + extra_body=dict(guided_choice=TEST_CHOICE)) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 2 + for i in range(2): + assert completion.choices[i].text in TEST_CHOICE + + +async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + "The best language for type-safe systems programming is " + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + extra_body=dict(guided_choice=TEST_CHOICE)) + choice1 = chat_completion.choices[0].message.content + assert choice1 in TEST_CHOICE + + messages.append({"role": "assistant", "content": choice1}) + messages.append({ + "role": "user", + "content": "I disagree, pick another one" + }) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + extra_body=dict(guided_choice=TEST_CHOICE)) + choice2 = chat_completion.choices[0].message.content + assert choice2 in TEST_CHOICE + assert choice1 != choice2 + + +async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON that fits this schema: 42", + extra_body=dict(guided_json=42)) + + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + "The best language for type-safe systems programming is " + }] + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + extra_body=dict(guided_regex={ + 1: "Python", + 2: "C++" + })) + + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example string that fits this regex", + extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7cba6546027..daa6419cdad 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -333,6 +333,9 @@ class AsyncLLMEngine: return (self.background_loop is not None and not self.background_loop.done()) + def get_tokenizer(self): + return self.engine.tokenizer.tokenizer + def start_background_loop(self) -> None: """Start the background loop.""" if self.is_running: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 97cfd797587..26499b8d7a6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,7 +3,7 @@ import time from typing import Dict, List, Literal, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from vllm.utils import random_uuid from vllm.sampling_params import SamplingParams @@ -86,6 +86,9 @@ class ChatCompletionRequest(BaseModel): min_p: Optional[float] = 0.0 include_stop_str_in_output: Optional[bool] = False length_penalty: Optional[float] = 1.0 + guided_json: Optional[Union[str, dict, BaseModel]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None def to_sampling_params(self) -> SamplingParams: if self.logprobs and not self.top_logprobs: @@ -131,6 +134,20 @@ class ChatCompletionRequest(BaseModel): logits_processors=logits_processors, ) + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + class CompletionRequest(BaseModel): model: str @@ -163,6 +180,9 @@ class CompletionRequest(BaseModel): min_p: Optional[float] = 0.0 include_stop_str_in_output: Optional[bool] = False length_penalty: Optional[float] = 1.0 + guided_json: Optional[Union[str, dict, BaseModel]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None def to_sampling_params(self): echo_without_generation = self.echo and self.max_tokens == 0 @@ -207,6 +227,20 @@ class CompletionRequest(BaseModel): logits_processors=logits_processors, ) + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e5ae39e110a..f4ad0aa5a01 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,6 +12,7 @@ from vllm.entrypoints.openai.protocol import ( UsageInfo) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA +from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor logger = init_logger(__name__) @@ -62,6 +63,14 @@ class OpenAIServingChat(OpenAIServing): prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + guided_decode_logits_processor = ( + await get_guided_decoding_logits_processor( + request, self.engine.get_tokenizer())) + if guided_decode_logits_processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append( + guided_decode_logits_processor) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 610f53549da..713e67793b2 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -16,6 +16,7 @@ from .protocol import ( ) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA +from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor logger = init_logger(__name__) @@ -286,6 +287,14 @@ class OpenAIServingCompletion(OpenAIServing): try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + guided_decode_logit_processor = ( + await get_guided_decoding_logits_processor( + request, self.engine.get_tokenizer())) + if guided_decode_logit_processor is not None: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append( + guided_decode_logit_processor) prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py new file mode 100644 index 00000000000..a8573f8bdc6 --- /dev/null +++ b/vllm/model_executor/guided_decoding.py @@ -0,0 +1,99 @@ +import asyncio +import concurrent.futures +from copy import copy +from enum import Enum +from functools import lru_cache +from json import dumps as json_dumps +from re import escape as regex_escape +from typing import Union, Tuple +from pydantic import BaseModel + +from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest +from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor + + +class GuidedDecodingMode(Enum): + JSON = "json" + REGEX = "regex" + CHOICE = "choice" + + +global_thread_pool = None # used for generating logits processor fsm + + +async def get_guided_decoding_logits_processor( + request: Union[CompletionRequest, ChatCompletionRequest], + tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + global global_thread_pool + guide, mode = _get_guide_and_mode(request) + if not guide: + return None + + if global_thread_pool is None: + global_thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=2) + loop = asyncio.get_running_loop() + + result = await loop.run_in_executor(global_thread_pool, + _get_cached_logits_processor, guide, + tokenizer, mode) + + logits_processor = copy(result) + # reset logits processor's internal state + logits_processor.init_state() + return logits_processor + + +def _get_guide_and_mode( + request: Union[CompletionRequest, ChatCompletionRequest] +) -> Tuple[str, GuidedDecodingMode]: + + if request.guided_json: + if not isinstance(request.guided_json, (str, dict, BaseModel)): + raise TypeError("JSON schema must be str, dict, or BaseModel") + + json = request.guided_json + if isinstance(json, dict): + # turn dict into hashable string + json = json_dumps(json, sort_keys=True) + elif isinstance(json, BaseModel): + # use pydantic signature so that different model classes + # with the same fields will get hashed the same + json = str(json.__signature__) + return json, GuidedDecodingMode.JSON + + elif request.guided_regex: + if not isinstance(request.guided_regex, str): + raise TypeError("Regex must be string") + return request.guided_regex, GuidedDecodingMode.REGEX + + elif request.guided_choice: + if not isinstance(request.guided_choice, list): + raise TypeError("Choices must be a list") + + # choice just uses regex + choices = [ + regex_escape(str(choice)) for choice in request.guided_choice + ] + choices_regex = "(" + "|".join(choices) + ")" + return choices_regex, GuidedDecodingMode.CHOICE + + else: + return None, None + + +@lru_cache(maxsize=32) +def _get_cached_logits_processor(guide: str, tokenizer, + mode: GuidedDecodingMode): + if mode == GuidedDecodingMode.JSON: + return JSONLogitsProcessor(guide, tokenizer) + elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: + return RegexLogitsProcessor(guide, tokenizer) + else: + raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_logits_processors.py new file mode 100644 index 00000000000..1b3e5e71a59 --- /dev/null +++ b/vllm/model_executor/guided_logits_processors.py @@ -0,0 +1,129 @@ +# Copyright 2024- the Outlines developers +# This file is adapted from +# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import math +from collections import defaultdict +from typing import Union, DefaultDict, Dict, List, Optional + +import torch +from pydantic import BaseModel +from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_schema + + +class RegexLogitsProcessor: + + def __init__(self, regex_string: str, tokenizer): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + The model's tokenizer + + """ + tokenizer = self.adapt_tokenizer(tokenizer) + fsm = RegexFSM(regex_string, tokenizer) + self.fsm = fsm + + def init_state(self): + """Initialize the FSM states.""" + self.fsm_state: DefaultDict[int, int] = defaultdict(int) + + def __call__(self, input_ids: List[int], + scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token.""" + + seq_id = hash(tuple(input_ids)) + + if len(input_ids) == 0: + self.init_state() + else: + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + self.fsm_state[seq_id] = self.fsm.next_state( + self.fsm_state[last_seq_id], last_token) + + allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + + mask = torch.full((scores.shape[-1], ), + -math.inf, + device=scores.device) + mask[allowed_tokens] = 0 + scores.add_(mask) + + return scores + + def adapt_tokenizer(self, tokenizer): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + + +class JSONLogitsProcessor(RegexLogitsProcessor): + + def __init__(self, + schema: Union[str, Dict, BaseModel], + tokenizer, + whitespace_pattern: Optional[str] = None): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate + tokenizer + The model's tokenizer + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + """ + if isinstance(schema, type(BaseModel)): + schema_str = json.dumps(schema.model_json_schema()) + elif isinstance(schema, Dict): + schema_str = json.dumps(schema) + elif isinstance(schema, str): + schema_str = schema + else: + raise ValueError( + f"Cannot parse schema {schema}. The schema must be either " + + "a Pydantic object, a dictionary or a string that contains the JSON " + + "Schema specification") + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string, tokenizer)