mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
helper for structured data
This commit is contained in:
@ -346,74 +346,75 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
else:
|
||||
raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.")
|
||||
|
||||
def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
# Handle images
|
||||
images = [example.get("images", []) for example in examples]
|
||||
# Transformers requires at least one image in the batch, otherwise it throws an error
|
||||
if all(img_list == [] for img_list in images):
|
||||
images = None
|
||||
def _has_structured_content(self, messages: list[dict]) -> tuple[bool, bool]:
|
||||
"""
|
||||
Check if messages contain structured content with images or videos.
|
||||
|
||||
# Handle videos
|
||||
videos = [example.get("videos", []) for example in examples]
|
||||
if all(vid_list == [] for vid_list in videos):
|
||||
videos = None
|
||||
|
||||
# Check if messages contain structured content with images or videos
|
||||
Returns:
|
||||
tuple[bool, bool]: (has_image_content, has_video_content)
|
||||
"""
|
||||
has_image_content = False
|
||||
has_video_content = False
|
||||
if "messages" in examples[0]: # conversational case
|
||||
|
||||
if messages and isinstance(messages, list):
|
||||
for msg in messages:
|
||||
if isinstance(msg.get("content"), list):
|
||||
for item in msg["content"]:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "image":
|
||||
has_image_content = True
|
||||
elif item.get("type") == "video":
|
||||
has_video_content = True
|
||||
if has_image_content and has_video_content:
|
||||
break
|
||||
|
||||
return has_image_content, has_video_content
|
||||
|
||||
def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
# Extract images and videos from examples
|
||||
images = [example.get("images", []) for example in examples]
|
||||
videos = [example.get("videos", []) for example in examples]
|
||||
images = None if all(img == [] for img in images) else images
|
||||
videos = None if all(vid == [] for vid in videos) else videos
|
||||
|
||||
# Apply chat template for conversational data
|
||||
if "messages" in examples[0]:
|
||||
messages_list = [example["messages"] for example in examples]
|
||||
if messages_list and isinstance(messages_list[0], list):
|
||||
for msg in messages_list[0]:
|
||||
if isinstance(msg.get("content"), list):
|
||||
for item in msg["content"]:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "image":
|
||||
has_image_content = True
|
||||
elif item.get("type") == "video":
|
||||
has_video_content = True
|
||||
if has_image_content and has_video_content:
|
||||
break
|
||||
# Check if messages use structured content format ({"type": "image"} or {"type": "video"})
|
||||
has_image_content, has_video_content = self._has_structured_content(messages_list[0])
|
||||
|
||||
# For images/videos with structured content, pass them to apply_chat_template
|
||||
kwargs = {}
|
||||
if has_image_content and images is not None:
|
||||
kwargs["images"] = images
|
||||
if has_video_content and videos is not None:
|
||||
kwargs["videos"] = videos
|
||||
|
||||
if kwargs:
|
||||
texts = self.processor.apply_chat_template(messages_list, **kwargs)
|
||||
else:
|
||||
texts = self.processor.apply_chat_template(messages_list)
|
||||
elif self.dataset_text_field in examples[0]: # standard case
|
||||
# For structured content, pass images/videos to apply_chat_template for extraction
|
||||
template_kwargs = {}
|
||||
if has_image_content and images:
|
||||
template_kwargs["images"] = images
|
||||
if has_video_content and videos:
|
||||
template_kwargs["videos"] = videos
|
||||
texts = self.processor.apply_chat_template(messages_list, **template_kwargs)
|
||||
elif self.dataset_text_field in examples[0]:
|
||||
texts = [example[self.dataset_text_field] for example in examples]
|
||||
has_image_content = has_video_content = False
|
||||
else:
|
||||
raise KeyError(
|
||||
"The input examples must contain either 'messages' for conversational data or 'text' for standard "
|
||||
"data."
|
||||
"The input examples must contain either 'messages' for conversational data or 'text' for standard data."
|
||||
)
|
||||
|
||||
# Process with images and videos
|
||||
# Build processor kwargs
|
||||
processor_kwargs = {
|
||||
"text": texts,
|
||||
"padding": True,
|
||||
"padding_side": "right",
|
||||
"pad_to_multiple_of": self.pad_to_multiple_of,
|
||||
"return_tensors": self.return_tensors,
|
||||
"add_special_tokens": False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
# Pass truncation parameters to processor if max_length is set
|
||||
# The processor will handle truncation appropriately for both images and videos
|
||||
if self.max_length is not None:
|
||||
processor_kwargs["truncation"] = True
|
||||
processor_kwargs["max_length"] = self.max_length
|
||||
|
||||
# Don't pass images/videos to processor if they're already in structured content
|
||||
# The processor will extract them from the formatted text
|
||||
if images is not None and not has_image_content:
|
||||
# Add images/videos to processor only if not already in structured content
|
||||
if images and not has_image_content:
|
||||
processor_kwargs["images"] = images
|
||||
if videos is not None and not has_video_content:
|
||||
if videos and not has_video_content:
|
||||
processor_kwargs["videos"] = videos
|
||||
|
||||
output = self.processor(**processor_kwargs)
|
||||
@ -432,25 +433,28 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
"Padding to a multiple of a value is not yet implemented for vision-language modeling and "
|
||||
"prompt-completion data yet."
|
||||
)
|
||||
# Handle images
|
||||
# Extract images and videos from examples
|
||||
images = [example.get("images", []) for example in examples]
|
||||
# Transformers requires at least one image in the batch, otherwise it throws an error
|
||||
if all(img_list == [] for img_list in images):
|
||||
images = None
|
||||
|
||||
# Handle videos
|
||||
videos = [example.get("videos", []) for example in examples]
|
||||
if all(vid_list == [] for vid_list in videos):
|
||||
videos = None
|
||||
images = None if all(img == [] for img in images) else images
|
||||
videos = None if all(vid == [] for vid in videos) else videos
|
||||
|
||||
# Apply chat template for conversational data
|
||||
if is_conversational(examples[0]):
|
||||
# Check if messages use structured content format
|
||||
first_prompt_completion = examples[0]["prompt"] + examples[0]["completion"]
|
||||
has_image_content, has_video_content = self._has_structured_content(first_prompt_completion)
|
||||
|
||||
# For non-structured content, add image placeholders (videos require structured content)
|
||||
if not (has_image_content or has_video_content):
|
||||
for example in examples:
|
||||
num_images = len(example.get("images", []))
|
||||
if num_images > 0 and not example.get("videos"):
|
||||
prepare_multimodal_messages(example["prompt"] + example["completion"], num_images=num_images)
|
||||
|
||||
if is_conversational(examples[0]): # conversational case
|
||||
for example in examples:
|
||||
num_images = len(example.get("images", []))
|
||||
num_videos = len(example.get("videos", []))
|
||||
# Only prepare multimodal messages for images; videos use native <video> tags
|
||||
if num_images > 0 and num_videos == 0:
|
||||
prepare_multimodal_messages(example["prompt"] + example["completion"], num_images=num_images)
|
||||
examples = [apply_chat_template(example, self.processor) for example in examples]
|
||||
else:
|
||||
has_image_content = has_video_content = False
|
||||
|
||||
prompts = [example["prompt"] for example in examples]
|
||||
completions = [example["completion"] for example in examples]
|
||||
@ -461,11 +465,12 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
"padding": True,
|
||||
"padding_side": "left",
|
||||
"return_tensors": self.return_tensors,
|
||||
"add_special_tokens": False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
if images is not None:
|
||||
# Add images/videos to processor only if not already in structured content
|
||||
if images and not has_image_content:
|
||||
prompt_kwargs["images"] = images
|
||||
if videos is not None:
|
||||
if videos and not has_video_content:
|
||||
prompt_kwargs["videos"] = videos
|
||||
|
||||
processed_prompts = self.processor(**prompt_kwargs)
|
||||
|
Reference in New Issue
Block a user