mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 02:53:59 +08:00
244 lines
22 KiB
Python
244 lines
22 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
import numpy as np
|
|
from datasets import Dataset, Features, Image, List, Value
|
|
from transformers import HfArgumentParser
|
|
|
|
|
|
Message = List({"content": List({"text": Value("string"), "type": Value("string")}), "role": Value("string")})
|
|
|
|
|
|
@dataclass
|
|
class ScriptArguments:
|
|
r"""
|
|
Arguments for the script.
|
|
|
|
Args:
|
|
test_size (`float`, *optional*, defaults to `0.1`):
|
|
Fraction of the dataset to include in the test split.
|
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
Whether to push the dataset to the Hugging Face Hub.
|
|
repo_id (`str`, *optional*, defaults to `"trl-internal-testing/zen-multi-image"`):
|
|
Hugging Face repository ID to push the dataset to.
|
|
"""
|
|
|
|
test_size: float = field(
|
|
default=0.1,
|
|
metadata={"help": "Fraction of the dataset to include in the test split."},
|
|
)
|
|
push_to_hub: bool = field(
|
|
default=False,
|
|
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
|
)
|
|
repo_id: str = field(
|
|
default="trl-internal-testing/zen-multi-image",
|
|
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
|
)
|
|
|
|
|
|
def main(test_size, push_to_hub, repo_id):
|
|
# fmt: off
|
|
messages = [
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than ugly?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Beautiful."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than implicit?"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Explicit."}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than complex?"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Simple."}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": "What is better than complicated?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Complex."}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than nested?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Flat."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than dense?"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Sparse."}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What counts?"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Readability."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Are special cases enough to break the rules?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "No, special cases aren't special enough to break the rules."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What beats purity?"}, {"type": "image"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Practicality."}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": "What should never pass silently?"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Errors."}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "When can errors pass silently?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "When explicitly silenced."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What should you do in the face of ambiguity?"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Refuse the temptation to guess."}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "How many ways should there be to do it?"}, {"type": "image"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "One, and preferably only one."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "For whom may the way not be obvious at first?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Dutch."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than never?"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Now is better than never."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Is"}, {"type": "image"}, {"type": "text", "text": " never better than *right* now?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Yes, often."}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What does it mean if the implementation is hard to explain?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "It means it's a bad idea."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What does it mean if the implementation is easy to explain?"}, {"type": "image"}]}, {"role": "assistant", "content": [{"type": "text", "text": "It means it may be a good idea."}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Any great ideas?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "Namespaces are one honking great idea."}]}],
|
|
]
|
|
# Create the images
|
|
number_of_images = [sum(1 for part in row[0]["content"] if part.get("type") == "image") for row in messages]
|
|
sizes = [np.random.randint(32, 64, size=(num_images, 2)) for num_images in number_of_images]
|
|
images = [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in s] for s in sizes]
|
|
conversational_language_modeling_dataset = Dataset.from_dict({"messages": messages, "images": images}, features=Features(messages=Message, images=List(Image())))
|
|
conversational_language_modeling_dataset = conversational_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False)
|
|
if push_to_hub:
|
|
conversational_language_modeling_dataset.push_to_hub(repo_id, config_name="conversational_language_modeling")
|
|
|
|
prompt = [
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than ugly?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than implicit?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than complex?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": "What is better than complicated?"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than nested?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than dense?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What counts?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Are special cases enough to break the rules?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What beats purity?"}, {"type": "image"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": "What should never pass silently?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "When can errors pass silently?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What should you do in the face of ambiguity?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "How many ways should there be to do it?"}, {"type": "image"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "For whom may the way not be obvious at first?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than never?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Is"}, {"type": "image"}, {"type": "text", "text": " never better than *right* now?"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What does it mean if the implementation is hard to explain?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What does it mean if the implementation is easy to explain?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Any great ideas?"}]}],
|
|
]
|
|
# Create the images
|
|
number_of_images = [sum(1 for part in row[0]["content"] if part.get("type") == "image") for row in prompt]
|
|
sizes = [np.random.randint(32, 64, size=(num_images, 2)) for num_images in number_of_images]
|
|
images = [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in s] for s in sizes]
|
|
conversational_prompt_only_dataset = Dataset.from_dict({"prompt": prompt, "images": images}, features=Features(prompt=Message, images=List(Image())))
|
|
conversational_prompt_only_dataset = conversational_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False)
|
|
if push_to_hub:
|
|
conversational_prompt_only_dataset.push_to_hub(repo_id, config_name="conversational_prompt_only")
|
|
|
|
prompt = [
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than ugly?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than implicit?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than complex?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": "What is better than complicated?"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than nested?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than dense?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What counts?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Are special cases enough to break the rules?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What beats purity?"}, {"type": "image"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": "What should never pass silently?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "When can errors pass silently?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What should you do in the face of ambiguity?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "How many ways should there be to do it?"}, {"type": "image"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "For whom may the way not be obvious at first?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than never?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Is"}, {"type": "image"}, {"type": "text", "text": " never better than *right* now?"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What does it mean if the implementation is hard to explain?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What does it mean if the implementation is easy to explain?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Any great ideas?"}]}],
|
|
]
|
|
completion = [
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Beautiful."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Explicit."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Simple."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Complex."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Flat."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Sparse."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Readability."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "No, special cases aren't special enough to break the rules."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Practicality."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Errors."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "When explicitly silenced."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Refuse the temptation to guess."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "One, and preferably only one."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Dutch."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Now is better than never."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Yes, often."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "It means it's a bad idea."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "It means it may be a good idea."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Namespaces are one honking great idea."}]}],
|
|
]
|
|
# Create the images
|
|
number_of_images = [sum(1 for part in row[0]["content"] if part.get("type") == "image") for row in prompt]
|
|
sizes = [np.random.randint(32, 64, size=(num_images, 2)) for num_images in number_of_images]
|
|
images = [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in s] for s in sizes]
|
|
conversational_prompt_completion_dataset = Dataset.from_dict({"prompt": prompt, "completion": completion, "images": images}, features=Features(prompt=Message, completion=Message, images=List(Image())))
|
|
conversational_prompt_completion_dataset = conversational_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False)
|
|
if push_to_hub:
|
|
conversational_prompt_completion_dataset.push_to_hub(repo_id, config_name="conversational_prompt_completion")
|
|
|
|
prompt = [
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than ugly?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than implicit?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than complex?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": "What is better than complicated?"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is better than nested?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than dense?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What counts?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Are special cases enough to break the rules?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What beats purity?"}, {"type": "image"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": "What should never pass silently?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "When can errors pass silently?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What should you do in the face of ambiguity?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "How many ways should there be to do it?"}, {"type": "image"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "For whom may the way not be obvious at first?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What is better than never?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Is"}, {"type": "image"}, {"type": "text", "text": " never better than *right* now?"}]}],
|
|
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What does it mean if the implementation is hard to explain?"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "What does it mean if the implementation is easy to explain?"}, {"type": "image"}]}],
|
|
[{"role": "user", "content": [{"type": "text", "text": "Any great ideas?"}]}],
|
|
]
|
|
chosen = [
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Beautiful."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Explicit."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Simple."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Complex."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Flat."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Sparse."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Readability."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "No, special cases aren't special enough to break the rules."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Practicality."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Errors."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "When explicitly silenced."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Refuse the temptation to guess."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "One, and preferably only one."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Dutch."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Now is better than never."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Yes, often."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "It means it's a bad idea."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "It means it may be a good idea."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Namespaces are one honking great idea."}]}],
|
|
]
|
|
rejected = [
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Acceptable."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Explained."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Very complex."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Very complicated."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Circular."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Heavy."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Looking complicated."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Yes, special cases are special enough to break the rules."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Nothing."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Warnings."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Never."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Give up."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "As many as possible."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "French."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Some day."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "No, never."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "It means it's a good idea."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "It means it's a bad idea."}]}],
|
|
[{"role": "assistant", "content": [{"type": "text", "text": "Recursion."}]}],
|
|
],
|
|
# Create the images
|
|
number_of_images = [sum(1 for part in row[0]["content"] if part.get("type") == "image") for row in prompt]
|
|
sizes = [np.random.randint(32, 64, size=(num_images, 2)) for num_images in number_of_images]
|
|
images = [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in s] for s in sizes]
|
|
conversational_preference_dataset = Dataset.from_dict({"prompt": prompt, "chosen": chosen, "rejected": rejected, "images": images}, features=Features(prompt=Message, chosen=Message, rejected=Message, images=List(Image())))
|
|
conversational_preference_dataset = conversational_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
|
|
if push_to_hub:
|
|
conversational_preference_dataset.push_to_hub(repo_id, config_name="conversational_preference")
|
|
# fmt: on
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = HfArgumentParser(ScriptArguments)
|
|
script_args = parser.parse_args_into_dataclasses()[0]
|
|
main(script_args.test_size, script_args.push_to_hub, script_args.repo_id)
|