mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 11:33:51 +08:00
* Added reward summarization example * readme change * Apply suggestions from code review all comments * addressed main pr comments * added doc pages * additional code comment * formatted according to style * quick bugfix * Update examples/summarization/scripts/reward_summarization.py Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * responsed to latest round of PR feedback * added example model to docs * added example model to docs --------- Co-authored-by: Nathan Lambert <nathan@huggingface.co> Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
206 lines
8.0 KiB
Python
206 lines
8.0 KiB
Python
from datasets import load_dataset
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForSequenceClassification,
|
|
TrainingArguments,
|
|
Trainer,
|
|
PreTrainedTokenizerBase,
|
|
HfArgumentParser,
|
|
)
|
|
from transformers.utils import PaddingStrategy
|
|
from typing import Optional, Union, List, Dict, Any
|
|
import evaluate
|
|
from dataclasses import dataclass, field
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
# Define and parse arguments.
|
|
@dataclass
|
|
class ScriptArguments:
|
|
"""
|
|
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
|
|
"""
|
|
|
|
local_rank: Optional[int] = field(default=0, metadata={"help": "Used for multi-gpu"})
|
|
resume_from_checkpoint: Optional[bool] = field(
|
|
default=False, metadata={"help": "If you want to resume training where it left off."}
|
|
)
|
|
deepspeed: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
|
|
},
|
|
)
|
|
per_device_train_batch_size: Optional[int] = field(default=16)
|
|
per_device_eval_batch_size: Optional[int] = field(default=16)
|
|
gradient_accumulation_steps: Optional[int] = field(default=4)
|
|
learning_rate: Optional[int] = field(default=2e-5)
|
|
weight_decay: Optional[int] = field(default=0.001)
|
|
model_name: Optional[str] = field(
|
|
default="gpt2",
|
|
metadata={
|
|
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
|
|
},
|
|
)
|
|
bf16: Optional[bool] = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
|
|
},
|
|
)
|
|
num_train_epochs: Optional[int] = field(
|
|
default="5",
|
|
metadata={
|
|
"help": "The number of training epochs for the reward model. OpenAI used 5."
|
|
}
|
|
)
|
|
|
|
|
|
parser = HfArgumentParser(ScriptArguments)
|
|
script_args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
# Load the human comparisons dataset for tuning the reward model.
|
|
ds = load_dataset("openai/summarize_from_feedback", name="comparisons")
|
|
|
|
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
|
|
training_args = TrainingArguments(
|
|
output_dir=f"{script_args.model_name}_summarization_reward_model",
|
|
learning_rate=script_args.learning_rate,
|
|
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
|
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
|
num_train_epochs=script_args.num_train_epochs,
|
|
weight_decay=script_args.weight_decay,
|
|
evaluation_strategy="epoch",
|
|
save_strategy="epoch",
|
|
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
|
deepspeed=script_args.deepspeed,
|
|
local_rank=script_args.local_rank,
|
|
remove_unused_columns=False,
|
|
label_names=[],
|
|
)
|
|
|
|
# Load the value-head model and tokenizer.
|
|
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
|
|
model = AutoModelForSequenceClassification.from_pretrained(script_args.model_name, num_labels=1)
|
|
|
|
# Need to do this for gpt2, because it doesn't have an official pad token.
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
model.config.pad_token_id = tokenizer.eos_token_id
|
|
|
|
# Turn the dataset into pairs of post + summaries, where text_j is the preferred post + summary and text_k is the other.
|
|
def turn_into_text_classification_format(examples):
|
|
new_examples = {"text_j": [], "text_k": []}
|
|
for info, summaries, choice in zip(examples["info"], examples["summaries"], examples["choice"]):
|
|
if len(summaries) != 2 or choice not in (0, 1):
|
|
raise ValueError(f"There should be two summaries with a choice that's either 0 or 1. Received {len(summaries)} summaries and choice={choice}.")
|
|
original_text_field = "post" if info["post"] is not None else "article"
|
|
new_examples["text_j"].append(
|
|
summaries[choice]["text"] + " " + tokenizer.bos_token + " " + info[original_text_field]
|
|
)
|
|
new_examples["text_k"].append(
|
|
summaries[0 if choice == 1 else 1]["text"] + " " + tokenizer.bos_token + " " + info[original_text_field]
|
|
)
|
|
|
|
return new_examples
|
|
|
|
|
|
num_proc = (
|
|
8
|
|
) # Can adjust to be higher if you have more processors. Should work even if you don't have 8 CPUs, though.
|
|
original_columns = ds["train"].column_names
|
|
ds = ds.map(turn_into_text_classification_format, batched=True, num_proc=num_proc, remove_columns=original_columns)
|
|
|
|
# Tokenize the dataset.
|
|
def preprocess_function(examples):
|
|
tokenized_j = tokenizer(examples["text_j"], truncation=True)
|
|
tokenized_k = tokenizer(examples["text_k"], truncation=True)
|
|
return {
|
|
"input_ids_j": tokenized_j["input_ids"],
|
|
"attention_mask_j": tokenized_j["attention_mask"],
|
|
"input_ids_k": tokenized_k["input_ids"],
|
|
"attention_mask_k": tokenized_k["attention_mask"],
|
|
}
|
|
|
|
|
|
tokenized_ds = ds.map(preprocess_function, batched=True, num_proc=num_proc, remove_columns=["text_j", "text_k"])
|
|
|
|
# We need to define a special data collator that batches the data in our j vs k format.
|
|
@dataclass
|
|
class RewardDataCollatorWithPadding:
|
|
|
|
tokenizer: PreTrainedTokenizerBase
|
|
padding: Union[bool, str, PaddingStrategy] = True
|
|
max_length: Optional[int] = None
|
|
pad_to_multiple_of: Optional[int] = None
|
|
return_tensors: str = "pt"
|
|
|
|
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
features_j = []
|
|
features_k = []
|
|
for feature in features:
|
|
features_j.append({"input_ids": feature["input_ids_j"], "attention_mask": feature["attention_mask_j"]})
|
|
features_k.append({"input_ids": feature["input_ids_k"], "attention_mask": feature["attention_mask_k"]})
|
|
batch_j = self.tokenizer.pad(
|
|
features_j,
|
|
padding=self.padding,
|
|
max_length=self.max_length,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors,
|
|
)
|
|
batch_k = self.tokenizer.pad(
|
|
features_k,
|
|
padding=self.padding,
|
|
max_length=self.max_length,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors=self.return_tensors,
|
|
)
|
|
batch = {
|
|
"input_ids_j": batch_j["input_ids"],
|
|
"attention_mask_j": batch_j["attention_mask"],
|
|
"input_ids_k": batch_k["input_ids"],
|
|
"attention_mask_k": batch_k["attention_mask"],
|
|
"return_loss": True,
|
|
}
|
|
return batch
|
|
|
|
|
|
# Define the metric that we'll use for validation.
|
|
accuracy = evaluate.load("accuracy")
|
|
|
|
|
|
def compute_metrics(eval_pred):
|
|
predictions, _ = eval_pred
|
|
# Here, predictions is rewards_j and rewards_k.
|
|
# We want to see how much of the time rewards_j > rewards_k.
|
|
predictions = np.argmax(predictions, axis=0)
|
|
labels = np.zeros(predictions.shape)
|
|
return accuracy.compute(predictions=predictions, references=labels)
|
|
|
|
|
|
class RewardTrainer(Trainer):
|
|
# Define how to compute the reward loss.
|
|
def compute_loss(self, model, inputs, return_outputs=False):
|
|
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
|
|
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
|
|
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
|
|
if return_outputs:
|
|
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
|
|
return loss
|
|
|
|
|
|
# Train the model, woohoo.
|
|
trainer = RewardTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=tokenized_ds["train"],
|
|
eval_dataset=tokenized_ds["validation"],
|
|
compute_metrics=compute_metrics,
|
|
data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer),
|
|
)
|
|
|
|
trainer.train(script_args.resume_from_checkpoint)
|
|
|
|
# Push to the hub so you can share it with people :D
|
|
model.push_to_hub(script_args.model_name)
|
|
tokenizer.push_to_hub(script_args.model_name)
|