Signed-off-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
Roger Wang
2025-05-09 13:03:34 -07:00
parent 22481fbfa3
commit 32c0155774

View File

@ -27,6 +27,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
from PIL import Image
from pydantic import TypeAdapter
# yapf: enable
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
@ -87,6 +88,20 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
"""The type of the content part."""
class PILImage(TypedDict, total=False):
image: Required[Image.Image]
"""
A PIL.Image.Image object.
"""
class ChatCompletionContentPartPILImageParam(TypedDict, total=False):
image: Required[PILImage]
type: Required[Literal["image"]]
"""The type of the content part."""
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
@ -124,6 +139,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartPILImageParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
@ -680,6 +696,10 @@ class BaseMultiModalContentParser(ABC):
image_embeds: Union[str, dict[str, str]]) -> None:
raise NotImplementedError
@abstractmethod
def parse_pil_image(self, image: Image.Image) -> None:
raise NotImplementedError
@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
@ -710,6 +730,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_pil_image(self, image: Image.Image) -> None:
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
if isinstance(image_embeds, dict):
@ -761,6 +785,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
def parse_pil_image(self, image: Image.Image) -> None:
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
@ -902,6 +930,8 @@ _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
# Parser for supporting raw multimodal data format
_PILImageParser = TypeAdapter(ChatCompletionContentPartPILImageParam).validate_python # noqa: E501
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]
@ -912,6 +942,8 @@ MM_PARSER_MAP: dict[
] = {
"text":
lambda part: _TextParser(part).get("text", None),
"image":
lambda part: _PILImageParser(part).get("image", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds":
@ -985,7 +1017,7 @@ def _parse_chat_message_content_mm_part(
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds",
"image_embeds", "image",
"audio_url", "input_audio", "video_url")
@ -1056,6 +1088,10 @@ def _parse_chat_message_content_part(
else:
return str_content
if part_type == "image":
image = cast(Image.Image, content)
mm_parser.parse_pil_image(image)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_url":
str_content = cast(str, content)
mm_parser.parse_image(str_content)