[training_utils] fix: allow empty image_key/video_key in rl dataset (#3281)

This commit is contained in:
ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟
2025-08-31 12:35:00 +03:00
committed by GitHub
parent 98676e8add
commit 14227201ec
2 changed files with 19 additions and 8 deletions

View File

@ -78,8 +78,9 @@ class CustomRLHFDataset(RLHFDataset):
multi_modal_data = {}
images = None
if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None:
images = [Image.open(io.BytesIO(image["bytes"])) for image in row_dict.pop(self.image_key)]
row_dict_images = row_dict.pop(self.image_key, None)
if row_dict_images:
images = [Image.open(io.BytesIO(image["bytes"])) for image in row_dict_images]
# due to the image key is "image" instead of "images" in vllm, we need to use "image" here
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 # noqa: E501

View File

@ -157,8 +157,16 @@ class RLHFDataset(Dataset):
raw_prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
)
images = [process_image(image) for image in doc[image_key]] if image_key in doc else None
videos = [process_video(video) for video in doc[video_key]] if video_key in doc else None
images = (
[process_image(image) for image in doc[image_key]]
if image_key in doc and doc[image_key]
else None
)
videos = (
[process_video(video) for video in doc[video_key]]
if video_key in doc and doc[video_key]
else None
)
return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0])
@ -230,16 +238,18 @@ class RLHFDataset(Dataset):
multi_modal_data = {}
images = None
if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None:
images = [process_image(image) for image in row_dict.pop(self.image_key)]
row_dict_images = row_dict.pop(self.image_key, None)
if row_dict_images:
images = [process_image(image) for image in row_dict_images]
# due to the image key is "image" instead of "images" in vllm, we need to use "image" here
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
multi_modal_data["image"] = images
videos = None
if self.video_key in row_dict and row_dict.get(self.video_key, None) is not None:
videos = [process_video(video) for video in row_dict.pop(self.video_key)]
row_dict_videos = row_dict.pop(self.video_key, None)
if row_dict_videos:
videos = [process_video(video) for video in row_dict_videos]
# due to the video key is "video" instead of "videos" in vllm, we need to use "video" here
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205