[Frontend] Support simpler image input format (#9478)

This commit is contained in:
Yue Zhang
2024-10-18 23:17:07 -07:00
committed by GitHub
parent 85dc92fc98
commit c5eea3c8ba
2 changed files with 138 additions and 23 deletions

View File

@ -388,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
"text": "What about these two?"
}]
}], phi3v_model_config, phi3v_tokenizer)
def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
"What's in these images?", {
"image_url": image_url
}, {
"image_url": image_url
}
]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)

View File

@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
Mapping, Optional, Tuple, TypeVar, Union, cast)
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
# yapf conflicts with isort for this block
# yapf: disable
@ -59,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part."""
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Example:
{
"image_url": "https://example.com/image.jpg"
}
"""
image_url: Required[str]
class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"audio_url": "https://example.com/audio.mp3"
}
"""
audio_url: Required[str]
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam]
CustomChatCompletionContentPartParam,
CustomChatCompletionContentSimpleImageParam,
CustomChatCompletionContentSimpleAudioParam, str]
class CustomChatCompletionMessageParam(TypedDict, total=False):
@ -387,6 +412,71 @@ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"refusal":
lambda part: _RefusalParser(part).get("refusal", ""),
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
"""
Parses a given multi modal content part based on its type.
Args:
part: A dict containing the content part, with a potential 'type' field.
Returns:
A tuple (part_type, content) where:
- part_type: Type of the part (e.g., 'text', 'image_url').
- content: Parsed content (e.g., text, image URL).
Raises:
ValueError: If the 'type' field is missing and no direct URL is found.
"""
assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
content = MM_PARSER_MAP[part_type](part)
# Special case for 'image_url.detail'
if part_type == "image_url" and part.get("detail") != "auto":
logger.warning("'image_url.detail' is currently not supported "
"and will be ignored.")
return part_type, content
# Handle missing 'type' but provided direct URL fields.
if part_type is None:
if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam,
part)
return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
return "audio_url", audio_params.get("audio_url", "")
# Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.")
if not isinstance(part_type, str):
raise ValueError("Invalid 'type' field in multimodal part.")
return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"audio_url")
def _parse_chat_message_content_parts(
role: str,
@ -402,29 +492,28 @@ def _parse_chat_message_content_parts(
has_image = False
for part in parts:
part_type = part["type"]
if part_type == "text":
text = _TextParser(part)["text"]
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
texts.append(text)
elif part_type == "image_url":
image_url = _ImageParser(part)["image_url"]
else: # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
# if part_type is text/refusal/image_url/audio_url but
# content is empty, logg a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
logger.warning("Skipping multimodal part "
"with empty / unparsable content.")
continue
mm_parser.parse_image(image_url["url"])
has_image = True
elif part_type == "audio_url":
audio_url = _AudioParser(part)["audio_url"]
mm_parser.parse_audio(audio_url["url"])
elif part_type == "refusal":
text = _RefusalParser(part)["refusal"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
if part_type in ("text", "refusal"):
texts.append(content)
elif part_type == "image_url":
mm_parser.parse_image(content)
has_image = True
elif part_type == "audio_url":
mm_parser.parse_audio(content)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if keep_multimodal_content: