mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🧭 HF jobs x TRL guide (#3890)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
committed by
GitHub
parent
4b3517facc
commit
0c91515b58
@ -16,9 +16,11 @@
|
||||
# dependencies = [
|
||||
# "trl @ git+https://github.com/huggingface/trl.git",
|
||||
# "peft",
|
||||
# "trackio",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
@ -43,12 +45,16 @@ from trl import (
|
||||
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
||||
|
||||
|
||||
"""
|
||||
python examples/scripts/ppo/ppo_tldr.py \
|
||||
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
|
||||
--dataset_test_split validation \
|
||||
--learning_rate 3e-6 \
|
||||
--output_dir models/minimal/ppo_tldr \
|
||||
--output_dir pythia-1b-deduped-tldr-preference-sft-trl-style-ppo \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 64 \
|
||||
--total_episodes 30000 \
|
||||
@ -65,7 +71,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
|
||||
examples/scripts/ppo/ppo_tldr.py \
|
||||
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
|
||||
--dataset_test_split validation \
|
||||
--output_dir models/minimal/ppo_tldr \
|
||||
--output_dir pythia-1b-deduped-tldr-preference-sft-trl-style-ppo \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@ -162,6 +168,7 @@ if __name__ == "__main__":
|
||||
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc)
|
||||
|
||||
assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
Reference in New Issue
Block a user