training_args for all TrainingArguments (#2082)

This commit is contained in:
Quentin Gallouédec
2024-09-19 15:03:47 +02:00
committed by GitHub
parent 9fb871f62f
commit 10c2f63b2a
27 changed files with 192 additions and 188 deletions

View File

@ -74,8 +74,8 @@ tqdm.pandas()
if __name__ == "__main__":
parser = HfArgumentParser((RewardScriptArguments, RewardConfig, ModelConfig))
args, config, model_config = parser.parse_args_into_dataclasses()
config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
args, training_args, model_config = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Model & Tokenizer
@ -138,19 +138,19 @@ if __name__ == "__main__":
chosen_fn = conversations_formatting_function(tokenizer, "chosen")
rejected_fn = conversations_formatting_function(tokenizer, "rejected")
dataset = dataset.map(
lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=config.dataset_num_proc
lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=training_args.dataset_num_proc
)
# Tokenize inputs
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=config.dataset_num_proc,
num_proc=training_args.dataset_num_proc,
)
# Filter out examples that are too long
dataset = dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= config.max_length
and len(x["input_ids_rejected"]) <= config.max_length,
num_proc=config.dataset_num_proc,
lambda x: len(x["input_ids_chosen"]) <= training_args.max_length
and len(x["input_ids_rejected"]) <= training_args.max_length,
num_proc=training_args.dataset_num_proc,
)
##########
@ -159,7 +159,7 @@ if __name__ == "__main__":
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=config,
args=training_args,
train_dataset=dataset[args.dataset_train_split],
eval_dataset=dataset[args.dataset_test_split],
peft_config=get_peft_config(model_config),
@ -169,9 +169,9 @@ if __name__ == "__main__":
############################
# Save model and push to Hub
############################
trainer.save_model(config.output_dir)
trainer.save_model(training_args.output_dir)
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
trainer.save_model(config.output_dir)
trainer.save_model(training_args.output_dir)
trainer.push_to_hub()