* clean deps
* new tests
* tests
* Add tests without optional dependencies workflow
* Update dependencies in tests.yml
* cpu version of torch
* Update dependencies and installation commands
* Disable fail-fast in test workflow
* Update test matrix in workflows file
* try fix windows
* Remove "rich" from required packages in setup.py
* Update dependency installation in tests.yml
* Add torch and deepspeed installation for windows-latest
* Fix conditional statement in workflow file
* Add torch and deepspeed installation for Windows
* Fix if statement
* Update torch and deepspeed dependencies
* Update liger package requirement for non-Windows platforms
* remove scipy dep
* Add torch GPU requirement for testing_utils
* Update trl/trainer/judges.py
* make Orpotrainer run faster on tpu
* less data transfer
* train-trl.py
* fix
* set device_map=auto
* add is_torch_xla_available guards
* delete file
* address comments
* make presubmit
* Update transformer version in setup.py
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Broken first pre-draft
* Change structure to leverage user-definition of pipeline
- reward function, pipeline and scheduler will be left to the user to define
- pipeline and scheduler contract interfaces is what the framework will define
- none of this actually works
* Incremental progress: trying to get the set-up running e2e
* Incemental progress: successfully running code
* Incremental progress: running setup
Next steps: fix accelerate gardient acc assertion error when we set value > 1
* Formatting and code standards
* Incremental prog: break down code a bit
- new config flag to notify code of async reward fetching
- break off image handling code and throw it on to user to define how to handle it
- more code restructuring
* Incremental progress:
1. More code sectioning off into own methods (more for readibility than anything else)
* Incremental progress:
1. clear up contracts
2. type the reward function and prompt function
* Code shuffling and expansion of tracker, accelerator config args to beyond wandb
* More small additions
Add tensorboard logging function
Remove wandb logging function for now
Consolidate the data that get's thrown to the logging function
Add README
* Formatting
* Formatting
* Remove print statement
Make tensorboard tracking the sole tracking for the training example
* 1. start of testing
2. more refactoring
3. start of docstrings
4. parameter rename
* Basic Tests
Formatting
* Docs according to the norm
* Doocs, credits and rename file
* docs and corrections
* Put example config to respectable state
* Add recent run params
* Correct the name of the library
* Move requirements to EXTRAS
* - Add license banners
- Guard import of DDPO functions with if_diffusers_available
- doc strings for output types
* Add snippet to pull weights from huggingface + banner
* Test if passes on CI/CD
* Minor refactor
* Test dummy unet
* Possible fix for randomly disappearing attribute
* Shuffling arrangement in hopes of meeting memory requirements
* Proper Names
* Appease windows memory allocator issues for the cpu device
* Remove print statements
* Update docs/source/ddpo_trainer.mdx
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Update docs/source/ddpo_trainer.mdx
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Add docstrings and correct url
* Spelling and grammar
* Add more documentation and commandline parsing for example script
* Markdown synatx correction
* Revert accidentally committed file and put the correct one
* More docs
* Remove subclassing and add docs for leftoover subclassing
* Put back subclassing
* Reward metadata and more docs
* Remove save_load_save flag
* Grammar
* Update trl/trainer/ddpo_trainer.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Update tests/test_ddpo_trainer.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Update setup.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Update examples/scripts/stable_diffusion_tuning.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Edits to the readme for DDPO
* Renamed modelling_sd_base to modeling_sd_base
* Insert try and catch for bitsandbytes import
* Change to smaller model
* Correct tolerance for floating point comparison
* Remove dummy unet and move to check is isfinite
* 1. Expand interface to ensure other Stable Diffusion pipelines could be covered
2. remove extra identification
* 1. Remove most of the asserts except for one and add value error
2. Remove default run name
* Remove progress bar
* Docs
* Put back progress bar
* 1. Revert progress bar deletion completely
2. grammar
3. relocate line
* Experiment
* Remove experiment parts and format properly
* Change formatting and edit info in docs
* Grammar
* Refactor out most of nitty gritty of loading/saving from trainer to example model
Readme addition
* Docs additions
* 1. Proper formatting fr the test file
2. incorporatioon of pull frm hub if fails try local
3. doc strings for interface
4. highlight in the trainer, that this is only ready fr sd pipelines
* Resources for before and after
* Attempt at embedding images
* Post testing example script
* Consistent naming and document edits in light of new args
* Remove resources and add CDN links in html in doc file
---------
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
- move notebooks to `examples/notebooks``
- removed `_nbdev`file
- refactored `gpt2.py` to make it work with more recent `transformers`
- update `requirements` to add recent `transformers`
* change ppo input from tensor to list of tensors for varying shapes
* update readme example with new input type
* update docs
* add listification of tensors need for new API
* replace nans in tensors for wandb compatibility
* add `listify_batch` helper function for backwards compatibility
* update sentiment example with new api
* update docs
* update library
* ignore wandb artifacts
* update requirements
* run experiment
* replace respond to batch with generate
* add experiment
* update docs
* fix action
* fix action