mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
PartialState().local_main_process_first() when map in examples (#1926)
* `PartialState().local_main_process_first()` when map in examples * allow load from cache --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
54f806b6ff
commit
f05f63c1ea
@ -31,6 +31,7 @@ python examples/scripts/reward_modeling.py \
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
|
||||
@ -99,16 +100,20 @@ if __name__ == "__main__":
|
||||
return new_examples
|
||||
|
||||
# Preprocess the dataset and filter out examples that are longer than args.max_length
|
||||
raw_datasets = raw_datasets.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=config.dataset_num_proc,
|
||||
)
|
||||
raw_datasets = raw_datasets.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= config.max_length
|
||||
and len(x["input_ids_rejected"]) <= config.max_length,
|
||||
num_proc=config.dataset_num_proc,
|
||||
)
|
||||
# Compute that only on the main process for faster data processing.
|
||||
# see: https://github.com/huggingface/trl/pull/1255
|
||||
with PartialState().local_main_process_first():
|
||||
raw_datasets = raw_datasets.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=config.dataset_num_proc,
|
||||
)
|
||||
raw_datasets = raw_datasets.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= config.max_length
|
||||
and len(x["input_ids_rejected"]) <= config.max_length,
|
||||
num_proc=config.dataset_num_proc,
|
||||
)
|
||||
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user