accelerate integration (#58)

* working v1

* add `accelerate` on requirements

* add `accelerate` on `setup.py`

* add `datasets` on `setup.py`

* small updates

- add docstring on most functions
- correct logging

* rm unneeded file

* replace with `generate`

* Update trl/trainer/accelerate_ppo.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* correct return

* add dataloader support

* add `wandb` to `setup.py`

* refactor

- remove `_build_dataset` method
- change name to `PPOTrainer`

* test

* fix test

* rename file

* refactor

* remove unneeded device assignment

* fix correct device assignment

* standardize docstrings

* add `wandb` on `dev`

* fix slow convergence

- random init seems to converge much faster

* oops

* revert fix

* revert patch

* remove unneeded reshape

* add input safety checker

* refactor

- added comments on example
- fixes CI test
- rewards should be a list of tensors
- clearer error messages
- remove build model method
- refactor log stats method

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* refactor

- added `PPOConfig` class
- docstring on `LengthSampler`
- fix test
- gather rewards when logging
- unwrap model when calling generate

* some refactor

* remove unneeded hack

* adapt dataset

* fix test

* remove rollout

* remove timing

* remove `shuffle=True`

* remove `LengthSampler` from trainer

* refactor

* remove text length sampler args from config

* change collate_fn

* fix silent bug

* rename

* move file

* refactor base trainer

* fix collate

* final bug

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Younes Belkada
2022-12-30 09:27:25 +01:00
committed by GitHub
parent 49c4bc7c8f
commit b1279004e7
15 changed files with 1015 additions and 645 deletions

View File

@ -5,5 +5,6 @@ datasets==1.17.0
torch>=1.4.0
tqdm
transformers
accelerate
wandb==0.10.20
matplotlib==3.5.1