mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Frontend] Re-enable custom roles in Chat Completions API (#4758)
This commit is contained in:
@ -783,6 +783,36 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
|
||||
assert content == "2"
|
||||
|
||||
|
||||
async def test_custom_role(server, client: openai.AsyncOpenAI):
|
||||
# Not sure how the model handles custom roles so we just check that
|
||||
# both string and complex message content are handled in the same way
|
||||
|
||||
resp1 = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "my-custom-role",
|
||||
"content": "what is 1+1?",
|
||||
}], # type: ignore
|
||||
temperature=0,
|
||||
seed=0)
|
||||
|
||||
resp2 = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role": "my-custom-role",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "what is 1+1?"
|
||||
}]
|
||||
}], # type: ignore
|
||||
temperature=0,
|
||||
seed=0)
|
||||
|
||||
content1 = resp1.choices[0].message.content
|
||||
content2 = resp2.choices[0].message.content
|
||||
assert content1 == content2
|
||||
|
||||
|
||||
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||
simple_sql_grammar = """
|
||||
start: select_statement
|
||||
|
@ -3,16 +3,50 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import openai.types.chat
|
||||
import torch
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from typing_extensions import Annotated, Required, TypedDict
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
||||
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
||||
|
||||
type: Required[str]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
ChatCompletionContentPartParam = Union[
|
||||
openai.types.chat.ChatCompletionContentPartParam,
|
||||
CustomChatCompletionContentPartParam]
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
"""Enables custom roles in the Chat Completion API."""
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||
"""The contents of the message."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the
|
||||
same role.
|
||||
"""
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[
|
||||
openai.types.chat.ChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam]
|
||||
|
||||
|
||||
class OpenAIBaseModel(BaseModel):
|
||||
# OpenAI API does not allow extra fields
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
@ -1,15 +1,16 @@
|
||||
import codecs
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
|
||||
Optional, Tuple, TypedDict, Union, final)
|
||||
from dataclasses import dataclass
|
||||
from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional,
|
||||
TypedDict, Union, cast, final)
|
||||
|
||||
from fastapi import Request
|
||||
from openai.types.chat import (ChatCompletionContentPartParam,
|
||||
ChatCompletionRole)
|
||||
from openai.types.chat import ChatCompletionContentPartTextParam
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionContentPartParam, ChatCompletionMessageParam,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
@ -31,6 +32,11 @@ class ConversationMessage(TypedDict):
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatMessageParseResult:
|
||||
messages: List[ConversationMessage]
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
@ -77,27 +83,40 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logger.warning(
|
||||
"No chat template provided. Chat API will not work.")
|
||||
|
||||
def _parse_chat_message_content(
|
||||
def _parse_chat_message_content_parts(
|
||||
self,
|
||||
role: ChatCompletionRole,
|
||||
content: Optional[Union[str,
|
||||
Iterable[ChatCompletionContentPartParam]]],
|
||||
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
|
||||
if content is None:
|
||||
return [], []
|
||||
if isinstance(content, str):
|
||||
return [ConversationMessage(role=role, content=content)], []
|
||||
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
) -> ChatMessageParseResult:
|
||||
texts: List[str] = []
|
||||
for _, part in enumerate(content):
|
||||
if part["type"] == "text":
|
||||
text = part["text"]
|
||||
|
||||
for _, part in enumerate(parts):
|
||||
part_type = part["type"]
|
||||
if part_type == "text":
|
||||
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
||||
|
||||
texts.append(text)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part['type']}")
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
return [ConversationMessage(role=role, content="\n".join(texts))], []
|
||||
messages = [ConversationMessage(role=role, content="\n".join(texts))]
|
||||
|
||||
return ChatMessageParseResult(messages=messages)
|
||||
|
||||
def _parse_chat_message_content(
|
||||
self,
|
||||
message: ChatCompletionMessageParam,
|
||||
) -> ChatMessageParseResult:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
|
||||
if content is None:
|
||||
return ChatMessageParseResult(messages=[])
|
||||
if isinstance(content, str):
|
||||
messages = [ConversationMessage(role=role, content=content)]
|
||||
return ChatMessageParseResult(messages=messages)
|
||||
|
||||
return self._parse_chat_message_content_parts(role, content)
|
||||
|
||||
async def create_chat_completion(
|
||||
self, request: ChatCompletionRequest, raw_request: Request
|
||||
@ -119,11 +138,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
try:
|
||||
conversation: List[ConversationMessage] = []
|
||||
|
||||
for m in request.messages:
|
||||
messages, _ = self._parse_chat_message_content(
|
||||
m["role"], m["content"])
|
||||
for msg in request.messages:
|
||||
parsed_msg = self._parse_chat_message_content(msg)
|
||||
|
||||
conversation.extend(messages)
|
||||
conversation.extend(parsed_msg.messages)
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
@ -387,4 +405,4 @@ class OpenAIServingChat(OpenAIServing):
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return response
|
||||
return response
|
||||
|
Reference in New Issue
Block a user