mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🕊️ Migration PPOv2
-> PPO
(#2174)
* delete old ppo * rename ppov2 files * PPOv2 -> PPO * rm old doc * rename ppo doc file * rm old test * rename test * re-add v2 with deprecation * style * start update customization * Lion * Finish update customization * remove ppo_multi_adaptater * remove ppo example * update some doc * rm test no peft * rm hello world * processing class * Update docs/source/detoxifying_a_lm.mdx Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com> * Update trl/trainer/ppov2_config.py Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com> * Update docs/source/customization.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docs/source/detoxifying_a_lm.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * po to example overview * drop lion * remove "Use 8-bit optimizer" * Update docs/source/customization.mdx * Update docs/source/customization.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * it applies to all trainers --------- Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
committed by
GitHub
parent
d0aa421e5e
commit
70036bf87f
@ -23,7 +23,7 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
)
|
||||
|
||||
from trl import ModelConfig, PPOv2Config, PPOv2Trainer
|
||||
from trl import ModelConfig, PPOConfig, PPOTrainer
|
||||
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((PPOv2Config, ModelConfig))
|
||||
parser = HfArgumentParser((PPOConfig, ModelConfig))
|
||||
training_args, model_config = parser.parse_args_into_dataclasses()
|
||||
# remove output_dir if exists
|
||||
shutil.rmtree(training_args.output_dir, ignore_errors=True)
|
||||
@ -123,7 +123,7 @@ if __name__ == "__main__":
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = PPOv2Trainer(
|
||||
trainer = PPOTrainer(
|
||||
config=training_args,
|
||||
processing_class=tokenizer,
|
||||
policy=policy,
|
||||
|
Reference in New Issue
Block a user