Files
trl/examples/scripts/reward_modeling.py
Kashif Rasul 850ddcf598 [pre-commit] update pre-commit yaml (#2002)
* update pre-commit yaml

* fix test

* use element_type
2024-09-02 19:15:25 +02:00

138 lines
4.9 KiB
Python

# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python examples/scripts/reward_modeling.py \
--model_name_or_path=facebook/opt-350m \
--output_dir="reward_modeling_anthropic_hh" \
--per_device_train_batch_size=16 \
--num_train_epochs=1 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing=True \
--learning_rate=1.41e-5 \
--report_to="wandb" \
--remove_unused_columns=False \
--optim="adamw_torch" \
--logging_steps=10 \
--eval_strategy="steps" \
--eval_steps=500 \
--max_length=512 \
"""
import warnings
import torch
from accelerate import PartialState
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
from trl import ModelConfig, RewardConfig, RewardTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
tqdm.pandas()
if __name__ == "__main__":
parser = HfArgumentParser((RewardConfig, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
)
model = AutoModelForSequenceClassification.from_pretrained(
model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
if model_config.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script."
)
################
# Dataset
################
raw_datasets = load_dataset("Anthropic/hh-rlhf")
# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples
# Preprocess the dataset and filter out examples that are longer than args.max_length
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
num_proc=config.dataset_num_proc,
)
raw_datasets = raw_datasets.filter(
lambda x: len(x["input_ids_chosen"]) <= config.max_length
and len(x["input_ids_rejected"]) <= config.max_length,
num_proc=config.dataset_num_proc,
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
################
# Training
################
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(config.output_dir)
trainer.push_to_hub()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
print(metrics)