mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 11:33:51 +08:00
♻️ Standardize script_args
(#2130)
This commit is contained in:
committed by
GitHub
parent
a0d714949f
commit
9af4734178
@ -45,7 +45,7 @@ class ScriptArguments:
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
@ -90,12 +90,12 @@ lora_config = LoraConfig(
|
||||
|
||||
# set up models
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
args.model_name,
|
||||
script_args.model_name,
|
||||
use_auth_token=True,
|
||||
load_in_4bit=True,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
ds = load_dataset("openai/gsm8k", "main", split="train")
|
||||
@ -107,7 +107,7 @@ ds_test = load_dataset("openai/gsm8k", "main", split="test")
|
||||
ds_test = ds_test.rename_columns({"question": "query"})
|
||||
ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
|
||||
|
||||
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=args.batch_size)
|
||||
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=script_args.batch_size)
|
||||
|
||||
# prompt
|
||||
prompt = """\
|
||||
@ -138,16 +138,16 @@ generation_kwargs = {
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"max_new_tokens": script_args.max_new_tokens,
|
||||
}
|
||||
|
||||
# trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate,
|
||||
mini_batch_size=args.mini_batch_size,
|
||||
ppo_epochs=args.ppo_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
batch_size=script_args.batch_size,
|
||||
learning_rate=script_args.learning_rate,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
ppo_epochs=script_args.ppo_epochs,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
log_with="wandb",
|
||||
tracker_project_name="trl-gsm8k",
|
||||
remove_unused_columns=False,
|
||||
@ -169,7 +169,7 @@ text_env = TextEnvironment(
|
||||
)
|
||||
|
||||
# main training loop
|
||||
for epoch in range(args.n_epochs):
|
||||
for epoch in range(script_args.n_epochs):
|
||||
for step, batch in enumerate(ppo_trainer.dataloader):
|
||||
if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs
|
||||
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
|
||||
@ -190,4 +190,4 @@ for epoch in range(args.n_epochs):
|
||||
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
|
||||
|
||||
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
|
||||
ppo_trainer.save_pretrained(f"model/{args.model_name}-gsm8k")
|
||||
ppo_trainer.save_pretrained(f"model/{script_args.model_name}-gsm8k")
|
||||
|
Reference in New Issue
Block a user