mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
accelerate
integration (#58)
* working v1 * add `accelerate` on requirements * add `accelerate` on `setup.py` * add `datasets` on `setup.py` * small updates - add docstring on most functions - correct logging * rm unneeded file * replace with `generate` * Update trl/trainer/accelerate_ppo.py Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * correct return * add dataloader support * add `wandb` to `setup.py` * refactor - remove `_build_dataset` method - change name to `PPOTrainer` * test * fix test * rename file * refactor * remove unneeded device assignment * fix correct device assignment * standardize docstrings * add `wandb` on `dev` * fix slow convergence - random init seems to converge much faster * oops * revert fix * revert patch * remove unneeded reshape * add input safety checker * refactor - added comments on example - fixes CI test - rewards should be a list of tensors - clearer error messages - remove build model method - refactor log stats method Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * refactor - added `PPOConfig` class - docstring on `LengthSampler` - fix test - gather rewards when logging - unwrap model when calling generate * some refactor * remove unneeded hack * adapt dataset * fix test * remove rollout * remove timing * remove `shuffle=True` * remove `LengthSampler` from trainer * refactor * remove text length sampler args from config * change collate_fn * fix silent bug * rename * move file * refactor base trainer * fix collate * final bug Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@ -5,5 +5,6 @@ datasets==1.17.0
|
||||
torch>=1.4.0
|
||||
tqdm
|
||||
transformers
|
||||
accelerate
|
||||
wandb==0.10.20
|
||||
matplotlib==3.5.1
|
||||
|
Reference in New Issue
Block a user