mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
1011 lines
41 KiB
Python
1011 lines
41 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import itertools
|
|
import textwrap
|
|
import unittest
|
|
from time import strftime
|
|
|
|
from datasets import Dataset, DatasetDict
|
|
from parameterized import parameterized
|
|
from transformers import AutoProcessor, AutoTokenizer
|
|
|
|
from trl.data_utils import (
|
|
apply_chat_template,
|
|
extract_prompt,
|
|
is_conversational,
|
|
is_conversational_from_value,
|
|
maybe_apply_chat_template,
|
|
maybe_convert_to_chatml,
|
|
maybe_extract_prompt,
|
|
maybe_unpair_preference_dataset,
|
|
pack_dataset,
|
|
prepare_multimodal_messages,
|
|
truncate_dataset,
|
|
unpair_preference_dataset,
|
|
)
|
|
|
|
from .testing_utils import TrlTestCase
|
|
|
|
|
|
class PrepareMultimodalMessagesTester(unittest.TestCase):
|
|
def test_basic_user_assistant_conversation(self):
|
|
"""Test basic conversation with user and assistant messages."""
|
|
messages = [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
]
|
|
|
|
prepare_multimodal_messages(messages, num_images=1)
|
|
|
|
expected = [
|
|
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
|
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
|
]
|
|
|
|
self.assertEqual(messages, expected)
|
|
|
|
def test_first_user_message_gets_image(self):
|
|
"""Test that only the first user message gets an image placeholder."""
|
|
messages = [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
{"role": "user", "content": "How about the grass?"},
|
|
]
|
|
|
|
prepare_multimodal_messages(messages, num_images=1)
|
|
|
|
expected = [
|
|
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
|
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
|
{"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]},
|
|
]
|
|
|
|
self.assertEqual(messages, expected)
|
|
|
|
def test_multiple_images(self):
|
|
"""Test that multiple images are added to the first user message."""
|
|
messages = [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
]
|
|
|
|
prepare_multimodal_messages(messages, num_images=3)
|
|
|
|
expected = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image"},
|
|
{"type": "image"},
|
|
{"type": "image"},
|
|
{"type": "text", "text": "What color is the sky?"},
|
|
],
|
|
},
|
|
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
|
]
|
|
|
|
self.assertEqual(messages, expected)
|
|
|
|
def test_system_message_transformation(self):
|
|
"""Test that system messages are properly transformed."""
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant"},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
]
|
|
|
|
prepare_multimodal_messages(messages, num_images=1)
|
|
|
|
expected = [
|
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
|
|
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
|
]
|
|
|
|
self.assertEqual(messages, expected)
|
|
|
|
def test_already_prepared_messages_unchanged(self):
|
|
"""Test that messages with list content are not modified."""
|
|
messages = [
|
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
|
|
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
|
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
|
]
|
|
|
|
original = copy.deepcopy(messages)
|
|
prepare_multimodal_messages(messages, num_images=1)
|
|
|
|
self.assertEqual(messages, original)
|
|
|
|
def test_mixed_prepared_and_unprepared_messages(self):
|
|
"""Test handling of mixed prepared and unprepared messages."""
|
|
messages = [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
|
{"role": "user", "content": "What about the grass?"},
|
|
]
|
|
|
|
prepare_multimodal_messages(messages, num_images=1)
|
|
|
|
expected = [
|
|
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
|
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
|
{"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]},
|
|
]
|
|
|
|
self.assertEqual(messages, expected)
|
|
|
|
|
|
class IsConversationalTester(TrlTestCase):
|
|
conversational_examples = [
|
|
{ # Language modeling
|
|
"messages": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
},
|
|
{ # Prompt-only
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
},
|
|
{ # Prompt-completion
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"completion": [{"role": "assistant", "content": "It is blue."}],
|
|
},
|
|
{ # Preference
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
|
"rejected": [{"role": "assistant", "content": "It is green."}],
|
|
},
|
|
{ # Preference with implicit prompt
|
|
"chosen": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is green."},
|
|
],
|
|
},
|
|
{ # Unpaired preference
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"completion": [{"role": "assistant", "content": "It is blue."}],
|
|
"label": True,
|
|
},
|
|
{ # Language modeling with harmony
|
|
"messages": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
},
|
|
{ # Prompt-only with harmony
|
|
"prompt": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
},
|
|
{ # Prompt-completion with harmony
|
|
"prompt": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
"completion": [
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
},
|
|
{ # Preference with harmony
|
|
"prompt": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
"chosen": [
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."},
|
|
],
|
|
},
|
|
{ # Preference with implicit prompt and harmony
|
|
"chosen": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."},
|
|
],
|
|
},
|
|
{ # Unpaired preference with harmony
|
|
"prompt": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
"completion": [
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
"label": True,
|
|
},
|
|
]
|
|
|
|
non_conversational_examples = [
|
|
{"prompt": "The sky is", "completion": " blue."},
|
|
{"text": "The sky is blue."},
|
|
{"prompt": "The sky is"},
|
|
{"prompt": "The sky is", "chosen": " blue.", "rejected": " green."},
|
|
{"prompt": "The sky is", "completion": " blue.", "label": True},
|
|
]
|
|
|
|
@parameterized.expand(itertools.product(conversational_examples))
|
|
def test_conversational(self, example):
|
|
self.assertTrue(is_conversational(example))
|
|
|
|
@parameterized.expand(itertools.product(non_conversational_examples))
|
|
def test_non_conversational(self, example):
|
|
self.assertFalse(is_conversational(example))
|
|
|
|
|
|
class IsConversationalFromValueTester(TrlTestCase):
|
|
def test_positive_1(self):
|
|
example = {
|
|
"conversations": [
|
|
{"from": "user", "value": "What color is the sky?"},
|
|
{"from": "assistant", "value": "It is blue."},
|
|
],
|
|
}
|
|
self.assertTrue(is_conversational_from_value(example))
|
|
|
|
def test_negative_1(self):
|
|
example = {
|
|
"messages": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
}
|
|
self.assertFalse(is_conversational_from_value(example))
|
|
|
|
def test_negative_2(self):
|
|
example = {"text": "The sky is blue."}
|
|
self.assertFalse(is_conversational_from_value(example))
|
|
|
|
|
|
class ApplyChatTemplateTester(TrlTestCase):
|
|
tokenizers = [
|
|
"trl-internal-testing/tiny-CohereForCausalLM",
|
|
# "trl-internal-testing/tiny-DbrxForCausalLM",
|
|
"trl-internal-testing/tiny-DeepseekV3ForCausalLM",
|
|
"trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528",
|
|
"trl-internal-testing/tiny-FalconMambaForCausalLM",
|
|
"trl-internal-testing/tiny-Gemma2ForCausalLM",
|
|
"trl-internal-testing/tiny-GemmaForCausalLM",
|
|
"trl-internal-testing/tiny-GptOssForCausalLM",
|
|
"trl-internal-testing/tiny-LlamaForCausalLM-3.1",
|
|
"trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
|
"trl-internal-testing/tiny-LlamaForCausalLM-3",
|
|
"trl-internal-testing/tiny-MistralForCausalLM-0.1",
|
|
"trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
|
"trl-internal-testing/tiny-Phi3ForCausalLM",
|
|
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
"trl-internal-testing/tiny-Qwen3ForCausalLM",
|
|
]
|
|
|
|
conversational_examples = [
|
|
{ # Language modeling
|
|
"messages": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
},
|
|
{ # Prompt-only
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
},
|
|
{ # Prompt-completion
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"completion": [{"role": "assistant", "content": "It is blue."}],
|
|
},
|
|
{ # Preference
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
|
"rejected": [{"role": "assistant", "content": "It is green."}],
|
|
},
|
|
{ # Preference with implicit prompt
|
|
"chosen": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is green."},
|
|
],
|
|
},
|
|
{ # Unpaired preference
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"completion": [{"role": "assistant", "content": "It is blue."}],
|
|
"label": True,
|
|
},
|
|
]
|
|
|
|
non_conversational_examples = [
|
|
{"text": "The sky is blue."}, # Language modeling
|
|
{"prompt": "The sky is"}, # Prompt-only
|
|
{"prompt": "The sky is", "completion": " blue."}, # Prompt-completion
|
|
{"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, # Preference
|
|
{"chosen": "The sky is blue.", "rejected": "The sky is green."}, # Preference with implicit prompt
|
|
{"prompt": "The sky is", "completion": " blue.", "label": True}, # Unpaired preference
|
|
]
|
|
|
|
@parameterized.expand(itertools.product(tokenizers, conversational_examples))
|
|
def test_apply_chat_template(self, tokenizer_id, example):
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
|
result = apply_chat_template(example, tokenizer)
|
|
|
|
# Checking if the result is a dictionary
|
|
self.assertIsInstance(result, dict)
|
|
|
|
# The chat template should be applied to the following keys
|
|
for key in ["prompt", "chosen", "rejected", "completion"]:
|
|
if key in example:
|
|
self.assertIn(key, result)
|
|
self.assertIsInstance(result[key], str)
|
|
|
|
# Exception for messages, the key is "text" once the chat template is applied
|
|
if "messages" in example:
|
|
self.assertIn("text", result)
|
|
self.assertIsInstance(result["text"], str)
|
|
|
|
# The label should be kept
|
|
if "label" in example:
|
|
self.assertIn("label", result)
|
|
self.assertIsInstance(result["label"], bool)
|
|
self.assertEqual(result["label"], example["label"])
|
|
|
|
# both conversational and non-conversational examples
|
|
@parameterized.expand(itertools.product(tokenizers, conversational_examples + non_conversational_examples))
|
|
def test_maybe_apply_chat_template(self, tokenizer_id, example):
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
|
result = maybe_apply_chat_template(example, tokenizer)
|
|
|
|
# Checking if the result is a dictionary
|
|
self.assertIsInstance(result, dict)
|
|
|
|
# The chat template should be applied to the following keys
|
|
for key in ["prompt", "chosen", "rejected", "completion"]:
|
|
if key in example:
|
|
self.assertIn(key, result)
|
|
self.assertIsInstance(result[key], str)
|
|
|
|
# Exception for messages, the key is "text" once the chat template is applied
|
|
if "messages" in example:
|
|
self.assertIn("text", result)
|
|
self.assertIsInstance(result["text"], str)
|
|
|
|
# The label should be kept
|
|
if "label" in example:
|
|
self.assertIn("label", result)
|
|
self.assertIsInstance(result["label"], bool)
|
|
self.assertEqual(result["label"], example["label"])
|
|
|
|
def test_apply_chat_template_with_tools(self):
|
|
tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2")
|
|
|
|
# Define dummy test tools
|
|
def get_current_temperature(location: str):
|
|
"""
|
|
Gets the temperature at a given location.
|
|
|
|
Args:
|
|
location: The location to get the temperature for
|
|
"""
|
|
return 22.0
|
|
|
|
# Define test case
|
|
test_case = {
|
|
"prompt": [
|
|
{"content": "Whats the temperature in London?", "role": "user"},
|
|
]
|
|
}
|
|
# Test with tools
|
|
result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature])
|
|
|
|
# Verify tools are included in the output
|
|
self.assertIn("get_current_temperature", result_with_tools["prompt"])
|
|
|
|
# Test without tools
|
|
result_without_tools = apply_chat_template(test_case, tokenizer, tools=None)
|
|
|
|
# Verify tools are not included in the output
|
|
self.assertNotIn("get_current_temperature", result_without_tools["prompt"])
|
|
|
|
|
|
class ApplyChatTemplateHarmonyTester(TrlTestCase):
|
|
def test_language_modeling(self):
|
|
messages = {
|
|
"messages": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
}
|
|
output = apply_chat_template(
|
|
messages,
|
|
tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
|
reasoning_effort="low",
|
|
model_identity="You are HuggingGPT.",
|
|
)
|
|
|
|
# docstyle-ignore
|
|
expected = textwrap.dedent(f"""\
|
|
<|start|>system<|message|>You are HuggingGPT.
|
|
Knowledge cutoff: 2024-06
|
|
Current date: {strftime("%Y-%m-%d")}
|
|
|
|
Reasoning: low
|
|
|
|
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
|
Respond in a friendly manner.
|
|
|
|
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""")
|
|
|
|
self.assertEqual(output["text"], expected)
|
|
|
|
def test_prompt_only(self):
|
|
messages = {
|
|
"prompt": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
}
|
|
output = apply_chat_template(
|
|
messages,
|
|
tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
|
reasoning_effort="low",
|
|
model_identity="You are HuggingGPT.",
|
|
)
|
|
|
|
# docstyle-ignore
|
|
expected = textwrap.dedent(f"""\
|
|
<|start|>system<|message|>You are HuggingGPT.
|
|
Knowledge cutoff: 2024-06
|
|
Current date: {strftime("%Y-%m-%d")}
|
|
|
|
Reasoning: low
|
|
|
|
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
|
Respond in a friendly manner.
|
|
|
|
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
|
|
|
|
self.assertEqual(output["prompt"], expected)
|
|
|
|
def test_prompt_completion(self):
|
|
messages = {
|
|
"prompt": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
"completion": [
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
}
|
|
output = apply_chat_template(
|
|
messages,
|
|
tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
|
reasoning_effort="low",
|
|
model_identity="You are HuggingGPT.",
|
|
)
|
|
|
|
# docstyle-ignore
|
|
expected_prompt = textwrap.dedent(f"""\
|
|
<|start|>system<|message|>You are HuggingGPT.
|
|
Knowledge cutoff: 2024-06
|
|
Current date: {strftime("%Y-%m-%d")}
|
|
|
|
Reasoning: low
|
|
|
|
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
|
Respond in a friendly manner.
|
|
|
|
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
|
|
expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>"
|
|
|
|
self.assertEqual(output["prompt"], expected_prompt)
|
|
self.assertEqual(output["completion"], expected_completion)
|
|
|
|
def test_preference(self):
|
|
messages = {
|
|
"prompt": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
"chosen": [
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."},
|
|
],
|
|
}
|
|
output = apply_chat_template(
|
|
messages,
|
|
tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
|
reasoning_effort="low",
|
|
model_identity="You are HuggingGPT.",
|
|
)
|
|
|
|
# docstyle-ignore
|
|
expected_prompt = textwrap.dedent(f"""\
|
|
<|start|>system<|message|>You are HuggingGPT.
|
|
Knowledge cutoff: 2024-06
|
|
Current date: {strftime("%Y-%m-%d")}
|
|
|
|
Reasoning: low
|
|
|
|
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
|
Respond in a friendly manner.
|
|
|
|
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
|
|
expected_chosen = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>"
|
|
expected_rejected = "<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>"
|
|
|
|
self.assertEqual(output["prompt"], expected_prompt)
|
|
self.assertEqual(output["chosen"], expected_chosen)
|
|
self.assertEqual(output["rejected"], expected_rejected)
|
|
|
|
def test_preference_with_implicit_prompt(self):
|
|
messages = {
|
|
"chosen": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."},
|
|
],
|
|
}
|
|
output = apply_chat_template(
|
|
messages,
|
|
tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
|
reasoning_effort="low",
|
|
model_identity="You are HuggingGPT.",
|
|
)
|
|
|
|
# docstyle-ignore
|
|
expected_chosen = textwrap.dedent(f"""\
|
|
<|start|>system<|message|>You are HuggingGPT.
|
|
Knowledge cutoff: 2024-06
|
|
Current date: {strftime("%Y-%m-%d")}
|
|
|
|
Reasoning: low
|
|
|
|
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
|
Respond in a friendly manner.
|
|
|
|
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""")
|
|
|
|
# docstyle-ignore
|
|
expected_rejected = textwrap.dedent(f"""\
|
|
<|start|>system<|message|>You are HuggingGPT.
|
|
Knowledge cutoff: 2024-06
|
|
Current date: {strftime("%Y-%m-%d")}
|
|
|
|
Reasoning: low
|
|
|
|
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
|
Respond in a friendly manner.
|
|
|
|
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>""")
|
|
|
|
self.assertEqual(output["chosen"], expected_chosen)
|
|
self.assertEqual(output["rejected"], expected_rejected)
|
|
|
|
def test_unpaired_preference(self):
|
|
messages = {
|
|
"prompt": [
|
|
{"role": "system", "content": "Respond in a friendly manner."},
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
"completion": [
|
|
{"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
|
],
|
|
"label": True,
|
|
}
|
|
output = apply_chat_template(
|
|
messages,
|
|
tokenizer=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
|
reasoning_effort="low",
|
|
model_identity="You are HuggingGPT.",
|
|
)
|
|
|
|
# docstyle-ignore
|
|
expected_prompt = textwrap.dedent(f"""\
|
|
<|start|>system<|message|>You are HuggingGPT.
|
|
Knowledge cutoff: 2024-06
|
|
Current date: {strftime("%Y-%m-%d")}
|
|
|
|
Reasoning: low
|
|
|
|
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
|
Respond in a friendly manner.
|
|
|
|
<|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
|
|
expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>"
|
|
|
|
self.assertEqual(output["prompt"], expected_prompt)
|
|
self.assertEqual(output["completion"], expected_completion)
|
|
self.assertTrue(output["label"])
|
|
|
|
|
|
class UnpairPreferenceDatasetTester(TrlTestCase):
|
|
paired_dataset = Dataset.from_dict(
|
|
{
|
|
"prompt": ["The sky is", "The sun is"],
|
|
"chosen": [" blue.", " in the sky."],
|
|
"rejected": [" green.", " in the sea."],
|
|
}
|
|
)
|
|
|
|
unpaired_dataset = Dataset.from_dict(
|
|
{
|
|
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
|
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
|
"label": [True, True, False, False],
|
|
}
|
|
)
|
|
|
|
def test_unpair_preference_dataset(self):
|
|
# Test that a paired dataset is correctly converted to unpaired
|
|
unpaired_dataset = unpair_preference_dataset(self.paired_dataset)
|
|
self.assertEqual(
|
|
unpaired_dataset.to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The paired dataset should be converted to unpaired.",
|
|
)
|
|
|
|
def test_unpair_preference_dataset_dict(self):
|
|
# Test that a paired dataset dict is correctly converted to unpaired
|
|
paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
|
|
unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict)
|
|
self.assertEqual(
|
|
unpaired_dataset_dict["abc"].to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The paired dataset should be converted to unpaired.",
|
|
)
|
|
|
|
def test_maybe_unpair_preference_dataset(self):
|
|
# Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset
|
|
unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset)
|
|
self.assertEqual(
|
|
unpaired_dataset.to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The paired dataset should be converted to unpaired.",
|
|
)
|
|
|
|
def test_maybe_unpair_preference_dataset_dict(self):
|
|
# Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset
|
|
paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
|
|
unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict)
|
|
self.assertEqual(
|
|
unpaired_dataset_dict["abc"].to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The paired dataset should be converted to unpaired.",
|
|
)
|
|
|
|
def test_maybe_unpair_preference_dataset_already_paired(self):
|
|
# Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset
|
|
unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset)
|
|
self.assertEqual(
|
|
unpaired_dataset.to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The unpaired dataset should remain unchanged.",
|
|
)
|
|
|
|
def test_maybe_unpair_preference_dataset_dict_already_paired(self):
|
|
# Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset
|
|
unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset}))
|
|
self.assertEqual(
|
|
unpaired_dataset_dict["abc"].to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The unpaired dataset should remain unchanged.",
|
|
)
|
|
|
|
|
|
class ExtractPromptTester(TrlTestCase):
|
|
example_implicit_prompt_conversational = {
|
|
"chosen": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is green."},
|
|
],
|
|
}
|
|
|
|
example_explicit_prompt_conversational = {
|
|
"prompt": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
"chosen": [
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "assistant", "content": "It is green."},
|
|
],
|
|
}
|
|
|
|
example_implicit_prompt_standard = {
|
|
"chosen": "The sky is blue.",
|
|
"rejected": "The sky is green.",
|
|
}
|
|
|
|
example_explicit_prompt_standard = {
|
|
"prompt": "The sky is",
|
|
"chosen": " blue.",
|
|
"rejected": " green.",
|
|
}
|
|
|
|
def test_extract_prompt_conversational(self):
|
|
# Test that the prompt is correctly extracted from the dataset
|
|
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_conversational,
|
|
"The prompt is not correctly extracted from the dataset.",
|
|
)
|
|
|
|
def test_maybe_extract_prompt_conversational(self):
|
|
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
|
|
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_conversational,
|
|
"The prompt is not correctly extracted from the dataset.",
|
|
)
|
|
|
|
def test_maybe_extract_prompt_conversational_already_explicit(self):
|
|
# Test that the prompt remains unchanged with maybe_extract_prompt
|
|
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_conversational,
|
|
"The prompt should remain unchanged.",
|
|
)
|
|
|
|
def test_extract_prompt_standard(self):
|
|
# Test that the prompt is correctly extracted from the dataset
|
|
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_standard,
|
|
"The prompt is not correctly extracted from the dataset.",
|
|
)
|
|
|
|
def test_maybe_extract_prompt_standard(self):
|
|
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
|
|
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_standard,
|
|
"The prompt is not correctly extracted from the dataset.",
|
|
)
|
|
|
|
def test_maybe_extract_prompt_standard_already_explicit(self):
|
|
# Test that the prompt remains unchanged with maybe_extract_prompt
|
|
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_standard,
|
|
"The prompt should remain unchanged.",
|
|
)
|
|
|
|
|
|
class TestPackDatasetWrapped(TrlTestCase):
|
|
def test_with_dataset(self):
|
|
examples = {
|
|
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
|
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
|
}
|
|
dataset = Dataset.from_dict(examples)
|
|
seq_length = 3
|
|
expected_output = {
|
|
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
|
|
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
|
|
}
|
|
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
|
|
self.assertEqual(dataset.to_dict(), expected_output)
|
|
|
|
def test_with_iterable_dataset(self):
|
|
examples = {
|
|
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
|
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
|
}
|
|
dataset = Dataset.from_dict(examples).to_iterable_dataset()
|
|
seq_length = 3
|
|
expected_output = {
|
|
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
|
|
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
|
|
}
|
|
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
|
|
num_examples = len(examples[next(iter(examples))])
|
|
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
|
|
|
|
|
|
class TestPackDatasetBfd(TrlTestCase):
|
|
def test_simple(self):
|
|
examples = {
|
|
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
|
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
|
}
|
|
dataset = Dataset.from_dict(examples)
|
|
seq_length = 4
|
|
expected_output = {
|
|
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
|
|
"attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]],
|
|
"seq_lengths": [[4], [3, 1]],
|
|
}
|
|
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
|
|
self.assertEqual(dataset.to_dict(), expected_output)
|
|
|
|
def test_with_iterable_dataset(self):
|
|
examples = {
|
|
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
|
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
|
}
|
|
dataset = Dataset.from_dict(examples).to_iterable_dataset()
|
|
seq_length = 4
|
|
expected_output = {
|
|
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
|
|
"attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]],
|
|
"seq_lengths": [[4], [3, 1]],
|
|
}
|
|
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
|
|
num_examples = len(examples[next(iter(examples))])
|
|
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
|
|
|
|
def test_with_truncation(self):
|
|
examples = {
|
|
"input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]],
|
|
"attention_mask": [[1, 1, 1, 1, 1], [1, 1], [1, 1, 1, 1], [1]],
|
|
}
|
|
dataset = Dataset.from_dict(examples)
|
|
seq_length = 4
|
|
expected_output = {
|
|
"input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 12]],
|
|
"attention_mask": [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]],
|
|
"seq_lengths": [[4], [4], [2, 1]],
|
|
}
|
|
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
|
|
self.assertEqual(dataset.to_dict(), expected_output)
|
|
|
|
def test_with_non_power_of_2(self):
|
|
examples = {
|
|
"input_ids": [[1, 2, 3, 4, 5], [6], [7, 8, 9, 10], [11, 12, 13]],
|
|
"attention_mask": [[1, 0, 0, 1, 1], [0], [0, 1, 0, 0], [1, 0, 1]],
|
|
}
|
|
dataset = Dataset.from_dict(examples)
|
|
seq_length = 5
|
|
expected_output = {
|
|
"input_ids": [[1, 2, 3, 4, 5], [7, 8, 9, 10, 6], [11, 12, 13]],
|
|
"attention_mask": [[1, 0, 0, 1, 1], [0, 1, 0, 0, 0], [1, 0, 1]],
|
|
"seq_lengths": [[5], [4, 1], [3]],
|
|
}
|
|
dataset = pack_dataset(dataset, seq_length, strategy="bfd")
|
|
self.assertEqual(dataset.to_dict(), expected_output)
|
|
|
|
|
|
class TestTruncateExamples(TrlTestCase):
|
|
def test_with_dataset(self):
|
|
examples = {
|
|
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
|
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
|
}
|
|
dataset = Dataset.from_dict(examples)
|
|
max_length = 2
|
|
expected_output = {
|
|
"input_ids": [[1, 2], [4, 5], [8]],
|
|
"attention_mask": [[0, 1], [0, 0], [1]],
|
|
}
|
|
dataset = truncate_dataset(dataset, max_length)
|
|
self.assertEqual(dataset.to_dict(), expected_output)
|
|
|
|
def test_with_iterable_dataset(self):
|
|
examples = {
|
|
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
|
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
|
}
|
|
dataset = Dataset.from_dict(examples).to_iterable_dataset()
|
|
max_length = 2
|
|
expected_output = {
|
|
"input_ids": [[1, 2], [4, 5], [8]],
|
|
"attention_mask": [[0, 1], [0, 0], [1]],
|
|
}
|
|
dataset = truncate_dataset(dataset, max_length)
|
|
num_examples = len(examples[next(iter(examples))])
|
|
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
|
|
|
|
def test_with_extra_column(self):
|
|
examples = {
|
|
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
|
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
|
"my_column": ["a", "b", "c"],
|
|
}
|
|
dataset = Dataset.from_dict(examples)
|
|
max_length = 2
|
|
expected_output = {
|
|
"input_ids": [[1, 2], [4, 5], [8]],
|
|
"attention_mask": [[0, 1], [0, 0], [1]],
|
|
"my_column": ["a", "b", "c"],
|
|
}
|
|
dataset = truncate_dataset(dataset, max_length)
|
|
self.assertEqual(dataset.to_dict(), expected_output)
|
|
|
|
|
|
class TestMaybeConvertToChatML(TrlTestCase):
|
|
def test_with_conversations_key(self):
|
|
# Particular case where the key is "conversations": we rename it to "messages"
|
|
example = {
|
|
"conversations": [
|
|
{"from": "user", "value": "What color is the sky?"},
|
|
{"from": "assistant", "value": "It is blue."},
|
|
]
|
|
}
|
|
expected_output = {
|
|
"messages": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
]
|
|
}
|
|
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
|
|
|
|
def test_without_conversations_key(self):
|
|
# Same as before, but we don't rename the keys
|
|
example = {
|
|
"prompt": [{"from": "user", "value": "What color is the sky?"}],
|
|
"completion": [{"from": "assistant", "value": "It is blue."}],
|
|
}
|
|
expected_output = {
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"completion": [{"role": "assistant", "content": "It is blue."}],
|
|
}
|
|
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
|
|
|
|
def test_not_conversional(self):
|
|
# When not needed, the example should remain unchanged
|
|
example = {"text": "The sky is blue."}
|
|
self.assertEqual(maybe_convert_to_chatml(example), example)
|
|
|
|
def test_already_chatml(self):
|
|
# When the example is already in ChatML format, it should remain unchanged
|
|
example = {
|
|
"messages": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
]
|
|
}
|
|
self.assertEqual(maybe_convert_to_chatml(example), example)
|
|
|
|
|
|
# Run the tests
|
|
if __name__ == "__main__":
|
|
unittest.main()
|