mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
training_args
for all TrainingArguments
(#2082)
This commit is contained in:
committed by
GitHub
parent
9fb871f62f
commit
10c2f63b2a
@ -74,8 +74,8 @@ tqdm.pandas()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((RewardScriptArguments, RewardConfig, ModelConfig))
|
||||
args, config, model_config = parser.parse_args_into_dataclasses()
|
||||
config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
args, training_args, model_config = parser.parse_args_into_dataclasses()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
|
||||
################
|
||||
# Model & Tokenizer
|
||||
@ -138,19 +138,19 @@ if __name__ == "__main__":
|
||||
chosen_fn = conversations_formatting_function(tokenizer, "chosen")
|
||||
rejected_fn = conversations_formatting_function(tokenizer, "rejected")
|
||||
dataset = dataset.map(
|
||||
lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=config.dataset_num_proc
|
||||
lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=training_args.dataset_num_proc
|
||||
)
|
||||
# Tokenize inputs
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=config.dataset_num_proc,
|
||||
num_proc=training_args.dataset_num_proc,
|
||||
)
|
||||
# Filter out examples that are too long
|
||||
dataset = dataset.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= config.max_length
|
||||
and len(x["input_ids_rejected"]) <= config.max_length,
|
||||
num_proc=config.dataset_num_proc,
|
||||
lambda x: len(x["input_ids_chosen"]) <= training_args.max_length
|
||||
and len(x["input_ids_rejected"]) <= training_args.max_length,
|
||||
num_proc=training_args.dataset_num_proc,
|
||||
)
|
||||
|
||||
##########
|
||||
@ -159,7 +159,7 @@ if __name__ == "__main__":
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=config,
|
||||
args=training_args,
|
||||
train_dataset=dataset[args.dataset_train_split],
|
||||
eval_dataset=dataset[args.dataset_test_split],
|
||||
peft_config=get_peft_config(model_config),
|
||||
@ -169,9 +169,9 @@ if __name__ == "__main__":
|
||||
############################
|
||||
# Save model and push to Hub
|
||||
############################
|
||||
trainer.save_model(config.output_dir)
|
||||
trainer.save_model(training_args.output_dir)
|
||||
metrics = trainer.evaluate()
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
trainer.save_model(config.output_dir)
|
||||
trainer.save_model(training_args.output_dir)
|
||||
trainer.push_to_hub()
|
||||
|
Reference in New Issue
Block a user