♻️ Standardize script_args (#2130)

This commit is contained in:
Quentin Gallouédec
2024-09-26 15:23:42 +02:00
committed by GitHub
parent a0d714949f
commit 9af4734178
27 changed files with 136 additions and 130 deletions

View File

@ -45,7 +45,7 @@ class ScriptArguments:
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
script_args = parser.parse_args_into_dataclasses()[0]
def exact_match_reward(responses, answers=None):
@ -90,12 +90,12 @@ lora_config = LoraConfig(
# set up models
model = AutoModelForCausalLMWithValueHead.from_pretrained(
args.model_name,
script_args.model_name,
use_auth_token=True,
load_in_4bit=True,
peft_config=lora_config,
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token
ds = load_dataset("openai/gsm8k", "main", split="train")
@ -107,7 +107,7 @@ ds_test = load_dataset("openai/gsm8k", "main", split="test")
ds_test = ds_test.rename_columns({"question": "query"})
ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=args.batch_size)
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=script_args.batch_size)
# prompt
prompt = """\
@ -138,16 +138,16 @@ generation_kwargs = {
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": args.max_new_tokens,
"max_new_tokens": script_args.max_new_tokens,
}
# trainer
ppo_config = PPOConfig(
batch_size=args.batch_size,
learning_rate=args.learning_rate,
mini_batch_size=args.mini_batch_size,
ppo_epochs=args.ppo_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
batch_size=script_args.batch_size,
learning_rate=script_args.learning_rate,
mini_batch_size=script_args.mini_batch_size,
ppo_epochs=script_args.ppo_epochs,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
log_with="wandb",
tracker_project_name="trl-gsm8k",
remove_unused_columns=False,
@ -169,7 +169,7 @@ text_env = TextEnvironment(
)
# main training loop
for epoch in range(args.n_epochs):
for epoch in range(script_args.n_epochs):
for step, batch in enumerate(ppo_trainer.dataloader):
if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
@ -190,4 +190,4 @@ for epoch in range(args.n_epochs):
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
ppo_trainer.save_pretrained(f"model/{args.model_name}-gsm8k")
ppo_trainer.save_pretrained(f"model/{script_args.model_name}-gsm8k")