mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 02:53:59 +08:00
a workaround for failing log_stats (#708)
This commit is contained in:
@ -188,7 +188,7 @@ for epoch in range(args.n_epochs):
|
||||
"response": [tokenizer.decode(response) for response in responses],
|
||||
"answer": batch["answer"],
|
||||
}
|
||||
ppo_trainer.log_stats(train_stats, texts, rewards)
|
||||
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
|
||||
|
||||
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
|
||||
ppo_trainer.save_pretrained(f"model/{args.model_name}-gsm8k")
|
||||
|
Reference in New Issue
Block a user