[Frontend] Factor out chat message parsing (#7055)

This commit is contained in:
Cyrus Leung
2024-08-03 12:31:27 +08:00
committed by GitHub
parent 69ea15e5cc
commit 8c025fa703
3 changed files with 39 additions and 27 deletions

View File

@ -1,7 +1,8 @@
import codecs
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, Union, cast, final
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
# yapf conflicts with isort for this block
# yapf: disable
@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
default_factory=list)
mm_futures: List[Awaitable[MultiModalDataDict]]
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
@ -174,7 +174,7 @@ def _parse_chat_message_content_parts(
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def parse_chat_message_content(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
@ -190,3 +190,21 @@ def parse_chat_message_content(
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)
conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures

View File

@ -1,6 +1,5 @@
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
Optional)
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
@ -11,7 +10,7 @@ from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
@ -92,15 +91,8 @@ class OpenAIServingChat(OpenAIServing):
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in request.messages:
chat_parsed_result = parse_chat_message_content(
msg, model_config, tokenizer)
conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
@ -115,6 +107,7 @@ class OpenAIServingChat(OpenAIServing):
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
assert isinstance(prompt, str)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))

View File

@ -1,13 +1,11 @@
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
@ -17,8 +15,11 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
class OpenAIServingTokenization(OpenAIServing):
@ -62,12 +63,12 @@ class OpenAIServingTokenization(OpenAIServing):
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation: List[ConversationMessage] = []
conversation, mm_futures = parse_chat_messages(
request.messages, model_config, tokenizer)
for message in request.messages:
result = parse_chat_message_content(message, model_config,
tokenizer)
conversation.extend(result.messages)
if mm_futures:
logger.warning(
"Multi-modal inputs are ignored during tokenization")
prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,