mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[training_utils] fix: allow empty image_key/video_key in rl dataset (#3281)
This commit is contained in:
@ -78,8 +78,9 @@ class CustomRLHFDataset(RLHFDataset):
|
|||||||
multi_modal_data = {}
|
multi_modal_data = {}
|
||||||
|
|
||||||
images = None
|
images = None
|
||||||
if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None:
|
row_dict_images = row_dict.pop(self.image_key, None)
|
||||||
images = [Image.open(io.BytesIO(image["bytes"])) for image in row_dict.pop(self.image_key)]
|
if row_dict_images:
|
||||||
|
images = [Image.open(io.BytesIO(image["bytes"])) for image in row_dict_images]
|
||||||
|
|
||||||
# due to the image key is "image" instead of "images" in vllm, we need to use "image" here
|
# due to the image key is "image" instead of "images" in vllm, we need to use "image" here
|
||||||
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 # noqa: E501
|
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 # noqa: E501
|
||||||
|
@ -157,8 +157,16 @@ class RLHFDataset(Dataset):
|
|||||||
raw_prompt = self.processor.apply_chat_template(
|
raw_prompt = self.processor.apply_chat_template(
|
||||||
messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
|
messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
|
||||||
)
|
)
|
||||||
images = [process_image(image) for image in doc[image_key]] if image_key in doc else None
|
images = (
|
||||||
videos = [process_video(video) for video in doc[video_key]] if video_key in doc else None
|
[process_image(image) for image in doc[image_key]]
|
||||||
|
if image_key in doc and doc[image_key]
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
videos = (
|
||||||
|
[process_video(video) for video in doc[video_key]]
|
||||||
|
if video_key in doc and doc[video_key]
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0])
|
return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0])
|
||||||
|
|
||||||
@ -230,16 +238,18 @@ class RLHFDataset(Dataset):
|
|||||||
multi_modal_data = {}
|
multi_modal_data = {}
|
||||||
|
|
||||||
images = None
|
images = None
|
||||||
if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None:
|
row_dict_images = row_dict.pop(self.image_key, None)
|
||||||
images = [process_image(image) for image in row_dict.pop(self.image_key)]
|
if row_dict_images:
|
||||||
|
images = [process_image(image) for image in row_dict_images]
|
||||||
|
|
||||||
# due to the image key is "image" instead of "images" in vllm, we need to use "image" here
|
# due to the image key is "image" instead of "images" in vllm, we need to use "image" here
|
||||||
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
|
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
|
||||||
multi_modal_data["image"] = images
|
multi_modal_data["image"] = images
|
||||||
|
|
||||||
videos = None
|
videos = None
|
||||||
if self.video_key in row_dict and row_dict.get(self.video_key, None) is not None:
|
row_dict_videos = row_dict.pop(self.video_key, None)
|
||||||
videos = [process_video(video) for video in row_dict.pop(self.video_key)]
|
if row_dict_videos:
|
||||||
|
videos = [process_video(video) for video in row_dict_videos]
|
||||||
|
|
||||||
# due to the video key is "video" instead of "videos" in vllm, we need to use "video" here
|
# due to the video key is "video" instead of "videos" in vllm, we need to use "video" here
|
||||||
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
|
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
|
||||||
|
Reference in New Issue
Block a user