Compare commits

...

332 Commits

Author SHA1 Message Date
55cc4b1076 Release v0.10.1 2024-08-29 16:51:11 +02:00
a879e6ad5a Release: v0.10.0 2024-08-29 13:01:20 +00:00
4dd0dc2988 Adds experimental Liger support to SFT script (#1992)
* adds cli and import utils

* updates SFT script

* adds liger model to trainer

* adds liger nightly dep

* precommit

* fix import

* Update trl/commands/cli_utils.py

* Fix quality

* moved use_liger arg to sft config

* remove arg

* remove use liger from sft trainer

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-29 14:48:35 +02:00
4f59e923ac Relax numpy upper bound and bump deepspeed version (#1990)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-29 13:17:48 +02:00
10f70fa333 Add ignore_index in DPOTrainer's nn.CrossEntropyLoss (#1987)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-28 16:41:41 +02:00
47ab034ca9 [DPO] tokenize and process DPO data via batches (#1914)
* tokenize and process DPO data via batches

* use helpers

* updated _process_tokens

* fixed

* incorporate build_tokenized_answer in the _tokenizer

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix tokenizer for is_vision_model

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* give the _tokenize the tokenizer as well as optional processor

* fix tests

* add bos and eos tokens

* add prompt_pixel_attention_mask

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* truncate by max_length

* formatting

* fix for enc-dec

* For encoder-decoder models, we need to use the prepared decoder_input_ids

* add tests for _build_tokenized_answer and _tokenize_feature

* check for EOS and BOS tokens

* formatting

* do not include pixel mask if they are not provided

* undo refactor

* undo add_bos_token_if_needed change

* refactor tokenizer into smaller helpers

* add back comments

* fix type hints

* format

* fix t5 tests

* args are never optional

* move cat to appropriate helper

* fix _truncate_tokens

* add tests for _truncate_tokens

* remove dead code

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-28 16:14:53 +02:00
e755eee660 Refactor Online DPO (#1839)
* online dpo trainer based on rloo trainer

* push changes

* refactor

* use `batch_generation` method

* precommit

* remove breakpoint()

* quick refactor

* push the current changes

* quick change

* refactor

* use the config name as the experiment name

* fix logging

* update online DPO docs

* use llm as a judge

* quick change

* quick fix

* cache changes

* new semantics

* style and arg order change

* rm duplicated num_epochs

* rm plot script

* num_epoch

* revert some changes

* revert changes

* revert whitespace

* rm whitespace

* revert change

* policy->model

* optional judge and reward model

* cleaning online dpo script

* warning when both reward mdoel and judge provided

* return -1 when the judge fails

* dataset num proc

* add judges in online dpo; fix collate and process within the trainer

* lr_scheduler.step() after optimizer step

* update odpo test

* reduce nestiness

* allow pickle

* generation config typing

* online dpo llm judge

* fix data collator pad token

* add space

* fix pref score

* -1 for judges

* self.model_wrapped = self.model

* onlinedpo inherits from training arguments

* num_epoch -> num_steps_in_epochs

* update -> epoch

* epoch -> step; step_in_epoch -> ppo_epoch; rm run_name

* num_steps_in_epoch -> num_ppo_epochs

* epoch_idx -> ppo_epoch_idx

* make init consistent with dpo

* try another option

* progress...

* odpo

* current progress

* log and other changes

* rename for legacy

* rename for legacy

* rename and move truncate

* rename

* new config

* LogCompletionsCallback

* style

* rename trainer

* truncate right in utils

* update example

* reward model path

* properly log

* fix example

* add generation prompt and log special tokens

* true penalty

* defaults from the paper

* Remove MPS (#1983)

* Set KV cache false when gradient checkpointing is enabled (#1984)

* Remove MPS

* Fix

* Various tweask

* Remove padding from table

* Clean up

* Fix test

* Revert log freq

* Fix docs

* Fix tests aain!

* Fix typo

* Revert

* Fix regression

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Fix DPO config test

* Fix doc tree

* Clean docs moar

* Add docstring

* raise NotImplemented error for judge

* Refactor cache clearning

---------

Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-28 15:39:51 +02:00
ac31d1205e Skip the failing Online DPO test (#1989)
* Harmonisation of tests between main and PR

* disable tqdm

* skip the test

* `"Programming Language :: Python :: 3.11"` and drop 3.7

* Update .github/workflows/tests.yml

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update setup.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update .github/workflows/tests-main.yml

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-28 14:55:18 +02:00
c44ab6d1e9 torch.load with weights_only=True (#1988)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-28 11:13:22 +02:00
a15a80e0d5 gather the target model params as well (#1978) 2024-08-28 09:27:26 +02:00
264f1279fd Promote PairRMJudge to top-level import (#1985)
* allow `from trl import PairRMJudge`

* test_pair_rm_judge

* Update setup.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-27 21:04:05 +02:00
0cda2f2f01 Restore test (#1982) 2024-08-27 11:16:32 +02:00
e0ff66103e Update tests for _get_kl_dataset (#1974)
* Test for #1970

* style

* drop last element in the batch for test

* check prompt_input_ids not modified

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-27 11:00:43 +02:00
3a3ed88f28 Fix dataset_num_proc missing in PPOConfig (#1966)
* fix a few minor bugs in ppo.py

* dataset_num_proc as training arg

* num proc in config

* Update examples/scripts/ppo.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-27 10:59:45 +02:00
b65657f41d Fix flaky Hub tests (#1981)
* Fix flaky Hub tests

* Trigger Build

* test buld
2024-08-27 10:14:39 +02:00
de024ece28 Use weights_only for load (#1933)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-26 18:18:38 +02:00
2fbc0f4fc2 Fix issue template path (#1973)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-26 14:40:37 +02:00
cf5168ea7c New mismatch pair creation strategy (#1970)
* new mismatch pair creation strategy

* style

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-26 13:29:22 +02:00
1e4fb80cbc Fix issue with unnecessary cached during logp calc. (#1969) 2024-08-26 12:38:58 +02:00
fe41acd6ae add arg padding_free to DataCollatorForCompletionOnlyLM (#1887)
* add arg `padding_free` to DataCollatorForCompletionOnlyLM

* Update tests/test_data_collator_completion_only.py

* Update trl/trainer/utils.py

* Update tests/test_data_collator_completion_only.py

* Update tests/test_data_collator_completion_only.py

* Update tests/test_data_collator_completion_only.py

* Update tests/test_data_collator_completion_only.py

* Update test_data_collator_completion_only.py

* Update tests/test_data_collator_completion_only.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2024-08-26 09:48:39 +02:00
c71262c9c6 Fix issue with precompute_ref_log_probs not working when rpo_alpha is None (#1961)
* Fix issue with precompute_ref_log_probs not working when rpo_alpha is None

* Test: Add test for precompute_ref_log_probs with rpo_alpha=None

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-25 12:15:57 +02:00
dcee683d96 Add issue/PR templates, code of conduct & better contributing guide (#1963)
* Add issue/PR templates, code of conduct & better contributing guide

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-23 23:12:40 +02:00
4788e5cda5 Support LLaVA-NeXT in Vision SFT (#1959)
* support llava next

* mention version for llava-next

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-23 11:37:40 +02:00
6cea2ef964 [ODPO] Refactor training script to use messages API (#1958)
* Refactor dataset prep

* Add moar doc
2024-08-22 20:03:12 +02:00
64d9816eac Fix response truncation in examples/notebooks/gpt2-sentiment.ipynb (#1957) 2024-08-22 16:22:46 +02:00
67564fdbbe "help wanted" in label to exempt from stale (#1956)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-22 11:27:37 +02:00
e529579232 Fix global step for consistent checkpointing with global updates (#1950) 2024-08-21 10:19:37 +02:00
dc4cfab700 Log WandB tables on main process (#1951) 2024-08-20 16:42:51 +02:00
66d3a82dd2 Add a simple-to-understand example for online DPO (#1947)
* Update online_dpo_trainer.md

* Update docs/source/online_dpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/online_dpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/online_dpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/online_dpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update online_dpo_trainer.md

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-20 16:14:40 +02:00
3eda856371 Don't mark issues as stale if nobody answered (#1949)
* don't mark issues as stale if nobody answered

* refactor

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-20 15:13:40 +02:00
616a273ac2 Fix model wrapping for online DPO (#1946) 2024-08-19 18:17:11 +02:00
9955583829 Drop token arg in push_to_hub (#1945)
* Skip token in `push_to_hub`

* fix doc

* move comment

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-19 11:34:11 +02:00
bed205a2d2 Properly tag models when pushed to 🤗 Hub (#1940)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-18 11:16:27 +02:00
42933fa647 Optional Additional Loss to Center Reward Models' Outputs (#1932)
* Implemented Eisenstein reward model centering

* Forgot self in accessing args

* Added docstring for center_rewards_coefficient.

* Fixed bug.

* Update trl/trainer/reward_config.py

Added a reference.

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Switched to Quentin's suggestion

* Update trl/trainer/reward_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* doc

* 0.01

* style

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-17 22:44:03 +02:00
bbdef00961 Fix model to save in PPOv2 (#1776)
* fix model to save in ppov2

currently saving self.backup_model but this should be self.model
self.backup_model is only a temp model used to store the policy and
value function whereas self.model should have just the policy to save

* simplified logic

* remove unused ordereddict

* format

* fix the fix

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-17 17:47:01 +02:00
0956dc17cc Add tests for DPO for VLM (#1935)
* add dpo visual test

* skip last layer of llava in test

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-16 16:29:40 +02:00
a7dc892717 Anchored preference optimization loss for DPO (#1928)
* feat: anchored pref optimization

* Update trl/trainer/dpo_trainer.py

* format and properly deprecate loss_type

* add aot in error message and reorder

* add "sppo_hard", "nca_pair" in label_smoothing warning warning

* add tests

* doc

* doc fixes

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-14 17:37:49 +02:00
b0372e66a5 Improve DPO/loss doc (#1929)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-14 16:52:26 +02:00
c1b272f4a6 minor BCO fixes (#1923)
* checkpointing BCO UDM classifier

* kto_config remove unused parameters

* BCO fix loading

* kto_config remove unused parameters

* kto_config remove unused parameters

---------

Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-14 15:27:13 +02:00
f05f63c1ea PartialState().local_main_process_first() when map in examples (#1926)
* `PartialState().local_main_process_first()` when map in examples

* allow load from cache

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-14 12:01:03 +02:00
54f806b6ff Standardize dataset_num_proc usage (#1925)
* uniform dataset_num_proc

* num_proc in shuffle

* Update examples/datasets/anthropic_hh.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/ppo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/ppo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-13 15:10:39 +02:00
a9a756553f Add explicit library name for TRL repos (#1922) 2024-08-13 11:36:01 +02:00
96bb3deb32 fix orpo trainer loss device (#1919) 2024-08-12 15:55:23 +02:00
dbea3da917 torch.cuda.amp.autocast() -> torch.amp.autocast("cuda") (#1921)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-12 14:43:38 +02:00
150a93101b lr_scheduler.step() call after optim.step() (#1918)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-12 14:21:50 +02:00
cbcaa46cd3 Various args and test fix (#1909)
* report to none

* simplify AlignPropTrainerTester

* rm unused marker

* Don't share setup in dpo trainer

* style

* don't share setup in test rich

* fix setup and classmethod

* fix args for sft

* test_trainer_args

* various arg fix

* report to none and vsdt simplifi

* drop generate_during_eval

* fix run_name

* style

* drop setUpClass

* style

* new ref values for ppo trainer tester

* update ref val

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-09 10:07:58 +02:00
e3fe28ee1a Fix AlignPropTrainer import (#1908) 2024-08-07 11:33:11 +02:00
fb0b9edc24 Fix GPT2 sentiment notebook reward (#1738)
* Fix reward change

* clean up notebook

* fix eval metric

* regenerate output with correct model

* swap wrong operation order

* Update gpt2-sentiment.ipynb

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-06 22:19:05 +02:00
fc76fe8d11 [Online-DPO] num_generation_per_prompt is fixed (#1898)
* num_generation_per_prompt is fixed

* remove unused no_grads

* removed bin

* fix scores

* fix scores

* formatting

* undo
2024-08-06 18:21:35 +02:00
b60ce797d8 Support Rank Stabilized LoRA in the ModelConfig/LoraConfig (#1877)
* feat: support RS-LoRA in the ModelConfig

* build: bump minimum peft version to support rslora

* test: add test for get_peft_config

* test: make test python 3.8 friendly

* rm unused marker

* minor changes

* simplify, clarify doc

* update deps (peft in test)

* re-ordering

* fix setup

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-06 18:02:59 +02:00
6faf4c0d81 [RPO] use loss from v3 of paper (#1904)
* RPO loss from v3

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix docs

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-06 16:28:46 +02:00
29bd0046a9 fix process orpo example (#1903)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-06 12:57:11 +02:00
4867c2a3db Support IterableDataset for SFTTrainer (#1899)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-05 18:04:17 +02:00
332062372d Drop setUpClass in reward tester (#1895)
* drop setUp class in reward tester

* report to none

* style

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-05 16:01:43 +02:00
b580e45c94 [WIP] Drop save/load test on windows (#1897)
* just test modelling

* Trigger CI

* always trigger

* only test_from_save_trl

* parametrize

* just one model

* file

* rm ref model

* assert exists

* style

* Update Makefile

* Update tests.yml

* Update Makefile

* Update test_modeling_value_head.py

* Update test_modeling_value_head.py

* skip windows

* skip test_from_save_transformers

* also skip test_from_save_trl

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-05 16:01:06 +02:00
2004d62c5c fix serialization of RunningMoments on multiple GPUs (#1892)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-04 10:57:28 +02:00
ac7c8b1284 evaluation_strategy -> eval_strategy (#1894)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-02 16:01:35 +02:00
df12913602 Fix SFT for VLM example (#1865)
* fix vsft example commands

* fix use_cache and get tokenizer from processor

* rm unused AutoTokenizer

* Squashed commit of the following:

commit 8bd2ab82f4cedc8b3459126aa145c63180078392
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 14:06:19 2024 +0200

    Refactor judges (#1856)

    * BaseJudge -> BasePairwiseJudge

    * hf judge asyncio

    * refactor judges

    * doc

    * doc

    * doc

    * memeber judge

    * :inherited-members:

    * :inherited-members:

    * doc

    * give up

    * judge tldr with judge class

    * fix rank in multithread

    * format

    * improve doc

    * update doc

    * typo doc

    * doc online dpo

    * Update judge_tldr.py

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 82b07d6b0169bb8150f2fa4ee0a58b678d597163
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:43:48 2024 +0200

    Llama in modelling value head tests (#1878)

commit 72bf6c21beedd95b1deb1ff95bd4d1bad5380503
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:33:07 2024 +0200

    Skip BigBird save and load test until next transformers version (#1874)

commit 74e54b5946b3e46c9fef516b6f5403943c7c4096
Author: Edward Beeching <edbeeching@users.noreply.github.com>
Date:   Fri Jul 26 09:36:25 2024 +0200

    fix online dpo example (#1879)

commit 393097356c3494a1310cd59b0205358723468443
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:17:37 2024 +0530

    Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)

    * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

    Added ```dataset_text_field``` in the SFTConfig while training

    * Update docs/source/sft_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit db8e09e3463837d6f80d593f2806c0d83d97c787
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:06:57 2024 +0530

    Import missing ```setup_chat_format``` (#1862)

commit 1dae55f90f6e929500df4fc4ee5bbc0146e35574
Author: elie <97572401+eliebak@users.noreply.github.com>
Date:   Thu Jul 25 10:27:34 2024 +0200

    add fsdp_qlora config and bnb_4bit_quant_storage (#1863)

commit c8cef79e6c895c9950ad7af61897f3a89372c56d
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 24 21:06:57 2024 +0200

    arXiv to HF Papers (#1870)

commit 7dcf437a1997cb1b252e8ea0b8ad7dca13261d7e
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 24 12:27:50 2024 +0200

    [online-DPO] online dpo cleanups (#1864)

    * online dpo cleanups

    * remove unused self.policy

    * add OnlineDPOTrainer and config to __init__.py

    * import from trainer

    * online dpo test

    * rename policy to model and ref_policy to ref_model

    * renamed internally

    * formatting

commit 4e85bd75a9dfca0074eef3a90130054c283eed39
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jul 18 14:35:31 2024 -0400

    Online DPO and Online trainer refactor (#1809)

    * online dpo trainer based on rloo trainer

    * push changes

    * refactor

    * use `batch_generation` method

    * precommit

    * remove breakpoint()

    * quick refactor

    * push the current changes

    * quick change

    * refactor

    * use the config name as the experiment name

    * fix logging

    * update online DPO docs

    * push docs

    * increment global step so tensorboard works again.

    * precommit

    * remove unused common online trainer

    * add online DPO docs

    * quick refactor

    * push changes

    * Update docs/source/online_dpo_trainer.md

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

    ---------

    Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit c9d56366ede5990d690f3b7a3f249c434f3633d6
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 18 18:28:49 2024 +0200

    rm token (#1852)

* add section in doc

* Squashed commit of the following:

commit 890232fa2861c40d46adeaf975a4209eb04fe841
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jul 30 14:29:47 2024 +0200

    update example overview (#1883)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 9929370dee9975f1c6d80b32198ea3e7fd0dcc06
Author: Clara Pohland <54847419+claralp@users.noreply.github.com>
Date:   Sun Jul 28 21:10:08 2024 +0200

    Move BCO to separate BCOTrainer with fixes (#1869)

    * kto_trainer: skip KL data for BCO

    * kto_trainer: BCO allow no positives or no negatives in batch

    * kto_trainer: make RunningMoments object serializable

    * add BCOTrainer

    * fix BCO UDM for not interleaved data

    * kto_trainer: remove unused UDM part

    * bco_trainer: add tests and docs, minor fixes

    * code style fixes

    * Update docs/source/bco_trainer.mdx

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * fix BCO UDM for bfloat16

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_config.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/utils.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_trainer.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_config.py

    * Update _toctree.yml

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_trainer.py

    * RunningMoments, fix multi GPU serialization

    * fix tests

    ---------

    Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

commit 6171cddee5165869af8b40b526476680cebe47ef
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:51:38 2024 +0200

    Re-add BigBird Pegasus save/load test (#1882)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 33d2151f4fa37728fea9448420301a1380fee745
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:07:10 2024 +0200

    Re-add BigBird Pegasus save/load test (#1876)

    * skip bigbird in ci

    * readd big bird test

    * pytest parametrize

    * dont check the version

    * rm model name

    * re add big bird

    * Merge branch 'main' into readd-bigbird-save-load-test

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 8bd2ab82f4cedc8b3459126aa145c63180078392
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 14:06:19 2024 +0200

    Refactor judges (#1856)

    * BaseJudge -> BasePairwiseJudge

    * hf judge asyncio

    * refactor judges

    * doc

    * doc

    * doc

    * memeber judge

    * :inherited-members:

    * :inherited-members:

    * doc

    * give up

    * judge tldr with judge class

    * fix rank in multithread

    * format

    * improve doc

    * update doc

    * typo doc

    * doc online dpo

    * Update judge_tldr.py

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 82b07d6b0169bb8150f2fa4ee0a58b678d597163
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:43:48 2024 +0200

    Llama in modelling value head tests (#1878)

commit 72bf6c21beedd95b1deb1ff95bd4d1bad5380503
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:33:07 2024 +0200

    Skip BigBird save and load test until next transformers version (#1874)

commit 74e54b5946b3e46c9fef516b6f5403943c7c4096
Author: Edward Beeching <edbeeching@users.noreply.github.com>
Date:   Fri Jul 26 09:36:25 2024 +0200

    fix online dpo example (#1879)

commit 393097356c3494a1310cd59b0205358723468443
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:17:37 2024 +0530

    Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)

    * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

    Added ```dataset_text_field``` in the SFTConfig while training

    * Update docs/source/sft_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit db8e09e3463837d6f80d593f2806c0d83d97c787
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:06:57 2024 +0530

    Import missing ```setup_chat_format``` (#1862)

commit 1dae55f90f6e929500df4fc4ee5bbc0146e35574
Author: elie <97572401+eliebak@users.noreply.github.com>
Date:   Thu Jul 25 10:27:34 2024 +0200

    add fsdp_qlora config and bnb_4bit_quant_storage (#1863)

commit c8cef79e6c895c9950ad7af61897f3a89372c56d
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 24 21:06:57 2024 +0200

    arXiv to HF Papers (#1870)

commit 7dcf437a1997cb1b252e8ea0b8ad7dca13261d7e
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 24 12:27:50 2024 +0200

    [online-DPO] online dpo cleanups (#1864)

    * online dpo cleanups

    * remove unused self.policy

    * add OnlineDPOTrainer and config to __init__.py

    * import from trainer

    * online dpo test

    * rename policy to model and ref_policy to ref_model

    * renamed internally

    * formatting

commit 4e85bd75a9dfca0074eef3a90130054c283eed39
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jul 18 14:35:31 2024 -0400

    Online DPO and Online trainer refactor (#1809)

    * online dpo trainer based on rloo trainer

    * push changes

    * refactor

    * use `batch_generation` method

    * precommit

    * remove breakpoint()

    * quick refactor

    * push the current changes

    * quick change

    * refactor

    * use the config name as the experiment name

    * fix logging

    * update online DPO docs

    * push docs

    * increment global step so tensorboard works again.

    * precommit

    * remove unused common online trainer

    * add online DPO docs

    * quick refactor

    * push changes

    * Update docs/source/online_dpo_trainer.md

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

    ---------

    Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit c9d56366ede5990d690f3b7a3f249c434f3633d6
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 18 18:28:49 2024 +0200

    rm token (#1852)

* simplify script

* doc

* use traning args

* args instead of trianing args

* fix doc

* drop eval

* rm eval section

* re-add bigbirg

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-02 10:31:51 +02:00
ddf4c8dc3e fix dpo_trainer bug for LLMs without bos_token in config (#1885)
* fix dpo_trainer bug for LLMs without bos_token in config

* fix adding bos_token_id bug in dpo,orpo,cpo trainers

* formatting for fixing bos_token adding bug

* Update trl/trainer/utils.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-07-31 12:42:06 +02:00
890232fa28 update example overview (#1883)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-30 14:29:47 +02:00
9929370dee Move BCO to separate BCOTrainer with fixes (#1869)
* kto_trainer: skip KL data for BCO

* kto_trainer: BCO allow no positives or no negatives in batch

* kto_trainer: make RunningMoments object serializable

* add BCOTrainer

* fix BCO UDM for not interleaved data

* kto_trainer: remove unused UDM part

* bco_trainer: add tests and docs, minor fixes

* code style fixes

* Update docs/source/bco_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* fix BCO UDM for bfloat16

* Update trl/trainer/bco_config.py

* Update trl/trainer/bco_config.py

Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

* Update trl/trainer/utils.py

Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

* Update trl/trainer/bco_trainer.py

Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

* Update trl/trainer/bco_config.py

* Update _toctree.yml

* Update trl/trainer/bco_config.py

* Update trl/trainer/bco_trainer.py

* RunningMoments, fix multi GPU serialization

* fix tests

---------

Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Seungjae Jung <seanexplode@gmail.com>
2024-07-28 21:10:08 +02:00
6171cddee5 Re-add BigBird Pegasus save/load test (#1882)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-28 15:51:38 +02:00
33d2151f4f Re-add BigBird Pegasus save/load test (#1876)
* skip bigbird in ci

* readd big bird test

* pytest parametrize

* dont check the version

* rm model name

* re add big bird

* Merge branch 'main' into readd-bigbird-save-load-test

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-28 15:07:10 +02:00
8bd2ab82f4 Refactor judges (#1856)
* BaseJudge -> BasePairwiseJudge

* hf judge asyncio

* refactor judges

* doc

* doc

* doc

* memeber judge

* :inherited-members:

* :inherited-members:

* doc

* give up

* judge tldr with judge class

* fix rank in multithread

* format

* improve doc

* update doc

* typo doc

* doc online dpo

* Update judge_tldr.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-28 14:06:19 +02:00
82b07d6b01 Llama in modelling value head tests (#1878) 2024-07-26 11:43:48 +02:00
72bf6c21be Skip BigBird save and load test until next transformers version (#1874) 2024-07-26 11:33:07 +02:00
74e54b5946 fix online dpo example (#1879) 2024-07-26 09:36:25 +02:00
393097356c Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)
* Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

Added ```dataset_text_field``` in the SFTConfig while training

* Update docs/source/sft_trainer.mdx

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-07-25 10:47:37 +02:00
db8e09e346 Import missing ``setup_chat_format`` (#1862) 2024-07-25 10:36:57 +02:00
1dae55f90f add fsdp_qlora config and bnb_4bit_quant_storage (#1863) 2024-07-25 10:27:34 +02:00
c8cef79e6c arXiv to HF Papers (#1870) 2024-07-24 21:06:57 +02:00
7dcf437a19 [online-DPO] online dpo cleanups (#1864)
* online dpo cleanups

* remove unused self.policy

* add OnlineDPOTrainer and config to __init__.py

* import from trainer

* online dpo test

* rename policy to model and ref_policy to ref_model

* renamed internally

* formatting
2024-07-24 12:27:50 +02:00
4e85bd75a9 Online DPO and Online trainer refactor (#1809)
* online dpo trainer based on rloo trainer

* push changes

* refactor

* use `batch_generation` method

* precommit

* remove breakpoint()

* quick refactor

* push the current changes

* quick change

* refactor

* use the config name as the experiment name

* fix logging

* update online DPO docs

* push docs

* increment global step so tensorboard works again.

* precommit

* remove unused common online trainer

* add online DPO docs

* quick refactor

* push changes

* Update docs/source/online_dpo_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-07-18 14:35:31 -04:00
c9d56366ed rm token (#1852) 2024-07-18 18:28:49 +02:00
4dce042a38 Add WinRateCallback and Judges (#1598)
* Add WinRateCallback

* Enable PairRM

* Refactor

* Streamline

* Add HF judge

* Add base judge

* Use better prompt

* Clean

* Add max tokens

* Use logging

* Add batched inference

* Squashed commit of the following:

commit 9e9dc96e676a3601882b5cf11842bd22267fd2c5
Author: Maxim Kopecki <kopecki.maxim@gmail.com>
Date:   Wed Jul 10 19:11:13 2024 +0200

    Added missing token kwarg in Peft model loading (#1825)

commit 7ddef5c1582f14f32b6dd692f8e4b904fd478038
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 10 18:26:11 2024 +0200

    Make use of `trust_remote_code` consistent (#1806)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit a9cddf8c55a0b2af101a3d18bd92f263f4ae4500
Author: Adnan Khan <AdnaneKhan@users.noreply.github.com>
Date:   Wed Jul 10 11:25:07 2024 -0400

    Delete unused benchmark.yml workflow. (#1822)

commit 2860ce5091e689bab167454453e9ddbe2337de3d
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jul 9 09:22:52 2024 +0200

    DPO Llava 1.5 and PaliGemma support (#1797)

    * llava support dpo

    * add_special_tokens=False only when possible

    * format

    * pali gemma

    * refactor size

    * remove image resize

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 30e33bd92da1f5569493e16da8971247cc376927
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jul 9 05:37:12 2024 +0200

    upgrade gh actions (#1818)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit d5a0d2d345ec26646ceaa06adfe6133aad18702a
Author: Costa Huang <costa.huang@outlook.com>
Date:   Mon Jul 8 11:12:41 2024 -0400

    Set dev version (#1817)

commit 314e8eb367cbfaf74c2e9717085346360e779508
Author: Puneet Singh Bhooi <puneetb@iiitd.ac.in>
Date:   Mon Jul 8 19:11:36 2024 +0530

    fix broken url in `docs\source\index.mdx` (#1813)

commit e10792032be644a65dcbcf2ebe9ec947497d4d46
Author: Costa Huang <costa.huang@outlook.com>
Date:   Mon Jul 8 09:38:09 2024 -0400

    0.9.6 release (#1816)

commit 78045dedc8678af04f4e35ffe63f37be196a435b
Author: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date:   Mon Jul 8 01:59:26 2024 +0200

    Fix `TRL_USE_RICH` environment variable handling (#1808)

    * Add `strtobool` custom implementation from `distutils`

    * Fix `TRL_USE_RICH` handling via `strtobool`

    * Run `make precommit`

commit 747612f9d3063de56b6524e5feb0c9feab21d4c4
Author: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date:   Fri Jul 5 16:28:59 2024 +0200

    Fix `torch_dtype` handling in `{DPO,SFT}Trainer` when provided via CLI (#1807)

    * Fix `torch_dtype` handling through CLI

    The `torch_dtype` is not properly handled when provided via the TRL CLI
    since it's provided initially as a string, but is then casted to
    `torch.dtype` before providing it to the `{DPO,SFT}Trainer`, which means
    that those trainers should handle the scenario where `torch_dtype` is a
    `torch.dtype` too.

    * Add `torch_dtype` tests in `test_{dpo,sft}_trainer.py`

    * Forward contribution credits

    * Run `make precommit`

    ---------

    Co-authored-by: Tash Srivastava <yash-srivastava19@users.noreply.github.com>

commit 9e3a35bd3d85ee506d180120f01bde2229b60265
Author: Michael <mnoukhov@gmail.com>
Date:   Fri Jul 5 07:29:48 2024 -0400

    Remove extra print in reward_trainer.py (#1799)

    `print_rich_table` is called twice and the first call doesn't restrict to `num_print_samples`. Remove the first, extra call

commit 4402b36dcf79a0921a858c77375cfbb285d603c7
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 4 14:29:25 2024 +0200

    clean examples (#1791)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 78f8228874d5cf9c0e68952533cb377202e1eb22
Author: Noah Tye <hi@noahtye.com>
Date:   Wed Jul 3 11:10:50 2024 -0700

    Bugfix: Preserve token fields when converting TrainingArguments to SFTConfig (#1794)

    * Preserve token fields when converting TrainingArguments to SFTConfig

    TrainingArguments.to_dict() redacts token fields, so we have to
    individually copy them over when converting to SFTConfig to avoid
    breaking push_to_hub functionality.

    Also adds a test.

    * run precommit

    * one-line args_as_dict definition per suggestion from kashif

    * generalize token copying to match TrainingArguments behavior

    * unwrap |= on dict, to support python 3.8

    * use .update instead of |= or for-loop

commit b6af2edc93b275afcee22a3eb71f9a5702ff9fd8
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 3 08:29:16 2024 +0200

    add model_init_kwargs to training_args (#1787)

commit cd85b14fbbaf7e4d9b01ef8ec19655666af20047
Author: Tommaso Buonocore <buonocore.tms@gmail.com>
Date:   Sat Jun 29 15:35:48 2024 +0200

    Fixed typo in SFT trainer docs (#1788)

    'STFConfig' instead of 'SFTConfig' appears multiple times in the doc, causing error when running the code snippets.

commit a57544f47a2fbc4940b4d49dde32f54406398c91
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Thu Jun 27 15:47:58 2024 +0200

    fix docs and examples (#1780)

commit b68ff96f0c74368961e194081e122959cd1f4d4d
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jun 26 16:26:37 2024 +0200

    Visual DPO (#1647)

    * Remove extra whitespaces

    * idefics

    * vdpo

    * sft idefics

    * pad with test

    * use prompt instead of tokenizer

    * rm name main

    * support vlm in tokenize row

    * temp fix for regex in lora_target_module

    * format

    * vdpo

    * tmp float16 hard code

    * concatenated_forward support for vision

    * style and new command line

    * all-linear

    * format

    * delete old examples

    * get image

    * upcast

    * new test

    * modified test

    * new strat for tokenizer

    * rm token transfer

    * integrate vision in dpo example

    * format

    * add FDivergenceType back

    * precommit

    * pillow test dep

    * optional prompt

    * `evaluation_strategy` to `eval_strategy`

    * revert vsft change (oos)

    * update test

    * test

    * comment and support more in process

    * update process

    * update doc for vdpo

    * caution about limited support

    * Update docs/source/dpo_trainer.mdx

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * revert DPO example changes

    * cleaner way to check if a model is vision

    * comment

    * update vdpo example

    * rename

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit c8c01cc05569f5ffea6726b2111f799a63e03aaa
Author: Mubin Manasia <48038715+Mubin17@users.noreply.github.com>
Date:   Wed Jun 26 03:23:36 2024 -0600

    Fix Documentation Overflow Issues for Long URLs in SFTConfig (#1774)

    * Update sft_config.py

    * Update sft_config.py

commit 3479606c8c6dbb5da96e4990b491e63a48fc7483
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed Jun 26 03:18:22 2024 -0400

    Remove the leading space in the tldr preference dataset (#1773)

commit 7965b7834052ab3d60a1cc5de382e2f56b3772e7
Author: Haozhe Ji <jihaozhe@gmail.com>
Date:   Tue Jun 25 22:47:32 2024 +0800

    add Efficient Exact Optimization (EXO) (#1735)

    * add exo

    * fix a detail

    * Update trl/trainer/dpo_trainer.py

    * Update trl/trainer/dpo_trainer.py

    * Update trl/trainer/dpo_trainer.py

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 56bd1bba26ac52aad976c1a1a0b3d9e1137b18c7
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jun 25 16:14:26 2024 +0200

    `evaluation_strategy` to `eval_strategy` (#1771)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 94d53e6617edc6434a38b2ac51c21e5da3329cda
Author: Clara Pohland <54847419+claralp@users.noreply.github.com>
Date:   Mon Jun 24 21:27:00 2024 +0200

    MoE Models: option to add load balancing loss (#1765)

    * KTO: add aux loss

    * use router_aux_loss_coef in KtoTrainer when aux_loss enabled

    * align optional aux_loss in DPO, KTO, CPO, ORPO

    * precommit changes

    * fix KL forward kwargs

    * add aux_loss doku entry

    * apply docs suggestions

    ---------

    Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>

commit b5be100ae0b37d743cd49435297f917eb54a0574
Author: Mihir Prabhudesai <mihirp1998.mp@gmail.com>
Date:   Mon Jun 24 12:05:44 2024 -0400

    Added Reward Backpropogation Support  (#1585)

    * added alignprop template

    * added alignprop support

    * Update alignprop_trainer.mdx

    * Update alignprop_trainer.mdx

    * added better why statement

    * fixed inference code

    * changed self to pipeline

    * removed aesthetic classifier

    * added aesthetic to auxiliary models

    * added unseen prompt logging

    * removed unseen prompt log

    * fixed minor

    * remove not needed import in trl/__init__.py

    Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

    * fixed styling

    * updated _toctree

    ---------

    Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

commit 6e1652bc5e8ff6d348c7f06048f4102a050f1544
Author: Haoran Xu <45837851+fe1ixxu@users.noreply.github.com>
Date:   Sun Jun 23 09:54:30 2024 -0700

    Add CPO-SimPO method (#1760)

    * enable cpo-simpo

    * highlight SimPO and CPO-SimPO

    * add test for cpo_alpha

    * formatting

    * Update docs/source/cpo_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 65374c6a711709157ea59297dce43dfb458d1c78
Author: Costa Huang <costa.huang@outlook.com>
Date:   Fri Jun 21 11:20:54 2024 -0400

    New sentiment and descriptiveness dataset (#1757)

    * push changes

    * handle edge cases where the chosen and the rejected are the same

commit 99560911123f739226b77813f27d5c90ed7f9ba2
Author: Juyoung Suk <scottsuk0306@gmail.com>
Date:   Fri Jun 21 18:01:08 2024 +0900

    Add dataset_text_field in examples/scripts/sft.py (#1758)

commit 34d273f227b30507c6d94ff1f93b6939794f38a3
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jun 20 13:16:43 2024 -0400

    Support num_train_epochs (#1743)

    * add a test case for num_train_epochs

    * fix ci

    * quick change

    * disable push to hub

    * debug windows ci

    * try another fix

    * skip subprocess tests on windows

commit 3bf94492a8dc84ac192f7c5206553e1460f53aa4
Author: Mert Sayar <mert.sayar@gmail.com>
Date:   Thu Jun 20 18:22:20 2024 +0300

    Fix masking of response tokens (#1718)

    Current handling of `response_masks` inside `batch_forward_pass`
    function does not take padding into consideration which results with
    shape unmatch during masking. Since response mask is a mask tensor of
    response tokens, response tokens should not be concatenated with a
    `torch.zeros(query_length)` and masking operation should be done without
    slicing.

    Remove the concatenation of the response mask, remove the slicing from
    the response mask since response mask already has the length of `end -
    start + 1`, which is equal to length of `masks[j, start:end]`.

commit ba6abee37f0f0463f6d891d63d0c2242039fc8ec
Author: idanshen <49375140+idanshen@users.noreply.github.com>
Date:   Thu Jun 20 09:14:16 2024 -0400

    Support for returning past_key_values from the model (#1742)

    * add support for returning past_key_values from the model

    * change order of  keys

commit a57e75967c2b787f42f4e402ed7ca23cd9bad9a9
Author: 1485840691 <110707330+1485840691@users.noreply.github.com>
Date:   Wed Jun 19 18:02:51 2024 +0800

    Integrate f-divergence to DPO (Follow up) (#1610)

    * Step 1: update ppo_trainer and hello_world example

    * Step 2: Refine comments and add parameter type

    * Step 2: Add missing parameter comments

    * Step 1: Organize ptx loss into a function and add ptx_loss to train_stats

    * Step 1 updates: add comment to ptx_loss function, fix a bug and add warning message

    * Step 2: 1) Add ppo_ptx trainig example as ppo; 2) separate pretrain data fetch and iterate

    * Step 2: Remove loss from columns_to_log in ppo_ptx example

    * Remove data set revision in load imbd dataset

    * Run pre-commit and fix format issues

    * Initial draft of f-divergence fn

    * Update f-divergence to avoid overflow

    * fix test errors and comments

    * Add Unit tests for dpo loss with alpha and js div f

    * Adjust format

    * Fix test error

    * Reverse this update

    * Add test cases

    * Reverse un-needed updates

    * Update code style

    * Try to fix code fmt error

    * remove extra end line

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit ae23d40f3b4d91d60a6153825ecf0319449d34b1
Author: Shihyueh Hsu <66808901+AIR-hl@users.noreply.github.com>
Date:   Tue Jun 18 22:07:24 2024 +0800

    change the `process` function in the example of DPO (#1753)

    * change the `process` function in the example of DPO

    * fix

commit 83b367b11a308b488ff9ddcf19cf4cfd6a7db642
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Tue Jun 18 11:31:17 2024 +0200

    CI / `KTOTrainer`: Remove old tests (#1750)

    * remove old tests

    * remove datasets

    * Update test_dpo_trainer.py

    * Update test_dpo_trainer.py

commit d1ed730ab8281b1b0c78d7d61bc0f6603a9ce958
Author: Michael <mnoukhov@gmail.com>
Date:   Mon Jun 17 10:50:21 2024 -0400

    prepare deepspeed accomodate fp16 and bf16 (#1728)

    * prepare deepspeed accomodate fp16 and bf16

    * precommit

commit 8f8e95e25d10c433cc1f2f8c7dcfed218bb13ac7
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 16:49:00 2024 +0200

    CPO / DPO: Fix red CI (#1749)

    * fix red CI

    * precommit

commit 4e23d958f20fd4fdd795cb06c2cdb7ebea704855
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 16:41:36 2024 +0200

    fix red CI

commit 50c46205b6fe741f11959adf7ec9cc0386f406bc
Author: Kawin <kawin.ethayarajh@gmail.com>
Date:   Mon Jun 17 07:14:44 2024 -0700

    small KTO fixes (#1734)

    * add warning for imbalanced data

    * update documentation

    * update script commands to be same as in dpo

    * use batch_size KL examples and batch_size target examples to calculate batch_size losses

    * fix deepspeed issue

    * speed up forward with no_grad for KL

    * add some removed metrics

    * Update trl/trainer/kto_trainer.py

    * Update trl/trainer/kto_trainer.py

    * Update trl/trainer/kto_trainer.py

    add reference to paper

    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * add more detailed comments

    * convert assert to ValueError

    * Update kto_trainer.py

    * precommit formatting

    * remove nans in metrics by gathering across machines

    * fix formatting

    * fix choice of mismatched examples for KL term

    * describe weights

    * fix hanging issue in distributed training

    * linting

    * move metrics to cpu

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

    * Update trl/trainer/kto_trainer.py

    * Update trl/trainer/kto_trainer.py

    * remove kto_pair

    * speed up data processing

    * move bco code inside

    * raise error for kto_pair argument

    * fix formatting

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
    Co-authored-by: Winnie Xu <winnie.xu97@gmail.com>

commit 6105d03f92e7069ffaa565d05418dec371569e6a
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 16:01:06 2024 +0200

    `TrlParser`: Add ignore extra args option (#1748)

    * add ignore extra args option

    * Update trl/commands/cli_utils.py

commit e247bbd7d5f57f8012ca71cfef6ad6a589874c34
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 15:16:07 2024 +0200

    CI / core: Pin `numpy` to `!=2.0.0` for CI and to users (#1747)

    * Update setup.py

    * Update setup.py

    * Update setup.py

    * Update test_best_of_n_sampler.py

    dummy commit

    * pin numpy

    * Update tests/test_best_of_n_sampler.py

    * Update setup.py

commit 3d044961960a2ab1ec1f51cfe62c6bf6b9a94807
Author: Michael <mnoukhov@gmail.com>
Date:   Mon Jun 17 08:43:33 2024 -0400

    better trl parser with yaml config (#1739)

    * working trl parser with config

    correctly overrides yaml config with command line arguments
    adds return_remaining_strings
    when return_remaining_strings is False, raises error if yaml contains
    extra args that are not in the dataclasses
    simpler and cleaner than previous yaml parsing and merging
    addresses #1733

    * lowercase trlparser

commit 2d244f8acb204cb2ddb83a4ef017ca4b1f2d366a
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 11:56:13 2024 +0200

    Workflow: Notify tests results on slack channel (#1744)

    * Update tests-main.yml

    * Update docker-build.yml

commit f5168fdbaf9cbf6a3f1bdc64dc44b9db3a9ae333
Author: Igor Melnyk <igoraries@gmail.com>
Date:   Wed Jun 12 05:54:54 2024 -0400

    adds AOT (#1701)

    * adds AOT

    * Applied format changes

    * added docs and tests

    ---------

    Co-authored-by: Igor Melnyk <igor.melnyk@ibm.com>

commit 79686e1ac701b1f5e3709a65efa8f13363bcde06
Author: jetlime <paul.houssel@yahoo.de>
Date:   Wed Jun 12 00:35:31 2024 +1000

    ktotrainer: Refuse datasets which contain only one class of labels (#1724)

    * ktotrainer: refuse dataset which contain only one class of labels

    * ktotrainer: document new dataset constraint

commit 34ebc4ccaf376c862a081ff4bb0b7e502b17b2fb
Author: Luc Georges <McPatate@users.noreply.github.com>
Date:   Mon Jun 10 11:17:54 2024 +0200

    feat(ci): add trufflehog secrets detection (#1721)

    * feat(ci): add trufflehog secrets detection

    * fix(ci): remove unnecessary permissions

commit 1d84e2b888ea0f3c1ce9c5c175f7f680d85273a8
Author: Michael <mnoukhov@gmail.com>
Date:   Fri Jun 7 11:42:08 2024 +0200

    Fix default padding_value in dpo_config.py (#1692)

    dpo_config default padding value should be None, not 0, otherwise it by default overrides the padding value of any tokenizer to 0

commit 2f71b8b1e2e54184cc278f267cca1bda051f68ea
Author: Michael <mnoukhov@gmail.com>
Date:   Fri Jun 7 10:37:27 2024 +0200

    fix yaml parser for derived config classes (#1713)

    fixes #1712
    reformatted cli_utils with ruff

commit 5bcb8ad0d6eaee1b1d2f993380100c37c4421fd0
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Fri Jun 7 08:48:17 2024 +0100

    RDPO fix nll loss (#1705)

commit b8b972fde183ec036885738e1439cd99877c2ad5
Author: Haoran Xu <45837851+fe1ixxu@users.noreply.github.com>
Date:   Thu Jun 6 14:06:47 2024 -0700

    Add a variant of CPO, SimPO (#1703)

    * add a variant of cpo: simpo

    * correct cpo-simpo loss

    * avoid 0 int error in logging

    * add simpo description

    * Update trl/trainer/cpo_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * fix formatting

    * add test for simpo

    * Update docs/source/cpo_trainer.mdx

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * add a docstring for simpogamma

    * move simpo description to the above docstring

    * change simpo description in the doc

    * formatting

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 3eb9ccb104e2c46360adb937f3f25871c167eb90
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Thu Jun 6 19:33:20 2024 +0200

    set dev version (#1710)

    * Update setup.py

    * Update __init__.py

commit 974b0d380f12c357b70265c5f2dd2c8cb39a6a3e
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jun 6 10:13:00 2024 -0400

    0.9.4 release (#1708)

commit 39a7d1c121d26224fd7455d3d2038e0d20831c54
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Thu Jun 6 15:50:17 2024 +0200

    SFTTrainer: Fix backward Compatibility issue with `TrainingArguments` (#1707)

    * fix BC

    * fixup

commit 0bdc63839f1abe67c56befa63251425b1ffc1ace
Author: Guilherme Freire <guilhermebfreire@gmail.com>
Date:   Thu Jun 6 14:42:58 2024 +0100

    Fixed doc string and docs for the SFTConfig update (#1706)

commit 275d33b3ef4f7afd40f79cc53591659bacfa3499
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed Jun 5 14:34:59 2024 -0400

    0.9.3 release (#1699)

commit c0819ee99fdf673e9843ef91789b928ae9050623
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Wed Jun 5 17:29:03 2024 +0200

    Update sft_trainer.py (#1698)

commit a03e7cc4e443e30eea942ca66bfce19407784f32
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed Jun 5 11:00:19 2024 -0400

    Release 0.9.2 (#1697)

    * Release: 0.9.0

    * Release

commit a13cb8952c55cfa4fc696d900a1b2a81d329c82d
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed Jun 5 10:20:54 2024 -0400

    Quick fix on GPT4-eval (#1696)

    * quick fix

    * precommit

commit 84156f179f91f519e48185414391d040112f2d34
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Jun 3 20:09:05 2024 +0200

    Fix typo in DPOTrainer's warnings (#1688)

commit 4eb0b905e28857341123d5329a6ca1b9d929734f
Author: Alex Brooks <alex.brooks@ibm.com>
Date:   Mon Jun 3 10:24:32 2024 -0600

    Skip packing validation (#1673)

    * Add test for skipping preproc if packing=True

    Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

    * Allow skipping of validation for packing=True

    Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

    * Use dummy dataset in no packing preproc test

    Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

    ---------

    Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

commit 6c203f9fef50c41d27fc4ed9965df7e458f02377
Author: Alexey Rozhkov <alexisrozhkov@gmail.com>
Date:   Mon Jun 3 10:16:22 2024 +0100

    Fix overriding optimize_device_cache with optimize_cuda_cache in PPOConfig (#1690)

    * Don't override optimize_device_cache when optimize_cuda_cache is not provided
    Raise an exception when both optimize_cuda_cache and optimize_device_cache are set

    * Minor fix

commit f18253bf2d747f68acc9cd89da95c85ebf59dbb9
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Mon Jun 3 09:43:02 2024 +0100

    intial RPO loss (#1686)

    * intial RPO loss

    * fix sign

    * clean up

commit 151a452d14c8ebccbaf8a033812ceb2dc77f634d
Author: Samuel <s.kiegeland@gmx.de>
Date:   Wed May 29 20:29:38 2024 +0200

    Fix max completion length (#1588)

commit 488b502d31c052801eacd9a047bf3db06623e9c2
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Wed May 29 20:19:26 2024 +0200

    fix (#1678)

commit 3c0a10b1aedbe533005dbfe18f2cc8057093f80b
Author: Wang, Yi <yi.a.wang@intel.com>
Date:   Mon May 27 20:52:20 2024 +0800

    fix dataset load error (#1670)

    Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

commit b031adfdb8708f1f295eab6c3f2cb910e8fe0c23
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Fri May 24 15:20:16 2024 +0200

    FIX / PPO: Fix `enable_input_require_grads` issues with PPO models (#1664)

    * Update modeling_base.py

    * Update ppo_config.py

    * Update ppo_trainer.py

    * style

commit e7cb597230bb0c630c67790881b0808f7b16cb05
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu May 23 11:37:16 2024 -0400

    Fix ppov2 test case (#1661)

    * Fix PPOv2 / RLOO refactor's stuff

    * update terminology to use stop token

commit bc8dfbf4e2169010b3094913a1fa4f888f750111
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Thu May 23 15:28:04 2024 +0200

    update eval_strategy (#1662)

commit e4ed7a3a5aa0f1e1b4f78317b3c7b25e5bf597f4
Author: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Date:   Thu May 23 18:34:22 2024 +0530

    do not upcast adapters when using FSDP+QLoRA (#1654)

commit 9a7efbd05126fa6a1448a95f670e8d04cac90d62
Author: syrn1k <85796210+syrn1k@users.noreply.github.com>
Date:   Thu May 23 15:58:49 2024 +0300

    🤫 TR-DPO implementation (#1593)

    * 🤫 TR-DPO implementation baseline

    * fix comments

    * docs

    * fix linters

    * test added

    * move configs to DPOConfig

    * fix typo

    * add docs

    * fix import

    * use state.global_step

    * fix order of arguments

    * make sure plugins are not none

    * Update trl/trainer/utils.py

    Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

    * Update trl/trainer/utils.py

    Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

    * checking that reference model weights have changed

    * sync_target_model as staticmethod

    * set reference model

    ---------

    Co-authored-by: Nikita Surnachev <n.surnachev@tinkoff.ru>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

commit b344bcea2c0b30d58ab6ebb0380647f24056ac58
Author: Anush Kini <33577829+Abilityguy@users.noreply.github.com>
Date:   Thu May 23 18:27:25 2024 +0530

    [DPO] Add 'robust' loss_type (#1653)

    * Initial commit

    * pre-commit fix

    * Minor change to comments

    * Added some documentation on how to use Robust DPO

commit 35e12dc5959fa8a08edd72b34aadcb0acb284e51
Author: Nicolinho <Nicolinho@users.noreply.github.com>
Date:   Thu May 23 14:36:15 2024 +0200

    Fix inheritance order in PPOv2Config (#1659)

    * fix inheritance order in PPOv2Config

    * fix inheritance order in rloo_config

commit 1da6be18e0e21a11ee2a2121ae744c5e2e904409
Author: Ali Bakly <anbakly@gmail.com>
Date:   Thu May 23 14:10:29 2024 +0200

    docs: correct cDPO usage in DPOTrainer (#1655)

commit e249cd802fb81cff3c4ceb1427cb666a138221d3
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Thu May 23 14:10:05 2024 +0200

    add support for training collator (#1658)

commit a02513c3b7085adba5fd18727296f4f4affd3ffb
Author: Zach Mueller <muellerzr@gmail.com>
Date:   Thu May 23 06:48:00 2024 -0400

    Apply deprecated `evaluation_strategy` (#1559)

    * Deprecate

    * Update tests/test_dpo_trainer.py

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 13454d2f4b243b7260fa4ec828297812c3f975fc
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed May 22 08:31:10 2024 -0400

    PPO / Reinforce Trainers (#1540)

    * Add ppov2 trainer

    * make eos trick optional, remove unused args

    * quick fix

    * precommit

    * update debugging script

    * fix out of bound `drop_last=True`; use built-in scheduler

    * Add PPO examples

    * push changes

    * quick change

    * quick change

    * various bug fixes

    * remove unnecessary grad accumulation setting

    * push new changes

    * fix DS3 model saving

    * update ppo.py

    * refactor

    * quick change

    * refactor

    * update ppo trainer

    * refactor

    * quick test

    * add ds2 /ds3 7 processes config

    * add vllm trainer

    * quick change

    * experiment with reward normalization

    * push changes

    * quick push

    * push changes

    * push various changes

    * refactor to use ModelConfig

    * quick change

    * refactor

    * refactor

    * Simplify DS logic

    * quick update

    * remove unnecessary files

    * precommit

    * deepspeed fix; handle edge case when eos_token_id = 0

    * add PPO tldr example

    * add TL;DR example

    * fix undefined var

    * utilize all samples in rloo

    * quick setting

    * remove the unnecessary `value_model`

    * use exact_div

    * allow saving the deepspeed model

    * refactor

    * remove dead code

    * Use some shared utilities

    * add some end-to-end test cases

    * add PPOv2 docs and RLOO docs / tests

    * update docs

    * quikc push

    * fix ci

    * fix type annotation for ci

    * quick update

    * update trainer docs

commit 99f2c94b2200927a1dc156f16e012dca11f865e1
Author: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Date:   Wed May 15 19:55:46 2024 +0530

    don't cast the trainable lora layers to half precision (#1644)

    * don't cast the trainable lora layers to half precision

    * quality

commit 6401d080c9f97e0610678b12d3d0056347675726
Author: Wing Lian <wing.lian@gmail.com>
Date:   Tue May 14 09:41:07 2024 -0400

    Pairwise Noise Contrastive Alignment (#1632)

    * add NCA paired preference loss

    * chore: lint

    * set more lenient tolerance for integration tests

    * Update tests/test_dpo_trainer.py

    * skip test

    * fix

    ---------

    Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
    Co-authored-by: younesbelkada <younesbelkada@gmail.com>

commit d632a5b289782c7384f5275426054e79acc0b744
Author: bartoszzuk <57541034+bartoszzuk@users.noreply.github.com>
Date:   Tue May 14 12:25:54 2024 +0200

    Fixed wrong logs prefixes in KTOTrainer (#1641)

    * Fixed wrong logs prefixes in KTOTrainer

    * Pre-commit formating

commit 5aeb752053876cce64f2164a178635db08d96158
Author: Tiezhen WANG <38108242+xianbaoqian@users.noreply.github.com>
Date:   Fri May 10 23:19:15 2024 +0800

    Update sft_llama2.py to work with the latest API (#1637)

    * Update sft_llama2.py to work with the latest API

    SFTTrainer now takes a STFConfig argument

    * Update dpo_llama2.py

    * precommit

commit b8b89783ca1ab081d25651a9a13e9358cc8e1869
Author: Ilya Gusev <phoenixilya@gmail.com>
Date:   Fri May 10 15:43:13 2024 +0200

    [ORPO] Correct label mask for pad tokens (#1625)

    * [ORPO] Correct label mask for pad tokens

    Recent [fix](57aebe9c36) for calculating NLL loss for a whole sequence introduced a bug. When input_ids are copied to labels, pad tokens are not masked.

    This PR aims to path this by masking labels based on the attention mask.

    * -100 -> label_pad_token_id

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 8799952876631d7c772ac80f9cbcff155da960e2
Author: Costa Huang <costa.huang@outlook.com>
Date:   Fri May 10 09:32:20 2024 -0400

    visualize rm prediction (#1636)

    * visualize rm prediction

    * quick update

    * quick check

    * quick fix

    * update eval steps

commit 3b4c24946b7d5580fd354b0e3800fc1047b82a41
Author: Xiao Yu <39458711+jasonyux@users.noreply.github.com>
Date:   Fri May 3 18:19:35 2024 -0400

    fixed adding bos and eos token unconditionally (#1591)

    * fixed adding bos and eos token unconditionally

    * fixed typo of tokenizer -> self.tokenizer. Also added update to ORPO

    * fixed code quality, and added BOS/EOS fix to KTO

    * code reformatting with pre-commit run --all-files

    * bug fix: check input id length before checking for EOS/BOS

commit 0347f583e3883f9144a959d1e6f748a4cc91cd09
Author: lewtun <lewis.c.tunstall@gmail.com>
Date:   Fri May 3 15:59:59 2024 +0200

    Fix ZeRO-3 generation context manager (#1617)

* judge refactoring and unittest

* format

* init

* doc

* format

* improve doc

* basejudge

* improve doc and add BaseAPIJudge

* Doc

* style

* refactor callback

* remove openai and pairrm judge from test

* doc

* rm dpo online example

* new prompts and completions

* skip hf judge and add hf token

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-07-18 15:16:59 +02:00
98ad01ddfd dpo vlm blog post (#1844)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-17 18:03:49 +02:00
fef8240c23 fix arg parsing in chat.py (#1846)
Co-authored-by: leandro <leandro.vonwerra@spoud.io>
2024-07-17 17:32:17 +02:00
915ffc7c61 add link to DPO datasets collection (#1845) 2024-07-17 11:18:35 -04:00
5828a666bf Fix issues of KTOTrainer (#1840)
* Fix issues of KTOTrainer

* Update trl/trainer/kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-07-17 08:46:14 +02:00
052a8e14b5 fix ppov2_trainer tensorboard log bugs (#1836) 2024-07-16 16:08:15 +02:00
a2adfb836a ref_model -> model_ref (#1835)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-15 18:50:29 +02:00
4ebfc5de28 refactor trainer callbacks (#1826)
* refactor trainer callbacks

* fix import

* more import fixes
2024-07-15 11:07:16 -04:00
9e9dc96e67 Added missing token kwarg in Peft model loading (#1825) 2024-07-10 19:11:13 +02:00
7ddef5c158 Make use of trust_remote_code consistent (#1806)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-10 18:26:11 +02:00
a9cddf8c55 Delete unused benchmark.yml workflow. (#1822) 2024-07-10 11:25:07 -04:00
2860ce5091 DPO Llava 1.5 and PaliGemma support (#1797)
* llava support dpo

* add_special_tokens=False only when possible

* format

* pali gemma

* refactor size

* remove image resize

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-09 09:22:52 +02:00
30e33bd92d upgrade gh actions (#1818)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-08 23:37:12 -04:00
d5a0d2d345 Set dev version (#1817) 2024-07-08 11:12:41 -04:00
314e8eb367 fix broken url in docs\source\index.mdx (#1813) 2024-07-08 15:41:36 +02:00
e10792032b 0.9.6 release (#1816) 2024-07-08 09:38:09 -04:00
78045dedc8 Fix TRL_USE_RICH environment variable handling (#1808)
* Add `strtobool` custom implementation from `distutils`

* Fix `TRL_USE_RICH` handling via `strtobool`

* Run `make precommit`
2024-07-07 19:59:26 -04:00
747612f9d3 Fix torch_dtype handling in {DPO,SFT}Trainer when provided via CLI (#1807)
* Fix `torch_dtype` handling through CLI

The `torch_dtype` is not properly handled when provided via the TRL CLI
since it's provided initially as a string, but is then casted to
`torch.dtype` before providing it to the `{DPO,SFT}Trainer`, which means
that those trainers should handle the scenario where `torch_dtype` is a
`torch.dtype` too.

* Add `torch_dtype` tests in `test_{dpo,sft}_trainer.py`

* Forward contribution credits

* Run `make precommit`

---------

Co-authored-by: Tash Srivastava <yash-srivastava19@users.noreply.github.com>
2024-07-05 16:28:59 +02:00
9e3a35bd3d Remove extra print in reward_trainer.py (#1799)
`print_rich_table` is called twice and the first call doesn't restrict to `num_print_samples`. Remove the first, extra call
2024-07-05 13:29:48 +02:00
4402b36dcf clean examples (#1791)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-04 14:29:25 +02:00
78f8228874 Bugfix: Preserve token fields when converting TrainingArguments to SFTConfig (#1794)
* Preserve token fields when converting TrainingArguments to SFTConfig

TrainingArguments.to_dict() redacts token fields, so we have to
individually copy them over when converting to SFTConfig to avoid
breaking push_to_hub functionality.

Also adds a test.

* run precommit

* one-line args_as_dict definition per suggestion from kashif

* generalize token copying to match TrainingArguments behavior

* unwrap |= on dict, to support python 3.8

* use .update instead of |= or for-loop
2024-07-03 20:10:50 +02:00
b6af2edc93 add model_init_kwargs to training_args (#1787) 2024-07-03 08:29:16 +02:00
cd85b14fbb Fixed typo in SFT trainer docs (#1788)
'STFConfig' instead of 'SFTConfig' appears multiple times in the doc, causing error when running the code snippets.
2024-06-29 15:35:48 +02:00
a57544f47a fix docs and examples (#1780) 2024-06-27 15:47:58 +02:00
b68ff96f0c Visual DPO (#1647)
* Remove extra whitespaces

* idefics

* vdpo

* sft idefics

* pad with test

* use prompt instead of tokenizer

* rm name main

* support vlm in tokenize row

* temp fix for regex in lora_target_module

* format

* vdpo

* tmp float16 hard code

* concatenated_forward support for vision

* style and new command line

* all-linear

* format

* delete old examples

* get image

* upcast

* new test

* modified test

* new strat for tokenizer

* rm token transfer

* integrate vision in dpo example

* format

* add FDivergenceType back

* precommit

* pillow test dep

* optional prompt

* `evaluation_strategy` to `eval_strategy`

* revert vsft change (oos)

* update test

* test

* comment and support more in process

* update process

* update doc for vdpo

* caution about limited support

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* revert DPO example changes

* cleaner way to check if a model is vision

* comment

* update vdpo example

* rename

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-26 16:26:37 +02:00
c8c01cc055 Fix Documentation Overflow Issues for Long URLs in SFTConfig (#1774)
* Update sft_config.py

* Update sft_config.py
2024-06-26 11:23:36 +02:00
3479606c8c Remove the leading space in the tldr preference dataset (#1773) 2024-06-26 09:18:22 +02:00
7965b78340 add Efficient Exact Optimization (EXO) (#1735)
* add exo

* fix a detail

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-25 16:47:32 +02:00
56bd1bba26 evaluation_strategy to eval_strategy (#1771)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-06-25 10:14:26 -04:00
94d53e6617 MoE Models: option to add load balancing loss (#1765)
* KTO: add aux loss

* use router_aux_loss_coef in KtoTrainer when aux_loss enabled

* align optional aux_loss in DPO, KTO, CPO, ORPO

* precommit changes

* fix KL forward kwargs

* add aux_loss doku entry

* apply docs suggestions

---------

Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-06-24 21:27:00 +02:00
b5be100ae0 Added Reward Backpropogation Support (#1585)
* added alignprop template

* added alignprop support

* Update alignprop_trainer.mdx

* Update alignprop_trainer.mdx

* added better why statement

* fixed inference code

* changed self to pipeline

* removed aesthetic classifier

* added aesthetic to auxiliary models

* added unseen prompt logging

* removed unseen prompt log

* fixed minor

* remove not needed import in trl/__init__.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fixed styling

* updated _toctree

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-06-24 12:05:44 -04:00
6e1652bc5e Add CPO-SimPO method (#1760)
* enable cpo-simpo

* highlight SimPO and CPO-SimPO

* add test for cpo_alpha

* formatting

* Update docs/source/cpo_trainer.mdx

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-23 18:54:30 +02:00
65374c6a71 New sentiment and descriptiveness dataset (#1757)
* push changes

* handle edge cases where the chosen and the rejected are the same
2024-06-21 11:20:54 -04:00
9956091112 Add dataset_text_field in examples/scripts/sft.py (#1758) 2024-06-21 11:01:08 +02:00
34d273f227 Support num_train_epochs (#1743)
* add a test case for num_train_epochs

* fix ci

* quick change

* disable push to hub

* debug windows ci

* try another fix

* skip subprocess tests on windows
2024-06-20 13:16:43 -04:00
3bf94492a8 Fix masking of response tokens (#1718)
Current handling of `response_masks` inside `batch_forward_pass`
function does not take padding into consideration which results with
shape unmatch during masking. Since response mask is a mask tensor of
response tokens, response tokens should not be concatenated with a
`torch.zeros(query_length)` and masking operation should be done without
slicing.

Remove the concatenation of the response mask, remove the slicing from
the response mask since response mask already has the length of `end -
start + 1`, which is equal to length of `masks[j, start:end]`.
2024-06-20 11:22:20 -04:00
ba6abee37f Support for returning past_key_values from the model (#1742)
* add support for returning past_key_values from the model

* change order of  keys
2024-06-20 09:14:16 -04:00
a57e75967c Integrate f-divergence to DPO (Follow up) (#1610)
* Step 1: update ppo_trainer and hello_world example

* Step 2: Refine comments and add parameter type

* Step 2: Add missing parameter comments

* Step 1: Organize ptx loss into a function and add ptx_loss to train_stats

* Step 1 updates: add comment to ptx_loss function, fix a bug and add warning message

* Step 2: 1) Add ppo_ptx trainig example as ppo; 2) separate pretrain data fetch and iterate

* Step 2: Remove loss from columns_to_log in ppo_ptx example

* Remove data set revision in load imbd dataset

* Run pre-commit and fix format issues

* Initial draft of f-divergence fn

* Update f-divergence to avoid overflow

* fix test errors and comments

* Add Unit tests for dpo loss with alpha and js div f

* Adjust format

* Fix test error

* Reverse this update

* Add test cases

* Reverse un-needed updates

* Update code style

* Try to fix code fmt error

* remove extra end line

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-19 12:02:51 +02:00
ae23d40f3b change the process function in the example of DPO (#1753)
* change the `process` function in the example of DPO

* fix
2024-06-18 10:07:24 -04:00
83b367b11a CI / KTOTrainer: Remove old tests (#1750)
* remove old tests

* remove datasets

* Update test_dpo_trainer.py

* Update test_dpo_trainer.py
2024-06-18 11:31:17 +02:00
d1ed730ab8 prepare deepspeed accomodate fp16 and bf16 (#1728)
* prepare deepspeed accomodate fp16 and bf16

* precommit
2024-06-17 10:50:21 -04:00
8f8e95e25d CPO / DPO: Fix red CI (#1749)
* fix red CI

* precommit
2024-06-17 10:49:00 -04:00
4e23d958f2 fix red CI 2024-06-17 16:41:36 +02:00
50c46205b6 small KTO fixes (#1734)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* remove kto_pair

* speed up data processing

* move bco code inside

* raise error for kto_pair argument

* fix formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Winnie Xu <winnie.xu97@gmail.com>
2024-06-17 10:14:44 -04:00
6105d03f92 TrlParser: Add ignore extra args option (#1748)
* add ignore extra args option

* Update trl/commands/cli_utils.py
2024-06-17 16:01:06 +02:00
e247bbd7d5 CI / core: Pin numpy to !=2.0.0 for CI and to users (#1747)
* Update setup.py

* Update setup.py

* Update setup.py

* Update test_best_of_n_sampler.py

dummy commit

* pin numpy

* Update tests/test_best_of_n_sampler.py

* Update setup.py
2024-06-17 15:16:07 +02:00
3d04496196 better trl parser with yaml config (#1739)
* working trl parser with config

correctly overrides yaml config with command line arguments
adds return_remaining_strings
when return_remaining_strings is False, raises error if yaml contains
extra args that are not in the dataclasses
simpler and cleaner than previous yaml parsing and merging
addresses #1733

* lowercase trlparser
2024-06-17 14:43:33 +02:00
2d244f8acb Workflow: Notify tests results on slack channel (#1744)
* Update tests-main.yml

* Update docker-build.yml
2024-06-17 11:56:13 +02:00
f5168fdbaf adds AOT (#1701)
* adds AOT

* Applied format changes

* added docs and tests

---------

Co-authored-by: Igor Melnyk <igor.melnyk@ibm.com>
2024-06-12 11:54:54 +02:00
79686e1ac7 ktotrainer: Refuse datasets which contain only one class of labels (#1724)
* ktotrainer: refuse dataset which contain only one class of labels

* ktotrainer: document new dataset constraint
2024-06-11 16:35:31 +02:00
34ebc4ccaf feat(ci): add trufflehog secrets detection (#1721)
* feat(ci): add trufflehog secrets detection

* fix(ci): remove unnecessary permissions
2024-06-10 11:17:54 +02:00
1d84e2b888 Fix default padding_value in dpo_config.py (#1692)
dpo_config default padding value should be None, not 0, otherwise it by default overrides the padding value of any tokenizer to 0
2024-06-07 11:42:08 +02:00
2f71b8b1e2 fix yaml parser for derived config classes (#1713)
fixes #1712
reformatted cli_utils with ruff
2024-06-07 10:37:27 +02:00
5bcb8ad0d6 RDPO fix nll loss (#1705) 2024-06-07 09:48:17 +02:00
b8b972fde1 Add a variant of CPO, SimPO (#1703)
* add a variant of cpo: simpo

* correct cpo-simpo loss

* avoid 0 int error in logging

* add simpo description

* Update trl/trainer/cpo_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* fix formatting

* add test for simpo

* Update docs/source/cpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add a docstring for simpogamma

* move simpo description to the above docstring

* change simpo description in the doc

* formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-06 17:06:47 -04:00
3eb9ccb104 set dev version (#1710)
* Update setup.py

* Update __init__.py
2024-06-06 13:33:20 -04:00
974b0d380f 0.9.4 release (#1708) 2024-06-06 10:13:00 -04:00
39a7d1c121 SFTTrainer: Fix backward Compatibility issue with TrainingArguments (#1707)
* fix BC

* fixup
2024-06-06 09:50:17 -04:00
0bdc63839f Fixed doc string and docs for the SFTConfig update (#1706) 2024-06-06 09:42:58 -04:00
275d33b3ef 0.9.3 release (#1699) 2024-06-05 14:34:59 -04:00
c0819ee99f Update sft_trainer.py (#1698) 2024-06-05 11:29:03 -04:00
a03e7cc4e4 Release 0.9.2 (#1697)
* Release: 0.9.0

* Release
2024-06-05 11:00:19 -04:00
a13cb8952c Quick fix on GPT4-eval (#1696)
* quick fix

* precommit
2024-06-05 10:20:54 -04:00
84156f179f Fix typo in DPOTrainer's warnings (#1688) 2024-06-03 14:09:05 -04:00
4eb0b905e2 Skip packing validation (#1673)
* Add test for skipping preproc if packing=True

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Allow skipping of validation for packing=True

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Use dummy dataset in no packing preproc test

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

---------

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
2024-06-03 18:24:32 +02:00
6c203f9fef Fix overriding optimize_device_cache with optimize_cuda_cache in PPOConfig (#1690)
* Don't override optimize_device_cache when optimize_cuda_cache is not provided
Raise an exception when both optimize_cuda_cache and optimize_device_cache are set

* Minor fix
2024-06-03 11:16:22 +02:00
f18253bf2d intial RPO loss (#1686)
* intial RPO loss

* fix sign

* clean up
2024-06-03 09:43:02 +01:00
151a452d14 Fix max completion length (#1588) 2024-05-29 20:29:38 +02:00
488b502d31 fix (#1678) 2024-05-29 20:19:26 +02:00
3c0a10b1ae fix dataset load error (#1670)
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
2024-05-27 14:52:20 +02:00
b031adfdb8 FIX / PPO: Fix enable_input_require_grads issues with PPO models (#1664)
* Update modeling_base.py

* Update ppo_config.py

* Update ppo_trainer.py

* style
2024-05-24 15:20:16 +02:00
e7cb597230 Fix ppov2 test case (#1661)
* Fix PPOv2 / RLOO refactor's stuff

* update terminology to use stop token
2024-05-23 11:37:16 -04:00
bc8dfbf4e2 update eval_strategy (#1662) 2024-05-23 15:28:04 +02:00
e4ed7a3a5a do not upcast adapters when using FSDP+QLoRA (#1654) 2024-05-23 15:04:22 +02:00
9a7efbd051 🤫 TR-DPO implementation (#1593)
* 🤫 TR-DPO implementation baseline

* fix comments

* docs

* fix linters

* test added

* move configs to DPOConfig

* fix typo

* add docs

* fix import

* use state.global_step

* fix order of arguments

* make sure plugins are not none

* Update trl/trainer/utils.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update trl/trainer/utils.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* checking that reference model weights have changed

* sync_target_model as staticmethod

* set reference model

---------

Co-authored-by: Nikita Surnachev <n.surnachev@tinkoff.ru>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
2024-05-23 14:58:49 +02:00
b344bcea2c [DPO] Add 'robust' loss_type (#1653)
* Initial commit

* pre-commit fix

* Minor change to comments

* Added some documentation on how to use Robust DPO
2024-05-23 14:57:25 +02:00
35e12dc595 Fix inheritance order in PPOv2Config (#1659)
* fix inheritance order in PPOv2Config

* fix inheritance order in rloo_config
2024-05-23 08:36:15 -04:00
1da6be18e0 docs: correct cDPO usage in DPOTrainer (#1655) 2024-05-23 08:10:29 -04:00
e249cd802f add support for training collator (#1658) 2024-05-23 08:10:05 -04:00
a02513c3b7 Apply deprecated evaluation_strategy (#1559)
* Deprecate

* Update tests/test_dpo_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-05-23 12:48:00 +02:00
13454d2f4b PPO / Reinforce Trainers (#1540)
* Add ppov2 trainer

* make eos trick optional, remove unused args

* quick fix

* precommit

* update debugging script

* fix out of bound `drop_last=True`; use built-in scheduler

* Add PPO examples

* push changes

* quick change

* quick change

* various bug fixes

* remove unnecessary grad accumulation setting

* push new changes

* fix DS3 model saving

* update ppo.py

* refactor

* quick change

* refactor

* update ppo trainer

* refactor

* quick test

* add ds2 /ds3 7 processes config

* add vllm trainer

* quick change

* experiment with reward normalization

* push changes

* quick push

* push changes

* push various changes

* refactor to use ModelConfig

* quick change

* refactor

* refactor

* Simplify DS logic

* quick update

* remove unnecessary files

* precommit

* deepspeed fix; handle edge case when eos_token_id = 0

* add PPO tldr example

* add TL;DR example

* fix undefined var

* utilize all samples in rloo

* quick setting

* remove the unnecessary `value_model`

* use exact_div

* allow saving the deepspeed model

* refactor

* remove dead code

* Use some shared utilities

* add some end-to-end test cases

* add PPOv2 docs and RLOO docs / tests

* update docs

* quikc push

* fix ci

* fix type annotation for ci

* quick update

* update trainer docs
2024-05-22 08:31:10 -04:00
99f2c94b22 don't cast the trainable lora layers to half precision (#1644)
* don't cast the trainable lora layers to half precision

* quality
2024-05-15 16:25:46 +02:00
6401d080c9 Pairwise Noise Contrastive Alignment (#1632)
* add NCA paired preference loss

* chore: lint

* set more lenient tolerance for integration tests

* Update tests/test_dpo_trainer.py

* skip test

* fix

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
2024-05-14 15:41:07 +02:00
d632a5b289 Fixed wrong logs prefixes in KTOTrainer (#1641)
* Fixed wrong logs prefixes in KTOTrainer

* Pre-commit formating
2024-05-14 12:25:54 +02:00
5aeb752053 Update sft_llama2.py to work with the latest API (#1637)
* Update sft_llama2.py to work with the latest API

SFTTrainer now takes a STFConfig argument

* Update dpo_llama2.py

* precommit
2024-05-10 17:19:15 +02:00
b8b89783ca [ORPO] Correct label mask for pad tokens (#1625)
* [ORPO] Correct label mask for pad tokens

Recent [fix](57aebe9c36) for calculating NLL loss for a whole sequence introduced a bug. When input_ids are copied to labels, pad tokens are not masked.

This PR aims to path this by masking labels based on the attention mask.

* -100 -> label_pad_token_id

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-05-10 15:43:13 +02:00
8799952876 visualize rm prediction (#1636)
* visualize rm prediction

* quick update

* quick check

* quick fix

* update eval steps
2024-05-10 09:32:20 -04:00
3b4c24946b fixed adding bos and eos token unconditionally (#1591)
* fixed adding bos and eos token unconditionally

* fixed typo of tokenizer -> self.tokenizer. Also added update to ORPO

* fixed code quality, and added BOS/EOS fix to KTO

* code reformatting with pre-commit run --all-files

* bug fix: check input id length before checking for EOS/BOS
2024-05-04 00:19:35 +02:00
0347f583e3 Fix ZeRO-3 generation context manager (#1617) 2024-05-03 15:59:59 +02:00
75de236c09 corrects loss function for Self-play Preference Optimization hard label version (#1615)
* corrects sppo hard lable version

* formatting

* formatting
2024-05-03 08:09:57 +02:00
7075cec94d Update HH dataset on helpful only subset (#1613)
* Update HH dataset on helpful only subset

* format
2024-05-02 12:12:12 -04:00
adf17a5a26 support loss function for Self-play Preference Optimization (#1612)
* support loss function for Self-play Preference Optimization

* update docs

* update value error msg

* update typehint

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* include sppo in tests

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-05-02 16:06:58 +02:00
0d40e186ee Docs: Fix build main documentation (#1604)
* Fix build documentation

* Update build_pr_documentation.yml
2024-05-02 11:44:29 +02:00
683bc5af6f Excluding tests from setup.py (#1607) 2024-05-02 10:30:27 +02:00
5f0913122b Use auto device map (#1596) 2024-05-02 09:22:31 +02:00
d1aa0b6b2c [KTOTrainer] add BCO (reward shift and underlying distribution matching) (#1599)
* add `Loss Functions` section in the doc.

* add bce loss with reward shift in KTOTrainer

* add underlying distribution matching

* update example to use underlying distribution matching

* add config description

* fix 'referenced before assignment' error

* add 'bco' and 'udm' test cases

* run pre-commit

* add `scikit-learn` dependency

* raise error is sklearn is not available

* call TrainingArguments().__post_init__() for proper init
2024-04-30 14:06:45 +02:00
d88ec14602 Update __init__.py (#1602) 2024-04-30 10:25:43 +02:00
6c18e40e97 fix typo (#1594) 2024-04-29 10:42:31 +02:00
1d0a7ea17b add warning in SFTTrainer (#1577) 2024-04-23 20:00:10 +02:00
9f68ead8cf FIX: Fix CI on transformers main (#1576)
* Update run_dpo.sh

* Update run_sft.sh

* Update clis.mdx

* Update example_config.yaml

* Update test_cli.py

* Update testing_constants.py

* Update test_dpo_trainer.py
2024-04-23 14:31:45 +02:00
f30daa4225 [SFT] add SFT Trainer Config dataclass (#1530)
* initial SFT Config

* remove pdb

* fix chat_template

* undo formatting

* add back removed commits

* fix the tests

* add back options to SftScriptArguments

* use sft_script_args

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* rename SFTScriptArguments and split names

* formatting docstrings

* docstring

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-04-23 11:55:13 +02:00
24fd8dd513 [DPO] DPOConfig class (#1554)
* initial DPOConfig

* fix doc string

* use DPOConfig

* fix missing import

* fix DpoScriptArguments

* override args config when given in init

* use DPOConfig

* fix output dir name

* over-ride with depreicated arguments if given

* use DPOConfig in tests

* fix comment

* add custom_message

* use dataset_train_name and dataset_test_name

* beta is also in the training_args

* fix loss_type docs

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use DPOScriptArguments

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-04-23 11:06:28 +02:00
c050ebc073 [DPO] add 'bco_pair' loss_type (#1524)
* add 'bco_pair' loss_type

* add BCO description to DPO doc

---------

Co-authored-by: sean.jung <sean.jung@seanjungui-MacBookPro.local>
2024-04-22 18:46:51 +02:00
abc0584736 fix add_special_tokens issue for data with template (#1509) 2024-04-22 18:44:10 +02:00
6d1cb85e73 set dev version (#1568) 2024-04-22 10:59:35 +02:00
e90e8d91d2 Release: v0.8.6 (#1567) 2024-04-22 10:58:13 +02:00
113aaae033 CLI: Add warning when ignored params are passed + parse config file if config if passed (#1565)
* add warning

* no need for `config` field
2024-04-22 10:48:59 +02:00
0865572748 Update __init__.py (#1557) 2024-04-18 14:51:40 +02:00
a6532a11c2 set dev version (#1556) 2024-04-18 13:58:17 +02:00
3595eb00e0 Release: v0.8.5 (#1555) 2024-04-18 13:56:36 +02:00
9afd901d0f enable multiple eos tokens (#1553) 2024-04-18 12:19:18 +02:00
e04432d5e3 FIX: make the train / test fields modulable (#1551)
* make the train / test fields modulable

* format

* fix --output_dir issue
2024-04-18 11:33:30 +02:00
75c1c47fcc set dev version (#1548) 2024-04-17 17:25:01 +02:00
a5788ac99b Release: v0.8.4 (#1547) 2024-04-17 17:19:28 +02:00
3bbe7e0407 Fixed ref model not used in PPO generation (#1534) 2024-04-17 07:22:56 -07:00
edf60e826b Update run_sft.sh (#1546) 2024-04-17 16:17:05 +02:00
5d1deb1445 CLI: Set dataset_text_field to None to allow ChatML automatic template (#1545)
* Update cli_utils.py

* Update test_cli.py
2024-04-17 14:45:14 +02:00
476c4b8dc0 [KTO] support to load the adapter twice (#1542)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-04-16 17:43:40 +02:00
e823458a6a save_model -> save_pretrained in ppo_trainer.mdx (#1537) 2024-04-15 09:35:03 +02:00
1c0d8bca15 VSFT hotfix - adds gen prompt to template and processor to hub (#1532)
* adds gen prompt to template and processor to hub

* fixes hub model id, removes Path
2024-04-12 20:14:12 +02:00
363369a717 [CPO] fix memory leak due to retained value (#1531) 2024-04-12 15:32:01 +02:00
aba4df02c1 set dev version (#1529) 2024-04-12 12:37:34 +02:00
98226473e4 Release: v0.8.3 (#1528) 2024-04-12 12:22:05 +02:00
87f4c70e60 [CLI] fix imports (#1527) 2024-04-12 12:17:05 +02:00
995f1174da set dev version (#1523) 2024-04-11 15:51:57 +02:00
143e11123d Release: v0.8.2 (#1522) 2024-04-11 15:42:47 +02:00
346c99d222 Adds VLM Training support to SFTTrainer + VSFT script (#1518)
* adds option to skip dataset preparation in SFTTrainer

* before changing the template

* adds support for new schema

* a few fixes to data collator to support new schema

* updates args

* precommit

* adds sys prompt to chat template and other fixes

* updates template, fixes collator for multiple images

* precommit

* rename vsft to vstf_llava

* adding integration tests

* adds integration test for vsft

* precommit

* adds back chat template

* docs

* typo

* adds eval, precommit

* adds peft launch args

* formatting

* fixes no deps tests by checking if PIL lib exists

* Update __init__.py

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-04-11 15:35:59 +02:00
087fe544b0 add data for sfttrainer doc (#1521) 2024-04-11 15:08:43 +02:00
ebbd37ba99 allow pre-tokenized datasets (#1520) 2024-04-11 14:50:39 +02:00
e667550a5a Allow streaming (datasets.IterableDataset) (#1468)
* safe-guard iterabledatasets

* import datasets

* reference the correct IterableDataset

* make pre-commit
2024-04-11 11:11:07 +02:00
57aebe9c36 [ORPO] Update NLL loss to use input_ids instead (#1516)
* Calculate loss on `input_ids` instead of only on response

* Use `concatenated_labels` if `is_encoder_decoder`
2024-04-09 14:10:09 +02:00
85f5fd220d correct metrics (#1514)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-04-08 17:09:04 +02:00
4dca169404 use kwarfs for RM (#1515) 2024-04-08 17:05:37 +02:00
f35b68a301 Speed up PPO with ZeRO-3 by 10x 🔥 (#1483)
* Speed up PPO by 10x 🔥

* Revert

* Clean up

* Use relative import

* Clean

* Fix typing for docs
2024-04-08 14:30:44 +02:00
5cf863576a Change the device index to device:index (#1490)
Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-04-08 14:20:42 +02:00
9a28b3fd05 Fix RichProgressCallback (#1496)
* fix RichProgressCallback

* Refine code styling in RichProgressCallback tests
2024-04-04 21:13:54 +02:00
4f8057ad23 [KTO] fix interleaving, reporting, hanging bugs (#1499)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* fix tokenization error: lack of bos

* change user warning for weight hyperparams

* minor update to docs

* reshape attention mask

* reformat

* add test for bos/eos tokens

* move dependency location

* Update tests/test_kto_trainer.py

* don't report nan metrics

* don't report nan metrics and remove data interleaving

* fix bugs in calculating metrics

* no need to gather KL term

* minor changes

* use nanmean for losses

* remove disabling of wandb

* revert changes

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-04-03 23:41:12 +02:00
ab0d11d815 Correct ppo_epochs usage (#1480)
* Correct ppo_epochs usage

The usage of ppo_epochs is incorrect here. 

In 8534f0edf8/trl/trainer/ppo_config.py (L104C8-L104C58)

the ppo_epochs was described as "Number of optimisation epochs per batch of samples". 

However, here it is used as the usual epoch number, in which you do one iteration over the training dataset.

* Update ppo_trainer.mdx

* Update docs/source/ppo_trainer.mdx

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-04-02 12:22:16 +02:00
c674c66a45 Fix DPO Unsloth example (#1494) 2024-04-02 12:16:56 +02:00
45da5df53e use log1p for loss (#1491) 2024-04-02 12:06:54 +02:00
04fd8d9400 Fix typo in how_to_train.md (#1503)
Said "big" where it should say "bug".
2024-04-02 12:05:07 +02:00
bf2aed3876 add dpo link (#1502) 2024-04-02 12:04:34 +02:00
0ee349dcd4 Update KTO example to use better model and ChatML support (#1485)
* Update KTO example

* Tweak params

* Fix values

* Fix LoRA params
2024-03-27 10:47:42 +01:00
7ff6206510 Ignore chat files (#1486)
* Ignore chat files

* Update .gitignore

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update .gitignore

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2024-03-27 10:44:23 +01:00
e4b20ecbc4 hackey update to ModelConfig to allow lora_target_modules="all-linear" (#1488)
the type hint forces a list which raises a "all-linear" layer not found. forcing a string makes it work. updating the type hint to `Union[str, list[str]]` also raise a parsing error
2024-03-27 09:04:41 +01:00
6c2f829bb7 [KTO] Use batching to speed up data processing (#1470)
* Refactor test

* Make batched tokenizer

* Make is FAST 🔥!

* Hack to the max

* Run on main process

* Refactor

* Add unit test

* f

* r

* Refactor

* Remove bs

* Refactor to tokenize once

* Add typing

* Add test for KL getter
2024-03-26 19:46:23 +01:00
c4f0f41935 Update KTO example with good dataset & chat format (#1481)
* Update KTO example with good dataset & chat format

* Add error for chat template
2024-03-25 16:56:43 +01:00
dc6a934269 add missing classes (#1479) 2024-03-24 22:08:28 +01:00
9ce7ac6925 Fix hyperparameters in KTO example (#1474)
* Fix hparams in KTO example

* Clean

* Fix
2024-03-24 14:29:22 +01:00
99553c19ae Add use_cache=False in {ORPO,CPO}Trainer.concatenated_forward (#1478)
* Add `use_cache=False` in `concatenated_forward`

Prevents `ORPOTrainer` from using the cache, as it's not required for computing the logits and runs into conflicts with Flash Attention 2

* Add `use_cache=False` to `concatenated_forward`

Co-authored-by: Kashif Rasul <kashif@users.noreply.github.com>

---------

Co-authored-by: Kashif Rasul <kashif@users.noreply.github.com>
2024-03-24 11:33:20 +01:00
2ce8e45bb2 ORPO trainer (#1435)
* initial orpo skeleton

* typos

* calculate orpo loss

* fix class name

* fix tests

* fix typo

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* rename max_target_length

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* more docs

* log log_odds_ratio and log_odds

* average_log_prob as per paper

* added logging section

* add nll_loss

* fix typo

* more verbose

* rename log_odds to log_odds_chosen

* allow datasets to be loaded

* remove dup debug arg

* tokenizer exists

* fix typo

* use trl-internal-testing/hh-rlhf-trl-style dataset

* formatting

* add missing imports

* fix output dir name

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* move dataset_num_proc to configs

* Update trl/trainer/orpo_config.py

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>

* Update trl/trainer/orpo_trainer.py

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>

* add ORPOTrainer to readme

* fix typo

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
2024-03-22 22:07:11 +01:00
d1df79f83c Add CPOTrainer (#1382)
* add CPOTrainer

* add docs

* fix formatting

* removed precompute_ref_log_probs arg

* remove precompute_ref_log_probs

* typos

* finish cpo trainer doc

* remove redundant lines

* typo

* formatting

* compute chosen nll loss also for enc-dec models

* fix gradient error of inplace operation for enc-dec models

* formatting

* use CPOConfig

* formatting

* use model_init_kwargs from CPOConfig

* comments in example

* fix doc string

* fix typo in docstring

* update year

* fixed typo

* use preference dataset

* fix learning rate

* move dataset_num_proc to configs

* Update cpo paper link from HF: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update description for CPO: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* remove _prepare_deepspeed for cpo

Because CPO does not need init for reference model

* Add explanation to CPO loss

* format

* fix bug when lengths are given

* add CPOTrainer to README

* fix grammer

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-22 21:32:45 +01:00
d10f7663b0 [peft] Update test_reward_trainer.py to fix tests (#1471)
* [peft] Update test_reward_trainer.py

Since we are requiring peft >= 0.4.0

* formatting
2024-03-22 19:12:54 +01:00
423991c204 Use the standard dataset for DPO CLI (#1456)
* Use the standard dataset

* update docs

* update dpo examples

* fix cli error

* fix CI

* use trl-internal-testing/hh-rlhf-trl-style
2024-03-20 13:14:08 -04:00
988d4c4e1a set dev version (#1463) 2024-03-20 12:30:48 +01:00
8534f0edf8 Release: v0.8.1 (#1462) 2024-03-20 11:32:06 +01:00
5095e7f948 add eos token to generate (#1459) 2024-03-20 10:30:27 +01:00
9fcf61d706 Fix chat CLI for model revisions (#1458)
* Fix chat CLI for model revisions

* Clean
2024-03-20 09:35:34 +01:00
66b043a910 set dev version (#1454) 2024-03-19 17:30:48 +01:00
f2c71771cc Release: v0.8.0 (#1453)
* Release: v0.7.12

* 0.8.0 instead
2024-03-19 17:19:38 +01:00
631c33cbb3 FEAT: Update README to add DPO + CLIs (#1448)
* Update README.md

* Update README.md

* move dpo/ppo description to docs

* rework readme

* Update README.md

---------

Co-authored-by: leandro <leandro.vonwerra@spoud.io>
2024-03-19 16:55:56 +01:00
3f7ff60528 model --> model_name_or_path (#1452)
* `model` --> `model_name_or_path`

* fix style
2024-03-19 16:52:42 +01:00
1705aebeba Fix yaml parsing issue (#1450) 2024-03-19 16:07:50 +01:00
4e622a9033 chat cli (#1431)
* first draft

* move chat to cli

* fix makefile

* make script less verbose

* fix parsing

* fix style

* add more examples

* fix setup.py

* add copyright

* fix verbose init

* attribute FastChat

* add docs
2024-03-19 12:37:06 +01:00
eb2d5b2972 CI / CLI: Properly raise error when CLI tests failed (#1446)
* properly raise error

* another fix

* Update tests.yml

* Update tests-main.yml
2024-03-19 11:39:07 +01:00
f976c6d234 Before update the tr_loss, make sure tr_loss_step is in the same device. (#1439)
* before update the loss from dpo, make sure it's in the same device of tr_loss

* Update trl/trainer/dpo_trainer.py

Co-authored-by: guy1992l <83535508+guy1992l@users.noreply.github.com>

---------

Co-authored-by: guy1992l <83535508+guy1992l@users.noreply.github.com>
2024-03-19 10:28:44 +01:00
abc7301bab Fix PPOTrainer README example (#1441)
* Fix example

* Delete newline
2024-03-19 10:18:49 +01:00
6cfa5cfc81 fix doc build on main (#1437) 2024-03-18 14:24:02 +01:00
a2aa0f0b09 FEAT: Add CLIs in TRL ! (#1419)
* CLI V1

* v1 CLI

* add rich enhancmeents

* revert unindented change

* some comments

* cleaner CLI

* fix

* fix

* remove print callback

* move to cli instead of trl_cli

* revert unneeded changes

* fix test

* Update trl/commands/sft.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* remove redundant strings

* fix import issue

* fix other issues

* add packing

* add config parser

* some refactor

* cleaner

* add example config yaml file

* small refactor

* change a bit the logic

* fix issues here and there

* add CLI in docs

* move to examples/sft

* remove redundant licenses

* make it work on dpo

* set to None

* switch to accelerate and fix many things

* add docs

* more docs

* added tests

* doc clarification

* more docs

* fix CI for windows and python 3.8

* fix

* attempt to fix CI

* fix?

* test

* fix

* tweak?

* fix

* test

* another test

* fix

* test

* fix

* fix

* fix

* skip tests for windows

* test @lvwerra approach

* make dev

* revert unneeded changes

* fix sft dpo

* optimize a bit

* address final comments

* update docs

* final comment

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2024-03-18 12:20:54 +01:00
304e208f77 Create standard dataset for TRL (#1424)
* add scripts to create standard dataset

* precommit

* push changes

* add script to play with
2024-03-14 10:57:48 -04:00
4fe8b027f6 [Kto] torch_dtype kwargs fix (#1429)
* set torch_dtype from string type

* fix test
2024-03-14 13:49:44 +01:00
fb6ebb1e11 [KTO] fix tokenization bugs (#1418)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* fix tokenization error: lack of bos

* change user warning for weight hyperparams

* minor update to docs

* reshape attention mask

* reformat

* add test for bos/eos tokens

* move dependency location

* Update tests/test_kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-14 08:22:50 +01:00
66078c7c01 CI: Fix CI on main (#1422)
* fix CI on main

* final fix
2024-03-13 13:54:22 +01:00
58c0888996 Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA (#1416)
* don't do mp casting

* don't use `prepare_for_kbit` when using fsdp+qlora or dsz3+qlora

* changes to enable fsdp+qlora and dsz3+qlora

* revert

* Update sft_trainer.py

* quality

* fix deprecation using changes from PR https://github.com/huggingface/trl/pull/1415

* fixes

* quality

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* quality

* relaunch tests

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-03-13 10:43:45 +01:00
486e7a4071 model init when args are given (#1413)
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2024-03-11 13:47:37 +01:00
7630f877f9 Fix import error from deprecation in transformers (#1415)
* Fix import error from  deprecation in transformers

* Fix import path
2024-03-11 13:23:56 +01:00
4d862da181 [KTO] fix various bugs (#1402)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-08 12:04:52 +01:00
22b4f548f4 fix RM script (#1393) 2024-03-07 08:49:52 +01:00
4219cbfedc Fix the pad_token_id error (#1394)
* Fix the pad_token_id error

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add the load_in_8bit argument in rl_training.py

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Reformate the patch

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Fix the check failed

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-03-05 02:18:42 +01:00
3bd02380c7 Log ddpo reward as float to fix numpy conversion during bf16 training (#1391) 2024-03-04 02:50:50 +01:00
067db7553a [KTO] prevent nans from appearing in metrics (#1386)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-01 12:19:55 +01:00
93e85ed808 [KTO] merge eval dataset only if it exists (#1383)
* merge eval dataset if it exists

* add eval dataset test
2024-03-01 12:15:14 +01:00
14e0d78807 fix bugs in KTO implementation (#1380)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-02-29 09:01:52 +01:00
b32656f726 FIX: Fix the CI again .. (#1374)
* Update tests-main.yml

* Update tests-main.yml

* Update tests-main.yml

* Update .github/workflows/tests-main.yml

* Update tests-main.yml

* Update tests-main.yml
2024-02-27 12:46:20 +01:00
9399bc113b Update tests-main.yml (#1373) 2024-02-27 12:07:50 +01:00
11f122ad49 Update tests-main.yml (#1372) 2024-02-27 11:45:02 +01:00
009c9a610b feature request add force_use_ref_model (#1367) 2024-02-27 11:19:16 +01:00
7712d42f8c add eval_packing (#1369) 2024-02-27 11:19:06 +01:00
7c2213b9e5 add ci message sending on TRL (#1370) 2024-02-27 11:18:55 +01:00
ddeebce176 Add some arguments for support XPU (#1366)
* Add use_bnb and load_in_4bit arguments.

Make it optional and not supported on all platforms

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Change the use_reentrant default value to False

If the default value of gradient_checkpointing is True, set the
use_reentrant default value as False. Because the following error
happens.

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 191 with name base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add model_dtype for loading the model in model_dtype

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Reformate the patch

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-02-27 02:49:16 +01:00
cf68d871cf Fix version for Python<3.8 (#1363) 2024-02-27 02:41:09 +01:00
2a2676e7ec set seed in sft/dpo/reward_modeling to make result reproducable (#1357)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-02-23 11:12:45 +01:00
ca90cba351 fix 8-bit multi-gpu training bug (#1353)
* fix 8-bit multi-gpu training bug see https://github.com/huggingface/trl/issues/1348

* Update dpo_llama2.py

make gradient_checkpointing_kwargs configurable.

* Update dpo_llama2.py

remote unnecessary config of device_map

* format with make precommit

---------

Co-authored-by: ubuntu <lili@liveremier.ai>
2024-02-23 03:58:43 +01:00
4f97fb4a74 more userfriendly (#1350) 2024-02-22 10:06:35 +01:00
a46cd84a64 Kto trainer (#1181)
* initial file

* initial tokenizer

* UnpairedPreferenceBatchSampler

* use batch_sampler

* use interleave_datasets

* add loss

* fix imports

* use SequentialSampler when training

* formatting

* add other helpers

* add prediction_step

* fix the kto pair docs

* tests

* compute_reference_log_probs

* add get_eval_dataloader

* fix typo

* kto with is_encoder_decoder true

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* fixed typo

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* renamed KTO dataset keys

* use DPOTrainer's get_batch_logps

* add get_batch_samples

* typo

* Handle last token in prompt

* Create KTOConfig class that subclasses transformers.TrainingArguments

* Update KTO tests to handle KTOConfig

* Update KTO script to use KTOConfig

* formatting

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/training_configs.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use max_completion_length

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add back get_batch_logps

* use max_completion_length

* move config to its own file

* Check tokenize params on Trainer init

* Clone labels for end-dec model to solve RuntimeError

* formatting

* fix enc-dec later

* completion_decoder_input_ids is optional for enc-dec

* fix breaking test

* add a kl key for KL estimation with shuffled completion

* add loss ad weights

* fix bug in chosen_idx

* add back metrics

* fix typos

* fix kto_loss docs

* typo

* set loss to None when there is no target completions in batch

* use nan tensor instead of none

* fix reference_logps test

* fix logits

* a bit more robust options

* log only the correct prompt-completion during eval

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add docs for desirable_weight and undesirable_weight args

* dropout is always disabled

* remove DDP hack

* formatting

* move more arguments of trainer to config

* comment out T5 test for now

* Add docstring to KTOTrainer

* moved Config docstrings to the appropriate class

* add autodoc to markdown

* formatting

* updated copyright year

* add model tags

* do not add BOS to start of completion

* Move data_collator to KTOTrainer

* formatting

* data_collator is not in args

* shuffle_completion with specific input_columns

* remove all but the needed columns

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* moved more args to kto_config

* fjx test

* use all_exhausted strategy and shuffle after

* use KTOConfig in HfArgumentParser

* use ModelConfig

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
2024-02-19 14:43:17 +01:00
1f56bffdf8 Update Example to reflect #aa35fec (#1333) 2024-02-18 14:10:04 +01:00
1bfe0b8fcb set dev version (#1332) 2024-02-16 09:49:05 +01:00
0f13e51efa Release: v0.7.11 (#1331) 2024-02-16 09:05:04 +01:00
1e77d8aeb2 [core / xxxTrainer] Automatic tagging (#1329)
* automatic tagging

* add comments

* fix tests

* fix
2024-02-15 14:47:32 +01:00
3b1911c2a9 add tests on transformers peft main (#1328) 2024-02-15 05:19:31 +01:00
851e7fe556 [core / DDPO] Fix diffusers import issue (#1314)
* fix

* more clean up
2024-02-15 04:45:27 +01:00
31b02d0cd0 Update README.md to clarify model requirement (#1315)
Clarify that language models must be transformers models for text.  This is a bit redundant with intro description, but attempts to better address a question that that comes up (issue 1257).

Closes: #1257
2024-02-15 04:38:17 +01:00
9bc478ecbb pre-commit: replace linters + formatters with Ruff; fix some issues (#1300)
* pre-commit: replace linters + formatters with Ruff

* Don't use bare except

* Clean up `noqa`s

* Enable Ruff UP; apply auto-fixes

* Enable Ruff B; apply fixes

* Enable Ruff T with exceptions

* Enable Ruff C (complexity); autofix

* Upgrade Ruff to 0.2.0
2024-02-15 04:37:41 +01:00
29f162b86c Best practice recommendation update for dpo_trainer.mdx (#1325)
In the document as it is now the best practice recommendations don't seem neither consistent nor correct. 

For example, the documentation links a tweet with a recommendation to merge adaptors into a quantized model, and a script that supposedly illustrates how to apply that recommendation. But the script actually does the opposite of what the tweet recommends, first dequantizing the model. 

There are similar inconsistencies/ambiguities further in that paragraph. For example, saying that using an unquantized model would lead to lower performance (I changed it to "higher memory demand").

Overall, I updated the paragraph to improve consistency and provided links to slightly more evidence-based merging recommendations.
2024-02-14 11:43:48 +01:00
6852097169 Fix PPOTrainer argument train_dataset -> dataset (#1321)
Both the argument's name as well as the value need to be renamed.
Otherwise we get both

NameError: name 'train_dataset' is not defined

and

TypeError: PPOTrainer.__init__() got an unexpected keyword argument 'train_dataset'
2024-02-06 22:37:04 +01:00
f12a1da74b Fix AttributeError in dpo_trainer for reference_free case in dpo_loss function (#1313)
* Update dpo_trainer.py

update reference_free parameter for dpo_loss

* Update dpo_trainer for reference_free case

updated the docstring typo and set device parameter to ref_logratios tensor
2024-02-02 11:02:40 +01:00
ae87b3aefa Fix typos in docs for Multi Adapter RL (MARL). (#1312)
* Fix more typos

* Fix typos in docs.
2024-02-02 07:37:08 +01:00
3f7cee7643 ENH: Run CI only if relevant files are modified (#1309)
* Update tests.yml

* Update .github/workflows/tests.yml
2024-02-01 23:49:32 +01:00
ae8431bd50 Codemod Unittest assertions to bare asserts (#1301)
* Remove stray commas from test data

* Codemod Unittest assertions to bare asserts

* Make `assertAlmostEqual` tests more idiomatic

* DRY some test strings
2024-02-01 23:49:03 +01:00
66a976c6bd Update sft_trainer.mdx to add note on launching DDP training (#1308)
As requested here: https://github.com/huggingface/trl/issues/1303#issuecomment-1920437586
2024-02-01 23:42:14 +01:00
814930377c Add num_proc arg to the eval_dataset processing (#1307) 2024-02-01 17:58:00 +01:00
88685f2cd4 Types: Fix PEP 484 implicit-optional compliance (#1297)
This was done automatically with hauntsaninja/no_implicit_optional.
2024-01-31 14:51:58 +01:00
6f40f20233 Fix DPOTrainer docstrings (#1298)
Some issues were leading the auto-generation of the API reference to fail and the args were overlapped in the documentation page
2024-01-31 14:49:41 +01:00
036213bd85 Fix sft trainer when args is None (#1295)
* fix sft trainer when args is None

* add test

* fix
2024-01-31 03:31:53 +01:00
6042596705 Fix DPO slow tests (#1292)
* Update test_dpo_slow.py

* style
2024-01-30 10:15:46 +01:00
070c75ec54 load data only on main process + fix dpo example test (#1291) 2024-01-30 10:14:22 +01:00
b415224a4a fix DPO trainer + mistral + FA2 (#1290) 2024-01-30 08:25:29 +01:00
9186710671 fix padding in dpo trainer (#1284) 2024-01-30 08:24:48 +01:00
aa35fec099 raise value error if one passes a ref_model and a peft_config (#1289) 2024-01-30 08:06:03 +01:00
737d771941 Add multiprocessing in the DPO trainer. (#1286)
* Update dpo_trainer.py

Added support for num_proc to tokenize the training dataset.

* Update dpo_trainer.py

added type in the new num_proc variable

* added test case

* add test case

* fix type

---------

Co-authored-by: imraviagrawal <ravi.agrawal@umass.edu>
Co-authored-by: Ravi Agrawal <raviagrawal@Ravis-MacBook-Pro.local>
2024-01-30 02:55:07 +01:00
ef441ea028 Update dpo_trainer.mdx (#1280) 2024-01-27 10:29:10 +01:00
af623aeba6 Fix sft ci (#1279) 2024-01-26 19:18:23 +01:00
3843cfc32f Fix SFT tuner (#1278) 2024-01-26 17:49:50 +01:00
9a71e67be9 Remove tyro (#1176)
* refactor

* Remove tyro in `ppo.py`

* quick update

* update default args

* quick push

* precommit

* refactor

* quick change

* remove tyro

* quick change

* precommit

* quick change

* fix hello_world

* remove docstring diffences

* add `module load cuda/12.1`

* push changes

* precommit

* make dpo runnable

* fix circular import

* quick fix

* refactor

* quick update

* path change

* update plots

* fix docs

* quick change

* Update trl/trainer/model_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/model_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/dpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* address comments. use attn_implementation

* precommit

* remove duplicate code

* update peft.py

* fix test no op dep

* Update trl/trainer/utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* precommit

* add docs

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-26 07:51:15 -08:00
09ca565b24 FIx SFTTrainer bugs on TRL main (#1276)
* Update sft_trainer.py

* Update trl/trainer/sft_trainer.py
2024-01-26 13:50:37 +01:00
4edc688311 Only load data on main process (#1255)
* fix: only load data on main process

* define is_main_process once

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* avoid re-initializing PartialState on train dataset check

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* avoid re-initializing PartialState on eval dataset check

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* process dataset on main first to take advantage of caching

* fix typo in docs

* use decorator to manage state

* Revert "fix typo in docs"

This reverts commit 0880a188812a698f7106853245ce1ba96a036831.

* Revert "Revert "fix typo in docs""

This reverts commit ff7ee33fbeedcd0032b728d86a17cfcb10e43f9b.

* Revert "use decorator to manage state"

This reverts commit 7ac7a45949f621941fedc522f0d2ca7b29367c3a.

* use is_local_main_process instead of is_main_process

* fix: use context manager instead of attribute

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-26 10:38:07 +01:00
29d439a204 [DPO] average_log_prob when loss is IPO (#1265)
* average_log_prob when loss is IPO

* updated docs with the fix
2024-01-24 12:18:04 +01:00
5760e5d3db Fix typo in extra_columns variable name (#1269)
Co-authored-by: Otto Laitila <otto.laitila@op.fi>
2024-01-23 14:46:13 +01:00
a3c5b7178a Update utils.py (#1256) 2024-01-22 15:32:29 +01:00
222d275b8a set dev version (#1254) 2024-01-19 11:58:47 +01:00
09ca7607d5 Release: v0.7.10 (#1253) 2024-01-19 11:52:51 +01:00
1e68753216 fix: fix loss_type and some args desc (#1247) 2024-01-18 17:20:52 +01:00
1f59eeb9bb Fix chatml template (#1248)
* first draft

* 64

* sourabs suggestion

* wip tests

* make style happy

* add check

* docstring

* fix docstring

* Update tests/test_model_utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move tests

* add todo for abstract class

* make style happy

* add slow tests and imports

* add documentation

* sft_trainer.mdx aktualisieren

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fix template & add test

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-18 16:47:25 +01:00
928d14445e Add setup_chat_format for adding new special tokens to model for training chat models (#1242)
* first draft

* 64

* sourabs suggestion

* wip tests

* make style happy

* add check

* docstring

* fix docstring

* Update tests/test_model_utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move tests

* add todo for abstract class

* make style happy

* add slow tests and imports

* add documentation

* sft_trainer.mdx aktualisieren

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-18 11:05:32 +01:00
3319993bd1 Fix weird doc bug (#1244)
* Update utils.py

* Update trl/trainer/utils.py

* Update trl/trainer/utils.py
2024-01-18 10:48:56 +01:00
4fb3d0c860 Update sft_trainer.py (#1241) 2024-01-17 15:16:07 +01:00
bcccdeb6f9 [core / SFTTrainer] Fix breaking change (#1229)
* fix breaking change

* revert

* fix

* final fix

* fix

* fix tests
2024-01-17 14:45:22 +01:00
ef209e311f [core / tests ] v1 slow tests (#1218)
* v1 slow tests

* nit

* add qlora tests for DPO

* add decorator

* release memory + log reports

* report to none to avoid seg fault issues

* update setup

* fix

* add exampel testing

* fix nit

* change temp filename

* add workflow file

* fix comment

* add slack push script

* more tests for DPO

* add dpo example tests

* another makefile command

* fix

* add paths + clean up

* nit

* Update slow-tests.yml

* trigger tests

* up

* up

* more fixes

* fix

* final fixes

* minor fixes

* oops

* add more text

* fix

* more

* trigger CI

* up

* fix

* remove

* run the tests on 2 GPUs only

* final fix SFT

* revert config files + address comments

* fix

* add Phi

* final fixes

* final fix
2024-01-17 10:17:57 +01:00
341f6a6787 fix: improve error message when pad_token_id is not configured (#1152)
* fix: improve error message when `pad_token_id` is not configured

* Add test for error raised when pad_token is None

* Fix pre-commit errors

* Fix error in the test environment
2024-01-17 09:34:20 +01:00
97b9fa212a Update dpo_trainer.py (#1160)
Log metrics on all distributed processes
2024-01-15 15:40:44 +01:00
a7d796c9a2 Remove a repeating line in how_to_train.md (#1226) 2024-01-15 15:18:49 +01:00
fa074e6a15 Create slow-tests.yml (#1223) 2024-01-12 09:29:57 +01:00
776939dcc4 Add support for ChatML dataset format in (#1208)
* Add support for ChatML dataset format in
SFTTrainer

* fix formatting

* fix tests

* more comment

* fix intent

* fix doc string

* Update dataset_formatting.py

* Update dataset_formatting.py

* add documentation

* Update sft_trainer.mdx

* add leonardos comment and more tests

* added more tests and fixed batching

* style

* comment in
2024-01-12 08:05:32 +01:00
163ca9f059 Refactor RewardConfig to own module (#1221)
* Refactor RewardConfig to own module

* Fix init

* Fix import
2024-01-12 17:50:37 +11:00
2eeb7b04cf [core / Docker] Add workflow to build TRL docker images (#1215)
* add docker build

* Update docker/trl-latest-gpu/Dockerfile

* Update docker/trl-source-gpu/Dockerfile
2024-01-11 11:03:43 +01:00
9f8d0e48ad Fix args type (#1214)
* fix args type

* add args desc
2024-01-10 16:35:19 +01:00
c9b7145c75 Update Unsloth SFT, DPO docs (#1213)
* Update sft_trainer.mdx

* Update sft_trainer.mdx

* Update dpo_trainer.mdx

* Update dpo_trainer.mdx

* Update sft_trainer.mdx
2024-01-10 09:08:08 +01:00
baf3c1c293 Fix FSDP error (#1196)
* Fix FSDP error

Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP.

* Apply suggestions from code review

force return_dict

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-09 18:21:23 +01:00
b181e401a7 Fix shape descriptions in calculate_loss method (#1204) 2024-01-09 14:24:41 +01:00
26da9e80cb Check tokenize params on DPOTrainer (#1197)
* Check if tokenizer and max len params are None

* Update warning messages for missing parameters
2024-01-09 14:10:22 +01:00
d6cc88ab2c set dev version (#1207) 2024-01-09 13:06:30 +01:00
7a95cc8696 release: v0.7.9 (#1206) 2024-01-09 13:02:31 +01:00
d1715514de Revert "Address issue #1122 (#1174)" (#1205)
This reverts commit d57d0f9ca46a63d370b91791352edda0154576f5.
2024-01-09 10:20:50 +01:00
d116887ed4 [DPOTrainer] Fix peft + DPO + bf16 if one uses generate_during_eval or pre-computed logits (#1203)
* fix peft + DPO + bf16

* fix

* revert old behaviour

* fix tests

* fix

* fix

* fix

* fix
2024-01-09 09:35:50 +01:00
a236c5750f Fix reported KL in PPO trainer (#1180)
* Fix reported KL in PPO trainer

previously this was always reporting the estimated KL, even when using `kl_penalty = 'full'` (or `abs`, etc).
Now we return the actual KL calculated in `compute_rewards()`, and report that.

* fix test
2024-01-09 06:48:25 +01:00
4ae35afdd6 Fix instruction token masking (#1185)
* Fix instruction token masking

Fix instruction token masking if the first instruction is tokenized differently than the others, or in general if no instruction is detected before the first response.

* Bugfix for edge case

(in case either of the templates isn't found at all, ...idxs[0] might not exist)

* Add test for instruction masking fix
2024-01-09 06:41:53 +01:00
b21ed0ddbc set dev version (#1201) 2024-01-09 05:19:10 +01:00
384b868fe6 Release: v0.7.8 (#1200) 2024-01-09 05:13:26 +01:00
193 changed files with 24508 additions and 2803 deletions

67
.github/ISSUE_TEMPLATE/bug-report.yml vendored Normal file
View File

@ -0,0 +1,67 @@
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve TRL
labels: [ "bug" ]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report! 🤗
Before you submit your bug report:
- If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
- type: textarea
id: system-info
attributes:
label: System Info
description: Please share your system info with us. You can run the command `transformers-cli env` and copy-paste its output below.
placeholder: trl version, transformers version, platform, python version, ...
validations:
required: true
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: 'The problem arises when using:'
options:
- label: "The official example scripts"
- label: "My own modified scripts"
- type: checkboxes
id: information-tasks
attributes:
label: Tasks
description: "The tasks I am working on are:"
options:
- label: "An officially supported task in the `examples` folder"
- label: "My own task or dataset (give details below)"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
If you have code snippets, error messages, stack traces please provide them here as well.
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
placeholder: |
Steps to reproduce the behavior:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."

View File

@ -0,0 +1,31 @@
name: "\U0001F680 Feature request"
description: Submit a proposal/request for a new TRL feature
labels: [ "Feature request" ]
body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request
description: |
A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation
description: |
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
- type: textarea
id: contribution
validations:
required: true
attributes:
label: Your contribution
description: |
Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md)

View File

@ -0,0 +1,32 @@
name: "\U0001F31F New trainer addition"
description: Submit a proposal/request to implement a new trainer for a post-training method
labels: [ "New trainer" ]
body:
- type: textarea
id: description-request
validations:
required: true
attributes:
label: Method description
description: |
Put any and all important information relative to the method
- type: checkboxes
id: information-tasks
attributes:
label: Open source status
description: |
Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`.
options:
- label: "The method implementation is available"
- label: "The model weights are available"
- label: "The training datasets are available"
- type: textarea
id: additional-info
attributes:
label: Provide useful links for the implementation
description: |
Please provide information regarding the implementation, the weights, and the authors.
Please mention the authors by @gh-username if you're aware of their usernames.

32
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,32 @@
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet though.
Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
Once you're done, someone will review your PR shortly. They may suggest changes to make the code even better.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request),
Pull Request section?
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the
[documentation guidelines](https://github.com/huggingface/trl/tree/main/docs).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

View File

@ -1,107 +0,0 @@
name: "Benchmark on Comment"
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows
on:
issue_comment:
types: [created]
jobs:
Benchmark:
strategy:
fail-fast: true
matrix:
python-version: [3.9]
os: [self-hosted]
name: Benchmark
# Only run if it#s a PR and the comment contains /Benchmark
if: github.event.issue.pull_request && startsWith(github.event.comment.body, '/benchmark-trl-experiments') && contains(FromJSON('["vwxyzjn", "younesbelkada", "lvwerra", "lewtun"]'), github.actor)
runs-on: ${{ matrix.os }}
steps:
- name: Get branch of PR
uses: xt0rted/pull-request-comment-branch@v1
id: comment-branch
- name: Set latest commit status as pending
uses: myrotvorets/set-commit-status-action@master
with:
sha: ${{ steps.comment-branch.outputs.head_sha }}
token: ${{ secrets.GITHUB_TOKEN }}
status: pending
- name: Checkout `main` branch
uses: actions/checkout@v3
- name: Checkout PR branch
run: gh pr checkout $PR_NUMBER
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.issue.number }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
# - name: Cleanup pip packages (specific to self-hosted runners)
# run: |
# echo PATH is $PATH
# echo PYTHONPATH is $PYTHONPATH
# echo which python is $(which python)
# echo which pip is $(which pip)
# pip_list=$(pip list --format=freeze | grep -v "^pip==" | grep -v "^setuptools==")
# if [ ! -z "$pip_list" ]; then
# echo "$pip_list" | xargs pip uninstall -y
# fi
- name: Print python depdenencies
run: pip list --format=freeze
- name: Install dependencies
run: |
pip install .[test,benchmark]
- name: Login
run: wandb login ${{ secrets.WANDB_API_KEY }} && huggingface-cli login --token ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
- name: Run benchmark
env:
GITHUB_CONTEXT: ${{ toJson(github) }}
PERSONAL_ACCESS_TOKEN_GITHUB: ${{ secrets.PERSONAL_ACCESS_TOKEN_GITHUB }}
run: |
COMMENT="${{ github.event.comment.body }}"
if [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level1.sh"* ]]; then
echo "Running benchmark/benchmark_level1.sh"
BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" bash benchmark/benchmark_and_report.sh
elif [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level2.sh"* ]]; then
echo "Running benchmark/benchmark_level2.sh"
BENCHMARK_SCRIPT="benchmark/benchmark_level2.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level2_plot.sh" bash benchmark/benchmark_and_report.sh
elif [[ "$COMMENT" == *"/benchmark-trl-experiments benchmark/benchmark_level3.sh"* ]]; then
echo "Running benchmark/benchmark_level3.sh"
BENCHMARK_SCRIPT="benchmark/benchmark_level3.sh" BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level3_plot.sh" bash benchmark/benchmark_and_report.sh
else
echo "Invalid command in comment. Skipping execution."
fi
# send message to PR
- name: Setup Node.js 16
uses: actions/setup-node@v3
with:
node-version: 16
- name: Add workflow result as comment on PR
uses: actions/github-script@v6
if: always()
with:
script: |
const name = '${{ github.workflow }}';
const url = '${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}';
const success = '${{ job.status }}' === 'success';
const body = `${name}: ${success ? 'succeeded ✅' : 'failed ❌'}\n${url}`;
await github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: body
})
- name: Set latest commit status as ${{ job.status }}
uses: myrotvorets/set-commit-status-action@master
if: always()
with:
sha: ${{ steps.comment-branch.outputs.head_sha }}
token: ${{ secrets.GITHUB_TOKEN }}
status: ${{ job.status }}

View File

@ -14,5 +14,6 @@ jobs:
commit_sha: ${{ github.sha }}
package: trl
version_tag_suffix: ""
custom_container: huggingface/transformers-doc-builder
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@ -14,4 +14,5 @@ jobs:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: trl
version_tag_suffix: ""
version_tag_suffix: ""
custom_container: huggingface/transformers-doc-builder

View File

@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Cleanup
run: |

95
.github/workflows/docker-build.yml vendored Normal file
View File

@ -0,0 +1,95 @@
name: Build Docker images (scheduled)
on:
workflow_dispatch:
workflow_call:
schedule:
- cron: "0 1 * * *"
concurrency:
group: docker-image-builds
cancel-in-progress: false
env:
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
jobs:
trl-latest:
name: "Latest TRL GPU"
runs-on: ubuntu-latest
steps:
- name: Cleanup disk
run: |
sudo ls -l /usr/local/lib/
sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Check out code
uses: actions/checkout@v4
- name: Login to DockerHub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@v4
with:
context: ./docker/trl-latest-gpu
push: true
tags: huggingface/trl-latest-gpu
- name: Post to Slack
if: always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the trl-latest-gpu Docker Image build
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
trl-source:
name: "Latest TRL + HF ecosystem from source"
runs-on: ubuntu-latest
steps:
- name: Cleanup disk
run: |
sudo ls -l /usr/local/lib/
sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Check out code
uses: actions/checkout@v4
- name: Login to DockerHub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@v4
with:
context: ./docker/trl-source-gpu
push: true
tags: huggingface/trl-source-gpu
- name: Post to Slack
if: always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the trl-source-gpu Docker Image build
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

96
.github/workflows/slow-tests.yml vendored Normal file
View File

@ -0,0 +1,96 @@
name: Slow tests (on push)
on:
push:
branches: [ main ]
paths:
# Run only when python files are modified
- "trl/**.py"
- "examples/**.py"
env:
RUN_SLOW: "yes"
IS_GITHUB_CI: "1"
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
jobs:
run_all_tests_single_gpu:
strategy:
fail-fast: false
matrix:
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
runs-on: [self-hosted, single-gpu, nvidia-gpu, t4, ci]
env:
CUDA_VISIBLE_DEVICES: "0"
TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
container:
image: ${{ matrix.docker-image-name }}
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v4
- name: Pip install
run: |
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
- name: Run slow SFT tests on single GPU
if: always()
run: |
source activate trl
make slow_tests
- name: Generate Report
if: always()
run: |
pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
run_all_tests_multi_gpu:
strategy:
fail-fast: false
matrix:
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
runs-on: [self-hosted, multi-gpu, nvidia-gpu, t4, ci]
env:
CUDA_VISIBLE_DEVICES: "0,1"
TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
container:
image: ${{ matrix.docker-image-name }}
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v4
- name: Pip install
run: |
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
- name: Run slow SFT tests on Multi GPU
if: always()
run: |
source activate trl
make slow_tests
- name: Run end-to-end examples tests on multi GPU
if: always()
run: |
source activate trl
pip install deepspeed
make test_examples
- name: Generate Reports
if: always()
run: |
pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
rm *.txt

View File

@ -12,10 +12,10 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.8

46
.github/workflows/tests-main.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: tests on transformers PEFT main
on:
push:
branches: [ main ]
env:
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
jobs:
tests:
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
os: ['ubuntu-latest', 'windows-latest']
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# install PEFT & transformers from source
pip install -U git+https://github.com/huggingface/peft.git
pip install -U git+https://github.com/huggingface/transformers.git
# cpu version of pytorch
pip install ".[test, diffusers]"
- name: Test with pytest
run: |
make test
- name: Post to Slack
if: always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the TRL CI on transformers/PEFT main
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

View File

@ -5,6 +5,16 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
# Run only when relevant files are modified
- "trl/**.py"
- "examples/**.py"
- "scripts/**.py"
- ".github/**.yml"
- "tests/**.py"
env:
TQDM_DISABLE: 1
jobs:
check_code_quality:
@ -14,15 +24,15 @@ jobs:
python-version: [3.9]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: recursive
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- uses: pre-commit/action@v2.0.3
- uses: pre-commit/action@v3.0.1
with:
extra_args: --all-files
@ -30,13 +40,13 @@ jobs:
needs: check_code_quality
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.9', '3.10', '3.11']
os: ['ubuntu-latest', 'windows-latest']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
@ -46,8 +56,11 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# install PEFT & transformers from source
pip install -U git+https://github.com/huggingface/peft.git
pip install -U git+https://github.com/huggingface/transformers.git
# cpu version of pytorch
pip install -e ".[test, peft, diffusers]"
pip install ".[test, diffusers]"
- name: Test with pytest
run: |
make test
@ -56,9 +69,9 @@ jobs:
needs: check_code_quality
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.9'
cache: "pip"
@ -72,4 +85,4 @@ jobs:
pip install .[test]
- name: Test with pytest
run: |
make test
make test

15
.github/workflows/trufflehog.yml vendored Normal file
View File

@ -0,0 +1,15 @@
on:
push:
name: Secret Leaks
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

5
.gitignore vendored
View File

@ -143,4 +143,7 @@ checklink/cookies.txt
# wandb files
nbs/wandb/
examples/notebooks/wandb/
wandb/
wandb/
# cli scripts that are symlinked from `examples/scripts`
trl/commands/scripts/

View File

@ -1,37 +1,10 @@
repos:
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
hooks:
- id: isort
args:
- --profile=black
- --skip-glob=wandb/**/*
- --thirdparty=wandb
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
args:
- -r
- --exclude=wandb,__init__.py
- --in-place
- --remove-unused-variables
- --remove-all-unused-imports
- repo: https://github.com/python/black
rev: 22.3.0
hooks:
- id: black
args:
- --line-length=119
- --target-version=py38
- --exclude=wandb
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
args:
- --ignore=E203,E501,W503,E128
- --max-line-length=119
- id: ruff
args: [ --fix ]
- id: ruff-format
# - repo: https://github.com/codespell-project/codespell
# rev: v2.1.0

133
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,133 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
feedback@huggingface.co.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations

View File

@ -1,53 +1,258 @@
# How to contribute
# How to contribute to TRL?
## How to get started
Everyone is welcome to contribute, and we value everybody's contribution. Code
contributions are not the only way to help the community. Answering questions, helping
others, and improving the documentation are also immensely valuable.
Before you start contributing make sure you installed all the dev tools:
It also helps us if you spread the word! Reference the library in blog posts
about the awesome projects it made possible, shout out on Twitter every time it has
helped you, or simply ⭐️ the repository to say thank you.
However you choose to contribute, please be mindful and respect our
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
## Ways to contribute
There are several ways you can contribute to TRL:
* Fix outstanding issues with the existing code.
* Submit issues related to bugs or desired new features.
* Implement trainers for new post-training algorithms.
* Contribute to the examples or to the documentation.
If you don't know where to start, there is a special [Good First
Issue](https://github.com/huggingface/trl/contribute) listing. It will give you a list of
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
> All contributions are equally valuable to the community. 🥰
Before you start contributing make sure you have installed all the dev tools:
```bash
pip install -e ".[dev]"
make dev
```
## Did you find a bug?
## Fixing outstanding issues
* Ensure the bug was not already reported by searching on GitHub under Issues.
* If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
* Be sure to add the complete error messages.
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#create-a-pull-request) and open a Pull Request!
#### Did you write a patch that fixes a bug?
## Submitting a bug-related issue or feature request
* Open a new GitHub pull request with the patch.
* Ensure that your PR includes a test that fails without your patch, and pass with it.
* Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback.
## PR submission guidelines
### Did you find a bug?
* Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
* Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
* Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
* Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
* If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
The TRL library is robust and reliable thanks to users who report the problems they encounter.
### Before you submit a PR
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
First you want to make sure that all the tests pass:
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
* A short, self-contained, code snippet that allows us to reproduce the bug in
less than 30s.
* The *full* traceback if an exception is raised.
* Attach any other additional information, like screenshots, you think may help.
To get the OS and software versions automatically, run the following command:
```bash
make test
transformers-cli env
```
Then before submitting your PR make sure the code quality follows the standards. You can run the following command to format:
### Do you want a new feature?
If there is a new feature you'd like to see in TRL, please open an issue and describe:
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?
Whatever it is, we'd love to hear about it!
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
3. Provide a *code snippet* that demonstrates the features usage.
4. If the feature is related to a paper, please include a link.
If your issue is well written we're already 80% of the way there by the time you create it.
## Do you want to implement a new trainer?
New post-training methods are published on a frequent basis and those which satisfy the following criteria are good candidates to be integrated in TRL:
* **Simplicity:** does the new method achieve similar performance as prior methods, but with less complexity? A good example is [Direct Preference Optimization](https://arxiv.org/abs/2305.18290) (DPO), which provided a simpler and compelling alternative to RLHF methods.
* **Efficiency:** does the new method provide a significant improvement in training efficiency? A good example is [Odds Ratio Preference Optimization](https://arxiv.org/abs/2403.07691v2), which utilises a similar objective as DPO, but requires half the GPU VRAM.
Methods which only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
* A short description of the method and a link to the paper.
* Link to the implementation if it is open-sourced.
* Link to model weights trained with the method if they are available.
Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration:
* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py)
* RL-based optimisation: [`rloo_trainer.py](./trl/trainer/rloo_trainer.py) and [`rloo_config.py](./trl/trainer/rloo_config.py)
* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py)
## Do you want to add documentation?
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links and any missing, unclear or inaccurate content.. We'll be happy to make the changes or help you make a contribution if you're interested!
## Submitting a pull request (PR)
Before writing code, we strongly advise you to search through the existing PRs or
issues to make sure that nobody is already working on the same thing. If you are
unsure, it is always a good idea to open an issue to get some feedback.
You will need basic `git` proficiency to be able to contribute to
TRL. `git` is not the easiest tool to use but it has the greatest
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing:
1. Fork the [repository](https://github.com/huggingface/trl) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
```bash
$ git clone git@github.com:<your Github handle>/trl.git
$ cd trl
$ git remote add upstream https://github.com/huggingface/trl.git
```
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
Start by synchronizing your `main` branch with the `upstream/main` branch (ore details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
```bash
$ git checkout main
$ git fetch upstream
$ git merge upstream/main
```
Once your `main` branch is synchronized, create a new branch from it:
```bash
$ git checkout -b a-descriptive-name-for-my-changes
```
**Do not** work on the `main` branch.
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
```bash
$ make dev
```
(If TRL was already installed in the virtual environment, remove
it with `pip uninstall trl` before reinstalling it.)
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite
passes. You should run the tests impacted by your changes like this (see
below an explanation regarding the environment variable):
```bash
$ pytest tests/<TEST_TO_RUN>.py
```
> For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
> Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).
You can also run the full suite with the following command.
```bash
$ make test
```
TRL relies on `ruff` to format its source code
consistently. After you make changes, apply automatic style corrections and code verifications
that can't be automated in one go with:
This target is also optimized to only work with files modified by the PR you're working on.
If you prefer to run the checks one after the other, the following command apply the
style corrections:
```bash
$ make precommit
```
Once you're happy with your changes, add changed files using `git add` and
make a commit with `git commit` to record your changes locally:
```bash
$ git add modified_file.py
$ git commit
```
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
It is a good idea to sync your copy of the code with the original
repository regularly. This way you can quickly account for changes:
```bash
$ git fetch upstream
$ git rebase upstream/main
```
Push the changes to your account using:
```bash
$ git push -u origin a-descriptive-name-for-my-changes
```
6. Once you are satisfied (**and the checklist below is happy too**), go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors
too! So everyone can see the changes in the Pull request, work in your local
branch and push the changes to your fork. They will automatically appear in
the pull request.
### Checklist
1. The title of your pull request should be a summary of its contribution;
2. If your pull request addresses an issue, please mention the issue number in
the pull request description to make sure they are linked (and people
consulting the issue know you are working on it);
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
it from PRs ready to be merged;
4. Make sure existing tests pass;
5. Add high-coverage tests. No quality testing = no merge.
### Tests
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
We use `pytest` in order to run the tests. From the root of the
repository, here's how to run tests with `pytest` for the library:
```bash
make precommit
$ python -m pytest -sv ./tests
```
Make sure to install `pre-commit` before running the command:
```bash
pip install pre-commit
```
## Do you want to contribute to the documentation?
* Docs are in the `docs/` folder and can be updated there.
In fact, that's how `make test` is implemented (sans the `pip install` line)!
You can specify a smaller set of tests in order to test only the feature
you're working on.

View File

@ -2,4 +2,4 @@ include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
recursive-exclude * __pycache__

View File

@ -1,9 +1,18 @@
.PHONY: test precommit benchmark_core benchmark_aux
.PHONY: test precommit benchmark_core benchmark_aux common_tests slow_tests test_examples tests_gpu
check_dirs := examples tests trl
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
COMMAND_FILES_PATH = `pwd`/commands
dev:
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
pip install -e ".[dev]"
ln -s `pwd`/examples/scripts/ `pwd`/trl/commands
test:
python -m pytest -n auto --dist=loadfile -s -v ./tests/
python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/
precommit:
pre-commit run --all-files
@ -13,3 +22,22 @@ benchmark_core:
benchmark_aux:
bash ./benchmark/benchmark_aux.sh
tests_gpu:
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
slow_tests:
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
test_examples:
touch temp_results_sft_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
echo $$?','$${file} >> temp_results_sft_tests.txt; \
done
touch temp_results_dpo_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
done

136
README.md
View File

@ -3,7 +3,7 @@
</div>
# TRL - Transformer Reinforcement Learning
> Full stack transformer language models with reinforcement learning.
> Full stack library to fine-tune and align large language models.
<p align="center">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
@ -20,61 +20,73 @@
## What is it?
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
</div>
The `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).
`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
**Highlights:**
- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
## How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
This process is illustrated in the sketch below:
The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>
## Highlights
- **`Efficient and scalable`**:
- [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer).
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).
## Installation
### Python package
Install the library with pip:
Install the library with `pip`:
```bash
pip install trl
```
### From source
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
If you want to use the latest features before an official release you can install from source:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .
pip install git+https://github.com/huggingface/trl.git
```
If you wish to develop TRL, you should install in editable mode:
### Repository
If you want to use the examples you can clone the repository with the following command:
```bash
pip install -e .
git clone https://github.com/huggingface/trl.git
```
## Command Line Interface (CLI)
You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:
**SFT:**
```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```
**DPO:**
```bash
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --output_dir opt-sft-hh-rlhf
```
**Chat:**
```bash
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
## How to use
For more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.
### `SFTTrainer`
This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
This is a basic example of how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
```python
# imports
@ -98,7 +110,7 @@ trainer.train()
### `RewardTrainer`
This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
This is a basic example of how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
```python
# imports
@ -124,7 +136,7 @@ trainer.train()
### `PPOTrainer`
This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
This is a basic example of how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
```python
# imports
@ -135,14 +147,13 @@ from trl.core import respond_to_batch
# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
ref_model = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# initialize trainer
ppo_config = PPOConfig(
batch_size=1,
)
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)
# encode a query
query_txt = "This morning I went to the "
@ -152,7 +163,7 @@ query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
response_tensor = respond_to_batch(model, query_tensor)
# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
@ -162,13 +173,50 @@ reward = [torch.tensor(1.0)]
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```
### `DPOTrainer`
`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://huggingface.co/papers/2305.18290). This is a basic example of how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
```python
# imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer
# load model and dataset - dataset needs to be in a specific format
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
# load trainer
trainer = DPOTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)
# train
trainer.train()
```
## Development
If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
```
## References
### Proximal Policy Optimisation
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://huggingface.co/papers/1909.08593), [code](https://github.com/openai/lm-human-preferences)].
### Direct Preference Optimization
DPO is based on the original implementation of **"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"** by E. Mitchell et al. \[[paper](https://huggingface.co/papers/2305.18290), [code](https://github.com/eric-mitchell/direct-preference-optimization)]
### Language models
The language models utilize the `transformers` library by 🤗 Hugging Face.
## Citation

View File

@ -1,20 +1,5 @@
#### Step 1: create a work directory:
# this is necessary because another github action job will remove
# the entire directory, which slurm depends on.
# https://stackoverflow.com/questions/4632028/how-to-create-a-temporary-directory
MY_SLURM_TMP_DIR=/fsx/costa/slurm_tmpdir
mkdir -p $MY_SLURM_TMP_DIR
WORK_DIR=`mktemp -d -p "$MY_SLURM_TMP_DIR"`
cp -r "$PWD" "$WORK_DIR"
cd "$WORK_DIR/$(basename "$PWD")"
echo WORK_DIR: $WORK_DIR
#### Step 2: actual work starts:
echo PATH is $PATH
echo PYTHONPATH is $PYTHONPATH
echo whcih python is $(which python)
export WANDB_ENTITY=huggingface
export WANDB_PROJECT=trl
bash $BENCHMARK_SCRIPT > output.txt
# Extract Job IDs into an array

View File

@ -1,6 +1,39 @@
# hello world experiment
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
--slurm-nodes 1 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/dpo.py --model_name_or_path=gpt2 --per_device_train_batch_size 4 --max_steps 1000 --learning_rate 1e-3 --gradient_accumulation_steps 1 --logging_steps 10 --eval_steps 500 --output_dir="dpo_anthropic_hh" --optim adamw_torch --warmup_steps 150 --report_to wandb --bf16 --logging_first_step --no_remove_unused_columns" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
--slurm-nodes 1 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/sft.py --model_name_or_path="facebook/opt-350m" --report_to="wandb" --learning_rate=1.41e-5 --per_device_train_batch_size=64 --gradient_accumulation_steps=16 --output_dir="sft_openassistant-guanaco" --logging_steps=1 --num_train_epochs=3 --max_steps=-1 --push_to_hub --gradient_checkpointing" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
--slurm-nodes 1 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/reward_modeling.py --model_name_or_path=facebook/opt-350m --output_dir="reward_modeling_anthropic_hh" --per_device_train_batch_size=64 --num_train_epochs=1 --gradient_accumulation_steps=16 --gradient_checkpointing=True --learning_rate=1.41e-5 --report_to="wandb" --remove_unused_columns=False --optim="adamw_torch" --logging_steps=10 --eval_strategy="steps" --max_length=512" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \

View File

@ -9,7 +9,37 @@ python -m openrlbenchmark.rlops_multi_metrics \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename benchmark/trl/$FOLDER_STRING/hello_world \
--output-filename benchmark/trl/$FOLDER_STRING/ppo \
--scan-history
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/rewards/accuracies&metrics=train/loss' \
"gpt2$TAGS_STRING" \
--env-ids dpo_anthropic_hh \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename benchmark/trl/$FOLDER_STRING/dpo \
--scan-history
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss&metrics=eval/accuracy&metrics=eval/loss' \
"facebook/opt-350m$TAGS_STRING" \
--env-ids reward_modeling_anthropic_hh \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename benchmark/trl/$FOLDER_STRING/reward_modeling \
--scan-history
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss' \
"facebook/opt-350m$TAGS_STRING" \
--env-ids sft_openassistant-guanaco \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename benchmark/trl/$FOLDER_STRING/sft \
--scan-history
python benchmark/upload_benchmark.py \

View File

@ -1,6 +1,6 @@
# compound experiments: gpt2xl + grad_accu
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
@ -12,7 +12,7 @@ python benchmark/benchmark.py \
# compound experiments: Cerebras-GPT-6.7B + deepspeed zero2 + grad_accu
python benchmark/benchmark.py \
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --ppo_config.exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --ppo_config.batch_size 32 --ppo_config.mini_batch_size 32 --ppo_config.log_with wandb --ppo_config.model_name cerebras/Cerebras-GPT-6.7B --ppo_config.reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --batch_size 32 --mini_batch_size 32 --log_with wandb --model_name cerebras/Cerebras-GPT-6.7B --reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \

View File

@ -1,6 +1,6 @@
## w/ and w/o gradient accumulation
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
@ -12,7 +12,7 @@ python benchmark/benchmark.py \
## w/ different models (gpt2, gpt2-xl, falcon, llama2)
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2 --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
@ -22,7 +22,7 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
@ -35,7 +35,7 @@ python benchmark/benchmark.py \
## w/ and w/o PEFT
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_peft --use_peft --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_peft --use_peft --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \

View File

@ -1,6 +1,6 @@
#!/bin/bash
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --partition=hopper-cpu
#SBATCH --ntasks=1
#SBATCH --output=slurm/logs/%x_%j.out

View File

@ -0,0 +1,3 @@
BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" \
BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" \
bash benchmark/benchmark_and_report.sh

View File

@ -1,16 +1,19 @@
#!/bin/bash
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --partition=hopper-prod
#SBATCH --gpus-per-task={{gpus_per_task}}
#SBATCH --cpus-per-gpu={{cpus_per_gpu}}
#SBATCH --ntasks={{ntasks}}
#SBATCH --output=slurm/logs/%x_%j.out
#SBATCH --array={{array}}
#SBATCH --exclude=ip-26-0-156-239,ip-26-0-148-151,ip-26-0-146-212,ip-26-0-145-137,ip-26-0-146-249,ip-26-0-146-149,ip-26-0-147-233,ip-26-0-145-154,ip-26-0-144-35,ip-26-0-144-189,ip-26-0-146-183,ip-26-0-147-120,ip-26-0-144-95,ip-26-0-145-193
##SBATCH --exclude=ip-26-0-149-199
module load cuda/12.1
{{nodes}}
seeds={{seeds}}
seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]}
echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed"
srun {{command}} --ppo_config.seed $seed
srun {{command}} --seed $seed

58
commands/run_dpo.sh Normal file
View File

@ -0,0 +1,58 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# This is a hack to get the number of available GPUs
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/examples/scripts/dpo.py \
--model_name_or_path $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

60
commands/run_sft.sh Normal file
View File

@ -0,0 +1,60 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_sft/"
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
DATASET_NAME="imdb"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# Set your number of GPUs here
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/examples/scripts/sft.py \
--model_name $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--dataset_text_field 'text' \
--per_device_train_batch_size $BATCH_SIZE \
--max_seq_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

View File

@ -0,0 +1,66 @@
# Builds GPU docker image of PyTorch
# Uses multi-staged approach to reduce size
# Stage 1
# Use base conda image to reduce time
FROM continuumio/miniconda3:latest AS compile-image
# Specify py version
ENV PYTHON_VERSION=3.10
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN apt-get update && \
apt-get install -y curl git wget software-properties-common git-lfs && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Install audio-related libraries
RUN apt-get update && \
apt install -y ffmpeg
RUN apt install -y libsndfile1-dev
RUN git lfs install
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
RUN python3 -m pip install --no-cache-dir --upgrade pip
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
# We don't install pytorch here yet since CUDA isn't available
# instead we use the direct torch wheel
ENV PATH /opt/conda/envs/trl/bin:$PATH
# Activate our bash shell
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
# Stage 2
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
COPY --from=compile-image /opt/conda /opt/conda
ENV PATH /opt/conda/bin:$PATH
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
RUN source activate trl && \
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
# Install apt libs
RUN apt-get update && \
apt-get install -y curl git wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Activate the conda env and install transformers + accelerate from source
RUN source activate trl && \
python3 -m pip install -U --no-cache-dir \
librosa \
"soundfile>=0.12.1" \
scipy \
transformers \
accelerate \
peft \
trl[test]@git+https://github.com/huggingface/trl
RUN source activate trl && \
pip freeze | grep trl
RUN echo "source activate trl" >> ~/.profile
# Activate the virtualenv
CMD ["/bin/bash"]

View File

@ -0,0 +1,66 @@
# Builds GPU docker image of PyTorch
# Uses multi-staged approach to reduce size
# Stage 1
# Use base conda image to reduce time
FROM continuumio/miniconda3:latest AS compile-image
# Specify py version
ENV PYTHON_VERSION=3.10
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN apt-get update && \
apt-get install -y curl git wget software-properties-common git-lfs && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Install audio-related libraries
RUN apt-get update && \
apt install -y ffmpeg
RUN apt install -y libsndfile1-dev
RUN git lfs install
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
RUN python3 -m pip install --no-cache-dir --upgrade pip
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
# We don't install pytorch here yet since CUDA isn't available
# instead we use the direct torch wheel
ENV PATH /opt/conda/envs/trl/bin:$PATH
# Activate our bash shell
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
# Stage 2
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
COPY --from=compile-image /opt/conda /opt/conda
ENV PATH /opt/conda/bin:$PATH
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
RUN source activate trl && \
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
# Install apt libs
RUN apt-get update && \
apt-get install -y curl git wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Activate the conda env and install transformers + accelerate from source
RUN source activate trl && \
python3 -m pip install -U --no-cache-dir \
librosa \
"soundfile>=0.12.1" \
scipy \
git+https://github.com/huggingface/transformers \
git+https://github.com/huggingface/accelerate \
git+https://github.com/huggingface/peft \
trl[test]@git+https://github.com/huggingface/trl
RUN source activate trl && \
pip freeze | grep transformers
RUN echo "source activate trl" >> ~/.profile
# Activate the virtualenv
CMD ["/bin/bash"]

View File

@ -5,6 +5,8 @@
title: Quickstart
- local: installation
title: Installation
- local: clis
title: Get started with Command Line Interfaces (CLIs)
- local: how_to_train
title: PPO Training FAQ
- local: use_model
@ -25,14 +27,34 @@
title: Supervised Fine-Tuning
- local: ppo_trainer
title: PPO Trainer
- local: ppov2_trainer
title: PPOv2 Trainer
- local: rloo_trainer
title: RLOO Trainer
- local: best_of_n
title: Best of N Sampling
- local: dpo_trainer
title: DPO Trainer
- local: online_dpo_trainer
title: Online DPO Trainer
- local: kto_trainer
title: KTO Trainer
- local: bco_trainer
title: BCO Trainer
- local: cpo_trainer
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: alignprop_trainer
title: AlignProp Trainer
- local: orpo_trainer
title: ORPO Trainer
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: callbacks
title: Callback Classes
- local: judges
title: Judge Classes
- local: text_environments
title: Text Environments
title: API

View File

@ -0,0 +1,91 @@
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation
## The why
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
<div style="text-align: center"><img src="https://align-prop.github.io/reward_tuning.png"/></div>
## Getting started with `examples/scripts/alignprop.py`
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
```batch
python alignprop.py --hf_user_access_token <token>
```
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
## Setting up the image logging hook function
Expect the function to be given a dictionary with keys
```python
['image', 'prompt', 'prompt_metadata', 'rewards']
```
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
### Key terms
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
- `prompt` : The prompt is the text that is used to generate the image
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
- `image` : The image generated by the Stable Diffusion model
Example code for logging sampled images with `wandb` is given below.
```python
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
```
### Using the finetuned model
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
```python
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.to("cuda")
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"dump/{prompt}.png")
```
## Credits
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).

139
docs/source/bco_trainer.mdx Normal file
View File

@ -0,0 +1,139 @@
# BCO Trainer
TRL supports the Binary Classifier Optimization (BCO).
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
For a full example have a look at [`examples/scripts/bco.py`].
## Expected dataset format
The BCO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
- `prompt`
- `completion`
- `label`
for example:
```
bco_dataset_dict = {
"prompt": [
"Hey, hello",
"How are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"completion": [
"hi nice to meet you",
"leave me alone",
"I don't have a name",
"My name is Mary",
"Python",
"C++",
"Java",
],
"label": [
True,
False,
False,
True,
True,
False,
False,
],
}
```
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
## Expected model format
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `BCOTrainer`
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
```py
training_args = BCOConfig(
beta=0.1,
)
bco_trainer = BCOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:
```py
bco_trainer.train()
```
## Underlying Distribution matching (UDM)
In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts.
Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts.
If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM.
Choose an embedding model and tokenizer:
```py
embedding_model = AutoModel.from_pretrained(your_model_id)
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
# customize this function depending on your embedding model
def embed_prompt(input_ids, attention_mask, model):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state.mean(dim=1)
embedding_model = Accelerator().prepare_model(self.embedding_model)
embedding_func = partial(embed_prompt, model=embedding_model)
```
Set `prompt_sample_size` to defined how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
```py
training_args = BCOConfig(
beta=0.1,
prompt_sample_size=512,
)
bco_trainer = BCOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
embedding_func=embedding_func,
embedding_tokenizer=self.embedding_tokenizer,
)
bco_trainer.train()
```
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
## BCOTrainer
[[autodoc]] BCOTrainer
## BCOConfig
[[autodoc]] BCOConfig

13
docs/source/callbacks.mdx Normal file
View File

@ -0,0 +1,13 @@
# Callbacks
## SyncRefModelCallback
[[autodoc]] SyncRefModelCallback
## RichProgressCallback
[[autodoc]] RichProgressCallback
## WinRateCallback
[[autodoc]] WinRateCallback

119
docs/source/clis.mdx Normal file
View File

@ -0,0 +1,119 @@
# Command Line Interfaces (CLIs)
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
Currently supported CLIs are:
- `trl sft`: fine-tune a LLM on a text/instruction dataset
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
## Fine-tuning with the CLI
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
Before using the `sft` or `dpo` commands make sure to run:
```bash
accelerate config
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
```yaml
model_name_or_path:
trl-internal-testing/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
```
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
```bash
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
```
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
### Supported Arguments
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
[[autodoc]] ModelConfig
You can pass any of these arguments either to the CLI or the YAML file.
### Supervised Fine-tuning (SFT)
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```
The SFT CLI is based on the `examples/scripts/sft.py` script.
### Direct Policy Optimization (DPO)
To use the DPO CLI, you need to have a dataset in the TRL format such as
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
These datasets always have at least three columns `prompt, chosen, rejected`:
* `prompt` is a list of strings.
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
To do a quick start, you can run the following command:
```bash
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
```
The DPO CLI is based on the `examples/scripts/dpo.py` script.
#### Custom preference dataset
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
```
## Chat interface
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
```bash
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```
> [!TIP]
> To use the chat CLI with the developer installation, you must run `make dev`
>
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Besides talking to the model there are a few commands you can use:
- **clear**: clears the current conversation and start a new one
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **exit**: closes the interface
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters.

113
docs/source/cpo_trainer.mdx Normal file
View File

@ -0,0 +1,113 @@
# CPO Trainer
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level, CPO trains models to
avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFTs methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
## SimPO
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the `CPOTrainer`. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the `CPOConfig`.
## CPO-SimPO
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO Github](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the CPOConfig.
## Expected dataset format
The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
- `prompt`
- `chosen`
- `rejected`
for example:
```py
cpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
## Expected model format
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `CPOTrainer`
For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above.
```py
cpo_config = CPOConfig(
beta=0.1,
)
cpo_trainer = CPOTrainer(
model,
args=cpo_config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:
```py
cpo_trainer.train()
```
## Loss functions
Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
## Logging
While training and evaluating we record the following reward metrics:
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
## CPOTrainer
[[autodoc]] CPOTrainer
## CPOConfig
[[autodoc]] CPOConfig

View File

@ -56,7 +56,7 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 2. define config
@ -69,7 +69,7 @@ optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
# 3. initialize trainer
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
```
For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`:
@ -83,7 +83,7 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 2. define config
@ -95,17 +95,17 @@ config = PPOConfig(**ppo_config)
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)
# 3. initialize trainer
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
```
### Use LION optimizer
You can use the new [LION optimizer from Google](https://arxiv.org/abs/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
You can use the new [LION optimizer from Google](https://huggingface.co/papers/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
```python
optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate)
...
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
```
We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)):
@ -124,7 +124,7 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 2. define config
@ -137,7 +137,7 @@ optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
# 3. initialize trainer
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)
```
## Memory efficient fine-tuning by sharing layers
@ -150,13 +150,13 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
model_ref = create_reference_model(model, num_shared_layers=6)
ref_model = create_reference_model(model, num_shared_layers=6)
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
# 2. initialize trainer
ppo_config = {'batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
```
## Pass 8-bit reference models
@ -178,13 +178,13 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
# 2. initialize trainer
ppo_config = {'batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
```
## Use the CUDA cache optimizer
@ -198,7 +198,7 @@ config = PPOConfig(..., optimize_cuda_cache=True)
## Use score scaling/normalization/clipping
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://arxiv.org/abs/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://huggingface.co/papers/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
```python
from trl import PPOConfig

View File

@ -116,4 +116,4 @@ for prompt, image in zip(prompts,results.images):
## Credits
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://arxiv.org/abs/2305.13301).
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301).

View File

@ -155,7 +155,7 @@ We report the toxicity score of 400 sampled examples, compute its mean and stand
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
| --- | --- | --- |
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | ,0.3178 |
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
| --- | --- | --- |
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |

View File

@ -1,10 +1,25 @@
# DPO Trainer
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
## How DPO works
Fine-tuning a language model via DPO consists of two steps and is easier than PPO:
1. **Data collection**: Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
DPO-compatible datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots/direct-preference-optimization-datasets](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) Collection to identify datasets that are likely to support DPO training.
This process is illustrated in the sketch below (from [figure 1 of the original paper](https://huggingface.co/papers/2305.18290)):
<img width="835" alt="Screenshot 2024-03-19 at 12 39 41" src="https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d">
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
## Expected dataset format
The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
@ -13,7 +28,7 @@ The DPO trainer expects a very specific format for the dataset. Since the model
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/rlhf-antropic-example.png", width="50%">
</div>
Therefore the final dataset object should contain these 3 entries if you use the default `DPODataCollatorWithPadding` data collator. The entries should be named:
Therefore the final dataset object should contain these 3 entries if you use the default [`DPODataCollatorWithPadding`] data collator. The entries should be named:
- `prompt`
- `chosen`
@ -55,23 +70,52 @@ dpo_dataset_dict = {
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
[`DPOTrainer`] can be used to fine-tune visual language models (VLMs). In this case, the dataset must also contain the key `images`, and the trainer's `tokenizer` is the VLM's `processor`. For example, for Idefics2, the processor expects the dataset to have the following format:
Note: Currently, VLM support is exclusive to Idefics2 and does not extend to other VLMs.
```py
dpo_dataset_dict = {
'images': [
[Image.open('beach.jpg')],
[Image.open('street.jpg')],
],
'prompt': [
'The image <image> shows',
'<image> The image depicts',
],
'chosen': [
'a sunny beach with palm trees.',
'a busy street with several cars and buildings.',
],
'rejected': [
'a snowy mountain with skiers.',
'a calm countryside with green fields.',
],
}
```
## Expected model format
The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
The DPO trainer expects a model of `AutoModelForCausalLM` or `AutoModelForVision2Seq`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `DPOTrainer`
For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the [`DPOTrainer`] with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
```py
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
training_args = DPOConfig(
beta=0.1,
)
dpo_trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
tokenizer=tokenizer, # for visual language models, use tokenizer=processor instead
)
```
After this one can then call:
```py
@ -82,75 +126,101 @@ Note that the `beta` is the temperature parameter for the DPO loss, typically so
## Loss functions
Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. To use this loss, set the `loss_type="sigmoid"` (default) in the [`DPOConfig`].
The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. To use this loss, set the `loss_type="hinge"` in the [`DPOConfig`]. In this case, the `beta` is the reciprocal of the margin.
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer.
The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. To use the loss set the `loss_type="ipo"` in the [`DPOConfig`]. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only).
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of preferences. Thus the dataset are not necessarily preferences but rather desirable vs undesirable completions. For paired preference data as required by the `DPOTrainer`, use the `loss_type="kto_pair"` argument to the trainer to utilize this loss, while for the more general case of desired and undesirable data, use the as of yet unimplemented `KTOTrainer`.
The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. To use the loss set the `loss_type="exo_pair"` in the [`DPOConfig`]. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large.
The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. To use the loss set the `loss_type="nca_pair"` in the [`DPOConfig`].
The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) and set the `loss_type="robust"` in the [`DPOConfig`].
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. To use this loss, set the `loss_type="bco_pair"` in the [`DPOConfig`].
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to 1.0.
The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. To use this loss, set the `loss_type="sppo_hard"` in the [`DPOConfig`].
The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size.
The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. To use these losses, set `loss_type="apo_zero"` or `loss_type="apo_down"` in the [`DPOConfig`].
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
## Logging
While training and evaluating we record the following reward metrics:
* `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
* `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
## Accelerate DPO fine-tuning using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `DPOTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures.
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows:
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| -------- | --------- | ---------- | --- | ---------------------- | ---------- | ------------- |
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```python
import torch
from transformers import TrainingArguments
from trl import DPOTrainer
from unsloth import FastLlamaModel, FastMistralModel
from trl import DPOConfig, DPOTrainer
from unsloth import FastLanguageModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
# Load Llama model
model, tokenizer = FastLlamaModel.from_pretrained(
model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/zephyr-sft",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False.
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
# Do model patching and add fast LoRA weights
model = FastLlamaModel.get_peft_model(
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Currently only supports dropout = 0
bias = "none", # Currently only supports bias = "none"
lora_dropout = 0, # Dropout = 0 is currently optimized
bias = "none", # Bias = "none" is currently optimized
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = max_seq_length,
)
args = TrainingArguments(output_dir="./output")
training_args = DPOConfig(
output_dir="./output",
beta=0.1,
)
dpo_trainer = DPOTrainer(
model,
model_ref=None,
ref_model=None,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
dpo_trainer.train()
```
@ -161,22 +231,21 @@ The saved model is fully compatible with Hugging Face's transformers library. Le
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
2. Merge the adapter into the base model, create another adapter on top, then leave the `model_ref` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
### Downsides to merging QLoRA before DPO (approach 2)
As suggested by [Tim Dettmers](https://twitter.com/Tim_Dettmers/status/1694654191325573456), the best option for merging QLoRA adapters is to first quantize the base model, merge the adapter, then convert back to bf16. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py)
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
You can also just merge the adapters the standard way without quantizing the base model, but then you have 1-2% reduced performance (and evidently, more issues with empty responses).
If you use the recommended approach, which quantizes the model, you're now in a situation where to use QLoRA for DPO, you will need to re-quantize the merged model again or use an unquantized merge with lower overall performance.
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
### Using option 3 - load the adapter twice
To avoid the downsides with option 2, at the expense of slightly increased VRAM, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in DPOTrainer.
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
For example:
```python
# Load the base model.
bnb_config = BitsAndBytesConfig(
@ -208,14 +277,21 @@ model = PeftModel.from_pretrained(
model.load_adapter("/path/to/peft", adapter_name="reference")
# Initialize the trainer, without a ref_model param.
dpo_trainer = DPOTrainer(
model,
...
training_args = DPOConfig(
model_adapter_name="train",
ref_adapter_name="reference",
)
dpo_trainer = DPOTrainer(
model,
args=training_args,
...
)
```
## DPOTrainer
[[autodoc]] DPOTrainer
## DPOConfig
[[autodoc]] DPOConfig

View File

@ -32,22 +32,31 @@ Then, it is encouraged to launch jobs with `accelerate launch`!
# Maintained Examples
| File | Description |
|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the `SFTTrainer` to fine tune a model or adapters into a target dataset. |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the `RewardTrainer` to train a reward model on your own dataset. |
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the `PPOTrainer` to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. |
| [`examples/scripts/stable_diffusion_tuning_example.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/stable_diffusion_tuning_example.py) | This script shows to use DDPOTrainer to fine-tune a stable diffusion model using reinforcement learning. |
| File | Description |
| ----------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
| [`examples/scripts/chat.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/chat.py) | This script allows you to load and use a model as a chatbot. |
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
| [`examples/scripts/dpo_visual.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_visual.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
| [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the [`PPOTrainer`] to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. |
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a sentiment analysis model using [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb). |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. |
| [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested on a [LLaVA 1.5]([llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)) model so users may see unexpected behaviour in other model architectures. |
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
| File | Description |
|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
| File | Description |
| --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
We also have some other examples that are less maintained but can be used as a reference:

View File

@ -19,7 +19,7 @@ However, the RL model being optimized against the reward model may learn pattern
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png">
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://arxiv.org/pdf/1909.08593.pdf">https://arxiv.org/pdf/1909.08593.pdf</a>. </p>
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
</div>
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
@ -30,7 +30,6 @@ If you generate text by purely sampling from the model distribution things work
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
- **min_length**: this ignores the EOS token until `min_length` is reached, thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
@ -60,7 +59,7 @@ Debugging the RL pipeline can be challenging due to its complexity. Here are som
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a big in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!

View File

@ -21,7 +21,7 @@ Check the appropriate sections of the documentation depending on your needs:
- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
- [`TextEnvironment`](text_environment): *Text environment to train your model using tools with RL.*
- [`TextEnvironment`](text_environments): *Text environment to train your model using tools with RL.*
## Examples
@ -37,6 +37,10 @@ Check the appropriate sections of the documentation depending on your needs:
<div class="mt-10">
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail">
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail">
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>

79
docs/source/judges.mdx Normal file
View File

@ -0,0 +1,79 @@
# Judges
TRL provides judges to easily compare two completions.
Make sure to have installed the required dependencies by running:
```bash
pip install trl[llm_judge]
```
## Using the provided judges
TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
```python
from trl import HfPairwiseJudge
judge = HfPairwiseJudge()
judge.judge(
prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]],
) # Outputs: [0, 1]
```
## Define your own judge
To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`BaseRankJudge`] and implement the [`BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`BasePairJudge`] and implement the [`BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`BaseJudge`] and implement the [`BaseJudge.judge`] method.
As an example, let's define a pairwise judge that prefers shorter completions:
```python
from trl import BasePairwiseJudge
class PrefersShorterJudge(BasePairwiseJudge):
def judge(self, prompts, completions, shuffle_order=False):
return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions]
```
You can then use this judge as follows:
```python
judge = PrefersShorterJudge()
judge.judge(
prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]],
) # Outputs: [0, 1]
```
## BaseJudge
[[autodoc]] BaseJudge
## BaseRankJudge
[[autodoc]] BaseRankJudge
## BasePairwiseJudge
[[autodoc]] BasePairwiseJudge
## RandomRankJudge
[[autodoc]] RandomRankJudge
## RandomPairwiseJudge
[[autodoc]] RandomPairwiseJudge
## PairRMJudge
[[autodoc]] PairRMJudge
## HfPairwiseJudge
[[autodoc]] HfPairwiseJudge
## OpenAIPairwiseJudge
[[autodoc]] OpenAIPairwiseJudge

102
docs/source/kto_trainer.mdx Normal file
View File

@ -0,0 +1,102 @@
# KTO Trainer
TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://huggingface.co/papers/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela.
For a full example have a look at [`examples/scripts/kto.py`].
Depending on how good your base model is, you may or may not need to do SFT before KTO.
This is different from standard RLHF and DPO, which always require SFT.
## Expected dataset format
The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
- `prompt`
- `completion`
- `label`
for example:
```
kto_dataset_dict = {
"prompt": [
"Hey, hello",
"How are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"completion": [
"hi nice to meet you",
"leave me alone",
"I don't have a name",
"My name is Mary",
"Python",
"C++",
"Java",
],
"label": [
True,
False,
False,
True,
True,
False,
False,
],
}
```
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
## Expected model format
The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `KTOTrainer`
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` * number of positives) to (`undesirable_weight` * number of negatives) is in the range 1:1 to 4:3.
```py
training_args = KTOConfig(
beta=0.1,
desirable_weight=1.0,
undesirable_weight=1.0,
)
kto_trainer = KTOTrainer(
model,
ref_model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:
```py
kto_trainer.train()
```
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
## KTOTrainer
[[autodoc]] KTOTrainer
## KTOConfig
[[autodoc]] KTOConfig

View File

@ -1,6 +1,6 @@
# Learning Tools (Experimental 🧪)
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://arxiv.org/abs/2302.04761) and [ToolBench](https://arxiv.org/pdf/2305.16504.pdf). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
@ -108,7 +108,7 @@ As we can see, while 1-2 experiments crashed for some reason, most of the runs o
## (Early Experiments 🧪): learning to use a wiki tool for question answering
In the [ToolFormer](https://arxiv.org/abs/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
<Tip warning={true}>
@ -121,7 +121,7 @@ In the [ToolFormer](https://arxiv.org/abs/2302.04761) paper, it shows an interes
### Building a search index
Since [ToolFormer](https://arxiv.org/abs/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
@ -155,7 +155,7 @@ We use the following settings:
* use the `bigcode/starcoderbase` model as the base model
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
* notice this is a simplified evaluation criteria. In [ToolFormer](https://arxiv.org/abs/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
* used the following prompt that demonstrates the usage of the wiki tool.
```python
prompt = """\
@ -194,7 +194,7 @@ Note that the correct rate of the trained model is on the low end, which could b
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
* [ToolFormer](https://arxiv.org/abs/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/brown_act.png)
@ -230,5 +230,3 @@ Q: """
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gms8k_learning_curve.png)

View File

@ -1,7 +1,7 @@
# Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)
The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
For more information on LoRA, see the [original paper](https://arxiv.org/abs/2106.09685).
For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685).
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
@ -71,7 +71,7 @@ The `trl` library is powered by `accelerate`. As such it is best to configure an
```bash
accelerate config # will prompt you to define the training configuration
accelerate launch scripts/gpt2-sentiment_peft.py # launches training
accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
```
## Using `trl` + `peft` and Data Parallelism
@ -140,5 +140,5 @@ python PATH_TO_SCRIPT
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
```bash
python examples/scripts/sft.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2
python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
```

View File

@ -1,6 +1,6 @@
# Multi Adapter RL (MARL) - a single base model for everything
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not tested the convergence of the approach. We encourage the community to let us know if they potentially face into any issue.
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues.
## Requirements
@ -48,7 +48,7 @@ trainer = PPOTrainer(
...
```
Then inside your PPO training loop, call the `compute_reward_score` method by accessing to the `model` attribute from `PPOTrainer`.
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
```python
rewards = trainer.model.compute_reward_score(**inputs)
@ -58,8 +58,8 @@ rewards = trainer.model.compute_reward_score(**inputs)
### Control on the adapter name
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is to train multiple adapters on the same base model to fine-tune on different policies.
In this case, you want to have a control on the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
```python
adapter_name_policy_1 = "policy_1"
@ -97,4 +97,4 @@ trainer = PPOTrainer(
...
)
...
```
```

View File

@ -0,0 +1,250 @@
# Online DPO Trainer
## Overview
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
The abstract from the paper is the following:
> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator.
The current implementation uses reward models for scoring completions -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use.
This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching).
## Usage tips
> [!WARNING]
> Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training.
The basic API is as follows:
```python
from datasets import Dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)
NUM_DUMMY_SAMPLES = 100
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
# The reference model to calculate the KL divergence against
ref_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
# The model to score completions with. In practice, you will need a reward model.
reward_model = AutoModelForSequenceClassification.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct", num_labels=1)
train_dataset = Dataset.from_dict(
{"prompt": ["Q: Hi how are you? A:"] * NUM_DUMMY_SAMPLES})
eval_dataset = Dataset.from_dict(
{"prompt": ["Q: What do you like to eat A:"] * NUM_DUMMY_SAMPLES})
args = OnlineDPOConfig(output_dir="online-dpo-model")
trainer = OnlineDPOTrainer(
model=model,
ref_model=ref_model,
reward_model=reward_model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
```
To test the online DPO script with 1B parameter models, run:
```bash
python examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-tldr-online-dpo \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 32 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--push_to_hub
```
Tips:
* `objective/rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up.
* We recommend using the "EOS trick" via the `--missing_eos_penalty` argument, which subtracts from the rewards a fixed scalar penalty for completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
### Expected dataset format
Unlike offline DPO, where one provides a dataset with chosen and rejected columns, online DPO only requires a dataset of prompts to generate the completions from. The [`OnlineDPOTrainer`] assumes that the dataset is preprocessed for model inference, so typically you will need to wrap your prompts in the messages format and then apply the chat template as follows:
```python
def prepare_dataset(row):
"""Apply chat template to messages"""
row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False, add_generation_prompt=True)
return row
dataset = prepare_dataset(dataset)
```
### Explanation of the logged metrics
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model.
* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
* `objective/scores`: The mean scores returned by the reward model / environment.
* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions.
* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model.
* `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions.
* `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions.
* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions.
* `logps/chosen`: The mean log probabilities of the chosen completions.
* `logps/rejected`: The mean log probabilities of the rejected completions.
* `val/contain_eos_token`: The fraction of completions which contain an EOS token.
## What is my model doing exactly?
To help you understand what your model is doing, we periodically log some sample completions from the model via [`LogCompletionsCallback`]. You can find an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/hlzevfro?nw=nwuserlewtun), which allows you to see the model's response at different stages of training. By default we generate during training, but you can customize the number of prompts to generate for in [`LogCompletionsCallback`].
## Implementation details
Many online implementation details are borrowed from the [`PPOv2Trainer`], which is itself based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
## Benchmark experiments
To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
```
# 1B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
# 2.8B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-2.8b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-2.8b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-2.8b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub \
# 6.9B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-6.9b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--gradient_checkpointing \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
```
Checkpoints and experiment tracking are available at:
- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
```bash
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 33.00%
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 41.50%
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 62.60%
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 74.20%
```
We can then plot the RLHF scaling chart.
```python
import matplotlib.pyplot as plt
results = {
"SFT": {1.0e9: 0.21, 2.8e9: 0.27, 6.9e9: 0.316},
"online-dpo": {1.0e9: 0.542, 2.8e9: 0.746, 6.9e9: 0.796},
"offline-dpo": {1.0e9: 0.422, 2.8e9: 0.517, 6.9e9: 0.701},
}
plt.plot(results["SFT"].keys(), results["SFT"].values(), label="SFT", marker="o")
plt.plot(results["online-dpo"].keys(), results["online-dpo"].values(), label="Online-dpo with RM judge", marker="o")
plt.plot(results["offline-dpo"].keys(), results["offline-dpo"].values(), label="Offline-dpo", marker="o")
plt.axhline(y=0.5, color="black", linestyle="-.", label="Human reference summary")
plt.xscale("log")
plt.xlabel("Model size")
plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)")
plt.title("DPO scaling by model size")
plt.legend()
plt.xlim(5e8, 1.2e10)
plt.xticks([1e9, 3e9, 1e10], ["1B", "3B", "10B"])
plt.grid(True, which="both", ls="--", c="0.7")
plt.tight_layout()
plt.show()
```
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online_dpo_scaling.png)
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
## OnlineDPOTrainer
[[autodoc]] OnlineDPOTrainer
## OnlineDPOConfig
[[autodoc]] OnlineDPOConfig

106
docs/source/orpo_trainer.md Normal file
View File

@ -0,0 +1,106 @@
# ORPO Trainer
[Odds Ratio Preference Optimization](https://huggingface.co/papers/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.
The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo).
## Expected dataset format
The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
- `prompt`
- `chosen`
- `rejected`
for example:
```py
orpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
## Expected model format
The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `ORPOTrainer`
For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT.
```py
orpo_config = ORPOConfig(
beta=0.1, # the lambda/alpha hyperparameter in the paper/code
)
orpo_trainer = ORPOTrainer(
model,
args=orpo_config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:
```py
orpo_trainer.train()
```
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
## Logging
While training and evaluating we record the following reward metrics:
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
## ORPOTrainer
[[autodoc]] ORPOTrainer
## ORPOConfig
[[autodoc]] ORPOConfig

View File

@ -1,9 +1,24 @@
# PPO Trainer
TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback).
TRL supports the [PPO](https://huggingface.co/papers/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback).
The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.
## How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
This process is illustrated in the sketch below:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>
## Expected dataset format
The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
@ -90,7 +105,7 @@ from trl import PPOTrainer
ppo_trainer = PPOTrainer(
model=model,
config=config,
train_dataset=train_dataset,
dataset=dataset,
tokenizer=tokenizer,
)
```
@ -115,7 +130,10 @@ We can then loop over all examples in the dataset and generate a response for ea
```py
from tqdm import tqdm
for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "):
epochs = 10
for epoch in tqdm(range(epochs), "epoch: "):
for batch in tqdm(ppo_trainer.dataloader):
query_tensors = batch["input_ids"]
@ -133,7 +151,7 @@ for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "):
ppo_trainer.log_stats(stats, batch, rewards)
#### Save model
ppo_trainer.save_model("my_ppo_model")
ppo_trainer.save_pretrained("my_ppo_model")
```
## Logging

View File

@ -0,0 +1,225 @@
# PPOv2 Trainer
TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347).
References:
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031)
## Get started
To just run a PPO script to make sure the trainer can run, you can run the following command to train a PPO model with a dummy reward model.
```bash
python examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--non_eos_penalty
```
## Explanation of the logged metrics
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
* `eps`: Tracks the number of episodes per second.
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
* `objective/scores`: The mean scores returned by the reward model / environment.
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: lr: The current learning rate used by the optimizer.
* `episode`: episode: The current global step or episode count in the training process.
## Cookbook
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.
## What is my model doing exactly?
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif?download=true)
In the logs the sampled generations look like
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ query ┃ model response ┃ score ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │
│ │ I don't know how to get rid of │ │
│ TITLE: How do you get someone │ those feelings. I'm │ │
│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │
│ │ │ │
│ POST: Hi, │ │ │
│ I'm 22, and I have been with my │ │ │
│ girlfriend for 5 years now. We │ │ │
│ recently moved together. We've │ │ │
│ always loved each other │ │ │
│ intensely. │ │ │
│ │ │ │
│ Problem, I recently started to │ │ │
│ have feelings for an other │ │ │
│ person (a friend). This person │ │ │
│ has had a boyfriend for now 3 │ │ │
│ years, and has absolutely no │ │ │
│ ideas. Those feelings were so │ │ │
│ strong, it was hard to hide │ │ │
│ them. After 2 months of me │ │ │
│ being distant and really sad, │ │ │
│ my girlfriend forced me to say │ │ │
│ what was bothering me. I'm not │ │ │
│ a good liar, and now she knows. │ │ │
│ │ │ │
│ We decided to give us a week │ │ │
│ alone, I went to my parents. │ │ │
│ │ │ │
│ Now, I'm completely lost. I │ │ │
│ keep on thinking about this │ │ │
│ person, and I hate that. I │ │ │
│ would like for those feelings │ │ │
│ to go away, to leave me alone. │ │ │
│ But I can't. │ │ │
│ │ │ │
│ What do I do? It's been 3 │ │ │
│ months now, and I'm just │ │ │
│ desperate. │ │ │
│ │ │ │
│ TL;DR: │ │ │
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │
│ │ TV. I blasted Gangnam Style on │ │
│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │
│ with a loud TV. │ up as high as it could │ │
│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │
│ POST: She was in her living │ │ │
│ room, watching TV. This was at │ │ │
│ about 8:30 in the morning, and │ │ │
│ she was exercising. She turned │ │ │
│ the TV up extra loud to hear it │ │ │
│ over her excercycle, and woke │ │ │
│ me up. I went in there asking │ │ │
│ for her to turn it down. She │ │ │
│ said she didn't have to; I │ │ │
│ explained that I always used │ │ │
│ headphones so she didn't have │ │ │
│ to deal with my noise and that │ │ │
│ she should give me a little │ │ │
│ more respect, given that I paid │ │ │
│ rent at the time. │ │ │
│ │ │ │
│ She disagreed. I went back to │ │ │
│ my room, rather pissed off at │ │ │
│ the lack of equality. I had no │ │ │
│ lock on my door; but I had a │ │ │
│ dresser right next to it, so I │ │ │
│ pulled one of the drawers out │ │ │
│ enough so that it caused the │ │ │
│ door to not be openable. Then, │ │ │
│ I turned my speakers up really │ │ │
│ loud and blasted Gangnam Style │ │ │
│ on repeat, with the bass │ │ │
│ cranked up as high as it could │ │ │
│ go. │ │ │
│ │ │ │
│ If you hate Gangnam Style for │ │ │
│ being overplayed, you will see │ │ │
│ why I chose that particular │ │ │
│ song. I personally don't mind │ │ │
│ it. But here's the thing about │ │ │
│ my bass; it vibrates the walls, │ │ │
│ making one hell of a lot of │ │ │
│ noise. Needless to say, my mom │ │ │
│ was not pleased and shut off │ │ │
│ the internet. But it was oh so │ │ │
│ worth it. │ │ │
│ │ │ │
│ TL;DR: │ │ │
└─────────────────────────────────┴─────────────────────────────────┴──────────┘
```
## Implementation details
This PPOv2 implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
## Benchmark experiments
To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
```
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/ppo/ppo_tldr.py \
--output_dir models/minimal/ppo_tldr \
--learning_rate 3e-6 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 4 \
--total_episodes 1000000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--stop_token eos \
```
Checkpoints and experiment tracking are available at:
- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/ppo_tldr)
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
```bash
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 33.00%
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 64.70%
```
The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the PPO training is working as intended.
Metrics:
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/pr-1540/ppov2.png)
```bash
# pip install openrlbenchmark==0.2.1a5
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/loss/value_avg&metrics=train/val/clipfrac_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
"cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
--env-ids models/minimal/ppo_tldr \
--pc.ncols 4 \
--pc.ncols-legend 1 \
--pc.xlabel "Episode" \
--output-filename benchmark/trl/pr-1540/ppov2 \
--scan-history
```

View File

@ -25,14 +25,14 @@ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# 2. initialize trainer
ppo_config = {"batch_size": 1}
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
# 3. encode a query
query_txt = "This morning I went to the "

View File

@ -68,6 +68,25 @@ def add_margin(row):
dataset = dataset.map(add_margin)
```
### Centering rewards
In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it.
[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs:
$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$
This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`).
```python
reward_config = RewardConfig(
center_rewards_coefficient=0.01,
...
)
```
For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932).
## RewardConfig
[[autodoc]] RewardConfig

265
docs/source/rloo_trainer.md Normal file
View File

@ -0,0 +1,265 @@
# RLOO Trainer
TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.
References:
- [Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740)
- [A2C is a special case of PPO](https://huggingface.co/papers/2205.09123)
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031)
## Get started
To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.
```bash
python examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-14m \
--reward_model_path EleutherAI/pythia-14m \
--non_eos_penalty
```
## Explanation of the logged metrics
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34)
<!-- * `rlhf_reward_var_per_prompt`: calculated by `rlhf_reward.var(0).mean()`. This is the variance of the rewards estimated across the `args.rloo_k` samples. Usually we expect it to go down (cause policy entropy goes down). -->
* `eps`: Tracks the number of episodes per second.
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
* `objective/scores`: The mean scores returned by the reward model / environment.
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: lr: The current learning rate used by the optimizer.
* `episode`: episode: The current global step or episode count in the training process.
## Cookbook
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.
## What is my model doing exactly?
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif)
In the logs the sampled generations look like
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ query ┃ model response ┃ score ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │
│ │ I don't know how to get rid of │ │
│ TITLE: How do you get someone │ those feelings. I'm │ │
│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │
│ │ │ │
│ POST: Hi, │ │ │
│ I'm 22, and I have been with my │ │ │
│ girlfriend for 5 years now. We │ │ │
│ recently moved together. We've │ │ │
│ always loved each other │ │ │
│ intensely. │ │ │
│ │ │ │
│ Problem, I recently started to │ │ │
│ have feelings for an other │ │ │
│ person (a friend). This person │ │ │
│ has had a boyfriend for now 3 │ │ │
│ years, and has absolutely no │ │ │
│ ideas. Those feelings were so │ │ │
│ strong, it was hard to hide │ │ │
│ them. After 2 months of me │ │ │
│ being distant and really sad, │ │ │
│ my girlfriend forced me to say │ │ │
│ what was bothering me. I'm not │ │ │
│ a good liar, and now she knows. │ │ │
│ │ │ │
│ We decided to give us a week │ │ │
│ alone, I went to my parents. │ │ │
│ │ │ │
│ Now, I'm completely lost. I │ │ │
│ keep on thinking about this │ │ │
│ person, and I hate that. I │ │ │
│ would like for those feelings │ │ │
│ to go away, to leave me alone. │ │ │
│ But I can't. │ │ │
│ │ │ │
│ What do I do? It's been 3 │ │ │
│ months now, and I'm just │ │ │
│ desperate. │ │ │
│ │ │ │
│ TL;DR: │ │ │
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │
│ │ TV. I blasted Gangnam Style on │ │
│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │
│ with a loud TV. │ up as high as it could │ │
│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │
│ POST: She was in her living │ │ │
│ room, watching TV. This was at │ │ │
│ about 8:30 in the morning, and │ │ │
│ she was exercising. She turned │ │ │
│ the TV up extra loud to hear it │ │ │
│ over her excercycle, and woke │ │ │
│ me up. I went in there asking │ │ │
│ for her to turn it down. She │ │ │
│ said she didn't have to; I │ │ │
│ explained that I always used │ │ │
│ headphones so she didn't have │ │ │
│ to deal with my noise and that │ │ │
│ she should give me a little │ │ │
│ more respect, given that I paid │ │ │
│ rent at the time. │ │ │
│ │ │ │
│ She disagreed. I went back to │ │ │
│ my room, rather pissed off at │ │ │
│ the lack of equality. I had no │ │ │
│ lock on my door; but I had a │ │ │
│ dresser right next to it, so I │ │ │
│ pulled one of the drawers out │ │ │
│ enough so that it caused the │ │ │
│ door to not be openable. Then, │ │ │
│ I turned my speakers up really │ │ │
│ loud and blasted Gangnam Style │ │ │
│ on repeat, with the bass │ │ │
│ cranked up as high as it could │ │ │
│ go. │ │ │
│ │ │ │
│ If you hate Gangnam Style for │ │ │
│ being overplayed, you will see │ │ │
│ why I chose that particular │ │ │
│ song. I personally don't mind │ │ │
│ it. But here's the thing about │ │ │
│ my bass; it vibrates the walls, │ │ │
│ making one hell of a lot of │ │ │
│ noise. Needless to say, my mom │ │ │
│ was not pleased and shut off │ │ │
│ the internet. But it was oh so │ │ │
│ worth it. │ │ │
│ │ │ │
│ TL;DR: │ │ │
└─────────────────────────────────┴─────────────────────────────────┴──────────┘
```
## Implementation details
The bulk of RLOOTrainer is based on the PPO implementation, which is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
Below is a vectorized advantage calculation for RLOO:
```python
def test_rloo_reward():
local_batch_size = 3
rloo_k = 4
rlhf_reward = torch.tensor([
1, 2, 3, # first rlhf reward for three prompts
2, 3, 4, # second rlhf reward for three prompts
5, 6, 7, # third rlhf reward for three prompts
8, 9, 10, # fourth rlhf reward for three prompts
]).float() # here we have 3 prompts which have 4 completions each
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
advantages = torch.zeros_like(rlhf_reward)
for i in range(0, len(advantages), local_batch_size):
other_response_rlhf_rewards = []
for j in range(0, len(advantages), local_batch_size):
if i != j:
other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(other_response_rlhf_rewards).mean(0)
assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6 # First rlhf reward for the first prompt
assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6 # Third rlhf reward for the second prompt
# Vectorized implementation
rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)
```
## Benchmark experiments
To validate the RLOO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
```
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/rloo/rloo_tldr.py \
--output_dir models/minimal/rloo_tldr \
--num_ppo_epochs 2 \
--num_mini_batches 2 \
--learning_rate 3e-6 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--total_episodes 1000000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--stop_token eos \
--kl_coef 0.03
```
Checkpoints and experiment tracking are available at:
- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/rloo_tldr)
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/u2sqci34)
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
```bash
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 33.00%
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 51.20%
```
The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the RLOO training is working as intended.
Metrics:
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/pr-1540/rloo.png)
```bash
# pip install openrlbenchmark==0.2.1a5
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
"cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
--env-ids models/minimal/rloo_tldr \
--pc.ncols 4 \
--pc.ncols-legend 1 \
--pc.xlabel "Episode" \
--output-filename benchmark/trl/pr-1540/rloo \
--scan-history
```

View File

@ -25,7 +25,7 @@ accelerate launch examples/scripts/ppo.py # launches training
# 3. get help text and documentation
python examples/scripts/ppo.py --help
# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16
python examples/scripts/ppo.py --ppo_config.log_with wandb --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 16
python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16
```
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
@ -42,7 +42,7 @@ Below are some benchmark results for `examples/scripts/ppo.py`. To reproduce loc
```bash
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -61,7 +61,7 @@ python benchmark/benchmark.py \
```bash
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -79,7 +79,7 @@ python benchmark/benchmark.py \
```bash
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2 --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -89,7 +89,7 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -99,7 +99,7 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -116,7 +116,7 @@ python benchmark/benchmark.py \
```
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_peft --use_peft --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_peft --use_peft --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \

View File

@ -3,6 +3,7 @@
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py).
Experimental support for Vision Language Models is also included in the example [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/vsft_llava.py).
## Quickstart
@ -11,42 +12,47 @@ The following code-snippet takes care of all the data pre-processing and trainin
```python
from datasets import load_dataset
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("imdb", split="train")
sft_config = SFTConfig(
dataset_text_field="text",
max_seq_length=512,
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
args=sft_config,
)
trainer.train()
```
Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
You can also construct a model outside of the trainer and pass it as follows:
```python
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
sft_config = SFTConfig(output_dir="/tmp")
trainer = SFTTrainer(
model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
args=sft_config,
)
trainer.train()
```
The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/huggingface/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
## Advanced usage
@ -58,7 +64,7 @@ To instantiate that collator for instruction data, pass a response template and
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
@ -78,6 +84,7 @@ collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenize
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(output_dir="/tmp"),
formatting_func=formatting_prompts_func,
data_collator=collator,
)
@ -90,7 +97,7 @@ To instantiate that collator for assistant style conversation data, pass a respo
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
@ -103,8 +110,11 @@ collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_temp
trainer = SFTTrainer(
model,
args=SFTConfig(
output_dir="/tmp",
dataset_text_field = "text",
),
train_dataset=dataset,
dataset_text_field="text",
data_collator=collator,
)
@ -115,7 +125,7 @@ Make sure to have a `pad_token_id` which is different from `eos_token_id` which
#### Using token_ids directly for `response_template`
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example:
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:
```python
from transformers import AutoTokenizer
@ -145,7 +155,7 @@ RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs
```
To solve this, you can tokenize the `response_template` with the same context than in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
```python
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
@ -154,6 +164,73 @@ response_template_ids = tokenizer.encode(response_template_with_context, add_spe
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```
### Add Special Tokens for Chat Format
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the models embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# Set up the chat format with default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)
```
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
### Dataset format support
The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
* conversational format
```json
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
```
* instruction format
```json
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```
If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method.
```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
...
# load jsonl dataset
dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train")
# load dataset from the HuggingFace Hub
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
...
sft_config = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
args=sft_config,
train_dataset=dataset,
)
```
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
### Format your input prompts
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
@ -179,32 +256,34 @@ def formatting_prompts_func(example):
trainer = SFTTrainer(
model,
args=sft_config,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
)
trainer.train()
```
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
### Packing dataset ([`ConstantLengthDataset`])
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTTrainer`] constructor.
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
```python
...
sft_config = SFTConfig(packing=True, dataset_text_field="text",)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
packing=True
args=sft_config
)
trainer.train()
```
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.
#### Customize your prompts using packed dataset
@ -215,35 +294,37 @@ def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
sft_config = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
packing=True,
args=sft_config,
formatting_func=formatting_func
)
trainer.train()
```
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTTrainer`] constructor. Please refer to that class' signature for more information.
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTConfig`] constructor. Please refer to that class' signature for more information.
### Control over the pretrained model
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to
```python
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
```
```python
...
sft_config = SFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
},
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
model_init_kwargs={
"torch_dtype": torch.bfloat16,
},
args=sft_config,
)
trainer.train()
@ -252,11 +333,11 @@ Note that all keyword arguments of `from_pretrained()` are supported.
### Training adapters
We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
```python
from datasets import load_dataset
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
dataset = load_dataset("imdb", split="train")
@ -272,7 +353,7 @@ peft_config = LoraConfig(
trainer = SFTTrainer(
"EleutherAI/gpt-neo-125m",
train_dataset=dataset,
dataset_text_field="text",
args=SFTConfig(output_dir="/tmp"),
peft_config=peft_config
)
@ -283,7 +364,7 @@ You can also continue training your `PeftModel`. For that, first load a `PeftMod
### Training adapters with base 8 bit models
For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
For that, you need to first load your 8 bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
```python
...
@ -305,7 +386,7 @@ model = AutoModelForCausalLM.from_pretrained(
trainer = SFTTrainer(
model,
train_dataset=dataset,
dataset_text_field="text",
args=SFTConfig(),
peft_config=peft_config,
)
@ -346,11 +427,11 @@ Note that you cannot train your model using Flash Attention 1 on an arbitrary da
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
|----------------|-------------------|-------------|------------|------------------------|
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
### Using Flash Attention-2
@ -360,24 +441,60 @@ To use Flash Attention 2, first install the latest `flash-attn` package:
pip install -U flash-attn
```
And add `use_flash_attention_2=True` when calling `from_pretrained`:
And add `attn_implementation="flash_attention_2"` when calling `from_pretrained`:
```python
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True
attn_implementation="flash_attention_2"
)
```
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
In contrary to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
### Enhance model's performances using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://arxiv.org/abs/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
### Using model creation utility
We included a utility function to create your model.
[[autodoc]] ModelConfig
```python
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
model_config = ModelConfig(
model_name_or_path="facebook/opt-350m"
attn_implementation=None, # or "flash_attention_2"
)
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
trainer = SFTTrainer(
...,
model=model_config.model_name_or_path,
peft_config=get_peft_config(model_config),
)
```
### Enhance the model's performances using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
@ -385,20 +502,21 @@ NEFTune is a technique to boost the performance of chat models and was introduce
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/neft-screenshot.png">
</div>
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTTrainer` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
```python
from datasets import load_dataset
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("imdb", split="train")
sft_config = SFTConfig(
neftune_noise_alpha=5,
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
neftune_noise_alpha=5,
args=sft_config,
)
trainer.train()
```
@ -413,53 +531,64 @@ Note however, that the amount of performance gain is _dataset dependent_ and in
### Accelerate fine-tuning 2x using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures.
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows:
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```python
import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import FastLlamaModel, FastMistralModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
# Load Llama model
model, tokenizer = FastLlamaModel.from_pretrained(
model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_seq_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
# Do model patching and add fast LoRA weights
model = FastLlamaModel.get_peft_model(
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Currently only supports dropout = 0
bias = "none", # Currently only supports bias = "none"
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = max_seq_length,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing=True,
random_state=3407,
)
args = TrainingArguments(output_dir="./output")
args = SFTConfig(
output_dir="./output",
max_seq_length=max_seq_length,
dataset_text_field="text",
)
trainer = SFTTrainer(
model = model,
args = args,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
model=model,
args=args,
train_dataset=dataset,
)
trainer.train()
```
@ -469,19 +598,155 @@ The saved model is fully compatible with Hugging Face's transformers library. Le
Pay attention to the following best practices when training a model with that trainer:
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
## Multi-GPU Training
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
- Ensure that the model is placed on the correct device:
```python
from accelerate import PartialState
device_string = PartialState().process_index
model = AutoModelForCausalLM.from_pretrained(
...
device_map={'':device_string}
)
```
## GPTQ Conversion
You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.
## Extending `SFTTrainer` for Vision Language Models
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
### Preparing the Data
The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:
```python
images = ["obama.png"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Who is this?"},
{"type": "image"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Barack Obama"}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is he famous for?"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "He is the 44th President of the United States."}
]
}
]
```
To illustrate how this data format will be processed using the LLaVA model, you can use the following code:
```python
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))
```
The output will be formatted as follows:
```txt
Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States.
```
<iframe src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train" frameborder="0" width="100%" height="560px"></iframe>
### A custom collator for processing multi-modal data
Unlike the default behavior of `SFTTrainer`, processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:
```python
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]
# Tokenize the texts and process the images
batch = processor(texts, images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
```
We can verify that the collator works as expected by running the following code:
```python
from datasets import load_dataset
dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])
```
### Training the vision-language model
Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the `SFTConfig`, specifically `dataset_text_field` and `remove_unused_columns`. We also need to set `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`.
```python
args.dataset_text_field = "" # needs a dummy field
args.remove_unused_columns = False
args.dataset_kwargs = {"skip_prepare_dataset": True}
trainer = SFTTrainer(
model=model,
args=args,
data_collator=collate_fn,
train_dataset=train_dataset,
tokenizer=processor.tokenizer,
)
```
A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py).
- [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s)
- [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf)
## SFTTrainer
[[autodoc]] SFTTrainer
## ConstantLengthDataset
## SFTConfig
[[autodoc]] SFTConfig
## Datasets
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
### ConstantLengthDataset
[[autodoc]] trainer.ConstantLengthDataset

View File

@ -1,9 +1,50 @@
# Trainer
At TRL we support PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
At TRL we support PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://huggingface.co/papers/1909.08593), [code](https://github.com/openai/lm-human-preferences)].
The Trainer and model classes are largely inspired from `transformers.Trainer` and `transformers.AutoModel` classes and adapted for RL.
We also support a `RewardTrainer` that can be used to train a reward model.
## CPOConfig
[[autodoc]] CPOConfig
## CPOTrainer
[[autodoc]] CPOTrainer
## DDPOConfig
[[autodoc]] DDPOConfig
## DDPOTrainer
[[autodoc]] DDPOTrainer
## DPOTrainer
[[autodoc]] DPOTrainer
## IterativeSFTTrainer
[[autodoc]] IterativeSFTTrainer
## KTOConfig
[[autodoc]] KTOConfig
## KTOTrainer
[[autodoc]] KTOTrainer
## ORPOConfig
[[autodoc]] ORPOConfig
## ORPOTrainer
[[autodoc]] ORPOTrainer
## PPOConfig
[[autodoc]] PPOConfig
@ -24,22 +65,6 @@ We also support a `RewardTrainer` that can be used to train a reward model.
[[autodoc]] SFTTrainer
## DPOTrainer
[[autodoc]] DPOTrainer
## DDPOConfig
[[autodoc]] DDPOConfig
## DDPOTrainer
[[autodoc]] DDPOTrainer
## IterativeSFTTrainer
[[autodoc]] IterativeSFTTrainer
## set_seed
[[autodoc]] set_seed

View File

@ -2,7 +2,6 @@ compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false

View File

@ -2,7 +2,6 @@ compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
@ -12,7 +11,7 @@ distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static

View File

@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: "NO"
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,20 @@
# This is an example configuration file of TRL CLI, you can use it for
# SFT like that: `trl sft --config config.yaml --output_dir test-sft`
# The YAML file supports environment variables by adding an `env` field
# as below
# env:
# CUDA_VISIBLE_DEVICES: 0
model_name_or_path:
trl-internal-testing/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine

View File

@ -0,0 +1,122 @@
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import HfApi
from huggingface_hub.repocard import RepoCard
from transformers import HfArgumentParser
"""
# debug
python -i examples/datasets/anthropic_hh.py --debug --push_to_hub
# actual push
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity trl-internal-testing
"""
api = HfApi()
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
hf_repo_id: Optional[str] = field(
default="hh-rlhf-helpful-base-trl-style", metadata={"help": "The Hugging Face repository ID"}
)
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
update_main_revision: Optional[bool] = field(
default=True, metadata={"help": "Update the main revision of the repository"}
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use for dataset processing"}
)
# GPT-4 generated 😄 Define a function to process the input and extract the dialogue into structured format
def extract_dialogue(input_text):
# Split the input by lines and initialize variables
lines = input_text.strip().split("\n\n")
dialogue_list = []
# Iterate through each line and extract the dialogue
for line in lines:
# Check if the line starts with "Human" or "Assistant" and split accordingly
if line.startswith("Human:"):
role = "user"
content = line.replace("Human: ", "").strip()
elif line.startswith("Assistant:"):
role = "assistant"
content = line.replace("Assistant: ", "").strip()
else:
# If the line doesn't start with "Human" or "Assistant", it's part of the previous message's content
# Append it to the last message's content
dialogue_list[-1]["content"] += "\n\n" + line.strip()
continue
# Append the extracted dialogue piece to the list
dialogue_list.append({"role": role, "content": content})
return dialogue_list
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
ds = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
def process(row):
row["chosen"] = extract_dialogue(row["chosen"])
row["rejected"] = extract_dialogue(row["rejected"])
row["prompt"] = row["chosen"][0]["content"]
return row
ds = ds.map(process, num_proc=args.dataset_num_proc)
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
# get the commnad used to run the script
run_command = " ".join(["python"] + sys.argv)
for revision in revisions:
ds.push_to_hub(full_repo_id, revision=revision)
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
# get the name of the current file
file_name = __file__.split("/")[-1]
api.upload_file(
path_or_fileobj=__file__,
path_in_repo=file_name,
revision=revision,
repo_id=full_repo_id,
repo_type="dataset",
)
sft_card = RepoCard.load(
full_repo_id,
repo_type="dataset",
)
sft_card.text = f"""\
# TRL's Anthropic HH Dataset
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""
sft_card.push_to_hub(
full_repo_id,
repo_type="dataset",
)

View File

@ -0,0 +1,188 @@
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import Dataset, DatasetDict
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.repocard import RepoCard
from transformers import AutoTokenizer, HfArgumentParser
"""
# debug
python -i examples/datasets/sentiment_descriptiveness.py --push_to_hub
# actual push
python examples/datasets/sentiment_descriptiveness.py \
--hf_repo_id sentiment-trl-style \
--task sentiment \
--push_to_hub \
--hf_entity trl-internal-testing
python examples/datasets/sentiment_descriptiveness.py \
--hf_repo_id descriptiveness-trl-style \
--task descriptiveness \
--push_to_hub \
--hf_entity trl-internal-testing
"""
api = HfApi()
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
hf_repo_id: Optional[str] = field(
default="sentiment-trl-style", metadata={"help": "The Hugging Face repository ID"}
)
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
update_main_revision: Optional[bool] = field(
default=True, metadata={"help": "Update the main revision of the repository"}
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
task: str = field(default="sentiment", metadata={"help": "The task of the dataset"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
)
task_to_filename = {
"sentiment": "sentiment/offline_5k.json",
"descriptiveness": "descriptiveness/offline_5k.json",
}
def deduplicate_query(ds):
query = set()
ranges = []
for i in range(len(ds)):
query_str = str(ds[i]["query"])
if query_str not in query:
query.add(query_str)
ranges.append(i)
return ds.select(ranges)
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
model_name = "gpt2"
dataset_tokenizer = AutoTokenizer.from_pretrained("gpt2") # of the dataset
################
# Dataset
################
json = hf_hub_download(
repo_id="vwxyzjn/lm-human-preferences",
repo_type="dataset",
filename=task_to_filename[args.task],
)
MAGIC_TRAIN_NUMBER = 4992 # taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L70
individual_ds = Dataset.from_json(json)
individual_ds = deduplicate_query(individual_ds)
ds = DatasetDict(
{
"train": individual_ds.select(range(MAGIC_TRAIN_NUMBER)),
"test": individual_ds.select(range(MAGIC_TRAIN_NUMBER, len(individual_ds))),
}
)
MAX_DEBUG_SAMPLES = 50
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(min(MAX_DEBUG_SAMPLES, len(ds[key]))))
# columns are `['sample2', 'sample3', 'sample0', 'query', 'sample1', 'best']`
NUM_SAMPLES = 4
# edge cases handling: remove the cases where all samples are the same
def filter(row):
best_idx = row["best"]
chosen_sample = row[f"sample{best_idx}"]
if all(chosen_sample == row[f"sample{j}"] for j in range(NUM_SAMPLES)):
return False
else:
return True
print("=== Before filtering ===", ds)
ds = ds.filter(filter, num_proc=args.dataset_num_proc)
print("=== After filtering ===", ds)
# here we simply take the preferred sample as the chosen one and the first non-preferred sample as the rejected one
def process(row):
for j in range(NUM_SAMPLES):
row[f"sample{j}"] = dataset_tokenizer.batch_decode(row[f"sample{j}"])
row["prompt"] = dataset_tokenizer.batch_decode(row["query"])
row["prompt"] = [item.strip() for item in row["prompt"]]
row["chosen"] = []
row["rejected"] = []
for i in range(len(row["best"])):
best_idx = row["best"][i]
chosen_sample = row[f"sample{best_idx}"][i].strip()
row["chosen"].append(
[
{"role": "user", "content": row["prompt"][i].strip()},
{"role": "assistant", "content": chosen_sample},
]
)
# find the first rejected sample which is different from the chosen one
rejected_idx = -1
for k in range(4):
if k != best_idx and row[f"sample{k}"][i].strip() != chosen_sample:
rejected_idx = k
break
rejected_sample = row[f"sample{rejected_idx}"][i].strip()
assert rejected_idx != -1, "No rejected sample found! This should not happen!"
row["rejected"].append(
[
{"role": "user", "content": row["prompt"][i].strip()},
{"role": "assistant", "content": rejected_sample},
]
)
assert chosen_sample != rejected_sample
return row
ds = ds.map(process, batched=True, num_proc=args.dataset_num_proc)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(["prompt", "chosen", "rejected"])
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
# get the commnad used to run the script
run_command = " ".join(["python"] + sys.argv)
for revision in revisions:
ds.push_to_hub(full_repo_id, revision=revision)
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
# get the name of the current file
file_name = __file__.split("/")[-1]
api.upload_file(
path_or_fileobj=__file__,
path_in_repo=file_name,
revision=revision,
repo_id=full_repo_id,
repo_type="dataset",
)
sft_card = RepoCard.load(
full_repo_id,
repo_type="dataset",
)
sft_card.text = f"""\
# TRL's Preference Dataset: {args.task}
The dataset comes from https://huggingface.co/papers/1909.08593, one of the earliest RLHF work from OpenAI.
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""
sft_card.push_to_hub(
full_repo_id,
repo_type="dataset",
)

View File

@ -0,0 +1,185 @@
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import HfApi
from huggingface_hub.repocard import RepoCard
from transformers import HfArgumentParser
"""
# debug
python -i examples/datasets/tldr_preference.py --debug --push_to_hub
# actual push
python examples/datasets/tldr_preference.py --push_to_hub --hf_entity trl-internal-testing
"""
api = HfApi()
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
hf_repo_id: Optional[str] = field(
default="tldr-preference-trl-style", metadata={"help": "The Hugging Face repository ID"}
)
sft_hf_repo_id: Optional[str] = field(
default="tldr-preference-sft-trl-style", metadata={"help": "The Hugging Face repository ID"}
)
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
update_main_revision: Optional[bool] = field(
default=True, metadata={"help": "Update the main revision of the repository"}
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
)
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
full_sft_repo_id = f"{args.hf_entity}/{args.sft_hf_repo_id}"
################
# Preference dataset
################
ds = load_dataset("openai/summarize_from_feedback", "comparisons")
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
cnndm_batches = ["batch0_cnndm", "cnndm0", "cnndm2"]
if not args.debug:
ds["validation_cnndm"] = ds["validation"].filter(
lambda x: x["batch"] in cnndm_batches, num_proc=args.dataset_num_proc
)
ds["validation"] = ds["validation"].filter(
lambda x: x["batch"] not in cnndm_batches, num_proc=args.dataset_num_proc
)
tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
cnndm_format_str = "Article:\n{article}\n\nTL;DR:"
def process(row):
format_str = cnndm_format_str if row["batch"] in cnndm_batches else tldr_format_str
row["prompt"] = format_str.format(**row["info"])
choice = row["choice"]
# need to remove the leading space
chosen = row["summaries"][choice]["text"].strip()
rejected = row["summaries"][1 - choice]["text"].strip()
row["chosen"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": chosen}]
row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": rejected}]
return row
ds = ds.map(process, num_proc=args.dataset_num_proc)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(
["prompt", "chosen", "rejected", "info", "summaries", "choice", "worker", "batch", "split", "extra"]
)
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
# get the commnad used to run the script
run_command = " ".join(["python"] + sys.argv)
for revision in revisions:
ds.push_to_hub(full_repo_id, revision=revision)
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
# get the name of the current file
file_name = __file__.split("/")[-1]
api.upload_file(
path_or_fileobj=__file__,
path_in_repo=file_name,
revision=revision,
repo_id=full_repo_id,
repo_type="dataset",
)
preference_card = RepoCard.load(
full_repo_id,
repo_type="dataset",
)
preference_card.text = f"""\
# TRL's TL;DR Preference Dataset
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
## Source of the dataset
We take the dataset from https://huggingface.co/datasets/openai/summarize_from_feedback.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""
preference_card.push_to_hub(
full_repo_id,
repo_type="dataset",
)
################
# SFT dataset
################
sft_ds = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered")
if args.debug:
for key in sft_ds:
sft_ds[key] = sft_ds[key].select(range(50))
def sft_process(row):
row["prompt"] = tldr_format_str.format(**row)
row["messages"] = [
{"role": "user", "content": row["prompt"]},
{"role": "assistant", "content": row["summary"]},
]
return row
sft_ds = sft_ds.map(sft_process, num_proc=args.dataset_num_proc)
for key in sft_ds: # reorder columns
sft_ds[key] = sft_ds[key].select_columns(["prompt", "messages", "id", "subreddit", "title", "post", "summary"])
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
# get the commnad used to run the script
run_command = " ".join(["python"] + sys.argv)
for revision in revisions:
sft_ds.push_to_hub(full_sft_repo_id, revision=revision)
repo_full_url = f"https://huggingface.co/datasets/{full_sft_repo_id}/tree/{revision}"
# get the name of the current file
file_name = __file__.split("/")[-1]
api.upload_file(
path_or_fileobj=__file__,
path_in_repo=file_name,
revision=revision,
repo_id=full_sft_repo_id,
repo_type="dataset",
)
sft_card = RepoCard.load(
full_sft_repo_id,
repo_type="dataset",
)
sft_card.text = f"""\
# TRL's TL;DR SFT Dataset
We preprocess the dataset using our standard `prompt, messages` format.
## Source of the dataset
We take the dataset from https://huggingface.co/datasets/vwxyzjn/summarize_from_feedback_tldr_3_filtered.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""

View File

@ -0,0 +1,42 @@
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser
"""
python -i examples/datasets/tokenize_ds.py --debug --model HuggingFaceH4/zephyr-7b-beta
python -i examples/datasets/tokenize_ds.py --debug --model gpt2
"""
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
dataset: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"}
)
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
)
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
ds = load_dataset(args.dataset)
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
ds = ds.map(process, num_proc=args.dataset_num_proc)
print(ds["train"][0]["chosen"])

View File

@ -7,14 +7,14 @@ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# 2. initialize trainer
ppo_config = {"batch_size": 1}
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
# 3. encode a query
query_txt = "This morning I went to the "
@ -29,7 +29,7 @@ generation_kwargs = {
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_tensor = ppo_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])
# 5. define a reward for response

View File

@ -121,7 +121,7 @@
"metadata": {},
"source": [
"You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n",
"https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
"https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
]
},
{
@ -152,7 +152,7 @@
"outputs": [],
"source": [
"gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n",
"gpt2_model_ref = create_reference_model(gpt2_model)\n",
"gpt2_ref_model = create_reference_model(gpt2_model)\n",
"gpt2_tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n",
"\n",
"gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token"
@ -353,7 +353,7 @@
}
],
"source": [
"ppo_trainer = PPOTrainer(config, gpt2_model, gpt2_model_ref, gpt2_tokenizer, dataset, data_collator=collator)"
"ppo_trainer = PPOTrainer(config, gpt2_model, gpt2_ref_model, gpt2_tokenizer, dataset, data_collator=collator)"
]
},
{

View File

@ -92,7 +92,7 @@
" log_with=\"wandb\",\n",
")\n",
"\n",
"sent_kwargs = {\"return_all_scores\": True, \"function_to_apply\": \"none\", \"batch_size\": 16}"
"sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
]
},
{
@ -110,8 +110,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n",
"https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
"You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/main/examples/legacy/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n",
"https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
]
},
{
@ -134,16 +134,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset imdb (/home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n",
"Loading cached processed dataset at /home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-ff455473e884c6a3.arrow\n"
]
}
],
"outputs": [],
"source": [
"def build_dataset(config, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n",
" \"\"\"\n",
@ -270,8 +261,8 @@
{
"data": {
"text/plain": [
"[[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n",
" {'label': 'POSITIVE', 'score': -2.726576566696167}]]"
"[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n",
" {'label': 'POSITIVE', 'score': -2.726576328277588}]"
]
},
"execution_count": null,
@ -292,8 +283,8 @@
{
"data": {
"text/plain": [
"[[{'label': 'NEGATIVE', 'score': -2.2947897911071777},\n",
" {'label': 'POSITIVE', 'score': 2.557039737701416}]]"
"[{'label': 'POSITIVE', 'score': 2.557040214538574},\n",
" {'label': 'NEGATIVE', 'score': -2.294790267944336}]"
]
},
"execution_count": null,
@ -371,7 +362,7 @@
"}\n",
"\n",
"\n",
"for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n",
"for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader)):\n",
" query_tensors = batch[\"input_ids\"]\n",
"\n",
" #### Get response from gpt2\n",
@ -379,14 +370,16 @@
" for query in query_tensors:\n",
" gen_len = output_length_sampler()\n",
" generation_kwargs[\"max_new_tokens\"] = gen_len\n",
" response = ppo_trainer.generate(query, **generation_kwargs)\n",
" response_tensors.append(response.squeeze()[-gen_len:])\n",
" query_response = ppo_trainer.generate(query, **generation_kwargs).squeeze()\n",
" response_len = len(query_response) - len(query)\n",
" response_tensors.append(query_response[-response_len:])\n",
" batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n",
"\n",
" #### Compute sentiment score\n",
" texts = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n",
" pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
" rewards = [torch.tensor(output[1][\"score\"]) for output in pipe_outputs]\n",
" positive_scores = [item[\"score\"] for output in pipe_outputs for item in output if item[\"label\"] == \"POSITIVE\"]\n",
" rewards = [torch.tensor(score) for score in positive_scores]\n",
"\n",
" #### Run PPO step\n",
" stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n",
@ -398,7 +391,7 @@
"metadata": {},
"source": [
"### Training progress\n",
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/huggingface/trl-showcase/runs/1jtvxb1m/).\n",
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://wandb.ai/huggingface/trl/runs/w9l3110g).\n",
"\n",
"<div style=\"text-align: center\">\n",
"<img src='https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gpt2_tuning_progress.png' width='800'>\n",
@ -416,7 +409,7 @@
"metadata": {},
"source": [
"## Model inspection\n",
"Let's inspect some examples from the IMDB dataset. We can use `model_ref` to compare the tuned model `model` against the model before optimisation."
"Let's inspect some examples from the IMDB dataset. We can use `ref_model` to compare the tuned model `model` against the model before optimisation."
]
},
{
@ -424,14 +417,6 @@
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/transformers/pipelines/base.py:1075: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
@ -463,131 +448,131 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Oh dear,</td>\n",
" <td>what are I saying?! I fast-forwarded through</td>\n",
" <td>I must say that I are hanging my head on this</td>\n",
" <td>-0.858954</td>\n",
" <td>-1.007609</td>\n",
" <td>I rented Zero Day</td>\n",
" <td>4 for my sister. To my surprise, the Wii caug...</td>\n",
" <td>. It is a pleasure. It is a huge leap 68 years...</td>\n",
" <td>1.736068</td>\n",
" <td>2.423731</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>I've seen</td>\n",
" <td>it, as well.&lt;br</td>\n",
" <td>three million dialogue throughout, and</td>\n",
" <td>1.996807</td>\n",
" <td>2.240883</td>\n",
" <td>The only</td>\n",
" <td>distro of her</td>\n",
" <td>special compliments is the</td>\n",
" <td>0.150852</td>\n",
" <td>0.190159</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Hi:&lt;br /&gt;&lt;br</td>\n",
" <td>/&gt;This movie is a turkey though when it comes to</td>\n",
" <td>/&gt;I also like that movie. It's so funny</td>\n",
" <td>-0.438191</td>\n",
" <td>2.415630</td>\n",
" <td>I've read a few</td>\n",
" <td>news reports about Mr. Mueller's activities b...</td>\n",
" <td>novels and I never watch this. It has a reall...</td>\n",
" <td>-1.417962</td>\n",
" <td>2.831814</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>I'm a writer</td>\n",
" <td>and I'm not going to be asked to</td>\n",
" <td>, not a screenwriter. I've written</td>\n",
" <td>-0.655991</td>\n",
" <td>-0.724324</td>\n",
" <td>This is the second British Rank film</td>\n",
" <td>, and I wouldn't be surprised anymore if it</td>\n",
" <td>that I have enjoyed, achieving it in both the</td>\n",
" <td>0.835876</td>\n",
" <td>2.205628</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>If you</td>\n",
" <td>absolutely love sensitive romance, the plot a...</td>\n",
" <td>are looking at the cinematography, the acting,</td>\n",
" <td>2.221309</td>\n",
" <td>0.148751</td>\n",
" <td>A classic</td>\n",
" <td>classic.&lt;br /&gt;&lt;br /&gt;And only this one will ha...</td>\n",
" <td>. It's a movie with a fine cast. As the beginn...</td>\n",
" <td>2.113075</td>\n",
" <td>2.739168</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>OMG this</td>\n",
" <td>casting cast. Obi cult breezy, this is</td>\n",
" <td>movie was totally wonderful, I it was the ide...</td>\n",
" <td>-1.533139</td>\n",
" <td>2.590190</td>\n",
" <td>This has to be one of the</td>\n",
" <td>worst with the differences being that for the</td>\n",
" <td>best thriller films I've seen in recent</td>\n",
" <td>-2.705339</td>\n",
" <td>2.730615</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>It's</td>\n",
" <td>unrealistic; the guy who was supposed to be E...</td>\n",
" <td>a very good film. It reminds us about over</td>\n",
" <td>-2.097017</td>\n",
" <td>2.835831</td>\n",
" <td>Happy Go Lovely is a waste</td>\n",
" <td>. Not only are extremely</td>\n",
" <td>of time, giving a</td>\n",
" <td>-2.429504</td>\n",
" <td>-2.934672</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>There is a really</td>\n",
" <td>awful laptop game!&lt;br /&gt;&lt;br /&gt;I used to</td>\n",
" <td>interesting story that set us the journey. Th...</td>\n",
" <td>-2.341743</td>\n",
" <td>2.282939</td>\n",
" <td>Wow, I just</td>\n",
" <td>can't make fun of it</td>\n",
" <td>feek it! This show</td>\n",
" <td>-2.201666</td>\n",
" <td>-0.106085</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>This is</td>\n",
" <td>my favorite part about</td>\n",
" <td>a well thought well</td>\n",
" <td>2.554794</td>\n",
" <td>2.734139</td>\n",
" <td>This movie makes several mistakes.</td>\n",
" <td>Despite being a great comedic diversion it es...</td>\n",
" <td>It's cool, wonderful - it held me into a very ...</td>\n",
" <td>-1.232380</td>\n",
" <td>2.707638</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>Wasn't</td>\n",
" <td>Wasn't it clichéd?&lt;|endoftext|&gt;</td>\n",
" <td>anyone else interested in this movie? It's a ...</td>\n",
" <td>-1.790802</td>\n",
" <td>2.631960</td>\n",
" <td>Branagh and Fish</td>\n",
" <td>burne, Drake is played</td>\n",
" <td>is a great show. Beautiful</td>\n",
" <td>0.776819</td>\n",
" <td>2.808996</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>This film is another of director Tim</td>\n",
" <td>Burton's masterpieces</td>\n",
" <td>Curry's best bombs</td>\n",
" <td>2.622917</td>\n",
" <td>2.544106</td>\n",
" <td>I might have given this movie a</td>\n",
" <td>rating of *11 when I heard that!), but it was...</td>\n",
" <td>great performance. It was truly a great movie...</td>\n",
" <td>0.276380</td>\n",
" <td>2.743328</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>I thought this movie</td>\n",
" <td>was excellent. I actually laughed 6 times and...</td>\n",
" <td>was perfect, and I believe it's almost overlo...</td>\n",
" <td>2.548022</td>\n",
" <td>2.601913</td>\n",
" <td>Really, really bad</td>\n",
" <td>with feel like there is no end to the</td>\n",
" <td>. This movie is incredibly good, with the</td>\n",
" <td>-2.639503</td>\n",
" <td>-1.568827</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>This early John Wayne</td>\n",
" <td>films looked like an abandoned police beating</td>\n",
" <td>film is a realistic portrayal of what</td>\n",
" <td>-1.742279</td>\n",
" <td>2.609762</td>\n",
" <td>What another reviewer called lack of</td>\n",
" <td>judgment, connecting into her own harsh obser...</td>\n",
" <td>suspense. Rogers and Rooney rate this as exce...</td>\n",
" <td>-1.079707</td>\n",
" <td>2.696888</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>I was</td>\n",
" <td>given an experience-a big one, almost 25</td>\n",
" <td>very happy with all the reflections and this ...</td>\n",
" <td>2.250709</td>\n",
" <td>2.558540</td>\n",
" <td>This is simply one</td>\n",
" <td>more problem of Steve</td>\n",
" <td>of the best choice</td>\n",
" <td>-1.445436</td>\n",
" <td>2.662699</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>Embarrassingly, I</td>\n",
" <td>am more at a strict conformity after getting ...</td>\n",
" <td>had never seen a movie before. There was one ...</td>\n",
" <td>-2.021666</td>\n",
" <td>-1.803383</td>\n",
" <td>\"Perhaps we can arrange a meet</td>\n",
" <td>-and-greet.&lt;br /&gt;&lt;br /&gt;Teleg</td>\n",
" <td>with spent, classic music and dance, and come...</td>\n",
" <td>0.258479</td>\n",
" <td>1.876662</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>I am a fan</td>\n",
" <td>of living on simple islands, and we have visi...</td>\n",
" <td>of many things and learned how to appreciate ...</td>\n",
" <td>1.791297</td>\n",
" <td>2.324461</td>\n",
" <td>Richard Willaims is</td>\n",
" <td>nice enough; the little black guy plays quite</td>\n",
" <td>beautifully hands on in his own spin, and</td>\n",
" <td>0.796508</td>\n",
" <td>2.820259</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
@ -595,76 +580,76 @@
],
"text/plain": [
" query \\\n",
"0 Oh dear, \n",
"1 I've seen \n",
"2 Hi:<br /><br \n",
"3 I'm a writer \n",
"4 If you \n",
"5 OMG this \n",
"6 It's \n",
"7 There is a really \n",
"8 This is \n",
"9 Wasn't \n",
"10 This film is another of director Tim \n",
"11 I thought this movie \n",
"12 This early John Wayne \n",
"13 I was \n",
"14 Embarrassingly, I \n",
"15 I am a fan \n",
"0 I rented Zero Day \n",
"1 The only \n",
"2 I've read a few \n",
"3 This is the second British Rank film \n",
"4 A classic \n",
"5 This has to be one of the \n",
"6 Happy Go Lovely is a waste \n",
"7 Wow, I just \n",
"8 This movie makes several mistakes. \n",
"9 Branagh and Fish \n",
"10 I might have given this movie a \n",
"11 Really, really bad \n",
"12 What another reviewer called lack of \n",
"13 This is simply one \n",
"14 \"Perhaps we can arrange a meet \n",
"15 Richard Willaims is \n",
"\n",
" response (before) \\\n",
"0 what are I saying?! I fast-forwarded through \n",
"1 it, as well.<br \n",
"2 />This movie is a turkey though when it comes to \n",
"3 and I'm not going to be asked to \n",
"4 absolutely love sensitive romance, the plot a... \n",
"5 casting cast. Obi cult breezy, this is \n",
"6 unrealistic; the guy who was supposed to be E... \n",
"7 awful laptop game!<br /><br />I used to \n",
"8 my favorite part about \n",
"9 Wasn't it clichéd?<|endoftext|> \n",
"10 Burton's masterpieces \n",
"11 was excellent. I actually laughed 6 times and... \n",
"12 films looked like an abandoned police beating \n",
"13 given an experience-a big one, almost 25 \n",
"14 am more at a strict conformity after getting ... \n",
"15 of living on simple islands, and we have visi... \n",
"0 4 for my sister. To my surprise, the Wii caug... \n",
"1 distro of her \n",
"2 news reports about Mr. Mueller's activities b... \n",
"3 , and I wouldn't be surprised anymore if it \n",
"4 classic.<br /><br />And only this one will ha... \n",
"5 worst with the differences being that for the \n",
"6 . Not only are extremely \n",
"7 can't make fun of it \n",
"8 Despite being a great comedic diversion it es... \n",
"9 burne, Drake is played \n",
"10 rating of *11 when I heard that!), but it was... \n",
"11 with feel like there is no end to the \n",
"12 judgment, connecting into her own harsh obser... \n",
"13 more problem of Steve \n",
"14 -and-greet.<br /><br />Teleg \n",
"15 nice enough; the little black guy plays quite \n",
"\n",
" response (after) rewards (before) \\\n",
"0 I must say that I are hanging my head on this -0.858954 \n",
"1 three million dialogue throughout, and 1.996807 \n",
"2 />I also like that movie. It's so funny -0.438191 \n",
"3 , not a screenwriter. I've written -0.655991 \n",
"4 are looking at the cinematography, the acting, 2.221309 \n",
"5 movie was totally wonderful, I it was the ide... -1.533139 \n",
"6 a very good film. It reminds us about over -2.097017 \n",
"7 interesting story that set us the journey. Th... -2.341743 \n",
"8 a well thought well 2.554794 \n",
"9 anyone else interested in this movie? It's a ... -1.790802 \n",
"10 Curry's best bombs 2.622917 \n",
"11 was perfect, and I believe it's almost overlo... 2.548022 \n",
"12 film is a realistic portrayal of what -1.742279 \n",
"13 very happy with all the reflections and this ... 2.250709 \n",
"14 had never seen a movie before. There was one ... -2.021666 \n",
"15 of many things and learned how to appreciate ... 1.791297 \n",
"0 . It is a pleasure. It is a huge leap 68 years... 1.736068 \n",
"1 special compliments is the 0.150852 \n",
"2 novels and I never watch this. It has a reall... -1.417962 \n",
"3 that I have enjoyed, achieving it in both the 0.835876 \n",
"4 . It's a movie with a fine cast. As the beginn... 2.113075 \n",
"5 best thriller films I've seen in recent -2.705339 \n",
"6 of time, giving a -2.429504 \n",
"7 feek it! This show -2.201666 \n",
"8 It's cool, wonderful - it held me into a very ... -1.232380 \n",
"9 is a great show. Beautiful 0.776819 \n",
"10 great performance. It was truly a great movie... 0.276380 \n",
"11 . This movie is incredibly good, with the -2.639503 \n",
"12 suspense. Rogers and Rooney rate this as exce... -1.079707 \n",
"13 of the best choice -1.445436 \n",
"14 with spent, classic music and dance, and come... 0.258479 \n",
"15 beautifully hands on in his own spin, and 0.796508 \n",
"\n",
" rewards (after) \n",
"0 -1.007609 \n",
"1 2.240883 \n",
"2 2.415630 \n",
"3 -0.724324 \n",
"4 0.148751 \n",
"5 2.590190 \n",
"6 2.835831 \n",
"7 2.282939 \n",
"8 2.734139 \n",
"9 2.631960 \n",
"10 2.544106 \n",
"11 2.601913 \n",
"12 2.609762 \n",
"13 2.558540 \n",
"14 -1.803383 \n",
"15 2.324461 "
"0 2.423731 \n",
"1 0.190159 \n",
"2 2.831814 \n",
"3 2.205628 \n",
"4 2.739168 \n",
"5 2.730615 \n",
"6 -2.934672 \n",
"7 -0.106085 \n",
"8 2.707638 \n",
"9 2.808996 \n",
"10 2.743328 \n",
"11 -1.568827 \n",
"12 2.696888 \n",
"13 2.662699 \n",
"14 1.876662 \n",
"15 2.820259 "
]
},
"execution_count": null,
@ -685,15 +670,16 @@
"\n",
"#### get response from gpt2 and gpt2_ref\n",
"for i in range(bs):\n",
" query = torch.tensor(query_tensors[i]).to(device)\n",
"\n",
" gen_len = output_length_sampler()\n",
" output = ref_model.generate(\n",
" torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
" ).squeeze()[-gen_len:]\n",
" response_tensors_ref.append(output)\n",
" output = model.generate(\n",
" torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
" ).squeeze()[-gen_len:]\n",
" response_tensors.append(output)\n",
" query_response = ref_model.generate(query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n",
" response_len = len(query_response) - len(query)\n",
" response_tensors_ref.append(query_response[-response_len:])\n",
"\n",
" query_response = model.generate(query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n",
" response_len = len(query_response) - len(query)\n",
" response_tensors.append(query_response[-response_len:])\n",
"\n",
"#### decode responses\n",
"game_data[\"response (before)\"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]\n",
@ -701,10 +687,14 @@
"\n",
"#### sentiment analysis of query/response pairs before/after\n",
"texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (before)\"])]\n",
"game_data[\"rewards (before)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n",
"pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
"positive_scores = [item[\"score\"] for output in pipe_outputs for item in output if item[\"label\"] == \"POSITIVE\"]\n",
"game_data[\"rewards (before)\"] = positive_scores\n",
"\n",
"texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (after)\"])]\n",
"game_data[\"rewards (after)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n",
"pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
"positive_scores = [item[\"score\"] for output in pipe_outputs for item in output if item[\"label\"] == \"POSITIVE\"]\n",
"game_data[\"rewards (after)\"] = positive_scores\n",
"\n",
"# store results in a dataframe\n",
"df_results = pd.DataFrame(game_data)\n",
@ -733,8 +723,8 @@
{
"data": {
"text/plain": [
"rewards (before) 0.156629\n",
"rewards (after) 1.686487\n",
"rewards (before) -0.512965\n",
"rewards (after) 1.676750\n",
"dtype: float64"
]
},
@ -752,8 +742,8 @@
{
"data": {
"text/plain": [
"rewards (before) -0.547091\n",
"rewards (after) 2.479868\n",
"rewards (before) -0.464427\n",
"rewards (after) 2.679794\n",
"dtype: float64"
]
},
@ -782,45 +772,6 @@
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/huggingface_hub/hf_api.py:1001: FutureWarning: `create_repo` now takes `token` as an optional positional argument. Be sure to adapt your code!\n",
" warnings.warn(\n",
"Cloning https://huggingface.co/lvwerra/gpt2-imdb-pos-v2 into local empty directory.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a953a6d0c465432bbc39aca826d37aaf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Upload file pytorch_model.bin: 0%| | 32.0k/487M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"remote: Enforcing permissions... \n",
"remote: Allowed refs: all \n",
"To https://huggingface.co/lvwerra/gpt2-imdb-pos-v2\n",
" 369b075..28b9865 main -> main\n",
"\n",
"remote: Enforcing permissions... \n",
"remote: Allowed refs: all \n",
"To https://huggingface.co/lvwerra/gpt2-imdb-pos-v2\n",
" 28b9865..42792ea main -> main\n",
"\n"
]
},
{
"data": {
"text/plain": [
@ -841,13 +792,6 @@
"model.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)\n",
"tokenizer.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -866,7 +810,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12 (main, Mar 26 2022, 15:51:15) \n[Clang 13.1.6 (clang-1316.0.21.2)]"
"version": "3.11.9"
},
"vscode": {
"interpreter": {

View File

@ -15,6 +15,7 @@ from transformers import (
Trainer,
TrainerCallback,
TrainingArguments,
set_seed,
)
from transformers.utils import PaddingStrategy
@ -89,16 +90,23 @@ class ScriptArguments:
default=False,
metadata={"help": "Whether to run eval after the first step"},
)
seed: Optional[int] = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
set_seed(script_args.seed)
# Load the human stack-exchange-paired dataset for tuning the reward model.
train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train")
train_dataset = load_dataset(
"lvwerra/stack-exchange-paired", data_dir="data/reward", split="train", verification_mode="no_checks"
)
if script_args.train_subset > 0:
train_dataset = train_dataset.select(range(script_args.train_subset))
eval_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train")
eval_dataset = load_dataset(
"lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train", verification_mode="no_checks"
)
if script_args.eval_subset > 0:
eval_dataset = eval_dataset.select(range(script_args.eval_subset))
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
@ -114,7 +122,7 @@ training_args = TrainingArguments(
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
num_train_epochs=script_args.num_train_epochs,
weight_decay=script_args.weight_decay,
evaluation_strategy="steps",
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
@ -129,7 +137,10 @@ training_args = TrainingArguments(
logging_steps=10,
optim=script_args.optim,
lr_scheduler_type=script_args.lr_scheduler_type,
seed=script_args.seed,
)
# Load the value-head model and tokenizer.
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
@ -187,7 +198,8 @@ train_dataset = train_dataset.map(
remove_columns=original_columns,
)
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
num_proc=num_proc,
)
eval_dataset = eval_dataset.map(
@ -197,7 +209,8 @@ eval_dataset = eval_dataset.map(
remove_columns=original_columns,
)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
num_proc=num_proc,
)
@ -264,7 +277,7 @@ def compute_metrics(eval_pred):
class RewardTrainer(Trainer):
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://huggingface.co/papers/2203.02155
def compute_loss(self, model, inputs, return_outputs=False):
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -67,6 +66,7 @@ class ScriptArguments:
)
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
parser = HfArgumentParser(ScriptArguments)
@ -90,7 +90,9 @@ config = PPOConfig(
adap_kl_ctrl=script_args.adap_kl_ctrl,
)
train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/rl", split="train")
train_dataset = load_dataset(
"lvwerra/stack-exchange-paired", data_dir="data/rl", split="train", verification_mode="no_checks"
)
train_dataset = train_dataset.select(range(100000))
original_columns = train_dataset.column_names
@ -152,7 +154,7 @@ def build_dataset(
num_proc=num_proc,
remove_columns=original_columns,
)
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False)
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False, num_proc=num_proc)
ds.set_format(type="torch")
return ds
@ -163,7 +165,7 @@ dataset = build_dataset(tokenizer)
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}
# set seed before initializing value head for deterministic eval
@ -181,7 +183,7 @@ lora_config = LoraConfig(
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
load_in_8bit=True,
load_in_8bit=script_args.load_in_8bit,
device_map={"": current_device},
peft_config=lora_config,
)
@ -216,11 +218,13 @@ sentiment_pipe = pipeline(
"sentiment-analysis",
model=reward_model_name,
device_map={"": current_device},
model_kwargs={"load_in_8bit": True},
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
tokenizer=tokenizer,
return_token_type_ids=False,
)
if sentiment_pipe.model.config.pad_token_id is None:
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.

View File

@ -148,7 +148,7 @@ def run_training(args, train_data, val_data):
training_args = TrainingArguments(
output_dir=args.output_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
eval_strategy="steps",
max_steps=args.max_steps,
eval_steps=args.eval_freq,
save_steps=args.save_freq,

View File

@ -4,11 +4,12 @@ from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from accelerate import Accelerator
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer
# Define and parse arguments.
@ -41,6 +42,10 @@ class ScriptArguments:
default=True, metadata={"help": "whether to use gradient checkpointing"}
)
gradient_checkpointing_use_reentrant: Optional[bool] = field(
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
)
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
@ -54,6 +59,10 @@ class ScriptArguments:
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
model_dtype: Optional[str] = field(
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
)
# instrumentation
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
@ -73,12 +82,15 @@ class ScriptArguments:
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
seed: Optional[int] = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)
def get_stack_exchange_paired(
data_dir: str = "data/rl",
sanity_check: bool = False,
cache_dir: str = None,
cache_dir: Optional[str] = None,
num_proc=24,
) -> Dataset:
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
@ -98,6 +110,7 @@ def get_stack_exchange_paired(
split="train",
cache_dir=cache_dir,
data_dir=data_dir,
verification_mode="no_checks",
)
original_columns = dataset.column_names
@ -123,12 +136,21 @@ if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
set_seed(script_args.seed)
# 1. load a pretrained model
torch_dtype = torch.float
if script_args.model_dtype == "float16":
torch_dtype = torch.float16
elif script_args.model_dtype == "bfloat16":
torch_dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
torch_dtype=torch_dtype,
load_in_4bit=script_args.load_in_4bit,
device_map={"": Accelerator().local_process_index},
)
model.config.use_cache = False
@ -138,12 +160,6 @@ if __name__ == "__main__":
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
model_ref = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
@ -151,18 +167,20 @@ if __name__ == "__main__":
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
train_dataset = train_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
num_proc=script_args.num_proc,
)
# 3. Load evaluation dataset
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
eval_dataset = eval_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
num_proc=script_args.num_proc,
)
# 4. initialize training arguments:
training_args = TrainingArguments(
training_args = DPOConfig(
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
max_steps=script_args.max_steps,
@ -171,7 +189,7 @@ if __name__ == "__main__":
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
eval_strategy="steps",
eval_steps=script_args.eval_steps,
output_dir=script_args.output_dir,
report_to=script_args.report_to,
@ -181,6 +199,8 @@ if __name__ == "__main__":
bf16=True,
remove_unused_columns=False,
run_name="dpo_llama2",
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
seed=script_args.seed,
)
peft_config = LoraConfig(
@ -203,7 +223,7 @@ if __name__ == "__main__":
# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
ref_model=None,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,

View File

@ -8,9 +8,15 @@ from accelerate import Accelerator
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
set_seed,
)
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
from trl.import_utils import is_npu_available, is_xpu_available
from trl.trainer import ConstantLengthDataset
@ -26,7 +32,7 @@ class ScriptArguments:
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
# LoraConfig
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
@ -34,7 +40,7 @@ class ScriptArguments:
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
parser = HfArgumentParser((ScriptArguments, SFTConfig))
script_args, training_args = parser.parse_args_into_dataclasses()
peft_config = LoraConfig(
r=script_args.lora_r,
@ -45,7 +51,7 @@ peft_config = LoraConfig(
task_type="CAUSAL_LM",
)
if training_args.group_by_length and script_args.packing:
if training_args.group_by_length and training_args.packing:
raise ValueError("Cannot use both packing and group by length")
# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used.
@ -53,6 +59,8 @@ if training_args.group_by_length and script_args.packing:
if training_args.gradient_checkpointing:
raise ValueError("gradient_checkpointing not supported")
set_seed(training_args.seed)
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
"""
@ -91,7 +99,7 @@ def prepare_sample_text(example):
return text
def create_datasets(tokenizer, args):
def create_datasets(tokenizer, args, seed=None):
dataset = load_dataset(
args.dataset_name,
data_dir=args.subset,
@ -104,9 +112,9 @@ def create_datasets(tokenizer, args):
print("Loading the dataset in streaming mode")
valid_data = dataset.take(args.size_valid_set)
train_data = dataset.skip(args.size_valid_set)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
else:
dataset = dataset.train_test_split(test_size=0.005, seed=None)
dataset = dataset.train_test_split(test_size=0.005, seed=seed)
train_data = dataset["train"]
valid_data = dataset["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
@ -133,11 +141,13 @@ def create_datasets(tokenizer, args):
return train_dataset, valid_dataset
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
bnb_config = None
if script_args.use_bnb:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
script_args.model_name,
@ -153,15 +163,15 @@ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_c
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
train_dataset, eval_dataset = create_datasets(tokenizer, script_args)
train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
trainer = SFTTrainer(
model=base_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=script_args.packing,
max_seq_length=None,
formatting_func=prepare_sample_text,
tokenizer=tokenizer,
args=training_args,
)

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -59,7 +58,7 @@ def exact_match_reward(responses, answers=None):
# set up models
model_id = "gpt2"
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
@ -94,7 +93,7 @@ ppo_config = PPOConfig(
mini_batch_size=64,
log_with="wandb",
)
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
# text env
text_env = TextEnvironment(
@ -107,7 +106,7 @@ text_env = TextEnvironment(
)
# main training loop
for step in range(100):
for _step in range(100):
tasks, answers = generate_data(ppo_config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -61,9 +60,9 @@ def exact_match_reward(responses, answers=None):
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs((predicted_number - float(answer))) < 0.1:
if np.abs(predicted_number - float(answer)) < 0.1:
reward += 1.0
except: # noqa
except Exception:
pass
rewards.append(torch.tensor(reward))
return rewards

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -114,7 +113,7 @@ dataset = dataset.shuffle(local_seed)
def data_generator():
for i in range(len(dataset)):
yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]]
yield dataset[i]["question"], list(dataset[i]["answer"]["normalized_aliases"])
gen = data_generator()
@ -123,7 +122,7 @@ gen = iter(gen)
def generate_data(n):
tasks, answers = [], []
for i in range(n):
for _i in range(n):
q, a = next(gen)
tasks.append(q)
answers.append(a)
@ -143,10 +142,14 @@ def exact_match_reward(responses, answers=None):
return rewards
def tool_fn(x):
# limit the amount of tokens
return tool(x).split("\n")[1][:600]
# text env
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
# limit the amount if tokens
tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa
text_env = TextEnvironment(
model,
tokenizer,
@ -184,8 +187,6 @@ for i in range(args.iterations):
"answer": [", ".join(item) for item in answers],
}
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
ppo_trainer.log_stats(
train_stats, texts, [item for item in all_rewards], columns_to_log=["query", "response", "answer"]
)
ppo_trainer.log_stats(train_stats, texts, list(all_rewards), columns_to_log=["query", "response", "answer"])
if i % 100 == 0:
ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -146,7 +145,7 @@ dataset = build_dataset(config, input_min_text_length=min_input_length, input_ma
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}
# set seed before initializing value head for deterministic eval
@ -218,7 +217,7 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
response_tensors.append(response.squeeze()[-gen_len:])
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
# Compute sentiment score # noqa
# Compute sentiment score
texts = batch["response"]
toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(
ppo_trainer.accelerator.device

View File

@ -0,0 +1,129 @@
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Total Batch size = 128 = 4 (num_gpus) * 8 (per_device_batch) * 4 (accumulation steps)
Feel free to reduce batch size or increasing truncated_rand_backprop_min to a higher value to reduce memory usage.
CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/scripts/alignprop.py \
--num_epochs=20 \
--train_gradient_accumulation_steps=4 \
--sample_num_steps=50 \
--train_batch_size=8 \
--tracker_project_name="stable_diffusion_training" \
--log_with="wandb"
"""
from dataclasses import dataclass, field
import numpy as np
from transformers import HfArgumentParser
from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline
from trl.models.auxiliary_modules import aesthetic_scorer
@dataclass
class ScriptArguments:
pretrained_model: str = field(
default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"}
)
pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"})
hf_hub_model_id: str = field(
default="alignprop-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"}
)
hf_hub_aesthetic_model_id: str = field(
default="trl-lib/ddpo-aesthetic-predictor",
metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"},
)
hf_hub_aesthetic_model_filename: str = field(
default="aesthetic-model.pth",
metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"},
)
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."})
# list of example prompts to feed stable diffusion
animals = [
"cat",
"dog",
"horse",
"monkey",
"rabbit",
"zebra",
"spider",
"bird",
"sheep",
"deer",
"cow",
"goat",
"lion",
"frog",
"chicken",
"duck",
"goose",
"bee",
"pig",
"turkey",
"fly",
"llama",
"camel",
"bat",
"gorilla",
"hedgehog",
"kangaroo",
]
def prompt_fn():
return np.random.choice(animals), {}
def image_outputs_logger(image_pair_data, global_step, accelerate_logger):
# For the sake of this example, we will only log the last batch of images
# and associated data
result = {}
images, prompts, _ = [image_pair_data["images"], image_pair_data["prompts"], image_pair_data["rewards"]]
for i, image in enumerate(images[:4]):
prompt = prompts[i]
result[f"{prompt}"] = image.unsqueeze(0).float()
accelerate_logger.log_images(
result,
step=global_step,
)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, AlignPropConfig))
args, alignprop_config = parser.parse_args_into_dataclasses()
alignprop_config.project_kwargs = {
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
}
pipeline = DefaultDDPOStableDiffusionPipeline(
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora
)
trainer = AlignPropTrainer(
alignprop_config,
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
prompt_fn,
pipeline,
image_samples_hook=image_outputs_logger,
)
trainer.train()
trainer.push_to_hub(args.hf_hub_model_id)

232
examples/scripts/bco.py Normal file
View File

@ -0,0 +1,232 @@
"""
Run the BCO training script with the commands below. In general, the optimal configuration for BCO will be similar to that of KTO.
# Full training:
python examples/scripts/bco.py \
--model_name_or_path=nnheui/stablelm-2-1_6b-sft-full \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 32 \
--num_train_epochs 1 \
--learning_rate 1e-6 \
--gradient_checkpointing \
--gradient_accumulation_steps 1 \
--logging_steps 0.01 \
--eval_steps 0.2 \
--save_strategy no \
--output_dir=bco-aligned-model \
--logging_first_step \
--max_length 2048 \
--max_prompt_length 1536 \
--max_completion_length 1024 \
--no_remove_unused_columns \
--warmup_ratio 0.1 \
--bf16 \
--report_to wandb
# QLoRA:
python examples/scripts/bco.py \
--model_name_or_path=nnheui/stablelm-2-1_6b-sft-full \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 32 \
--num_train_epochs 1 \
--learning_rate 1e-6 \
--gradient_checkpointing \
--gradient_accumulation_steps 1 \
--logging_steps 0.01 \
--eval_steps 0.2 \
--save_strategy no \
--output_dir=bco-aligned-model-lora \
--logging_first_step \
--warmup_ratio 0.1 \
--report_to wandb \
--max_length 2048 \
--max_prompt_length 1536 \
--max_completion_length 1024 \
--no_remove_unused_columns \
--warmup_ratio 0.1 \
--bf16 \
--use_peft \
--load_in_4bit \
--lora_target_modules=all-linear \
--lora_r=16 \
--lora_alpha=16
"""
import logging
from dataclasses import dataclass
from functools import partial
from typing import Literal, Optional
import torch
import torch.nn.functional as F
from accelerate import Accelerator, PartialState
from datasets import Dataset, load_dataset
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedModel
from trl import BCOConfig, BCOTrainer, ModelConfig, get_peft_config, setup_chat_format
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the BCO training script.
"""
llm_name: Literal["gpt-3.5-turbo", "llama-2-7b-chat", "llama-2-70b-chat"] = "gpt-3.5-turbo"
def build_helpfulness_dataset(llm_name: str, num_proc: Optional[int] = None) -> Dataset:
"""
Filter `llm_name` completions and binarize given their helpfulness score.
If helpfulness score is 5, it is desirable. Otherwise, it is undesirable.
"""
def get_model_rating(example, metric: str, llm_name: str):
try:
model_index = example["models"].index(llm_name)
return {metric: int(example["completions"][model_index]["annotations"][metric]["Rating"])}
except ValueError as e:
logging.warning(e)
return -1
def get_model_response(example, llm_name: str):
try:
model_index = example["models"].index(llm_name)
return {"response": example["completions"][model_index]["response"]}
except ValueError as e:
logging.warning(e)
return -1
dataset = load_dataset("openbmb/UltraFeedback")["train"]
ds = dataset.filter(lambda example: llm_name in example["models"], batched=False, num_proc=num_proc)
ds = ds.filter(
lambda example: len(example["models"]) == len(example["completions"]), batched=False, num_proc=num_proc
)
METRIC = "helpfulness"
ds = ds.map(
get_model_rating,
batched=False,
fn_kwargs={"metric": METRIC, "llm_name": llm_name},
num_proc=num_proc,
)
ds = ds.map(
get_model_response,
batched=False,
fn_kwargs={"llm_name": llm_name},
num_proc=num_proc,
)
ds = ds.select_columns(["source", "instruction", "response", "helpfulness"])
ds = ds.rename_columns({"instruction": "prompt", "response": "completion"})
ds = ds.map(lambda example: {"label": example["helpfulness"] >= 5}, batched=False, num_proc=num_proc)
ds = ds.map(
lambda example: {"prompt": [{"role": "user", "content": example["prompt"]}]},
batched=False,
num_proc=num_proc,
)
dataset = ds.train_test_split(test_size=0.05, seed=42)
return dataset
def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.LongTensor, model: PreTrainedModel):
"""
Borrowed from https://huggingface.co/nomic-ai/nomic-embed-text-v1.5#transformers
"""
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
with torch.no_grad():
model_output = model(input_ids=input_ids, attention_mask=attention_mask)
embeddings = mean_pooling(model_output, attention_mask)
matryoshka_dim = 512
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
embeddings = embeddings[:, :matryoshka_dim]
return embeddings
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, BCOConfig, ModelConfig))
script_args, bco_args, model_args = parser.parse_args_into_dataclasses()
bco_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
ref_model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# If we are aligning a base model, we use ChatML as the default template
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(
example["prompt"], tokenize=False, add_generation_prompt=True
)
return example
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Load the dataset
dataset = build_helpfulness_dataset(script_args.llm_name, num_proc=bco_args.dataset_num_proc)
formatted_dataset = dataset.map(format_dataset, batched=False, num_proc=bco_args.dataset_num_proc)
accelerator = Accelerator()
embedding_model = AutoModel.from_pretrained(
"nomic-ai/nomic-embed-text-v1.5",
trust_remote_code=model_args.trust_remote_code,
safe_serialization=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
embedding_model = accelerator.prepare_model(embedding_model)
embedding_tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased", trust_remote_code=model_args.trust_remote_code
)
embedding_func = partial(
embed_prompt,
model=embedding_model,
)
# Initialize the BCO trainer
bco_trainer = BCOTrainer(
model,
ref_model,
args=bco_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
embedding_func=embedding_func,
embedding_tokenizer=embedding_tokenizer,
)
# Train and push the model to the Hub
bco_trainer.train()
bco_trainer.save_model(bco_args.output_dir)

367
examples/scripts/chat.py Normal file
View File

@ -0,0 +1,367 @@
# flake8: noqa
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trl.commands.cli_utils import init_zero_verbose
init_zero_verbose()
import copy
import json
import os
import sys
import pwd
import re
import time
from threading import Thread
import torch
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from trl.commands.cli_utils import ChatArguments, TrlParser, init_zero_verbose
from trl.trainer.utils import get_quantization_config
HELP_STRING = """\
**TRL CHAT INTERFACE**
The chat interface is a simple tool to try out a chat model.
Besides talking to the model there are several commands:
- **clear**: clears the current conversation and start a new one
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **exit**: closes the interface
"""
SUPPORTED_GENERATION_KWARGS = [
"max_new_tokens",
"do_sample",
"num_beams",
"temperature",
"top_p",
"top_k",
"repetition_penalty",
]
SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$"
class RichInterface:
def __init__(self, model_name=None, user_name=None):
self._console = Console()
if model_name is None:
self.model_name = "assistant"
else:
self.model_name = model_name
if user_name is None:
self.user_name = "user"
else:
self.user_name = user_name
def stream_output(self, output_stream):
"""Stream output from a role."""
# This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
# Create a Live context for updating the console output
text = ""
self._console.print(f"[bold blue]<{self.model_name}>:")
with Live(console=self._console, refresh_per_second=4) as live:
# Read lines from the stream
for i, outputs in enumerate(output_stream):
if not outputs or i == 0:
continue
text += outputs
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
# Update the Live console output
live.update(markdown)
self._console.print()
return text
def input(self):
input = self._console.input(f"[bold red]<{self.user_name}>:\n")
self._console.print()
return input
def clear(self):
self._console.clear()
def print_user_message(self, text):
self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
self._console.print()
def print_green(self, text):
self._console.print(f"[bold green]{text}")
self._console.print()
def print_red(self, text):
self._console.print(f"[bold red]{text}")
self._console.print()
def print_help(self):
self._console.print(Markdown(HELP_STRING))
self._console.print()
def get_username():
return pwd.getpwuid(os.getuid())[0]
def create_default_filename(model_name):
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
return f"{model_name}/chat_{time_str}.json"
def save_chat(chat, args, filename):
output_dict = {}
output_dict["settings"] = vars(args)
output_dict["chat_history"] = chat
folder = args.save_folder
if filename is None:
filename = create_default_filename(args.model_name_or_path)
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
json.dump(output_dict, f, indent=4)
return os.path.abspath(filename)
def clear_chat_history(system_prompt):
if system_prompt is None:
chat = []
else:
chat = [{"role": "system", "content": system_prompt}]
return chat
def parse_settings(user_input, current_args, interface):
settings = user_input[4:].strip().split(";")
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
settings = dict(settings)
error = False
for name in settings:
if hasattr(current_args, name):
try:
if isinstance(getattr(current_args, name), bool):
if settings[name] == "True":
settings[name] = True
elif settings[name] == "False":
settings[name] = False
else:
raise ValueError
else:
settings[name] = type(getattr(current_args, name))(settings[name])
except ValueError:
interface.print_red(
f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
)
else:
interface.print_red(f"There is no '{name}' setting.")
if error:
interface.print_red("There was an issue parsing the settings. No settings have been changed.")
return current_args, False
else:
for name in settings:
setattr(current_args, name, settings[name])
interface.print_green(f"Set {name} to {settings[name]}.")
time.sleep(1.5) # so the user has time to read the changes
return current_args, True
def load_model_and_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
)
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
quantization_config = get_quantization_config(args)
model_kwargs = dict(
revision=args.model_revision,
attn_implementation=args.attn_implementation,
torch_dtype=torch_dtype,
device_map="auto",
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
)
if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
return model, tokenizer
def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids):
if tokenizer.pad_token_id is None:
pad_token_id = tokenizer.eos_token_id
else:
pad_token_id = tokenizer.pad_token_id
all_eos_token_ids = []
if eos_tokens is not None:
all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
if eos_token_ids is not None:
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
if len(all_eos_token_ids) == 0:
all_eos_token_ids.append(tokenizer.eos_token_id)
return pad_token_id, all_eos_token_ids
def chat_cli():
parser = TrlParser(ChatArguments)
if "--config" not in sys.argv:
sys.argv.append("--config")
sys.argv.append(os.path.join(os.path.dirname(__file__), "config/default_chat_config.yaml"))
args = parser.parse_args_and_config()[0]
if args.examples is None:
args.examples = {}
current_args = copy.deepcopy(args)
if args.user is None:
user = get_username()
else:
user = args.user
model, tokenizer = load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
interface.clear()
chat = clear_chat_history(current_args.system_prompt)
while True:
try:
user_input = interface.input()
if user_input == "clear":
chat = clear_chat_history(current_args.system_prompt)
interface.clear()
continue
if user_input == "help":
interface.print_help()
continue
if user_input == "exit":
break
if user_input == "reset":
interface.clear()
current_args = copy.deepcopy(args)
chat = clear_chat_history(current_args.system_prompt)
continue
if user_input.startswith("save") and len(user_input.split()) < 2:
split_input = user_input.split()
if len(split_input) == 2:
filename = split_input[1]
else:
filename = None
filename = save_chat(chat, current_args, filename)
interface.print_green(f"Chat saved in {filename}!")
continue
if re.match(SETTING_RE, user_input):
current_args, success = parse_settings(user_input, current_args, interface)
if success:
chat = []
interface.clear()
continue
if user_input.startswith("example") and len(user_input.split()) == 2:
example_name = user_input.split()[1]
if example_name in current_args.examples:
interface.clear()
chat = []
interface.print_user_message(current_args.examples[example_name]["text"])
user_input = current_args.examples[example_name]["text"]
else:
interface.print_red(
f"Example {example_name} not found in list of available examples: {list(current_args.examples.keys())}."
)
continue
chat.append({"role": "user", "content": user_input})
generation_kwargs = dict(
inputs=tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
),
streamer=generation_streamer,
max_new_tokens=current_args.max_new_tokens,
do_sample=current_args.do_sample,
num_beams=current_args.num_beams,
temperature=current_args.temperature,
top_k=current_args.top_k,
top_p=current_args.top_p,
repetition_penalty=current_args.repetition_penalty,
pad_token_id=pad_token_id,
eos_token_id=eos_token_ids,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
model_output = interface.stream_output(generation_streamer)
thread.join()
chat.append({"role": "assistant", "content": model_output})
except KeyboardInterrupt:
break
if __name__ == "__main__":
chat_cli()

View File

@ -0,0 +1,13 @@
examples:
llama:
text: There is a Llama in my lawn, how can I get rid of it?
code:
text: Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].
helicopter:
text: How many helicopters can a human eat in one sitting?
numbers:
text: Count to 10 but skip every number ending with an 'e'
birds:
text: Why aren't birds real?
socks:
text: Why is it important to eat socks after meditating?

125
examples/scripts/cpo.py Normal file
View File

@ -0,0 +1,125 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the CPO training script with the following command with some example arguments.
In general, the optimal configuration for CPO will be similar to that of DPO:
# regular:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-aligned-cpo" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-lora-aligned-cpo" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""
from dataclasses import dataclass, field
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config
@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style",
metadata={"help": "The name of the dataset to use."},
)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
args, cpo_args, model_config = parser.parse_args_into_dataclasses()
################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
ds = load_dataset(args.dataset)
if cpo_args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=cpo_args.dataset_num_proc)
train_dataset = ds["train"]
eval_dataset = ds["test"]
################
# Training
################
trainer = CPOTrainer(
model,
args=cpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
# train and save the model
trainer.train()
trainer.save_model(cpo_args.output_dir)

View File

@ -11,18 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python examples/scripts/ddpo.py \
--num_epochs=200 \
--train_gradient_accumulation_steps=1 \
--sample_num_steps=50 \
--sample_batch_size=6 \
--train_batch_size=3 \
--sample_num_batches_per_epoch=4 \
--per_prompt_stat_tracking=True \
--per_prompt_stat_tracking_buffer_size=32 \
--tracker_project_name="stable_diffusion_training" \
--log_with="wandb"
"""
import os
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import tyro
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from transformers import CLIPModel, CLIPProcessor
from transformers import CLIPModel, CLIPProcessor, HfArgumentParser
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
from trl.import_utils import is_npu_available, is_xpu_available
@ -30,40 +40,22 @@ from trl.import_utils import is_npu_available, is_xpu_available
@dataclass
class ScriptArguments:
hf_user_access_token: str
pretrained_model: str = "runwayml/stable-diffusion-v1-5"
"""the pretrained model to use"""
pretrained_revision: str = "main"
"""the pretrained model revision to use"""
hf_hub_model_id: str = "ddpo-finetuned-stable-diffusion"
"""HuggingFace repo to save model weights to"""
hf_hub_aesthetic_model_id: str = "trl-lib/ddpo-aesthetic-predictor"
"""HuggingFace model ID for aesthetic scorer model weights"""
hf_hub_aesthetic_model_filename: str = "aesthetic-model.pth"
"""HuggingFace model filename for aesthetic scorer model weights"""
use_lora: bool = True
"""Whether to use LoRA."""
ddpo_config: DDPOConfig = field(
default_factory=lambda: DDPOConfig(
num_epochs=200,
train_gradient_accumulation_steps=1,
sample_num_steps=50,
sample_batch_size=6,
train_batch_size=3,
sample_num_batches_per_epoch=4,
per_prompt_stat_tracking=True,
per_prompt_stat_tracking_buffer_size=32,
tracker_project_name="stable_diffusion_training",
log_with="wandb",
project_kwargs={
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
},
)
pretrained_model: str = field(
default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"}
)
pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"})
hf_hub_model_id: str = field(
default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"}
)
hf_hub_aesthetic_model_id: str = field(
default="trl-lib/ddpo-aesthetic-predictor",
metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"},
)
hf_hub_aesthetic_model_filename: str = field(
default="aesthetic-model.pth",
metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"},
)
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."})
class MLP(nn.Module):
@ -101,7 +93,7 @@ class AestheticScorer(torch.nn.Module):
cached_path = hf_hub_download(model_id, model_filename)
except EntryNotFoundError:
cached_path = os.path.join(model_id, model_filename)
state_dict = torch.load(cached_path, map_location=torch.device("cpu"))
state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True)
self.mlp.load_state_dict(state_dict)
self.dtype = dtype
self.eval()
@ -183,7 +175,7 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
for i, image in enumerate(images):
prompt = prompts[i]
reward = rewards[i].item()
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0)
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float()
accelerate_logger.log_images(
result,
@ -192,14 +184,21 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
if __name__ == "__main__":
args = tyro.cli(ScriptArguments)
parser = HfArgumentParser((ScriptArguments, DDPOConfig))
args, ddpo_config = parser.parse_args_into_dataclasses()
ddpo_config.project_kwargs = {
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
}
pipeline = DefaultDDPOStableDiffusionPipeline(
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora
)
trainer = DDPOTrainer(
args.ddpo_config,
ddpo_config,
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
prompt_fn,
pipeline,
@ -208,4 +207,4 @@ if __name__ == "__main__":
trainer.train()
trainer.push_to_hub(args.hf_hub_model_id, token=args.hf_user_access_token)
trainer.push_to_hub(args.hf_hub_model_id)

View File

@ -1,4 +1,4 @@
# coding=utf-8
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -12,187 +12,177 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="dpo_anthropic_hh" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source
# TODO: bump transformers version in requirements at next release.
# peft:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-helpful-base-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="dpo_anthropic_hh" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""
# 0. imports
from dataclasses import dataclass, field
from typing import Dict, Optional
import logging
import multiprocessing
import os
from contextlib import nullcontext
from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser
from trl.env_utils import strtobool
TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0"))
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from trl import DPOTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import PartialState
from trl import (
DPOConfig,
DPOTrainer,
ModelConfig,
RichProgressCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the DPO training script.
"""
# data parameters
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
# training parameters
model_name_or_path: Optional[str] = field(default="gpt2", metadata={"help": "the model name"})
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: Optional[int] = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
# lora parameters
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
# instrumentation
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
default=None,
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
},
)
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)
def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
def split_prompt_and_responses(sample) -> Dict[str, str]:
prompt = extract_anthropic_prompt(sample["chosen"])
return {
"prompt": prompt,
"chosen": sample["chosen"][len(prompt) :],
"rejected": sample["rejected"][len(prompt) :],
}
return dataset.map(split_prompt_and_responses)
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# 1. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
if script_args.ignore_bias_buffers:
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
peft_config = get_peft_config(model_config)
if peft_config is None:
ref_model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
else:
ref_model = None
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 2. Load the Anthropic Helpful-Harmless dataset
train_dataset = get_hh("train", sanity_check=script_args.sanity_check)
# 3. Load evaluation dataset
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)
# 4. initialize training arguments:
training_args = TrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
max_steps=script_args.max_steps,
remove_unused_columns=False,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
logging_first_step=True,
logging_steps=10, # match results in blog post
eval_steps=500,
output_dir="./test",
optim="rmsprop",
warmup_steps=150,
report_to=script_args.report_to,
bf16=True,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next transformers release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)
if script_args.use_peft:
peft_config = LoraConfig(
r=script_args.peft_lora_r,
lora_alpha=script_args.peft_lora_alpha,
bias="none",
task_type="CAUSAL_LM",
################
# Dataset
################
ds = load_dataset(args.dataset_name)
if args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(50))
def process(row):
row["prompt"] = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False)
row["chosen"] = tokenizer.apply_chat_template([row["chosen"][-1]], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)
return row
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)
train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]
################
# Training
################
with init_context:
trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config,
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
else:
peft_config = None
# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=script_args.max_length,
max_target_length=script_args.max_target_length,
max_prompt_length=script_args.max_prompt_length,
generate_during_eval=True,
peft_config=peft_config,
)
trainer.train()
# 6. train
dpo_trainer.train()
with save_context:
trainer.save_model(training_args.output_dir)

View File

@ -0,0 +1,108 @@
# flake8: noqa
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
python examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-tldr-online-dpo \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 32 \
--num_train_epochs 3 \
--completion_length 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--push_to_hub
"""
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from accelerate import PartialState
from trl import (
DPOScriptArguments,
ModelConfig,
OnlineDPOConfig,
OnlineDPOTrainer,
get_kbit_device_map,
get_quantization_config,
)
from trl.commands.cli_utils import TrlParser
from trl.trainer.callbacks import LogCompletionsCallback
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, OnlineDPOConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
args.gradient_checkpointing_kwargs = {"use_reentrant": True}
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
ref_model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path, num_labels=1, trust_remote_code=model_config.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
dataset = load_dataset(args.dataset_name)
def prepare_dataset(row):
row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False, add_generation_prompt=True)
return row
with PartialState().local_main_process_first():
dataset = dataset.map(prepare_dataset, num_proc=training_args.dataset_num_proc)
prompts = dataset[args.dataset_test_split]["prompt"][:8]
trainer = OnlineDPOTrainer(
model=model,
ref_model=ref_model,
reward_model=reward_model,
args=training_args,
train_dataset=dataset[args.dataset_train_split],
eval_dataset=dataset[args.dataset_test_split],
tokenizer=tokenizer,
)
log_completions_callback = LogCompletionsCallback(prompts)
trainer.add_callback(log_completions_callback)
trainer.train()

View File

@ -0,0 +1,178 @@
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
accelerate launch examples/scripts/dpo_visual.py \
--dataset_name HuggingFaceH4/rlaif-v_formatted \
--model_name_or_path HuggingFaceM4/idefics2-8b \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 32 \
--dataset_num_proc 32 \
--output_dir dpo_idefics_rlaif-v \
--bf16 \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--use_peft \
--lora_target_modules=all-linear
"""
import logging
import os
from contextlib import nullcontext
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser
from accelerate import PartialState
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from datasets import load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
from trl import (
DPOConfig,
DPOTrainer,
ModelConfig,
RichProgressCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
peft_config = get_peft_config(model_config)
if peft_config is None:
ref_model = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
else:
ref_model = None
processor = AutoProcessor.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
do_image_splitting=False,
)
tokenizer = processor.tokenizer
# Set up the chat template
if model.config.model_type == "idefics2":
pass # the processor already has a valid chat template
elif model.config.model_type == "paligemma":
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] if item['type'] == 'text' %}{{ item['text'] }}<|im_end|>{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
elif model.config.model_type == "llava":
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)
################
# Dataset
################
ds = load_dataset(args.dataset_name)
if args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(50))
def process(row):
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)
return row
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)
train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]
################
# Training
################
with init_context:
trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=processor,
peft_config=peft_config,
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
trainer.train()
with save_context:
trainer.save_model(training_args.output_dir)

View File

@ -0,0 +1,74 @@
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser
from vllm import LLM, SamplingParams
from trl import HfPairwiseJudge, OpenAIPairwiseJudge
"""
Examples:
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --num_examples 1000
Model win rate: 31.40%
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000
Model win rate: 51.60%
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 51.20%
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --num_examples 1000
Model win rate: 46.30%
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000
Model win rate: 52.50%
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 63.00%
"""
@dataclass
class ScriptArguments:
model_name_or_path: str = field(metadata={"help": "The model name or path to the model to evaluate."})
judge_model: str = field(
default="meta-llama/Meta-Llama-3-70B-Instruct",
metadata={
"help": "The model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125', 'meta-llama/Meta-Llama-3-70B-Instruct'."
},
)
num_examples: Optional[int] = field(default=None, metadata={"help": "The number of examples to evaluate."})
# Parse the arguments
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
# Load the dataset
raw_dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style", split="test")
if args.num_examples is not None:
raw_dataset = raw_dataset.select(range(args.num_examples))
# Extract the prompts and reference completions
prompts = raw_dataset["prompt"]
reference_completions = [message[-1]["content"] for message in raw_dataset["messages"]]
# Generate the model completions
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=200) # very generous max token length
llm = LLM(model=args.model_name_or_path, tensor_parallel_size=1)
outputs = llm.generate(prompts, sampling_params)
model_completions = [output.outputs[0].text.strip() for output in outputs]
# Judge the outputs
if "gpt" in args.judge_model:
judge = OpenAIPairwiseJudge(args.judge_model)
else:
judge = HfPairwiseJudge(args.judge_model)
completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions)]
best_idxs = judge.judge(prompts, completions)
model_win_rate = best_idxs.count(1) / len(best_idxs)
print(f"Model win rate: {model_win_rate*100:.2f}%")

125
examples/scripts/kto.py Normal file
View File

@ -0,0 +1,125 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.
# Full training:
python examples/scripts/kto.py \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 16 \
--num_train_epochs 1 \
--learning_rate 1e-5 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir=kto-aligned-model \
--warmup_ratio 0.1 \
--report_to wandb \
--bf16 \
--logging_first_step
# QLoRA:
python examples/scripts/kto.py \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--learning_rate 1e-4 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir=kto-aligned-model-lora \
--warmup_ratio 0.1 \
--report_to wandb \
--bf16 \
--logging_first_step \
--use_peft \
--load_in_4bit \
--lora_target_modules=all-linear \
--lora_r=16 \
--lora_alpha=16
"""
from dataclasses import dataclass
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""
dataset_name: str = "trl-lib/kto-mix-14k"
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
ref_model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# If we are aligning a base model, we use ChatML as the default template
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
# Load the dataset
dataset = load_dataset(script_args.dataset_name)
# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
return example
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
formatted_dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc)
# Initialize the KTO trainer
kto_trainer = KTOTrainer(
model,
ref_model,
args=kto_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
# Train and push the model to the Hub
kto_trainer.train()
kto_trainer.save_model(kto_args.output_dir)
kto_trainer.push_to_hub()

126
examples/scripts/orpo.py Normal file
View File

@ -0,0 +1,126 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the ORPO training script with the following command with some example arguments.
In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model:
# regular:
python examples/scripts/orpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-aligned-orpo" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/orpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-lora-aligned-orpo" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""
from dataclasses import dataclass, field
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config
@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style",
metadata={"help": "The name of the dataset to use."},
)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig))
args, orpo_args, model_config = parser.parse_args_into_dataclasses()
################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
ds = load_dataset(args.dataset)
if orpo_args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
def process(row):
row["prompt"] = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False)
row["chosen"] = tokenizer.apply_chat_template([row["chosen"][-1]], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)
return row
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_prc=orpo_args.dataset_num_proc)
train_dataset = ds["train"]
eval_dataset = ds["test"]
################
# Training
################
trainer = ORPOTrainer(
model,
args=orpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
# train and save the model
trainer.train()
trainer.save_model(orpo_args.output_dir)

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -12,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python examples/scripts/ppo.py \
--log_with=wandb
"""
from dataclasses import dataclass, field
from typing import Optional
import torch
import tyro
from accelerate import Accelerator
from accelerate import Accelerator, PartialState
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoTokenizer, pipeline
from transformers import AutoTokenizer, HfArgumentParser, pipeline
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
@ -33,42 +35,17 @@ tqdm.pandas()
@dataclass
class ScriptArguments:
ppo_config: PPOConfig = field(
default_factory=lambda: PPOConfig(
model_name="lvwerra/gpt2-imdb",
query_dataset="imdb",
reward_model="sentiment-analysis:lvwerra/distilbert-imdb",
learning_rate=1.41e-5,
log_with=None,
mini_batch_size=128,
batch_size=128,
gradient_accumulation_steps=1,
early_stopping=False,
target_kl=6.0,
kl_penalty="kl",
seed=0,
use_score_scaling=False,
use_score_norm=False,
score_clip=None,
)
)
use_seq2seq: bool = False
"""whether to use seq2seq models"""
use_peft: bool = False
"""whether to use peft"""
peft_config: Optional[LoraConfig] = field(
default_factory=lambda: LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="CAUSAL_LM",
),
)
use_seq2seq: bool = field(default=False, metadata={"help": "whether to use seq2seq"})
trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
# LoraConfig
use_peft: bool = field(default=False, metadata={"help": "whether to use peft"})
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_r: Optional[int] = field(default=16, metadata={"help": "the lora r parameter"})
args = tyro.cli(ScriptArguments)
parser = HfArgumentParser((ScriptArguments, PPOConfig))
args, ppo_config = parser.parse_args_into_dataclasses()
# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
@ -76,11 +53,14 @@ sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_si
trl_model_class = AutoModelForCausalLMWithValueHead if not args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead
tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
tokenizer.pad_token = tokenizer.eos_token
# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
# its own dataset.
def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8):
def build_dataset(query_dataset, dataset_num_proc, input_min_text_length=2, input_max_text_length=8):
"""
Build dataset for training. This builds the dataset from `load_dataset`, one should
customize this function to train the model on its own dataset.
@ -93,12 +73,10 @@ def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text
dataloader (`torch.utils.data.DataLoader`):
The dataloader for the dataset.
"""
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
# load imdb with datasets
ds = load_dataset(query_dataset, split="train")
ds = ds.rename_columns({"text": "review"})
ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)
ds = ds.filter(lambda x: len(x["review"]) > 200, num_proc=dataset_num_proc)
input_size = LengthSampler(input_min_text_length, input_max_text_length)
@ -107,48 +85,56 @@ def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text
sample["query"] = tokenizer.decode(sample["input_ids"])
return sample
ds = ds.map(tokenize, batched=False)
ds = ds.map(tokenize, num_proc=dataset_num_proc)
ds.set_format(type="torch")
return ds
# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(args.ppo_config, args.ppo_config.query_dataset)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = build_dataset(ppo_config.query_dataset, ppo_config.dataset_num_proc)
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}
# set seed before initializing value head for deterministic eval
set_seed(args.ppo_config.seed)
set_seed(ppo_config.seed)
# Now let's build the model, the reference model, and the tokenizer.
if not args.use_peft:
ref_model = trl_model_class.from_pretrained(args.ppo_config.model_name, trust_remote_code=args.trust_remote_code)
ref_model = trl_model_class.from_pretrained(ppo_config.model_name, trust_remote_code=args.trust_remote_code)
device_map = None
peft_config = None
else:
peft_config = args.peft_config
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
bias="none",
task_type="CAUSAL_LM",
)
ref_model = None
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
model = trl_model_class.from_pretrained(
args.ppo_config.model_name,
ppo_config.model_name,
trust_remote_code=args.trust_remote_code,
device_map=device_map,
peft_config=peft_config,
)
tokenizer = AutoTokenizer.from_pretrained(args.ppo_config.model_name)
tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here.
tokenizer.pad_token_id = tokenizer.eos_token_id
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(args.ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
# We then build the sentiment analysis pipeline, passing the model name and the
# sentiment analysis pipeline arguments. Let's also make sure to set the device
@ -162,7 +148,7 @@ if ppo_trainer.accelerator.num_processes == 1:
else:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
task, model_name = args.ppo_config.reward_model.split(":")
task, model_name = ppo_config.reward_model.split(":")
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
with ds_plugin.zero3_init_context_manager(enable=False):
sentiment_pipe = pipeline(task, model=model_name, device=device)
@ -188,7 +174,7 @@ generation_kwargs = {
"max_new_tokens": 32,
}
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
for batch in tqdm(ppo_trainer.dataloader):
query_tensors = batch["input_ids"]
# Get response from gpt2

123
examples/scripts/ppo/ppo.py Normal file
View File

@ -0,0 +1,123 @@
import shutil
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)
from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
"""
python -i examples/scripts/ppo/ppo.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--non_eos_penalty \
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/ppo/ppo.py \
--output_dir models/minimal/ppo \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--learning_rate 3e-6 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path EleutherAI/pythia-1b-deduped \
--reward_model_path EleutherAI/pythia-1b-deduped \
--local_rollout_forward_batch_size 1 \
--deepspeed3 \
--non_eos_penalty \
"""
if __name__ == "__main__":
parser = HfArgumentParser((PPOv2Config, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)
################
# Model & Tokenizer
################
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
################
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
eval_samples = 20
train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
dataset_text_field = "prompt"
def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field],
padding=False,
)
return {"input_ids": outputs["input_ids"]}
return dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
num_proc=config.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
################
# Training
################
trainer = PPOv2Trainer(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(config.output_dir)
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()

View File

@ -0,0 +1,132 @@
import shutil
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)
from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
"""
python examples/scripts/ppo/ppo_tldr.py \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 64 \
--total_episodes 30000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--non_eos_penalty \
--stop_token eos \
--response_length 53 \
--sanity_check
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/ppo/ppo_tldr.py \
--output_dir models/minimal/ppo_tldr \
--learning_rate 3e-6 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 4 \
--total_episodes 1000000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--non_eos_penalty \
--stop_token eos \
"""
if __name__ == "__main__":
parser = HfArgumentParser((PPOv2Config, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)
################
# Model & Tokenizer
################
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
################
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style")
if config.sanity_check:
for key in raw_datasets:
raw_datasets[key] = raw_datasets[key].select(range(1000))
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["validation"]
def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""
def tokenize(element):
input_ids = tokenizer.apply_chat_template(
element["messages"][:1],
padding=False,
add_generation_prompt=True,
)
return {"input_ids": input_ids, "lengths": len(input_ids)}
return dataset.map(
tokenize,
remove_columns=dataset.column_names,
load_from_cache_file=not config.sanity_check,
num_proc=config.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
# filtering
train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=config.dataset_num_proc)
assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"
################
# Training
################
trainer = PPOv2Trainer(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(config.output_dir)
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -16,6 +15,7 @@ from dataclasses import dataclass, field
from typing import Optional
import torch
from accelerate import PartialState
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
@ -49,13 +49,16 @@ class ScriptArguments:
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
def create_and_prepare_dataset(tokenizer):
def create_and_prepare_dataset(tokenizer, num_proc):
dataset = load_dataset(script_args.dataset_name, split="train[:1%]")
input_size = LengthSampler(input_min_text_length, input_max_text_length)
@ -66,7 +69,7 @@ def create_and_prepare_dataset(tokenizer):
example["query"] = tokenizer.decode(example["input_ids"])
return example
dataset = dataset.map(tokenize, batched=False)
dataset = dataset.map(tokenize, batched=False, num_proc=num_proc)
dataset.set_format("torch")
return dataset
@ -93,11 +96,14 @@ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
tokenizer.pad_token = tokenizer.eos_token
dataset = create_and_prepare_dataset(tokenizer)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = create_and_prepare_dataset(tokenizer, script_args.dataset_num_proc)
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}
config = PPOConfig(
@ -131,7 +137,7 @@ generation_kwargs = {
"max_new_tokens": 32,
}
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
question_tensors = batch["input_ids"]
response_tensors = ppo_trainer.generate(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -12,162 +11,126 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
"""
python examples/scripts/reward_modeling.py \
--model_name_or_path=facebook/opt-350m \
--output_dir="reward_modeling_anthropic_hh" \
--per_device_train_batch_size=16 \
--num_train_epochs=1 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing=True \
--learning_rate=1.41e-5 \
--report_to="wandb" \
--remove_unused_columns=False \
--optim="adamw_torch" \
--logging_steps=10 \
--eval_strategy="steps" \
--eval_steps=500 \
--max_length=512 \
"""
import warnings
import tyro
from accelerate import Accelerator
import torch
from accelerate import PartialState
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
from trl import RewardConfig, RewardTrainer, is_xpu_available
from trl import ModelConfig, RewardConfig, RewardTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
tqdm.pandas()
@dataclass
class ScriptArguments:
model_name: str = "facebook/opt-350m"
"""the model name"""
dataset_name: str = "Anthropic/hh-rlhf"
"""the dataset name"""
dataset_text_field: str = "text"
"""the text field of the dataset"""
eval_split: str = "none"
"""the dataset split to evaluate on; default to 'none' (no evaluation)"""
load_in_8bit: bool = False
"""load the model in 8 bits precision"""
load_in_4bit: bool = False
"""load the model in 4 bits precision"""
trust_remote_code: bool = True
"""Enable `trust_remote_code`"""
reward_config: RewardConfig = field(
default_factory=lambda: RewardConfig(
output_dir="output",
per_device_train_batch_size=64,
num_train_epochs=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
learning_rate=1.41e-5,
report_to="tensorboard",
remove_unused_columns=False,
optim="adamw_torch",
logging_steps=500,
evaluation_strategy="no",
max_length=512,
if __name__ == "__main__":
parser = HfArgumentParser((RewardConfig, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
)
model = AutoModelForSequenceClassification.from_pretrained(
model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
if model_config.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script."
)
################
# Dataset
################
raw_datasets = load_dataset("Anthropic/hh-rlhf")
# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples
# Preprocess the dataset and filter out examples that are longer than args.max_length
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
num_proc=config.dataset_num_proc,
)
raw_datasets = raw_datasets.filter(
lambda x: len(x["input_ids_chosen"]) <= config.max_length
and len(x["input_ids_rejected"]) <= config.max_length,
num_proc=config.dataset_num_proc,
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
################
# Training
################
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_config),
)
use_peft: bool = False
"""whether to use peft"""
peft_config: Optional[LoraConfig] = field(
default_factory=lambda: LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="SEQ_CLS",
modules_to_save=["scores"],
),
)
args = tyro.cli(ScriptArguments)
args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no"
# Step 1: Load the model
if args.load_in_8bit and args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif args.load_in_8bit or args.load_in_4bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)
# Copy the model to each device
device_map = (
{"": f"xpu:{Accelerator().local_process_index}"}
if is_xpu_available()
else {"": Accelerator().local_process_index}
)
else:
device_map = None
quantization_config = None
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
quantization_config=quantization_config,
device_map=device_map,
trust_remote_code=args.trust_remote_code,
num_labels=1,
)
# Step 2: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
train_dataset = load_dataset(args.dataset_name, split="train")
# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples
# Preprocess the dataset and filter out examples that are longer than args.max_length
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=4,
)
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
)
if args.eval_split == "none":
eval_dataset = None
else:
eval_dataset = load_dataset(args.dataset_name, split=args.eval_split)
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
num_proc=4,
)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
)
# Step 4: Define the LoraConfig
if args.use_peft:
peft_config = args.peft_config
else:
peft_config = None
# Step 5: Define the Trainer
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=args.reward_config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
trainer.train()
trainer.train()
trainer.save_model(config.output_dir)
trainer.push_to_hub()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
print(metrics)

View File

@ -0,0 +1,121 @@
import shutil
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)
from trl import ModelConfig
from trl.trainer.rloo_trainer import RLOOConfig, RLOOTrainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
"""
python -i examples/scripts/rloo/rloo.py \
--learning_rate 3e-6 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--non_eos_penalty \
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/rloo/rloo.py \
--output_dir models/minimal/rloo \
--rloo_k 2 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--learning_rate 3e-6 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path EleutherAI/pythia-1b-deduped \
--reward_model_path EleutherAI/pythia-1b-deduped \
--local_rollout_forward_batch_size 1 \
--deepspeed3 \
--non_eos_penalty \
"""
if __name__ == "__main__":
parser = HfArgumentParser((RLOOConfig, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)
################
# Model & Tokenizer
################
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
################
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
eval_samples = 20
train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
dataset_text_field = "prompt"
def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field],
padding=False,
)
return {"input_ids": outputs["input_ids"]}
return dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
num_proc=config.dataset_num_proc,
)
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
################
# Training
################
trainer = RLOOTrainer(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(config.output_dir)
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()

Some files were not shown because too many files have changed in this diff Show More