diff --git a/docs/source/clis.md b/docs/source/clis.md index 0938dec26..6972960f0 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -219,6 +219,49 @@ trl dpo --config dpo_config.yaml +### Using dataset mixtures + +You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data. + + + + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: stanfordnlp/imdb + - path: roneneldan/TinyStories +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + + + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: BAAI/Infinity-Preference + - path: argilla/Capybara-Preferences +``` + +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` + + + + +To see all the available keywords for defining dataset mixtures, refer to the [`scripts.utils.DatasetConfig`] and [`DatasetMixtureConfig`] classes. + ## Getting the System Information You can get the system information by running the following command: diff --git a/docs/source/script_utils.md b/docs/source/script_utils.md index aba81bf9f..1ecb73756 100644 --- a/docs/source/script_utils.md +++ b/docs/source/script_utils.md @@ -10,3 +10,15 @@ - parse_args_and_config - parse_args_into_dataclasses - set_defaults_with_config + +## get_dataset + +[[autodoc]] get_dataset + +## DatasetConfig + +[[autodoc]] scripts.utils.DatasetConfig + +## DatasetMixtureConfig + +[[autodoc]] DatasetMixtureConfig diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index 24a616500..89567213c 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest from dataclasses import dataclass from unittest.mock import mock_open, patch -from trl import TrlParser +from datasets import DatasetDict, load_dataset + +from trl import DatasetMixtureConfig, TrlParser, get_dataset +from trl.scripts.utils import DatasetConfig @dataclass @@ -262,3 +266,162 @@ class TestTrlParser(unittest.TestCase): # Check that config values were applied to the subparser self.assertEqual(result_args[0].arg1, 2) # Default from config self.assertEqual(result_args[0].arg2, "config_value") # Default from config + + +class TestGetDataset(unittest.TestCase): + def test_single_dataset_with_config(self): + mixture_config = DatasetMixtureConfig( + datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")] + ) + result = get_dataset(mixture_config) + expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + self.assertEqual(expected["train"][:], result["train"][:]) + + def test_single_dataset_preference_config(self): + mixture_config = DatasetMixtureConfig( + datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_preference")] + ) + result = get_dataset(mixture_config) + expected = load_dataset("trl-internal-testing/zen", "standard_preference") + self.assertEqual(expected["train"][:], result["train"][:]) + + def test_single_dataset_streaming(self): + mixture_config = DatasetMixtureConfig( + datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")], + streaming=True, + ) + result = get_dataset(mixture_config) + expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + self.assertEqual(expected["train"].to_list(), list(result["train"])) + + def test_dataset_mixture_basic(self): + dataset_config1 = DatasetConfig( + path="trl-internal-testing/zen", name="standard_prompt_completion", split="train", columns=["prompt"] + ) + dataset_config2 = DatasetConfig( + path="trl-internal-testing/zen", name="standard_preference", split="train", columns=["prompt"] + ) + mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) + result = get_dataset(mixture_config) + self.assertIsInstance(result, DatasetDict) + self.assertIn("train", result) + train_dataset = result["train"] + self.assertEqual(train_dataset.column_names, ["prompt"]) + prompts = train_dataset["prompt"] + expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"]) + expected_second_half = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train") + self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"]) + + def test_dataset_mixture_with_weights(self): + dataset_config1 = DatasetConfig( + path="trl-internal-testing/zen", name="standard_prompt_completion", split="train[:50%]", columns=["prompt"] + ) + dataset_config2 = DatasetConfig( + path="trl-internal-testing/zen", name="standard_preference", split="train[:50%]", columns=["prompt"] + ) + mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) + result = get_dataset(mixture_config) + self.assertIsInstance(result, DatasetDict) + self.assertIn("train", result) + train_dataset = result["train"] + self.assertEqual(train_dataset.column_names, ["prompt"]) + prompts = train_dataset["prompt"] + expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train[:50%]") + self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"]) + expected_second_half = load_dataset( + "trl-internal-testing/zen", "standard_prompt_completion", split="train[:50%]" + ) + self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"]) + + def test_dataset_mixture_with_test_split(self): + mixture_config = DatasetMixtureConfig( + datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")], + test_split_size=2, + ) + result = get_dataset(mixture_config) + self.assertIsInstance(result, DatasetDict) + self.assertIn("train", result) + self.assertIn("test", result) + self.assertEqual(len(result["train"]), 15) + self.assertEqual(len(result["test"]), 2) + + def test_empty_dataset_mixture_raises_error(self): + mixture_config = DatasetMixtureConfig(datasets=[]) + + with self.assertRaises(ValueError) as context: + get_dataset(mixture_config) + + self.assertIn("No datasets were loaded", str(context.exception)) + + def test_mixture_multiple_different_configs(self): + dataset_config1 = DatasetConfig( + path="trl-internal-testing/zen", name="conversational_preference", split="train", columns=["prompt"] + ) + dataset_config2 = DatasetConfig( + path="trl-internal-testing/zen", name="conversational_prompt_only", split="test" + ) + mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) + result = get_dataset(mixture_config) + self.assertIsInstance(result, DatasetDict) + self.assertIn("train", result) + self.assertGreater(len(result["train"]), 0) + + def test_trlparser_parses_yaml_config_correctly(self): + # Prepare YAML content exactly like your example + yaml_content = """ + datasets: + - path: trl-internal-testing/zen + name: standard_prompt_only + - path: trl-internal-testing/zen + name: standard_preference + columns: + - prompt + """ + + # Write YAML to a temporary file + with tempfile.NamedTemporaryFile("w+", suffix=".yaml") as tmpfile: + tmpfile.write(yaml_content) + tmpfile.flush() + parser = TrlParser((DatasetMixtureConfig,)) + args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0] + + # Assert that we got DatasetMixtureConfig instance + self.assertIsInstance(args, DatasetMixtureConfig) + + # Assert datasets list length + self.assertEqual(len(args.datasets), 2) + + # Check first dataset + dataset_config1 = args.datasets[0] + self.assertIsInstance(dataset_config1, DatasetConfig) + self.assertEqual(dataset_config1.path, "trl-internal-testing/zen") + self.assertEqual(dataset_config1.name, "standard_prompt_only") + self.assertIsNone(dataset_config1.columns) # No columns specified + + # Check second dataset + dataset_config2 = args.datasets[1] + self.assertIsInstance(dataset_config2, DatasetConfig) + self.assertEqual(dataset_config2.path, "trl-internal-testing/zen") + self.assertEqual(dataset_config2.name, "standard_preference") + self.assertEqual(dataset_config2.columns, ["prompt"]) # Columns specified + + def test_trlparser_parses_yaml_and_loads_dataset(self): + # Prepare YAML content exactly like your example + yaml_content = """ + datasets: + - path: trl-internal-testing/zen + name: standard_language_modeling + """ + + # Write YAML to a temporary file + with tempfile.NamedTemporaryFile("w+", suffix=".yaml") as tmpfile: + tmpfile.write(yaml_content) + tmpfile.flush() + parser = TrlParser((DatasetMixtureConfig,)) + args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0] + + # Load the dataset using get_dataset + result = get_dataset(args) + expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") + self.assertEqual(expected["train"][:], result["train"][:]) diff --git a/trl/__init__.py b/trl/__init__.py index 76710789e..eae4d6506 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -20,7 +20,7 @@ from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffus _import_structure = { - "scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"], + "scripts": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"], "data_utils": [ "apply_chat_template", "extract_prompt", @@ -136,7 +136,7 @@ if TYPE_CHECKING: create_reference_model, setup_chat_format, ) - from .scripts import ScriptArguments, TrlParser, init_zero_verbose + from .scripts import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose from .trainer import ( AlignPropConfig, AlignPropTrainer, diff --git a/trl/scripts/__init__.py b/trl/scripts/__init__.py index 272a717db..720a91a68 100644 --- a/trl/scripts/__init__.py +++ b/trl/scripts/__init__.py @@ -18,11 +18,11 @@ from ..import_utils import _LazyModule _import_structure = { - "utils": ["init_zero_verbose", "ScriptArguments", "TrlParser"], + "utils": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"], } if TYPE_CHECKING: - from .utils import ScriptArguments, TrlParser, init_zero_verbose + from .utils import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose else: import sys diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 7d435ffe6..111ba3dbd 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -61,17 +61,20 @@ python trl/scripts/dpo.py \ """ import argparse +import warnings import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import ( + DatasetMixtureConfig, DPOConfig, DPOTrainer, ModelConfig, ScriptArguments, TrlParser, + get_dataset, get_kbit_device_map, get_peft_config, get_quantization_config, @@ -79,7 +82,7 @@ from trl import ( from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -def main(script_args, training_args, model_args): +def main(script_args, training_args, model_args, dataset_args): ################ # Model & Tokenizer ################### @@ -118,18 +121,22 @@ def main(script_args, training_args, model_args): name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool ] - ################ - # Dataset - ################ - dataset = load_dataset( - script_args.dataset_name, - name=script_args.dataset_config, - streaming=script_args.dataset_streaming, - ) + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + warnings.warn( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") - ########## - # Training - ################ + # Initialize the DPO trainer trainer = DPOTrainer( model, ref_model, @@ -140,6 +147,7 @@ def main(script_args, training_args, model_args): peft_config=peft_config, ) + # Train the model trainer.train() if training_args.eval_strategy != "no": @@ -147,14 +155,14 @@ def main(script_args, training_args, model_args): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) - # Save and push to hub + # Save and push to Hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) def make_parser(subparsers: argparse._SubParsersAction = None): - dataclass_types = (ScriptArguments, DPOConfig, ModelConfig) + dataclass_types = (ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig) if subparsers is not None: parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types) else: @@ -164,5 +172,10 @@ def make_parser(subparsers: argparse._SubParsersAction = None): if __name__ == "__main__": parser = make_parser() - script_args, training_args, model_args = parser.parse_args_and_config() - main(script_args, training_args, model_args) + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/trl/scripts/grpo.py b/trl/scripts/grpo.py index 8832c9293..8455d8a58 100644 --- a/trl/scripts/grpo.py +++ b/trl/scripts/grpo.py @@ -23,13 +23,22 @@ import argparse import importlib import os import sys +import warnings from dataclasses import dataclass, field from typing import Optional from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config +from trl import ( + DatasetMixtureConfig, + GRPOConfig, + GRPOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, + get_dataset, + get_peft_config, +) from trl.rewards import think_format_reward @@ -68,22 +77,11 @@ class GRPOScriptArguments(ScriptArguments): ) -def main(script_args, training_args, model_args): - # Load a pretrained model - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code - ) - +def main(script_args, training_args, model_args, dataset_args): # Get the reward models and functions reward_funcs = [] if script_args.reward_model_name_or_path: - reward_model = AutoModelForSequenceClassification.from_pretrained( - script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 - ) - reward_funcs.append(reward_model) + reward_funcs.append(script_args.reward_model_name_or_path) if script_args.reward_funcs: for func_name in script_args.reward_funcs: @@ -102,30 +100,41 @@ def main(script_args, training_args, model_args): ) # Load the dataset - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + if dataset_args.datasets and script_args.dataset_name: + warnings.warn( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") # Initialize the GRPO trainer trainer = GRPOTrainer( - model=model, + model=model_args.model_name_or_path, reward_funcs=reward_funcs, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - processing_class=tokenizer, peft_config=get_peft_config(model_args), ) - # Train and push the model to the Hub + # Train the model trainer.train() - # Save and push to hub + # Save and push to Hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) def make_parser(subparsers: argparse._SubParsersAction = None): - dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig) + dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig, DatasetMixtureConfig) if subparsers is not None: parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types) else: @@ -135,5 +144,10 @@ def make_parser(subparsers: argparse._SubParsersAction = None): if __name__ == "__main__": parser = make_parser() - script_args, training_args, model_args = parser.parse_args_and_config() - main(script_args, training_args, model_args) + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/trl/scripts/kto.py b/trl/scripts/kto.py index 58a064db6..9529aec22 100644 --- a/trl/scripts/kto.py +++ b/trl/scripts/kto.py @@ -65,22 +65,25 @@ python trl/scripts/kto.py \ """ import argparse +import warnings from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import ( + DatasetMixtureConfig, KTOConfig, KTOTrainer, ModelConfig, ScriptArguments, TrlParser, + get_dataset, get_peft_config, setup_chat_format, ) -def main(script_args, training_args, model_args): +def main(script_args, training_args, model_args, dataset_args): # Load a pretrained model model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code @@ -100,7 +103,19 @@ def main(script_args, training_args, model_args): model, tokenizer = setup_chat_format(model, tokenizer) # Load the dataset - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + if dataset_args.datasets and script_args.dataset_name: + warnings.warn( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") # Initialize the KTO trainer trainer = KTOTrainer( @@ -113,17 +128,17 @@ def main(script_args, training_args, model_args): peft_config=get_peft_config(model_args), ) - # Train and push the model to the Hub + # Train the model trainer.train() - # Save and push to hub + # Save and push to Hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) def make_parser(subparsers: argparse._SubParsersAction = None): - dataclass_types = (ScriptArguments, KTOConfig, ModelConfig) + dataclass_types = (ScriptArguments, KTOConfig, ModelConfig, DatasetMixtureConfig) if subparsers is not None: parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types) else: @@ -133,5 +148,10 @@ def make_parser(subparsers: argparse._SubParsersAction = None): if __name__ == "__main__": parser = make_parser() - script_args, training_args, model_args = parser.parse_args_and_config() - main(script_args, training_args, model_args) + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index b72f1ff8c..6b2f25681 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -61,25 +61,28 @@ python trl/scripts/sft.py \ """ import argparse +import warnings from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES from trl import ( + DatasetMixtureConfig, ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, clone_chat_template, + get_dataset, get_kbit_device_map, get_peft_config, get_quantization_config, ) -def main(script_args, training_args, model_args): +def main(script_args, training_args, model_args, dataset_args): ################ # Model init kwargs & Tokenizer ################ @@ -116,14 +119,22 @@ def main(script_args, training_args, model_args): # TODO: source should be passed as an argument model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B") - ################ - # Dataset - ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + warnings.warn( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") - ################ - # Training - ################ + # Initialize the SFT trainer trainer = SFTTrainer( model=model, args=training_args, @@ -133,16 +144,17 @@ def main(script_args, training_args, model_args): peft_config=get_peft_config(model_args), ) + # Train the model trainer.train() - # Save and push to hub + # Save and push to Hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) def make_parser(subparsers: argparse._SubParsersAction = None): - dataclass_types = (ScriptArguments, SFTConfig, ModelConfig) + dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig) if subparsers is not None: parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) else: @@ -155,5 +167,7 @@ if __name__ == "__main__": # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. # To ensure that their parsing does not interfere with the script arguments, parse the arguments with # `return_remaining_strings=True`, then ignore the remaining strings. - script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True) - main(script_args, training_args, model_args) + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index 030719ab1..8abd78eb0 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -23,7 +23,9 @@ from collections.abc import Iterable from dataclasses import dataclass, field from typing import Optional, Union +import datasets import yaml +from datasets import DatasetDict, concatenate_datasets from transformers import HfArgumentParser from transformers.hf_argparser import DataClass, DataClassType from transformers.utils import is_rich_available @@ -32,22 +34,121 @@ from transformers.utils import is_rich_available logger = logging.getLogger(__name__) +@dataclass +class DatasetConfig: + """ + Configuration for a dataset. + + This class matches the signature of [`~datasets.load_dataset`] and the arguments are used directly in the + `datasets.load_dataset` function. You can refer to the `datasets.load_dataset` documentation for more details. + + Parameters: + path (`str`): + Path or name of the dataset. + name (`str`, *optional*, defaults to `None`): + Defining the name of the dataset configuration. + data_dir (`str`, *optional*, defaults to `None`): + Defining the `data_dir` of the dataset configuration. If specified for the generic builders(csv, text + etc.) or the Hub datasets and `data_files` is `None`, the behavior is equal to passing + `os.path.join(data_dir, **)` as `data_files` to reference all the files in a directory. + data_files (`str` or `Sequence` or `Mapping`, *optional*, defaults to `None`): + Path(s) to source data file(s). + split (`str`, *optional*, defaults to `"train"`): + Which split of the data to load. + columns (`list[str]`, *optional*, defaults to `None`): + List of column names to select from the dataset. If `None`, all columns are selected. + """ + + path: str + name: Optional[str] = None + data_dir: Optional[str] = None + data_files: Optional[Union[str, list[str], dict[str, str]]] = None + split: str = "train" + columns: Optional[list[str]] = None + + +@dataclass +class DatasetMixtureConfig: + """ + Configuration class for a mixture of datasets. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + datasets (`list[DatasetConfig]`): + List of dataset configurations to include in the mixture. + streaming (`bool`, *optional*, defaults to `False`): + Whether to stream the datasets. If `True`, the datasets will be loaded in streaming mode. + test_split_size (`float` or `None`, *optional*, defaults to `None`): + Size of the test split. Refer to the `test_size` parameter in the [`~datasets.train_test_split`] function + for more details. If `None`, the dataset will not be split into train and test sets. + + Usage: + When using the CLI, you can add the following section to your YAML config file: + + ```yaml + datasets: + - path: ... + name: ... + data_dir: ... + data_files: ... + split: ... + columns: ... + - path: ... + name: ... + data_dir: ... + data_files: ... + split: ... + columns: ... + streaming: ... + test_split_size: ... + ``` + """ + + datasets: list[DatasetConfig] = field( + default_factory=list, + metadata={"help": "List of dataset configurations to include in the mixture."}, + ) + streaming: bool = field( + default=False, + metadata={"help": "Whether to stream the datasets. If True, the datasets will be loaded in streaming mode."}, + ) + test_split_size: Optional[float] = field( + default=None, + metadata={ + "help": "Size of the test split. Refer to the `test_size` parameter in the `datasets.train_test_split` " + "function for more details. If None, the dataset will not be split into train and test sets." + }, + ) + + def __post_init__(self): + # Convert any dataset dicts (from CLI/config parsing) into DatasetConfig objects + for idx, dataset in enumerate(self.datasets): + if isinstance(dataset, dict): + # If it's a dict, convert it to DatasetConfig + self.datasets[idx] = DatasetConfig(**dataset) + + @dataclass class ScriptArguments: """ Arguments common to all scripts. Args: - dataset_name (`str`): - Dataset name. + dataset_name (`str`, or `None`, *optional*, defaults to `None`): + Path or name of the dataset to load. If `datasets` is provided, this will be ignored. dataset_config (`str` or `None`, *optional*, defaults to `None`): Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. + If `datasets` is provided, this will be ignored. dataset_train_split (`str`, *optional*, defaults to `"train"`): - Dataset split to use for training. + Dataset split to use for training. If `datasets` is provided, this will be ignored. dataset_test_split (`str`, *optional*, defaults to `"test"`): - Dataset split to use for evaluation. + Dataset split to use for evaluation. If `datasets` is provided, this will be ignored. dataset_streaming (`bool`, *optional*, defaults to `False`): - Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. + Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. If `datasets` is + provided, this will be ignored. gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): Whether to apply `use_reentrant` for gradient checkpointing. ignore_bias_buffers (`bool`, *optional*, defaults to `False`): @@ -56,19 +157,31 @@ class ScriptArguments: https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. """ - dataset_name: Optional[str] = field(default=None, metadata={"help": "Dataset name."}) + dataset_name: Optional[str] = field( + default=None, + metadata={"help": "Path or name of the dataset to load. If `datasets` is provided, this will be ignored."}, + ) dataset_config: Optional[str] = field( default=None, metadata={ "help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` " - "function." + "function. If `datasets` is provided, this will be ignored." }, ) - dataset_train_split: str = field(default="train", metadata={"help": "Dataset split to use for training."}) - dataset_test_split: str = field(default="test", metadata={"help": "Dataset split to use for evaluation."}) + dataset_train_split: str = field( + default="train", + metadata={"help": "Dataset split to use for training. If `datasets` is provided, this will be ignored."}, + ) + dataset_test_split: str = field( + default="test", + metadata={"help": "Dataset split to use for evaluation. If `datasets` is provided, this will be ignored."}, + ) dataset_streaming: bool = field( default=False, - metadata={"help": "Whether to stream the dataset. If True, the dataset will be loaded in streaming mode."}, + metadata={ + "help": "Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. If " + "`datasets` is provided, this will be ignored." + }, ) gradient_checkpointing_use_reentrant: bool = field( default=False, @@ -282,3 +395,66 @@ def get_git_commit_hash(package_name): return None except Exception as e: return f"Error: {str(e)}" + + +def get_dataset(mixture_config: DatasetMixtureConfig) -> DatasetDict: + """ + Load a mixture of datasets based on the configuration. + + Args: + mixture_config (`DatasetMixtureConfig`): + Script arguments containing dataset configuration. + + Returns: + `DatasetDict`: + Combined dataset(s) from the mixture configuration, with optional train/test split if `test_split_size` is + set. + + Example: + ```python + from trl import DatasetMixtureConfig, get_dataset + from trl.scripts.utils import DatasetConfig + + mixture_config = DatasetMixtureConfig(datasets=[DatasetConfig(path="trl-lib/tldr")]) + dataset = get_dataset(mixture_config) + print(dataset) + ``` + + ``` + DatasetDict({ + train: Dataset({ + features: ['prompt', 'completion'], + num_rows: 116722 + }) + }) + ``` + """ + logger.info(f"Creating dataset mixture with {len(mixture_config.datasets)} datasets") + datasets_list = [] + for dataset_config in mixture_config.datasets: + logger.info(f"Loading dataset for mixture: {dataset_config.path} (config name: {dataset_config.name})") + dataset = datasets.load_dataset( + path=dataset_config.path, + name=dataset_config.name, + data_dir=dataset_config.data_dir, + data_files=dataset_config.data_files, + split=dataset_config.split, + streaming=mixture_config.streaming, + ) + if dataset_config.columns is not None: + dataset = dataset.select_columns(dataset_config.columns) + datasets_list.append(dataset) + + if datasets_list: + combined_dataset = concatenate_datasets(datasets_list) + if isinstance(combined_dataset, datasets.Dataset): # IterableDataset does not have a length + logger.info(f"Created dataset mixture with {len(combined_dataset)} examples") + + if mixture_config.test_split_size is not None: + logger.info(f"Spliting dataset into train and test sets with test size: {mixture_config.test_split_size}") + combined_dataset = combined_dataset.train_test_split(test_size=mixture_config.test_split_size) + return combined_dataset + else: + return DatasetDict({"train": combined_dataset}) + else: + raise ValueError("No datasets were loaded from the mixture configuration")