mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
* first commit * uncomment * other tests adaptations * Remove unused variable in test_setup_chat_format * Remove unused import statement * style * Add Bart model * Update BCOTrainerTester class in test_bco_trainer.py * Update model IDs and tokenizers in test files * Add new models and processors * Update model IDs in test files * Fix formatting issue in test_dataset_formatting.py * Refactor dataset formatting in test_dataset_formatting.py * Fix dataset sequence length in SFTTrainerTester * Remove tokenizer * Remove print statement * Add reward_model_path and sft_model_path to PPO trainer * Fix tokenizer padding issue * Add chat template for testing purposes in PaliGemma model * Update PaliGemma model and chat template * Increase learning rate to speed up test * Update model names in run_dpo.sh and run_sft.sh scripts * Update model and dataset names * Fix formatting issue in test_dataset_formatting.py * Fix formatting issue in test_dataset_formatting.py * Remove unused chat template * Update model generation script * additional models * Update model references in test files * Remove unused imports in test_online_dpo_trainer.py * Add is_llm_blender_available import and update reward_tokenizer * Refactor test_online_dpo_trainer.py: Move skipped test case decorator * remove models without chat templates * Update model names in scripts and tests * Update model_id in test_modeling_value_head.py * Update model versions in test files * Fix formatting issue in test_dataset_formatting.py * Update embedding model ID in BCOTrainerTester * Update test_online_dpo_trainer.py with reward model changes * Update expected formatted text in test_dataset_formatting.py * Add reward_tokenizer to TestOnlineDPOTrainer * fix tests * Add SIMPLE_CHAT_TEMPLATE to T5 tokenizer * Fix dummy_text format in test_rloo_trainer.py * Skip outdated test for chatML data collator * Add new vision language models * Commented out unused model IDs in test_vdpo_trainer * Update model and vision configurations in generate_tiny_models.py and test_dpo_trainer.py * Update model and tokenizer references * Don't push if it already exists * Add comment explaining test skip * Fix model_exists function call and add new models * Update LlavaForConditionalGeneration model and processor * `qgallouedec` -> `trl-internal-testing`
366 lines
15 KiB
Python
366 lines
15 KiB
Python
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import itertools
|
|
import unittest
|
|
|
|
from datasets import Dataset, DatasetDict
|
|
from parameterized import parameterized
|
|
from transformers import AutoTokenizer
|
|
|
|
from trl.data_utils import (
|
|
apply_chat_template,
|
|
extract_prompt,
|
|
is_conversational,
|
|
maybe_apply_chat_template,
|
|
maybe_extract_prompt,
|
|
maybe_unpair_preference_dataset,
|
|
unpair_preference_dataset,
|
|
)
|
|
|
|
|
|
class IsConversationalTester(unittest.TestCase):
|
|
conversational_examples = [
|
|
{ # Language modeling
|
|
"messages": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
},
|
|
{ # Prompt only
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
},
|
|
{ # Pompt-completion
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"completion": [{"role": "assistant", "content": "It is blue."}],
|
|
},
|
|
{ # Preference
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
|
"rejected": [{"role": "assistant", "content": "It is green."}],
|
|
},
|
|
{ # Preference with implicit prompt
|
|
"chosen": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is green."},
|
|
],
|
|
},
|
|
{ # Unpaired preference
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"completion": [{"role": "assistant", "content": "It is blue."}],
|
|
"label": True,
|
|
},
|
|
]
|
|
|
|
non_conversational_examples = [
|
|
{"prompt": "The sky is", "completion": " blue."},
|
|
{"text": "The sky is blue."},
|
|
{"prompt": "The sky is"},
|
|
{"prompt": "The sky is", "chosen": " blue.", "rejected": " green."},
|
|
{"prompt": "The sky is", "completion": " blue.", "label": True},
|
|
]
|
|
|
|
@parameterized.expand(itertools.product(conversational_examples))
|
|
def test_conversational(self, example):
|
|
self.assertTrue(is_conversational(example))
|
|
|
|
@parameterized.expand(itertools.product(non_conversational_examples))
|
|
def test_non_conversational(self, example):
|
|
self.assertFalse(is_conversational(example))
|
|
|
|
|
|
class ApplyChatTemplateTester(unittest.TestCase):
|
|
tokenizers = [
|
|
"trl-internal-testing/tiny-CohereForCausalLM",
|
|
"trl-internal-testing/tiny-DbrxForCausalLM",
|
|
"trl-internal-testing/tiny-FalconMambaForCausalLM",
|
|
"trl-internal-testing/tiny-Gemma2ForCausalLM",
|
|
"trl-internal-testing/tiny-GemmaForCausalLM",
|
|
"trl-internal-testing/tiny-LlamaForCausalLM-3.1",
|
|
"trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
|
"trl-internal-testing/tiny-LlamaForCausalLM-3",
|
|
"trl-internal-testing/tiny-MistralForCausalLM-0.1",
|
|
"trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
|
"trl-internal-testing/tiny-Phi3ForCausalLM",
|
|
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
]
|
|
|
|
conversational_examples = [
|
|
{ # Language modeling
|
|
"messages": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
},
|
|
{ # Prompt only
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
},
|
|
{ # Pompt-completion
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"completion": [{"role": "assistant", "content": "It is blue."}],
|
|
},
|
|
{ # Preference
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
|
"rejected": [{"role": "assistant", "content": "It is green."}],
|
|
},
|
|
{ # Preference with implicit prompt
|
|
"chosen": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is green."},
|
|
],
|
|
},
|
|
{ # Unpaired preference
|
|
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
"completion": [{"role": "assistant", "content": "It is blue."}],
|
|
"label": True,
|
|
},
|
|
]
|
|
|
|
non_conversational_examples = [
|
|
{"prompt": "The sky is", "completion": " blue."},
|
|
{"text": "The sky is blue."},
|
|
{"prompt": "The sky is"},
|
|
{"prompt": "The sky is", "chosen": " blue.", "rejected": " green."},
|
|
{"chosen": "The sky is blue.", "rejected": "The sky is green."},
|
|
{"prompt": "The sky is", "completion": " blue.", "label": True},
|
|
]
|
|
|
|
@parameterized.expand(itertools.product(tokenizers, conversational_examples))
|
|
def test_apply_chat_template(self, tokenizer_id, example):
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
|
result = apply_chat_template(example, tokenizer)
|
|
|
|
# Checking if the result is a dictionary
|
|
self.assertIsInstance(result, dict)
|
|
|
|
# The chat template should be applied to the the following keys
|
|
for key in ["prompt", "chosen", "rejected", "completion"]:
|
|
if key in example:
|
|
self.assertIn(key, result)
|
|
self.assertIsInstance(result[key], str)
|
|
|
|
# Exception for messages, the key is "text" once the chat template is applied
|
|
if "messages" in example:
|
|
self.assertIn("text", result)
|
|
self.assertIsInstance(result["text"], str)
|
|
|
|
# The label should be kept
|
|
if "label" in example:
|
|
self.assertIn("label", result)
|
|
self.assertIsInstance(result["label"], bool)
|
|
self.assertEqual(result["label"], example["label"])
|
|
|
|
# both conversational and non-conversational examples
|
|
@parameterized.expand(itertools.product(tokenizers, conversational_examples + non_conversational_examples))
|
|
def test_maybe_apply_chat_template(self, tokenizer_id, example):
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
|
result = maybe_apply_chat_template(example, tokenizer)
|
|
|
|
# Checking if the result is a dictionary
|
|
self.assertIsInstance(result, dict)
|
|
|
|
# The chat template should be applied to the the following keys
|
|
for key in ["prompt", "chosen", "rejected", "completion"]:
|
|
if key in example:
|
|
self.assertIn(key, result)
|
|
self.assertIsInstance(result[key], str)
|
|
|
|
# Exception for messages, the key is "text" once the chat template is applied
|
|
if "messages" in example:
|
|
self.assertIn("text", result)
|
|
self.assertIsInstance(result["text"], str)
|
|
|
|
# The label should be kept
|
|
if "label" in example:
|
|
self.assertIn("label", result)
|
|
self.assertIsInstance(result["label"], bool)
|
|
self.assertEqual(result["label"], example["label"])
|
|
|
|
|
|
class UnpairPreferenceDatasetTester(unittest.TestCase):
|
|
paired_dataset = Dataset.from_dict(
|
|
{
|
|
"prompt": ["The sky is", "The sun is"],
|
|
"chosen": [" blue.", " in the sky."],
|
|
"rejected": [" green.", " in the sea."],
|
|
}
|
|
)
|
|
|
|
unpaired_dataset = Dataset.from_dict(
|
|
{
|
|
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
|
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
|
"label": [True, True, False, False],
|
|
}
|
|
)
|
|
|
|
def test_unpair_preference_dataset(self):
|
|
# Test that a paired dataset is correctly converted to unpaired
|
|
unpaired_dataset = unpair_preference_dataset(self.paired_dataset)
|
|
self.assertEqual(
|
|
unpaired_dataset.to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The paired dataset should be converted to unpaired.",
|
|
)
|
|
|
|
def test_unpair_preference_dataset_dict(self):
|
|
# Test that a paired dataset dict is correctly converted to unpaired
|
|
paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
|
|
unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict)
|
|
self.assertEqual(
|
|
unpaired_dataset_dict["abc"].to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The paired dataset should be converted to unpaired.",
|
|
)
|
|
|
|
def test_maybe_unpair_preference_dataset(self):
|
|
# Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset
|
|
unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset)
|
|
self.assertEqual(
|
|
unpaired_dataset.to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The paired dataset should be converted to unpaired.",
|
|
)
|
|
|
|
def test_maybe_unpair_preference_dataset_dict(self):
|
|
# Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset
|
|
paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
|
|
unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict)
|
|
self.assertEqual(
|
|
unpaired_dataset_dict["abc"].to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The paired dataset should be converted to unpaired.",
|
|
)
|
|
|
|
def test_maybe_unpair_preference_dataset_already_paired(self):
|
|
# Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset
|
|
unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset)
|
|
self.assertEqual(
|
|
unpaired_dataset.to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The unpaired dataset should remain unchanged.",
|
|
)
|
|
|
|
def test_maybe_unpair_preference_dataset_dict_already_paired(self):
|
|
# Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset
|
|
unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset}))
|
|
self.assertEqual(
|
|
unpaired_dataset_dict["abc"].to_dict(),
|
|
self.unpaired_dataset.to_dict(),
|
|
"The unpaired dataset should remain unchanged.",
|
|
)
|
|
|
|
|
|
class ExtractPromptTester(unittest.TestCase):
|
|
example_implicit_prompt_conversational = {
|
|
"chosen": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
{"role": "assistant", "content": "It is green."},
|
|
],
|
|
}
|
|
|
|
example_explicit_prompt_conversational = {
|
|
"prompt": [
|
|
{"role": "user", "content": "What color is the sky?"},
|
|
],
|
|
"chosen": [
|
|
{"role": "assistant", "content": "It is blue."},
|
|
],
|
|
"rejected": [
|
|
{"role": "assistant", "content": "It is green."},
|
|
],
|
|
}
|
|
|
|
example_implicit_prompt_standard = {
|
|
"chosen": "The sky is blue.",
|
|
"rejected": "The sky is green.",
|
|
}
|
|
|
|
example_explicit_prompt_standard = {
|
|
"prompt": "The sky is",
|
|
"chosen": " blue.",
|
|
"rejected": " green.",
|
|
}
|
|
|
|
def test_extract_prompt_conversational(self):
|
|
# Test that the prompt is correctly extracted from the dataset
|
|
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_conversational,
|
|
"The prompt is not correctly extracted from the dataset.",
|
|
)
|
|
|
|
def test_maybe_extract_prompt_conversational(self):
|
|
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
|
|
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_conversational,
|
|
"The prompt is not correctly extracted from the dataset.",
|
|
)
|
|
|
|
def test_maybe_extract_prompt_conversational_already_explicit(self):
|
|
# Test that the prompt remains unchanged with maybe_extract_prompt
|
|
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_conversational,
|
|
"The prompt should remain unchanged.",
|
|
)
|
|
|
|
def test_extract_prompt_standard(self):
|
|
# Test that the prompt is correctly extracted from the dataset
|
|
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_standard,
|
|
"The prompt is not correctly extracted from the dataset.",
|
|
)
|
|
|
|
def test_maybe_extract_prompt_standard(self):
|
|
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
|
|
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_standard,
|
|
"The prompt is not correctly extracted from the dataset.",
|
|
)
|
|
|
|
def test_maybe_extract_prompt_standard_already_explicit(self):
|
|
# Test that the prompt remains unchanged with maybe_extract_prompt
|
|
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard)
|
|
self.assertEqual(
|
|
example_extracted_prompt,
|
|
self.example_explicit_prompt_standard,
|
|
"The prompt should remain unchanged.",
|
|
)
|
|
|
|
|
|
# Run the tests
|
|
if __name__ == "__main__":
|
|
unittest.main()
|