mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
🎚️ Add dataset mixer (#3791)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
@ -219,6 +219,49 @@ trl dpo --config dpo_config.yaml
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using dataset mixtures
|
||||
|
||||
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
|
||||
|
||||
<hfoptions id="accelerate_config">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
datasets:
|
||||
- path: stanfordnlp/imdb
|
||||
- path: roneneldan/TinyStories
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
datasets:
|
||||
- path: BAAI/Infinity-Preference
|
||||
- path: argilla/Capybara-Preferences
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
To see all the available keywords for defining dataset mixtures, refer to the [`scripts.utils.DatasetConfig`] and [`DatasetMixtureConfig`] classes.
|
||||
|
||||
## Getting the System Information
|
||||
|
||||
You can get the system information by running the following command:
|
||||
|
@ -10,3 +10,15 @@
|
||||
- parse_args_and_config
|
||||
- parse_args_into_dataclasses
|
||||
- set_defaults_with_config
|
||||
|
||||
## get_dataset
|
||||
|
||||
[[autodoc]] get_dataset
|
||||
|
||||
## DatasetConfig
|
||||
|
||||
[[autodoc]] scripts.utils.DatasetConfig
|
||||
|
||||
## DatasetMixtureConfig
|
||||
|
||||
[[autodoc]] DatasetMixtureConfig
|
||||
|
@ -12,11 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from trl import TrlParser
|
||||
from datasets import DatasetDict, load_dataset
|
||||
|
||||
from trl import DatasetMixtureConfig, TrlParser, get_dataset
|
||||
from trl.scripts.utils import DatasetConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -262,3 +266,162 @@ class TestTrlParser(unittest.TestCase):
|
||||
# Check that config values were applied to the subparser
|
||||
self.assertEqual(result_args[0].arg1, 2) # Default from config
|
||||
self.assertEqual(result_args[0].arg2, "config_value") # Default from config
|
||||
|
||||
|
||||
class TestGetDataset(unittest.TestCase):
|
||||
def test_single_dataset_with_config(self):
|
||||
mixture_config = DatasetMixtureConfig(
|
||||
datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")]
|
||||
)
|
||||
result = get_dataset(mixture_config)
|
||||
expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
||||
self.assertEqual(expected["train"][:], result["train"][:])
|
||||
|
||||
def test_single_dataset_preference_config(self):
|
||||
mixture_config = DatasetMixtureConfig(
|
||||
datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_preference")]
|
||||
)
|
||||
result = get_dataset(mixture_config)
|
||||
expected = load_dataset("trl-internal-testing/zen", "standard_preference")
|
||||
self.assertEqual(expected["train"][:], result["train"][:])
|
||||
|
||||
def test_single_dataset_streaming(self):
|
||||
mixture_config = DatasetMixtureConfig(
|
||||
datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")],
|
||||
streaming=True,
|
||||
)
|
||||
result = get_dataset(mixture_config)
|
||||
expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
||||
self.assertEqual(expected["train"].to_list(), list(result["train"]))
|
||||
|
||||
def test_dataset_mixture_basic(self):
|
||||
dataset_config1 = DatasetConfig(
|
||||
path="trl-internal-testing/zen", name="standard_prompt_completion", split="train", columns=["prompt"]
|
||||
)
|
||||
dataset_config2 = DatasetConfig(
|
||||
path="trl-internal-testing/zen", name="standard_preference", split="train", columns=["prompt"]
|
||||
)
|
||||
mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2])
|
||||
result = get_dataset(mixture_config)
|
||||
self.assertIsInstance(result, DatasetDict)
|
||||
self.assertIn("train", result)
|
||||
train_dataset = result["train"]
|
||||
self.assertEqual(train_dataset.column_names, ["prompt"])
|
||||
prompts = train_dataset["prompt"]
|
||||
expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
||||
self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"])
|
||||
expected_second_half = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train")
|
||||
self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"])
|
||||
|
||||
def test_dataset_mixture_with_weights(self):
|
||||
dataset_config1 = DatasetConfig(
|
||||
path="trl-internal-testing/zen", name="standard_prompt_completion", split="train[:50%]", columns=["prompt"]
|
||||
)
|
||||
dataset_config2 = DatasetConfig(
|
||||
path="trl-internal-testing/zen", name="standard_preference", split="train[:50%]", columns=["prompt"]
|
||||
)
|
||||
mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2])
|
||||
result = get_dataset(mixture_config)
|
||||
self.assertIsInstance(result, DatasetDict)
|
||||
self.assertIn("train", result)
|
||||
train_dataset = result["train"]
|
||||
self.assertEqual(train_dataset.column_names, ["prompt"])
|
||||
prompts = train_dataset["prompt"]
|
||||
expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train[:50%]")
|
||||
self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"])
|
||||
expected_second_half = load_dataset(
|
||||
"trl-internal-testing/zen", "standard_prompt_completion", split="train[:50%]"
|
||||
)
|
||||
self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"])
|
||||
|
||||
def test_dataset_mixture_with_test_split(self):
|
||||
mixture_config = DatasetMixtureConfig(
|
||||
datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")],
|
||||
test_split_size=2,
|
||||
)
|
||||
result = get_dataset(mixture_config)
|
||||
self.assertIsInstance(result, DatasetDict)
|
||||
self.assertIn("train", result)
|
||||
self.assertIn("test", result)
|
||||
self.assertEqual(len(result["train"]), 15)
|
||||
self.assertEqual(len(result["test"]), 2)
|
||||
|
||||
def test_empty_dataset_mixture_raises_error(self):
|
||||
mixture_config = DatasetMixtureConfig(datasets=[])
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
get_dataset(mixture_config)
|
||||
|
||||
self.assertIn("No datasets were loaded", str(context.exception))
|
||||
|
||||
def test_mixture_multiple_different_configs(self):
|
||||
dataset_config1 = DatasetConfig(
|
||||
path="trl-internal-testing/zen", name="conversational_preference", split="train", columns=["prompt"]
|
||||
)
|
||||
dataset_config2 = DatasetConfig(
|
||||
path="trl-internal-testing/zen", name="conversational_prompt_only", split="test"
|
||||
)
|
||||
mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2])
|
||||
result = get_dataset(mixture_config)
|
||||
self.assertIsInstance(result, DatasetDict)
|
||||
self.assertIn("train", result)
|
||||
self.assertGreater(len(result["train"]), 0)
|
||||
|
||||
def test_trlparser_parses_yaml_config_correctly(self):
|
||||
# Prepare YAML content exactly like your example
|
||||
yaml_content = """
|
||||
datasets:
|
||||
- path: trl-internal-testing/zen
|
||||
name: standard_prompt_only
|
||||
- path: trl-internal-testing/zen
|
||||
name: standard_preference
|
||||
columns:
|
||||
- prompt
|
||||
"""
|
||||
|
||||
# Write YAML to a temporary file
|
||||
with tempfile.NamedTemporaryFile("w+", suffix=".yaml") as tmpfile:
|
||||
tmpfile.write(yaml_content)
|
||||
tmpfile.flush()
|
||||
parser = TrlParser((DatasetMixtureConfig,))
|
||||
args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0]
|
||||
|
||||
# Assert that we got DatasetMixtureConfig instance
|
||||
self.assertIsInstance(args, DatasetMixtureConfig)
|
||||
|
||||
# Assert datasets list length
|
||||
self.assertEqual(len(args.datasets), 2)
|
||||
|
||||
# Check first dataset
|
||||
dataset_config1 = args.datasets[0]
|
||||
self.assertIsInstance(dataset_config1, DatasetConfig)
|
||||
self.assertEqual(dataset_config1.path, "trl-internal-testing/zen")
|
||||
self.assertEqual(dataset_config1.name, "standard_prompt_only")
|
||||
self.assertIsNone(dataset_config1.columns) # No columns specified
|
||||
|
||||
# Check second dataset
|
||||
dataset_config2 = args.datasets[1]
|
||||
self.assertIsInstance(dataset_config2, DatasetConfig)
|
||||
self.assertEqual(dataset_config2.path, "trl-internal-testing/zen")
|
||||
self.assertEqual(dataset_config2.name, "standard_preference")
|
||||
self.assertEqual(dataset_config2.columns, ["prompt"]) # Columns specified
|
||||
|
||||
def test_trlparser_parses_yaml_and_loads_dataset(self):
|
||||
# Prepare YAML content exactly like your example
|
||||
yaml_content = """
|
||||
datasets:
|
||||
- path: trl-internal-testing/zen
|
||||
name: standard_language_modeling
|
||||
"""
|
||||
|
||||
# Write YAML to a temporary file
|
||||
with tempfile.NamedTemporaryFile("w+", suffix=".yaml") as tmpfile:
|
||||
tmpfile.write(yaml_content)
|
||||
tmpfile.flush()
|
||||
parser = TrlParser((DatasetMixtureConfig,))
|
||||
args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0]
|
||||
|
||||
# Load the dataset using get_dataset
|
||||
result = get_dataset(args)
|
||||
expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
||||
self.assertEqual(expected["train"][:], result["train"][:])
|
||||
|
@ -20,7 +20,7 @@ from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffus
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"],
|
||||
"scripts": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"],
|
||||
"data_utils": [
|
||||
"apply_chat_template",
|
||||
"extract_prompt",
|
||||
@ -136,7 +136,7 @@ if TYPE_CHECKING:
|
||||
create_reference_model,
|
||||
setup_chat_format,
|
||||
)
|
||||
from .scripts import ScriptArguments, TrlParser, init_zero_verbose
|
||||
from .scripts import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose
|
||||
from .trainer import (
|
||||
AlignPropConfig,
|
||||
AlignPropTrainer,
|
||||
|
@ -18,11 +18,11 @@ from ..import_utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"utils": ["init_zero_verbose", "ScriptArguments", "TrlParser"],
|
||||
"utils": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils import ScriptArguments, TrlParser, init_zero_verbose
|
||||
from .utils import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -61,17 +61,20 @@ python trl/scripts/dpo.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl import (
|
||||
DatasetMixtureConfig,
|
||||
DPOConfig,
|
||||
DPOTrainer,
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
TrlParser,
|
||||
get_dataset,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
@ -79,7 +82,7 @@ from trl import (
|
||||
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
def main(script_args, training_args, model_args, dataset_args):
|
||||
################
|
||||
# Model & Tokenizer
|
||||
###################
|
||||
@ -118,18 +121,22 @@ def main(script_args, training_args, model_args):
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(
|
||||
script_args.dataset_name,
|
||||
name=script_args.dataset_config,
|
||||
streaming=script_args.dataset_streaming,
|
||||
)
|
||||
# Load the dataset
|
||||
if dataset_args.datasets and script_args.dataset_name:
|
||||
warnings.warn(
|
||||
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
|
||||
"dataset and `dataset_name` will be ignored."
|
||||
)
|
||||
elif dataset_args.datasets and not script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_args)
|
||||
elif not dataset_args.datasets and script_args.dataset_name:
|
||||
dataset = load_dataset(
|
||||
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
|
||||
|
||||
##########
|
||||
# Training
|
||||
################
|
||||
# Initialize the DPO trainer
|
||||
trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model,
|
||||
@ -140,6 +147,7 @@ def main(script_args, training_args, model_args):
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
if training_args.eval_strategy != "no":
|
||||
@ -147,14 +155,14 @@ def main(script_args, training_args, model_args):
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Save and push to hub
|
||||
# Save and push to Hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
|
||||
|
||||
def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
dataclass_types = (ScriptArguments, DPOConfig, ModelConfig)
|
||||
dataclass_types = (ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types)
|
||||
else:
|
||||
@ -164,5 +172,10 @@ def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = make_parser()
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
||||
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
|
||||
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
|
||||
# `return_remaining_strings=True`, then ignore the remaining strings.
|
||||
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
main(script_args, training_args, model_args, dataset_args)
|
||||
|
@ -23,13 +23,22 @@ import argparse
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||||
from trl import (
|
||||
DatasetMixtureConfig,
|
||||
GRPOConfig,
|
||||
GRPOTrainer,
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
TrlParser,
|
||||
get_dataset,
|
||||
get_peft_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
|
||||
|
||||
@ -68,22 +77,11 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Load a pretrained model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
|
||||
def main(script_args, training_args, model_args, dataset_args):
|
||||
# Get the reward models and functions
|
||||
reward_funcs = []
|
||||
if script_args.reward_model_name_or_path:
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
|
||||
)
|
||||
reward_funcs.append(reward_model)
|
||||
reward_funcs.append(script_args.reward_model_name_or_path)
|
||||
|
||||
if script_args.reward_funcs:
|
||||
for func_name in script_args.reward_funcs:
|
||||
@ -102,30 +100,41 @@ def main(script_args, training_args, model_args):
|
||||
)
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
if dataset_args.datasets and script_args.dataset_name:
|
||||
warnings.warn(
|
||||
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
|
||||
"dataset and `dataset_name` will be ignored."
|
||||
)
|
||||
elif dataset_args.datasets and not script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_args)
|
||||
elif not dataset_args.datasets and script_args.dataset_name:
|
||||
dataset = load_dataset(
|
||||
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
|
||||
|
||||
# Initialize the GRPO trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
model=model_args.model_name_or_path,
|
||||
reward_funcs=reward_funcs,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
processing_class=tokenizer,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
# Train and push the model to the Hub
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
# Save and push to Hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
|
||||
|
||||
def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig)
|
||||
dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types)
|
||||
else:
|
||||
@ -135,5 +144,10 @@ def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = make_parser()
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
||||
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
|
||||
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
|
||||
# `return_remaining_strings=True`, then ignore the remaining strings.
|
||||
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
main(script_args, training_args, model_args, dataset_args)
|
||||
|
@ -65,22 +65,25 @@ python trl/scripts/kto.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl import (
|
||||
DatasetMixtureConfig,
|
||||
KTOConfig,
|
||||
KTOTrainer,
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
TrlParser,
|
||||
get_dataset,
|
||||
get_peft_config,
|
||||
setup_chat_format,
|
||||
)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
def main(script_args, training_args, model_args, dataset_args):
|
||||
# Load a pretrained model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
@ -100,7 +103,19 @@ def main(script_args, training_args, model_args):
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
if dataset_args.datasets and script_args.dataset_name:
|
||||
warnings.warn(
|
||||
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
|
||||
"dataset and `dataset_name` will be ignored."
|
||||
)
|
||||
elif dataset_args.datasets and not script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_args)
|
||||
elif not dataset_args.datasets and script_args.dataset_name:
|
||||
dataset = load_dataset(
|
||||
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
|
||||
|
||||
# Initialize the KTO trainer
|
||||
trainer = KTOTrainer(
|
||||
@ -113,17 +128,17 @@ def main(script_args, training_args, model_args):
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
# Train and push the model to the Hub
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
# Save and push to Hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
|
||||
|
||||
def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
dataclass_types = (ScriptArguments, KTOConfig, ModelConfig)
|
||||
dataclass_types = (ScriptArguments, KTOConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types)
|
||||
else:
|
||||
@ -133,5 +148,10 @@ def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = make_parser()
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
||||
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
|
||||
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
|
||||
# `return_remaining_strings=True`, then ignore the remaining strings.
|
||||
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
main(script_args, training_args, model_args, dataset_args)
|
||||
|
@ -61,25 +61,28 @@ python trl/scripts/sft.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
||||
|
||||
from trl import (
|
||||
DatasetMixtureConfig,
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
clone_chat_template,
|
||||
get_dataset,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
def main(script_args, training_args, model_args, dataset_args):
|
||||
################
|
||||
# Model init kwargs & Tokenizer
|
||||
################
|
||||
@ -116,14 +119,22 @@ def main(script_args, training_args, model_args):
|
||||
# TODO: source should be passed as an argument
|
||||
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
# Load the dataset
|
||||
if dataset_args.datasets and script_args.dataset_name:
|
||||
warnings.warn(
|
||||
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
|
||||
"dataset and `dataset_name` will be ignored."
|
||||
)
|
||||
elif dataset_args.datasets and not script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_args)
|
||||
elif not dataset_args.datasets and script_args.dataset_name:
|
||||
dataset = load_dataset(
|
||||
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
# Initialize the SFT trainer
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
@ -133,16 +144,17 @@ def main(script_args, training_args, model_args):
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
# Save and push to Hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
|
||||
|
||||
def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
|
||||
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
|
||||
else:
|
||||
@ -155,5 +167,7 @@ if __name__ == "__main__":
|
||||
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
|
||||
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
|
||||
# `return_remaining_strings=True`, then ignore the remaining strings.
|
||||
script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)
|
||||
main(script_args, training_args, model_args)
|
||||
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
main(script_args, training_args, model_args, dataset_args)
|
||||
|
@ -23,7 +23,9 @@ from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import yaml
|
||||
from datasets import DatasetDict, concatenate_datasets
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.hf_argparser import DataClass, DataClassType
|
||||
from transformers.utils import is_rich_available
|
||||
@ -32,22 +34,121 @@ from transformers.utils import is_rich_available
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
"""
|
||||
Configuration for a dataset.
|
||||
|
||||
This class matches the signature of [`~datasets.load_dataset`] and the arguments are used directly in the
|
||||
`datasets.load_dataset` function. You can refer to the `datasets.load_dataset` documentation for more details.
|
||||
|
||||
Parameters:
|
||||
path (`str`):
|
||||
Path or name of the dataset.
|
||||
name (`str`, *optional*, defaults to `None`):
|
||||
Defining the name of the dataset configuration.
|
||||
data_dir (`str`, *optional*, defaults to `None`):
|
||||
Defining the `data_dir` of the dataset configuration. If specified for the generic builders(csv, text
|
||||
etc.) or the Hub datasets and `data_files` is `None`, the behavior is equal to passing
|
||||
`os.path.join(data_dir, **)` as `data_files` to reference all the files in a directory.
|
||||
data_files (`str` or `Sequence` or `Mapping`, *optional*, defaults to `None`):
|
||||
Path(s) to source data file(s).
|
||||
split (`str`, *optional*, defaults to `"train"`):
|
||||
Which split of the data to load.
|
||||
columns (`list[str]`, *optional*, defaults to `None`):
|
||||
List of column names to select from the dataset. If `None`, all columns are selected.
|
||||
"""
|
||||
|
||||
path: str
|
||||
name: Optional[str] = None
|
||||
data_dir: Optional[str] = None
|
||||
data_files: Optional[Union[str, list[str], dict[str, str]]] = None
|
||||
split: str = "train"
|
||||
columns: Optional[list[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetMixtureConfig:
|
||||
"""
|
||||
Configuration class for a mixture of datasets.
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
datasets (`list[DatasetConfig]`):
|
||||
List of dataset configurations to include in the mixture.
|
||||
streaming (`bool`, *optional*, defaults to `False`):
|
||||
Whether to stream the datasets. If `True`, the datasets will be loaded in streaming mode.
|
||||
test_split_size (`float` or `None`, *optional*, defaults to `None`):
|
||||
Size of the test split. Refer to the `test_size` parameter in the [`~datasets.train_test_split`] function
|
||||
for more details. If `None`, the dataset will not be split into train and test sets.
|
||||
|
||||
Usage:
|
||||
When using the CLI, you can add the following section to your YAML config file:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
name: ...
|
||||
data_dir: ...
|
||||
data_files: ...
|
||||
split: ...
|
||||
columns: ...
|
||||
- path: ...
|
||||
name: ...
|
||||
data_dir: ...
|
||||
data_files: ...
|
||||
split: ...
|
||||
columns: ...
|
||||
streaming: ...
|
||||
test_split_size: ...
|
||||
```
|
||||
"""
|
||||
|
||||
datasets: list[DatasetConfig] = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "List of dataset configurations to include in the mixture."},
|
||||
)
|
||||
streaming: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to stream the datasets. If True, the datasets will be loaded in streaming mode."},
|
||||
)
|
||||
test_split_size: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Size of the test split. Refer to the `test_size` parameter in the `datasets.train_test_split` "
|
||||
"function for more details. If None, the dataset will not be split into train and test sets."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Convert any dataset dicts (from CLI/config parsing) into DatasetConfig objects
|
||||
for idx, dataset in enumerate(self.datasets):
|
||||
if isinstance(dataset, dict):
|
||||
# If it's a dict, convert it to DatasetConfig
|
||||
self.datasets[idx] = DatasetConfig(**dataset)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
Arguments common to all scripts.
|
||||
|
||||
Args:
|
||||
dataset_name (`str`):
|
||||
Dataset name.
|
||||
dataset_name (`str`, or `None`, *optional*, defaults to `None`):
|
||||
Path or name of the dataset to load. If `datasets` is provided, this will be ignored.
|
||||
dataset_config (`str` or `None`, *optional*, defaults to `None`):
|
||||
Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function.
|
||||
If `datasets` is provided, this will be ignored.
|
||||
dataset_train_split (`str`, *optional*, defaults to `"train"`):
|
||||
Dataset split to use for training.
|
||||
Dataset split to use for training. If `datasets` is provided, this will be ignored.
|
||||
dataset_test_split (`str`, *optional*, defaults to `"test"`):
|
||||
Dataset split to use for evaluation.
|
||||
Dataset split to use for evaluation. If `datasets` is provided, this will be ignored.
|
||||
dataset_streaming (`bool`, *optional*, defaults to `False`):
|
||||
Whether to stream the dataset. If True, the dataset will be loaded in streaming mode.
|
||||
Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. If `datasets` is
|
||||
provided, this will be ignored.
|
||||
gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`):
|
||||
Whether to apply `use_reentrant` for gradient checkpointing.
|
||||
ignore_bias_buffers (`bool`, *optional*, defaults to `False`):
|
||||
@ -56,19 +157,31 @@ class ScriptArguments:
|
||||
https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(default=None, metadata={"help": "Dataset name."})
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path or name of the dataset to load. If `datasets` is provided, this will be ignored."},
|
||||
)
|
||||
dataset_config: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` "
|
||||
"function."
|
||||
"function. If `datasets` is provided, this will be ignored."
|
||||
},
|
||||
)
|
||||
dataset_train_split: str = field(default="train", metadata={"help": "Dataset split to use for training."})
|
||||
dataset_test_split: str = field(default="test", metadata={"help": "Dataset split to use for evaluation."})
|
||||
dataset_train_split: str = field(
|
||||
default="train",
|
||||
metadata={"help": "Dataset split to use for training. If `datasets` is provided, this will be ignored."},
|
||||
)
|
||||
dataset_test_split: str = field(
|
||||
default="test",
|
||||
metadata={"help": "Dataset split to use for evaluation. If `datasets` is provided, this will be ignored."},
|
||||
)
|
||||
dataset_streaming: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to stream the dataset. If True, the dataset will be loaded in streaming mode."},
|
||||
metadata={
|
||||
"help": "Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. If "
|
||||
"`datasets` is provided, this will be ignored."
|
||||
},
|
||||
)
|
||||
gradient_checkpointing_use_reentrant: bool = field(
|
||||
default=False,
|
||||
@ -282,3 +395,66 @@ def get_git_commit_hash(package_name):
|
||||
return None
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
|
||||
def get_dataset(mixture_config: DatasetMixtureConfig) -> DatasetDict:
|
||||
"""
|
||||
Load a mixture of datasets based on the configuration.
|
||||
|
||||
Args:
|
||||
mixture_config (`DatasetMixtureConfig`):
|
||||
Script arguments containing dataset configuration.
|
||||
|
||||
Returns:
|
||||
`DatasetDict`:
|
||||
Combined dataset(s) from the mixture configuration, with optional train/test split if `test_split_size` is
|
||||
set.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from trl import DatasetMixtureConfig, get_dataset
|
||||
from trl.scripts.utils import DatasetConfig
|
||||
|
||||
mixture_config = DatasetMixtureConfig(datasets=[DatasetConfig(path="trl-lib/tldr")])
|
||||
dataset = get_dataset(mixture_config)
|
||||
print(dataset)
|
||||
```
|
||||
|
||||
```
|
||||
DatasetDict({
|
||||
train: Dataset({
|
||||
features: ['prompt', 'completion'],
|
||||
num_rows: 116722
|
||||
})
|
||||
})
|
||||
```
|
||||
"""
|
||||
logger.info(f"Creating dataset mixture with {len(mixture_config.datasets)} datasets")
|
||||
datasets_list = []
|
||||
for dataset_config in mixture_config.datasets:
|
||||
logger.info(f"Loading dataset for mixture: {dataset_config.path} (config name: {dataset_config.name})")
|
||||
dataset = datasets.load_dataset(
|
||||
path=dataset_config.path,
|
||||
name=dataset_config.name,
|
||||
data_dir=dataset_config.data_dir,
|
||||
data_files=dataset_config.data_files,
|
||||
split=dataset_config.split,
|
||||
streaming=mixture_config.streaming,
|
||||
)
|
||||
if dataset_config.columns is not None:
|
||||
dataset = dataset.select_columns(dataset_config.columns)
|
||||
datasets_list.append(dataset)
|
||||
|
||||
if datasets_list:
|
||||
combined_dataset = concatenate_datasets(datasets_list)
|
||||
if isinstance(combined_dataset, datasets.Dataset): # IterableDataset does not have a length
|
||||
logger.info(f"Created dataset mixture with {len(combined_dataset)} examples")
|
||||
|
||||
if mixture_config.test_split_size is not None:
|
||||
logger.info(f"Spliting dataset into train and test sets with test size: {mixture_config.test_split_size}")
|
||||
combined_dataset = combined_dataset.train_test_split(test_size=mixture_config.test_split_size)
|
||||
return combined_dataset
|
||||
else:
|
||||
return DatasetDict({"train": combined_dataset})
|
||||
else:
|
||||
raise ValueError("No datasets were loaded from the mixture configuration")
|
||||
|
Reference in New Issue
Block a user