a workaround for failing log_stats (#708)

This commit is contained in:
Younes Belkada
2023-08-30 12:23:57 +02:00
committed by GitHub
parent 7f636c9ed7
commit 79b90e19ba
4 changed files with 13 additions and 9 deletions

View File

@ -188,7 +188,7 @@ for epoch in range(args.n_epochs):
"response": [tokenizer.decode(response) for response in responses],
"answer": batch["answer"],
}
ppo_trainer.log_stats(train_stats, texts, rewards)
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
ppo_trainer.save_pretrained(f"model/{args.model_name}-gsm8k")