mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Frontend] Factor out chat message parsing (#7055)
This commit is contained in:
@ -1,7 +1,8 @@
|
||||
import codecs
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Awaitable, Iterable, List, Optional, Union, cast, final
|
||||
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
|
||||
final)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
|
||||
@dataclass(frozen=True)
|
||||
class ChatMessageParseResult:
|
||||
messages: List[ConversationMessage]
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
|
||||
default_factory=list)
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]]
|
||||
|
||||
|
||||
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
|
||||
@ -174,7 +174,7 @@ def _parse_chat_message_content_parts(
|
||||
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
|
||||
|
||||
|
||||
def parse_chat_message_content(
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
@ -190,3 +190,21 @@ def parse_chat_message_content(
|
||||
|
||||
return _parse_chat_message_content_parts(role, content, model_config,
|
||||
tokenizer)
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
|
||||
for msg in messages:
|
||||
parse_result = _parse_chat_message_content(msg, model_config,
|
||||
tokenizer)
|
||||
|
||||
conversation.extend(parse_result.messages)
|
||||
mm_futures.extend(parse_result.mm_futures)
|
||||
|
||||
return conversation, mm_futures
|
||||
|
@ -1,6 +1,5 @@
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
|
||||
Optional)
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
@ -11,7 +10,7 @@ from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_message_content)
|
||||
parse_chat_messages)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
@ -92,15 +91,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||
lora_request)
|
||||
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
|
||||
for msg in request.messages:
|
||||
chat_parsed_result = parse_chat_message_content(
|
||||
msg, model_config, tokenizer)
|
||||
|
||||
conversation.extend(chat_parsed_result.messages)
|
||||
mm_futures.extend(chat_parsed_result.mm_futures)
|
||||
conversation, mm_futures = parse_chat_messages(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
tool_dicts = None if request.tools is None else [
|
||||
tool.model_dump() for tool in request.tools
|
||||
@ -115,6 +107,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
assert isinstance(prompt, str)
|
||||
except Exception as e:
|
||||
logger.error("Error in applying chat template from request: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
|
@ -1,13 +1,11 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_message_content)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
@ -17,8 +15,11 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
@ -62,12 +63,12 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
model_config = self.model_config
|
||||
|
||||
conversation: List[ConversationMessage] = []
|
||||
conversation, mm_futures = parse_chat_messages(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
for message in request.messages:
|
||||
result = parse_chat_message_content(message, model_config,
|
||||
tokenizer)
|
||||
conversation.extend(result.messages)
|
||||
if mm_futures:
|
||||
logger.warning(
|
||||
"Multi-modal inputs are ignored during tokenization")
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
|
Reference in New Issue
Block a user