PartialState().local_main_process_first() when map in examples (#1926)

* `PartialState().local_main_process_first()` when map in examples

* allow load from cache

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
This commit is contained in:
Quentin Gallouédec
2024-08-14 12:01:03 +02:00
committed by GitHub
parent 54f806b6ff
commit f05f63c1ea
18 changed files with 115 additions and 86 deletions

View File

@ -31,6 +31,7 @@ python examples/scripts/reward_modeling.py \
import warnings
import torch
from accelerate import PartialState
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
@ -99,16 +100,20 @@ if __name__ == "__main__":
return new_examples
# Preprocess the dataset and filter out examples that are longer than args.max_length
raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
num_proc=config.dataset_num_proc,
)
raw_datasets = raw_datasets.filter(
lambda x: len(x["input_ids_chosen"]) <= config.max_length
and len(x["input_ids_rejected"]) <= config.max_length,
num_proc=config.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
num_proc=config.dataset_num_proc,
)
raw_datasets = raw_datasets.filter(
lambda x: len(x["input_ids_chosen"]) <= config.max_length
and len(x["input_ids_rejected"]) <= config.max_length,
num_proc=config.dataset_num_proc,
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]