🎚️ Add dataset mixer (#3791)

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
lewtun
2025-08-12 05:14:50 +02:00
committed by GitHub
parent de27d612b0
commit 72d4d82b8c
10 changed files with 528 additions and 73 deletions

View File

@ -219,6 +219,49 @@ trl dpo --config dpo_config.yaml
</hfoption> </hfoption>
</hfoptions> </hfoptions>
### Using dataset mixtures
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
<hfoptions id="accelerate_config">
<hfoption id="SFT">
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: stanfordnlp/imdb
- path: roneneldan/TinyStories
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: BAAI/Infinity-Preference
- path: argilla/Capybara-Preferences
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
</hfoptions>
To see all the available keywords for defining dataset mixtures, refer to the [`scripts.utils.DatasetConfig`] and [`DatasetMixtureConfig`] classes.
## Getting the System Information ## Getting the System Information
You can get the system information by running the following command: You can get the system information by running the following command:

View File

@ -10,3 +10,15 @@
- parse_args_and_config - parse_args_and_config
- parse_args_into_dataclasses - parse_args_into_dataclasses
- set_defaults_with_config - set_defaults_with_config
## get_dataset
[[autodoc]] get_dataset
## DatasetConfig
[[autodoc]] scripts.utils.DatasetConfig
## DatasetMixtureConfig
[[autodoc]] DatasetMixtureConfig

View File

@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
from dataclasses import dataclass from dataclasses import dataclass
from unittest.mock import mock_open, patch from unittest.mock import mock_open, patch
from trl import TrlParser from datasets import DatasetDict, load_dataset
from trl import DatasetMixtureConfig, TrlParser, get_dataset
from trl.scripts.utils import DatasetConfig
@dataclass @dataclass
@ -262,3 +266,162 @@ class TestTrlParser(unittest.TestCase):
# Check that config values were applied to the subparser # Check that config values were applied to the subparser
self.assertEqual(result_args[0].arg1, 2) # Default from config self.assertEqual(result_args[0].arg1, 2) # Default from config
self.assertEqual(result_args[0].arg2, "config_value") # Default from config self.assertEqual(result_args[0].arg2, "config_value") # Default from config
class TestGetDataset(unittest.TestCase):
def test_single_dataset_with_config(self):
mixture_config = DatasetMixtureConfig(
datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")]
)
result = get_dataset(mixture_config)
expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
self.assertEqual(expected["train"][:], result["train"][:])
def test_single_dataset_preference_config(self):
mixture_config = DatasetMixtureConfig(
datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_preference")]
)
result = get_dataset(mixture_config)
expected = load_dataset("trl-internal-testing/zen", "standard_preference")
self.assertEqual(expected["train"][:], result["train"][:])
def test_single_dataset_streaming(self):
mixture_config = DatasetMixtureConfig(
datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")],
streaming=True,
)
result = get_dataset(mixture_config)
expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
self.assertEqual(expected["train"].to_list(), list(result["train"]))
def test_dataset_mixture_basic(self):
dataset_config1 = DatasetConfig(
path="trl-internal-testing/zen", name="standard_prompt_completion", split="train", columns=["prompt"]
)
dataset_config2 = DatasetConfig(
path="trl-internal-testing/zen", name="standard_preference", split="train", columns=["prompt"]
)
mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2])
result = get_dataset(mixture_config)
self.assertIsInstance(result, DatasetDict)
self.assertIn("train", result)
train_dataset = result["train"]
self.assertEqual(train_dataset.column_names, ["prompt"])
prompts = train_dataset["prompt"]
expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"])
expected_second_half = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train")
self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"])
def test_dataset_mixture_with_weights(self):
dataset_config1 = DatasetConfig(
path="trl-internal-testing/zen", name="standard_prompt_completion", split="train[:50%]", columns=["prompt"]
)
dataset_config2 = DatasetConfig(
path="trl-internal-testing/zen", name="standard_preference", split="train[:50%]", columns=["prompt"]
)
mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2])
result = get_dataset(mixture_config)
self.assertIsInstance(result, DatasetDict)
self.assertIn("train", result)
train_dataset = result["train"]
self.assertEqual(train_dataset.column_names, ["prompt"])
prompts = train_dataset["prompt"]
expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train[:50%]")
self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"])
expected_second_half = load_dataset(
"trl-internal-testing/zen", "standard_prompt_completion", split="train[:50%]"
)
self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"])
def test_dataset_mixture_with_test_split(self):
mixture_config = DatasetMixtureConfig(
datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")],
test_split_size=2,
)
result = get_dataset(mixture_config)
self.assertIsInstance(result, DatasetDict)
self.assertIn("train", result)
self.assertIn("test", result)
self.assertEqual(len(result["train"]), 15)
self.assertEqual(len(result["test"]), 2)
def test_empty_dataset_mixture_raises_error(self):
mixture_config = DatasetMixtureConfig(datasets=[])
with self.assertRaises(ValueError) as context:
get_dataset(mixture_config)
self.assertIn("No datasets were loaded", str(context.exception))
def test_mixture_multiple_different_configs(self):
dataset_config1 = DatasetConfig(
path="trl-internal-testing/zen", name="conversational_preference", split="train", columns=["prompt"]
)
dataset_config2 = DatasetConfig(
path="trl-internal-testing/zen", name="conversational_prompt_only", split="test"
)
mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2])
result = get_dataset(mixture_config)
self.assertIsInstance(result, DatasetDict)
self.assertIn("train", result)
self.assertGreater(len(result["train"]), 0)
def test_trlparser_parses_yaml_config_correctly(self):
# Prepare YAML content exactly like your example
yaml_content = """
datasets:
- path: trl-internal-testing/zen
name: standard_prompt_only
- path: trl-internal-testing/zen
name: standard_preference
columns:
- prompt
"""
# Write YAML to a temporary file
with tempfile.NamedTemporaryFile("w+", suffix=".yaml") as tmpfile:
tmpfile.write(yaml_content)
tmpfile.flush()
parser = TrlParser((DatasetMixtureConfig,))
args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0]
# Assert that we got DatasetMixtureConfig instance
self.assertIsInstance(args, DatasetMixtureConfig)
# Assert datasets list length
self.assertEqual(len(args.datasets), 2)
# Check first dataset
dataset_config1 = args.datasets[0]
self.assertIsInstance(dataset_config1, DatasetConfig)
self.assertEqual(dataset_config1.path, "trl-internal-testing/zen")
self.assertEqual(dataset_config1.name, "standard_prompt_only")
self.assertIsNone(dataset_config1.columns) # No columns specified
# Check second dataset
dataset_config2 = args.datasets[1]
self.assertIsInstance(dataset_config2, DatasetConfig)
self.assertEqual(dataset_config2.path, "trl-internal-testing/zen")
self.assertEqual(dataset_config2.name, "standard_preference")
self.assertEqual(dataset_config2.columns, ["prompt"]) # Columns specified
def test_trlparser_parses_yaml_and_loads_dataset(self):
# Prepare YAML content exactly like your example
yaml_content = """
datasets:
- path: trl-internal-testing/zen
name: standard_language_modeling
"""
# Write YAML to a temporary file
with tempfile.NamedTemporaryFile("w+", suffix=".yaml") as tmpfile:
tmpfile.write(yaml_content)
tmpfile.flush()
parser = TrlParser((DatasetMixtureConfig,))
args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0]
# Load the dataset using get_dataset
result = get_dataset(args)
expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
self.assertEqual(expected["train"][:], result["train"][:])

View File

@ -20,7 +20,7 @@ from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffus
_import_structure = { _import_structure = {
"scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"], "scripts": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"],
"data_utils": [ "data_utils": [
"apply_chat_template", "apply_chat_template",
"extract_prompt", "extract_prompt",
@ -136,7 +136,7 @@ if TYPE_CHECKING:
create_reference_model, create_reference_model,
setup_chat_format, setup_chat_format,
) )
from .scripts import ScriptArguments, TrlParser, init_zero_verbose from .scripts import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose
from .trainer import ( from .trainer import (
AlignPropConfig, AlignPropConfig,
AlignPropTrainer, AlignPropTrainer,

View File

@ -18,11 +18,11 @@ from ..import_utils import _LazyModule
_import_structure = { _import_structure = {
"utils": ["init_zero_verbose", "ScriptArguments", "TrlParser"], "utils": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"],
} }
if TYPE_CHECKING: if TYPE_CHECKING:
from .utils import ScriptArguments, TrlParser, init_zero_verbose from .utils import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose
else: else:
import sys import sys

View File

@ -61,17 +61,20 @@ python trl/scripts/dpo.py \
""" """
import argparse import argparse
import warnings
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import ( from trl import (
DatasetMixtureConfig,
DPOConfig, DPOConfig,
DPOTrainer, DPOTrainer,
ModelConfig, ModelConfig,
ScriptArguments, ScriptArguments,
TrlParser, TrlParser,
get_dataset,
get_kbit_device_map, get_kbit_device_map,
get_peft_config, get_peft_config,
get_quantization_config, get_quantization_config,
@ -79,7 +82,7 @@ from trl import (
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
def main(script_args, training_args, model_args): def main(script_args, training_args, model_args, dataset_args):
################ ################
# Model & Tokenizer # Model & Tokenizer
################### ###################
@ -118,18 +121,22 @@ def main(script_args, training_args, model_args):
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
] ]
################ # Load the dataset
# Dataset if dataset_args.datasets and script_args.dataset_name:
################ warnings.warn(
dataset = load_dataset( "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
script_args.dataset_name, "dataset and `dataset_name` will be ignored."
name=script_args.dataset_config, )
streaming=script_args.dataset_streaming, elif dataset_args.datasets and not script_args.dataset_name:
) dataset = get_dataset(dataset_args)
elif not dataset_args.datasets and script_args.dataset_name:
dataset = load_dataset(
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
)
else:
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
########## # Initialize the DPO trainer
# Training
################
trainer = DPOTrainer( trainer = DPOTrainer(
model, model,
ref_model, ref_model,
@ -140,6 +147,7 @@ def main(script_args, training_args, model_args):
peft_config=peft_config, peft_config=peft_config,
) )
# Train the model
trainer.train() trainer.train()
if training_args.eval_strategy != "no": if training_args.eval_strategy != "no":
@ -147,14 +155,14 @@ def main(script_args, training_args, model_args):
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# Save and push to hub # Save and push to Hub
trainer.save_model(training_args.output_dir) trainer.save_model(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.push_to_hub(dataset_name=script_args.dataset_name)
def make_parser(subparsers: argparse._SubParsersAction = None): def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, DPOConfig, ModelConfig) dataclass_types = (ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None: if subparsers is not None:
parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types) parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types)
else: else:
@ -164,5 +172,10 @@ def make_parser(subparsers: argparse._SubParsersAction = None):
if __name__ == "__main__": if __name__ == "__main__":
parser = make_parser() parser = make_parser()
script_args, training_args, model_args = parser.parse_args_and_config() # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
main(script_args, training_args, model_args) # To ensure that their parsing does not interfere with the script arguments, parse the arguments with
# `return_remaining_strings=True`, then ignore the remaining strings.
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
return_remaining_strings=True
)
main(script_args, training_args, model_args, dataset_args)

View File

@ -23,13 +23,22 @@ import argparse
import importlib import importlib
import os import os
import sys import sys
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config from trl import (
DatasetMixtureConfig,
GRPOConfig,
GRPOTrainer,
ModelConfig,
ScriptArguments,
TrlParser,
get_dataset,
get_peft_config,
)
from trl.rewards import think_format_reward from trl.rewards import think_format_reward
@ -68,22 +77,11 @@ class GRPOScriptArguments(ScriptArguments):
) )
def main(script_args, training_args, model_args): def main(script_args, training_args, model_args, dataset_args):
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
# Get the reward models and functions # Get the reward models and functions
reward_funcs = [] reward_funcs = []
if script_args.reward_model_name_or_path: if script_args.reward_model_name_or_path:
reward_model = AutoModelForSequenceClassification.from_pretrained( reward_funcs.append(script_args.reward_model_name_or_path)
script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
)
reward_funcs.append(reward_model)
if script_args.reward_funcs: if script_args.reward_funcs:
for func_name in script_args.reward_funcs: for func_name in script_args.reward_funcs:
@ -102,30 +100,41 @@ def main(script_args, training_args, model_args):
) )
# Load the dataset # Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) if dataset_args.datasets and script_args.dataset_name:
warnings.warn(
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
"dataset and `dataset_name` will be ignored."
)
elif dataset_args.datasets and not script_args.dataset_name:
dataset = get_dataset(dataset_args)
elif not dataset_args.datasets and script_args.dataset_name:
dataset = load_dataset(
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
)
else:
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
# Initialize the GRPO trainer # Initialize the GRPO trainer
trainer = GRPOTrainer( trainer = GRPOTrainer(
model=model, model=model_args.model_name_or_path,
reward_funcs=reward_funcs, reward_funcs=reward_funcs,
args=training_args, args=training_args,
train_dataset=dataset[script_args.dataset_train_split], train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args), peft_config=get_peft_config(model_args),
) )
# Train and push the model to the Hub # Train the model
trainer.train() trainer.train()
# Save and push to hub # Save and push to Hub
trainer.save_model(training_args.output_dir) trainer.save_model(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.push_to_hub(dataset_name=script_args.dataset_name)
def make_parser(subparsers: argparse._SubParsersAction = None): def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig) dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None: if subparsers is not None:
parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types) parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types)
else: else:
@ -135,5 +144,10 @@ def make_parser(subparsers: argparse._SubParsersAction = None):
if __name__ == "__main__": if __name__ == "__main__":
parser = make_parser() parser = make_parser()
script_args, training_args, model_args = parser.parse_args_and_config() # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
main(script_args, training_args, model_args) # To ensure that their parsing does not interfere with the script arguments, parse the arguments with
# `return_remaining_strings=True`, then ignore the remaining strings.
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
return_remaining_strings=True
)
main(script_args, training_args, model_args, dataset_args)

View File

@ -65,22 +65,25 @@ python trl/scripts/kto.py \
""" """
import argparse import argparse
import warnings
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import ( from trl import (
DatasetMixtureConfig,
KTOConfig, KTOConfig,
KTOTrainer, KTOTrainer,
ModelConfig, ModelConfig,
ScriptArguments, ScriptArguments,
TrlParser, TrlParser,
get_dataset,
get_peft_config, get_peft_config,
setup_chat_format, setup_chat_format,
) )
def main(script_args, training_args, model_args): def main(script_args, training_args, model_args, dataset_args):
# Load a pretrained model # Load a pretrained model
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
@ -100,7 +103,19 @@ def main(script_args, training_args, model_args):
model, tokenizer = setup_chat_format(model, tokenizer) model, tokenizer = setup_chat_format(model, tokenizer)
# Load the dataset # Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) if dataset_args.datasets and script_args.dataset_name:
warnings.warn(
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
"dataset and `dataset_name` will be ignored."
)
elif dataset_args.datasets and not script_args.dataset_name:
dataset = get_dataset(dataset_args)
elif not dataset_args.datasets and script_args.dataset_name:
dataset = load_dataset(
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
)
else:
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
# Initialize the KTO trainer # Initialize the KTO trainer
trainer = KTOTrainer( trainer = KTOTrainer(
@ -113,17 +128,17 @@ def main(script_args, training_args, model_args):
peft_config=get_peft_config(model_args), peft_config=get_peft_config(model_args),
) )
# Train and push the model to the Hub # Train the model
trainer.train() trainer.train()
# Save and push to hub # Save and push to Hub
trainer.save_model(training_args.output_dir) trainer.save_model(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.push_to_hub(dataset_name=script_args.dataset_name)
def make_parser(subparsers: argparse._SubParsersAction = None): def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, KTOConfig, ModelConfig) dataclass_types = (ScriptArguments, KTOConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None: if subparsers is not None:
parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types) parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types)
else: else:
@ -133,5 +148,10 @@ def make_parser(subparsers: argparse._SubParsersAction = None):
if __name__ == "__main__": if __name__ == "__main__":
parser = make_parser() parser = make_parser()
script_args, training_args, model_args = parser.parse_args_and_config() # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
main(script_args, training_args, model_args) # To ensure that their parsing does not interfere with the script arguments, parse the arguments with
# `return_remaining_strings=True`, then ignore the remaining strings.
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
return_remaining_strings=True
)
main(script_args, training_args, model_args, dataset_args)

View File

@ -61,25 +61,28 @@ python trl/scripts/sft.py \
""" """
import argparse import argparse
import warnings
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from trl import ( from trl import (
DatasetMixtureConfig,
ModelConfig, ModelConfig,
ScriptArguments, ScriptArguments,
SFTConfig, SFTConfig,
SFTTrainer, SFTTrainer,
TrlParser, TrlParser,
clone_chat_template, clone_chat_template,
get_dataset,
get_kbit_device_map, get_kbit_device_map,
get_peft_config, get_peft_config,
get_quantization_config, get_quantization_config,
) )
def main(script_args, training_args, model_args): def main(script_args, training_args, model_args, dataset_args):
################ ################
# Model init kwargs & Tokenizer # Model init kwargs & Tokenizer
################ ################
@ -116,14 +119,22 @@ def main(script_args, training_args, model_args):
# TODO: source should be passed as an argument # TODO: source should be passed as an argument
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B") model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
################ # Load the dataset
# Dataset if dataset_args.datasets and script_args.dataset_name:
################ warnings.warn(
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
"dataset and `dataset_name` will be ignored."
)
elif dataset_args.datasets and not script_args.dataset_name:
dataset = get_dataset(dataset_args)
elif not dataset_args.datasets and script_args.dataset_name:
dataset = load_dataset(
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
)
else:
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
################ # Initialize the SFT trainer
# Training
################
trainer = SFTTrainer( trainer = SFTTrainer(
model=model, model=model,
args=training_args, args=training_args,
@ -133,16 +144,17 @@ def main(script_args, training_args, model_args):
peft_config=get_peft_config(model_args), peft_config=get_peft_config(model_args),
) )
# Train the model
trainer.train() trainer.train()
# Save and push to hub # Save and push to Hub
trainer.save_model(training_args.output_dir) trainer.save_model(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name) trainer.push_to_hub(dataset_name=script_args.dataset_name)
def make_parser(subparsers: argparse._SubParsersAction = None): def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig) dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None: if subparsers is not None:
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
else: else:
@ -155,5 +167,7 @@ if __name__ == "__main__":
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with # To ensure that their parsing does not interfere with the script arguments, parse the arguments with
# `return_remaining_strings=True`, then ignore the remaining strings. # `return_remaining_strings=True`, then ignore the remaining strings.
script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True) script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
main(script_args, training_args, model_args) return_remaining_strings=True
)
main(script_args, training_args, model_args, dataset_args)

View File

@ -23,7 +23,9 @@ from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Union from typing import Optional, Union
import datasets
import yaml import yaml
from datasets import DatasetDict, concatenate_datasets
from transformers import HfArgumentParser from transformers import HfArgumentParser
from transformers.hf_argparser import DataClass, DataClassType from transformers.hf_argparser import DataClass, DataClassType
from transformers.utils import is_rich_available from transformers.utils import is_rich_available
@ -32,22 +34,121 @@ from transformers.utils import is_rich_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class DatasetConfig:
"""
Configuration for a dataset.
This class matches the signature of [`~datasets.load_dataset`] and the arguments are used directly in the
`datasets.load_dataset` function. You can refer to the `datasets.load_dataset` documentation for more details.
Parameters:
path (`str`):
Path or name of the dataset.
name (`str`, *optional*, defaults to `None`):
Defining the name of the dataset configuration.
data_dir (`str`, *optional*, defaults to `None`):
Defining the `data_dir` of the dataset configuration. If specified for the generic builders(csv, text
etc.) or the Hub datasets and `data_files` is `None`, the behavior is equal to passing
`os.path.join(data_dir, **)` as `data_files` to reference all the files in a directory.
data_files (`str` or `Sequence` or `Mapping`, *optional*, defaults to `None`):
Path(s) to source data file(s).
split (`str`, *optional*, defaults to `"train"`):
Which split of the data to load.
columns (`list[str]`, *optional*, defaults to `None`):
List of column names to select from the dataset. If `None`, all columns are selected.
"""
path: str
name: Optional[str] = None
data_dir: Optional[str] = None
data_files: Optional[Union[str, list[str], dict[str, str]]] = None
split: str = "train"
columns: Optional[list[str]] = None
@dataclass
class DatasetMixtureConfig:
"""
Configuration class for a mixture of datasets.
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
datasets (`list[DatasetConfig]`):
List of dataset configurations to include in the mixture.
streaming (`bool`, *optional*, defaults to `False`):
Whether to stream the datasets. If `True`, the datasets will be loaded in streaming mode.
test_split_size (`float` or `None`, *optional*, defaults to `None`):
Size of the test split. Refer to the `test_size` parameter in the [`~datasets.train_test_split`] function
for more details. If `None`, the dataset will not be split into train and test sets.
Usage:
When using the CLI, you can add the following section to your YAML config file:
```yaml
datasets:
- path: ...
name: ...
data_dir: ...
data_files: ...
split: ...
columns: ...
- path: ...
name: ...
data_dir: ...
data_files: ...
split: ...
columns: ...
streaming: ...
test_split_size: ...
```
"""
datasets: list[DatasetConfig] = field(
default_factory=list,
metadata={"help": "List of dataset configurations to include in the mixture."},
)
streaming: bool = field(
default=False,
metadata={"help": "Whether to stream the datasets. If True, the datasets will be loaded in streaming mode."},
)
test_split_size: Optional[float] = field(
default=None,
metadata={
"help": "Size of the test split. Refer to the `test_size` parameter in the `datasets.train_test_split` "
"function for more details. If None, the dataset will not be split into train and test sets."
},
)
def __post_init__(self):
# Convert any dataset dicts (from CLI/config parsing) into DatasetConfig objects
for idx, dataset in enumerate(self.datasets):
if isinstance(dataset, dict):
# If it's a dict, convert it to DatasetConfig
self.datasets[idx] = DatasetConfig(**dataset)
@dataclass @dataclass
class ScriptArguments: class ScriptArguments:
""" """
Arguments common to all scripts. Arguments common to all scripts.
Args: Args:
dataset_name (`str`): dataset_name (`str`, or `None`, *optional*, defaults to `None`):
Dataset name. Path or name of the dataset to load. If `datasets` is provided, this will be ignored.
dataset_config (`str` or `None`, *optional*, defaults to `None`): dataset_config (`str` or `None`, *optional*, defaults to `None`):
Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function.
If `datasets` is provided, this will be ignored.
dataset_train_split (`str`, *optional*, defaults to `"train"`): dataset_train_split (`str`, *optional*, defaults to `"train"`):
Dataset split to use for training. Dataset split to use for training. If `datasets` is provided, this will be ignored.
dataset_test_split (`str`, *optional*, defaults to `"test"`): dataset_test_split (`str`, *optional*, defaults to `"test"`):
Dataset split to use for evaluation. Dataset split to use for evaluation. If `datasets` is provided, this will be ignored.
dataset_streaming (`bool`, *optional*, defaults to `False`): dataset_streaming (`bool`, *optional*, defaults to `False`):
Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. If `datasets` is
provided, this will be ignored.
gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`):
Whether to apply `use_reentrant` for gradient checkpointing. Whether to apply `use_reentrant` for gradient checkpointing.
ignore_bias_buffers (`bool`, *optional*, defaults to `False`): ignore_bias_buffers (`bool`, *optional*, defaults to `False`):
@ -56,19 +157,31 @@ class ScriptArguments:
https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992.
""" """
dataset_name: Optional[str] = field(default=None, metadata={"help": "Dataset name."}) dataset_name: Optional[str] = field(
default=None,
metadata={"help": "Path or name of the dataset to load. If `datasets` is provided, this will be ignored."},
)
dataset_config: Optional[str] = field( dataset_config: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` " "help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` "
"function." "function. If `datasets` is provided, this will be ignored."
}, },
) )
dataset_train_split: str = field(default="train", metadata={"help": "Dataset split to use for training."}) dataset_train_split: str = field(
dataset_test_split: str = field(default="test", metadata={"help": "Dataset split to use for evaluation."}) default="train",
metadata={"help": "Dataset split to use for training. If `datasets` is provided, this will be ignored."},
)
dataset_test_split: str = field(
default="test",
metadata={"help": "Dataset split to use for evaluation. If `datasets` is provided, this will be ignored."},
)
dataset_streaming: bool = field( dataset_streaming: bool = field(
default=False, default=False,
metadata={"help": "Whether to stream the dataset. If True, the dataset will be loaded in streaming mode."}, metadata={
"help": "Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. If "
"`datasets` is provided, this will be ignored."
},
) )
gradient_checkpointing_use_reentrant: bool = field( gradient_checkpointing_use_reentrant: bool = field(
default=False, default=False,
@ -282,3 +395,66 @@ def get_git_commit_hash(package_name):
return None return None
except Exception as e: except Exception as e:
return f"Error: {str(e)}" return f"Error: {str(e)}"
def get_dataset(mixture_config: DatasetMixtureConfig) -> DatasetDict:
"""
Load a mixture of datasets based on the configuration.
Args:
mixture_config (`DatasetMixtureConfig`):
Script arguments containing dataset configuration.
Returns:
`DatasetDict`:
Combined dataset(s) from the mixture configuration, with optional train/test split if `test_split_size` is
set.
Example:
```python
from trl import DatasetMixtureConfig, get_dataset
from trl.scripts.utils import DatasetConfig
mixture_config = DatasetMixtureConfig(datasets=[DatasetConfig(path="trl-lib/tldr")])
dataset = get_dataset(mixture_config)
print(dataset)
```
```
DatasetDict({
train: Dataset({
features: ['prompt', 'completion'],
num_rows: 116722
})
})
```
"""
logger.info(f"Creating dataset mixture with {len(mixture_config.datasets)} datasets")
datasets_list = []
for dataset_config in mixture_config.datasets:
logger.info(f"Loading dataset for mixture: {dataset_config.path} (config name: {dataset_config.name})")
dataset = datasets.load_dataset(
path=dataset_config.path,
name=dataset_config.name,
data_dir=dataset_config.data_dir,
data_files=dataset_config.data_files,
split=dataset_config.split,
streaming=mixture_config.streaming,
)
if dataset_config.columns is not None:
dataset = dataset.select_columns(dataset_config.columns)
datasets_list.append(dataset)
if datasets_list:
combined_dataset = concatenate_datasets(datasets_list)
if isinstance(combined_dataset, datasets.Dataset): # IterableDataset does not have a length
logger.info(f"Created dataset mixture with {len(combined_dataset)} examples")
if mixture_config.test_split_size is not None:
logger.info(f"Spliting dataset into train and test sets with test size: {mixture_config.test_split_size}")
combined_dataset = combined_dataset.train_test_split(test_size=mixture_config.test_split_size)
return combined_dataset
else:
return DatasetDict({"train": combined_dataset})
else:
raise ValueError("No datasets were loaded from the mixture configuration")