Files
trl/examples/scripts/rloo/rloo.py
Quentin Gallouédec 54f806b6ff Standardize dataset_num_proc usage (#1925)
* uniform dataset_num_proc

* num_proc in shuffle

* Update examples/datasets/anthropic_hh.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/ppo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/ppo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-13 15:10:39 +02:00

116 lines
3.8 KiB
Python

import shutil
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)
from trl import ModelConfig
from trl.trainer.rloo_trainer import RLOOConfig, RLOOTrainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
"""
python -i examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--non_eos_penalty \
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/rloo/rloo.py \
--output_dir models/minimal/rloo \
--rloo_k 2 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--learning_rate 3e-6 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path EleutherAI/pythia-1b-deduped \
--reward_model_path EleutherAI/pythia-1b-deduped \
--local_rollout_forward_batch_size 1 \
--deepspeed3 \
--non_eos_penalty \
"""
if __name__ == "__main__":
parser = HfArgumentParser((RLOOConfig, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)
################
# Model & Tokenizer
################
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
################
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
eval_samples = 20
train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
dataset_text_field = "prompt"
def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field],
padding=False,
)
return {"input_ids": outputs["input_ids"]}
return dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
load_from_cache_file=False,
num_proc=config.dataset_num_proc,
)
################
# Training
################
trainer = RLOOTrainer(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
train_dataset=prepare_dataset(train_dataset, tokenizer),
eval_dataset=prepare_dataset(eval_dataset, tokenizer),
)
trainer.train()
trainer.save_model(config.output_dir)
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()