Compare commits

...

913 Commits

Author SHA1 Message Date
8d4cc9427d Release: v0.13.0 2024-12-15 19:51:55 +01:00
aeca63774f 👨‍🏫 smol course links and badges (#2484)
* smol course links and badges

* try without space

* revert space
2024-12-15 19:38:48 +01:00
117c6d4b52 📥 Fix missing BitsAndBytesConfig import in doc (#2478) 2024-12-15 16:54:38 +01:00
6d4ed070f1 ☄️ Add support for Comet experiment management SDK integration (#2462)
* Added support for Comet URL integration into model cards created by trainers.

* Moved `get_comet_experiment_url()` into utils.py

* Updated Comet badge in the model card to use PNG image instead of text.

* Fixed bug related to running PPO example during model saving. The error as following: 'GPTNeoXForCausalLM' object has no attribute 'policy'. Introduced guard check that attribute `policy` exists.

* Implemented utility method to handle logging of tabular data to the Comet experiment.

* Implemented logging of the completions table to Comet by `PPOTrainer`.

* Implemented logging of the completions table to Comet by `WinRateCallback`.

* Implemented logging of the completions table to Comet by `RLOOTrainer` and `RewardTrainer`.

* Restored line to the main branch version.

* Moved Comet related utility methods into `trainer/utils.py` to resolve merge conflict with master branch,

* Update trl/trainer/utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Implemented raising of `ModuleNotFoundError` error when logging table to Comet if `comet-ml` is not installed.

* import comet with other imports

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-12-13 22:08:10 +01:00
cd7156fb34 👀 Add "PaliGemma 🤝 Direct Preference Optimization" in community tutorials (#2475) 2024-12-13 20:29:35 +01:00
ca850be0a2 🕹️ CLI refactor (#2380)
* Refactor main function in dpo.py

* Update setup.py and add cli.py

* Add examples to package data

* style

* Refactor setup.py file

* Add new file t.py

* Move dpo to package

* Update MANIFEST.in and setup.py, refactor trl/cli.py

* Add __init__.py to trl/scripts directory

* Add license header to __init__.py

* File moved instruction

* Add Apache License and update file path

* Move dpo.py to new location

* Refactor CLI and DPO script

* Refactor import structure in scripts package

* env

* rm config from chat arg

* rm old cli

* chat init

* test cli [skip ci]

* Add `datast_config_name` to `ScriptArguments` (#2440)

* add missing arg

* Add test cases for 'trl sft' and 'trl dpo' commands

* Add sft.py script and update cli.py to include sft command

* Move sft script

* chat

* style [ci skip]

* kto

* rm example config

* first step on doc

* see #2442

* see #2443

* fix chat windows

* ©️ Copyrights update (#2454)

* First changes

* Other files

* Finally

* rm comment

* fix nashmd

* Fix example

* Fix example [ci skip]

* 💬 Fix chat for windows (#2443)

* fix chat for windows

* add some tests back

* Revert "add some tests back"

This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06.

* 🆔 Add `datast_config` to `ScriptArguments` (#2440)

* datast_config_name

* Update trl/utils.py [ci skip]

* sort import

* typo [ci skip]

* Trigger CI

* Rename `dataset_config_name` to `dataset_config`

* 🏎 Fix deepspeed preparation of `ref_model` in `OnlineDPOTrainer` (#2417)

* Remove unused deepspeed code

* add model prep back

* add deepspeed even if it doesn't work

* rm old code

* Fix config name

* Remove `make dev` in favor of `pip install -e .[dev]`

* Update script paths and remove old symlink related things

* Fix chat script path [ci skip]

* style
2024-12-13 17:52:23 +01:00
179ba53671 🐾 Process-supervised RM Trainer (#2127)
* initial skeleton

* tokenize fn

* adding bos and eos to tokenization fn

* prmtrainer

* fixing small typo in tokenize

* typo in input_ids and labels construction

* numpy dimension

* introduce the stepwise reward trainer

* update markdown files

* let user decide post step separator in config

* doc post_step_separator

* do not add post step_tokens to last step of the reasoning process

* renaming prm to stepwisereward

* formatting

* fix tokenize kwargs

* adapt test to the new post_token args

* adding example script

* fix small typo

* add create_model_card and renaming

* fixing booleans

* Adding the new stepwise_preference instead of placeholders for datasets

* formatting

* Update docs/source/_toctree.yml

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update examples/scripts/stepwise_reward_modeling.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* update push to hub

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* step_separator can't be None

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix suggested typos

* add citation

* reformat doc

* reordering init

* push to hub prm800k

* changing dataset in example

* change dataset format to align with the sky is blue example

* fix tokenization column names

* fix num labels in openai example

* add support for conversational dataset

* remove training whitespace

* replace tokenizer with processing class

* Update docs/source/dataset_formats.mdx

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* remove openai_prm800k

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/stepwise_reward_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/stepwise_reward_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* renaming

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* renaming

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* minor renamings in docs

* using prm800k instead of openai_prm800k

* update num labels to 2 following the new format

* changing doc examples to math examples

* change reference to dataset_formats.mdx

* changing dataset config in test

* remove conversational dataset support

* remove conv dataset support

* fix bos token

* fix scriptarguments in example

* completion to completions

* remove valuerror for step_separator inside steps

* run precommit

* remove conv dataset support

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* renaming zen dataset

* remove unused printing

* unknown label column

* introduce the train on last step arg

* _tokenize support train_on_last_step

* incorporate train_on_last_step to tests

* formatting

* remove comments in trainer

* Refactor `tokenize_row`

* Update max_completion_length parameter in StepwiseRewardConfig

* Collator

* Update comment

* Update type hint

* fix table

* Remove collator

* don't need pad token id

* add error back

* max length args

* use tokenizer arg

* Update doc

* label -> labels

* fixing tokenization issues in tokenize row

* correct labels for token classification

* adding max_length to tokenize_row

* reformat tests

* adding tests for tokenize row

* fixing typos in comments

* update doc

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Add math_shepherd.py script for dataset processing

* split the dataset

* formatting

* same evaluation method for the two training methods

* adding filtering to example script

* formatting

* Add features to avoid casting labels to bool in dataset tokenization

* Update docs/source/stepwise_reward_trainer.mdx [ci skip]

* Add learning_rate parameter to StepwiseRewardConfig class

* update doc

* Remove unused setup_chat_format function

* Fix warning message in stepwise_reward_modeling.py

* Update logging steps in stepwise_reward_trainer.mdx

* little doc change [ci skip]

* Fix copyrights

* fix space after copyrights

* Update dataset loading in stepwise_reward_modeling.py

* refine compute_accuracy and proper test

* fix tests

* style

* renamings

* renaming in init

* doc renaming

* fix sorting and tag

* experiemental [ci skip]

* trigger CI

* other doc fix

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-12-13 15:56:10 +01:00
e3e171a26b 🔨 Support for tools for data utils (#2455)
* function calling training support for SFTTraining

* adding tool support to data_utils

* adding test for function calling tokenizer

* reverting changes to sfttrainer and config,added maybe_apply_chat_template

* arg for maybe_apply_chat_templates docstring

* Doc sectioning

* minor test modification

* minor doc modification

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-12-12 17:11:50 +01:00
b3aff441ff 🎞️ Add "Fine-tuning open AI models using Hugging Face TRL" YouTube video to community tutorials (#2467) 2024-12-12 16:40:28 +01:00
efc687db62 🛠️ Update tests and fix PPO (#2463)
* [bugfix] critic not update

* Update ppo_trainer.py

* Update ppo_trainer.py

* add failing test

* test both policy and critic

* formatting

* fix tests

* formatting

* Update tests/test_ppo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix test

---------

Co-authored-by: NINGBENZHE <53843873+NINGBENZHE@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-12-12 12:53:32 +01:00
f2e362656c ⚖️ Add tests_latest.yml workflow file (#2457)
* Add tests_latest.yml workflow file

* don't check the branch

* Fix workflow
2024-12-11 18:11:41 +01:00
c9c4f18039 [bugfix] Fix DataCollatorForChatML unexpected generation prompt (#2450)
* [bugfix] Fix DataCollatorForChatML unexpected generation prompt

* Update utils.py

* Update test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-12-11 15:18:54 +01:00
460e780265 👯 Standardize model_args (#2442)
* `model_config` -> `model_args`

* sort
2024-12-10 12:51:20 +01:00
7ba118a229 🏎 Fix deepspeed preparation of ref_model in OnlineDPOTrainer (#2417)
* Remove unused deepspeed code

* add model prep back

* add deepspeed even if it doesn't work

* rm old code
2024-12-10 12:40:13 +01:00
6a05feff02 🆔 Add datast_config to ScriptArguments (#2440)
* datast_config_name

* Update trl/utils.py [ci skip]

* sort import

* typo [ci skip]

* Trigger CI

* Rename `dataset_config_name` to `dataset_config`
2024-12-10 11:09:26 +01:00
2f72f47191 💬 Fix chat for windows (#2443)
* fix chat for windows

* add some tests back

* Revert "add some tests back"

This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06.
2024-12-10 10:40:23 +01:00
9410874787 ©️ Copyrights update (#2454)
* First changes

* Other files

* Finally

* rm comment

* fix nashmd

* Fix example

* Fix example [ci skip]
2024-12-10 10:40:00 +01:00
9c5388b69e 🔗 Add "Open in Colab" badges in community tutorials page (#2441) 2024-12-06 10:51:55 +01:00
b02189aaa5 🗂️ Harmonize run and example batch sizes in RLOO docs (#2439)
Doc has different grad_accumulation_steps and per_device_batch size than the actual hyperparameters, can be verified from wandb run.
2024-12-04 19:19:14 +01:00
52201d3c18 🧮 Fix max_steps calculation in RLOOTrainer (#2433) 2024-12-03 21:31:32 +01:00
9ff79a65e3 🔮 Fix unused precomputed ref log probs in DPO (#2431) 2024-12-03 11:36:57 +01:00
9001a8682c 📑 Refactor TrlParser (#2412)
* refactor parser

* Only document some methods

* Update imports in cli_utils.py and remove config option in utils.py

* add `test_parse_args_and_arg_override_config` and remove unnecessary mocks [ci skip]

* fix comment [ci skip]

* fix comment [ci skip]

* Extra arg in config also returned

* fix docstring [ci skip]

* add mock back

* use `deprecate_kwarg`
2024-12-02 19:57:35 +01:00
f6f42651e2 🧑‍🍳 Add precompute batch size argument in DPOTrainer for reference model (#2426)
* added precompute_batch

* review-fixes

* moving up

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

* Update trl/trainer/dpo_config.py [ci skip]

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-12-02 17:17:41 +01:00
148b592313 Update modeling_base.py (#2419) 2024-11-30 12:14:36 +01:00
d6a8f2c2f6 ⚠️ Add warning guidelines and update codebase to follow best practices (#2350)
* Add guidelines for working with warnings in the codebase

* Remove unnecessary warnings and improve code initialization

* Fix warnings and improve accuracy calculation

* Add rich library dependency for text formatting

* Update LoRA weight loading warning message

* Fix logging and import issues in AlignPropConfig

* Fix warnings and improve code readability

* Remove unused import statements

* Refactor CPOTrainer class in cpo_trainer.py

* Remove unnecessary warnings and raise ValueError for missing model

* Fix warnings and improve code consistency

* Update CONTRIBUTING.md to clarify the purpose of warnings

* Fix string formatting in DataCollatorForCompletionOnlyLM class

* Update SimPO loss parameters in CPOTrainer

* Fix warnings and remove unnecessary code in ConstantLengthDataset class

* Clarify warning guidelines

* Rewrite the entire section

* Fix capitalization in CONTRIBUTING.md

* Fix formatting in CONTRIBUTING.md
2024-11-29 16:07:38 +01:00
8d9cfaafeb 🌋 Add support for LLaVA-Next in DPOTrainer (#2413)
* add support for llava-next in dpotrainer

* enable unit test

* code style

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Ignore last layer in test

---------

Co-authored-by: zesong.cwz <zesong.cwz@taobao.com>
Co-authored-by: 1rubbishyuan <2773496952@qq.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-11-29 15:53:50 +01:00
94e4135a17 🔓 Remove lm_head check in AutoModelForCausalLMWithValueHead (#2398)
* Remove lm_head check in `AutoModelForCausalLMWithValueHead`

* Style

* Remove test
2024-11-29 15:52:35 +01:00
ac267781ec 🌐 Community Tutorials (#2411)
* Add community notebooks to API documentation

* fix extension

* add table of community tutorials

* respond to feedback - fix links and split table

* add class references

* rename file and update toc

* Update docs/source/community_tutorials.md

* Update docs/source/community_tutorials.md

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-11-29 11:39:37 +01:00
2c6e0d9705 Add note about special tokens in chat templates for LoRA SFT (#2414) 2024-11-29 10:35:39 +01:00
e1d781353b 👁️ Added SFT support for SmolVLM models via standalone script sft_vlm_smol_vlm.py (#2409)
* Added SFT VLM script for SmolVLM

* Run make precommit

* Updated command example
2024-11-28 18:45:37 +01:00
a34e9bf84f 🖨 Add Script Utilities section to the documentation (#2407)
* Add script_utils.md to the documentation

* Refactor ScriptArguments class documentation

* Refactor TrlParser class to improve code organization and readability
2024-11-28 16:43:08 +01:00
c10cc8995b 🗝️ Update type hints (#2399)
* New type hint structure

* Update type hints

* Delete wrong file

* Remove dict import
2024-11-26 20:37:27 +01:00
9368dccef6 🐢 Fix slow tests (#2397)
* fix slow CI

* fix dpo

* formatting

* Apply suggestions from code review

* `setup_chat_format` may add a pad token

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-11-26 15:38:46 +01:00
43df3a485a 🧳 Move zen generation script and fix tests (#2393)
* Move zen

* step -> stepwise_supervision

* Fix train_test_split shuffle issue

* Fix tests

* Update tests/test_sft_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Fix typo in key name

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-26 14:08:06 +01:00
baee06f2e8 🖋️ Fix warning message formatting in KTOTrainer (#2394) 2024-11-26 13:05:25 +01:00
bbd8cbb720 🤐 Fix deprecation warnings (#2395) 2024-11-26 11:29:07 +01:00
4f937c7629 🤐 Fix deprecation warnings (#2392) 2024-11-26 11:18:43 +01:00
16fa13ce72 👮 Deprecate policy in favor of model in PPOTrainer (#2386) 2024-11-26 08:13:10 +01:00
453db5cd79 🤏 New models for tests (#2287)
* first commit

* uncomment

* other tests adaptations

* Remove unused variable in test_setup_chat_format

* Remove unused import statement

* style

* Add Bart model

* Update BCOTrainerTester class in test_bco_trainer.py

* Update model IDs and tokenizers in test files

* Add new models and processors

* Update model IDs in test files

* Fix formatting issue in test_dataset_formatting.py

* Refactor dataset formatting in test_dataset_formatting.py

* Fix dataset sequence length in SFTTrainerTester

* Remove tokenizer

* Remove print statement

* Add reward_model_path and sft_model_path to PPO trainer

* Fix tokenizer padding issue

* Add chat template for testing purposes in PaliGemma model

* Update PaliGemma model and chat template

* Increase learning rate to speed up test

* Update model names in run_dpo.sh and run_sft.sh scripts

* Update model and dataset names

* Fix formatting issue in test_dataset_formatting.py

* Fix formatting issue in test_dataset_formatting.py

* Remove unused chat template

* Update model generation script

* additional models

* Update model references in test files

* Remove unused imports in test_online_dpo_trainer.py

* Add is_llm_blender_available import and update reward_tokenizer

* Refactor test_online_dpo_trainer.py: Move skipped test case decorator

* remove models without chat templates

* Update model names in scripts and tests

* Update model_id in test_modeling_value_head.py

* Update model versions in test files

* Fix formatting issue in test_dataset_formatting.py

* Update embedding model ID in BCOTrainerTester

* Update test_online_dpo_trainer.py with reward model changes

* Update expected formatted text in test_dataset_formatting.py

* Add reward_tokenizer to TestOnlineDPOTrainer

* fix tests

* Add SIMPLE_CHAT_TEMPLATE to T5 tokenizer

* Fix dummy_text format in test_rloo_trainer.py

* Skip outdated test for chatML data collator

* Add new vision language models

* Commented out unused model IDs in test_vdpo_trainer

* Update model and vision configurations in generate_tiny_models.py and test_dpo_trainer.py

* Update model and tokenizer references

* Don't push if it already exists

* Add comment explaining test skip

* Fix model_exists function call and add new models

* Update LlavaForConditionalGeneration model and processor

* `qgallouedec` -> `trl-internal-testing`
2024-11-25 16:31:56 +01:00
ee3cbe1946 💾 Deprecate config in favor of args in PPOTrainer (#2384) 2024-11-25 14:48:08 +01:00
17e8060984 📦 Support for packing tokenized datasets for SFT (#2011)
* feat: add support for packing tokenized datasetS

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: address review comments

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: add tests for pretokenized dataset packing

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
2024-11-25 10:36:58 +01:00
163695e85c 🙈 Suppress warning for estimating tokens in trainers (#2389)
* Suppress warning for estimating tokens in trainer

* Suppress warning for estimating FLOPs in ORPO and Reward trainers
2024-11-24 16:55:43 +01:00
672c96546d Update log method to include start_time parameter (#2381) 2024-11-21 21:30:10 +01:00
bdeb117320 📝 Fix typo in dataset generation script (#2379) 2024-11-21 20:37:44 +01:00
6578fdc101 🔀 Add MergeModelCallBack (#2282)
* Create mergekit_utils.py

* adding mergekit as an optional dependancy

* adding MergeModel to callbacks

* adding mergekit_utils dependencies to callbacks

* setting lower bound for mergekit

* setting mergekit lower band to 0.0.5.1

* adding support for MergeModelCallBack __init__.py

* adding support for mergemodelcallback

* mergemodelcallback tests

* Update callbacks.py

* Update __init__.py

* Update __init__.py

* Update test_callbacks.py

* Update trl/trainer/callbacks.py

removing ## from docs

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/callbacks.py

removing ## from docs

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/callbacks.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* using different dataset for tests

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/mergekit_utils.py

adding types

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/mergekit_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* replacing get_last_checkpoint

* renaming Merge to merge_models

* setting mergers default value to linear

* removing unnecessary docs and comments

* adding docstring to Mergeconfig

* adding mergekits link to docstring

* precommit

* removing duplicated import

* typos in mergekit_utils docstring

* fixing tests

* making mergemodelcallback tests optional

* Make import optional

* minor

* use tmp dir in test

* sort

* Add import error checks for mergekit extra

* use a common _merge_and_maybe_push method and compat with windows path

* debug windows

* Update dependencies for mergekit and add test dependencies

* Add assertion to check if merged folder exists in the last checkpoint

* Fix temporary directory cleanup in test_callbacks.py

* Add sys import and skip test for Python versions below 3.10 due to cleanup errors with temp dir

* revert change for debug

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-21 14:06:45 +01:00
a0066f47f8 Add start_time to _maybe_log_save_evaluate (#2373) 2024-11-20 12:49:49 +01:00
5626806aef 🧲 Use our own require_bitsandbytes (#2370)
* use our own require_bitsandbytes

* rephrase
2024-11-20 11:51:05 +01:00
bb0afc2459 remove redunant call to eval and train (#2372) 2024-11-20 11:24:41 +01:00
066fc37bd3 Fix dev install (#2369) 2024-11-19 13:30:09 +01:00
b80c1a6fb8 🎲 Move random judges in testing utilities (#2365)
* Update judges and testing utilities

* Update judges in test files

* Update judges in test files
2024-11-18 18:43:18 +01:00
b5eabbeb07 🤝 Mixture of judges (#2159)
* base judge

* adding mixture of judges

* update doc

* update doc

* formatting

* fix small typo in doc

* fix randomcontraintjudge

* replace arxiv by hf papers

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* formatting

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix naming in __init__

* run precommi

* adding gold answers to judges

* cgpo llm judges

* fix init

* output type

* adjust booleans in test

* adapt moj doc

* renaming and removing factuality and safety judges

* fix typo in import

* fix small typo in naming

* formatting

* Update trl/trainer/judges.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* update parameter name

* update tests

* update doc

* Update trl/trainer/judges.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update doc

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix alltruejudge type

* Refactor judge variable names and update test names

* Clarify judgment logic

* Fix invalid binary judgment check in AllTrueJudge class

* Fix invalid binary judgment check in AllTrueJudge class

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-11-18 16:54:57 +01:00
cbf9abcd07 🗺️ Implementation DiscoPOP Loss (#2323)
* Implement DiscoPOP Loss

* Updated DiscoPOP documentation

* Corrected docs/source/dpo_trainer.mdx

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

* Delete scripts directory

* style

* empty commit

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-11-18 14:15:00 +01:00
6f8fe59aeb 📃 Fix description for parameter "generate_during_eval" in dpo_config (#2364) 2024-11-18 14:03:02 +01:00
1293f37c5f 📉 Add PEFT support for PPOTrainer (#2344)
* Add peft/lora support for

* Fix: style

* Fix: typo

* Add ppo.py PEFT example

* Fixed the optional dependencies error

* skip peft test if peft is unavailable

* Update trl/trainer/ppo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-11-18 11:54:09 +01:00
e7870dd5d6 🗃️ Use specified data_collator in RLOOTrainer and PPOTrainer (#2360)
* Fix "Use specified data_collator instead of hard-coding the option"

* Remove query_responses = [] since it's immediately overwritten afterwards.

* Use self.data_collator

* Use specified data_collator instead of hard-coded one in PPOTrainer

* Move the data_collator creation

* Run make precommit
2024-11-18 11:53:47 +01:00
21d5baf338 🔮 Inference mode in GeometricMixtureWrapper.forward (#2345)
* geom mixture model train

* use inference_mode
2024-11-18 09:58:26 +01:00
76dbb1a576 🪜 Stepwise supervision dataset type (#2148) 2024-11-18 09:58:00 +01:00
b8c9d9c7bc ⚖️ Add use_soft_judge option to WinRateCallback (#2347)
* add `use_soft_judge` option to WinRateCallback

* formatting

* Update trl/trainer/callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* renamed soft_win_rate to avg_win_prob

* Update trl/trainer/callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix tests

* keep orignal

* formatting

* Update tests/test_callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_callbacks.py

* fix test

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-11-15 15:49:43 +01:00
623963126b 👋 Remove deprecated tokenizer argument in BCO, GKD, Iterative SFT, Nash MD and XPO (#2349) 2024-11-12 09:22:17 -04:00
2d24d35013 Adding video llm fine-tuning example (#2336)
* adding video example

* exposing more parameters

* fixing formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-12 12:56:38 +01:00
dde20b23cf 🖨️ Fix error text in BCO and KTO tokenizing function (#2286) 2024-11-11 19:18:36 -04:00
015321e135 👈 Add tokenizer arg back and add deprecation guidelines (#2348)
* Add deprecation and backward compatibility guidelines

* Update tokenizer argument in trainer classes

* Add warning message for TRL Judges API
2024-11-11 19:06:20 -04:00
454f36d951 💣 Remove transformers version check (#2343) 2024-11-11 09:34:26 -04:00
9b7f9f3519 🪡 Various RLOO fixes (#2325) 2024-11-11 08:59:03 -04:00
518e29ca9c 🫴 Better guide users in error reporting (#2327)
* update issue template

* Add checklist for bug report template

* Fix formatting in bug report template

* Update bug report template with additional instructions for code formatting and screenshots

* Update bug report template with code formatting instructions

* Update bug report template with code examples

* Update code block placeholder in bug report template

* Update .github/ISSUE_TEMPLATE/bug-report.yml

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-11 08:42:16 -04:00
ac7b6cfdfa 🧞 Add output_layer to the list of lm_head_namings in AutoModelForCausalLMWithValueHead (#2328) 2024-11-11 08:16:09 -04:00
0238d96c6f DPO trainer supports num_logits_to_keep to save memory (#2129)
* Support num_logits_to_keep, which computes necessary logits in the forward pass.

* update doc

* bug fix

* update

* check is model supports num_logits_to_keep

* ruff format

* update test file

* peft model support

* test passed

* update

* apply use_num_logits_to_keep

* fix num_logits_to_keep compute bug

* compare all outputs

* pytest

* pass test

* use check_min_version

* format

* test_dpo_trainer_use_num_logits_to_keep passed

* add some comments

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-10 11:34:51 +01:00
c86b51cd12 Bump liger-kernel to fix grad acc and more features (#2333)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-08 12:16:33 +01:00
ac77c09223 Fix gradient_checkpointing_kwargs assignment in examples (#2331)
Co-authored-by: Ping <ping.zhu@jmuse.cn>
2024-11-07 09:28:10 +01:00
7f2ccbe3a2 fix truncating index in DPOTrainer's concatenated_forward() (#2332) 2024-11-07 09:27:32 +01:00
74e20cbbbc 🪪 Check with token_id instead of token in DPOTrainer (#2324) 2024-11-04 21:08:41 +01:00
27b9e3a93f 🪧 Fix slack notification titles (#2322) 2024-11-04 21:02:27 +01:00
dc2b8b9e90 🧽 Fix judge documentation (#2320)
* Bump dev version to `0.13.0.dev0`

* Update version number to 0.12 in CITATION.cff

* 🧽 Fix judge documentation (#2318)

* Update judge examples and documentation

* without ':'

* Clean doc

* Fix typo in example code

* Add space after Attributes

* Update attribute name in judges.py

* Add installation instructions for llm-blender library

* Update PairRMJudge attributes documentation

* Fix return type in PairRMJudge

* Revert "🧽 Fix judge documentation (#2318)"

This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.

* Revert "🧽 Fix judge documentation (#2318)"

This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.

* 🧽 Fix judge documentation (#2318)

* Update judge examples and documentation

* without ':'

* Clean doc

* Fix typo in example code

* Add space after Attributes

* Update attribute name in judges.py

* Add installation instructions for llm-blender library

* Update PairRMJudge attributes documentation

* Fix return type in PairRMJudge
2024-11-04 19:00:27 +01:00
5e90682836 ⚰️ Remove deprecated args, script arguments, and PPOv2 (#2306)
* Remove deprecated args

* Remove deprecated args in SFTTrainer

* Remove deprecated script argument classes

* Remove deprecated PPOv2Config and PPOv2Trainer classes

* Commented out sync_ref_model line in test_trainers_args.py
2024-11-04 16:07:26 +01:00
3b439967f4 📰 Update blog posts in documentation (#2319)
* Bump dev version to `0.13.0.dev0`

* Update version number to 0.12 in CITATION.cff

* Add publication date to blog post

* 🧽 Fix judge documentation (#2318)

* Update judge examples and documentation

* without ':'

* Clean doc

* Fix typo in example code

* Add space after Attributes

* Update attribute name in judges.py

* Add installation instructions for llm-blender library

* Update PairRMJudge attributes documentation

* Fix return type in PairRMJudge

* Revert "🧽 Fix judge documentation (#2318)"

This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.

* Update blog post publication dates

* revert to p5

* Update image URLs in index.mdx

* Sort and uniform thumbnail

* Update image alignment in index.mdx
2024-11-04 16:00:27 +01:00
2f34a161cd Bump dev version to 0.13.0.dev0 (#2305)
* Bump dev version to `0.13.0.dev0`

* Update version number to 0.12 in CITATION.cff

* 🧽 Fix judge documentation (#2318)

* Update judge examples and documentation

* without ':'

* Clean doc

* Fix typo in example code

* Add space after Attributes

* Update attribute name in judges.py

* Add installation instructions for llm-blender library

* Update PairRMJudge attributes documentation

* Fix return type in PairRMJudge

* Revert "🧽 Fix judge documentation (#2318)"

This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.
2024-11-04 15:59:52 +01:00
6138439df4 🧓 Specify and test min versions (#2303)
* Add conditional check for LLMBlender availability in test_judges.py

* Fix import issues and update test requirements

* Remove unused imports

* Add require_peft decorator to test cases

* Fix import_utils module to use correct package name for llm_blender

* Found min version and test

* Update Slack notification titles

* Update dependencies versions

* Update GitHub Actions workflow to include setup.py and reorder file paths

* Revert "Update Slack notification titles"

This reverts commit be02a7f2de87905e86a847540770968d0416934a.

* Update Slack notification titles

* Remove pull_request branch restriction in tests.yml

* add check code quality back

* Fix PairRMJudge model loading issue
2024-11-01 00:26:53 +01:00
d57a181163 🧩 Add optimizer_cls_and_kwargs attribute to PPOTrainer and RLOOTrainer (#2302) 2024-10-31 23:10:11 +01:00
73c3970c1f 🙅 Ensure dependency optionality (#2301)
* Add conditional check for LLMBlender availability in test_judges.py

* Fix import issues and update test requirements

* Remove unused imports

* Add require_peft decorator to test cases

* Fix import_utils module to use correct package name for llm_blender
2024-10-31 22:37:49 +01:00
013a32b396 Remove stale bot (#2300) 2024-10-31 21:16:30 +01:00
24fb32733f 🔧 Use standard unittest assertion methods (#2283)
* WIP: Partial unit test update

* Update unittest format

* Update tests/slow/test_sft_slow.py comment

* Refactor unit tests: replace pytest.raises with self.assertRaises

* Fix: Restore accidentally deleted 'ref_model' parameter in DPOTrainer

* Re-run pre-commit

* fix: Incorrectly replacing non-TestCase assert

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-10-31 15:10:43 +01:00
bb56c6e6af 💾 Fix _save_checkpoint for online methods (#2288)
* Update trainer_utils import and save strategy in online_dpo_trainer.py

* fix back-compat for online-dpo

* better comment

* Update transformers dependency to commit f33904
2024-10-31 12:35:25 +01:00
06be6f409a 🖇️ Better dependency and partitioning of CI tests (#2298)
* clean deps

* new tests

* tests

* Add tests without optional dependencies workflow

* Update dependencies in tests.yml

* cpu version of torch

* Update dependencies and installation commands

* Disable fail-fast in test workflow

* Update test matrix in workflows file

* try fix windows

* Remove "rich" from required packages in setup.py

* Update dependency installation in tests.yml

* Add torch and deepspeed installation for windows-latest

* Fix conditional statement in workflow file

* Add torch and deepspeed installation for Windows

* Fix if statement

* Update torch and deepspeed dependencies

* Update liger package requirement for non-Windows platforms

* remove scipy dep

* Add torch GPU requirement for testing_utils

* Update trl/trainer/judges.py
2024-10-31 11:08:51 +01:00
b2696578ce 🍬 Use any reward model for online methods (#2276)
* Refactor reward processing in OnlineDPOTrainer

* Refactor completion decoding and reward processing

* remove strip

* remove warning

* Add reward_tokenizer to training script

* Add reward_tokenizer and reward_processing_class to OnlineDPOTrainer test

* propagate to xpo and nash

* style

* reduce memory requirement with inference_mode

* fix tests

* pairrm judge llmblender

* setUpClass(cls)

* Add setUpClass method to TestJudges class

* truncation left for reward tokenizer

* don't logcompletion without eval dataset

* only eval when possible
2024-10-28 16:21:40 +01:00
0ce3b65928 🔌 Fix type hint in LogCompletionsCallback (#2285)
* Update callbacks.py for fix small python type error

* Update callbacks.py

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-10-28 11:49:35 +01:00
e155cb8a66 ⛓️💥 Don't use eval_dataset in scripts when no eval strategy (#2270) 2024-10-28 11:40:51 +01:00
ea7a1be92c 🧮 Fix the computation of KL divergence loss (#2277) 2024-10-25 18:16:02 +02:00
110d0884c7 🏁 Add bos_token_id only if it exists (#2279)
Co-authored-by: sean.jung <sean.jung@sean-ai.local>
2024-10-25 18:15:08 +02:00
57ba9b93aa 🧘 Replace F.log(F.sigmoid(log_odds) with F.logsigmoid(log_odds) (#2274)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-10-24 20:51:55 +02:00
0de75b26f2 🧼 Refactor log_reports.py for Improved Logging, File Processing, and Slack Payload Handling (#2249)
* Update log_reports.py

* comments text update

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* emoji added

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update scripts/log_reports.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update scripts/log_reports.py

* style

* Update scripts/log_reports.py

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-24 20:48:12 +02:00
e615974a03 ♾️ Fix test generation max_new_tokens (#2272)
* `eval_strategy="steps" if eval_dataset else "no"`

* tmp skip test

* drop `eval_strategy` in `test_sft_trainer_uncorrect_data`

* remove eval strategy

* Add parameterized test for generate method

* Revert "`eval_strategy="steps" if eval_dataset else "no"`"

This reverts commit 1e8b331fa2c222a699cb3563f44f5702a7d6f50b.

* Revert "tmp skip test"

This reverts commit 44558f84cc43e20254b567d608b44d059a14913b.

* Revert "drop `eval_strategy` in `test_sft_trainer_uncorrect_data`"

This reverts commit a1ef7016286649fce10b3665159abcbfac2219e3.

* Revert "remove eval strategy"

This reverts commit cb7fafa874b108ba91b29f15944b7c4a41705d6d.

* style

* Refactor test_generate method in test_modeling_value_head.py

* `max_new_tokens=9`
2024-10-24 20:20:01 +02:00
c2bb1eed14 Add torch_dtype to model kwargs in reward modeling example (#2266)
Update model_kwargs to include torch_dtype.

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-10-24 20:12:26 +02:00
9c376c571f [Judges] use the pair-judges in online-preference trainers (#2243)
* use the pair-judges

* add test

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* decode and skip special characters

* initial nash

* return tensors

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* add back the logging

* use batch_decode

* add judges api to XPO trainer

* Update tests/test_online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* judge in examples

* judge in config

* add back logs when using reward model

* typo

* add back model_scores logging when using reward model

* log scores for reward model only

* better cond on what to log

* same for rlhf reward

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* use decode_and_strip_padding

* error if both reward and judge or none are set

* remove unused check

* Uniform way to pass conversation into judge

* heading -> leading

* LogCompletionsCallback compat with online method

* Update Online DPO doc

* check if data is conversational for judges

* update example

* remove comment

* use zip

* fix stats xpo

* Replace judge with PairRMJudge and import AutoModelForSequenceClassification

* update xpo documentation

* Remove doc duplication

* update nash doc

* XPO trl chat

* nash md doc

* HfPairwiseJudge

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-24 16:47:10 +02:00
16994738d0 Conversational dataset support for KTOTrainer (#2248)
* `get_batch_sample` -> `generate_from_model[_and_ref]`

* add `num_items_in_batch=None`

* `num_items_in_batch` in `training_step`

* Fix return type hint

* desc for unpair dataset util

* update example

* process in KTO

* Update doc

* KTO  doc rewrite

* fix orpo doc

* add other dataset config names in test

* update doc image

* fix links in doc

* Update reward and log probability metrics in KTOTrainer doc

* skip enc-dec test

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-10-24 14:01:41 +02:00
99225bb6d6 Bump the minimum transformers version to v4.46 (#2245)
* Bump the minimum transformers version

* Bump version in `requirements.txt`

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-10-24 10:42:30 +02:00
88be2c07e5 🚩 setup_chat_format: throw error if there is already a template in base model (#2252)
* setup_chat_format: throw error if there was already a template

* fix lint

* clarify in docs

* fix test?

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-10-22 13:29:32 +02:00
f2349d2af0 Adjust padding in batch generation (#2251)
* pad batch generation

* Use pad utility

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/utils.py

* reshaping

* fix test_utils.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-10-22 09:36:43 +02:00
d843b3dadd Use processing_class instead of tokenizer in LogCompletionsCallback (#2261) 2024-10-22 09:35:04 +02:00
84dab850f6 🧽 Fix typo in dataset format doc (#2259)
doc update
2024-10-21 17:06:19 +02:00
92f6d246d3 🏗️ Refactor DPO data processing (#2209)
* in progress

* refactor concatenated_inputs and concatenated_forward

* progress

* further modif

* padding side

* eos prompt enc dec

* prompt_padding_side

* drop prompt apdding side collator

* working on decoder only

* dpo trainer

* Fix loss_mask type conversion bug

* bad attention mask

* try to get the same tokens as main

* fix loss mask

* fix unused col

* added comment

* raise error when paddind token not set

* remove private method tests

* initial vlm support

* make it work for paligemma

* minor test updates

* style

* improve readibility

* improve doc

* style

* flush left and truncate

* flush left in the code

* fix empty_cols and make max_length optional

* always add eos token

* minor changes and doc

* style

* fix docstring

* preference collator in doc

* fix doc

* optional max_completion_length

* Investigating CI failing

* style

* just dpo trainer test

* just idefics

* paligemma

* llava

* test cli

* dataset in test

* all tests

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

* reference to ref

* rich descriptions

* fix logits reporting

* fix truncation

* remove chat template from dpo_vlm

* `get_batch_sample` -> `generate_from_model[_and_ref]`

* add `num_items_in_batch=None`

* `num_items_in_batch` in `training_step`

* Fix return type hint

* test tokenize row

* fix test

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-10-21 12:47:33 +02:00
31b7820aad 🔀 Rename get_batch_sample and add num_items_in_batch to compute_loss (#2246) 2024-10-18 21:02:24 +02:00
b9aa965cce Enhance log report script: add error handling and logging (#2232)
* Update log_example_reports.py

1. Added logging: Imported the logging module and set up a logger in the main function. This allows for better error tracking and debugging.

2. Improved file reading: Used a with statement to ensure the file is properly closed after reading. Also added error handling to catch and log any issues when reading the file.

3. Error handling for Slack SDK import: Added a try-except block to handle cases where the slack_sdk might not be installed.

4. Enhanced Slack message sending: Added error handling and logging for the Slack message sending process. This will help identify any issues with the Slack integration.

* style

* Update log_reports.py

1. Logging: Added logging to track errors and important events.

2. Error Handling: Wrapped the log file processing in a try-except block to handle potential errors gracefully.

3. Logging Total Failed Tests: Added a log statement to report the total number of failed tests

* style

* further improve

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-18 19:40:30 +02:00
a67f2143c3 Update SFT examples (#2244) 2024-10-17 14:11:46 +02:00
494b4afa10 [CLI] Setting capture output to False (#2239)
* setting capture output to False

* Update trl/commands/cli.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-10-17 11:04:23 +02:00
02f4e750c0 DPO support remove_unused_columns (#2233) 2024-10-16 10:00:27 +02:00
2ba3005d1c Updated ScriptArguments warning messages (#2230) 2024-10-15 07:46:58 +02:00
7e394b03e8 🎭 Deprecate [SFT/DPO/Reward]ScriptArguments in favour of ScriptArguments (#2145)
* `DPOScriptArguments` to `ScriptArguments`

* use dataset_train_split

* Use scriptarguments

* dataset names in command lines

* use `ScriptArguments` everywhere

* ignore biais buffer to end

* remove in v0.13

* rm comment

* update test commands

* Update docs/source/rloo_trainer.md

* Update tests/test_rloo_trainer.py

* Added dataset_train_split argument to ppo.py and rloo.py

* update scripts with dataset_train_split
2024-10-14 11:14:58 +02:00
14f3613dac Update commands for code linting in contributing guidelines (#2225)
* update commands for code liniting in contributing guidelines

* update docs on code formatting in contributing guidelines

* fix markdown rendering error

* Update CONTRIBUTING.md

* Update CONTRIBUTING.md

* Update CONTRIBUTING.md

* Update CONTRIBUTING.md

* Update CONTRIBUTING.md

* "sans" -> "without"

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-13 09:22:24 +02:00
5e24101b36 📒 Fix type/format confusions (#2223) 2024-10-11 23:39:19 +02:00
b81a6121c3 Add GKD to dataset_formats.mdx (#2222)
* Update dataset_formats.mdx

* Update dataset_formats.mdx

* Update docs/source/dataset_formats.mdx

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Modified to Prompt-completion

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-10-11 21:52:20 +02:00
7f0d246235 Add Sequence-Level KD (#2220)
* Fix templates for dpo, etc.

* Update dpo.py

Add the third issue fixs

* make this a utility.

* Add Sequence-Level KD

* add to the docs-strings and the documentation

* reviewed

* Update docs/source/gkd_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-10-11 20:14:09 +02:00
70036bf87f 🕊️ Migration PPOv2 -> PPO (#2174)
* delete old ppo

* rename ppov2 files

* PPOv2 -> PPO

* rm old doc

* rename ppo doc file

* rm old test

* rename test

* re-add v2 with deprecation

* style

* start update customization

* Lion

* Finish update customization

* remove ppo_multi_adaptater

* remove ppo example

* update some doc

* rm test no peft

* rm hello world

* processing class

* Update docs/source/detoxifying_a_lm.mdx

Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>

* Update trl/trainer/ppov2_config.py

Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>

* Update docs/source/customization.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/detoxifying_a_lm.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* po to example overview

* drop lion

* remove "Use 8-bit optimizer"

* Update docs/source/customization.mdx

* Update docs/source/customization.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* it applies to all trainers

---------

Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-10-11 17:28:39 +02:00
d0aa421e5e Conversational dataset support for ORPOTrainer (#2184)
* default learning rate

* update trainer

* update test

* update script

* update dataset format

* add line in dpo doc

* update orpo doc

* refine implicit/explicit

* update demo chat
2024-10-11 17:08:28 +02:00
5375d71bbd trl env report all cuda devices (#2216) 2024-10-11 16:32:34 +02:00
6004e033a4 Updated README.md with CLI examples and additional usage instructions (#2199)
* Updated README.md with CLI examples and additional usage instructions

Added Command Line Interface (CLI) examples for SFT, DPO, and Chat features.
Improved the "How to Use" section by providing code examples for SFTTrainer and RewardTrainer.
Included installation instructions for both Python Package and source-based installation.
Refined highlights to better showcase efficiency and scalability features.
Updated the repository clone instructions for working with examples.
Added new links to CLI documentation and contribution guide for better navigation.

* Update README.md

* Update README.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update README.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update README.md

* update badges

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-11 16:31:38 +02:00
f436c3e1c9 Update README.md (#2180)
* Update README.md

* Update README.md

* Update README.md

Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>

* Update README.md

* Update README.md

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
2024-10-11 16:14:46 +02:00
cd1aa6bdcc [Judges] Soft judges for PairRM (#2221)
* initial soft judges

* add soft-judge to PairRM

* remove comments

* fix from review
2024-10-11 15:53:42 +02:00
b3f93f0bad Report to "none" in GKD test (#2214) 2024-10-10 19:05:55 +02:00
6c32c8bfcd Improve slack reporting (#2182)
* Update log_example_reports.py

1. Added logging: Imported the logging module and set up a logger in the main function. This allows for better error tracking and debugging.

2. Improved file reading: Used a with statement to ensure the file is properly closed after reading. Also added error handling to catch and log any issues when reading the file.

3. Error handling for Slack SDK import: Added a try-except block to handle cases where the slack_sdk might not be installed.

4. Enhanced Slack message sending: Added error handling and logging for the Slack message sending process. This will help identify any issues with the Slack integration.

* style

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-10 17:42:06 +02:00
3107a40f16 Update incorrect data processing in DataCollatorForChatML (#2172)
* Update incorrect data processing in DataCollatorForChatML

Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels.

* Update trl/trainer/utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* style

* move comment

* add test for DataCollatorForChatML

* update comment with more details

* update assert reports and comments, and adds verification that the last token of input_ids should be EOS token

* new line at the end of file for code quality

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* update tests

* fix test

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* formatting

* fix typo

* simplify

* Revert "simplify"

This reverts commit 7e4006c87265665183032932ca05dffef567e38b.

* tokenize full messages

* dont add eos

* eos is in the last token

* simplify DataCollatorForChatML

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-10 12:49:10 +02:00
419791695c Drop decoder_input_ids in DPOTrainer (#2208) 2024-10-10 10:20:40 +02:00
7e5924d17e [GKD] interpolate in prob. space (#2204)
* interpolate in prob. space

* better var names

* use logsumexp

* set beta dtype

* beta tensor
2024-10-09 12:13:18 +02:00
ed9ea74b62 [DPO] Adding weighted preference optimization (WPO) (#2141)
* skeleton

* add weighting arg in config

* formatting

* fix doc

* do not compute gradients in weighting term

* fixed detach

* add WPO doc

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-10-08 19:52:54 +02:00
511c92c91c Get the aux_loss_coef at BCOTrainer, CPOTrainer, KTOTrainer, and ORPOTrainer initialization (#2201)
* Fix aux_loss coefficient bug of BCOTrainer

* Fix aux_loss coefficient bug of CPOTrainer

* Fix aux_loss coefficient bug of KTOTrainer

* Fix aux_loss coefficient bug of ORPOTrainer
2024-10-08 16:17:09 +02:00
c6cb6353a5 Get the aux_loss_coef at DPOTrainer initialization (#2200) 2024-10-08 16:06:48 +02:00
adb3e0560b ♾️ [CI] Use transformers from source in "tests_no_optional_dep" (#2198) 2024-10-08 12:19:04 +02:00
adf58d80d0 skip_prompt=True in TextIteratorStreamer (#2193)
* skip_prompt in `TextIteratorStreamer`

* Update trl/commands/cli.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update generation streamer in chat.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-10-07 17:38:40 +02:00
9aa022503c Update README.md (#2186)
* Update README.md

Fix grammatical errors in README.md
fixes issue #2185

Description:

I found a grammatical error in the README.md of the project. This PR fixes the error to improve the overall readability and clarity of the documentation.

Changes:
Corrected grammatical errors
Updated lines to reflect the correct grammar
Reasoning: The original text contained a grammatical error that could confuse readers. This fix ensures that the documentation is accurate and easy to understand.

Closes #2185

* Update README.md

Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
2024-10-07 14:30:00 +02:00
82ad390caf Fix RLOO checkpointing (#2114)
* Fix RLOO checkpointing for transformers>=4.45.0

* Add missing import

* Fix pre-commit issues

* Added test for RLOO checkpointing

* Ensure that tokenizer matches SFT and Reward model

* Pre-commit formatting

* processing class

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-07 13:11:17 +02:00
ac038ef03a Update CONTRIBUTING.md (#2181)
* Update CONTRIBUTING.md

* Update CONTRIBUTING.md

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-10-07 06:56:19 -04:00
51ca76b749 [CI] fix dpo gpu ci tests (#2189)
* fix dpo ci test

* color-blind
2024-10-07 10:59:43 +02:00
7005ab4d11 🃏 Model card: "unsloth" tag (#2173) 2024-10-07 10:57:05 +02:00
ffb1ab74ba Update documentation CLI Chat (#2191) 2024-10-07 10:33:51 +02:00
47d08a9626 Rename trainer arg tokenizer to processing_class (#2162) 2024-10-07 09:39:32 +02:00
70327c18e6 add trl to tag for models (#2178) 2024-10-07 08:12:44 +02:00
f05c3fa8fc minor KTO setting changes + KL batch size (#2153)
* add argument for dropout

* increase default lr

* change default lr in examples

* fix bug in calculation of KL batch size

* KL batch size should be args.per_device_train_batch_size

* Update kto_trainer.mdx with hparam recs

* typo

* allow dropout to be disabled

* update lr in sample scrippt

* Update kto_config.py

* Update trl/trainer/kto_trainer.py

* Update docs/source/kto_trainer.mdx

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-10-06 13:13:11 +02:00
4799ba4842 Capybara replaced with ultrafeedback_binarized (#2183) 2024-10-05 18:49:48 +02:00
d45c86e2a7 Conversational dataset support for CPOTrainer (#2144)
* extract prompt and apply chat template in cpo trainer

* default leanring rate

* simplify example

* update doc

* test all formats

* extend exptract prompt

* improve doc format

* link in dataset formats

* Update docs/source/cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-10-04 18:01:02 +02:00
c6b0d1358b 🗑️ Set deprecation version for DPO and SFT arguments to version 0.13 (#2170) 2024-10-04 17:46:55 +02:00
3321084e30 Update trl version in CITATION.cff (#2171) 2024-10-04 12:24:09 +02:00
a9cffc7caf Default dataset_text_field to "text" (#2078)
* clarify ConstantLengthDataset usage

* dont provide dataset text field when formatting func is provided

* kto maybe_apply_chat_template

* default text field

* doc

* remove maybe_apply_chat_template from kto example

* dataset text field always a str

* remove `dataset_text_field="text"`

* update doc
2024-10-04 10:55:47 +02:00
32a928cfc2 🏷️ Model badges in trainer documentation (#2160) 2024-10-04 10:55:06 +02:00
1a3bb372ac Fix typo in error message (#2168)
occured -> occurred
2024-10-04 09:36:52 +02:00
d4564b7c64 ↩️ Revert tokenizer hotfix #2163 2024-10-04 00:14:12 +02:00
1be4d86ccc 🩹 [Hotfix] Add setter for tokenizer (#2163) 2024-10-03 16:13:50 +02:00
78249d9de4 Conversational dataset support for DPOTrainer (#2131)
* conversational dataset support for dpo

* support standard dataset for extract prompt

* test standard dataset for extract prompt

* fix maybe

* fix maybe apply prompt

* style

* overwrite default learning rate of DPO

* style

* rlaif script

* `writer_batch_size` in `train_test_split`

* initial dpo doc refactoring

* vision data section in doc

* lil format modif

* refine Vision datasets

* refine doc

* test new loss type format

* restrcture loss function

* table loss type

* simplify `unsloth`

* improve doc

* looged metrics up

* refine loss section

* Fix label_smoothing parameter in DPOConfig

* dataset for test

* update readme

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* try colorized code block

* refine doc style

* further refine doc

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* re add pali gemma test

* Add missing period

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-10-02 10:04:03 +02:00
5c21de30ae [CI] Don't use eval_strategy="steps" when no eval dataset (#2152)
* `eval_strategy="steps" if eval_dataset else "no"`

* tmp skip test

* drop `eval_strategy` in `test_sft_trainer_uncorrect_data`

* remove eval strategy
2024-10-01 21:46:41 +02:00
0a566f0c58 🩹 Fix attention mask warning in chat CLI (#2147)
* explicit attention mask

* fix chat command
2024-10-01 10:53:18 +02:00
de3876577c [GKD] Set custom EOS tokens in generation config (#2142)
* Expose EOS token IDs in GKD generation

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Revert

* Refactor EOS token setting

* Remove EOS from config

* Refactor

* Add unit test

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-09-30 13:53:16 +02:00
1201aa61b4 rename example (#2139) 2024-09-27 21:45:21 +02:00
c00722ce0a 🃏 Model card for TRL (#2123)
* template and util

* test for online dpo

* template in package_data

* template in manifest

* standardize push_to_hub

* wandb badge and quick start

* bco

* xpo

* simplify `create_model_card`

* cpo

* kto

* dpo

* gkd

* orpo

* style

* nash-md

* alignprop

* bco citation

* citation template

* cpo citation

* ddpo

* fix alignprop

* dpo

* gkd citation

* kto

* online dpo citation

* orpo citation

* citation in utils

* optional citation

* reward

* optional trainer citation

* sft

* remove add_model_tags bco

* Remove unnecessary code for adding model tags

* Fix model tag issue and update URL format

* Remove unused code for adding model tags

* Add citation for XPOTrainer

* Remove unused code in SFTTrainer

* Add model card generation in RLOOTrainer

* Remove unused import and method call in reward_trainer.py

* Add model card generation

* Remove unused code and update error message in ORPOTrainer class

* Add import statements and create model card in IterativeSFTTrainer

* Add dataset name to push_to_hub() call

* Update trainer.push_to_hub() dataset names

* script args

* test

* better doc

* fix tag test

* fix test tag

* Add tags parameter to create_model_card method

* doc

* script args

* Update trl/templates/model_card.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* unittest's `assertIn` instead of `assert`

* Update trl/templates/model_card.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-27 15:23:05 +02:00
124189c86a Add correct label for WinRateCallback table (#2134)
Small fix to make it clear in WandB which table it which
2024-09-27 10:33:41 +02:00
d5eeaab462 arXiv to HF papers (#2133) 2024-09-27 09:00:49 +02:00
5368be1e1e 🧹 Style (#2132)
* drop `# flake8: noqa` in examples

* `__init__.py`

* fix init

* unwrap_model_for_generation

* ignore import violation in init
2024-09-26 21:02:48 +02:00
b169e1030d Add table for WinRateCallback (#2116)
* Add table for WinRateCallback

* Fix tests

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Refactor

* Remove super

* Clean

* Clean

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-09-26 19:28:44 +02:00
9af4734178 ♻️ Standardize script_args (#2130) 2024-09-26 15:23:42 +02:00
a0d714949f Tokenize row during in training_step in OnlineDPOTrainer (#2117)
* tokenize while training

* same for nashmd and xpo

* Update trl/trainer/online_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-26 11:58:14 +02:00
a0e28143ec Eos token encouragement Clarification (#2128)
* Update nash_md_trainer.md

* Update online_dpo_trainer.md

* Update xpo_trainer.mdx

* Fixing XPO Script Location
2024-09-26 11:47:48 +02:00
32d9d34eb1 Standardize pushing to Hub in examples (#2126) 2024-09-26 10:00:51 +02:00
fb1b48fdbe Remove max_length from RewardDataCollatorWithPadding (#2119) 2024-09-26 09:59:12 +02:00
b5e4bc5984 Update example_overview.md (#2125) 2024-09-25 20:45:57 +02:00
7a24565d9d Generalizes VSFT script to support REDACTED (#2120)
* generalizes vst script

* precommit

* change launch command to use accelerate

* updates docs

* rename to sft_vlm

* fix script location

* fix formatting

* comma

* add model link

* fix name

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-09-25 19:54:44 +02:00
44a06fc487 BCOTrainer conversational dataset support (#2107)
* update test

* maybe_apply_chat_template

* simplify bco example

* Update documentation

* Update examples/scripts/bco.py

* Update docs/source/bco_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-24 18:15:57 +02:00
a84fc5d815 Fix packing test (#2111)
* Fix pack test

* same for eval
2024-09-24 17:12:54 +02:00
80038a5a92 [online-dpo] allow parse-args as list of floats (#2108)
* use a seperate argument for list of floats

* do super first

* fix docstrings

* typos

* use list of floats only

* check if it has len

* fix docstring

* fix suggestion

* fix default

* Update trl/trainer/online_dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/xpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

* additional tests

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-09-24 16:56:27 +02:00
cece86b182 fix formatting (#2109)
* fix formatting

* formatting
2024-09-24 16:05:55 +02:00
d005980d8b Fix documentation links (#2105) 2024-09-24 15:35:29 +02:00
cc23b511e4 [RewardTrainer] Tokenize inputs within trainer (#2102)
* Pretokenize in reward modelling

* Fix README example

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Move chat template formatting inside trainer

* Refactor tests

* Fix README

* Disable wandb

* Update readme

* add comment `remove_unused_columns`

* Update trl/trainer/reward_config.py

* doc

* implicit*

* explicit

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-09-24 13:03:32 +02:00
2cad48d511 [CLI] trl env for printing system info (#2104) 2024-09-24 09:57:24 +02:00
6859e048da Fix PPO/RLOO examples (#2100) 2024-09-23 11:49:36 +02:00
92eea1f239 Clean up README and remove openrlbenchmark dependency (#2085)
* Clean up README

* Add Kashif and Quentin

* Refactor

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Add citation

* Omit benchmarks from dev install

* Remove openrlbenchmark

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-09-23 09:21:41 +02:00
663002f609 KTO: fix logits metric, add logits metric to BCOTrainer too (#2094)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-09-21 19:08:10 +02:00
44d998b2af Fix _process_tokens for empty prompts in KTOTrainer (#2093)
The function _process_tokens in trl/trainers/kto_trainer.py crashes if the prompt_input_ids are an empty list.
- added a check for nonzero length
- added a check for nonzero length of answer_input_ids for consistency

The checks happen when determining when subtracting 1 from max_length (happens when BOS or EOS is already present).
2024-09-21 12:49:54 +02:00
9b80f3d50c fix: device could be in meta, transformers#33154 (#2089)
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
2024-09-21 09:11:34 +02:00
2038e52c30 Fix typo in orpo example. (#2092) 2024-09-21 09:11:01 +02:00
10c2f63b2a training_args for all TrainingArguments (#2082) 2024-09-19 15:03:47 +02:00
9fb871f62f [SFT] fix neftune_noise_alpha in SFTTrainer (#1841)
* fix neftune_noise_alpha

* del neftune_noise_alpha first

* check len after removing handle

* make sure we do not load twice

* Update trl/trainer/sft_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* remove neftune from SFTTrainer as the superclass has it now

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-19 11:57:36 +02:00
3cec013a20 Bump dev version 2024-09-19 10:47:21 +02:00
cc80ac6b47 Fix DeepSpeed for PPOv2Trainer.save (#2080) 2024-09-19 09:29:57 +02:00
4c0c98d950 Standardize dataset naming (#2081)
* `ds`, `raw_dataset` etc -> `dataset`

* Update docs/source/detoxifying_a_lm.mdx
2024-09-19 08:59:28 +02:00
0d2bee51aa [WIP] Fix logits/chosen and logits/rejected metrics in KTOTrainer (#2077)
* fix metrics

* fix formatting

* fix "#" sign
2024-09-18 21:09:21 +02:00
6920c2d1bb Conversational dataset support for Online DPO (#2075)
* first modifications in the documentation

* Add script for processing ultrafeedback prompt dataset

* Remove unused variable in ultrafeedback.py

* style

* apply chat template within the init

* extend test

* new default lr

* nash md and xpo conv test

* Update prompt length check to 512 characters

* remove `maybe_apply_chat_template` in XPO and Nash examples

* polish online dpo doc

* better section name

* LogCompletionsCallback doc

* optional generation config

* reorder stats (consistency with online dpo)

* update online dpo doc

* format online dpo config

* format nash_md config

* update nash md

* Nash MD -> Nash-MD

* xpo doc

* doc
2024-09-18 14:10:38 +02:00
4d8267610f Use wrapped model for reference completions in WinRateCallback and set default freq to eval_steps in LogCompletionsCallback` (#2074)
* Use wrapped model for reference completions

* Add unit test for LoRA

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Fix quality

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-09-18 13:55:49 +02:00
c3143832cb processor(prompt, images=image) to processor(images=image, text=prompt) (#2076)
* `prompt, images=image` to `images=image, text=prompt`

* special case of model being str in BCO
2024-09-17 12:09:16 +02:00
e74dbf2d6a Added error when ref_model and model have same id (#2057)
* Added error check to RLOO, PPOv2, OnlineDPO that ref_policy and policy should have different identities.

* Update online_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* style

* extend to other trainers

* bco as well

* case models are strings

* add tests

* style

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-09-17 10:48:32 +02:00
41fe228654 Minor doc fixes and comments (#2073)
* Sort toctree

* rm trainer.mdx

* add missing `

* comment

* online dpo
2024-09-16 16:42:22 +02:00
07f0e687cb Use transformers utilities when possible (#2064)
* use transformers' availability functions

* require from transformers

* rm file

* fix no peft

* fix import

* don't alter  _peft_available

* fix require_diffusers

* style

* transformers>=4.40 and add back `is_liger_kernel_available`
2024-09-16 15:56:49 +02:00
dc2bd07408 Nash md (#1853)
* initial skeleton

* initial config and class

* move TrainerCallback to callbacks.py

* initial trainer mockup

* formatting

* add back header

* script with reward model

* call ref policy forward with torch no_grad

* fix api

* clean up the configs

* use the new API

* fix typo

* get get_reward without grads

* remove unused no_grad calls

* fix formatting

* initial GeometricMixtureWrapper

* Update trl/models/modeling_base.py

Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>

* undo changes to callback

* GenerationMixin needs generation_config

* calculate score with model and mixture model outputs

* fix scores and mixture_scores tensors

* undo

* use interleaved version to calcuate chosen-rejected

* Revert "use interleaved version to calcuate chosen-rejected"

This reverts commit 4a63a60971a7db173d10771548f17f650d955c2a.

* fix mixture scores

* Fix global step

* use mixture_coeff

* record scores_margin only

* fix del

* First version of Nash MD trainer

* undo

* fix formatting

* fix toc

* initial refactorin

* mixin fixes

* fix refactoring

* cleanup comments

* add log_stats

* add test

* initial docs

* fix logs

* fix missing_eos_penalty

* fix output_dir

* add peft_config to docs and super

* undo init changes

* Update docs/source/_toctree.yml

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* add dataset format

* add authors

* add dynamic parameter callback

* update test

* fix comments

* test GeometricMixtureWrapper

* header

* formatting

* formatting

* add paper and abstract

* Update docs/source/nash_md_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* DynamicParameterCallback

* drop callback in favor of getter

* revert kto config change

* revert kto config change

* fix contribution

* `coeff` to `coef`

* log dynamic coefs

* Update docs/source/nash_md_trainer.md

* Update docs/source/nash_md_trainer.md

* fix tests

* use self.ref_model

* one-line

---------

Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Co-authored-by: Daniil Tiapkin <daniil.tiapkin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-09-16 13:46:52 +02:00
cdafc9333c [KTO] Overrides default learning_rate in KTOConfig (#2070)
* learning rate recomentations for kto

* update from suggestion

* override default lr

* add tip tag

* Update trl/trainer/kto_config.py

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-09-16 12:24:43 +02:00
40f05226de Standardizing datasets for testing (#2065)
* zen dataset

* Update dataset test bco

* some tests

* Simple chat template

* bco

* xpo

* kto

* gkd

* trainer_args

* sft

* online dpo

* orpo

* zen script
2024-09-14 22:34:15 +02:00
f6c664301d remove min_new_tokens=args.max_new_tokens (#2069) 2024-09-14 19:37:12 +02:00
08ba866c86 Fix dataset in GKD script (#2067)
I added the wrong dataset name in a prior commit 🙈
2024-09-14 12:29:13 +02:00
ebc85b2e39 PEFT support for Online DPO (#2041)
* Promote `PPOv2Trainer` and `PPOv2Config` to top-level import

* Deprecate `PPOTrainer` and `PPOConfig`

* changes

* Revert "Promote `PPOv2Trainer` and `PPOv2Config` to top-level import"

This reverts commit 96ae02a54154acd2c5c3cc873af3519fedd33d0b.

* Revert "Deprecate `PPOTrainer` and `PPOConfig`"

This reverts commit 65990deb81df1dcaeb2245f01582e8bb45511335.

* peft

* peft

* try to simplify

* revert utils changes

* update dpo script

* peft

* style

* revert gitignore

* test_online_dpo_peft

* ref model

* peft example command

* typo

* remove param.requires_grad = False for the reward model

* make `model` required arg

* update example script

* update xpo trainer

* Update examples/scripts/dpo_online.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/dpo_online.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* merge and unload

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-13 19:15:18 +02:00
88bede66fc Standardise API for WinRateCallback and LogCompletionsCallback (#2061)
* Use wrapped model

* Make WinRateCallback work

* Make LogCompletions work

* Make LogCompletions work

* Fix scripts

* Fix path

* Refactor

* Remove padding

* Refactor

* Fix docs

* Fix scripts

* Fix TLDR template

* Use explicit args

* Fix callback import

* Add docstring
2024-09-13 17:38:42 +02:00
7a2bbe3957 Shuffle examples before they are packed (#2037)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-13 14:23:24 +02:00
d47220f299 make cuda-only tests device-agnostic (#2044)
* update code

* update

* fix style

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-13 14:23:12 +02:00
d8324924c8 Support for SFTTrainer.evaluate() and SFTTrainer.predict() with null train_dataset (#2004)
* add null train_dataset check

* Fix pre-commit errors

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-13 14:22:43 +02:00
4c92ba5769 ©️ Copyrights (#2063)
* copyrights

* fail if missing
2024-09-13 14:18:47 +02:00
a5b98fcf97 Mask loss in gkd when generating from the student (#2058)
* mask loss in gkd

* fix minor issue in test

* Update tests/test_gkd_trainer.py

* fixing masking issues

* Update tests/test_gkd_trainer.py

* Update tests/test_gkd_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-09-13 11:30:59 +02:00
e51a5ac985 Add missing autodocs (#2056) 2024-09-11 21:54:28 +02:00
31b93876a7 📝 Document dataset format (#2020)
* first piece of doc

* improve readibility

* some data utils and doc

* simplify prompt-only

* format

* fix path data utils

* fix example format

* simplify

* tests

* prompt-completion

* update antropic hh

* update dataset script

* implicit prompt

* additional content

* `maybe_reformat_dpo_to_kto` -> `unpair_preference_dataset`

* Preference dataset with implicit prompt

* unpair preference dataset tests

* documentation

* ...

* doc

* changes applied to dpo example

* better doc and better log error

* a bit more doc

* improve doc

* converting

* some subsections

* converting section

* further refinements

* tldr

* tldr preference

* rename

* lm-human-preferences-sentiment

* `imdb` to `stanfordnlp/imdb`

* Add script for LM human preferences descriptiveness

* Remove sentiment_descriptiveness.py script

* style

* example judge tlrd with new dataset

* Syle

* Dataset conversion for TRL compatibility

* further refinements

* trainers in doc

* top level for functions

* stanfordnlp/imdb

* downgrade transformers

* temp reduction of tests

* next commit

* next commit

* additional content

* proper tick format

* precise the assistant start token

* improve

* lower case

* Update titles in _toctree.yml and data_utils.mdx

* revert make change

* correct dataset ids

* expand a bit dataset formats

* skip gated repo tests

* data utilities in API

* Update docs/source/dataset_formats.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/dataset_formats.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/dataset_formats.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/dataset_formats.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* tiny internal testing for chat template testing

* precise type/format

* exlude sft trainer in doc

* Update trl/trainer/utils.py

* XPO in the doc

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-11 20:11:25 +02:00
85696aa64c Gkd trainer (#1814)
* initial

* initial gkd script

* fix output dir name

* smaller max_new_tokens_response size

* fix tab

* use temperature from config

* initial docs

* initial test

* add generalized_jsd_loss

* some docs

* fix order of interpolation

* use log_target=True

* fix formatting

* docstrings

* add peft example

* more docs

* formatting

* fix ordering

* use unwrap_model_for_generation

* initial DataCollatorForLastCompletionLM

* add generation inputs

* logits from the completions

* add eps to probs

* select the logits after removing the padding

* formatting

* interpolate log_probs

* add back online sampling

* update tests

* fix typos

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/_toctree.yml

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use Qwen2

* Update trl/trainer/gkd_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* fixes

* renamed lamda to lmbda due to keyword

* fix config name

* move collator to utils

* fix formatting

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* the larger the lmbda the more on policy it should be

* Use JSD instead of KL

* use DataCollatorForChatML

* fix labels

* use torch_call

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* set default collator to DataCollatorForChatML

* return only the prompts

* fix labels of generated outputs

* formatting

* fix comment

* add missing _prepare_deepspeed

* no attention mask when generating

* update test

* set a sensible max_seq_length

* set default in the collator

* Update tests/test_gkd_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_gkd_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix padding

* formatting

* Update tests/test_gkd_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix tests

* TestGeneralizedJSDLoss

* fix typos

* use a mask to calculate jsd loss

* use the super() training_step after the inputs are created

* fix the docs

* create generate_on_policy_outputs

* loss does not need labels

* use_cache is false when gradient checkpointing is True

* use self.assert

* fix toc

* generate_on_policy_outputs needs token_id

* use papers link

* teacher_model is in eval mode so no need for disabling dropout

* log completions and use_liger

* prompt from train if no eval

* fix logging and add cache empty

* add_generation_prompt=True

* fix prompts

* Update docs/source/gkd_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update examples/scripts/gkd.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* minor doc changes

* fix temp default

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update docs

* fix dataset format

* fix dataset format

* no need for scores in generation

* teacher_model_init_kwargs

* Update _toctree.yml

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_gkd_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/gkd_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/gkd.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* fix

* remove rich

* add determinstic test

* fix code

* use bigger teacher model

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-09-11 19:16:59 +02:00
642c4b1855 Remove debug and sanity_check args (#2055) 2024-09-11 17:56:02 +02:00
9a6061fc2f Clean up DPO example (#2043)
* Clean up DPO example

* Fix bs

* Remove rentrant

* Fix tests

* Nuke sanity checks

* Switch dataset

* Remove sanity check from XPO
2024-09-11 17:45:00 +02:00
a8fd6dcd17 Remove RichProgressCallback from examples (#2053)
* Disable RichProgressCallback by default in examples

* Nuke rich

* Clean
2024-09-11 16:51:05 +02:00
e2966c8d99 Integrate OrpoTrainer with PyTorchXLA for faster step time on TPUs (#2001)
* make Orpotrainer run faster on tpu

* less data transfer

* train-trl.py

* fix

* set device_map=auto

* add is_torch_xla_available guards

* delete file

* address comments

* make presubmit

* Update transformer version in setup.py

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-11 15:11:28 +02:00
37934d70a9 Windows back in CI (#2051)
* Revert "Temporary pin the transformers hash in the CI (#2049)"

This reverts commit f8cf88ab6573699a1a49420f859fdf6aa2f10326.

* Update commit

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-09-11 14:24:07 +02:00
9c043e596b Fix logits compuation in KTO trainer prediction step (#2050)
* Fix logits compuation in KTO trainer prediction step

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-09-11 13:31:42 +02:00
a20e822737 Deprecate PPOTrainer (#2016)
* Promote `PPOv2Trainer` and `PPOv2Config` to top-level import

* Deprecate `PPOTrainer` and `PPOConfig`

* FutureWarning

* Update trl/trainer/ppo_config.py
2024-09-10 19:04:29 +02:00
3511856767 [XPO] xpo trainer (#1943)
* initial xpo trainer

* compute rewards and ref log probs in smaller batches

* add logging

* initial log docs

* fix global_step increment

* fix metric descriptions

* use messages API

* use training_step API

* fix logs

* add test

* add back max_new_tokens

* use max_new_tokens

* refactor

* top_k is an int

* fix formatting

* fix the loss

* fix logging

* fix logging

* fix logging

* fix loss

* calcuate pi_log_ratio once

* fix stats

* fix loss

* do not log loss again

* fix docs

* add disable_dropout_in_model via flag

* comments

* revert doc change

* rm empty cache in online dpo

* improve doc xpo config

* some comment

* fix loggings stats

* fix docs

* save the model

* fix model and reward model

* Update trl/trainer/xpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-09-10 16:08:30 +02:00
f8cf88ab65 Temporary pin the transformers hash in the CI (#2049)
* tmp ci fix

* Update .github/workflows/tests-main.yml

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update .github/workflows/tests.yml

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update .github/workflows/tests-main.yml

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-09-10 16:01:28 +02:00
2ee0b62cdb Change non_eos_penalty to missing_eos_penalty to be consistent across OnPolicy trainers (#2033)
* Subtract a penalty from OnPolicy Trainers if output does not contain an EOS token

* Caught a few other problems

* Updated the documentation for RLOO trainer and PPOv2Trainer

* Corrected the default type and value for missing_eos_penalty

* Made RLOO Trainer consistent with Online DPO and PPOv2

* Removed --non_eos_penalty from all documentation

* Made missing_eos_penalty examples positive (because we subtract).

* Caught two more incorrect examples

* Removed unnecessary whitespace to make ruff happy

* Update trl/trainer/utils.py

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-09-10 14:40:23 +02:00
ac071d6225 Drop canonical dataset namespaces (#2048)
* drop canonical

* Delete ultrafeedback_prompt_only.py dataset script

* reduce dif in best_of_n

* try to revert best_of_n to make github happy

* anyway...
2024-09-10 12:12:00 +02:00
72f19c3fce fix: unpackaging error in Custom Mixture of Experts model when aux_loss_enabled is set to True. (#2039)
* fix: prevent unpackaging error due to additional **aux_loss** returned by **concatenated_forward** function when **aux_loss_enabled** is set to True.

* Refactor: Simplify tuple unpacking in `concatenated_forward` call in `get_batch_loss_metrics` function

* Refactor: improve code quality
2024-09-09 11:47:54 +02:00
8d7b54d4bf Fix packing doc in SFTConfig and fix error when neither dataset_text_field nor formatting_func is provided. (#2035)
* fix dataset and value error in sft

* Update trl/trainer/sft_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* move the test to the right place

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-09 11:39:37 +02:00
a638f73f5c Improves formatting of docstring + newlines (#2006)
* Improves formatting of docstring + newlines

* Linting fix

* Update utils.py

* Set to "Parameters" in config files

* some fixes

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-09-09 10:26:46 +02:00
8a518ee619 Remove unused functions (#2017) 2024-09-08 14:05:46 +02:00
7a67de3c1c Fix docs formatting of ˋ\timesˋ sign in ˋkto_trainer.mdxˋ (#2031)
* correct formatting of star sign in kto_trainer.mdx

The "*" symbol in markdown doesn't show. I changed it to $\times$ so the mathematical formula is clearer

* fix markdown

* one more try
2024-09-08 11:54:04 +02:00
3412f513f2 Refactor reward modelling script to work with chat models (#2026)
* Make Qwen2 works

* Make it work

* Refactor

* Add doc

* Add dataset

* Fix

* Quality
2024-09-06 13:12:38 +02:00
fc20db8873 Clean configs documentation (#1944)
* Clean BCO

* Optional[int]

* fix sft config

* alignprop config

* upadte tempfile to work with output_dir

* clean kto config

* intro docstring

* style

* reward config

* orpo config

* warning in trainer, not in config

* cpo config

* ppo v2

* model config

* ddpo and per_device_train_batch_size (instead of (train_batch_size)

* rloo

* Online config

* tmp_dir in test_ddpo

* style

* remove to_dict and fix post-init

* batch size in test ddpo

* dpo

* style

* `Args` -> `Parameters`

* parameters

* ppo config

* dont overwrite world size

* style

* outputdir in test ppo

* output dir in ppo config

* revert non-core change (1/n)

* revert non-core changes (2/n)

* revert non-core change (3/n)

* uniform max_length

* fix uniform max_length

* beta uniform

* style

* link to `ConstantLengthDataset`

* uniform `dataset_num_proc`

* uniform `disable_dropout`

* `eval_packing` doc

* try latex and α in doc

* try title first

* doesn't work

* reorganize doc

* overview

* better latex

* is_encoder_decoder uniform

* proper ticks

* fix latex

* uniform generate_during_eval

* uniform truncation_mode

* ref_model_mixup_alpha

* ref_model_mixup_alpha and ref_model_sync_steps

* Uniform  `model_init_kwargs` and `ref_model_init_kwargs`

* rpo_alpha

* Update maximum length argument names in config files

* Update loss_type descriptions in config files

* Update max_target_length to max_completion_length in CPOConfig and CPOTrainer

* Update padding value in config files

* Update precompute_ref_log_probs flag documentation

* Fix typos and update comments in dpo_config.py and sft_config.py

* post init warning for `max_target_length`
2024-09-04 10:07:49 +02:00
7acb9c2319 Feat: Add support for APO-zero in KTOTrainer (#1952)
* feat : add kto command

* feat : add support for apo loss in KTO Trainer

* feat : make kto script compatible with dpo-formatted datasets

* fix: lint data utils

* add loss_type in kto test

* fix: data utils docstrings

* fix: add dataset reformat test

* fix: lint tests

* fix: only reference kl_logps if needed

---------

Co-authored-by: Karel D'Oosterlinck <karel@contextual.ai>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-09-04 09:31:46 +02:00
684038057e Allow WinRateCallback to be used without reference model (#2013)
* tests

* make ref model optional

* style

* remove attribute error
2024-09-04 00:05:05 +02:00
1f6a1d2f9a Remove prompts arg from WinrateCallback (#2010)
* rm prompts and add doc

* proper judge type and doc

* test for callback

* style
2024-09-03 17:24:08 +02:00
d60a1f50fe [ci] pin numpy to < 2 on win (#2009) 2024-09-03 13:03:38 +02:00
728a9a3b5f [Docs] Add Liger-Kernel usage to SFTTrainer page (#2007)
* Add Liger-Kernel usage in SFTTrainer

* initial commit

* update flaws

* fix flaws

* Update sft_trainer.mdx

* Update docs/source/sft_trainer.mdx

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update docs/source/sft_trainer.mdx

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update docs/source/sft_trainer.mdx

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update sft_trainer.mdx

---------

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
2024-09-03 08:40:58 +02:00
850ddcf598 [pre-commit] update pre-commit yaml (#2002)
* update pre-commit yaml

* fix test

* use element_type
2024-09-02 19:15:25 +02:00
d57e4b7265 [Online-DPO] fixes to the training scripts and setup.py (#1997)
* fixes

* fixed typo

* add tests for liger

* fix imports

* class name
2024-08-30 22:05:14 +02:00
11f442fc05 move slow-tests CI to new cluster (#1996) 2024-08-30 12:29:21 +02:00
437e8ccaba Bump dev version 2024-08-29 14:39:18 +00:00
4dd0dc2988 Adds experimental Liger support to SFT script (#1992)
* adds cli and import utils

* updates SFT script

* adds liger model to trainer

* adds liger nightly dep

* precommit

* fix import

* Update trl/commands/cli_utils.py

* Fix quality

* moved use_liger arg to sft config

* remove arg

* remove use liger from sft trainer

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-29 14:48:35 +02:00
4f59e923ac Relax numpy upper bound and bump deepspeed version (#1990)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-29 13:17:48 +02:00
10f70fa333 Add ignore_index in DPOTrainer's nn.CrossEntropyLoss (#1987)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-28 16:41:41 +02:00
47ab034ca9 [DPO] tokenize and process DPO data via batches (#1914)
* tokenize and process DPO data via batches

* use helpers

* updated _process_tokens

* fixed

* incorporate build_tokenized_answer in the _tokenizer

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix tokenizer for is_vision_model

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* give the _tokenize the tokenizer as well as optional processor

* fix tests

* add bos and eos tokens

* add prompt_pixel_attention_mask

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* truncate by max_length

* formatting

* fix for enc-dec

* For encoder-decoder models, we need to use the prepared decoder_input_ids

* add tests for _build_tokenized_answer and _tokenize_feature

* check for EOS and BOS tokens

* formatting

* do not include pixel mask if they are not provided

* undo refactor

* undo add_bos_token_if_needed change

* refactor tokenizer into smaller helpers

* add back comments

* fix type hints

* format

* fix t5 tests

* args are never optional

* move cat to appropriate helper

* fix _truncate_tokens

* add tests for _truncate_tokens

* remove dead code

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-28 16:14:53 +02:00
e755eee660 Refactor Online DPO (#1839)
* online dpo trainer based on rloo trainer

* push changes

* refactor

* use `batch_generation` method

* precommit

* remove breakpoint()

* quick refactor

* push the current changes

* quick change

* refactor

* use the config name as the experiment name

* fix logging

* update online DPO docs

* use llm as a judge

* quick change

* quick fix

* cache changes

* new semantics

* style and arg order change

* rm duplicated num_epochs

* rm plot script

* num_epoch

* revert some changes

* revert changes

* revert whitespace

* rm whitespace

* revert change

* policy->model

* optional judge and reward model

* cleaning online dpo script

* warning when both reward mdoel and judge provided

* return -1 when the judge fails

* dataset num proc

* add judges in online dpo; fix collate and process within the trainer

* lr_scheduler.step() after optimizer step

* update odpo test

* reduce nestiness

* allow pickle

* generation config typing

* online dpo llm judge

* fix data collator pad token

* add space

* fix pref score

* -1 for judges

* self.model_wrapped = self.model

* onlinedpo inherits from training arguments

* num_epoch -> num_steps_in_epochs

* update -> epoch

* epoch -> step; step_in_epoch -> ppo_epoch; rm run_name

* num_steps_in_epoch -> num_ppo_epochs

* epoch_idx -> ppo_epoch_idx

* make init consistent with dpo

* try another option

* progress...

* odpo

* current progress

* log and other changes

* rename for legacy

* rename for legacy

* rename and move truncate

* rename

* new config

* LogCompletionsCallback

* style

* rename trainer

* truncate right in utils

* update example

* reward model path

* properly log

* fix example

* add generation prompt and log special tokens

* true penalty

* defaults from the paper

* Remove MPS (#1983)

* Set KV cache false when gradient checkpointing is enabled (#1984)

* Remove MPS

* Fix

* Various tweask

* Remove padding from table

* Clean up

* Fix test

* Revert log freq

* Fix docs

* Fix tests aain!

* Fix typo

* Revert

* Fix regression

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Fix DPO config test

* Fix doc tree

* Clean docs moar

* Add docstring

* raise NotImplemented error for judge

* Refactor cache clearning

---------

Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-28 15:39:51 +02:00
ac31d1205e Skip the failing Online DPO test (#1989)
* Harmonisation of tests between main and PR

* disable tqdm

* skip the test

* `"Programming Language :: Python :: 3.11"` and drop 3.7

* Update .github/workflows/tests.yml

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update setup.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update .github/workflows/tests-main.yml

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-28 14:55:18 +02:00
c44ab6d1e9 torch.load with weights_only=True (#1988)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-28 11:13:22 +02:00
a15a80e0d5 gather the target model params as well (#1978) 2024-08-28 09:27:26 +02:00
264f1279fd Promote PairRMJudge to top-level import (#1985)
* allow `from trl import PairRMJudge`

* test_pair_rm_judge

* Update setup.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-27 21:04:05 +02:00
0cda2f2f01 Restore test (#1982) 2024-08-27 11:16:32 +02:00
e0ff66103e Update tests for _get_kl_dataset (#1974)
* Test for #1970

* style

* drop last element in the batch for test

* check prompt_input_ids not modified

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-27 11:00:43 +02:00
3a3ed88f28 Fix dataset_num_proc missing in PPOConfig (#1966)
* fix a few minor bugs in ppo.py

* dataset_num_proc as training arg

* num proc in config

* Update examples/scripts/ppo.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-27 10:59:45 +02:00
b65657f41d Fix flaky Hub tests (#1981)
* Fix flaky Hub tests

* Trigger Build

* test buld
2024-08-27 10:14:39 +02:00
de024ece28 Use weights_only for load (#1933)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-26 18:18:38 +02:00
2fbc0f4fc2 Fix issue template path (#1973)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-26 14:40:37 +02:00
cf5168ea7c New mismatch pair creation strategy (#1970)
* new mismatch pair creation strategy

* style

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-26 13:29:22 +02:00
1e4fb80cbc Fix issue with unnecessary cached during logp calc. (#1969) 2024-08-26 12:38:58 +02:00
fe41acd6ae add arg padding_free to DataCollatorForCompletionOnlyLM (#1887)
* add arg `padding_free` to DataCollatorForCompletionOnlyLM

* Update tests/test_data_collator_completion_only.py

* Update trl/trainer/utils.py

* Update tests/test_data_collator_completion_only.py

* Update tests/test_data_collator_completion_only.py

* Update tests/test_data_collator_completion_only.py

* Update tests/test_data_collator_completion_only.py

* Update test_data_collator_completion_only.py

* Update tests/test_data_collator_completion_only.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2024-08-26 09:48:39 +02:00
c71262c9c6 Fix issue with precompute_ref_log_probs not working when rpo_alpha is None (#1961)
* Fix issue with precompute_ref_log_probs not working when rpo_alpha is None

* Test: Add test for precompute_ref_log_probs with rpo_alpha=None

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-25 12:15:57 +02:00
dcee683d96 Add issue/PR templates, code of conduct & better contributing guide (#1963)
* Add issue/PR templates, code of conduct & better contributing guide

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-23 23:12:40 +02:00
4788e5cda5 Support LLaVA-NeXT in Vision SFT (#1959)
* support llava next

* mention version for llava-next

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-23 11:37:40 +02:00
6cea2ef964 [ODPO] Refactor training script to use messages API (#1958)
* Refactor dataset prep

* Add moar doc
2024-08-22 20:03:12 +02:00
64d9816eac Fix response truncation in examples/notebooks/gpt2-sentiment.ipynb (#1957) 2024-08-22 16:22:46 +02:00
67564fdbbe "help wanted" in label to exempt from stale (#1956)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-22 11:27:37 +02:00
e529579232 Fix global step for consistent checkpointing with global updates (#1950) 2024-08-21 10:19:37 +02:00
dc4cfab700 Log WandB tables on main process (#1951) 2024-08-20 16:42:51 +02:00
66d3a82dd2 Add a simple-to-understand example for online DPO (#1947)
* Update online_dpo_trainer.md

* Update docs/source/online_dpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/online_dpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/online_dpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/online_dpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update online_dpo_trainer.md

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-20 16:14:40 +02:00
3eda856371 Don't mark issues as stale if nobody answered (#1949)
* don't mark issues as stale if nobody answered

* refactor

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-20 15:13:40 +02:00
616a273ac2 Fix model wrapping for online DPO (#1946) 2024-08-19 18:17:11 +02:00
9955583829 Drop token arg in push_to_hub (#1945)
* Skip token in `push_to_hub`

* fix doc

* move comment

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-19 11:34:11 +02:00
bed205a2d2 Properly tag models when pushed to 🤗 Hub (#1940)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-18 11:16:27 +02:00
42933fa647 Optional Additional Loss to Center Reward Models' Outputs (#1932)
* Implemented Eisenstein reward model centering

* Forgot self in accessing args

* Added docstring for center_rewards_coefficient.

* Fixed bug.

* Update trl/trainer/reward_config.py

Added a reference.

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Switched to Quentin's suggestion

* Update trl/trainer/reward_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* doc

* 0.01

* style

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-17 22:44:03 +02:00
bbdef00961 Fix model to save in PPOv2 (#1776)
* fix model to save in ppov2

currently saving self.backup_model but this should be self.model
self.backup_model is only a temp model used to store the policy and
value function whereas self.model should have just the policy to save

* simplified logic

* remove unused ordereddict

* format

* fix the fix

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-17 17:47:01 +02:00
0956dc17cc Add tests for DPO for VLM (#1935)
* add dpo visual test

* skip last layer of llava in test

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-16 16:29:40 +02:00
a7dc892717 Anchored preference optimization loss for DPO (#1928)
* feat: anchored pref optimization

* Update trl/trainer/dpo_trainer.py

* format and properly deprecate loss_type

* add aot in error message and reorder

* add "sppo_hard", "nca_pair" in label_smoothing warning warning

* add tests

* doc

* doc fixes

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-14 17:37:49 +02:00
b0372e66a5 Improve DPO/loss doc (#1929)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-14 16:52:26 +02:00
c1b272f4a6 minor BCO fixes (#1923)
* checkpointing BCO UDM classifier

* kto_config remove unused parameters

* BCO fix loading

* kto_config remove unused parameters

* kto_config remove unused parameters

---------

Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-14 15:27:13 +02:00
f05f63c1ea PartialState().local_main_process_first() when map in examples (#1926)
* `PartialState().local_main_process_first()` when map in examples

* allow load from cache

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-14 12:01:03 +02:00
54f806b6ff Standardize dataset_num_proc usage (#1925)
* uniform dataset_num_proc

* num_proc in shuffle

* Update examples/datasets/anthropic_hh.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/ppo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/ppo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-08-13 15:10:39 +02:00
a9a756553f Add explicit library name for TRL repos (#1922) 2024-08-13 11:36:01 +02:00
96bb3deb32 fix orpo trainer loss device (#1919) 2024-08-12 15:55:23 +02:00
dbea3da917 torch.cuda.amp.autocast() -> torch.amp.autocast("cuda") (#1921)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-12 14:43:38 +02:00
150a93101b lr_scheduler.step() call after optim.step() (#1918)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-12 14:21:50 +02:00
cbcaa46cd3 Various args and test fix (#1909)
* report to none

* simplify AlignPropTrainerTester

* rm unused marker

* Don't share setup in dpo trainer

* style

* don't share setup in test rich

* fix setup and classmethod

* fix args for sft

* test_trainer_args

* various arg fix

* report to none and vsdt simplifi

* drop generate_during_eval

* fix run_name

* style

* drop setUpClass

* style

* new ref values for ppo trainer tester

* update ref val

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-09 10:07:58 +02:00
e3fe28ee1a Fix AlignPropTrainer import (#1908) 2024-08-07 11:33:11 +02:00
fb0b9edc24 Fix GPT2 sentiment notebook reward (#1738)
* Fix reward change

* clean up notebook

* fix eval metric

* regenerate output with correct model

* swap wrong operation order

* Update gpt2-sentiment.ipynb

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-06 22:19:05 +02:00
fc76fe8d11 [Online-DPO] num_generation_per_prompt is fixed (#1898)
* num_generation_per_prompt is fixed

* remove unused no_grads

* removed bin

* fix scores

* fix scores

* formatting

* undo
2024-08-06 18:21:35 +02:00
b60ce797d8 Support Rank Stabilized LoRA in the ModelConfig/LoraConfig (#1877)
* feat: support RS-LoRA in the ModelConfig

* build: bump minimum peft version to support rslora

* test: add test for get_peft_config

* test: make test python 3.8 friendly

* rm unused marker

* minor changes

* simplify, clarify doc

* update deps (peft in test)

* re-ordering

* fix setup

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-06 18:02:59 +02:00
6faf4c0d81 [RPO] use loss from v3 of paper (#1904)
* RPO loss from v3

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix docs

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-08-06 16:28:46 +02:00
29bd0046a9 fix process orpo example (#1903)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-06 12:57:11 +02:00
4867c2a3db Support IterableDataset for SFTTrainer (#1899)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-05 18:04:17 +02:00
332062372d Drop setUpClass in reward tester (#1895)
* drop setUp class in reward tester

* report to none

* style

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-05 16:01:43 +02:00
b580e45c94 [WIP] Drop save/load test on windows (#1897)
* just test modelling

* Trigger CI

* always trigger

* only test_from_save_trl

* parametrize

* just one model

* file

* rm ref model

* assert exists

* style

* Update Makefile

* Update tests.yml

* Update Makefile

* Update test_modeling_value_head.py

* Update test_modeling_value_head.py

* skip windows

* skip test_from_save_transformers

* also skip test_from_save_trl

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-05 16:01:06 +02:00
2004d62c5c fix serialization of RunningMoments on multiple GPUs (#1892)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-08-04 10:57:28 +02:00
ac7c8b1284 evaluation_strategy -> eval_strategy (#1894)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-02 16:01:35 +02:00
df12913602 Fix SFT for VLM example (#1865)
* fix vsft example commands

* fix use_cache and get tokenizer from processor

* rm unused AutoTokenizer

* Squashed commit of the following:

commit 8bd2ab82f4cedc8b3459126aa145c63180078392
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 14:06:19 2024 +0200

    Refactor judges (#1856)

    * BaseJudge -> BasePairwiseJudge

    * hf judge asyncio

    * refactor judges

    * doc

    * doc

    * doc

    * memeber judge

    * :inherited-members:

    * :inherited-members:

    * doc

    * give up

    * judge tldr with judge class

    * fix rank in multithread

    * format

    * improve doc

    * update doc

    * typo doc

    * doc online dpo

    * Update judge_tldr.py

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 82b07d6b0169bb8150f2fa4ee0a58b678d597163
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:43:48 2024 +0200

    Llama in modelling value head tests (#1878)

commit 72bf6c21beedd95b1deb1ff95bd4d1bad5380503
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:33:07 2024 +0200

    Skip BigBird save and load test until next transformers version (#1874)

commit 74e54b5946b3e46c9fef516b6f5403943c7c4096
Author: Edward Beeching <edbeeching@users.noreply.github.com>
Date:   Fri Jul 26 09:36:25 2024 +0200

    fix online dpo example (#1879)

commit 393097356c3494a1310cd59b0205358723468443
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:17:37 2024 +0530

    Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)

    * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

    Added ```dataset_text_field``` in the SFTConfig while training

    * Update docs/source/sft_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit db8e09e3463837d6f80d593f2806c0d83d97c787
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:06:57 2024 +0530

    Import missing ```setup_chat_format``` (#1862)

commit 1dae55f90f6e929500df4fc4ee5bbc0146e35574
Author: elie <97572401+eliebak@users.noreply.github.com>
Date:   Thu Jul 25 10:27:34 2024 +0200

    add fsdp_qlora config and bnb_4bit_quant_storage (#1863)

commit c8cef79e6c895c9950ad7af61897f3a89372c56d
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 24 21:06:57 2024 +0200

    arXiv to HF Papers (#1870)

commit 7dcf437a1997cb1b252e8ea0b8ad7dca13261d7e
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 24 12:27:50 2024 +0200

    [online-DPO] online dpo cleanups (#1864)

    * online dpo cleanups

    * remove unused self.policy

    * add OnlineDPOTrainer and config to __init__.py

    * import from trainer

    * online dpo test

    * rename policy to model and ref_policy to ref_model

    * renamed internally

    * formatting

commit 4e85bd75a9dfca0074eef3a90130054c283eed39
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jul 18 14:35:31 2024 -0400

    Online DPO and Online trainer refactor (#1809)

    * online dpo trainer based on rloo trainer

    * push changes

    * refactor

    * use `batch_generation` method

    * precommit

    * remove breakpoint()

    * quick refactor

    * push the current changes

    * quick change

    * refactor

    * use the config name as the experiment name

    * fix logging

    * update online DPO docs

    * push docs

    * increment global step so tensorboard works again.

    * precommit

    * remove unused common online trainer

    * add online DPO docs

    * quick refactor

    * push changes

    * Update docs/source/online_dpo_trainer.md

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

    ---------

    Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit c9d56366ede5990d690f3b7a3f249c434f3633d6
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 18 18:28:49 2024 +0200

    rm token (#1852)

* add section in doc

* Squashed commit of the following:

commit 890232fa2861c40d46adeaf975a4209eb04fe841
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jul 30 14:29:47 2024 +0200

    update example overview (#1883)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 9929370dee9975f1c6d80b32198ea3e7fd0dcc06
Author: Clara Pohland <54847419+claralp@users.noreply.github.com>
Date:   Sun Jul 28 21:10:08 2024 +0200

    Move BCO to separate BCOTrainer with fixes (#1869)

    * kto_trainer: skip KL data for BCO

    * kto_trainer: BCO allow no positives or no negatives in batch

    * kto_trainer: make RunningMoments object serializable

    * add BCOTrainer

    * fix BCO UDM for not interleaved data

    * kto_trainer: remove unused UDM part

    * bco_trainer: add tests and docs, minor fixes

    * code style fixes

    * Update docs/source/bco_trainer.mdx

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * fix BCO UDM for bfloat16

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_config.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/utils.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_trainer.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_config.py

    * Update _toctree.yml

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_trainer.py

    * RunningMoments, fix multi GPU serialization

    * fix tests

    ---------

    Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

commit 6171cddee5165869af8b40b526476680cebe47ef
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:51:38 2024 +0200

    Re-add BigBird Pegasus save/load test (#1882)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 33d2151f4fa37728fea9448420301a1380fee745
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:07:10 2024 +0200

    Re-add BigBird Pegasus save/load test (#1876)

    * skip bigbird in ci

    * readd big bird test

    * pytest parametrize

    * dont check the version

    * rm model name

    * re add big bird

    * Merge branch 'main' into readd-bigbird-save-load-test

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 8bd2ab82f4cedc8b3459126aa145c63180078392
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 14:06:19 2024 +0200

    Refactor judges (#1856)

    * BaseJudge -> BasePairwiseJudge

    * hf judge asyncio

    * refactor judges

    * doc

    * doc

    * doc

    * memeber judge

    * :inherited-members:

    * :inherited-members:

    * doc

    * give up

    * judge tldr with judge class

    * fix rank in multithread

    * format

    * improve doc

    * update doc

    * typo doc

    * doc online dpo

    * Update judge_tldr.py

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 82b07d6b0169bb8150f2fa4ee0a58b678d597163
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:43:48 2024 +0200

    Llama in modelling value head tests (#1878)

commit 72bf6c21beedd95b1deb1ff95bd4d1bad5380503
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:33:07 2024 +0200

    Skip BigBird save and load test until next transformers version (#1874)

commit 74e54b5946b3e46c9fef516b6f5403943c7c4096
Author: Edward Beeching <edbeeching@users.noreply.github.com>
Date:   Fri Jul 26 09:36:25 2024 +0200

    fix online dpo example (#1879)

commit 393097356c3494a1310cd59b0205358723468443
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:17:37 2024 +0530

    Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)

    * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

    Added ```dataset_text_field``` in the SFTConfig while training

    * Update docs/source/sft_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit db8e09e3463837d6f80d593f2806c0d83d97c787
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:06:57 2024 +0530

    Import missing ```setup_chat_format``` (#1862)

commit 1dae55f90f6e929500df4fc4ee5bbc0146e35574
Author: elie <97572401+eliebak@users.noreply.github.com>
Date:   Thu Jul 25 10:27:34 2024 +0200

    add fsdp_qlora config and bnb_4bit_quant_storage (#1863)

commit c8cef79e6c895c9950ad7af61897f3a89372c56d
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 24 21:06:57 2024 +0200

    arXiv to HF Papers (#1870)

commit 7dcf437a1997cb1b252e8ea0b8ad7dca13261d7e
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 24 12:27:50 2024 +0200

    [online-DPO] online dpo cleanups (#1864)

    * online dpo cleanups

    * remove unused self.policy

    * add OnlineDPOTrainer and config to __init__.py

    * import from trainer

    * online dpo test

    * rename policy to model and ref_policy to ref_model

    * renamed internally

    * formatting

commit 4e85bd75a9dfca0074eef3a90130054c283eed39
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jul 18 14:35:31 2024 -0400

    Online DPO and Online trainer refactor (#1809)

    * online dpo trainer based on rloo trainer

    * push changes

    * refactor

    * use `batch_generation` method

    * precommit

    * remove breakpoint()

    * quick refactor

    * push the current changes

    * quick change

    * refactor

    * use the config name as the experiment name

    * fix logging

    * update online DPO docs

    * push docs

    * increment global step so tensorboard works again.

    * precommit

    * remove unused common online trainer

    * add online DPO docs

    * quick refactor

    * push changes

    * Update docs/source/online_dpo_trainer.md

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

    ---------

    Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit c9d56366ede5990d690f3b7a3f249c434f3633d6
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 18 18:28:49 2024 +0200

    rm token (#1852)

* simplify script

* doc

* use traning args

* args instead of trianing args

* fix doc

* drop eval

* rm eval section

* re-add bigbirg

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-08-02 10:31:51 +02:00
ddf4c8dc3e fix dpo_trainer bug for LLMs without bos_token in config (#1885)
* fix dpo_trainer bug for LLMs without bos_token in config

* fix adding bos_token_id bug in dpo,orpo,cpo trainers

* formatting for fixing bos_token adding bug

* Update trl/trainer/utils.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-07-31 12:42:06 +02:00
890232fa28 update example overview (#1883)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-30 14:29:47 +02:00
9929370dee Move BCO to separate BCOTrainer with fixes (#1869)
* kto_trainer: skip KL data for BCO

* kto_trainer: BCO allow no positives or no negatives in batch

* kto_trainer: make RunningMoments object serializable

* add BCOTrainer

* fix BCO UDM for not interleaved data

* kto_trainer: remove unused UDM part

* bco_trainer: add tests and docs, minor fixes

* code style fixes

* Update docs/source/bco_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* fix BCO UDM for bfloat16

* Update trl/trainer/bco_config.py

* Update trl/trainer/bco_config.py

Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

* Update trl/trainer/utils.py

Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

* Update trl/trainer/bco_trainer.py

Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

* Update trl/trainer/bco_config.py

* Update _toctree.yml

* Update trl/trainer/bco_config.py

* Update trl/trainer/bco_trainer.py

* RunningMoments, fix multi GPU serialization

* fix tests

---------

Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Seungjae Jung <seanexplode@gmail.com>
2024-07-28 21:10:08 +02:00
6171cddee5 Re-add BigBird Pegasus save/load test (#1882)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-28 15:51:38 +02:00
33d2151f4f Re-add BigBird Pegasus save/load test (#1876)
* skip bigbird in ci

* readd big bird test

* pytest parametrize

* dont check the version

* rm model name

* re add big bird

* Merge branch 'main' into readd-bigbird-save-load-test

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-28 15:07:10 +02:00
8bd2ab82f4 Refactor judges (#1856)
* BaseJudge -> BasePairwiseJudge

* hf judge asyncio

* refactor judges

* doc

* doc

* doc

* memeber judge

* :inherited-members:

* :inherited-members:

* doc

* give up

* judge tldr with judge class

* fix rank in multithread

* format

* improve doc

* update doc

* typo doc

* doc online dpo

* Update judge_tldr.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-28 14:06:19 +02:00
82b07d6b01 Llama in modelling value head tests (#1878) 2024-07-26 11:43:48 +02:00
72bf6c21be Skip BigBird save and load test until next transformers version (#1874) 2024-07-26 11:33:07 +02:00
74e54b5946 fix online dpo example (#1879) 2024-07-26 09:36:25 +02:00
393097356c Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)
* Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

Added ```dataset_text_field``` in the SFTConfig while training

* Update docs/source/sft_trainer.mdx

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-07-25 10:47:37 +02:00
db8e09e346 Import missing ``setup_chat_format`` (#1862) 2024-07-25 10:36:57 +02:00
1dae55f90f add fsdp_qlora config and bnb_4bit_quant_storage (#1863) 2024-07-25 10:27:34 +02:00
c8cef79e6c arXiv to HF Papers (#1870) 2024-07-24 21:06:57 +02:00
7dcf437a19 [online-DPO] online dpo cleanups (#1864)
* online dpo cleanups

* remove unused self.policy

* add OnlineDPOTrainer and config to __init__.py

* import from trainer

* online dpo test

* rename policy to model and ref_policy to ref_model

* renamed internally

* formatting
2024-07-24 12:27:50 +02:00
4e85bd75a9 Online DPO and Online trainer refactor (#1809)
* online dpo trainer based on rloo trainer

* push changes

* refactor

* use `batch_generation` method

* precommit

* remove breakpoint()

* quick refactor

* push the current changes

* quick change

* refactor

* use the config name as the experiment name

* fix logging

* update online DPO docs

* push docs

* increment global step so tensorboard works again.

* precommit

* remove unused common online trainer

* add online DPO docs

* quick refactor

* push changes

* Update docs/source/online_dpo_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-07-18 14:35:31 -04:00
c9d56366ed rm token (#1852) 2024-07-18 18:28:49 +02:00
4dce042a38 Add WinRateCallback and Judges (#1598)
* Add WinRateCallback

* Enable PairRM

* Refactor

* Streamline

* Add HF judge

* Add base judge

* Use better prompt

* Clean

* Add max tokens

* Use logging

* Add batched inference

* Squashed commit of the following:

commit 9e9dc96e676a3601882b5cf11842bd22267fd2c5
Author: Maxim Kopecki <kopecki.maxim@gmail.com>
Date:   Wed Jul 10 19:11:13 2024 +0200

    Added missing token kwarg in Peft model loading (#1825)

commit 7ddef5c1582f14f32b6dd692f8e4b904fd478038
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 10 18:26:11 2024 +0200

    Make use of `trust_remote_code` consistent (#1806)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit a9cddf8c55a0b2af101a3d18bd92f263f4ae4500
Author: Adnan Khan <AdnaneKhan@users.noreply.github.com>
Date:   Wed Jul 10 11:25:07 2024 -0400

    Delete unused benchmark.yml workflow. (#1822)

commit 2860ce5091e689bab167454453e9ddbe2337de3d
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jul 9 09:22:52 2024 +0200

    DPO Llava 1.5 and PaliGemma support (#1797)

    * llava support dpo

    * add_special_tokens=False only when possible

    * format

    * pali gemma

    * refactor size

    * remove image resize

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 30e33bd92da1f5569493e16da8971247cc376927
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jul 9 05:37:12 2024 +0200

    upgrade gh actions (#1818)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit d5a0d2d345ec26646ceaa06adfe6133aad18702a
Author: Costa Huang <costa.huang@outlook.com>
Date:   Mon Jul 8 11:12:41 2024 -0400

    Set dev version (#1817)

commit 314e8eb367cbfaf74c2e9717085346360e779508
Author: Puneet Singh Bhooi <puneetb@iiitd.ac.in>
Date:   Mon Jul 8 19:11:36 2024 +0530

    fix broken url in `docs\source\index.mdx` (#1813)

commit e10792032be644a65dcbcf2ebe9ec947497d4d46
Author: Costa Huang <costa.huang@outlook.com>
Date:   Mon Jul 8 09:38:09 2024 -0400

    0.9.6 release (#1816)

commit 78045dedc8678af04f4e35ffe63f37be196a435b
Author: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date:   Mon Jul 8 01:59:26 2024 +0200

    Fix `TRL_USE_RICH` environment variable handling (#1808)

    * Add `strtobool` custom implementation from `distutils`

    * Fix `TRL_USE_RICH` handling via `strtobool`

    * Run `make precommit`

commit 747612f9d3063de56b6524e5feb0c9feab21d4c4
Author: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date:   Fri Jul 5 16:28:59 2024 +0200

    Fix `torch_dtype` handling in `{DPO,SFT}Trainer` when provided via CLI (#1807)

    * Fix `torch_dtype` handling through CLI

    The `torch_dtype` is not properly handled when provided via the TRL CLI
    since it's provided initially as a string, but is then casted to
    `torch.dtype` before providing it to the `{DPO,SFT}Trainer`, which means
    that those trainers should handle the scenario where `torch_dtype` is a
    `torch.dtype` too.

    * Add `torch_dtype` tests in `test_{dpo,sft}_trainer.py`

    * Forward contribution credits

    * Run `make precommit`

    ---------

    Co-authored-by: Tash Srivastava <yash-srivastava19@users.noreply.github.com>

commit 9e3a35bd3d85ee506d180120f01bde2229b60265
Author: Michael <mnoukhov@gmail.com>
Date:   Fri Jul 5 07:29:48 2024 -0400

    Remove extra print in reward_trainer.py (#1799)

    `print_rich_table` is called twice and the first call doesn't restrict to `num_print_samples`. Remove the first, extra call

commit 4402b36dcf79a0921a858c77375cfbb285d603c7
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 4 14:29:25 2024 +0200

    clean examples (#1791)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 78f8228874d5cf9c0e68952533cb377202e1eb22
Author: Noah Tye <hi@noahtye.com>
Date:   Wed Jul 3 11:10:50 2024 -0700

    Bugfix: Preserve token fields when converting TrainingArguments to SFTConfig (#1794)

    * Preserve token fields when converting TrainingArguments to SFTConfig

    TrainingArguments.to_dict() redacts token fields, so we have to
    individually copy them over when converting to SFTConfig to avoid
    breaking push_to_hub functionality.

    Also adds a test.

    * run precommit

    * one-line args_as_dict definition per suggestion from kashif

    * generalize token copying to match TrainingArguments behavior

    * unwrap |= on dict, to support python 3.8

    * use .update instead of |= or for-loop

commit b6af2edc93b275afcee22a3eb71f9a5702ff9fd8
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 3 08:29:16 2024 +0200

    add model_init_kwargs to training_args (#1787)

commit cd85b14fbbaf7e4d9b01ef8ec19655666af20047
Author: Tommaso Buonocore <buonocore.tms@gmail.com>
Date:   Sat Jun 29 15:35:48 2024 +0200

    Fixed typo in SFT trainer docs (#1788)

    'STFConfig' instead of 'SFTConfig' appears multiple times in the doc, causing error when running the code snippets.

commit a57544f47a2fbc4940b4d49dde32f54406398c91
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Thu Jun 27 15:47:58 2024 +0200

    fix docs and examples (#1780)

commit b68ff96f0c74368961e194081e122959cd1f4d4d
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jun 26 16:26:37 2024 +0200

    Visual DPO (#1647)

    * Remove extra whitespaces

    * idefics

    * vdpo

    * sft idefics

    * pad with test

    * use prompt instead of tokenizer

    * rm name main

    * support vlm in tokenize row

    * temp fix for regex in lora_target_module

    * format

    * vdpo

    * tmp float16 hard code

    * concatenated_forward support for vision

    * style and new command line

    * all-linear

    * format

    * delete old examples

    * get image

    * upcast

    * new test

    * modified test

    * new strat for tokenizer

    * rm token transfer

    * integrate vision in dpo example

    * format

    * add FDivergenceType back

    * precommit

    * pillow test dep

    * optional prompt

    * `evaluation_strategy` to `eval_strategy`

    * revert vsft change (oos)

    * update test

    * test

    * comment and support more in process

    * update process

    * update doc for vdpo

    * caution about limited support

    * Update docs/source/dpo_trainer.mdx

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * revert DPO example changes

    * cleaner way to check if a model is vision

    * comment

    * update vdpo example

    * rename

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit c8c01cc05569f5ffea6726b2111f799a63e03aaa
Author: Mubin Manasia <48038715+Mubin17@users.noreply.github.com>
Date:   Wed Jun 26 03:23:36 2024 -0600

    Fix Documentation Overflow Issues for Long URLs in SFTConfig (#1774)

    * Update sft_config.py

    * Update sft_config.py

commit 3479606c8c6dbb5da96e4990b491e63a48fc7483
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed Jun 26 03:18:22 2024 -0400

    Remove the leading space in the tldr preference dataset (#1773)

commit 7965b7834052ab3d60a1cc5de382e2f56b3772e7
Author: Haozhe Ji <jihaozhe@gmail.com>
Date:   Tue Jun 25 22:47:32 2024 +0800

    add Efficient Exact Optimization (EXO) (#1735)

    * add exo

    * fix a detail

    * Update trl/trainer/dpo_trainer.py

    * Update trl/trainer/dpo_trainer.py

    * Update trl/trainer/dpo_trainer.py

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 56bd1bba26ac52aad976c1a1a0b3d9e1137b18c7
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jun 25 16:14:26 2024 +0200

    `evaluation_strategy` to `eval_strategy` (#1771)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 94d53e6617edc6434a38b2ac51c21e5da3329cda
Author: Clara Pohland <54847419+claralp@users.noreply.github.com>
Date:   Mon Jun 24 21:27:00 2024 +0200

    MoE Models: option to add load balancing loss (#1765)

    * KTO: add aux loss

    * use router_aux_loss_coef in KtoTrainer when aux_loss enabled

    * align optional aux_loss in DPO, KTO, CPO, ORPO

    * precommit changes

    * fix KL forward kwargs

    * add aux_loss doku entry

    * apply docs suggestions

    ---------

    Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>

commit b5be100ae0b37d743cd49435297f917eb54a0574
Author: Mihir Prabhudesai <mihirp1998.mp@gmail.com>
Date:   Mon Jun 24 12:05:44 2024 -0400

    Added Reward Backpropogation Support  (#1585)

    * added alignprop template

    * added alignprop support

    * Update alignprop_trainer.mdx

    * Update alignprop_trainer.mdx

    * added better why statement

    * fixed inference code

    * changed self to pipeline

    * removed aesthetic classifier

    * added aesthetic to auxiliary models

    * added unseen prompt logging

    * removed unseen prompt log

    * fixed minor

    * remove not needed import in trl/__init__.py

    Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

    * fixed styling

    * updated _toctree

    ---------

    Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

commit 6e1652bc5e8ff6d348c7f06048f4102a050f1544
Author: Haoran Xu <45837851+fe1ixxu@users.noreply.github.com>
Date:   Sun Jun 23 09:54:30 2024 -0700

    Add CPO-SimPO method (#1760)

    * enable cpo-simpo

    * highlight SimPO and CPO-SimPO

    * add test for cpo_alpha

    * formatting

    * Update docs/source/cpo_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 65374c6a711709157ea59297dce43dfb458d1c78
Author: Costa Huang <costa.huang@outlook.com>
Date:   Fri Jun 21 11:20:54 2024 -0400

    New sentiment and descriptiveness dataset (#1757)

    * push changes

    * handle edge cases where the chosen and the rejected are the same

commit 99560911123f739226b77813f27d5c90ed7f9ba2
Author: Juyoung Suk <scottsuk0306@gmail.com>
Date:   Fri Jun 21 18:01:08 2024 +0900

    Add dataset_text_field in examples/scripts/sft.py (#1758)

commit 34d273f227b30507c6d94ff1f93b6939794f38a3
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jun 20 13:16:43 2024 -0400

    Support num_train_epochs (#1743)

    * add a test case for num_train_epochs

    * fix ci

    * quick change

    * disable push to hub

    * debug windows ci

    * try another fix

    * skip subprocess tests on windows

commit 3bf94492a8dc84ac192f7c5206553e1460f53aa4
Author: Mert Sayar <mert.sayar@gmail.com>
Date:   Thu Jun 20 18:22:20 2024 +0300

    Fix masking of response tokens (#1718)

    Current handling of `response_masks` inside `batch_forward_pass`
    function does not take padding into consideration which results with
    shape unmatch during masking. Since response mask is a mask tensor of
    response tokens, response tokens should not be concatenated with a
    `torch.zeros(query_length)` and masking operation should be done without
    slicing.

    Remove the concatenation of the response mask, remove the slicing from
    the response mask since response mask already has the length of `end -
    start + 1`, which is equal to length of `masks[j, start:end]`.

commit ba6abee37f0f0463f6d891d63d0c2242039fc8ec
Author: idanshen <49375140+idanshen@users.noreply.github.com>
Date:   Thu Jun 20 09:14:16 2024 -0400

    Support for returning past_key_values from the model (#1742)

    * add support for returning past_key_values from the model

    * change order of  keys

commit a57e75967c2b787f42f4e402ed7ca23cd9bad9a9
Author: 1485840691 <110707330+1485840691@users.noreply.github.com>
Date:   Wed Jun 19 18:02:51 2024 +0800

    Integrate f-divergence to DPO (Follow up) (#1610)

    * Step 1: update ppo_trainer and hello_world example

    * Step 2: Refine comments and add parameter type

    * Step 2: Add missing parameter comments

    * Step 1: Organize ptx loss into a function and add ptx_loss to train_stats

    * Step 1 updates: add comment to ptx_loss function, fix a bug and add warning message

    * Step 2: 1) Add ppo_ptx trainig example as ppo; 2) separate pretrain data fetch and iterate

    * Step 2: Remove loss from columns_to_log in ppo_ptx example

    * Remove data set revision in load imbd dataset

    * Run pre-commit and fix format issues

    * Initial draft of f-divergence fn

    * Update f-divergence to avoid overflow

    * fix test errors and comments

    * Add Unit tests for dpo loss with alpha and js div f

    * Adjust format

    * Fix test error

    * Reverse this update

    * Add test cases

    * Reverse un-needed updates

    * Update code style

    * Try to fix code fmt error

    * remove extra end line

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit ae23d40f3b4d91d60a6153825ecf0319449d34b1
Author: Shihyueh Hsu <66808901+AIR-hl@users.noreply.github.com>
Date:   Tue Jun 18 22:07:24 2024 +0800

    change the `process` function in the example of DPO (#1753)

    * change the `process` function in the example of DPO

    * fix

commit 83b367b11a308b488ff9ddcf19cf4cfd6a7db642
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Tue Jun 18 11:31:17 2024 +0200

    CI / `KTOTrainer`: Remove old tests (#1750)

    * remove old tests

    * remove datasets

    * Update test_dpo_trainer.py

    * Update test_dpo_trainer.py

commit d1ed730ab8281b1b0c78d7d61bc0f6603a9ce958
Author: Michael <mnoukhov@gmail.com>
Date:   Mon Jun 17 10:50:21 2024 -0400

    prepare deepspeed accomodate fp16 and bf16 (#1728)

    * prepare deepspeed accomodate fp16 and bf16

    * precommit

commit 8f8e95e25d10c433cc1f2f8c7dcfed218bb13ac7
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 16:49:00 2024 +0200

    CPO / DPO: Fix red CI (#1749)

    * fix red CI

    * precommit

commit 4e23d958f20fd4fdd795cb06c2cdb7ebea704855
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 16:41:36 2024 +0200

    fix red CI

commit 50c46205b6fe741f11959adf7ec9cc0386f406bc
Author: Kawin <kawin.ethayarajh@gmail.com>
Date:   Mon Jun 17 07:14:44 2024 -0700

    small KTO fixes (#1734)

    * add warning for imbalanced data

    * update documentation

    * update script commands to be same as in dpo

    * use batch_size KL examples and batch_size target examples to calculate batch_size losses

    * fix deepspeed issue

    * speed up forward with no_grad for KL

    * add some removed metrics

    * Update trl/trainer/kto_trainer.py

    * Update trl/trainer/kto_trainer.py

    * Update trl/trainer/kto_trainer.py

    add reference to paper

    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * add more detailed comments

    * convert assert to ValueError

    * Update kto_trainer.py

    * precommit formatting

    * remove nans in metrics by gathering across machines

    * fix formatting

    * fix choice of mismatched examples for KL term

    * describe weights

    * fix hanging issue in distributed training

    * linting

    * move metrics to cpu

    * Update trl/trainer/kto_trainer.py

    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

    * Update trl/trainer/kto_trainer.py

    * Update trl/trainer/kto_trainer.py

    * remove kto_pair

    * speed up data processing

    * move bco code inside

    * raise error for kto_pair argument

    * fix formatting

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
    Co-authored-by: Winnie Xu <winnie.xu97@gmail.com>

commit 6105d03f92e7069ffaa565d05418dec371569e6a
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 16:01:06 2024 +0200

    `TrlParser`: Add ignore extra args option (#1748)

    * add ignore extra args option

    * Update trl/commands/cli_utils.py

commit e247bbd7d5f57f8012ca71cfef6ad6a589874c34
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 15:16:07 2024 +0200

    CI / core: Pin `numpy` to `!=2.0.0` for CI and to users (#1747)

    * Update setup.py

    * Update setup.py

    * Update setup.py

    * Update test_best_of_n_sampler.py

    dummy commit

    * pin numpy

    * Update tests/test_best_of_n_sampler.py

    * Update setup.py

commit 3d044961960a2ab1ec1f51cfe62c6bf6b9a94807
Author: Michael <mnoukhov@gmail.com>
Date:   Mon Jun 17 08:43:33 2024 -0400

    better trl parser with yaml config (#1739)

    * working trl parser with config

    correctly overrides yaml config with command line arguments
    adds return_remaining_strings
    when return_remaining_strings is False, raises error if yaml contains
    extra args that are not in the dataclasses
    simpler and cleaner than previous yaml parsing and merging
    addresses #1733

    * lowercase trlparser

commit 2d244f8acb204cb2ddb83a4ef017ca4b1f2d366a
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Mon Jun 17 11:56:13 2024 +0200

    Workflow: Notify tests results on slack channel (#1744)

    * Update tests-main.yml

    * Update docker-build.yml

commit f5168fdbaf9cbf6a3f1bdc64dc44b9db3a9ae333
Author: Igor Melnyk <igoraries@gmail.com>
Date:   Wed Jun 12 05:54:54 2024 -0400

    adds AOT (#1701)

    * adds AOT

    * Applied format changes

    * added docs and tests

    ---------

    Co-authored-by: Igor Melnyk <igor.melnyk@ibm.com>

commit 79686e1ac701b1f5e3709a65efa8f13363bcde06
Author: jetlime <paul.houssel@yahoo.de>
Date:   Wed Jun 12 00:35:31 2024 +1000

    ktotrainer: Refuse datasets which contain only one class of labels (#1724)

    * ktotrainer: refuse dataset which contain only one class of labels

    * ktotrainer: document new dataset constraint

commit 34ebc4ccaf376c862a081ff4bb0b7e502b17b2fb
Author: Luc Georges <McPatate@users.noreply.github.com>
Date:   Mon Jun 10 11:17:54 2024 +0200

    feat(ci): add trufflehog secrets detection (#1721)

    * feat(ci): add trufflehog secrets detection

    * fix(ci): remove unnecessary permissions

commit 1d84e2b888ea0f3c1ce9c5c175f7f680d85273a8
Author: Michael <mnoukhov@gmail.com>
Date:   Fri Jun 7 11:42:08 2024 +0200

    Fix default padding_value in dpo_config.py (#1692)

    dpo_config default padding value should be None, not 0, otherwise it by default overrides the padding value of any tokenizer to 0

commit 2f71b8b1e2e54184cc278f267cca1bda051f68ea
Author: Michael <mnoukhov@gmail.com>
Date:   Fri Jun 7 10:37:27 2024 +0200

    fix yaml parser for derived config classes (#1713)

    fixes #1712
    reformatted cli_utils with ruff

commit 5bcb8ad0d6eaee1b1d2f993380100c37c4421fd0
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Fri Jun 7 08:48:17 2024 +0100

    RDPO fix nll loss (#1705)

commit b8b972fde183ec036885738e1439cd99877c2ad5
Author: Haoran Xu <45837851+fe1ixxu@users.noreply.github.com>
Date:   Thu Jun 6 14:06:47 2024 -0700

    Add a variant of CPO, SimPO (#1703)

    * add a variant of cpo: simpo

    * correct cpo-simpo loss

    * avoid 0 int error in logging

    * add simpo description

    * Update trl/trainer/cpo_trainer.py

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * fix formatting

    * add test for simpo

    * Update docs/source/cpo_trainer.mdx

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * add a docstring for simpogamma

    * move simpo description to the above docstring

    * change simpo description in the doc

    * formatting

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 3eb9ccb104e2c46360adb937f3f25871c167eb90
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Thu Jun 6 19:33:20 2024 +0200

    set dev version (#1710)

    * Update setup.py

    * Update __init__.py

commit 974b0d380f12c357b70265c5f2dd2c8cb39a6a3e
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jun 6 10:13:00 2024 -0400

    0.9.4 release (#1708)

commit 39a7d1c121d26224fd7455d3d2038e0d20831c54
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Thu Jun 6 15:50:17 2024 +0200

    SFTTrainer: Fix backward Compatibility issue with `TrainingArguments` (#1707)

    * fix BC

    * fixup

commit 0bdc63839f1abe67c56befa63251425b1ffc1ace
Author: Guilherme Freire <guilhermebfreire@gmail.com>
Date:   Thu Jun 6 14:42:58 2024 +0100

    Fixed doc string and docs for the SFTConfig update (#1706)

commit 275d33b3ef4f7afd40f79cc53591659bacfa3499
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed Jun 5 14:34:59 2024 -0400

    0.9.3 release (#1699)

commit c0819ee99fdf673e9843ef91789b928ae9050623
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Wed Jun 5 17:29:03 2024 +0200

    Update sft_trainer.py (#1698)

commit a03e7cc4e443e30eea942ca66bfce19407784f32
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed Jun 5 11:00:19 2024 -0400

    Release 0.9.2 (#1697)

    * Release: 0.9.0

    * Release

commit a13cb8952c55cfa4fc696d900a1b2a81d329c82d
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed Jun 5 10:20:54 2024 -0400

    Quick fix on GPT4-eval (#1696)

    * quick fix

    * precommit

commit 84156f179f91f519e48185414391d040112f2d34
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Jun 3 20:09:05 2024 +0200

    Fix typo in DPOTrainer's warnings (#1688)

commit 4eb0b905e28857341123d5329a6ca1b9d929734f
Author: Alex Brooks <alex.brooks@ibm.com>
Date:   Mon Jun 3 10:24:32 2024 -0600

    Skip packing validation (#1673)

    * Add test for skipping preproc if packing=True

    Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

    * Allow skipping of validation for packing=True

    Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

    * Use dummy dataset in no packing preproc test

    Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

    ---------

    Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

commit 6c203f9fef50c41d27fc4ed9965df7e458f02377
Author: Alexey Rozhkov <alexisrozhkov@gmail.com>
Date:   Mon Jun 3 10:16:22 2024 +0100

    Fix overriding optimize_device_cache with optimize_cuda_cache in PPOConfig (#1690)

    * Don't override optimize_device_cache when optimize_cuda_cache is not provided
    Raise an exception when both optimize_cuda_cache and optimize_device_cache are set

    * Minor fix

commit f18253bf2d747f68acc9cd89da95c85ebf59dbb9
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Mon Jun 3 09:43:02 2024 +0100

    intial RPO loss (#1686)

    * intial RPO loss

    * fix sign

    * clean up

commit 151a452d14c8ebccbaf8a033812ceb2dc77f634d
Author: Samuel <s.kiegeland@gmx.de>
Date:   Wed May 29 20:29:38 2024 +0200

    Fix max completion length (#1588)

commit 488b502d31c052801eacd9a047bf3db06623e9c2
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Wed May 29 20:19:26 2024 +0200

    fix (#1678)

commit 3c0a10b1aedbe533005dbfe18f2cc8057093f80b
Author: Wang, Yi <yi.a.wang@intel.com>
Date:   Mon May 27 20:52:20 2024 +0800

    fix dataset load error (#1670)

    Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

commit b031adfdb8708f1f295eab6c3f2cb910e8fe0c23
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Fri May 24 15:20:16 2024 +0200

    FIX / PPO: Fix `enable_input_require_grads` issues with PPO models (#1664)

    * Update modeling_base.py

    * Update ppo_config.py

    * Update ppo_trainer.py

    * style

commit e7cb597230bb0c630c67790881b0808f7b16cb05
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu May 23 11:37:16 2024 -0400

    Fix ppov2 test case (#1661)

    * Fix PPOv2 / RLOO refactor's stuff

    * update terminology to use stop token

commit bc8dfbf4e2169010b3094913a1fa4f888f750111
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Thu May 23 15:28:04 2024 +0200

    update eval_strategy (#1662)

commit e4ed7a3a5aa0f1e1b4f78317b3c7b25e5bf597f4
Author: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Date:   Thu May 23 18:34:22 2024 +0530

    do not upcast adapters when using FSDP+QLoRA (#1654)

commit 9a7efbd05126fa6a1448a95f670e8d04cac90d62
Author: syrn1k <85796210+syrn1k@users.noreply.github.com>
Date:   Thu May 23 15:58:49 2024 +0300

    🤫 TR-DPO implementation (#1593)

    * 🤫 TR-DPO implementation baseline

    * fix comments

    * docs

    * fix linters

    * test added

    * move configs to DPOConfig

    * fix typo

    * add docs

    * fix import

    * use state.global_step

    * fix order of arguments

    * make sure plugins are not none

    * Update trl/trainer/utils.py

    Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

    * Update trl/trainer/utils.py

    Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

    * checking that reference model weights have changed

    * sync_target_model as staticmethod

    * set reference model

    ---------

    Co-authored-by: Nikita Surnachev <n.surnachev@tinkoff.ru>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

commit b344bcea2c0b30d58ab6ebb0380647f24056ac58
Author: Anush Kini <33577829+Abilityguy@users.noreply.github.com>
Date:   Thu May 23 18:27:25 2024 +0530

    [DPO] Add 'robust' loss_type (#1653)

    * Initial commit

    * pre-commit fix

    * Minor change to comments

    * Added some documentation on how to use Robust DPO

commit 35e12dc5959fa8a08edd72b34aadcb0acb284e51
Author: Nicolinho <Nicolinho@users.noreply.github.com>
Date:   Thu May 23 14:36:15 2024 +0200

    Fix inheritance order in PPOv2Config (#1659)

    * fix inheritance order in PPOv2Config

    * fix inheritance order in rloo_config

commit 1da6be18e0e21a11ee2a2121ae744c5e2e904409
Author: Ali Bakly <anbakly@gmail.com>
Date:   Thu May 23 14:10:29 2024 +0200

    docs: correct cDPO usage in DPOTrainer (#1655)

commit e249cd802fb81cff3c4ceb1427cb666a138221d3
Author: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date:   Thu May 23 14:10:05 2024 +0200

    add support for training collator (#1658)

commit a02513c3b7085adba5fd18727296f4f4affd3ffb
Author: Zach Mueller <muellerzr@gmail.com>
Date:   Thu May 23 06:48:00 2024 -0400

    Apply deprecated `evaluation_strategy` (#1559)

    * Deprecate

    * Update tests/test_dpo_trainer.py

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 13454d2f4b243b7260fa4ec828297812c3f975fc
Author: Costa Huang <costa.huang@outlook.com>
Date:   Wed May 22 08:31:10 2024 -0400

    PPO / Reinforce Trainers (#1540)

    * Add ppov2 trainer

    * make eos trick optional, remove unused args

    * quick fix

    * precommit

    * update debugging script

    * fix out of bound `drop_last=True`; use built-in scheduler

    * Add PPO examples

    * push changes

    * quick change

    * quick change

    * various bug fixes

    * remove unnecessary grad accumulation setting

    * push new changes

    * fix DS3 model saving

    * update ppo.py

    * refactor

    * quick change

    * refactor

    * update ppo trainer

    * refactor

    * quick test

    * add ds2 /ds3 7 processes config

    * add vllm trainer

    * quick change

    * experiment with reward normalization

    * push changes

    * quick push

    * push changes

    * push various changes

    * refactor to use ModelConfig

    * quick change

    * refactor

    * refactor

    * Simplify DS logic

    * quick update

    * remove unnecessary files

    * precommit

    * deepspeed fix; handle edge case when eos_token_id = 0

    * add PPO tldr example

    * add TL;DR example

    * fix undefined var

    * utilize all samples in rloo

    * quick setting

    * remove the unnecessary `value_model`

    * use exact_div

    * allow saving the deepspeed model

    * refactor

    * remove dead code

    * Use some shared utilities

    * add some end-to-end test cases

    * add PPOv2 docs and RLOO docs / tests

    * update docs

    * quikc push

    * fix ci

    * fix type annotation for ci

    * quick update

    * update trainer docs

commit 99f2c94b2200927a1dc156f16e012dca11f865e1
Author: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Date:   Wed May 15 19:55:46 2024 +0530

    don't cast the trainable lora layers to half precision (#1644)

    * don't cast the trainable lora layers to half precision

    * quality

commit 6401d080c9f97e0610678b12d3d0056347675726
Author: Wing Lian <wing.lian@gmail.com>
Date:   Tue May 14 09:41:07 2024 -0400

    Pairwise Noise Contrastive Alignment (#1632)

    * add NCA paired preference loss

    * chore: lint

    * set more lenient tolerance for integration tests

    * Update tests/test_dpo_trainer.py

    * skip test

    * fix

    ---------

    Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
    Co-authored-by: younesbelkada <younesbelkada@gmail.com>

commit d632a5b289782c7384f5275426054e79acc0b744
Author: bartoszzuk <57541034+bartoszzuk@users.noreply.github.com>
Date:   Tue May 14 12:25:54 2024 +0200

    Fixed wrong logs prefixes in KTOTrainer (#1641)

    * Fixed wrong logs prefixes in KTOTrainer

    * Pre-commit formating

commit 5aeb752053876cce64f2164a178635db08d96158
Author: Tiezhen WANG <38108242+xianbaoqian@users.noreply.github.com>
Date:   Fri May 10 23:19:15 2024 +0800

    Update sft_llama2.py to work with the latest API (#1637)

    * Update sft_llama2.py to work with the latest API

    SFTTrainer now takes a STFConfig argument

    * Update dpo_llama2.py

    * precommit

commit b8b89783ca1ab081d25651a9a13e9358cc8e1869
Author: Ilya Gusev <phoenixilya@gmail.com>
Date:   Fri May 10 15:43:13 2024 +0200

    [ORPO] Correct label mask for pad tokens (#1625)

    * [ORPO] Correct label mask for pad tokens

    Recent [fix](57aebe9c36) for calculating NLL loss for a whole sequence introduced a bug. When input_ids are copied to labels, pad tokens are not masked.

    This PR aims to path this by masking labels based on the attention mask.

    * -100 -> label_pad_token_id

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit 8799952876631d7c772ac80f9cbcff155da960e2
Author: Costa Huang <costa.huang@outlook.com>
Date:   Fri May 10 09:32:20 2024 -0400

    visualize rm prediction (#1636)

    * visualize rm prediction

    * quick update

    * quick check

    * quick fix

    * update eval steps

commit 3b4c24946b7d5580fd354b0e3800fc1047b82a41
Author: Xiao Yu <39458711+jasonyux@users.noreply.github.com>
Date:   Fri May 3 18:19:35 2024 -0400

    fixed adding bos and eos token unconditionally (#1591)

    * fixed adding bos and eos token unconditionally

    * fixed typo of tokenizer -> self.tokenizer. Also added update to ORPO

    * fixed code quality, and added BOS/EOS fix to KTO

    * code reformatting with pre-commit run --all-files

    * bug fix: check input id length before checking for EOS/BOS

commit 0347f583e3883f9144a959d1e6f748a4cc91cd09
Author: lewtun <lewis.c.tunstall@gmail.com>
Date:   Fri May 3 15:59:59 2024 +0200

    Fix ZeRO-3 generation context manager (#1617)

* judge refactoring and unittest

* format

* init

* doc

* format

* improve doc

* basejudge

* improve doc and add BaseAPIJudge

* Doc

* style

* refactor callback

* remove openai and pairrm judge from test

* doc

* rm dpo online example

* new prompts and completions

* skip hf judge and add hf token

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-07-18 15:16:59 +02:00
98ad01ddfd dpo vlm blog post (#1844)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-17 18:03:49 +02:00
fef8240c23 fix arg parsing in chat.py (#1846)
Co-authored-by: leandro <leandro.vonwerra@spoud.io>
2024-07-17 17:32:17 +02:00
915ffc7c61 add link to DPO datasets collection (#1845) 2024-07-17 11:18:35 -04:00
5828a666bf Fix issues of KTOTrainer (#1840)
* Fix issues of KTOTrainer

* Update trl/trainer/kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-07-17 08:46:14 +02:00
052a8e14b5 fix ppov2_trainer tensorboard log bugs (#1836) 2024-07-16 16:08:15 +02:00
a2adfb836a ref_model -> model_ref (#1835)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-15 18:50:29 +02:00
4ebfc5de28 refactor trainer callbacks (#1826)
* refactor trainer callbacks

* fix import

* more import fixes
2024-07-15 11:07:16 -04:00
9e9dc96e67 Added missing token kwarg in Peft model loading (#1825) 2024-07-10 19:11:13 +02:00
7ddef5c158 Make use of trust_remote_code consistent (#1806)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-10 18:26:11 +02:00
a9cddf8c55 Delete unused benchmark.yml workflow. (#1822) 2024-07-10 11:25:07 -04:00
2860ce5091 DPO Llava 1.5 and PaliGemma support (#1797)
* llava support dpo

* add_special_tokens=False only when possible

* format

* pali gemma

* refactor size

* remove image resize

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-09 09:22:52 +02:00
30e33bd92d upgrade gh actions (#1818)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-08 23:37:12 -04:00
d5a0d2d345 Set dev version (#1817) 2024-07-08 11:12:41 -04:00
314e8eb367 fix broken url in docs\source\index.mdx (#1813) 2024-07-08 15:41:36 +02:00
e10792032b 0.9.6 release (#1816) 2024-07-08 09:38:09 -04:00
78045dedc8 Fix TRL_USE_RICH environment variable handling (#1808)
* Add `strtobool` custom implementation from `distutils`

* Fix `TRL_USE_RICH` handling via `strtobool`

* Run `make precommit`
2024-07-07 19:59:26 -04:00
747612f9d3 Fix torch_dtype handling in {DPO,SFT}Trainer when provided via CLI (#1807)
* Fix `torch_dtype` handling through CLI

The `torch_dtype` is not properly handled when provided via the TRL CLI
since it's provided initially as a string, but is then casted to
`torch.dtype` before providing it to the `{DPO,SFT}Trainer`, which means
that those trainers should handle the scenario where `torch_dtype` is a
`torch.dtype` too.

* Add `torch_dtype` tests in `test_{dpo,sft}_trainer.py`

* Forward contribution credits

* Run `make precommit`

---------

Co-authored-by: Tash Srivastava <yash-srivastava19@users.noreply.github.com>
2024-07-05 16:28:59 +02:00
9e3a35bd3d Remove extra print in reward_trainer.py (#1799)
`print_rich_table` is called twice and the first call doesn't restrict to `num_print_samples`. Remove the first, extra call
2024-07-05 13:29:48 +02:00
4402b36dcf clean examples (#1791)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-07-04 14:29:25 +02:00
78f8228874 Bugfix: Preserve token fields when converting TrainingArguments to SFTConfig (#1794)
* Preserve token fields when converting TrainingArguments to SFTConfig

TrainingArguments.to_dict() redacts token fields, so we have to
individually copy them over when converting to SFTConfig to avoid
breaking push_to_hub functionality.

Also adds a test.

* run precommit

* one-line args_as_dict definition per suggestion from kashif

* generalize token copying to match TrainingArguments behavior

* unwrap |= on dict, to support python 3.8

* use .update instead of |= or for-loop
2024-07-03 20:10:50 +02:00
b6af2edc93 add model_init_kwargs to training_args (#1787) 2024-07-03 08:29:16 +02:00
cd85b14fbb Fixed typo in SFT trainer docs (#1788)
'STFConfig' instead of 'SFTConfig' appears multiple times in the doc, causing error when running the code snippets.
2024-06-29 15:35:48 +02:00
a57544f47a fix docs and examples (#1780) 2024-06-27 15:47:58 +02:00
b68ff96f0c Visual DPO (#1647)
* Remove extra whitespaces

* idefics

* vdpo

* sft idefics

* pad with test

* use prompt instead of tokenizer

* rm name main

* support vlm in tokenize row

* temp fix for regex in lora_target_module

* format

* vdpo

* tmp float16 hard code

* concatenated_forward support for vision

* style and new command line

* all-linear

* format

* delete old examples

* get image

* upcast

* new test

* modified test

* new strat for tokenizer

* rm token transfer

* integrate vision in dpo example

* format

* add FDivergenceType back

* precommit

* pillow test dep

* optional prompt

* `evaluation_strategy` to `eval_strategy`

* revert vsft change (oos)

* update test

* test

* comment and support more in process

* update process

* update doc for vdpo

* caution about limited support

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* revert DPO example changes

* cleaner way to check if a model is vision

* comment

* update vdpo example

* rename

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-26 16:26:37 +02:00
c8c01cc055 Fix Documentation Overflow Issues for Long URLs in SFTConfig (#1774)
* Update sft_config.py

* Update sft_config.py
2024-06-26 11:23:36 +02:00
3479606c8c Remove the leading space in the tldr preference dataset (#1773) 2024-06-26 09:18:22 +02:00
7965b78340 add Efficient Exact Optimization (EXO) (#1735)
* add exo

* fix a detail

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-25 16:47:32 +02:00
56bd1bba26 evaluation_strategy to eval_strategy (#1771)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-06-25 10:14:26 -04:00
94d53e6617 MoE Models: option to add load balancing loss (#1765)
* KTO: add aux loss

* use router_aux_loss_coef in KtoTrainer when aux_loss enabled

* align optional aux_loss in DPO, KTO, CPO, ORPO

* precommit changes

* fix KL forward kwargs

* add aux_loss doku entry

* apply docs suggestions

---------

Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-06-24 21:27:00 +02:00
b5be100ae0 Added Reward Backpropogation Support (#1585)
* added alignprop template

* added alignprop support

* Update alignprop_trainer.mdx

* Update alignprop_trainer.mdx

* added better why statement

* fixed inference code

* changed self to pipeline

* removed aesthetic classifier

* added aesthetic to auxiliary models

* added unseen prompt logging

* removed unseen prompt log

* fixed minor

* remove not needed import in trl/__init__.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fixed styling

* updated _toctree

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-06-24 12:05:44 -04:00
6e1652bc5e Add CPO-SimPO method (#1760)
* enable cpo-simpo

* highlight SimPO and CPO-SimPO

* add test for cpo_alpha

* formatting

* Update docs/source/cpo_trainer.mdx

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-23 18:54:30 +02:00
65374c6a71 New sentiment and descriptiveness dataset (#1757)
* push changes

* handle edge cases where the chosen and the rejected are the same
2024-06-21 11:20:54 -04:00
9956091112 Add dataset_text_field in examples/scripts/sft.py (#1758) 2024-06-21 11:01:08 +02:00
34d273f227 Support num_train_epochs (#1743)
* add a test case for num_train_epochs

* fix ci

* quick change

* disable push to hub

* debug windows ci

* try another fix

* skip subprocess tests on windows
2024-06-20 13:16:43 -04:00
3bf94492a8 Fix masking of response tokens (#1718)
Current handling of `response_masks` inside `batch_forward_pass`
function does not take padding into consideration which results with
shape unmatch during masking. Since response mask is a mask tensor of
response tokens, response tokens should not be concatenated with a
`torch.zeros(query_length)` and masking operation should be done without
slicing.

Remove the concatenation of the response mask, remove the slicing from
the response mask since response mask already has the length of `end -
start + 1`, which is equal to length of `masks[j, start:end]`.
2024-06-20 11:22:20 -04:00
ba6abee37f Support for returning past_key_values from the model (#1742)
* add support for returning past_key_values from the model

* change order of  keys
2024-06-20 09:14:16 -04:00
a57e75967c Integrate f-divergence to DPO (Follow up) (#1610)
* Step 1: update ppo_trainer and hello_world example

* Step 2: Refine comments and add parameter type

* Step 2: Add missing parameter comments

* Step 1: Organize ptx loss into a function and add ptx_loss to train_stats

* Step 1 updates: add comment to ptx_loss function, fix a bug and add warning message

* Step 2: 1) Add ppo_ptx trainig example as ppo; 2) separate pretrain data fetch and iterate

* Step 2: Remove loss from columns_to_log in ppo_ptx example

* Remove data set revision in load imbd dataset

* Run pre-commit and fix format issues

* Initial draft of f-divergence fn

* Update f-divergence to avoid overflow

* fix test errors and comments

* Add Unit tests for dpo loss with alpha and js div f

* Adjust format

* Fix test error

* Reverse this update

* Add test cases

* Reverse un-needed updates

* Update code style

* Try to fix code fmt error

* remove extra end line

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-19 12:02:51 +02:00
ae23d40f3b change the process function in the example of DPO (#1753)
* change the `process` function in the example of DPO

* fix
2024-06-18 10:07:24 -04:00
83b367b11a CI / KTOTrainer: Remove old tests (#1750)
* remove old tests

* remove datasets

* Update test_dpo_trainer.py

* Update test_dpo_trainer.py
2024-06-18 11:31:17 +02:00
d1ed730ab8 prepare deepspeed accomodate fp16 and bf16 (#1728)
* prepare deepspeed accomodate fp16 and bf16

* precommit
2024-06-17 10:50:21 -04:00
8f8e95e25d CPO / DPO: Fix red CI (#1749)
* fix red CI

* precommit
2024-06-17 10:49:00 -04:00
4e23d958f2 fix red CI 2024-06-17 16:41:36 +02:00
50c46205b6 small KTO fixes (#1734)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* remove kto_pair

* speed up data processing

* move bco code inside

* raise error for kto_pair argument

* fix formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Winnie Xu <winnie.xu97@gmail.com>
2024-06-17 10:14:44 -04:00
6105d03f92 TrlParser: Add ignore extra args option (#1748)
* add ignore extra args option

* Update trl/commands/cli_utils.py
2024-06-17 16:01:06 +02:00
e247bbd7d5 CI / core: Pin numpy to !=2.0.0 for CI and to users (#1747)
* Update setup.py

* Update setup.py

* Update setup.py

* Update test_best_of_n_sampler.py

dummy commit

* pin numpy

* Update tests/test_best_of_n_sampler.py

* Update setup.py
2024-06-17 15:16:07 +02:00
3d04496196 better trl parser with yaml config (#1739)
* working trl parser with config

correctly overrides yaml config with command line arguments
adds return_remaining_strings
when return_remaining_strings is False, raises error if yaml contains
extra args that are not in the dataclasses
simpler and cleaner than previous yaml parsing and merging
addresses #1733

* lowercase trlparser
2024-06-17 14:43:33 +02:00
2d244f8acb Workflow: Notify tests results on slack channel (#1744)
* Update tests-main.yml

* Update docker-build.yml
2024-06-17 11:56:13 +02:00
f5168fdbaf adds AOT (#1701)
* adds AOT

* Applied format changes

* added docs and tests

---------

Co-authored-by: Igor Melnyk <igor.melnyk@ibm.com>
2024-06-12 11:54:54 +02:00
79686e1ac7 ktotrainer: Refuse datasets which contain only one class of labels (#1724)
* ktotrainer: refuse dataset which contain only one class of labels

* ktotrainer: document new dataset constraint
2024-06-11 16:35:31 +02:00
34ebc4ccaf feat(ci): add trufflehog secrets detection (#1721)
* feat(ci): add trufflehog secrets detection

* fix(ci): remove unnecessary permissions
2024-06-10 11:17:54 +02:00
1d84e2b888 Fix default padding_value in dpo_config.py (#1692)
dpo_config default padding value should be None, not 0, otherwise it by default overrides the padding value of any tokenizer to 0
2024-06-07 11:42:08 +02:00
2f71b8b1e2 fix yaml parser for derived config classes (#1713)
fixes #1712
reformatted cli_utils with ruff
2024-06-07 10:37:27 +02:00
5bcb8ad0d6 RDPO fix nll loss (#1705) 2024-06-07 09:48:17 +02:00
b8b972fde1 Add a variant of CPO, SimPO (#1703)
* add a variant of cpo: simpo

* correct cpo-simpo loss

* avoid 0 int error in logging

* add simpo description

* Update trl/trainer/cpo_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* fix formatting

* add test for simpo

* Update docs/source/cpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add a docstring for simpogamma

* move simpo description to the above docstring

* change simpo description in the doc

* formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-06-06 17:06:47 -04:00
3eb9ccb104 set dev version (#1710)
* Update setup.py

* Update __init__.py
2024-06-06 13:33:20 -04:00
974b0d380f 0.9.4 release (#1708) 2024-06-06 10:13:00 -04:00
39a7d1c121 SFTTrainer: Fix backward Compatibility issue with TrainingArguments (#1707)
* fix BC

* fixup
2024-06-06 09:50:17 -04:00
0bdc63839f Fixed doc string and docs for the SFTConfig update (#1706) 2024-06-06 09:42:58 -04:00
275d33b3ef 0.9.3 release (#1699) 2024-06-05 14:34:59 -04:00
c0819ee99f Update sft_trainer.py (#1698) 2024-06-05 11:29:03 -04:00
a03e7cc4e4 Release 0.9.2 (#1697)
* Release: 0.9.0

* Release
2024-06-05 11:00:19 -04:00
a13cb8952c Quick fix on GPT4-eval (#1696)
* quick fix

* precommit
2024-06-05 10:20:54 -04:00
84156f179f Fix typo in DPOTrainer's warnings (#1688) 2024-06-03 14:09:05 -04:00
4eb0b905e2 Skip packing validation (#1673)
* Add test for skipping preproc if packing=True

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Allow skipping of validation for packing=True

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* Use dummy dataset in no packing preproc test

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

---------

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
2024-06-03 18:24:32 +02:00
6c203f9fef Fix overriding optimize_device_cache with optimize_cuda_cache in PPOConfig (#1690)
* Don't override optimize_device_cache when optimize_cuda_cache is not provided
Raise an exception when both optimize_cuda_cache and optimize_device_cache are set

* Minor fix
2024-06-03 11:16:22 +02:00
f18253bf2d intial RPO loss (#1686)
* intial RPO loss

* fix sign

* clean up
2024-06-03 09:43:02 +01:00
151a452d14 Fix max completion length (#1588) 2024-05-29 20:29:38 +02:00
488b502d31 fix (#1678) 2024-05-29 20:19:26 +02:00
3c0a10b1ae fix dataset load error (#1670)
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
2024-05-27 14:52:20 +02:00
b031adfdb8 FIX / PPO: Fix enable_input_require_grads issues with PPO models (#1664)
* Update modeling_base.py

* Update ppo_config.py

* Update ppo_trainer.py

* style
2024-05-24 15:20:16 +02:00
e7cb597230 Fix ppov2 test case (#1661)
* Fix PPOv2 / RLOO refactor's stuff

* update terminology to use stop token
2024-05-23 11:37:16 -04:00
bc8dfbf4e2 update eval_strategy (#1662) 2024-05-23 15:28:04 +02:00
e4ed7a3a5a do not upcast adapters when using FSDP+QLoRA (#1654) 2024-05-23 15:04:22 +02:00
9a7efbd051 🤫 TR-DPO implementation (#1593)
* 🤫 TR-DPO implementation baseline

* fix comments

* docs

* fix linters

* test added

* move configs to DPOConfig

* fix typo

* add docs

* fix import

* use state.global_step

* fix order of arguments

* make sure plugins are not none

* Update trl/trainer/utils.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update trl/trainer/utils.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* checking that reference model weights have changed

* sync_target_model as staticmethod

* set reference model

---------

Co-authored-by: Nikita Surnachev <n.surnachev@tinkoff.ru>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
2024-05-23 14:58:49 +02:00
b344bcea2c [DPO] Add 'robust' loss_type (#1653)
* Initial commit

* pre-commit fix

* Minor change to comments

* Added some documentation on how to use Robust DPO
2024-05-23 14:57:25 +02:00
35e12dc595 Fix inheritance order in PPOv2Config (#1659)
* fix inheritance order in PPOv2Config

* fix inheritance order in rloo_config
2024-05-23 08:36:15 -04:00
1da6be18e0 docs: correct cDPO usage in DPOTrainer (#1655) 2024-05-23 08:10:29 -04:00
e249cd802f add support for training collator (#1658) 2024-05-23 08:10:05 -04:00
a02513c3b7 Apply deprecated evaluation_strategy (#1559)
* Deprecate

* Update tests/test_dpo_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-05-23 12:48:00 +02:00
13454d2f4b PPO / Reinforce Trainers (#1540)
* Add ppov2 trainer

* make eos trick optional, remove unused args

* quick fix

* precommit

* update debugging script

* fix out of bound `drop_last=True`; use built-in scheduler

* Add PPO examples

* push changes

* quick change

* quick change

* various bug fixes

* remove unnecessary grad accumulation setting

* push new changes

* fix DS3 model saving

* update ppo.py

* refactor

* quick change

* refactor

* update ppo trainer

* refactor

* quick test

* add ds2 /ds3 7 processes config

* add vllm trainer

* quick change

* experiment with reward normalization

* push changes

* quick push

* push changes

* push various changes

* refactor to use ModelConfig

* quick change

* refactor

* refactor

* Simplify DS logic

* quick update

* remove unnecessary files

* precommit

* deepspeed fix; handle edge case when eos_token_id = 0

* add PPO tldr example

* add TL;DR example

* fix undefined var

* utilize all samples in rloo

* quick setting

* remove the unnecessary `value_model`

* use exact_div

* allow saving the deepspeed model

* refactor

* remove dead code

* Use some shared utilities

* add some end-to-end test cases

* add PPOv2 docs and RLOO docs / tests

* update docs

* quikc push

* fix ci

* fix type annotation for ci

* quick update

* update trainer docs
2024-05-22 08:31:10 -04:00
99f2c94b22 don't cast the trainable lora layers to half precision (#1644)
* don't cast the trainable lora layers to half precision

* quality
2024-05-15 16:25:46 +02:00
6401d080c9 Pairwise Noise Contrastive Alignment (#1632)
* add NCA paired preference loss

* chore: lint

* set more lenient tolerance for integration tests

* Update tests/test_dpo_trainer.py

* skip test

* fix

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
2024-05-14 15:41:07 +02:00
d632a5b289 Fixed wrong logs prefixes in KTOTrainer (#1641)
* Fixed wrong logs prefixes in KTOTrainer

* Pre-commit formating
2024-05-14 12:25:54 +02:00
5aeb752053 Update sft_llama2.py to work with the latest API (#1637)
* Update sft_llama2.py to work with the latest API

SFTTrainer now takes a STFConfig argument

* Update dpo_llama2.py

* precommit
2024-05-10 17:19:15 +02:00
b8b89783ca [ORPO] Correct label mask for pad tokens (#1625)
* [ORPO] Correct label mask for pad tokens

Recent [fix](57aebe9c36) for calculating NLL loss for a whole sequence introduced a bug. When input_ids are copied to labels, pad tokens are not masked.

This PR aims to path this by masking labels based on the attention mask.

* -100 -> label_pad_token_id

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-05-10 15:43:13 +02:00
8799952876 visualize rm prediction (#1636)
* visualize rm prediction

* quick update

* quick check

* quick fix

* update eval steps
2024-05-10 09:32:20 -04:00
3b4c24946b fixed adding bos and eos token unconditionally (#1591)
* fixed adding bos and eos token unconditionally

* fixed typo of tokenizer -> self.tokenizer. Also added update to ORPO

* fixed code quality, and added BOS/EOS fix to KTO

* code reformatting with pre-commit run --all-files

* bug fix: check input id length before checking for EOS/BOS
2024-05-04 00:19:35 +02:00
0347f583e3 Fix ZeRO-3 generation context manager (#1617) 2024-05-03 15:59:59 +02:00
75de236c09 corrects loss function for Self-play Preference Optimization hard label version (#1615)
* corrects sppo hard lable version

* formatting

* formatting
2024-05-03 08:09:57 +02:00
7075cec94d Update HH dataset on helpful only subset (#1613)
* Update HH dataset on helpful only subset

* format
2024-05-02 12:12:12 -04:00
adf17a5a26 support loss function for Self-play Preference Optimization (#1612)
* support loss function for Self-play Preference Optimization

* update docs

* update value error msg

* update typehint

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* include sppo in tests

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-05-02 16:06:58 +02:00
0d40e186ee Docs: Fix build main documentation (#1604)
* Fix build documentation

* Update build_pr_documentation.yml
2024-05-02 11:44:29 +02:00
683bc5af6f Excluding tests from setup.py (#1607) 2024-05-02 10:30:27 +02:00
5f0913122b Use auto device map (#1596) 2024-05-02 09:22:31 +02:00
d1aa0b6b2c [KTOTrainer] add BCO (reward shift and underlying distribution matching) (#1599)
* add `Loss Functions` section in the doc.

* add bce loss with reward shift in KTOTrainer

* add underlying distribution matching

* update example to use underlying distribution matching

* add config description

* fix 'referenced before assignment' error

* add 'bco' and 'udm' test cases

* run pre-commit

* add `scikit-learn` dependency

* raise error is sklearn is not available

* call TrainingArguments().__post_init__() for proper init
2024-04-30 14:06:45 +02:00
d88ec14602 Update __init__.py (#1602) 2024-04-30 10:25:43 +02:00
6c18e40e97 fix typo (#1594) 2024-04-29 10:42:31 +02:00
1d0a7ea17b add warning in SFTTrainer (#1577) 2024-04-23 20:00:10 +02:00
9f68ead8cf FIX: Fix CI on transformers main (#1576)
* Update run_dpo.sh

* Update run_sft.sh

* Update clis.mdx

* Update example_config.yaml

* Update test_cli.py

* Update testing_constants.py

* Update test_dpo_trainer.py
2024-04-23 14:31:45 +02:00
f30daa4225 [SFT] add SFT Trainer Config dataclass (#1530)
* initial SFT Config

* remove pdb

* fix chat_template

* undo formatting

* add back removed commits

* fix the tests

* add back options to SftScriptArguments

* use sft_script_args

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* rename SFTScriptArguments and split names

* formatting docstrings

* docstring

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-04-23 11:55:13 +02:00
24fd8dd513 [DPO] DPOConfig class (#1554)
* initial DPOConfig

* fix doc string

* use DPOConfig

* fix missing import

* fix DpoScriptArguments

* override args config when given in init

* use DPOConfig

* fix output dir name

* over-ride with depreicated arguments if given

* use DPOConfig in tests

* fix comment

* add custom_message

* use dataset_train_name and dataset_test_name

* beta is also in the training_args

* fix loss_type docs

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use DPOScriptArguments

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-04-23 11:06:28 +02:00
c050ebc073 [DPO] add 'bco_pair' loss_type (#1524)
* add 'bco_pair' loss_type

* add BCO description to DPO doc

---------

Co-authored-by: sean.jung <sean.jung@seanjungui-MacBookPro.local>
2024-04-22 18:46:51 +02:00
abc0584736 fix add_special_tokens issue for data with template (#1509) 2024-04-22 18:44:10 +02:00
6d1cb85e73 set dev version (#1568) 2024-04-22 10:59:35 +02:00
e90e8d91d2 Release: v0.8.6 (#1567) 2024-04-22 10:58:13 +02:00
113aaae033 CLI: Add warning when ignored params are passed + parse config file if config if passed (#1565)
* add warning

* no need for `config` field
2024-04-22 10:48:59 +02:00
0865572748 Update __init__.py (#1557) 2024-04-18 14:51:40 +02:00
a6532a11c2 set dev version (#1556) 2024-04-18 13:58:17 +02:00
3595eb00e0 Release: v0.8.5 (#1555) 2024-04-18 13:56:36 +02:00
9afd901d0f enable multiple eos tokens (#1553) 2024-04-18 12:19:18 +02:00
e04432d5e3 FIX: make the train / test fields modulable (#1551)
* make the train / test fields modulable

* format

* fix --output_dir issue
2024-04-18 11:33:30 +02:00
75c1c47fcc set dev version (#1548) 2024-04-17 17:25:01 +02:00
a5788ac99b Release: v0.8.4 (#1547) 2024-04-17 17:19:28 +02:00
3bbe7e0407 Fixed ref model not used in PPO generation (#1534) 2024-04-17 07:22:56 -07:00
edf60e826b Update run_sft.sh (#1546) 2024-04-17 16:17:05 +02:00
5d1deb1445 CLI: Set dataset_text_field to None to allow ChatML automatic template (#1545)
* Update cli_utils.py

* Update test_cli.py
2024-04-17 14:45:14 +02:00
476c4b8dc0 [KTO] support to load the adapter twice (#1542)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-04-16 17:43:40 +02:00
e823458a6a save_model -> save_pretrained in ppo_trainer.mdx (#1537) 2024-04-15 09:35:03 +02:00
1c0d8bca15 VSFT hotfix - adds gen prompt to template and processor to hub (#1532)
* adds gen prompt to template and processor to hub

* fixes hub model id, removes Path
2024-04-12 20:14:12 +02:00
363369a717 [CPO] fix memory leak due to retained value (#1531) 2024-04-12 15:32:01 +02:00
aba4df02c1 set dev version (#1529) 2024-04-12 12:37:34 +02:00
98226473e4 Release: v0.8.3 (#1528) 2024-04-12 12:22:05 +02:00
87f4c70e60 [CLI] fix imports (#1527) 2024-04-12 12:17:05 +02:00
995f1174da set dev version (#1523) 2024-04-11 15:51:57 +02:00
143e11123d Release: v0.8.2 (#1522) 2024-04-11 15:42:47 +02:00
346c99d222 Adds VLM Training support to SFTTrainer + VSFT script (#1518)
* adds option to skip dataset preparation in SFTTrainer

* before changing the template

* adds support for new schema

* a few fixes to data collator to support new schema

* updates args

* precommit

* adds sys prompt to chat template and other fixes

* updates template, fixes collator for multiple images

* precommit

* rename vsft to vstf_llava

* adding integration tests

* adds integration test for vsft

* precommit

* adds back chat template

* docs

* typo

* adds eval, precommit

* adds peft launch args

* formatting

* fixes no deps tests by checking if PIL lib exists

* Update __init__.py

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-04-11 15:35:59 +02:00
087fe544b0 add data for sfttrainer doc (#1521) 2024-04-11 15:08:43 +02:00
ebbd37ba99 allow pre-tokenized datasets (#1520) 2024-04-11 14:50:39 +02:00
e667550a5a Allow streaming (datasets.IterableDataset) (#1468)
* safe-guard iterabledatasets

* import datasets

* reference the correct IterableDataset

* make pre-commit
2024-04-11 11:11:07 +02:00
57aebe9c36 [ORPO] Update NLL loss to use input_ids instead (#1516)
* Calculate loss on `input_ids` instead of only on response

* Use `concatenated_labels` if `is_encoder_decoder`
2024-04-09 14:10:09 +02:00
85f5fd220d correct metrics (#1514)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-04-08 17:09:04 +02:00
4dca169404 use kwarfs for RM (#1515) 2024-04-08 17:05:37 +02:00
f35b68a301 Speed up PPO with ZeRO-3 by 10x 🔥 (#1483)
* Speed up PPO by 10x 🔥

* Revert

* Clean up

* Use relative import

* Clean

* Fix typing for docs
2024-04-08 14:30:44 +02:00
5cf863576a Change the device index to device:index (#1490)
Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-04-08 14:20:42 +02:00
9a28b3fd05 Fix RichProgressCallback (#1496)
* fix RichProgressCallback

* Refine code styling in RichProgressCallback tests
2024-04-04 21:13:54 +02:00
4f8057ad23 [KTO] fix interleaving, reporting, hanging bugs (#1499)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* fix tokenization error: lack of bos

* change user warning for weight hyperparams

* minor update to docs

* reshape attention mask

* reformat

* add test for bos/eos tokens

* move dependency location

* Update tests/test_kto_trainer.py

* don't report nan metrics

* don't report nan metrics and remove data interleaving

* fix bugs in calculating metrics

* no need to gather KL term

* minor changes

* use nanmean for losses

* remove disabling of wandb

* revert changes

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-04-03 23:41:12 +02:00
ab0d11d815 Correct ppo_epochs usage (#1480)
* Correct ppo_epochs usage

The usage of ppo_epochs is incorrect here. 

In 8534f0edf8/trl/trainer/ppo_config.py (L104C8-L104C58)

the ppo_epochs was described as "Number of optimisation epochs per batch of samples". 

However, here it is used as the usual epoch number, in which you do one iteration over the training dataset.

* Update ppo_trainer.mdx

* Update docs/source/ppo_trainer.mdx

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-04-02 12:22:16 +02:00
c674c66a45 Fix DPO Unsloth example (#1494) 2024-04-02 12:16:56 +02:00
45da5df53e use log1p for loss (#1491) 2024-04-02 12:06:54 +02:00
04fd8d9400 Fix typo in how_to_train.md (#1503)
Said "big" where it should say "bug".
2024-04-02 12:05:07 +02:00
bf2aed3876 add dpo link (#1502) 2024-04-02 12:04:34 +02:00
0ee349dcd4 Update KTO example to use better model and ChatML support (#1485)
* Update KTO example

* Tweak params

* Fix values

* Fix LoRA params
2024-03-27 10:47:42 +01:00
7ff6206510 Ignore chat files (#1486)
* Ignore chat files

* Update .gitignore

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update .gitignore

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2024-03-27 10:44:23 +01:00
e4b20ecbc4 hackey update to ModelConfig to allow lora_target_modules="all-linear" (#1488)
the type hint forces a list which raises a "all-linear" layer not found. forcing a string makes it work. updating the type hint to `Union[str, list[str]]` also raise a parsing error
2024-03-27 09:04:41 +01:00
6c2f829bb7 [KTO] Use batching to speed up data processing (#1470)
* Refactor test

* Make batched tokenizer

* Make is FAST 🔥!

* Hack to the max

* Run on main process

* Refactor

* Add unit test

* f

* r

* Refactor

* Remove bs

* Refactor to tokenize once

* Add typing

* Add test for KL getter
2024-03-26 19:46:23 +01:00
c4f0f41935 Update KTO example with good dataset & chat format (#1481)
* Update KTO example with good dataset & chat format

* Add error for chat template
2024-03-25 16:56:43 +01:00
dc6a934269 add missing classes (#1479) 2024-03-24 22:08:28 +01:00
9ce7ac6925 Fix hyperparameters in KTO example (#1474)
* Fix hparams in KTO example

* Clean

* Fix
2024-03-24 14:29:22 +01:00
99553c19ae Add use_cache=False in {ORPO,CPO}Trainer.concatenated_forward (#1478)
* Add `use_cache=False` in `concatenated_forward`

Prevents `ORPOTrainer` from using the cache, as it's not required for computing the logits and runs into conflicts with Flash Attention 2

* Add `use_cache=False` to `concatenated_forward`

Co-authored-by: Kashif Rasul <kashif@users.noreply.github.com>

---------

Co-authored-by: Kashif Rasul <kashif@users.noreply.github.com>
2024-03-24 11:33:20 +01:00
2ce8e45bb2 ORPO trainer (#1435)
* initial orpo skeleton

* typos

* calculate orpo loss

* fix class name

* fix tests

* fix typo

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* rename max_target_length

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* more docs

* log log_odds_ratio and log_odds

* average_log_prob as per paper

* added logging section

* add nll_loss

* fix typo

* more verbose

* rename log_odds to log_odds_chosen

* allow datasets to be loaded

* remove dup debug arg

* tokenizer exists

* fix typo

* use trl-internal-testing/hh-rlhf-trl-style dataset

* formatting

* add missing imports

* fix output dir name

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* move dataset_num_proc to configs

* Update trl/trainer/orpo_config.py

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>

* Update trl/trainer/orpo_trainer.py

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>

* add ORPOTrainer to readme

* fix typo

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
2024-03-22 22:07:11 +01:00
d1df79f83c Add CPOTrainer (#1382)
* add CPOTrainer

* add docs

* fix formatting

* removed precompute_ref_log_probs arg

* remove precompute_ref_log_probs

* typos

* finish cpo trainer doc

* remove redundant lines

* typo

* formatting

* compute chosen nll loss also for enc-dec models

* fix gradient error of inplace operation for enc-dec models

* formatting

* use CPOConfig

* formatting

* use model_init_kwargs from CPOConfig

* comments in example

* fix doc string

* fix typo in docstring

* update year

* fixed typo

* use preference dataset

* fix learning rate

* move dataset_num_proc to configs

* Update cpo paper link from HF: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update description for CPO: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* remove _prepare_deepspeed for cpo

Because CPO does not need init for reference model

* Add explanation to CPO loss

* format

* fix bug when lengths are given

* add CPOTrainer to README

* fix grammer

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-22 21:32:45 +01:00
d10f7663b0 [peft] Update test_reward_trainer.py to fix tests (#1471)
* [peft] Update test_reward_trainer.py

Since we are requiring peft >= 0.4.0

* formatting
2024-03-22 19:12:54 +01:00
423991c204 Use the standard dataset for DPO CLI (#1456)
* Use the standard dataset

* update docs

* update dpo examples

* fix cli error

* fix CI

* use trl-internal-testing/hh-rlhf-trl-style
2024-03-20 13:14:08 -04:00
988d4c4e1a set dev version (#1463) 2024-03-20 12:30:48 +01:00
8534f0edf8 Release: v0.8.1 (#1462) 2024-03-20 11:32:06 +01:00
5095e7f948 add eos token to generate (#1459) 2024-03-20 10:30:27 +01:00
9fcf61d706 Fix chat CLI for model revisions (#1458)
* Fix chat CLI for model revisions

* Clean
2024-03-20 09:35:34 +01:00
66b043a910 set dev version (#1454) 2024-03-19 17:30:48 +01:00
f2c71771cc Release: v0.8.0 (#1453)
* Release: v0.7.12

* 0.8.0 instead
2024-03-19 17:19:38 +01:00
631c33cbb3 FEAT: Update README to add DPO + CLIs (#1448)
* Update README.md

* Update README.md

* move dpo/ppo description to docs

* rework readme

* Update README.md

---------

Co-authored-by: leandro <leandro.vonwerra@spoud.io>
2024-03-19 16:55:56 +01:00
3f7ff60528 model --> model_name_or_path (#1452)
* `model` --> `model_name_or_path`

* fix style
2024-03-19 16:52:42 +01:00
1705aebeba Fix yaml parsing issue (#1450) 2024-03-19 16:07:50 +01:00
4e622a9033 chat cli (#1431)
* first draft

* move chat to cli

* fix makefile

* make script less verbose

* fix parsing

* fix style

* add more examples

* fix setup.py

* add copyright

* fix verbose init

* attribute FastChat

* add docs
2024-03-19 12:37:06 +01:00
eb2d5b2972 CI / CLI: Properly raise error when CLI tests failed (#1446)
* properly raise error

* another fix

* Update tests.yml

* Update tests-main.yml
2024-03-19 11:39:07 +01:00
f976c6d234 Before update the tr_loss, make sure tr_loss_step is in the same device. (#1439)
* before update the loss from dpo, make sure it's in the same device of tr_loss

* Update trl/trainer/dpo_trainer.py

Co-authored-by: guy1992l <83535508+guy1992l@users.noreply.github.com>

---------

Co-authored-by: guy1992l <83535508+guy1992l@users.noreply.github.com>
2024-03-19 10:28:44 +01:00
abc7301bab Fix PPOTrainer README example (#1441)
* Fix example

* Delete newline
2024-03-19 10:18:49 +01:00
6cfa5cfc81 fix doc build on main (#1437) 2024-03-18 14:24:02 +01:00
a2aa0f0b09 FEAT: Add CLIs in TRL ! (#1419)
* CLI V1

* v1 CLI

* add rich enhancmeents

* revert unindented change

* some comments

* cleaner CLI

* fix

* fix

* remove print callback

* move to cli instead of trl_cli

* revert unneeded changes

* fix test

* Update trl/commands/sft.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* remove redundant strings

* fix import issue

* fix other issues

* add packing

* add config parser

* some refactor

* cleaner

* add example config yaml file

* small refactor

* change a bit the logic

* fix issues here and there

* add CLI in docs

* move to examples/sft

* remove redundant licenses

* make it work on dpo

* set to None

* switch to accelerate and fix many things

* add docs

* more docs

* added tests

* doc clarification

* more docs

* fix CI for windows and python 3.8

* fix

* attempt to fix CI

* fix?

* test

* fix

* tweak?

* fix

* test

* another test

* fix

* test

* fix

* fix

* fix

* skip tests for windows

* test @lvwerra approach

* make dev

* revert unneeded changes

* fix sft dpo

* optimize a bit

* address final comments

* update docs

* final comment

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2024-03-18 12:20:54 +01:00
304e208f77 Create standard dataset for TRL (#1424)
* add scripts to create standard dataset

* precommit

* push changes

* add script to play with
2024-03-14 10:57:48 -04:00
4fe8b027f6 [Kto] torch_dtype kwargs fix (#1429)
* set torch_dtype from string type

* fix test
2024-03-14 13:49:44 +01:00
fb6ebb1e11 [KTO] fix tokenization bugs (#1418)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* fix tokenization error: lack of bos

* change user warning for weight hyperparams

* minor update to docs

* reshape attention mask

* reformat

* add test for bos/eos tokens

* move dependency location

* Update tests/test_kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-14 08:22:50 +01:00
66078c7c01 CI: Fix CI on main (#1422)
* fix CI on main

* final fix
2024-03-13 13:54:22 +01:00
58c0888996 Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA (#1416)
* don't do mp casting

* don't use `prepare_for_kbit` when using fsdp+qlora or dsz3+qlora

* changes to enable fsdp+qlora and dsz3+qlora

* revert

* Update sft_trainer.py

* quality

* fix deprecation using changes from PR https://github.com/huggingface/trl/pull/1415

* fixes

* quality

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* quality

* relaunch tests

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-03-13 10:43:45 +01:00
486e7a4071 model init when args are given (#1413)
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2024-03-11 13:47:37 +01:00
7630f877f9 Fix import error from deprecation in transformers (#1415)
* Fix import error from  deprecation in transformers

* Fix import path
2024-03-11 13:23:56 +01:00
4d862da181 [KTO] fix various bugs (#1402)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-08 12:04:52 +01:00
22b4f548f4 fix RM script (#1393) 2024-03-07 08:49:52 +01:00
4219cbfedc Fix the pad_token_id error (#1394)
* Fix the pad_token_id error

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add the load_in_8bit argument in rl_training.py

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Reformate the patch

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Fix the check failed

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-03-05 02:18:42 +01:00
3bd02380c7 Log ddpo reward as float to fix numpy conversion during bf16 training (#1391) 2024-03-04 02:50:50 +01:00
067db7553a [KTO] prevent nans from appearing in metrics (#1386)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-01 12:19:55 +01:00
93e85ed808 [KTO] merge eval dataset only if it exists (#1383)
* merge eval dataset if it exists

* add eval dataset test
2024-03-01 12:15:14 +01:00
14e0d78807 fix bugs in KTO implementation (#1380)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-02-29 09:01:52 +01:00
b32656f726 FIX: Fix the CI again .. (#1374)
* Update tests-main.yml

* Update tests-main.yml

* Update tests-main.yml

* Update .github/workflows/tests-main.yml

* Update tests-main.yml

* Update tests-main.yml
2024-02-27 12:46:20 +01:00
9399bc113b Update tests-main.yml (#1373) 2024-02-27 12:07:50 +01:00
11f122ad49 Update tests-main.yml (#1372) 2024-02-27 11:45:02 +01:00
009c9a610b feature request add force_use_ref_model (#1367) 2024-02-27 11:19:16 +01:00
7712d42f8c add eval_packing (#1369) 2024-02-27 11:19:06 +01:00
7c2213b9e5 add ci message sending on TRL (#1370) 2024-02-27 11:18:55 +01:00
ddeebce176 Add some arguments for support XPU (#1366)
* Add use_bnb and load_in_4bit arguments.

Make it optional and not supported on all platforms

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Change the use_reentrant default value to False

If the default value of gradient_checkpointing is True, set the
use_reentrant default value as False. Because the following error
happens.

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 191 with name base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add model_dtype for loading the model in model_dtype

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Reformate the patch

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-02-27 02:49:16 +01:00
cf68d871cf Fix version for Python<3.8 (#1363) 2024-02-27 02:41:09 +01:00
2a2676e7ec set seed in sft/dpo/reward_modeling to make result reproducable (#1357)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-02-23 11:12:45 +01:00
ca90cba351 fix 8-bit multi-gpu training bug (#1353)
* fix 8-bit multi-gpu training bug see https://github.com/huggingface/trl/issues/1348

* Update dpo_llama2.py

make gradient_checkpointing_kwargs configurable.

* Update dpo_llama2.py

remote unnecessary config of device_map

* format with make precommit

---------

Co-authored-by: ubuntu <lili@liveremier.ai>
2024-02-23 03:58:43 +01:00
4f97fb4a74 more userfriendly (#1350) 2024-02-22 10:06:35 +01:00
a46cd84a64 Kto trainer (#1181)
* initial file

* initial tokenizer

* UnpairedPreferenceBatchSampler

* use batch_sampler

* use interleave_datasets

* add loss

* fix imports

* use SequentialSampler when training

* formatting

* add other helpers

* add prediction_step

* fix the kto pair docs

* tests

* compute_reference_log_probs

* add get_eval_dataloader

* fix typo

* kto with is_encoder_decoder true

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* fixed typo

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* renamed KTO dataset keys

* use DPOTrainer's get_batch_logps

* add get_batch_samples

* typo

* Handle last token in prompt

* Create KTOConfig class that subclasses transformers.TrainingArguments

* Update KTO tests to handle KTOConfig

* Update KTO script to use KTOConfig

* formatting

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/training_configs.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use max_completion_length

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add back get_batch_logps

* use max_completion_length

* move config to its own file

* Check tokenize params on Trainer init

* Clone labels for end-dec model to solve RuntimeError

* formatting

* fix enc-dec later

* completion_decoder_input_ids is optional for enc-dec

* fix breaking test

* add a kl key for KL estimation with shuffled completion

* add loss ad weights

* fix bug in chosen_idx

* add back metrics

* fix typos

* fix kto_loss docs

* typo

* set loss to None when there is no target completions in batch

* use nan tensor instead of none

* fix reference_logps test

* fix logits

* a bit more robust options

* log only the correct prompt-completion during eval

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add docs for desirable_weight and undesirable_weight args

* dropout is always disabled

* remove DDP hack

* formatting

* move more arguments of trainer to config

* comment out T5 test for now

* Add docstring to KTOTrainer

* moved Config docstrings to the appropriate class

* add autodoc to markdown

* formatting

* updated copyright year

* add model tags

* do not add BOS to start of completion

* Move data_collator to KTOTrainer

* formatting

* data_collator is not in args

* shuffle_completion with specific input_columns

* remove all but the needed columns

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* moved more args to kto_config

* fjx test

* use all_exhausted strategy and shuffle after

* use KTOConfig in HfArgumentParser

* use ModelConfig

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
2024-02-19 14:43:17 +01:00
1f56bffdf8 Update Example to reflect #aa35fec (#1333) 2024-02-18 14:10:04 +01:00
1bfe0b8fcb set dev version (#1332) 2024-02-16 09:49:05 +01:00
0f13e51efa Release: v0.7.11 (#1331) 2024-02-16 09:05:04 +01:00
1e77d8aeb2 [core / xxxTrainer] Automatic tagging (#1329)
* automatic tagging

* add comments

* fix tests

* fix
2024-02-15 14:47:32 +01:00
3b1911c2a9 add tests on transformers peft main (#1328) 2024-02-15 05:19:31 +01:00
851e7fe556 [core / DDPO] Fix diffusers import issue (#1314)
* fix

* more clean up
2024-02-15 04:45:27 +01:00
31b02d0cd0 Update README.md to clarify model requirement (#1315)
Clarify that language models must be transformers models for text.  This is a bit redundant with intro description, but attempts to better address a question that that comes up (issue 1257).

Closes: #1257
2024-02-15 04:38:17 +01:00
9bc478ecbb pre-commit: replace linters + formatters with Ruff; fix some issues (#1300)
* pre-commit: replace linters + formatters with Ruff

* Don't use bare except

* Clean up `noqa`s

* Enable Ruff UP; apply auto-fixes

* Enable Ruff B; apply fixes

* Enable Ruff T with exceptions

* Enable Ruff C (complexity); autofix

* Upgrade Ruff to 0.2.0
2024-02-15 04:37:41 +01:00
29f162b86c Best practice recommendation update for dpo_trainer.mdx (#1325)
In the document as it is now the best practice recommendations don't seem neither consistent nor correct. 

For example, the documentation links a tweet with a recommendation to merge adaptors into a quantized model, and a script that supposedly illustrates how to apply that recommendation. But the script actually does the opposite of what the tweet recommends, first dequantizing the model. 

There are similar inconsistencies/ambiguities further in that paragraph. For example, saying that using an unquantized model would lead to lower performance (I changed it to "higher memory demand").

Overall, I updated the paragraph to improve consistency and provided links to slightly more evidence-based merging recommendations.
2024-02-14 11:43:48 +01:00
6852097169 Fix PPOTrainer argument train_dataset -> dataset (#1321)
Both the argument's name as well as the value need to be renamed.
Otherwise we get both

NameError: name 'train_dataset' is not defined

and

TypeError: PPOTrainer.__init__() got an unexpected keyword argument 'train_dataset'
2024-02-06 22:37:04 +01:00
f12a1da74b Fix AttributeError in dpo_trainer for reference_free case in dpo_loss function (#1313)
* Update dpo_trainer.py

update reference_free parameter for dpo_loss

* Update dpo_trainer for reference_free case

updated the docstring typo and set device parameter to ref_logratios tensor
2024-02-02 11:02:40 +01:00
ae87b3aefa Fix typos in docs for Multi Adapter RL (MARL). (#1312)
* Fix more typos

* Fix typos in docs.
2024-02-02 07:37:08 +01:00
3f7cee7643 ENH: Run CI only if relevant files are modified (#1309)
* Update tests.yml

* Update .github/workflows/tests.yml
2024-02-01 23:49:32 +01:00
ae8431bd50 Codemod Unittest assertions to bare asserts (#1301)
* Remove stray commas from test data

* Codemod Unittest assertions to bare asserts

* Make `assertAlmostEqual` tests more idiomatic

* DRY some test strings
2024-02-01 23:49:03 +01:00
66a976c6bd Update sft_trainer.mdx to add note on launching DDP training (#1308)
As requested here: https://github.com/huggingface/trl/issues/1303#issuecomment-1920437586
2024-02-01 23:42:14 +01:00
814930377c Add num_proc arg to the eval_dataset processing (#1307) 2024-02-01 17:58:00 +01:00
88685f2cd4 Types: Fix PEP 484 implicit-optional compliance (#1297)
This was done automatically with hauntsaninja/no_implicit_optional.
2024-01-31 14:51:58 +01:00
6f40f20233 Fix DPOTrainer docstrings (#1298)
Some issues were leading the auto-generation of the API reference to fail and the args were overlapped in the documentation page
2024-01-31 14:49:41 +01:00
036213bd85 Fix sft trainer when args is None (#1295)
* fix sft trainer when args is None

* add test

* fix
2024-01-31 03:31:53 +01:00
6042596705 Fix DPO slow tests (#1292)
* Update test_dpo_slow.py

* style
2024-01-30 10:15:46 +01:00
070c75ec54 load data only on main process + fix dpo example test (#1291) 2024-01-30 10:14:22 +01:00
b415224a4a fix DPO trainer + mistral + FA2 (#1290) 2024-01-30 08:25:29 +01:00
9186710671 fix padding in dpo trainer (#1284) 2024-01-30 08:24:48 +01:00
aa35fec099 raise value error if one passes a ref_model and a peft_config (#1289) 2024-01-30 08:06:03 +01:00
737d771941 Add multiprocessing in the DPO trainer. (#1286)
* Update dpo_trainer.py

Added support for num_proc to tokenize the training dataset.

* Update dpo_trainer.py

added type in the new num_proc variable

* added test case

* add test case

* fix type

---------

Co-authored-by: imraviagrawal <ravi.agrawal@umass.edu>
Co-authored-by: Ravi Agrawal <raviagrawal@Ravis-MacBook-Pro.local>
2024-01-30 02:55:07 +01:00
ef441ea028 Update dpo_trainer.mdx (#1280) 2024-01-27 10:29:10 +01:00
af623aeba6 Fix sft ci (#1279) 2024-01-26 19:18:23 +01:00
3843cfc32f Fix SFT tuner (#1278) 2024-01-26 17:49:50 +01:00
9a71e67be9 Remove tyro (#1176)
* refactor

* Remove tyro in `ppo.py`

* quick update

* update default args

* quick push

* precommit

* refactor

* quick change

* remove tyro

* quick change

* precommit

* quick change

* fix hello_world

* remove docstring diffences

* add `module load cuda/12.1`

* push changes

* precommit

* make dpo runnable

* fix circular import

* quick fix

* refactor

* quick update

* path change

* update plots

* fix docs

* quick change

* Update trl/trainer/model_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/model_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/dpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* address comments. use attn_implementation

* precommit

* remove duplicate code

* update peft.py

* fix test no op dep

* Update trl/trainer/utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* precommit

* add docs

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-26 07:51:15 -08:00
09ca565b24 FIx SFTTrainer bugs on TRL main (#1276)
* Update sft_trainer.py

* Update trl/trainer/sft_trainer.py
2024-01-26 13:50:37 +01:00
4edc688311 Only load data on main process (#1255)
* fix: only load data on main process

* define is_main_process once

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* avoid re-initializing PartialState on train dataset check

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* avoid re-initializing PartialState on eval dataset check

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* process dataset on main first to take advantage of caching

* fix typo in docs

* use decorator to manage state

* Revert "fix typo in docs"

This reverts commit 0880a188812a698f7106853245ce1ba96a036831.

* Revert "Revert "fix typo in docs""

This reverts commit ff7ee33fbeedcd0032b728d86a17cfcb10e43f9b.

* Revert "use decorator to manage state"

This reverts commit 7ac7a45949f621941fedc522f0d2ca7b29367c3a.

* use is_local_main_process instead of is_main_process

* fix: use context manager instead of attribute

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-26 10:38:07 +01:00
29d439a204 [DPO] average_log_prob when loss is IPO (#1265)
* average_log_prob when loss is IPO

* updated docs with the fix
2024-01-24 12:18:04 +01:00
5760e5d3db Fix typo in extra_columns variable name (#1269)
Co-authored-by: Otto Laitila <otto.laitila@op.fi>
2024-01-23 14:46:13 +01:00
a3c5b7178a Update utils.py (#1256) 2024-01-22 15:32:29 +01:00
222d275b8a set dev version (#1254) 2024-01-19 11:58:47 +01:00
09ca7607d5 Release: v0.7.10 (#1253) 2024-01-19 11:52:51 +01:00
1e68753216 fix: fix loss_type and some args desc (#1247) 2024-01-18 17:20:52 +01:00
1f59eeb9bb Fix chatml template (#1248)
* first draft

* 64

* sourabs suggestion

* wip tests

* make style happy

* add check

* docstring

* fix docstring

* Update tests/test_model_utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move tests

* add todo for abstract class

* make style happy

* add slow tests and imports

* add documentation

* sft_trainer.mdx aktualisieren

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fix template & add test

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-18 16:47:25 +01:00
928d14445e Add setup_chat_format for adding new special tokens to model for training chat models (#1242)
* first draft

* 64

* sourabs suggestion

* wip tests

* make style happy

* add check

* docstring

* fix docstring

* Update tests/test_model_utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move tests

* add todo for abstract class

* make style happy

* add slow tests and imports

* add documentation

* sft_trainer.mdx aktualisieren

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-18 11:05:32 +01:00
3319993bd1 Fix weird doc bug (#1244)
* Update utils.py

* Update trl/trainer/utils.py

* Update trl/trainer/utils.py
2024-01-18 10:48:56 +01:00
4fb3d0c860 Update sft_trainer.py (#1241) 2024-01-17 15:16:07 +01:00
bcccdeb6f9 [core / SFTTrainer] Fix breaking change (#1229)
* fix breaking change

* revert

* fix

* final fix

* fix

* fix tests
2024-01-17 14:45:22 +01:00
ef209e311f [core / tests ] v1 slow tests (#1218)
* v1 slow tests

* nit

* add qlora tests for DPO

* add decorator

* release memory + log reports

* report to none to avoid seg fault issues

* update setup

* fix

* add exampel testing

* fix nit

* change temp filename

* add workflow file

* fix comment

* add slack push script

* more tests for DPO

* add dpo example tests

* another makefile command

* fix

* add paths + clean up

* nit

* Update slow-tests.yml

* trigger tests

* up

* up

* more fixes

* fix

* final fixes

* minor fixes

* oops

* add more text

* fix

* more

* trigger CI

* up

* fix

* remove

* run the tests on 2 GPUs only

* final fix SFT

* revert config files + address comments

* fix

* add Phi

* final fixes

* final fix
2024-01-17 10:17:57 +01:00
341f6a6787 fix: improve error message when pad_token_id is not configured (#1152)
* fix: improve error message when `pad_token_id` is not configured

* Add test for error raised when pad_token is None

* Fix pre-commit errors

* Fix error in the test environment
2024-01-17 09:34:20 +01:00
97b9fa212a Update dpo_trainer.py (#1160)
Log metrics on all distributed processes
2024-01-15 15:40:44 +01:00
a7d796c9a2 Remove a repeating line in how_to_train.md (#1226) 2024-01-15 15:18:49 +01:00
fa074e6a15 Create slow-tests.yml (#1223) 2024-01-12 09:29:57 +01:00
776939dcc4 Add support for ChatML dataset format in (#1208)
* Add support for ChatML dataset format in
SFTTrainer

* fix formatting

* fix tests

* more comment

* fix intent

* fix doc string

* Update dataset_formatting.py

* Update dataset_formatting.py

* add documentation

* Update sft_trainer.mdx

* add leonardos comment and more tests

* added more tests and fixed batching

* style

* comment in
2024-01-12 08:05:32 +01:00
163ca9f059 Refactor RewardConfig to own module (#1221)
* Refactor RewardConfig to own module

* Fix init

* Fix import
2024-01-12 17:50:37 +11:00
2eeb7b04cf [core / Docker] Add workflow to build TRL docker images (#1215)
* add docker build

* Update docker/trl-latest-gpu/Dockerfile

* Update docker/trl-source-gpu/Dockerfile
2024-01-11 11:03:43 +01:00
9f8d0e48ad Fix args type (#1214)
* fix args type

* add args desc
2024-01-10 16:35:19 +01:00
c9b7145c75 Update Unsloth SFT, DPO docs (#1213)
* Update sft_trainer.mdx

* Update sft_trainer.mdx

* Update dpo_trainer.mdx

* Update dpo_trainer.mdx

* Update sft_trainer.mdx
2024-01-10 09:08:08 +01:00
baf3c1c293 Fix FSDP error (#1196)
* Fix FSDP error

Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP.

* Apply suggestions from code review

force return_dict

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-09 18:21:23 +01:00
b181e401a7 Fix shape descriptions in calculate_loss method (#1204) 2024-01-09 14:24:41 +01:00
26da9e80cb Check tokenize params on DPOTrainer (#1197)
* Check if tokenizer and max len params are None

* Update warning messages for missing parameters
2024-01-09 14:10:22 +01:00
d6cc88ab2c set dev version (#1207) 2024-01-09 13:06:30 +01:00
7a95cc8696 release: v0.7.9 (#1206) 2024-01-09 13:02:31 +01:00
d1715514de Revert "Address issue #1122 (#1174)" (#1205)
This reverts commit d57d0f9ca46a63d370b91791352edda0154576f5.
2024-01-09 10:20:50 +01:00
d116887ed4 [DPOTrainer] Fix peft + DPO + bf16 if one uses generate_during_eval or pre-computed logits (#1203)
* fix peft + DPO + bf16

* fix

* revert old behaviour

* fix tests

* fix

* fix

* fix

* fix
2024-01-09 09:35:50 +01:00
a236c5750f Fix reported KL in PPO trainer (#1180)
* Fix reported KL in PPO trainer

previously this was always reporting the estimated KL, even when using `kl_penalty = 'full'` (or `abs`, etc).
Now we return the actual KL calculated in `compute_rewards()`, and report that.

* fix test
2024-01-09 06:48:25 +01:00
4ae35afdd6 Fix instruction token masking (#1185)
* Fix instruction token masking

Fix instruction token masking if the first instruction is tokenized differently than the others, or in general if no instruction is detected before the first response.

* Bugfix for edge case

(in case either of the templates isn't found at all, ...idxs[0] might not exist)

* Add test for instruction masking fix
2024-01-09 06:41:53 +01:00
b21ed0ddbc set dev version (#1201) 2024-01-09 05:19:10 +01:00
384b868fe6 Release: v0.7.8 (#1200) 2024-01-09 05:13:26 +01:00
3267be0fcd Allow swapping PEFT adapters for target/ref model. (#1193)
* Allow swapping PEFT adapters for target/ref model.

* Update DPOTrainer docs.

* python format

* isort

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-01-08 16:12:45 +01:00
dbcb2f0021 Allow separate devices for target/ref models. (#1190)
* Allow separate devices for target/ref models.

* Remove original/duplicate.

* Cleanup original, black formatting.

---------

Co-authored-by: Jon Durbin <jonathan@convai.com>
2024-01-08 10:26:40 +01:00
d5910b0ff5 Handle last token from generation prompt (#1153)
* Handle last token from generation prompt

* Remove prints

* Reformat dpo_trainer file
2024-01-08 09:15:53 +01:00
104a02d207 SFTTrainer: follow args.remove_unused_columns (#1188) 2024-01-08 06:09:10 +01:00
ad597dbcb3 Fix misleading variable "epoch" from the training loop from PPOTrainer Doc. (#1171)
* Fix misleading variable "epoch" from PPOTrainer Doc. 

The usage of the variable “epoch” is misleading in the original Doc, the dataloader does not contain the data for ALL epochs, but 1 only, thus 
"for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader))"
is misleading and does not actually stores the epoch #. 

The correct version comes from the TRL PPO notebook tutorial 
(https://github.com/huggingface/trl/blob/main/examples/notebooks/gpt2-sentiment-control.ipynb), which uses an outer loop to capture the epochs.

I posted also the question on forum: https://discuss.huggingface.co/t/confusing-and-possibly-misleading-ppo-trainer-code-from-trl-api-doc-tutorial/67531

* Remove batch_id
2024-01-08 05:50:00 +01:00
d57d0f9ca4 Address issue #1122 (#1174)
* Address issue #1122

    Issue [#1122](https://github.com/huggingface/trl/issues/1122)
    takes care of an inconsistency between `_prepare_packed_dataloader`
    and `_prepare_non_packed_dataloader`

* made attention_mask field in ConstantLengthDataset a tensor
2024-01-08 05:43:34 +01:00
ec3d41b879 Fix batch all gather (#1177)
* Fix batch all gather

* quick fix
2024-01-04 17:41:52 +01:00
be32d304db Update sft_trainer.py (#1162)
Fix spelling mistakes in argument description for trl -> SFT Trainer
2024-01-04 16:33:53 +01:00
dc53b8c6b0 Correct shape (#1170) 2024-01-04 16:27:39 +01:00
20428c48ba add: support for peft in ddpo. (#1165)
* add: support for peft in ddpo.

* revert to the original modeling_base.

* style

* specify weight_name

* explicitly specify weight_name

* fix: parameter parsing

* fix: trainable_layers.

* parameterize use_lora.

* fix one more trainable_layers

* debug

* debug

* more fixes.

* manually set unet of sd_pipeline

* make trainable_layers cleaner.

* more fixes

* remove prints.

* tester class for LoRA too.
2024-01-02 12:52:36 +01:00
6614b8aa6b Minor fixes to some comments in some examples. (#1156) 2023-12-29 14:12:05 +01:00
df7b770da8 change device order of metrics (#1154) 2023-12-29 10:55:58 +01:00
18a33ffcd3 SFT Tokenizer Fix (#1142) 2023-12-27 10:25:56 +01:00
911d3658e2 [xxxTrainer] Add unsloth tag (#1130)
* add unsloth tag

* add it on all trainers

* few changes

* add in docs

* revert

* final commit
2023-12-26 16:39:10 +01:00
95ec8577df add peft_module_casting_to_bf16 in DPOTrainer (#1143)
* add peft_module_casting_to_bf16 in DPOTrainer

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Update trl/trainer/dpo_trainer.py

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2023-12-26 11:25:53 +01:00
3539f3e3cd set dev version (#1145) 2023-12-26 10:26:15 +01:00
e451298b50 Release: v0.7.7 (#1144) 2023-12-26 10:24:47 +01:00
3efb484694 [PPOTrainer / DDPOTrainer] Fix ppo & ddpo push to Hub (#1141)
* fix ppo push to Hub

* fix also ddpo

* more tags
2023-12-26 10:06:20 +01:00
8f5b4923c8 reformatted (#1128) 2023-12-23 10:16:27 +01:00
e0dec27272 reformatted (#1129) 2023-12-23 10:13:38 +01:00
6ef785a6fb Add type hints to core.py (#1097)
* Add type hinting to core.py functions

* Fixes

* Remove unused functions

* Remove unused import
2023-12-22 17:05:20 +01:00
950ee2187d clear up the parameters of supervised_finetuning.py (#1126)
no_gradient_checkpointing is always false

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2023-12-22 17:00:28 +01:00
c1bb1f39f6 set dev version (#1135) 2023-12-22 15:09:37 +01:00
54babd9508 Release: v0.7.6 (#1134) 2023-12-22 15:03:24 +01:00
0c4edb750e [xxxTrainer] multi-tags support for tagging (#1133)
* multi-tags support for tagging

* oops
2023-12-22 14:52:16 +01:00
17ec68d980 set dev version (#1132) 2023-12-22 14:12:24 +01:00
9be5680039 Release: v0.7.5 (#1131) 2023-12-22 14:01:44 +01:00
f11e213fd8 [Docs] Add unsloth optimizations in TRL's documentation (#1119)
* add unsloth

* Update sft_trainer.mdx (#1124)

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
2023-12-22 13:45:26 +01:00
814fe396d4 rename kto loss (#1127) 2023-12-22 13:32:16 +01:00
06b7959b72 save eval_dataset for subsequent calls (#1125) 2023-12-21 17:28:56 +01:00
b07935f867 [xxxTrainer] Add tags to all trainers in TRL (#1120)
* add tags to sfttrainer

* extend it to other trainers

* add for ddpo
2023-12-21 17:04:18 +01:00
2aff709144 Update description in setup.py (#1101) 2023-12-21 15:35:12 +01:00
830cadfc4c fix gradient checkpointing when using PEFT (#1118) 2023-12-20 13:35:56 +01:00
f2acd821e0 Make prepending of bos token configurable. (#1114)
* make prepending of bos token configurable.

* address comments

* fix bug

Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-12-20 11:28:50 +01:00
f100ca34cc peft_module_casting_to_bf16 util method, append_concat_token flag, remove callback PeftSavingCallback (#1110)
* SFT Trainer enhancements

* remove the callback `PeftSavingCallback`

* bump the version of transformers to `4.31.0`

* remove `PeftSavingCallback` from all places.
2023-12-19 17:43:25 +01:00
d708ec272f [Feature] Add Ascend NPU accelerator support (#1096)
* add npu support

* make precommit
2023-12-15 15:34:35 +01:00
8140129595 Updated documentation for docs/source/reward_trainer.mdx to import the correct Enum for the reward modelling LoRA config (#1092) 2023-12-15 11:24:20 +01:00
48b3ef0b7b [DPO] use ref model logprobs if it exists in the data (#885)
* use logprobs if it exists in the batch

* add features to tokenized batch if in data

* make get_batch_logps a static method

* add tokenize_batch_element dataset mapper

* Remove tokenize_batch method from DPODataCollator

* Initial sketch to precompute reference_logps

* run ref model via pytorch dataloader

* add a padding helper

* clean up the helper

* use logprob item()

* default behaviour

* clean up collator

* add docstring

* copy data back to cpu if needed

* use get_train_dataloader methods

* fix tests

* rename: more explicit variable name precompute_ref_log_probs

* improve comment

* update comment

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* refactor models into setup parameters

* parametrize precompute_ref_log_probs flag

* remove useless test

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update function arg name

* distinguish between pad token_id and mask values

* fix tokenization #932 by @nrailg

* fix test

* undo test refactor

* new line

* undo breaking change

* Update token counter condition to allow Llama tokenizer

* Acount for merged tokens on certain tokenizers such Llama-2 tokenizer

* Update variable name to match list value when truncating response

* map function on multi-gpu and gather

* Add test cases for DPOTrainer tokenization step

* revert since we need the prepeared model

* Use gather_with_metrics on ref_logps precomputation to keep original dataset size

* Add flag to keep track of when ref_logps are precomputed

* make variable names private

* formatting

* if precompute_ref_log_probs is true one can use non-peft to populate log-probs

* Use tokenizer padding token unless padding_value is set

* Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors

* eval can be none

* move to cpu to avoid gpu oom

* remove unneeded cast to float32

* remove unneeded

* fix merge

* fix merge

* fix merge

* add precompute log-prob status via tqdm

* Truncate answer if too longer once prompt has been truncated

* Add prompt_input_ids to batch to enable generation

* formatting and add lora example

* fix formatting

* Tokenize row now expects sample to have space on chosen/rejected for llama

* Revert "Tokenize row now expects sample to have space on chosen/rejected for llama"

This reverts commit dd07a10fe8c19b6ac6bbcc7b8144189756710d52.

* raise error when using zero-3 with precompute_ref_log_probs

---------

Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
Co-authored-by: Shoaib Burq <saburq@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2023-12-12 17:16:46 +01:00
c0ce52ab26 consistency on log (#1084) 2023-12-12 10:58:21 +01:00
393dbf6749 Removing tyro in sft_llama2.py (#1081)
* refactor

* precommit
2023-12-11 11:28:20 -06:00
94fa4b022b Make CI happy (#1080)
* Update test_ppo_trainer.py

* Update test_ppo_trainer.py

* Update test_ppo_trainer.py
2023-12-11 16:52:17 +01:00
cb7819e627 add local folder support as input for rl_training. (#1078)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2023-12-11 16:37:01 +01:00
8f0fc4c8f7 Add args to SFT example (#1079) 2023-12-11 16:16:47 +01:00
d275cb431e [DPO] add KTO loss (#1075)
* add KTO loss

* fix docs

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* formatting

* add link to papers

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2023-12-11 11:41:03 +01:00
7d0a8eea4e Add missing loss_type in ValueError message (#1067) 2023-12-07 08:40:53 +01:00
5a233546ee enable multiple eval datasets (#1052)
* enable multiple eval datasets

* added test

* try to avoid infinite computation

* make sure eval set is not infinite

* downsizing the test
2023-12-06 20:26:24 +01:00
9fb00cf007 [SFTTrainer] Fix Trainer when args is None (#1064)
* fix sfttrainer when args is None

* oops
2023-12-06 19:02:09 +01:00
ee44946814 [core] Fix failing tests on main (#1065)
* fix tests on main

* fix last test
2023-12-06 18:31:02 +01:00
7f2401bd6e update doc for the computer_metrics argument of SFTTrainer (#1062) 2023-12-06 17:46:36 +01:00
23bf9d4b58 Improve PreTrainedModelWrapper._get_current_device (#1048)
* use LOCAL_RANK in _get_current_device

* use PartialState in _get_current_device

* update annotation
2023-12-05 17:47:40 +01:00
501c347083 Update doc CI (#1060) 2023-12-05 13:31:01 +01:00
f06f357e9c [SFT Trainer] precompute packed iterable into a dataset (#979)
* precompute packed iterable into a dataset

* add generator function

* fix typo

* fix style

* fix test

* fix style

* add test

* minor refactor

* fix test

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* style

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
2023-12-04 13:13:18 +01:00
4cdc03ab5c Fixing accelerator version function call. (#1056)
Co-authored-by: Partha Ghosh <pghosh@brown.is.localnet>
2023-12-04 12:39:58 +01:00
a60ceefa69 Update dpo_trainer.py (#1049) 2023-12-01 17:03:09 +01:00
baa8f09cb3 Revert "[DPO] Refactor eval logging of dpo trainer (#954)" (#1047)
This reverts commit 6d9ea38ae18c7e266f797b62de4a68a12a13aba4.
2023-12-01 10:33:31 +01:00
c859f5fa5f remove spurious optimize_cuda_cache deprecation warning on init (#1045)
Signed-off-by: Chander Govindarajan <mail@chandergovind.org>
2023-12-01 10:26:42 +01:00
481ef96293 Fixes reward and text gathering in distributed training (#850)
* adds a tensor gather on rewards

* adds dist gather on texts

* style

* adds a tensor gather on rewards

* adds dist gather on texts

* style

* simplifies gathering of rewards

* style

* simplify logic

* precommit

* Update trl/trainer/ppo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* quick change

* push changes

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2023-11-30 10:32:09 -05:00
6d9ea38ae1 [DPO] Refactor eval logging of dpo trainer (#954)
* first attempts at refactor of dpo trainer

* removed extra stuff in prediction step

* import fixes

* label names

* all working

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-11-30 12:09:33 +01:00
c203e47fbf spelling is hard (#1043) 2023-11-30 12:09:13 +01:00
c84e5918a6 [DPO] cDPO loss (#1035)
* add cDPO loss

* add comment

* docs

* info about label_smoothing not being used
2023-11-30 11:50:30 +01:00
4b67af37b6 Update utils.py (#1012)
* Update utils.py

update compute_accuracy to deal with the cases where str_chosen and str_rej got the same scores, which is probably what the developers don't want

* Update utils.py

updated so only warning is reserved

* Update trl/trainer/utils.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-11-29 16:02:50 +01:00
55d7c952c7 [DPO] IPO Training loss (#1022)
* initial IPO loss

* fix loss

* fixed comments

* added docs

* fix doc-strings

* add tests

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* fixes for review

* Added doc about beta in the Trainer's docstring

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-11-24 15:52:40 +01:00
μT
3719f7a929 Add missing elements to sft_trainer document (#1029) 2023-11-23 12:34:27 +01:00
e7961e45f1 Remove duplicate data loading in rl_training.py (#1020)
We load dataset twice, but in line 149 (new), we do 
`ds = train_dataset.map` anyway
2023-11-23 12:25:07 +01:00
b307faf07b [Multi-Adapter PPO] Fix and Refactor reward model adapter (#982)
* reward adapter loaded as part of init

more flexible, clearer args

* fixed script for multi gpu

unwrap model since it is DDP
downside, with reward adapter it seems we need to use
find_unused_parameters=True

* remove gradient from reward score calculation

* change supported_args back to None
2023-11-21 14:48:18 +01:00
aea1da8e2b Adds requires_grad to input for non-quantized peft models (#1006)
* Update sft_trainer.py

* style

* add tests
2023-11-20 15:57:46 +01:00
e5eb4db8b5 Update how_to_train.md (#1003)
* Update how_to_train.md

fix description about `min_new_tokens`

* Update docs/source/how_to_train.md

Co-authored-by: Costa Huang <costa.huang@outlook.com>

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
2023-11-20 10:33:34 +01:00
28bdb6a373 Fixed wrong trigger for warning (#971)
func.__code__.co_varnames was used to count the function arguments for formatting_func. This code actually counted the function variables rather than function parameters.
2023-11-15 14:36:54 +01:00
e140d22881 make distributed true for multiple process (#997)
* make distributed true for multiple process

* Update trl/trainer/ppo_trainer.py

distributed should have more than 1 process

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-11-15 11:20:25 +01:00
e23a541af9 add docs (#992) 2023-11-14 19:31:10 +01:00
be3faa768e [DataCollatorForCompletionOnlyLM] warn if eos_token_id and pad_token_id are identical (#988)
Display a warning message if the  and  values are the same in order to prevent unintended behavior during multi-turn training.
2023-11-14 19:24:56 +01:00
13679aa97e Update README.md (#994) 2023-11-14 18:29:08 +01:00
9e9f024399 Fix a bunch of outdated references to examples/ (#977) 2023-11-10 11:29:21 +01:00
c2884b5096 [Tests] Add non optional packages tests (#974)
* add non-peft tests

* change name

* test

* change

* fix test
2023-11-09 15:01:46 +01:00
2f726ce4e8 set dev version (#970) 2023-11-08 11:54:01 +01:00
a78a05d7b7 Release: v0.7.4 2023-11-08 10:30:29 +00:00
1b258247cd Pin bnb to <=0.41.1 (#968)
* pin bnb to 0.41.1

* Update setup.py

* Update setup.py
2023-11-08 11:28:17 +01:00
9c93dec05e fix peft config typehint (#967) 2023-11-08 11:11:39 +01:00
d1dad6ebda set dev version (#966) 2023-11-08 11:00:24 +01:00
8ce810250e Release: v0.7.3 (#965) 2023-11-08 10:52:47 +01:00
8e9cae8072 fix: dpo trainer ds config (#957)
* fix: dpo trainer ds config

ref_model and model shouldn share the same ds config, so we shouldn modify the ds config directly. or else, it will cause sth wrong when init deepspeed engine

* fix: import sort

import sort by isort
2023-11-06 14:37:04 +01:00
654543a8cf Added support for custom EncoderDecoder models (#911) 2023-11-06 09:52:10 +01:00
c273b18c1c Adds model kwargs to SFT and DPO trainers (#951)
* adds model kwargs to SFT and DPO trainers

* adds checks for model_kwarg passing when model is not str

* changed warning to ValueError

* renames model_kwargs to model_init_kwargs

* corrects argument names in
2023-11-06 09:48:18 +01:00
6c6ff24926 [DPO] Merge initial peft model if trainer has a peft_config (#956)
* failing test
Co-authored-by: Shoaib Burq <saburq@gmail.com>

* merge initial peft model
2023-11-06 09:45:46 +01:00
6ff0fac2c1 Fix unwrapping peft models (#948)
* First unwrap the model and then process the input embeddings

* Changed base_model to base_model.model to stay consistent with peft model abstractions
2023-11-05 08:31:47 +01:00
951ca1841f [CI] Fix CI with new transformers release (#946)
* fix CI with transformers release

* final fix
2023-11-03 10:38:58 +01:00
cc1de9820a Introducing the Iterative Trainer (#737)
* initial skeleton

* iterative trainer for decoder only

* iterative trainer unittest

* encoder_decoder support

* fix typo in unittest

* init

* fix typo

* fix init typo

* adding loggings and safety checker

* fixed minor issues

* doc

* table of contents update

* add test for seq2seq2 models

* change year

* adding text as step input

* precommit

* fixing typo

* run precommit

* fixing typo in safety checker

* fix text tokenization issue

* add truncate and inherit from trainer

* remove iterative config from tests

* remove iterative config from init

* fix peft model

* change truncation side based on truncation_mode

* removed iterativeconfig autodoc

* fixed typo in trainer.mdx

* remove mention of iterative config in docs

* make sure optimizer and scheduler are created

* adding max_steps to test

* remove log_stats fn

* remove compute loss

* fixing encoder decoder detection

* fix PPODecorator

* run precommit

* fix testing

* fix small typos in iterative trainer

* adapted function log and eval
2023-11-02 17:37:48 +01:00
a64a522fcc Update dpo_trainer.py (#941) 2023-11-02 11:27:49 +01:00
5b32372b71 Optionally logging reference response (#847)
* Optionally logging reference response

* log ref rewards as welll

* peft logic re-write

* fix peft test case

* refactor

* push changes

* test

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* quick fix

* black

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-10-31 17:55:09 -04:00
d759004e52 Fix stale bot (#935)
* Update stale.py

* Update stale.py

* fix
2023-10-31 20:10:38 +01:00
cbc6c9bb3e [core / DDP] Fix RM trainer + DDP + quantization + propagate gradient_checkpointing_kwargs in SFT & DPO (#912)
* make use of forward hooks

* correctly delete attributes

* fix RM DPP issues

* revert unneeded changes

* more fixes

* fix diff

* fix

* propagate to SFT

* Update examples/scripts/reward_modeling.py

* propagate the fix on DPO trainer

* add to example scripts

* trigger CI
2023-10-31 18:50:17 +01:00
f3cd86578b Update dpo_llama2.py (#934) 2023-10-31 18:20:53 +01:00
b763432eaf [SFTTrainer] Make sure to not conflict between transformers and TRL implementation (#933)
* standardize neftune

* up

* fix again
2023-10-31 16:04:09 +01:00
2bbd594ec5 hotfix for dpo trainer (#919)
addresses #914
2023-10-31 10:58:41 +01:00
b89b712dbf fix DPO + GC issues (#927) 2023-10-31 10:55:46 +01:00
ec9e76623e [Feature] Enable Intel XPU support (#839)
* enable xpu support

* fix bug

* review commits

* fix style

* add xou decorator

* refactor review commit

* fix test

* review commit

* fix test

* Update benchmark.yml (#856)

* Standardise example scripts (#842)

* Standardise example scripts

* fix plotting script

* Rename run_xxx to xxx

* Fix doc

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>

* Fix version check in import_utils.py (#853)

* dont use get_peft_model if model is already peft (#857)

* merge conflict

* add xou decorator

* resolve

* resolves

* upstream

* refactor and precommit

* fix new tests

* add device mapping for xpu

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Adam Pauls <adpauls@gmail.com>
Co-authored-by: abhishek thakur <1183441+abhishekkrthakur@users.noreply.github.com>
2023-10-31 10:15:35 +01:00
d192244f54 Bump tyro (#928) 2023-10-30 20:48:34 -04:00
051d5a1f61 updating PPOTrainer docstring (#897)
* adding specific dict structure to tracker_kwargs doc string to enable changing tracker params like wandb experiment name for ease, avoids needing to go deep into accelerate source

* push changes

* set default dict

* refactor

* use typing extension

---------

Co-authored-by: Laura O'Mahony <lauraomahony@L-MacBook-Pro.fritz.box>
Co-authored-by: Costa Huang <costa.huang@outlook.com>
2023-10-30 13:22:53 -04:00
2068fdcd93 Generalize NEFTune for FSDP, DDP, ... (#924)
* Update sft_trainer.py

* quality
2023-10-30 11:17:14 +01:00
02f5c1d8ce fix stackllama2 sft gradient checkpointing (#906)
* fix stackllama2 sft gradient checkpointing

* stackllama2 sft use tyro as arg parser
2023-10-25 09:58:26 -04:00
7de7db6765 deactivate MacOS CI (#913) 2023-10-24 16:06:12 +02:00
4e7d5b5abe [Update reward_trainer.py] append PeftSavingCallback if callbacks is not None (#910) 2023-10-24 14:32:45 +02:00
a90e13321b Fix broken link/markdown (#903)
* Fix broken link/markdown

* attempt to fix mps issue

* attempt fix mps issue

* test

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
2023-10-24 14:27:03 +02:00
5b2aeca6c0 [NEFTune] Make use of forward hooks instead (#889)
* make use of forward hooks

* correctly delete attributes

* address suggestions
2023-10-24 14:18:44 +02:00
1f3314fd2f Add whiten ops before compute advatanges (#887)
* Add whiten ops before compute advatanges

1. From LLaMA 2 paper, it says:
```
We also find it important to whiten the final linear scores (shown here by reversing the sigmoid with the logit function) in order to increase stability and balance properly with the KL penalty term (β) above.
```
2. This function is taken from [alpaca_farm](64e489c67e/src/alpaca_farm/rl/ppo_trainer.py (L86))

* Fix type def of self

---------

Co-authored-by: Lin Junpeng <linjunpeng@sensetime.com>
2023-10-23 11:32:45 -04:00
304ee70eef Fix couple broken links on lib homepage (#908) 2023-10-23 11:46:37 +02:00
0a5aee7d99 [reward_modeling] Cleaning example script (#882)
* remove load in repeated multiple times & truncation

* trigger CI
2023-10-19 16:00:20 +02:00
db592a2eb6 fix: remove useless token (#896) 2023-10-19 14:28:33 +02:00
122edc8f5d fix peft_config type (#883)
Co-authored-by: wanglei.w <wanglei.w@bytedance.com>
2023-10-18 23:45:38 +02:00
f91fb2bda2 remove duplicate key in reward_modeling.py (#890) 2023-10-18 23:45:18 +02:00
01e4ad0009 fix syntax error 2023-10-17 21:22:53 +02:00
1e56ff0f16 Fix security breach 2023-10-17 08:01:24 +02:00
c4ed3274be [SFTTrainer] Adds NEFTune into SFTTrainer (#871)
* v1 neftune

* docstring

* add doc + fix nit

* add more docs

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2023-10-17 06:58:05 +02:00
14b6bc6691 [DPO] add SLiC hinge loss to DPOTrainer (#866)
* add SLiC hinge loss

* fix links

* beta when loss is hinge is reciprocal of margin

* fix tests

* fix docs

* doc strings

* fix method name

* raise error if loss_type is not correct

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* fix formatting

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-10-16 16:02:57 +02:00
eb4d2f381a set dev version (#864) 2023-10-12 15:51:54 +02:00
78e08bd658 Release: 0.7.2 (#863) 2023-10-12 15:29:10 +02:00
96d4854455 Support both old and new diffusers import path (#843)
* Update modeling_sd_base.py

* Update trl/models/modeling_sd_base.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* make precommit

* cleaner approach

* oops

* better alternative

* rm uneeded file

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
2023-10-12 15:06:09 +02:00
3ef21a24e7 [core] Fix import issues (#859)
* fix import issues

* cleaner approach
2023-10-11 19:04:49 +02:00
f7707fd4c6 dont use get_peft_model if model is already peft (#857) 2023-10-11 18:58:56 +02:00
dd9b8f4189 Fix version check in import_utils.py (#853) 2023-10-11 18:55:43 +02:00
ddd318865b Standardise example scripts (#842)
* Standardise example scripts

* fix plotting script

* Rename run_xxx to xxx

* Fix doc

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
2023-10-11 17:28:15 +02:00
8aa12d3c95 Update benchmark.yml (#856) 2023-10-11 11:06:48 -04:00
95aea7c072 Use uniform config (#817)
* Use uniform config

* quick fix

* refactor

* update docs
2023-10-09 09:15:06 -04:00
eda1f36c57 Raise error in create_reference_model() when ZeRO-3 is enabled (#840)
* Raise error when using  with ZeRO-3

* Fix

* Refactor

* Revert

* Restore remote code

* Revert example
2023-10-09 10:49:01 +02:00
ac0d5b726d add DDPO to index (#826)
* add DDPO to index

* Update index.mdx
2023-10-06 14:42:56 +02:00
6826d592ae Clarify docstrings, help messages, assert messages in merge_peft_adapter.py (#838)
An assertion was also corrected to the intended test condition
2023-10-06 11:04:58 +02:00
c058ee6f05 [MINOR:TYPOS] Update README.md (#829) 2023-10-05 14:33:20 +02:00
fbeb146eea Set trust remote code to false by default (#833) 2023-10-04 22:53:57 +02:00
98845b9282 Fix DeepSpeed ZeRO-{1,2} for DPOTrainer (#825) 2023-10-03 09:56:00 +02:00
9f6326e65a Unify sentiment documentation (#803)
* Update documentation

* update docs

* test

* format

* Update docs/source/example_overview.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* update

* add quantization dependency and update docs

* Update docs/source/example_overview.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/example_overview.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/example_overview.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/example_overview.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/sentiment_tuning.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/sentiment_tuning.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/sentiment_tuning.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/sentiment_tuning.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/sentiment_tuning.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/sentiment_tuning.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update

* quick update 2

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2023-10-02 10:35:49 -04:00
7dcc71b1a6 Small fixes to the PPO trainer doc. (#811)
One outstanding issue is that ppo_trainer.save_model doesn't exist.
How do we actually save the model after training?
2023-10-02 11:01:05 +02:00
6b73adc900 add option for compute_metrics in DPOTrainer (#822) 2023-09-29 12:33:47 +02:00
249d3e3259 Add RMSProp back to DPO (#821)
* init

* add install instructions
2023-09-26 10:44:44 -07:00
ad8d50e30d init custom eval loop for further DPO evals (#766)
* init

* run

* Update custom eval loop to aid DPO debugging (#770)

* sample_during_eval -> generate_during_eval

* Remove unused return_tokens

* Add import utils for W&B, prevent test fails

* Optimize dataloader random batch selection

* Separate prompt and response in logs

Makes it much easier to quickly read the starts of the generations

* Simplify logging

* reset eval steps

* manual merge fixes

* revert merge

* remove self.max_length

* style

* fix max_length

---------

Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
2023-09-26 08:09:15 -07:00
d608fea0d1 Allow passing the token_ids as instruction_template in DataCollatorForCompletionOnlyLM (#749)
* Update utils.py

* correctly assign instruction_template in DataCollatorForCompletionOnlyLM

* correctly use instruction_token_ids in DataCollatorForCompletionOnlyLM

* DataCollatorForCompletionOnlyLM: fix instruction_template / response_template type check: handle cases where instruction_template is None

* make precommit

* Test DataCollatorForCompletionOnlyLM with pre-tokenized instruction_template
2023-09-26 11:38:30 +02:00
92b03f5fdc fixes ppo trainer generate nit (#798) 2023-09-26 10:19:29 +02:00
7877e92991 Update sft_trainer.mdx (#808) 2023-09-22 17:55:54 +02:00
1d7e3c2ae2 Update sft_trainer.mdx to highlight Flash Attention features (#807)
* Update sft_trainer.mdx

* Update sft_trainer.mdx
2023-09-22 17:42:21 +02:00
eb6aa20401 clarify PEFT docs (#797) 2023-09-21 11:22:20 +02:00
b8f0c4cf12 Add deepspeed experiment (#795)
* Add deepspeed experiment

* add deepspeed pip install

* update hello world.sh

* update comments

* remove cleanup
2023-09-20 09:32:42 -04:00
e11a45c5d8 Revert "Add default Optim to DPO example (#759)" (#799)
This reverts commit d603e7c52704054a9e7f306ae63acdafaa3d179a.
2023-09-20 10:32:55 +02:00
08cfc4179b Add margin to RM training (#719)
* Start adding margin to RM training

* Fix typo and cleanup

* Fix incompatibilities when not using margin

* Format using 'make precommit'

* Add documentation and test for reward trainer

* Run 'make precommit'

* Update docs/source/reward_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Fix missed merge conflict in reward trainer docs

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2023-09-20 10:18:38 +02:00
d603e7c527 Add default Optim to DPO example (#759)
* add optim

* make configurable
2023-09-19 07:56:52 -07:00
5d30cd4d30 Changed the default value of the log_with argument (#792)
This change avoids setting report_to="all" (the default behavior in
transformers v4), which could lead to unexpected error messages for
inexperienced users. Note that the default value of report_to will
change anyway to "none" in transformers v5.
2023-09-19 13:04:17 +02:00
46975236be Temp benchmark ci dir (#765)
* Support fork in benchmark CI

* use temporary dir for benchmark CI

* debug

* revert back

* dependency fix

* refactor script
2023-09-18 11:16:16 -04:00
9a8d52cc5a Fix type checking (#748) 2023-09-18 13:54:41 +02:00
0a6c42c12c Update benchmark.yml (#782) 2023-09-15 13:45:21 -04:00
221be13d26 Update benchmark.yml (#781) 2023-09-15 11:34:09 -04:00
a922af6927 Update benchmark.yml (#780) 2023-09-15 11:28:16 -04:00
42e7a0a824 Update benchmark.yml (#779) 2023-09-15 11:18:55 -04:00
15d52e759b Update benchmark.yml (#778) 2023-09-15 11:02:10 -04:00
24e914a0ab Update benchmark.yml (#777) 2023-09-15 10:57:08 -04:00
637612d95f Benchmark CI fix (#776) 2023-09-15 10:33:45 -04:00
35694baef2 Benchmark CI fix (#775) 2023-09-15 08:52:24 -04:00
d2f27df50a Update benchmark.yml (#773)
* Update benchmark.yml

* quick change
2023-09-15 09:40:20 +02:00
5cee9a0478 Support fork in benchmark CI (#764) 2023-09-14 08:44:36 -04:00
3f7710aed7 docs: add initial version of docs for PPOTrainer (#665)
* docs: add initial version of docs for  `PPOTrainer`

* Apply suggestions from code review Leandro

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* updated docs based on feedback leandro
- specified reference to reward model
- added batched generator
- added line of saving model
- remove reference model

* Apply suggestions from code review

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-09-14 10:34:19 +02:00
ca0af3944d Benchmark CI (actual) (#754)
* refactor and benchmark

* update code

* Add accelerate logging

* logs

* quick fix

* update config

* precommit

* modify training example

* fix multi-gpu all_reduce error `Tensors must be CUDA and dense`

* support more models and benchmark

* update

* add changes

* upload benchmark

* precommit

* add tyro as a dependency

* add tyro

* pre-commit

* precommit

* weird...

* lol typo

* precommit

* sigh

* push changes

* Update benchmark/README.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Add experiments

* upload image to tag specific folder

* add openrlbenchmark documentation

* rename

* remove unused field

* precommit

* update slurm template

* add dependency

* update dependency

* ..

* .

* quick change

* push changes

* update

* update

* remove wandb tag code

* quick change

* precommit

* update test

* update dependency

* update test

* update benchmark dependency

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-09-13 13:34:00 -04:00
e4f9a483d9 Refactor and benchmark (#662)
* refactor and benchmark

* update code

* Add accelerate logging

* logs

* quick fix

* update config

* precommit

* modify training example

* fix multi-gpu all_reduce error `Tensors must be CUDA and dense`

* support more models and benchmark

* update

* add changes

* upload benchmark

* precommit

* add tyro as a dependency

* add tyro

* pre-commit

* precommit

* weird...

* lol typo

* precommit

* sigh

* push changes

* Update benchmark/README.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Add experiments

* upload image to tag specific folder

* add openrlbenchmark documentation

* rename

* remove unused field

* precommit

* push changes

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-09-13 10:24:18 -04:00
80890b17be [PPOTrainer] - add comment of zero masking (from second query token) (#763)
It took a while to understand why zero-masked tokens are one less than the length of query tokens. 

If I got it correctly, it is because the first logit (and state-value) from the outputs refers to the second token in the query. 

Hope this comment can be helpful to others who may encounter a similar question in the first-pass reading of the code :)
2023-09-13 10:23:04 +02:00
cf9d2a7133 Imrpove benchmark ci (#760) 2023-09-13 09:29:06 +02:00
c02ce6d3f5 Extend DeepSpeed integration to ZeRO-{1,2,3} (#758)
* Generalise deepspeed

* Refactor

* Add reward model arg

* Fix pipeline tokenizer

* Fix deprecation

* Pin deepspeed lower

* Fix docs

* Revert top_k change

* Add ZeRO-3 context manager

* Revert docs change

* Fix docs

* Polish docs

* Update docs/source/customization.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-09-12 18:59:49 +02:00
9141aa42ba EOS token processing for multi-turn DPO (#741)
* init

* fix

* add doc

* style

* clarify example
2023-09-12 09:49:51 -07:00
05723c0b88 benchmark CI fix (#755) 2023-09-12 09:04:57 -04:00
b87ec2d5a0 update to prepare_model_for_kbit_training (#728)
* update to `prepare_model_for_kbit_training`

from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for #694

* workaround for gradient checkpointing issue

calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop

also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
2023-09-12 10:56:10 +02:00
27df071ad8 add benchmark ci (#752) 2023-09-11 13:35:53 -04:00
67452ef213 fix import of torch_utils (#751) 2023-09-11 18:46:19 +02:00
22a90198e5 [DPO] self.accelerator._prepare_deepspeed return tuples (#745) 2023-09-08 11:50:06 +02:00
4f81e7736d Seq2Seq model support for DPO (#586)
* dpo_collator for seq2seq models

* dpo trainer support

* refactoring

* update collator

* computes decoder input ids if possible

* decoder input ids for dpo trainer

* added test for seq2seq

* quality

* fixed typo

* fixed string padding for seq2seq

* fixed minor issues in padding

* fixed typo in dpo.py

* add docstring

* run all precommit

* fixed gradient accumulation steps in test

* reformatting

* fixing dpo tests

* update .mdx
2023-09-07 18:03:10 +02:00
14292b08af fixed metrics typo (#743) 2023-09-07 18:02:20 +02:00
453c4eca14 Enable gradient checkpointing to be disabled for reward modelling (#725)
* Enable gradient checkpointing to be disabled for reward modelling

* Update examples/scripts/reward_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Tidy docs

* Remove commas

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-09-06 14:08:15 +02:00
decc832d3e Add epsilon to score normalization (#727) 2023-09-06 10:28:07 +02:00
1111295776 check correctly for condition (#668) 2023-09-06 10:24:55 +02:00
c04074e248 Fix DeepSpeed ZeRO-3 in PPOTrainer (#730)
* Initialise ref model with ZeRO-3

* Fix deadlock

* Refactor & fix KL div

* Refactor

* Refactor

* Fix imports

* Add types

* Add accelerate configs

* Add more DeepSpeed configs

* Fix types

* Disable debug

* Refactor

* Add docs

* Disable eval mode for peft

* Restore eval mode

* Revert ref model prep for peft

* Update examples/scripts/README.md

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

* Add docs

---------

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
2023-09-05 11:00:49 +02:00
d484dc2a93 Refactor RewardTrainer hyperparameters into dedicated dataclass (#726)
* Refactor RewardTrainer hyperparameters into dedicated dataclass

* Revert

* Add doc string

* Fix warning

* Handle backwards compat

* Fix tests

* Add docs

* Refactor to RewardConfig

* Fix case conditions

* Fix
2023-09-05 09:05:42 +02:00
34e6948d45 [core] Bump peft to 0.4.0 (#720)
* bump peft to 0.4.0

* all of them
2023-09-01 15:01:36 +02:00
9f69f06a1c Add pyproject.toml (#690)
* example pyproject.toml

* update target to py38

* make pyproject.toml equivalent to accelerate
2023-09-01 11:42:18 +02:00
jp
5bb46687c5 Fix: RuntimeError: 'weight' must be 2-D issue (#687)
* Update dpo_trainer.py

* Fix: self.args.deepspeed > self.is_deepspeed_enabled

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-09-01 11:27:54 +02:00
25d6700c5e fix sft mistakes (#717) 2023-08-31 16:56:29 +02:00
4d31d0c4f8 Update docs on gms8k (#711) 2023-08-31 16:48:07 +02:00
0ff39d2a87 fix device issue (#681)
* fix device issue

* fix device issue

* fix device issue

* merge changes

* fix device issue
2023-08-31 16:37:42 +02:00
b4899b29d2 set dev version (#710) 2023-08-30 17:00:34 +02:00
6aae9e75f3 Release: VERSION (#709) 2023-08-30 12:48:10 +02:00
79b90e19ba a workaround for failing log_stats (#708) 2023-08-30 12:23:57 +02:00
7f636c9ed7 set dev version (#707) 2023-08-30 11:58:22 +02:00
98d8cc509d Release: v0.7.0 (#706) 2023-08-30 11:55:54 +02:00
9d09b3e107 TextEnvironments (#424)
* WIP skeleton

* minimal working poc

* cleanup

* rename variables

* quick typo fix

* add v1 masking (#429)

* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: leandro <leandro.vonwerra@spoud.io>

* Add masking (#461)

* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

* fix batched generation

* improve stopping criteria

* improve error handling in tool call

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Costa Huang <costa.huang@outlook.com>

* fix uknown tool

* fix rewards and increase bs

* remove unused script

* ugly WIP fix

* do not return modified obj for in-place operations

* do not return modified obj for in-place operations

* clean up stopping criterium

* push updates

* push update

* format, add docs

* rename file

* add kwargs to reward fn

* simplify example

* simplify example

* bug fix

* add a trivia example

* pre-commit

* max tool response length

* fix regex for multi-line

* refactor tool exceptions

* fix exceptions in tool

* add docs

* fix style

* make rich optional

* add docstrings

* add  tests

* add TextEnv tests (WIP)

* update triviaqa code

* update docs

* refactor text env

* update tests (WIP)

* add end2end test

* update docs

* upload tool demo

* refactor

* customizable system prompt

* add text env docs

* update index and toc

* fix `TextHistory` show methods

* add max length

* fix style

* fix typo

* refactor to kwargs in init and tasks to queries

* kwargs for reward docs

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/tool_demo.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/text_environments.md

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move to tool folder

* remove assets

* remove tool demo

* move rich import test to import utils

* add copyright

* fixes for masks in ppo trainer

* add text env api docs

* make precommit + add ppo test with mask

* move examples and add python

* fix style

* update triviaqa example

* add more docs

* update docs

* Update docs/source/learning_tools.mdx

* Apply suggestions from code review

* precommit

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: leandro von werra <leandro@hf.co>
2023-08-30 11:44:06 +02:00
336d63eb80 [Docs] fix example README.md (#705) 2023-08-30 11:27:50 +02:00
7fc970983c [DPO] fix DPO ref_model=None (#703)
* fix by @tannonk

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add import

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-08-29 12:57:10 +02:00
d3bbee3ab8 set dev version (#685) 2023-08-24 11:04:07 +02:00
eb5465df7e Release: v0.6.0 (#684) 2023-08-24 10:18:46 +02:00
1c272240ac Simplify immutable TrainingArgs fix using dataclasses.replace (#682) 2023-08-24 09:50:48 +02:00
Wei
b095245830 fix PeftConfig loading from a remote repo. (#649)
* fix PeftConfig loading from a remote repo.

* failed to catch hf_hub_download() EntryNotFoundError.

At least in huggingface-hub 0.10.1, the error for "not found" is:
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error

* pass precommit checks.

* replace some bare excepts with specific codes

* catch LocalEntryNotFoundError additionally.
2023-08-24 09:50:20 +02:00
c115453fba Update sft_llama2.py (#678)
Add argument num_workers. Fixed error on line 103 if streaming set = False
2023-08-23 16:56:31 +02:00
16f214c58d fix unmutable TrainingArguments issue (#676) 2023-08-23 10:54:59 +02:00
e9a437992e propagating eval_batch_size to TrainingArguments (#675)
Co-authored-by: Rahul Jha <rahuljha@netflix.com>
2023-08-23 10:52:25 +02:00
c837fbe5b9 Fix DPO blogpost thumbnail (#673) 2023-08-22 11:53:21 +02:00
01c4a35928 Denoising Diffusion Policy Optimization (#508)
* Broken first pre-draft

* Change structure to leverage user-definition of pipeline
 - reward function, pipeline and scheduler will be left to the user to define
 - pipeline and scheduler contract interfaces is what the framework will define
 - none of this actually works

* Incremental progress: trying to get the set-up running e2e

* Incemental progress: successfully running code

* Incremental progress: running setup
Next steps: fix accelerate gardient acc assertion error when we set value > 1

* Formatting and code standards

* Incremental prog: break down code a bit
- new config flag to notify code of async reward fetching
- break off image handling code and throw it on to user to define how to handle it
- more code restructuring

* Incremental progress:
1. More code sectioning off into own methods (more for readibility than anything else)

* Incremental progress:
1. clear up contracts
2. type the reward function and prompt function

* Code shuffling and expansion of tracker, accelerator config args to beyond wandb

* More small additions
Add tensorboard logging function
Remove wandb logging function for now
Consolidate the data that get's thrown to the logging function
Add README

* Formatting

* Formatting

* Remove print statement
Make tensorboard tracking the sole tracking for the training example

* 1. start of testing
2. more refactoring
3. start of docstrings
4. parameter rename

* Basic Tests
Formatting

* Docs according to the norm

* Doocs, credits and rename file

* docs and corrections

* Put example config to respectable state

* Add recent run params

* Correct the name of the library

* Move requirements to EXTRAS

* - Add license banners
- Guard import of DDPO functions with if_diffusers_available
- doc strings for output types

* Add snippet to pull weights from huggingface + banner

* Test if passes on CI/CD

* Minor refactor

* Test dummy unet

* Possible fix for randomly disappearing attribute

* Shuffling arrangement in hopes of meeting memory requirements

* Proper Names

* Appease windows memory allocator issues for the cpu device

* Remove print statements

* Update docs/source/ddpo_trainer.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/ddpo_trainer.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Add docstrings and correct url

* Spelling and grammar

* Add more documentation and commandline parsing for example script

* Markdown synatx correction

* Revert accidentally committed file and put the correct one

* More docs

* Remove subclassing and add docs for leftoover subclassing

* Put back subclassing

* Reward metadata and more docs

* Remove save_load_save flag

* Grammar

* Update trl/trainer/ddpo_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update tests/test_ddpo_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update setup.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update examples/scripts/stable_diffusion_tuning.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Edits to the readme for DDPO

* Renamed modelling_sd_base to modeling_sd_base

* Insert try and catch for bitsandbytes import

* Change to smaller model

* Correct tolerance for floating point comparison

* Remove dummy unet and move to check is isfinite

* 1. Expand interface to ensure other Stable Diffusion pipelines could be covered
2. remove extra identification

* 1. Remove most of the asserts except for one and add value error
2. Remove default run name

* Remove progress bar

* Docs

* Put back progress bar

* 1. Revert progress bar deletion completely
2. grammar
3. relocate line

* Experiment

* Remove experiment parts and format properly

* Change formatting and edit info in docs

* Grammar

* Refactor out most of nitty gritty of loading/saving from trainer to example model
Readme addition

* Docs additions

* 1. Proper formatting fr the test file
2. incorporatioon of pull frm hub if fails try local
3. doc strings for interface
4. highlight in the trainer, that this is only ready fr sd pipelines

* Resources for before and after

* Attempt at embedding images

* Post testing example script

* Consistent naming and document edits in light of new args

* Remove resources and add CDN links in html in doc file

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-08-21 19:24:52 +02:00
1aca98fbcf add check of arguments (#660) 2023-08-21 12:02:07 +02:00
029f961b7c Handle potentially long sequences with DataCollatorForCompletionOnlyLM (#644)
* avoid RuntimeError on long sequences

* add unittests and format

* remove dependency on external repo

* bug fix in DataCollatorForCompletionOnlyLM
2023-08-18 10:30:25 +02:00
8ec912ffa6 Add more args to SFT example (#642)
* add more args

* fix style issues
2023-08-18 10:15:43 +02:00
f360c37466 Allow for ref_model=None in DPOTrainer (#640)
* Update dpo_trainer.py

Make ref_model optional.

* add tests for ref_model=None

* better handling for ref_model=None

* Update dpo_trainer.py

Correct docstring

* move instantiation of self.ref_model closer to model

* use .disable_adapters instead of .get_base_model

* handle ref_model=None in get_batch_samples

* fix failing test in dpo_trainer due to disable_dropout_in_model

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-08-18 10:02:16 +02:00
217313014b Update README.md (#657)
* Update README.md

fix reward modeling example

* Update README.md

more concise fix
2023-08-17 22:00:58 +02:00
b946e875b1 Resolve various typos throughout the docs (#654)
* Resolve various typos throughout the docs

I found the first few manually, and then found the rest via codespell

* HuggingFace -> Hugging Face

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-08-17 12:27:54 +02:00
6dd50b45d8 Add checks on backward batch size (#651)
* Add checks on backward batch size

* add test case

* update test case

* Update citation
2023-08-17 10:35:44 +02:00
98120d6aeb Disable dropout in DPO Training (#639)
* disable dropout in dpo

* quick fix docs

* precommiot

* add disable_dropout_in_model to DPOTrainer

* disable_dropout -> disable_dropout_in_model

* .

* .
2023-08-14 14:40:45 +02:00
3b2c820db6 Add score scaling/normalization/clipping (#560)
* Add reward/score scaling/normalization/clipping

* Run pre-commit to fix styles and remove some dupe code

* Make sure score module and pretrained_model have the same dtype

* Add multi_adapter_rl_v2.py

* Add log_with

* Add more verbose help message for use_score_norm

* Fix score clipping for float16

* Minor fix
2023-08-10 10:30:56 +02:00
25fd6f2313 Move repo (#628)
* update actions

* update references
2023-08-09 17:48:25 +02:00
3f1477cdc0 Improve docs (#612)
* WIP

* improve inference docs

* improve training faq

* update toctree

* fix toctree

* fix improve blog

* improve blog

* fix customization

* reword faq a bit

* reword inference a bit

* add references back

* integrate feedback from code review

* fix link in html
2023-08-08 11:45:16 +02:00
2cff1e4385 Allow already tokenized sequences for response_template in DataCollatorForCompletionOnlyLM (#622)
* Allow tokenized ids in DataCollatorForCompletionOnlyLM. Add test and docs

* Formatting

* Documentation

* Remove unused code from test

---------

Co-authored-by: Ivan Sanchez <ivan.sanchez@zyte.com>
2023-08-08 11:33:12 +02:00
d7d7902938 use log_with argument (#620) 2023-08-08 10:13:22 +02:00
77b0cc1707 [DPO] stack-llama-2 training scripts (#611)
* initial stack-llama-2 scripts

* removed unused function

* add accelerate

* link to stack-llama-2 code

* running the model

* pre-commit fixes

* use the merge_peft script

* Add section on logged metrics
2023-08-07 14:36:16 +02:00
17f22c1c20 Add docs explaining logged metrics (#616) 2023-08-04 12:50:39 -04:00
e448bb69f0 [Modeling] Add token support for hf_hub_download (#604)
* add token support for hf_hub_download

* allow to pass it to from_pretrained
2023-08-03 12:49:31 +02:00
9aa4e3ce2b set dev version (#608) 2023-08-02 10:43:27 +02:00
ca8a508913 Release: 0.5.0 (#607) 2023-08-02 10:31:43 +02:00
a00ab445ba refactor grad accum (#546)
* refactor grad accum

* quick fix

* use correct place to step optim

* push changes

* cleanup and fix division by zero in `masked_var`

* revert back changes

* use unbiased var

* deal with division by zero

* add test case

* calculate advantage only once

* format

* add warning

* add more warnings

* quick fix

* remove unhelpful warning

* fix test cases

* fix test cases

* bump version given the breaking change

* black

* refactor

* update test cases

* error out

* push changes

* remove exact div

* add comments
2023-08-01 09:00:41 -04:00
431f0c9a2f Fix comparison in DataCollatorForCompletionOnlyLM (#588) (#594)
* Add unit test to DataCollatorForCompletionOnlyLM to reproduce the bug.

* Change comparison target from examples[i][input_ids] to batch[labels][i] in DataCollatorForCompletionOnlyLM
2023-07-31 14:13:35 +02:00
64bc9bc9e6 docs: Replace SFTTrainer with RewardTrainer in comment (#589)
Likely just a copy-paste error
2023-07-28 15:37:25 +02:00
5a1e1bf06e Introducing DataCollatorForChatCompletionOnlyLM (#456)
* added DataCollatorForChatCompletionOnlyLM

* added simple test

* merged the two collators and fixed ### in completion

* fix response template

* fixing ordering in test

* quality

* fixed minor comments & make doc

* chat test back

* Update tests/test_sft_trainer.py

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-07-28 14:17:03 +02:00
e8dd8102d8 Update the example sft_trainer.py (#587)
Added save the model, because by default it saves only checkpoints not the final version.
2023-07-28 13:50:41 +02:00
1b46c61d43 [PPO] fix corner cases with PPO batch size and forward_batch_size (#563)
* fix corner cases PPO

* forward contrib credits from initial contribution

* forward contrib credits from initial discussions

---------

Co-authored-by: 1485840691-eng <1485840691-eng@users.noreply.github.com>
Co-authored-by: shubhlohiya <shubhlohiya@users.noreply.github.com>
2023-07-28 11:05:34 +02:00
3b0a1b5f8c Add missing max_seq_length arg to example sft_trainer.py (#585) 2023-07-27 18:17:43 +02:00
31658b4263 Computes the KL penalty using the entire distribution (#541)
* adds full log probs

* Adds tests, comments

* precommit

* bug all -> full

* adds option description to sentiment analysis script, fixes a few bugs
2023-07-27 12:08:24 +02:00
f7227fb296 Fix model output dim in reward trainer example (#566)
* correct glitches in reward modelling

* add the eval_split option

* correct code format
2023-07-26 11:02:23 +02:00
b3c2e73e70 [DPO] Resolve logging for DPOTrainer (#570)
* Resolve logging for DPOTrainer

* Ensure the WandB logger correctly prefixes all logs

* Run pre-commit

Whoops, hadn't run `pre-commit install` yet
2023-07-26 08:06:25 +02:00
d78d917880 Add comment to explain how the sentiment pipeline is used to run the … (#555)
* Add comment to explain how the sentiment pipeline is used to run the reward model in the StackLLaMA example

* Apply 'make precommit'
2023-07-24 18:09:45 +02:00
cdde7f71d7 Add DataCollatorForCompletionOnlyLM in the docs (#565)
* add `DataCollatorForCompletionOnlyLM` in the docs

* nit
2023-07-24 16:47:41 +02:00
51d5f08d88 add epochs and num steps on CLI (#562) 2023-07-24 14:01:54 +02:00
8762507d3a Minor typo and whitespace fixes (#559)
* [docs] remove extra whitespace

* [examples] fix help for dataset_name
2023-07-24 13:56:55 +02:00
1bd852aa8f remove unused batch_size arg (#554) 2023-07-24 13:23:33 +02:00
170d58ffce [SFTTrainer] Add warning for wrong padding_side (#550)
* add warning for wrong padding_side

* add warning

* revert

* oops
2023-07-22 10:53:16 +02:00
84c9209037 ADD: num_proc to SFTTrainer (#547)
* ADD: num_proc to SFTTrainer

* make precommit

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add batch_size

* Update trl/trainer/sft_trainer.py

Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
2023-07-20 15:41:48 +02:00
d0fe348a0a Add use_auth_token arg to sft_trainer example (#544)
* Add use_auth_token arg to sft_trainer example

* Update examples/scripts/sft_trainer.py

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-07-19 21:12:18 +02:00
5857d0acc6 [examples] make the sft script more modulable (#543)
* make the script more modulable

* docs + some changes
2023-07-19 18:13:55 +02:00
fd50e063e1 [DPO] remove response/pairs from the DPO side (#540)
* remove response/pairs from the DPO side

* Simplify get_hh helper function

* removed unused import

* update tests and docs for dpo_trainer

---------

Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
Co-authored-by: Shoaib Burq <saburq@gmail.com>
2023-07-19 17:36:24 +02:00
bcff7c2dab Relax reward trainer constraint (#539)
* relax reward trainer constraint

* Update trl/trainer/reward_trainer.py

Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>

* relax also for DPO

---------

Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
2023-07-19 14:12:23 +02:00
0e8d9f8504 fix offline case (#538) 2023-07-19 12:16:13 +02:00
7f297b38c6 all the concated batches are on same device (#528) 2023-07-18 13:21:17 +02:00
84393f3b94 DPO Trainer (#416)
* initial DPO Trainer

* typo

* initial dpo from reward trainer

* calc. log_probs from logits

* remove dpo config for now

* fix inits

* add intial DPODataCollatorWithPadding

* use the RewardDataCollatorWithPadding

* initial test

* means of loss

* add assert

* just call the train instead of step

* functional debug example before refactor

* check the params have changed

* initial DPODataCollatorWithPadding

* Data collator with masking

* going through trainer.accelerate to wrap ref_model

* style / imports

* style / imports

* `broadcast_buffers=False` fix to distributed training

* better fix for DDP issues

* arguments and style clean-up

* better doc, some light refactoring

* better imports

* initial dpo doc

* fix test

* fix formatting

* fix

* called models once

* fix tests

* add example

* fix doc string

* intitial example with anthropic hh dataset

* refactored dpo trainer

* revert

* return metrics

* fixed tests

* updated docs

* update test

* fixed typo

* note about the beta

* added dpo authors

* fix docstrings

* add prediction_step

* remove compute_metrics and log metrics manually

* fix typo

* add DPOTrainer doc

* add dpo to toc

* ValueError

* add to index and example

* fix docs

* fix assert

---------

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
Co-authored-by: Gaetan LOPEZ <gaetanloplat@gmail.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
2023-07-17 14:52:14 +02:00
388bdc03ac Fix sentiment nit (#517) 2023-07-14 14:11:24 +02:00
5c7bfbc8d9 [examples] Big refactor of examples and documentation (#509)
* added sfttrainer and rmtrainer example scripts.

* added few lines in the documentation.

* moved notebooks.

* delete `examples/summarization`

* remove from docs as well

* refactor sentiment tuning

* more refactoring.

* updated docs for multi-adapter RL.

* add research projects folder

* more refactor

* refactor docs.

* refactor structure

* add correct scripts all over the place

* final touches

* final touches

* updated documentation from feedback.
2023-07-14 12:00:56 +02:00
36b77ae81d Use local process index for _get_current_device() (#515)
This PR fixes a bug in `_get_current_device()` where the global process index was being returned instead of the local one. 

With this fix, it is possible to run training in **multi-node** environments and avoid the dreaded `RuntimeError: CUDA error: invalid device ordinal` :)
2023-07-14 10:53:33 +02:00
2049d03e82 Put labels tensors onto GPU to fix eval bug on deepspeed (#513) 2023-07-13 11:51:21 +02:00
31b98aa5a6 set dev version 2023-07-13 08:28:52 +00:00
d06b131097 git commit -m 'Release: v0.4.7' 2023-07-13 08:17:49 +00:00
f3230902b1 [SFTTrainer] Fix the sequence length check of SFTTrainer (#512)
* fix the sequence length check of `SFTTrainer`

* forward contrib credits from initial contribution

* forward contrib credits from initial contribution

* final comments

---------

Co-authored-by: mrm8488 <mrm8488@users.noreply.github.com>
Co-authored-by: BramVanroy <BramVanroy@users.noreply.github.com>
2023-07-12 15:25:17 +02:00
bbc7eeb29c [PPOTrainer] Add prompt tuning support on TRL (#500)
* add prompt tuning support on TRL

* fix CI

* revert + add docs
2023-07-06 15:16:37 +02:00
163dae5579 [PPOTrainer] Add prefix tuning support (#501)
* add prefix tuning support

* fix CI

* better check
2023-07-06 14:56:05 +02:00
64c8db2f9a Update ppo_trainer.py (#499) 2023-07-06 10:32:19 +02:00
25d4d81801 Disable mlm by default in DataCollatorForCompletionOnlyLM, add ignore_index and docstring (#476)
* add docstring and ignore index

* hard-code mlm=False

* make precommit

* FIX: re-add mlm parameter

---------

Co-authored-by: Bram Vanroy <Bram.Vanroy@UGent.be>
2023-07-06 10:22:40 +02:00
685620ac6c correctly implement gradient checkpointing (#479)
switch to new peft api
add max_length to RewardTrainer
2023-07-06 09:26:13 +02:00
2b531b9223 Adds some options to stabilize the KL penalty (#486)
* adds options for the kl penalty

* style

* adds kl penalty to trl sentiment example args

* ppo_config -> config

* fix tests (equal -> allclose)

* style

* add a random seed option

* updates kl penalty description

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-07-05 11:23:10 +02:00
4f7f73dd09 Remove padding in batched generation. (#487)
* fix padding

* Update examples/sentiment/scripts/gpt2-sentiment.py

* fix style

---------

Co-authored-by: leandro von werra <leandro@hf.co>
2023-07-05 10:41:06 +02:00
c60c41688e FIX: contributing guidelines command (#493)
* FIX: contributing guidelines command

* Update CONTRIBUTING.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update CONTRIBUTING.md

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-07-04 14:27:52 +02:00
cbb98dabb1 fix typo in reward_modeling.py (#494) 2023-07-04 14:17:32 +02:00
a86eaab8e8 add ratio threshold to avoid spikes (#488) 2023-07-04 10:09:53 +02:00
aa9770c6bd Refactor README (#460)
* v1

* update

* link

* nits
2023-07-03 14:30:15 +02:00
0fe603eca1 Update sft_trainer.py (#474)
* Update sft_trainer.py

Allows the user to give their own peft model arg. https://github.com/lvwerra/trl/issues/473

* cleaner
2023-06-28 00:44:15 +02:00
843c14574f fix CI RM (#468) 2023-06-26 14:30:06 +02:00
009b82412f Debug the tortuous logic in _prepare_dataset function (#464)
* Debug the tortuous logic in `_prepare_dataset` function

There are two issues with the previous `_prepare_dataset` function.

1. Tortuous and burdensome logic: the `is_already_dataset` variable is confusing and not helpful. So, remove it.
2. The comments and the logics do not match. 

For instance, in the previous version, the comments said "check if torch dataset ... and do nothing". However, when "dataset" is a torch.utils.data.Dataset and `packing = True`? It will still move into the _prepare_non_packed_dataloader(...) function call. 

The corrected version will do nothing if the dataset is already a torch dataloader/dataset/ConstantLengthDataset.

* Lint: sft_trainer.py

* Lint empty line
2023-06-24 08:43:03 +02:00
82c8f20601 Pre-commit (#448)
* Pre-commit

* modify CI

* modify make file

* temporarily disable codespell

* update make file

* update contribution guide

* pushc changes
2023-06-23 11:37:18 -04:00
b56e8b3277 Improve stabiliy: change default hyperparamers 2023-06-23 09:04:24 -04:00
0161a8e602 added shuffle parameter. I found it useful to turn off shuffle here and shuffle independently of this. (#457) 2023-06-23 11:47:08 +02:00
6e34c5932b set dev version 2023-06-23 09:20:25 +00:00
e1531aa526 Release: v0.4.6 2023-06-23 09:17:31 +00:00
cb6c45474a fix google colab issue (#459) 2023-06-23 11:13:36 +02:00
fe55b440e7 set dev version 2023-06-23 08:42:20 +00:00
431456732c Release: 0.4.5 2023-06-23 08:13:50 +00:00
9679d87012 Multi adapter RL (MARL) - a single model for RM & Value Head (#373)
* fix doc

* adapt from suggestions

* working v1 multiple adapters

* style

* style && quality

* oops

* docs

* add tests and docs

* add RM script

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/0_abstraction_rl.mdx

* Apply suggestions from code review

* Update docs/source/0_abstraction_rl.mdx

* add 4bit

* replace with `reward_adapter`

* explain break

* simple comment

* fix llama tokenizer

* fixes

* fixes

* rename

* quality

* rm unneeded file

* add disclaimer

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2023-06-22 11:19:45 +02:00
099f0bf42b Add accelerate project_config passthrough (#437) 2023-06-22 10:16:34 +02:00
33f88ead0b [ConstantLengthDataset] Fix packed dataset issue (#452)
* fix packed dataset issue

* Apply suggestions from code review

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* address

* more docs

* trigger CI

* fix failing CI

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-06-22 10:12:55 +02:00
7705daa672 [SFTTrainer] Introducing DataCollatorForCompletionOnlyLM (#445)
* v1 of alpaca datacollator

* make sure to match the response tokens

* add test

* add it in main init

* add check

* adapt test

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
2023-06-20 17:51:23 +02:00
fe49697e66 add stale bot (#447) 2023-06-19 17:26:17 +02:00
d1ad5405cb [SFTTrainer] Fix non packed dataset (#444)
* fix non packed dataset

* fixing tests and documentation

* Update docs/source/sft_trainer.mdx
2023-06-16 18:51:20 +02:00
1e88b84ab9 fix packing issue (#442) 2023-06-16 13:55:47 +02:00
c39207460f Drop support for Python 3.7 (#441)
* drop support for Python 3.7

* adapt
2023-06-16 13:30:01 +02:00
61af5f26b6 Fix correct gradient accumulation (#407)
* add correct grad acc

* add some tests but they fail

* test should pass

* style

* fix
2023-06-14 08:43:35 -04:00
7a89a43c3f handle the offline case (#431)
* handle the offline case

* adds warning
2023-06-13 15:36:12 +02:00
fead2c8c77 best-of-n sampler class (#375)
* First draft of best-of-n sampler class

* Formatting

* Add best-of-n class to init

* Rearrange files

* Correction

* Make sure input query is in shape

* check for numpy.ndarray type

* Fix for shapes and types AND linter fixes

* Make reward pipeline a callback for more broader application

* Documentation for best-of-n sampler class usage

* Docs update for best-of-n class

* Doc fixes for best-of-n sampler class

* Remove colon from new addition

* Change user callback output type and associated side-effects of said change

* Relocate param because of collision

* Documentation update

* Make input param keyword easier to grasp

* Remove comments and add docstrings

* Tests and fixes for best_of_n sampler class

* Change input arg name

* Formatting

* Removed unnecessary cloning
2023-06-13 10:25:21 +02:00
b4bb12992e Update test_reward_trainer.py (#421) 2023-06-09 15:52:41 +02:00
b21baddc5c [doc build] Use secrets (#420) 2023-06-09 15:52:10 +02:00
216c119fa9 Enable autotag feature w/ wandb (#411)
* Enable autotag feature

* use `logging.info`

* Update trl/trainer/ppo_config.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/ppo_config.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-06-09 11:20:18 +02:00
a2747acc0f Add slurm utility (#412)
* Add slurm utility

* move files
2023-06-09 11:04:43 +02:00
b61a4b95a0 set dev version 2023-06-08 14:28:37 +00:00
5c5d7687d8 Release: v0.4.4 2023-06-08 14:26:14 +00:00
096f5e9da5 unpin accelerate (#418) 2023-06-08 16:25:03 +02:00
2a0ed3a596 set dev version 2023-06-08 08:55:33 +00:00
ff13c5bc6d Release: v0.4.3 2023-06-08 08:52:04 +00:00
d3e05d6490 Update setup.py (#414) 2023-06-08 10:49:03 +02:00
fadffc22bc Update test_reward_trainer.py (#410) 2023-06-07 12:22:22 +02:00
d405c87068 set dev version 2023-06-07 10:22:06 +00:00
b46716c4f5 Release: v0.4.2 2023-06-07 09:43:23 +00:00
ec8a5b7679 Remove unused imports in docs. (#406)
* remove unused var

* bug fix

* update docs, add e2e CI

* black

* isort

* CI
2023-06-06 18:06:49 +02:00
376d152d3f Resolve broken evaluation/prediction for RewardTrainer (#404)
* Implement evaluation/prediction for RewardTrainer

* Stick with unittest assertions

* Perform prediction forward calls without gradient

* Remove Literal to preserve Python 3.7 support

I recognize that I can also import from typing_extensions with a try-except,
but that is a bit overkill for this I feel.

* Remove eval_steps=1 to prevent flaky test on CI

The flaky test is caused by a division by zero when dividing by the runtime.
This is done on the transformers side, so it's not a TRL issue.
In practice, this won't happen - it only happens because both the model
and dataset are tiny.
2023-06-06 16:49:30 +02:00
ef57cddbc3 StackLLaMA: fix supervised finetuning and reward model training (#399)
* better reward modelling

tokenizer can be separately specified from model
removed old llama tokenizer hacks
evaluate after first step option to make nicer graphs
black + isort

* removed tokenizer hacks from supervised ft

* black and flake8
2023-06-06 10:41:07 +02:00
20111ad03a Fixed some type annotations of trl.trainer.PPoTrainer (#392)
* Fixed some type annotations of trl.trainer.PPoTrainer

- Ref model should be Optional
- The usual annotation for the Huggingface tokenizers is PreTrainedTokenizerBase. Not using that messes up people's annotation checks.
- Fixed the comments wrt the other two points

* fix quality and style

* synced & requality & restyled
2023-06-06 10:32:37 +02:00
a4793c2ede StackLlama: fixed RL training and added args (#400)
* fixed rl training args

added steps argument and break to respect max training epochs
added more PPOConfig args to script args
removed llama tokenizer hacks
removed extra args in dataset
changed to llamatokenizer from autotokenizer
black + isort

* black and flake8

* style, quality, and switch back to AutoTokenizer
2023-06-05 10:30:20 +02:00
0ddf9f657f StackLLaMA: correctly merge peft model (#398)
* correctly merge stackllama models

correctly merge weights with peft's merge_and_unload
load sequence classification model for reward models

* style, black line length 119

* flake8
2023-06-05 10:25:53 +02:00
3138ef6f5a fix 4 bit SFT (#396) 2023-06-02 10:49:41 +02:00
a5b0414f63 keep state_dict kwargs instead of popping it in save_pretrained (#393) 2023-05-31 10:56:45 +02:00
e174bd50a5 from_pretrain with peft adapter on the hub (# 379) (#380)
* from_pretrain with peft adapter on the hub (# 379)

* Update the comment

* PR comment
2023-05-31 10:38:25 +02:00
86c117404c fix typo in ppo_trainer.py (#389)
`dataloader must be a torch.utils.data.Dataset`: `dataloader` should be `dataset`
2023-05-30 15:23:02 +02:00
a94761a02c Update customization.mdx (#390) 2023-05-30 15:22:41 +02:00
5fb5af7c34 [core] Add 4bit QLora (#383)
* add 4bit

* style
2023-05-24 13:52:38 +02:00
25fa1bd880 fix warning issue (#377) 2023-05-18 08:43:44 +02:00
6916e0d2df [docs] fix SFT doc (#367)
* fix doc

* adapt from suggestions
2023-05-15 16:26:27 +02:00
1704a864e7 Delete test_training.py (#371) 2023-05-15 16:21:28 +02:00
e547c392f9 Remove obsolete layer_norm_names parameter and add peft>=0.3.0 to requirements (#366)
* remove obsolete layer_norm_names parameter

* remove obsolete parameter layer_norm_names and add peft>=0.3.0 to requirements

* make style - oops

* typo
2023-05-15 16:08:11 +02:00
a31bad83fb add is_trainable in kwargs (#363)
Add is_trainable in kwargs to enable continue training of peft model.
2023-05-15 16:08:00 +02:00
31cc361d17 Fix bug when loading local peft model (#342)
* Fix bug when loading local peft model 

Fix bug in https://github.com/lvwerra/trl/issues/341

* Fix loading bug when load lora mode

Fix loading bug when load lora model but not resuming training

1. Implement the fix logic described in https://github.com/lvwerra/trl/pull/342#pullrequestreview-1422298054

2. Set peft lora weight to trainable.

* Remove is_trainable

Leave is_trainable to future PR.

* add test_load_pretrained_peft

Check that the model saved with peft class interface can be loaded properly.
2023-05-11 23:07:50 +02:00
ab453ec183 140/best n sampling (#326)
* Create best_of_n.ipynb

* First draft

* Refactor as ref vs ppo vs non-ppo

* Changed notebook location and added README to explain motivation

* 1. Spelling and formatting refactor
2. Minor refactor of notebook

* Formatting of notebook
2023-05-11 17:56:12 +02:00
933c91cc66 fix tensorboard issue (#330) 2023-05-11 17:45:59 +02:00
ffad0a19d0 relax negative KL constraint (#352) 2023-05-11 17:45:47 +02:00
e0172fc8ec add parameter to control max_length (to mitigate OOM errors) (#359) 2023-05-11 15:28:32 +02:00
dec9993129 stack_llama: update instructions in README, fix broken _get_submodules and save tokenizer (#358)
* update instructions in README and fix broken _get_submodules

* save tokenizer

* add note about peft>=0.3.0
2023-05-11 12:29:02 +02:00
c85cdbdbd0 Fix argument's description (#339) 2023-05-04 14:29:07 +02:00
e59cce9f81 fix sft issues (#336) 2023-05-03 12:53:32 +02:00
c60fd915c1 [core] officially support SFT (Supervised Finetuning) (#323)
* add v1

* revert

* correct filename

* add tests and final tweaks

* fix tests

* adapt from offline suggestions

* Update trl/trainer/sft_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* fixes

* remove warning

* multiple fixes

* fixes

* fix

* final fixes

* final fix

* more clarification

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add test

* add arg

* add callback instructions

* add formatting_prompts_func

* try docs

* add CLD

* fix docstrings

* format

* Update docs/source/sft_trainer.mdx

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* remove `prepare_in_int8_kwargs`

* change `return_overflowing_tokens`

* add warnings

* address comments

* revert pretrained kwargs

* quality

* fix sft script

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2023-05-03 10:42:01 +02:00
08f550674c added doc for using torch.distributed.launch/run (#324)
* added doc for using torch.distributed.launch/run

* Update docs/source/customization.mdx

---------

Co-authored-by: Afshin Oroojlooyjadid <afshin.oroojlooyjadid@oracle.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-04-28 16:18:07 +02:00
52fecee883 Give a key to the wandb PPOConfig config entry (#315)
* Give a key to the wandb PPOConfig config entry

There is a lot of stuff with very generic keys in the `PPOConfig` dict, and the user may have logged a `wandb` config dict elsewhere.
I know I had that problem. To counter that, I pass the PPOConfig dict in a dict under the key `trl_ppo_trainer_config`, to prevent collisions & be very clear.

* did black --line-length 119 --target-version py38 examples tests trl
isort examples tests trl and black --check --line-length 119 --target-version py38 examples tests trl
isort --check-only examples tests trl
flake8 examples tests trl
2023-04-26 22:14:55 +02:00
3cfe194e34 [core] Officially Support Reward Modeling (#303)
* v1

- add working version
- add all possible tests
- add docs

* add some contents

* clean up

* fixes

* patch test for now

* fix test

* clean up

* fix

* this time fix

* Update docs/source/trainer.mdx

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* fixe

* update

* final changes

* oops

* Update docs/source/reward_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/reward_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/reward_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* switch to chosen / rejected

* fixes

* add example

* add accuracy metric

* pass PEFT config

* refactor compute metrics

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2023-04-26 11:51:56 +02:00
ad325152cc add details on multi-GPU / multi-node (#320) 2023-04-26 11:12:15 +02:00
1f29725381 fix broken tests (#318) 2023-04-25 13:57:40 +02:00
23a06c94b8 fix DS for peft ref_model in ppo trainer (#309)
peft ref_model is got by calling `disable_adapter` method, e.g. ,
```
with self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter():
    ref_logprobs, _, _, _ = self.batched_forward_pass(self.model, queries, responses, model_inputs)
```
2023-04-25 12:52:49 +02:00
5c24d5bb2e fixed typo (#312) 2023-04-25 11:38:28 +02:00
503ac5d82c clean examples folder (#294)
* clean examples folder

* Update examples/toxicity/README.md
2023-04-25 11:33:54 +02:00
ce37eadcfa Log Token distribution of Query / Response (#295)
* reset git

* move to log_step_stats, make optional

* fix stack

* reset script

* fix types

* always log, add dist
2023-04-17 17:49:14 +02:00
160d0c9d6c [t5] Fix negative kl issue (#262)
* fix negative kl issue

* fix

* make style
2023-04-14 11:50:17 +02:00
d1c7529328 Fix arguments description (#298)
* Fix arguments description

* fix-argument-description

* Fix-argument-description
2023-04-12 16:00:42 +02:00
fc468e0f35 Small improvements / fixes to toxicity example (#266)
* fixes during debugging

* Update examples/toxicity/scripts/gpt-j-6b-toxicity.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-04-10 14:24:06 -07:00
131e5cdd10 add functionality to push best models to the hub during training (#275)
* add functionality to push best models to the hub during training

* fix indentation

* Update tests/test_ppo_trainer.py

Co-authored-by: Nathan Lambert <nathan@huggingface.co>

* Update trl/trainer/ppo_trainer.py

Co-authored-by: Nathan Lambert <nathan@huggingface.co>

* Update trl/trainer/ppo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fix style

---------

Co-authored-by: Nathan Lambert <nathan@huggingface.co>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-04-10 11:32:53 -07:00
bb4a9800fa fix typo in gpt2-sentiment.ipynb (#293)
inital -> initial
2023-04-10 20:08:55 +02:00
3804a72e6c Fix swapped helper texts (#284) 2023-04-10 10:23:37 +02:00
a004b02c4a Add LLaMA tutorial to docs (#278)
* docs docs docs

* add truncated blog to docs
2023-04-07 08:16:42 -07:00
8b234479bc fix doc string problem in ppo trainer loss function (#279)
* fix a loss function docstring problem

`hidden_dim` should be `response_length`

* Update ppo_trainer.py
2023-04-07 10:22:02 +02:00
meg
cf20878113 Adding pointer back to Meta's LLaMA. (#277) 2023-04-06 14:04:12 -07:00
d8ae4d08c6 stack-llama (#273)
* adds the main scripts

* adds non-score reward clamping

* Adds adapter merge script.

* style

* adds non_reward clamp option to config

* reverts kl clamping

* style

* makes model name required for adapter merge

* updates merge adapter so it does not refer to HF internal llama checkpoints

* renames to stack_llama, adds clearer instructions

* updates readme, adds ds config

* Update examples/stack_llama/scripts/rl_finetuning_peft.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update examples/stack_llama/scripts/rl_finetuning_peft.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* removes ds config, renamed scripts

* style

* updates launch commands

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-04-05 17:11:43 +02:00
a2749d9e0c Use active model to generate response in example on README (#269) (#271)
Co-authored-by: rmilleti <rmilleti@amazon.com>
2023-04-03 15:36:46 +02:00
ed87942a47 Add LlaMa in tests + create_reference_model (#261)
* add LlaMa in tests

* Update tests/test_modeling_value_head.py

* add warning message

---------

Co-authored-by: Nathan Lambert <nathan@huggingface.co>
2023-03-30 10:49:46 +02:00
734624274d [core] Fix ds issue (#260)
* fix ds issue

* more comments
2023-03-29 14:20:27 +02:00
237eb9c6a5 [distributed] Fix early stopping and DP (#254)
* fix ES DP

* fix coef

* wrap in a private method

* fix value

* fix trainer logic
2023-03-28 14:31:16 +02:00
2672a942a6 [core] Fix DeepSpeed zero-3 issue (#182)
* fix zero-3 issue

* Update trl/trainer/ppo_trainer.py

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* adapt

* make style

* fix

* add docs

* fix

---------

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
2023-03-28 13:43:52 +02:00
b5cce0d13e Using batched generate in sentiment scripts (#249)
Co-authored-by: gaurav.vi <gaurav.vi@media.net>
2023-03-27 12:09:50 +02:00
0b165e60bc Fix typo (#253) 2023-03-27 11:59:15 +02:00
404621f0f9 Improve logging for PPO + Docs page (#243)
* init pr

* try and fix docpreview

* fix

* try to fix tests

* nit

* fix tests

* convert to tensor
2023-03-24 09:34:57 +01:00
89df6abf21 feat(ci): enable pip cache (#198)
* feat(ci): add pip caching to CI

* feat(ci): create workflow to cleanup cache

* feat(ci): enable `pip` caching in CI
2023-03-24 09:33:43 +01:00
9523474490 PPO config __init__ is bloated (#241)
* Moving `total_ppo_epochs`, forward_batch_size` and `log_with` to post init method and let the dataclass automatically assign the other member variables.

* Using default factory functions for initializing dict

* Using fields + metadata for args description

* Reformatting the file using black(jupyter)

* Trying styling checks again

* Adding new args from PR 238

---------

Co-authored-by: gaurav.vi <gaurav.vi@media.net>
2023-03-24 09:33:22 +01:00
1620da371a adds early stopping (#238)
* adds early stopping

* zero opt grad

* style

* Fixed typo in early stopping property description

* Auto stash before rebase of "origin/main"
2023-03-23 15:24:04 +01:00
9b60207f0b [core] Add warning when negative KL (#239)
* add warning

* oops

* fix

* Update trl/trainer/ppo_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-03-22 12:18:43 +01:00
a6ebdb6e75 Reduce memory consumption in batched_forward_pass (#234)
* Reduce memory consumption by not storing logits in forward_pass

* Add docstring of return_logits
2023-03-22 10:18:18 +01:00
9c3e9e43d0 Batched generation (#228)
* add `_generate_batch`

* fix style

* omit tensor conversion

* no multiple pad by default

* add test

* stylez

* update docstring

* encoder/decoder check

* input shape safety

* moar style

---------

Co-authored-by: leandro von werra <leandro@hf.co>
2023-03-21 16:48:34 +01:00
0610711dda [core] refactor peft API (#231)
* refactor peft API

* update gpt2 peft script

* refactor

* few fixes

* fix bug

* make style

* update docs

* more update

* fix docs

* fix issues and add tests

* make style

* update dcos
2023-03-21 13:35:21 +01:00
24627e9c89 set dev version 2023-03-17 10:40:04 +00:00
258 changed files with 48087 additions and 4939 deletions

93
.github/ISSUE_TEMPLATE/bug-report.yml vendored Normal file
View File

@ -0,0 +1,93 @@
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve TRL
labels: [ "bug" ]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report! 🤗
🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ...
You can get this information by running `trl env` in your terminal.
placeholder: Copy-paste the output of `trl env`
validations:
required: true
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: 'The problem arises when using:'
options:
- label: "The official example scripts"
- label: "My own modified scripts"
- type: checkboxes
id: information-tasks
attributes:
label: Tasks
description: "The tasks I am working on are:"
options:
- label: "An officially supported task in the `examples` folder"
- label: "My own task or dataset (give details below)"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
If you have code snippets, error messages, stack traces please provide them here as well.
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
value: |
```python
from trl import ...
```
outputs:
```
Traceback (most recent call last):
File "example.py", line 42, in <module>
...
```
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."
- type: checkboxes
id: terms
attributes:
label: Checklist
description: |
Before submitting, please confirm that you've completed each of the following.
If an item doesn't apply to your issue, check it anyway to show you've reviewed it.
options:
- label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))"
required: true
- label: "I have included my system information"
required: true
- label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
required: true
- label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
required: true
- label: "Any traceback provided is complete"
required: true

View File

@ -0,0 +1,31 @@
name: "\U0001F680 Feature request"
description: Submit a proposal/request for a new TRL feature
labels: [ "Feature request" ]
body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request
description: |
A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation
description: |
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
- type: textarea
id: contribution
validations:
required: true
attributes:
label: Your contribution
description: |
Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md)

View File

@ -0,0 +1,32 @@
name: "\U0001F31F New trainer addition"
description: Submit a proposal/request to implement a new trainer for a post-training method
labels: [ "New trainer" ]
body:
- type: textarea
id: description-request
validations:
required: true
attributes:
label: Method description
description: |
Put any and all important information relative to the method
- type: checkboxes
id: information-tasks
attributes:
label: Open source status
description: |
Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`.
options:
- label: "The method implementation is available"
- label: "The model weights are available"
- label: "The training datasets are available"
- type: textarea
id: additional-info
attributes:
label: Provide useful links for the implementation
description: |
Please provide information regarding the implementation, the weights, and the authors.
Please mention the authors by @gh-username if you're aware of their usernames.

32
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,32 @@
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet though.
Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
Once you're done, someone will review your PR shortly. They may suggest changes to make the code even better.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request),
Pull Request section?
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the
[documentation guidelines](https://github.com/huggingface/trl/tree/main/docs).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

View File

@ -13,7 +13,7 @@ jobs:
with:
commit_sha: ${{ github.sha }}
package: trl
repo_owner: lvwerra
version_tag_suffix: ""
custom_container: huggingface/transformers-doc-builder
secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@ -14,5 +14,5 @@ jobs:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: trl
repo_owner: lvwerra
version_tag_suffix: ""
version_tag_suffix: ""
custom_container: huggingface/transformers-doc-builder

33
.github/workflows/clear_cache.yml vendored Normal file
View File

@ -0,0 +1,33 @@
name: "Cleanup Cache"
on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
jobs:
cleanup:
runs-on: ubuntu-latest
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Cleanup
run: |
gh extension install actions/gh-actions-cache
REPO=${{ github.repository }}
echo "Fetching list of cache key"
cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 )
## Setting this to not fail the workflow while deleting cache keys.
set +e
echo "Deleting caches..."
for cacheKey in $cacheKeysForPR
do
gh actions-cache delete $cacheKey -R $REPO --confirm
done
echo "Done"
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,13 +0,0 @@
name: Delete dev documentation
on:
pull_request:
types: [ closed ]
jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
with:
pr_number: ${{ github.event.number }}
package: trl

95
.github/workflows/docker-build.yml vendored Normal file
View File

@ -0,0 +1,95 @@
name: Build Docker images (scheduled)
on:
workflow_dispatch:
workflow_call:
schedule:
- cron: "0 1 * * *"
concurrency:
group: docker-image-builds
cancel-in-progress: false
env:
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
jobs:
trl-latest:
name: "Latest TRL GPU"
runs-on: ubuntu-latest
steps:
- name: Cleanup disk
run: |
sudo ls -l /usr/local/lib/
sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Check out code
uses: actions/checkout@v4
- name: Login to DockerHub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@v4
with:
context: ./docker/trl-latest-gpu
push: true
tags: huggingface/trl-latest-gpu
- name: Post to Slack
if: always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the trl-latest-gpu Docker Image build
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
trl-source:
name: "Latest TRL + HF ecosystem from source"
runs-on: ubuntu-latest
steps:
- name: Cleanup disk
run: |
sudo ls -l /usr/local/lib/
sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Check out code
uses: actions/checkout@v4
- name: Login to DockerHub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@v4
with:
context: ./docker/trl-source-gpu
push: true
tags: huggingface/trl-source-gpu
- name: Post to Slack
if: always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the trl-source-gpu Docker Image build
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

98
.github/workflows/slow-tests.yml vendored Normal file
View File

@ -0,0 +1,98 @@
name: Slow tests (on push)
on:
push:
branches: [ main ]
paths:
# Run only when python files are modified
- "trl/**.py"
- "examples/**.py"
env:
RUN_SLOW: "yes"
IS_GITHUB_CI: "1"
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
jobs:
run_all_tests_single_gpu:
strategy:
fail-fast: false
matrix:
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
runs-on:
group: aws-g4dn-2xlarge
env:
CUDA_VISIBLE_DEVICES: "0"
TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
container:
image: ${{ matrix.docker-image-name }}
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v4
- name: Pip install
run: |
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
- name: Run slow SFT tests on single GPU
if: always()
run: |
source activate trl
make slow_tests
- name: Generate Report
if: always()
run: |
pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
run_all_tests_multi_gpu:
strategy:
fail-fast: false
matrix:
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
runs-on:
group: aws-g4dn-2xlarge
env:
CUDA_VISIBLE_DEVICES: "0,1"
TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
container:
image: ${{ matrix.docker-image-name }}
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v4
- name: Pip install
run: |
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
- name: Run slow SFT tests on Multi GPU
if: always()
run: |
source activate trl
make slow_tests
- name: Run end-to-end examples tests on multi GPU
if: always()
run: |
source activate trl
pip install deepspeed
make test_examples
- name: Generate Reports
if: always()
run: |
pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
rm *.txt

View File

@ -1,47 +1,163 @@
name: tests
name: Tests
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
# Run only when relevant files are modified
- ".github/**.yml"
- "examples/**.py"
- "scripts/**.py"
- "tests/**.py"
- "trl/**.py"
- "setup.py"
env:
TQDM_DISABLE: 1
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
jobs:
check_code_quality:
name: Check code quality
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
- uses: actions/checkout@v4
with:
python-version: "3.8"
fetch-depth: 0
submodules: recursive
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: 3.12
- uses: pre-commit/action@v3.0.1
with:
extra_args: --all-files
tests:
name: Tests
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
os: ['ubuntu-latest', 'windows-latest']
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
- name: Check quality
python -m pip install ".[dev]"
- name: Test with pytest
run: |
make quality
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python ${{ matrix.python-version }} on ${{ matrix.os }} with lastest dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests:
needs: check_code_quality
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
runs-on: ${{ matrix.os }}
tests_dev:
name: Tests with dev dependencies
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install .[test]
- name: Test with pytest
run: |
make test
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -U git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/datasets.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install ".[dev]"
- name: Test with pytest
run: |
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python 3.12 on ubuntu-latest with dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_wo_optional_deps:
name: Tests without optional dependencies
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[test]"
- name: Test with pytest
run: |
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python 3.12 on ubuntu-latest without optional dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_min_versions:
name: Tests with minimum versions
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install accelerate==0.34.0
python -m pip install datasets==2.21.0
python -m pip install transformers==4.46.0
python -m pip install ".[dev]"
- name: Test with pytest
run: |
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python 3.12 on ubuntu-latest with minimum versions
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

45
.github/workflows/tests_latest.yml vendored Normal file
View File

@ -0,0 +1,45 @@
name: Tests latest TRL release with dev dependencies
on:
schedule:
- cron: '0 0 * * *' # Runs daily at midnight UTC
workflow_dispatch:
env:
TQDM_DISABLE: 1
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
jobs:
tests:
name: Tests latest TRL release with dev dependencies
runs-on: 'ubuntu-latest'
steps:
- name: Git checkout
uses: actions/checkout@v4
with: { ref: v0.13-release }
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -U git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/datasets.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install ".[dev]"
- name: Test with pytest
run: |
make test
- name: Post to Slack
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results of latest TRL with Python 3.12 on ubuntu-latest with dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

15
.github/workflows/trufflehog.yml vendored Normal file
View File

@ -0,0 +1,15 @@
on:
push:
name: Secret Leaks
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

View File

@ -0,0 +1,16 @@
name: Upload PR Documentation
on:
workflow_run:
workflows: ["Build PR Documentation"]
types:
- completed
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: trl
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}

2
.gitignore vendored
View File

@ -142,4 +142,4 @@ checklink/cookies.txt
# wandb files
nbs/wandb/
examples/notebooks/wandb/
wandb/
wandb/

17
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,17 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.3
hooks:
- id: ruff
types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format
types_or: [ python, pyi ]
# - repo: https://github.com/codespell-project/codespell
# rev: v2.1.0
# hooks:
# - id: codespell
# args:
# - --ignore-words-list=nd,reacher,thist,ths,magent,ba
# - --skip=docs/css/termynal.css,docs/js/termynal.js

View File

@ -17,7 +17,13 @@ authors:
family-names: Thrush
- given-names: Nathan
family-names: Lambert
repository-code: 'https://github.com/lvwerra/trl'
- given-names: Shengyi
family-names: Huang
- given-names: Kashif
family-names: Rasul
- given-names: Quentin
family-names: Gallouédec
repository-code: 'https://github.com/huggingface/trl'
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
keywords:
- rlhf
@ -25,4 +31,4 @@ keywords:
- pytorch
- transformers
license: Apache-2.0
version: 0.2.1
version: 0.12

133
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,133 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
feedback@huggingface.co.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations

View File

@ -1,48 +1,338 @@
# How to contribute
# How to contribute to TRL?
## How to get started
Everyone is welcome to contribute, and we value everybody's contribution. Code
contributions are not the only way to help the community. Answering questions, helping
others, and improving the documentation are also immensely valuable.
Before you start contributing make sure you installed all the dev tools:
It also helps us if you spread the word! Reference the library in blog posts
about the awesome projects it made possible, shout out on Twitter every time it has
helped you, or simply ⭐️ the repository to say thank you.
However you choose to contribute, please be mindful and respect our
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
## Ways to contribute
There are several ways you can contribute to TRL:
* Fix outstanding issues with the existing code.
* Submit issues related to bugs or desired new features.
* Implement trainers for new post-training algorithms.
* Contribute to the examples or the documentation.
If you don't know where to start, there is a special [Good First
Issue](https://github.com/huggingface/trl/contribute) listing. It will give you a list of
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
> All contributions are equally valuable to the community. 🥰
Before you start contributing make sure you have installed all the dev tools:
```bash
pip install -e ".[dev]"
pip install -e .[dev]
```
## Did you find a bug?
## Fixing outstanding issues
* Ensure the bug was not already reported by searching on GitHub under Issues.
* If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
* Be sure to add the complete error messages.
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#create-a-pull-request) and open a Pull Request!
#### Did you write a patch that fixes a bug?
## Submitting a bug-related issue or feature request
* Open a new GitHub pull request with the patch.
* Ensure that your PR includes a test that fails without your patch, and pass with it.
* Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback.
## PR submission guidelines
### Did you find a bug?
* Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
* Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
* Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
* Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
* If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
The TRL library is robust and reliable thanks to users who report the problems they encounter.
### Before you submit a PR
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
First you want to make sure that all the tests pass:
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
* A short, self-contained, code snippet that allows us to reproduce the bug in
less than 30s.
* The *full* traceback if an exception is raised.
* Attach any other additional information, like screenshots, you think may help.
To get the OS and software versions automatically, run the following command:
```bash
make test
trl env
```
Then before submitting your PR make sure the code quality follows the standards. You can run the following command to format and test:
### Do you want a new feature?
If there is a new feature you'd like to see in TRL, please open an issue and describe:
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?
Whatever it is, we'd love to hear about it!
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
3. Provide a *code snippet* that demonstrates the feature's usage.
4. If the feature is related to a paper, please include a link.
If your issue is well written we're already 80% of the way there by the time you create it.
## Do you want to implement a new trainer?
New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL:
* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods.
* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM.
Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
* A short description of the method and a link to the paper.
* Link to the implementation if it is open-sourced.
* Link to model weights trained with the method if they are available.
Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration:
* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py)
* RL-based optimisation: [`rloo_trainer.py](./trl/trainer/rloo_trainer.py) and [`rloo_config.py](./trl/trainer/rloo_config.py)
* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py)
## Do you want to add documentation?
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested!
## Submitting a pull request (PR)
Before writing code, we strongly advise you to search through the existing PRs or
issues to make sure that nobody is already working on the same thing. If you are
unsure, it is always a good idea to open an issue to get some feedback.
You will need basic `git` proficiency to be able to contribute to
TRL. `git` is not the easiest tool to use but it has the greatest
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing:
1. Fork the [repository](https://github.com/huggingface/trl) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
```bash
$ git clone git@github.com:<your Github handle>/trl.git
$ cd trl
$ git remote add upstream https://github.com/huggingface/trl.git
```
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
```bash
$ git checkout main
$ git fetch upstream
$ git merge upstream/main
```
Once your `main` branch is synchronized, create a new branch from it:
```bash
$ git checkout -b a-descriptive-name-for-my-changes
```
**Do not** work on the `main` branch.
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
```bash
$ pip install -e .[dev]
```
(If TRL was already installed in the virtual environment, remove
it with `pip uninstall trl` before reinstalling it.)
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite
passes. You should run the tests impacted by your changes like this (see
below an explanation regarding the environment variable):
```bash
$ pytest tests/<TEST_TO_RUN>.py
```
> For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
> Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).
You can also run the full suite with the following command.
```bash
$ make test
```
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR.
To apply these checks and corrections in one step, use:
```bash
$ make precommit
```
This command runs the following:
- Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
- Runs additional scripts such as adding copyright information.
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
Once you're happy with your changes, add changed files using `git add` and
make a commit with `git commit` to record your changes locally:
```bash
$ git add modified_file.py
$ git commit
```
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
It is a good idea to sync your copy of the code with the original
repository regularly. This way you can quickly account for changes:
```bash
$ git fetch upstream
$ git rebase upstream/main
```
Push the changes to your account using:
```bash
$ git push -u origin a-descriptive-name-for-my-changes
```
6. Once you are satisfied (**and the checklist below is happy too**), go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
### Checklist
1. The title of your pull request should be a summary of its contribution;
2. If your pull request addresses an issue, please mention the issue number in
the pull request description to make sure they are linked (and people
consulting the issue know you are working on it);
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
it from PRs ready to be merged;
4. Make sure existing tests pass;
5. Add high-coverage tests. No quality testing = no merge.
### Tests
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
We use `pytest` to run the tests. From the root of the
repository here's how to run tests with `pytest` for the library:
```bash
make style && make quality
$ python -m pytest -sv ./tests
```
## Do you want to contribute to the documentation?
That's how `make test` is implemented (without the `pip install` line)!
* Docs are in the `docs/` folder and can be updated there.
You can specify a smaller set of tests to test only the feature
you're working on.
### Deprecation and Backward Compatibility
Our approach to deprecation and backward compatibility is flexible and based on the features usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
Example:
```python
warnings.warn(
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
"Please use the `Trainer.bar` class instead.",
FutureWarning,
)
```
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
### Working with warnings
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
#### Definitions
- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
#### Choosing the right message
- **Correct → No warning**:
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
- **Correct but deserves attention → No warning, possibly a log message**:
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
```python
logger.info("This is an informational message about a rare but correct operation.")
```
- **Correct but very likely a mistake → Warning with option to disable**:
In rare cases, you may want to issue a warning for a correct operation thats very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
```python
def my_function(foo, bar, _warn=True):
if foo == bar:
if _warn:
warnings.warn("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
# Do something
```
- **Supported but not correct → Warning**:
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
```python
def my_function(foo, bar):
if foo and bar:
warnings.warn("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
# Do something
```
- **Not supported → Exception**:
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
```python
def my_function(foo, bar):
if foo and bar:
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
```
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.

View File

@ -3,3 +3,4 @@ include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
include trl/templates/*.md

View File

@ -1,15 +1,32 @@
.PHONY: quality style test
.PHONY: test precommit common_tests slow_tests test_examples tests_gpu
check_dirs := examples tests trl
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
COMMAND_FILES_PATH = `pwd`/commands
test:
python -m pytest -n auto --dist=loadfile -s -v ./tests/
python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/
quality:
black --check --line-length 119 --target-version py38 $(check_dirs)
isort --check-only $(check_dirs)
flake8 $(check_dirs)
precommit:
pre-commit run --all-files
python scripts/add_copyrights.py
style:
black --line-length 119 --target-version py38 $(check_dirs)
isort $(check_dirs)
tests_gpu:
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
slow_tests:
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
test_examples:
touch temp_results_sft_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
echo $$?','$${file} >> temp_results_sft_tests.txt; \
done
touch temp_results_dpo_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
done

256
README.md
View File

@ -1,121 +1,219 @@
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png">
</div>
# TRL - Transformer Reinforcement Learning
> Train transformer language models with reinforcement learning.
## What is it?
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point most of decoder architectures and encoder-decoder architectures are supported.
**Highlights:**
- `PPOTrainer`: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- `AutoModelForCausalLMWithValueHead` & `AutoModelForSeq2SeqLMWithValueHead`: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier.
## How it works
Fine-tuning a language model via PPO consists of roughly three steps:
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate to far from the reference language model. The active language model is then trained with PPO.
This process is illustrated in the sketch below:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png" alt="TRL Banner">
</div>
<hr> <br>
<h3 align="center">
<p>A comprehensive library to post-train foundation models</p>
</h3>
<p align="center">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
</p>
## Overview
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
## Highlights
- **Efficient and scalable**:
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
- Full integration with [`PEFT`](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
- Integrates [Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune and interact with models without needing to write code.
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`ORPOTrainer`](https://huggingface.co/docs/trl/orpo_trainer) and more.
- **AutoModels**: Use pre-defined model classes like [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) to simplify reinforcement learning (RL) with LLMs.
## Installation
### Python package
Install the library with pip:
### Python Package
Install the library using `pip`:
```bash
pip install trl
```
### From source
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
If you want to use the latest features before an official release, you can install TRL from source:
```bash
git clone https://github.com/lvwerra/trl.git
cd trl/
pip install .
pip install git+https://github.com/huggingface/trl.git
```
If you wish to develop TRL, you should install in editable mode:
### Repository
If you want to use the examples you can clone the repository with the following command:
```bash
pip install -e .
git clone https://github.com/huggingface/trl.git
```
## Command Line Interface (CLI)
You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:
**SFT:**
```bash
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara \
--output_dir Qwen2.5-0.5B-SFT
```
**DPO:**
```bash
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name argilla/Capybara-Preferences \
--output_dir Qwen2.5-0.5B-DPO
```
**Chat:**
```bash
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
## How to use
### Example
This is a basic example on how to use the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
### `SFTTrainer`
Here is a basic example of how to use the `SFTTrainer`:
```python
# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
dataset = load_dataset("trl-lib/Capybara", split="train")
tokenizer = AutoTokenizer.from_pretrained('gpt2')
# initialize trainer
ppo_config = PPOConfig(
batch_size=1,
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
# get model response
response_tensor = respond_to_batch(model_ref, query_tensor)
# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]
# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
trainer.train()
```
### Advanced example: IMDB sentiment
For a detailed example check out the example python script `examples/sentiment/scripts/gpt2-sentiment.py`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
### `RewardTrainer`
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/table_imdb_preview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
</div>
Here is a basic example of how to use the `RewardTrainer`:
## References
```python
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
### Proximal Policy Optimisation
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id
### Language models
The language models utilize the `transformers` library by 🤗 Hugging Face.
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
args=training_args,
model=model,
processing_class=tokenizer,
train_dataset=dataset,
)
trainer.train()
```
### `RLOOTrainer`
`RLOOTrainer` implements a [REINFORCE-style optimization](https://huggingface.co/papers/2402.14740) for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the `RLOOTrainer`:
```python
from trl import RLOOConfig, RLOOTrainer, apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback-prompt")
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt")
training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL")
trainer = RLOOTrainer(
config=training_args,
processing_class=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
trainer.train()
```
### `DPOTrainer`
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the `DPOTrainer`:
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()
```
## Development
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]
```
## Citation
```bibtex
@misc{vonwerra2022trl,
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert},
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
title = {TRL: Transformer Reinforcement Learning},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/lvwerra/trl}}
howpublished = {\url{https://github.com/huggingface/trl}}
}
```
## License
This repository's source code is available under the [Apache-2.0 License](LICENSE).

58
commands/run_dpo.sh Normal file
View File

@ -0,0 +1,58 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# This is a hack to get the number of available GPUs
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/trl/scripts/dpo.py \
--model_name_or_path $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

59
commands/run_sft.sh Normal file
View File

@ -0,0 +1,59 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_sft/"
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
DATASET_NAME="stanfordnlp/imdb"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# Set your number of GPUs here
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/trl/scripts/sft.py \
--model_name $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_seq_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

View File

@ -0,0 +1,66 @@
# Builds GPU docker image of PyTorch
# Uses multi-staged approach to reduce size
# Stage 1
# Use base conda image to reduce time
FROM continuumio/miniconda3:latest AS compile-image
# Specify py version
ENV PYTHON_VERSION=3.10
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN apt-get update && \
apt-get install -y curl git wget software-properties-common git-lfs && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Install audio-related libraries
RUN apt-get update && \
apt install -y ffmpeg
RUN apt install -y libsndfile1-dev
RUN git lfs install
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
RUN python3 -m pip install --no-cache-dir --upgrade pip
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
# We don't install pytorch here yet since CUDA isn't available
# instead we use the direct torch wheel
ENV PATH /opt/conda/envs/trl/bin:$PATH
# Activate our bash shell
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
# Stage 2
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
COPY --from=compile-image /opt/conda /opt/conda
ENV PATH /opt/conda/bin:$PATH
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
RUN source activate trl && \
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
# Install apt libs
RUN apt-get update && \
apt-get install -y curl git wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Activate the conda env and install transformers + accelerate from source
RUN source activate trl && \
python3 -m pip install -U --no-cache-dir \
librosa \
"soundfile>=0.12.1" \
scipy \
transformers \
accelerate \
peft \
trl[test]@git+https://github.com/huggingface/trl
RUN source activate trl && \
pip freeze | grep trl
RUN echo "source activate trl" >> ~/.profile
# Activate the virtualenv
CMD ["/bin/bash"]

View File

@ -0,0 +1,66 @@
# Builds GPU docker image of PyTorch
# Uses multi-staged approach to reduce size
# Stage 1
# Use base conda image to reduce time
FROM continuumio/miniconda3:latest AS compile-image
# Specify py version
ENV PYTHON_VERSION=3.10
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN apt-get update && \
apt-get install -y curl git wget software-properties-common git-lfs && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Install audio-related libraries
RUN apt-get update && \
apt install -y ffmpeg
RUN apt install -y libsndfile1-dev
RUN git lfs install
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
RUN python3 -m pip install --no-cache-dir --upgrade pip
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
# We don't install pytorch here yet since CUDA isn't available
# instead we use the direct torch wheel
ENV PATH /opt/conda/envs/trl/bin:$PATH
# Activate our bash shell
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
# Stage 2
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
COPY --from=compile-image /opt/conda /opt/conda
ENV PATH /opt/conda/bin:$PATH
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
RUN source activate trl && \
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
# Install apt libs
RUN apt-get update && \
apt-get install -y curl git wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Activate the conda env and install transformers + accelerate from source
RUN source activate trl && \
python3 -m pip install -U --no-cache-dir \
librosa \
"soundfile>=0.12.1" \
scipy \
git+https://github.com/huggingface/transformers \
git+https://github.com/huggingface/accelerate \
git+https://github.com/huggingface/peft \
trl[test]@git+https://github.com/huggingface/trl
RUN source activate trl && \
pip freeze | grep transformers
RUN echo "source activate trl" >> ~/.profile
# Activate the virtualenv
CMD ["/bin/bash"]

View File

@ -1,26 +1,90 @@
- sections:
- sections:
- local: index
title: TRL
- local: quickstart
title: Quickstart
- local: installation
title: Installation
- local: quickstart
title: Quickstart
- local: clis
title: Get started with Command Line Interfaces (CLIs)
- local: dataset_formats
title: Dataset Formats
- local: how_to_train
title: PPO Training FAQ
- local: use_model
title: Use Trained Models
- local: customization
title: Customize your training
title: Customize the Training
- local: logging
title: Understanding Logs
title: Get started
- sections:
- sections: # Sort alphabetically
- local: alignprop_trainer
title: AlignProp
- local: bco_trainer
title: BCO
- local: cpo_trainer
title: CPO
- local: ddpo_trainer
title: DDPO
- local: dpo_trainer
title: DPO
- local: online_dpo_trainer
title: Online DPO
- local: gkd_trainer
title: GKD
- local: kto_trainer
title: KTO
- local: nash_md_trainer
title: Nash-MD
- local: orpo_trainer
title: ORPO
- local: ppo_trainer
title: PPO
- local: prm_trainer
title: PRM
- local: reward_trainer
title: Reward
- local: rloo_trainer
title: RLOO
- local: sft_trainer
title: SFT
- local: iterative_sft_trainer
title: Iterative SFT
- local: xpo_trainer
title: XPO
title: Trainers
- local: models
title: Model Classes
- local: trainer
title: Trainer Classes
- local: best_of_n
title: Best of N Sampling
- local: judges
title: Judges
- local: callbacks
title: Callbacks
- local: data_utils
title: Data Utilities
- local: text_environments
title: Text Environments
- local: script_utils
title: Script Utilities
title: API
- sections:
- sections:
- local: community_tutorials
title: Community Tutorials
- local: example_overview
title: Example Overview
- local: sentiment_tuning
title: Sentiment Tuning
- local: sentiment_tuning_peft
title: Peft support - Low rank adaption of 8 bit models
- local: summarization_reward_tuning
title: Summarization Reward Tuning
- local: lora_tuning_peft
title: Training with PEFT
- local: detoxifying_a_lm
title: Detoxifying a Language Model
- local: using_llama_models
title: Training StackLlama
- local: learning_tools
title: Learning to Use Tools
- local: multi_adapter_rl
title: Multi Adapter RLHF
title: Examples

View File

@ -0,0 +1,93 @@
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation
[![](https://img.shields.io/badge/All_models-AlignProp-blue)](https://huggingface.co/models?other=alignprop,trl)
## The why
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
<div style="text-align: center"><img src="https://align-prop.github.io/reward_tuning.png"/></div>
## Getting started with `examples/scripts/alignprop.py`
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
```batch
python alignprop.py --hf_user_access_token <token>
```
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
## Setting up the image logging hook function
Expect the function to be given a dictionary with keys
```python
['image', 'prompt', 'prompt_metadata', 'rewards']
```
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
### Key terms
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
- `prompt` : The prompt is the text that is used to generate the image
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
- `image` : The image generated by the Stable Diffusion model
Example code for logging sampled images with `wandb` is given below.
```python
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
```
### Using the finetuned model
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
```python
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.to("cuda")
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"dump/{prompt}.png")
```
## Credits
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).

100
docs/source/bco_trainer.mdx Normal file
View File

@ -0,0 +1,100 @@
# BCO Trainer
[![](https://img.shields.io/badge/All_models-BCO-blue)](https://huggingface.co/models?other=bco,trl)
TRL supports the Binary Classifier Optimization (BCO).
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
For a full example have a look at [`examples/scripts/bco.py`].
## Expected dataset type
The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Expected model format
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `BCOTrainer`
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
```py
training_args = BCOConfig(
beta=0.1,
)
bco_trainer = BCOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
)
```
After this one can then call:
```py
bco_trainer.train()
```
## Underlying Distribution matching (UDM)
In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts.
Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts.
If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM.
Choose an embedding model and tokenizer:
```py
embedding_model = AutoModel.from_pretrained(your_model_id)
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
# customize this function depending on your embedding model
def embed_prompt(input_ids, attention_mask, model):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state.mean(dim=1)
embedding_model = Accelerator().prepare_model(self.embedding_model)
embedding_func = partial(embed_prompt, model=embedding_model)
```
Set `prompt_sample_size` to defined how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
```py
training_args = BCOConfig(
beta=0.1,
prompt_sample_size=512,
)
bco_trainer = BCOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
embedding_func=embedding_func,
embedding_tokenizer=self.embedding_tokenizer,
)
bco_trainer.train()
```
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
## BCOTrainer
[[autodoc]] BCOTrainer
## BCOConfig
[[autodoc]] BCOConfig

72
docs/source/best_of_n.mdx Normal file
View File

@ -0,0 +1,72 @@
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
## Usage
To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
```python
from transformers import pipeline, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
from trl.extras import BestOfNSampler
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
tokenizer.pad_token = tokenizer.eos_token
# callable that takes a list of raw text and returns a list of corresponding reward scores
def queries_to_scores(list_of_strings):
return [output["score"] for output in reward_pipe(list_of_strings)]
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler)
```
And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method
```python
best_of_n.generate(query_tensors, device=device, **gen_kwargs)
```
The default sample size is 4, but you can change it at the time of instance initialization like so
```python
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8)
```
The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
```python
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2)
```
There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
```python
from transformers import GenerationConfig
generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id)
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config)
best_of_n.generate(query_tensors, device=device)
```
Furthermore, at the time of initialization you can set the seed to control repeatability of the generation process and the number of samples to generate for each query

21
docs/source/callbacks.mdx Normal file
View File

@ -0,0 +1,21 @@
# Callbacks
## SyncRefModelCallback
[[autodoc]] SyncRefModelCallback
## RichProgressCallback
[[autodoc]] RichProgressCallback
## WinRateCallback
[[autodoc]] WinRateCallback
## LogCompletionsCallback
[[autodoc]] LogCompletionsCallback
## MergeModelCallback
[[autodoc]] MergeModelCallback

175
docs/source/clis.mdx Normal file
View File

@ -0,0 +1,175 @@
# Command Line Interfaces (CLIs)
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
Currently supported CLIs are:
#### Training commands
- `trl dpo`: fine-tune a LLM with DPO
- `trl kto`: fine-tune a LLM with KTO
- `trl sft`: fine-tune a LLM with SFT
#### Other commands
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
- `trl env`: get the system information
## Fine-tuning with the CLI
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
Before using the `sft` or `dpo` commands make sure to run:
```bash
accelerate config
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
```yaml
model_name_or_path:
Qwen/Qwen2.5-0.5B
dataset_name:
stanfordnlp/imdb
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
```
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
```bash
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
```
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
### Supported Arguments
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
[[autodoc]] ModelConfig
You can pass any of these arguments either to the CLI or the YAML file.
### Supervised Fine-tuning (SFT)
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
```
The SFT CLI is based on the `trl/scripts/sft.py` script.
### Direct Policy Optimization (DPO)
To use the DPO CLI, you need to have a dataset in the TRL format such as
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
These datasets always have at least three columns `prompt, chosen, rejected`:
* `prompt` is a list of strings.
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
To do a quick start, you can run the following command:
```bash
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
```
The DPO CLI is based on the `trl/scripts/dpo.py` script.
#### Custom preference dataset
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
```
## Chat interface
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
<pre><code>$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
<strong><span style="color: blue;">&lt;Qwen/Qwen1.5-0.5B-Chat&gt;:</span></strong>
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,
and scalability. Ultimately, it depends on personal preference, needs, and goals.
</code></pre>
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Besides talking to the model there are a few commands you can use:
- `clear`: clears the current conversation and start a new one
- `example {NAME}`: load example named `{NAME}` from the config and use it as the user input
- `set {SETTING_NAME}={SETTING_VALUE};`: change the system prompt or generation settings (multiple settings are separated by a `;`).
- `reset`: same as clear but also resets the generation configs to defaults if they have been changed by `set`
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- `exit`: closes the interface
## Getting the system information
You can get the system information by running the following command:
```bash
trl env
```
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
```txt
Copy-paste the following information when reporting an issue:
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.1
- CUDA device: NVIDIA H100 80GB HBM3
- Transformers version: 4.45.0.dev0
- Accelerate version: 0.34.2
- Accelerate config:
- compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- mixed_precision: no
- use_cpu: False
- debug: False
- num_processes: 4
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- Datasets version: 3.0.0
- HF Hub version: 0.24.7
- TRL version: 0.12.0.dev0+acb4d70
- bitsandbytes version: 0.41.1
- DeepSpeed version: 0.15.1
- Diffusers version: 0.30.3
- Liger-Kernel version: 0.3.0
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.12.0
```
This information are required when reporting an issue.

View File

@ -0,0 +1,26 @@
# Community Tutorials
Community tutorials are made by active members of the Hugging Face community that want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
# Language Models
| Task | Class | Description | Author | Tutorial | Colab |
| ----------------------- | --------------- | ---------------------------------------------------------------------------------------- | -------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) |
| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
<Youtube id="cnGyyM0vOes" />
# Vision Language Models
| Task | Class | Description | Author | Tutorial | Colab |
| --------------- | -------------- | ---------------------------------------------------------------------------- | ------------------------------------------------------ | -------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) |
| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) |
| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) |
## Contributing
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.

108
docs/source/cpo_trainer.mdx Normal file
View File

@ -0,0 +1,108 @@
# CPO Trainer
[![](https://img.shields.io/badge/All_models-CPO-blue)](https://huggingface.co/models?other=cpo,trl)
## Overview
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFTs methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
## Quick start
This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_cpo.py
from datasets import load_dataset
from trl import CPOConfig, CPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO", logging_steps=10)
trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_cpo.py
```
## Expected dataset type
CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Example script
We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch examples/scripts/cpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-CPO
```
## Logged metrics
While training and evaluating we record the following reward metrics:
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
## CPO variants
### Simple Preference Optimization (SimPO)
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`].
### CPO-SimPO
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
## Loss functions
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
| `loss_type=` | Description |
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## CPOTrainer
[[autodoc]] CPOTrainer
## CPOConfig
[[autodoc]] CPOConfig

View File

@ -1,152 +1,163 @@
# Training customization
At `trl` we provide the possibility to give enough modularity to users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques.
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
## Use different optimizers
## Train on multiple GPUs / nodes
By default, the `PPOTrainer` creates a `torch.optim.Adam` optimizer. You can create and define a different optimizer and pass it to `PPOTrainer`:
```python
import torch
from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 2. define config
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
config = PPOConfig(**ppo_config)
# 2. Create optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
# 3. initialize trainer
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
```bash
accelerate config
```
For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`:
and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:
```python
import torch
import bitsandbytes as bnb
from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 2. define config
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
config = PPOConfig(**ppo_config)
# 2. Create optimizer
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)
# 3. initialize trainer
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
```bash
accelerate launch your_script.py
```
### Use LION optimizer
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
You can use the new [LION optimizer from Google](https://arxiv.org/abs/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
```python
optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate)
...
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
```shell
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
```
We advice you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)):
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-lion.png">
</div>
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details.
### Distributed training with DeepSpeed
## Add a learning rate scheduler
All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:
```shell
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script
```
Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example:
You can also play with your training by adding learning rate schedulers!
```python
import torch
from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
with ds_plugin.zero3_init_context_manager(enable=False):
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
else:
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
```
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 2. define config
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
config = PPOConfig(**ppo_config)
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
# 2. Create optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
lr_scheduler = lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
## Use different optimizers and schedulers
# 3. initialize trainer
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)
By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import optim
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate)
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
optimizers=(optimizer, None),
)
trainer.train()
```
### Add a learning rate scheduler
You can also play with your training by adding learning rate schedulers.
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import optim
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
optimizers=(optimizer, lr_scheduler),
)
trainer.train()
```
## Memory efficient fine-tuning by sharing layers
Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.
```python
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import create_reference_model, DPOConfig, DPOTrainer
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
model_ref = create_reference_model(model, num_shared_layers=6)
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
ref_model = create_reference_model(model, num_shared_layers=6)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
# 2. initialize trainer
ppo_config = {'batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
```
## Pass 8-bit reference models
<div>
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition).
</div>
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
```python
# 0. imports
# pip install bitsandbytes
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import DPOConfig, DPOTrainer
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
# 2. initialize trainer
ppo_config = {'batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
```
## Use the CUDA cache optimizer
When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass `optimize_cuda_cache=True` to `PPOConfig`:
When training large models, you should better handle the CUDA cache by iteratively clearing it. To do so, simply pass `optimize_cuda_cache=True` to `DPOConfig`:
```python
config = PPOConfig(..., optimize_cuda_cache=True)
```
training_args = DPOConfig(..., optimize_cuda_cache=True)
```

View File

@ -0,0 +1,29 @@
# Data Utilities
## is_conversational
[[autodoc]] is_conversational
## apply_chat_template
[[autodoc]] apply_chat_template
## maybe_apply_chat_template
[[autodoc]] maybe_apply_chat_template
## extract_prompt
[[autodoc]] extract_prompt
## maybe_extract_prompt
[[autodoc]] maybe_extract_prompt
## unpair_preference_dataset
[[autodoc]] unpair_preference_dataset
## maybe_unpair_preference_dataset
[[autodoc]] maybe_unpair_preference_dataset

View File

@ -0,0 +1,911 @@
# Dataset formats and types
This guide provides an overview of the dataset formats and types supported by each trainer in TRL.
## Overview of the dataset formats and types
- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*.
- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.
<table>
<tr>
<th>Type \ Format</th>
<th>Standard</th>
<th>Conversational</th>
</tr>
<tr>
<td>Language modeling</td>
<td>
<pre><code>{"text": "The sky is blue."}</code></pre>
</td>
<td>
<pre><code>{"messages": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}]}</code></pre>
</td>
</tr>
<tr>
<td>Prompt-only</td>
<td>
<pre><code>{"prompt": "The sky is"}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}]}</code></pre>
</td>
</tr>
<tr>
<td>Prompt-completion</td>
<td>
<pre><code>{"prompt": "The sky is",
"completion": " blue."}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}</code></pre>
</td>
</tr>
</tr>
<tr>
<td>Preference</td>
<td>
<pre><code>{"prompt": "The sky is",
"chosen": " blue.",
"rejected": " green."}</code></pre>
or, with implicit prompt:
<pre><code>{"chosen": "The sky is blue.",
"rejected": "The sky is green."}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}]}</code></pre>
or, with implicit prompt:
<pre><code>{"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}</code></pre>
</td>
</tr>
<td>Unpaired preference</td>
<td>
<pre><code>{"prompt": "The sky is",
"completion": " blue.",
"label": True}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is green."}],
"label": False}</code></pre>
</td>
</tr>
</tr>
<td>Stepwise supervision</td>
<td>
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": ["The fractional part of 9.8 is 0.8.",
"The fractional part of 9.11 is 0.11.",
"0.11 is greater than 0.8.",
"Hence, 9.11 > 9.8."],
"labels": [True, True, False, False]}</code></pre>
</td>
<td></td>
</tr>
</table>
### Formats
#### Standard
The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks:
```python
# Language modeling
language_modeling_example = {"text": "The sky is blue."}
# Preference
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
# Unpaired preference
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
```
#### Conversational
Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text).
```python
messages = [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
{"role": "user", "content": "I'd like to show off how chat templating works!"},
]
```
Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks:
```python
# Prompt-completion
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}
# Preference
preference_example = {
"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}],
}
```
Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
### Types
#### Language modeling
A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
```python
# Standard format
language_modeling_example = {"text": "The sky is blue."}
# Conversational format
language_modeling_example = {"messages": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}
]}
```
#### Prompt-only
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input.
```python
# Standard format
prompt_only_example = {"prompt": "The sky is"}
# Conversational format
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
```
<Tip>
While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
```python
from transformers import AutoTokenizer
from trl import apply_chat_template
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
# Example for prompt-only type
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
apply_chat_template(prompt_only_example, tokenizer)
# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
# Example for language modeling type
lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
apply_chat_template(lm_example, tokenizer)
# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
```
- The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistants turn and expecting the model to generate a completion.
- In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
</Tip>
#### Prompt-completion
A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
```python
# Standard format
prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
# Conversational format
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}
```
#### Preference
A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
```python
# Standard format
## Explicit prompt (recommended)
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
# Implicit prompt
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
# Conversational format
## Explicit prompt (recommended)
preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}]}
## Implicit prompt
preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}
```
Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
#### Unpaired preference
An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
```python
# Standard format
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
# Conversational format
unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}],
"label": True}
```
#### Stepwise supervision
A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process.
```python
stepwise_example = {
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."],
"labels": [True, False]
}
```
## Which dataset type to use?
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
| Trainer | Expected dataset type |
| ----------------------- | ------------------------------------------------------------------------------------------------------ |
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
<Tip>
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
</Tip>
## Working with conversational datasets in TRL
Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format.
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
### Converting a conversational dataset into a standard dataset
To convert a conversational dataset into a standard dataset, you need to _apply a chat template_ to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
In TRL, the method you apply to convert the dataset will vary depending on the task. Fortunately, TRL provides a helper function called [`apply_chat_template`] to simplify this process. Here's an example of how to use it:
```python
from transformers import AutoTokenizer
from trl import apply_chat_template
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
example = {
"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]
}
apply_chat_template(example, tokenizer)
# Output:
# {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'}
```
Alternatively, you can use the [`~datasets.Dataset.map`] method to apply the template across an entire dataset:
```python
from datasets import Dataset
from trl import apply_chat_template
dataset_dict = {
"prompt": [[{"role": "user", "content": "What color is the sky?"}],
[{"role": "user", "content": "Where is the sun?"}]],
"completion": [[{"role": "assistant", "content": "It is blue."}],
[{"role": "assistant", "content": "In the sky."}]]
}
dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
# Output:
# {'prompt': ['<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n',
# '<|user|>\nWhere is the sun?<|end|>\n<|assistant|>\n'],
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
```
<Tip warning={true}>
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle conversation.
For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
</Tip>
<Tip warning={true}>
It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
```python
apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
# Output:
# {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
# 'completion': 'It is blue.<|im_end|>\n'}
```
Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
</Tip>
## Using any dataset with TRL: preprocessing and conversion
Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions.
### Example: UltraFeedback dataset
Lets take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset:
<iframe
src="https://huggingface.co/datasets/openbmb/UltraFeedback/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>
As shown above, the dataset format does not match the expected structure. Its not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty").
By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub:
```sh
python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness
```
Once converted, the dataset will look like this:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Now, you can use this dataset with TRL!
By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL.
## Utilities for converting dataset types
This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently.
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- | -------------------- |
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
### From prompt-completion to language modeling dataset
To convert a prompt-completion dataset into a language modeling dataset, concatenate the prompt and the completion.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"completion": [" blue.", " in the sky."],
})
def concat_prompt_completion(example):
return {"text": example["prompt"] + example["completion"]}
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
```
```python
>>> dataset[0]
{'text': 'The sky is blue.'}
```
### From prompt-completion to prompt-only dataset
To convert a prompt-completion dataset into a prompt-only dataset, remove the completion.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"completion": [" blue.", " in the sky."],
})
dataset = dataset.remove_columns("completion")
```
```python
>>> dataset[0]
{'prompt': 'The sky is'}
```
### From preference with implicit prompt to language modeling dataset
To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"chosen": ["The sky is blue.", "The sun is in the sky."],
"rejected": ["The sky is green.", "The sun is in the sea."],
})
dataset = dataset.rename_column("chosen", "text").remove_columns("rejected")
```
```python
>>> dataset[0]
{'text': 'The sky is blue.'}
```
### From preference with implicit prompt to prompt-completion dataset
To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`.
```python
from datasets import Dataset
from trl import extract_prompt
dataset = Dataset.from_dict({
"chosen": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
],
})
dataset = dataset.map(extract_prompt).remove_columns("rejected").rename_column("chosen", "completion")
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], 'completion': [{'role': 'assistant', 'content': 'It is blue.'}]}
```
### From preference with implicit prompt to prompt-only dataset
To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen.
```python
from datasets import Dataset
from trl import extract_prompt
dataset = Dataset.from_dict({
"chosen": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
],
})
dataset = dataset.map(extract_prompt).remove_columns(["chosen", "rejected"])
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}]}
```
### From implicit to explicit prompt preference dataset
To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`].
```python
from datasets import Dataset
from trl import extract_prompt
dataset = Dataset.from_dict({
"chosen": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
],
})
dataset = dataset.map(extract_prompt)
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
'chosen': [{'role': 'assistant', 'content': 'It is blue.'}],
'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}
```
### From preference with implicit prompt to unpaired preference dataset
To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`].
```python
from datasets import Dataset
from trl import extract_prompt, unpair_preference_dataset
dataset = Dataset.from_dict({
"chosen": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
],
})
dataset = dataset.map(extract_prompt)
dataset = unpair_preference_dataset(dataset)
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
'label': True}
```
### From preference to language modeling dataset
To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"chosen": [" blue.", " in the sky."],
"rejected": [" green.", " in the sea."],
})
def concat_prompt_chosen(example):
return {"text": example["prompt"] + example["chosen"]}
dataset = dataset.map(concat_prompt_chosen, remove_columns=["prompt", "chosen", "rejected"])
```
```python
>>> dataset[0]
{'text': 'The sky is blue.'}
```
### From preference to prompt-completion dataset
To convert a preference dataset into a prompt-completion dataset, remove the rejected, and rename the column `"chosen"` to `"completion"`.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"chosen": [" blue.", " in the sky."],
"rejected": [" green.", " in the sea."],
})
dataset = dataset.remove_columns("rejected").rename_column("chosen", "completion")
```
```python
>>> dataset[0]
{'prompt': 'The sky is', 'completion': ' blue.'}
```
### From preference to prompt-only dataset
To convert a preference dataset into a prompt-only dataset, remove the rejected and the chosen.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is"],
"chosen": [" blue.", " in the sky."],
"rejected": [" green.", " in the sea."],
})
dataset = dataset.remove_columns(["chosen", "rejected"])
```
```python
>>> dataset[0]
{'prompt': 'The sky is'}
```
### From explicit to implicit prompt preference dataset
To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": [
[{"role": "user", "content": "What color is the sky?"}],
[{"role": "user", "content": "Where is the sun?"}],
],
"chosen": [
[{"role": "assistant", "content": "It is blue."}],
[{"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "assistant", "content": "It is green."}],
[{"role": "assistant", "content": "In the sea."}],
],
})
def concat_prompt_to_completions(example):
return {"chosen": example["prompt"] + example["chosen"], "rejected": example["prompt"] + example["rejected"]}
dataset = dataset.map(concat_prompt_to_completions, remove_columns="prompt")
```
```python
>>> dataset[0]
{'chosen': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is blue.'}],
'rejected': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is green.'}]}
```
### From preference to unpaired preference dataset
To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`].
```python
from datasets import Dataset
from trl import unpair_preference_dataset
dataset = Dataset.from_dict({
"prompt": [
[{"role": "user", "content": "What color is the sky?"}],
[{"role": "user", "content": "Where is the sun?"}],
],
"chosen": [
[{"role": "assistant", "content": "It is blue."}],
[{"role": "assistant", "content": "In the sky."}],
],
"rejected": [
[{"role": "assistant", "content": "It is green."}],
[{"role": "assistant", "content": "In the sea."}],
],
})
dataset = unpair_preference_dataset(dataset)
```
```python
>>> dataset[0]
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
'label': True}
```
### From unpaired preference to language modeling dataset
To convert an unpaired preference dataset into a language modeling dataset, concatenate the prompt and the completion into the `"text"` column, and remove the prompt, completion and label columns.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
"label": [True, True, False, False],
})
def concatenate_prompt_completion(example):
return {"text": example["prompt"] + example["completion"]}
dataset = dataset.map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
```
```python
>>> dataset[0]
{'text': 'The sky is blue.'}
```
### From unpaired preference to prompt-completion dataset
To convert an unpaired preference dataset into a prompt-completion dataset, remove the label columns.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
"label": [True, True, False, False],
})
dataset = dataset.remove_columns(["label"])
```
```python
>>> dataset[0]
{'prompt': 'The sky is', 'completion': ' blue.'}
```
### From unpaired preference to prompt-only dataset
To convert an unpaired preference dataset into a prompt-only dataset, remove the completion and the label columns.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
"label": [True, True, False, False],
})
dataset = dataset.remove_columns(["completion", "label"])
```
```python
>>> dataset[0]
{'prompt': 'The sky is'}
```
### From stepwise supervision to language modeling dataset
To convert a stepwise supervision dataset into a language modeling dataset, concatenate the prompt and the completions into the `"text"` column.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
def concatenate_prompt_completions(example):
completion = "".join(example["completions"])
return {"text": example["prompt"] + completion}
dataset = dataset.map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"])
```
```python
>>> dataset[0]
{'text': 'Blue light scatters more in the atmosphere, so the sky is green.'}
```
### From stepwise supervision to prompt completion dataset
To convert a stepwise supervision dataset into a prompt-completion dataset, join the completions and remove the labels.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
def join_completions(example):
completion = "".join(example["completions"])
return {"completion": completion}
dataset = dataset.map(join_completions, remove_columns=["completions", "labels"])
```
```python
>>> dataset[0]
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'}
```
### From stepwise supervision to prompt only dataset
To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
dataset = dataset.remove_columns(["completions", "labels"])
```
```python
>>> dataset[0]
{'prompt': 'Blue light'}
```
### From stepwise supervision to unpaired preference dataset
To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels.
The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
def merge_completions_and_labels(example):
return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])}
dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "labels"])
```
```python
>>> dataset[0]
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.', 'label': False}
```
## Vision datasets
Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
A conversational vision dataset differs from a standard conversational dataset in two key ways:
1. The dataset must contain the key `images` with the image data.
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.
Example:
```python
# Textual dataset:
"content": "What color is the sky?"
# Vision dataset:
"content": [
{"type": "image"},
{"type": "text", "text": "What color is the sky in the image?"}
]
```
An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly:
<iframe
src="https://huggingface.co/datasets/trl-lib/rlaif-v/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>

View File

@ -0,0 +1,131 @@
# Denoising Diffusion Policy Optimization
[![](https://img.shields.io/badge/All_models-DDPO-blue)](https://huggingface.co/models?other=ddpo,trl)
## The why
| Before | After DDPO finetuning |
| --- | --- |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
## Getting started with Stable Diffusion finetuning with reinforcement learning
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
## Getting started with `examples/scripts/ddpo.py`
The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
```batch
python ddpo.py --hf_user_access_token <token>
```
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
## Setting up the image logging hook function
Expect the function to be given a list of lists of the form
```python
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
```
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
### Key terms
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
- `prompt` : The prompt is the text that is used to generate the image
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
- `image` : The image generated by the Stable Diffusion model
Example code for logging sampled images with `wandb` is given below.
```python
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, _, rewards, _ = image_data[-1]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
```
### Using the finetuned model
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
```python
import torch
from trl import DefaultDDPOStableDiffusionPipeline
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"{prompt}.png")
```
## Credits
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301).
## DDPOTrainer
[[autodoc]] DDPOTrainer
## DDPOConfig
[[autodoc]] DDPOConfig

View File

@ -4,12 +4,12 @@ Language models (LMs) are known to sometimes generate toxic outputs. In this exa
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/lvwerra/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
| File | Description | Colab link |
|---|---| --- |
| [`gpt-j-6b-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
| [`evaluate-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
## Context
@ -58,13 +58,13 @@ And its `continuation` value:
We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
```python
ds = load_dataset("allenai/real-toxicity-prompts", split="train")
train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")
def filter_fn(sample):
toxicity = sample["prompt"]["toxicity"]
return toxicity is not None and toxicity > 0.3
ds = ds.filter(filter_fn, batched=False)
train_dataset = train_dataset.filter(filter_fn, batched=False)
```
### Reward function
@ -98,19 +98,15 @@ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by just speifying `num_shared_layers` argument when creating a `PPOTrainer`:
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-shared-layers.png">
</div>
```python
ppo_trainer = PPOTrainer(
model=model,
tokenizer=tokenizer,
num_shared_layers=4,
...
)
ref_model = create_reference_model(model, num_shared_layers=6)
trainer = PPOTrainer(..., ref_model=ref_model)
```
In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
@ -155,7 +151,7 @@ We report the toxicity score of 400 sampled examples, compute its mean and stand
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
| --- | --- | --- |
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | ,0.3178 |
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
| --- | --- | --- |
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
@ -174,7 +170,7 @@ Below are few generation examples of `gpt-j-6b-detox` model:
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-toxicity-examples.png">
</div>
The evaluation script can be found [here](https://github.com/lvwerra/trl/blob/main/examples/toxicity/scripts/evaluate-toxicity.py).
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
### Discussions

283
docs/source/dpo_trainer.mdx Normal file
View File

@ -0,0 +1,283 @@
# DPO Trainer
[![](https://img.shields.io/badge/All_models-DPO-blue)](https://huggingface.co/models?other=dpo,trl) [![](https://img.shields.io/badge/smol_course-Chapter_2-yellow)](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
## Overview
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn).
The abstract from the paper is the following:
> While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):
1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
![](https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d)
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
## Quick start
This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_dpo.py
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_dpo.py
```
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/dpo-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-DPO&gt;:</span></strong>
The best programming language for specific applications can vary depending on the use case and knowledge level of the programmer. Here are some general factors that can be used as input to choose the best programming language:
<strong><span style="color: green;">1</span></strong> Ease of use: Some programming languages are more user-friendly than others, such as Python, Java, or Ruby. Python is popular due to its simplicity and great scalability.
<strong><span style="color: green;">2</span></strong> Versatility: The ability to work with a wide range of data structures and frameworks can define the language as versatile.
<strong><span style="color: green;">3</span></strong> Ease of learning: Different programming languages have different learning curves, so users must be willing to take some time to master one.
<strong><span style="color: green;">4</span></strong> Community support: The broader community of developers and enthusiasts in the selected programming language can provide great support and resources.
<strong><span style="color: green;">5</span></strong> Reusability: Languages that emphasize code reuse and can be easily modifiable can be more suitable for software development.
The best programming language based on these factors is subjective and depends on what the programmer intends to accomplish.
</code></pre>
## Expected dataset type
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
### Special considerations for vision-language models
The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section.
Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`.
```diff
- model = AutoModelForCausalLM.from_pretrained(model_id)
+ model = AutoModelForVision2Seq.from_pretrained(model_id)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
+ processor = AutoProcessor.from_pretrained(model_id)
trainer = DPOTrainer(
model,
args=training_args,
train_dataset=train_dataset,
- processing_class=tokenizer,
+ processing_class=processor,
)
```
For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
## Example script
We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch trl/scripts/dpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-DPO
```
## Logged metrics
While training and evaluating we record the following reward metrics:
- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
## Loss functions
The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported:
| `loss_type=` | Description |
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
| `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. |
| `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. |
| `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) |
| `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`BCOTrainer`]. |
| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
| `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). |
### Label smoothing
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
### Syncing the reference model
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
### RPO loss
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`.
### WPO loss
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## Accelerate DPO fine-tuning using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| -------- | --------- | ---------- | --- | --------------------- | --------- | ------------ |
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```diff
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
- from transformers import AutoModelForCausalLM, AutoTokenizer
+ from unsloth import FastLanguageModel
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ model = FastLanguageModel.get_peft_model(model)
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
- training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
+ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10, bf16=True)
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
## Reference model considerations with PEFT
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
### Downsides to merging QLoRA before DPO (approach 2)
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
### Using option 3 - load the adapter twice
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
For example:
```python
# Load the base model.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/mixtral-8x7b-v0.1",
load_in_4bit=True,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.config.use_cache = False
# Load the adapter.
model = PeftModel.from_pretrained(
model,
"/path/to/peft",
is_trainable=True,
adapter_name="train",
)
# Load the adapter a second time, with a different name, which will be our reference model.
model.load_adapter("/path/to/peft", adapter_name="reference")
# Initialize the trainer, without a ref_model param.
training_args = DPOConfig(
model_adapter_name="train",
ref_adapter_name="reference",
)
dpo_trainer = DPOTrainer(
model,
args=training_args,
...
)
```
## DPOTrainer
[[autodoc]] DPOTrainer
## DPOConfig
[[autodoc]] DPOConfig
## PreferenceCollator
[[autodoc]] trainer.dpo_trainer.PreferenceCollator

View File

@ -0,0 +1,78 @@
# Examples
## Introduction
The examples should work in any of the following settings (with the same script):
- single GPU
- multi GPUS (using PyTorch distributed mode)
- multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
To run it in each of these various modes, first initialize the accelerate
configuration with `accelerate config`
**NOTE to train with a 4-bit or 8-bit model**, please run
```bash
pip install --upgrade trl[quantization]
```
## Accelerate Config
For all the examples, you'll need to generate a 🤗 Accelerate config file with:
```shell
accelerate config # will prompt you to define the training configuration
```
Then, it is encouraged to launch jobs with `accelerate launch`!
# Maintained Examples
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
| File | Description |
| ----------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
| File | Description |
| --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
We also have some other examples that are less maintained but can be used as a reference:
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
## Distributed training
All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.)
```shell
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
```
You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision).
### Distributed training with DeepSpeed
Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
```shell
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
```

View File

@ -0,0 +1,98 @@
# Generalized Knowledge Distillation Trainer
[![](https://img.shields.io/badge/All_models-GKD-blue)](https://huggingface.co/models?other=gkd,trl)
## Overview
Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
The abstract from the paper is the following:
> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.
The key aspects of GKD are:
1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences.
2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher.
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Lewis Tunstall](https://huggingface.co/lewtun).
## Usage tips
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
> [!WARNING]
> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture.
The basic API is as follows:
```python
from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
NUM_DUMMY_SAMPLES = 100
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
train_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "Hi, how are you?"},
{"role": "assistant", "content": "I'm great thanks"},
]
]
* NUM_DUMMY_SAMPLES
}
)
eval_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "user", "content": "What colour is the sky?"},
{"role": "assistant", "content": "The sky is blue"},
]
]
* NUM_DUMMY_SAMPLES
}
)
training_args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
model=model,
teacher_model=teacher_model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
```
### Expected dataset type
The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:
* `role`: either `system`, `assistant` or `user`
* `content`: the message content
## GKDTrainer
[[autodoc]] GKDTrainer
## GKDConfig
[[autodoc]] GKDConfig

View File

@ -0,0 +1,65 @@
# Training FAQ
## What Metrics Should I Look at?
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
To address this, we recommend focusing on two key metrics first:
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
However, there are more metrics that can be useful for debugging, checkout the [logging section](logging).
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png">
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
</div>
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
## What Is the Concern with Negative KL Divergence?
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases:
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
So how should you generate text for PPO training? Let's have a look!
## How to generate text for training?
In order to avoid the KL issues described above we recommend to use the following settings:
```python
generation_kwargs = {
"min_length": -1, # don't ignore the EOS token (see above)
"top_k": 0.0, # no top-k sampling
"top_p": 1.0, # no nucleus sampling
"do_sample": True, # yes, we want to sample
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
"max_new_tokens": 32, # specify how many tokens you want to generate at most
}
```
With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
## How can debug your own use-case?
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!

View File

@ -4,6 +4,71 @@
# TRL - Transformer Reinforcement Learning
With the TRL (Transformer Reinforcement Learning) library you can train transformer language models with reinforcement learning. The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step.
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
TRL supports decoder models such as GPT-2, BLOOM, GPT-Neo which can all be optimized using Proximal Policy Optimization (PPO). You can find installation instructions in the [installation guide](installation) and an introduction to the library in the [Quickstart section](quickstart). There is also a more [in-depth example](sentiment_tuning) to tune GPT-2 to produce positive movie reviews.
## Learn post-training
Learn post-training with the 🤗 [smol course](https://github.com/huggingface/smol-course).
## API documentation
- [Model Classes](models): *A brief overview of what each public model class does.*
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
- [`TextEnvironment`](text_environments): *Text environment to train your model using tools with RL.*
## Examples
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
## Blog posts
<div class="mt-10">
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on July 10, 2024</p>
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/putting_rl_back_in_rlhf_with_rloo">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/putting_rl_back_in_rlhf_with_rloo/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on June 12, 2024</p>
<p class="text-gray-700">Putting RL back in RLHF</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/166_trl_ddpo/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on September 29, 2023</p>
<p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/157_dpo_trl/dpo_thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on August 8, 2023</p>
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/138_stackllama/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on April 5, 2023</p>
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/133_trl_peft/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on March 9, 2023</p>
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on December 9, 2022</p>
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
</a>
</div>
</div>

View File

@ -12,7 +12,7 @@ pip install trl
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
```bash
git clone https://github.com/lvwerra/trl.git
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .
```

View File

@ -0,0 +1,57 @@
# Iterative Trainer
[![](https://img.shields.io/badge/All_models-Iterative_SFT-blue)](https://huggingface.co/models?other=iterative-sft,trl)
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
## Usage
To get started quickly, instantiate an instance a model, and a tokenizer.
```python
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
trainer = IterativeSFTTrainer(
model,
tokenizer
)
```
You have the choice to either provide a list of strings or a list of tensors to the step function.
#### Using a list of tensors as input:
```python
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
trainer.step(**inputs)
```
#### Using a list of strings as input:
```python
inputs = {
"texts": texts
}
trainer.step(**inputs)
```
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
## IterativeTrainer
[[autodoc]] IterativeSFTTrainer

89
docs/source/judges.mdx Normal file
View File

@ -0,0 +1,89 @@
# Judges
<Tip warning={true}>
TRL Judges is an experimental API which is subject to change at any time.
</Tip>
TRL provides judges to easily compare two completions.
Make sure to have installed the required dependencies by running:
```bash
pip install trl[judges]
```
## Using the provided judges
TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
```python
from trl import HfPairwiseJudge
judge = HfPairwiseJudge()
judge.judge(
prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]],
) # Outputs: [0, 1]
```
## Define your own judge
To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`BaseRankJudge`] and implement the [`BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`BasePairJudge`] and implement the [`BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`BaseJudge`] and implement the [`BaseJudge.judge`] method.
As an example, let's define a pairwise judge that prefers shorter completions:
```python
from trl import BasePairwiseJudge
class PrefersShorterJudge(BasePairwiseJudge):
def judge(self, prompts, completions, shuffle_order=False):
return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions]
```
You can then use this judge as follows:
```python
judge = PrefersShorterJudge()
judge.judge(
prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]],
) # Outputs: [0, 1]
```
## Provided judges
### PairRMJudge
[[autodoc]] PairRMJudge
### HfPairwiseJudge
[[autodoc]] HfPairwiseJudge
### OpenAIPairwiseJudge
[[autodoc]] OpenAIPairwiseJudge
### AllTrueJudge
[[autodoc]] AllTrueJudge
## Base classes
### BaseJudge
[[autodoc]] BaseJudge
### BaseBinaryJudge
[[autodoc]] BaseBinaryJudge
### BaseRankJudge
[[autodoc]] BaseRankJudge
### BasePairwiseJudge
[[autodoc]] BasePairwiseJudge

139
docs/source/kto_trainer.mdx Normal file
View File

@ -0,0 +1,139 @@
# KTO Trainer
[![](https://img.shields.io/badge/All_models-KTO-blue)](https://huggingface.co/models?other=kto,trl)
## Overview
Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela).
The abstract from the paper is the following:
> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive.
The official code can be found in [ContextualAI/HALOs](https://github.com/ContextualAI/HALOs).
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Younes Belkada](https://huggingface.co/ybelkada), [Lewis Tunstall](https://huggingface.co/lewtun) and Pablo Vicente.
## Quick start
This example demonstrates how to train a model using the KTO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [KTO Mix 14k](https://huggingface.co/datasets/trl-lib/kto-mix-14k). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/kto-mix-14k/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_kto.py
from datasets import load_dataset
from trl import KTOConfig, KTOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train")
training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO", logging_steps=10)
trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_kto.py
```
Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kto-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-KTO&gt;:</span></strong>
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:
Here are some other factors to consider when choosing a programming language for a project:
<strong><span style="color: green;">1</span> JavaScript</strong>: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.
<strong><span style="color: green;">2</span> Java</strong>: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.
<strong><span style="color: green;">3</span> C++</strong>: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.
<strong><span style="color: green;">4</span> Python</strong>: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.
</code></pre>
## Expected dataset format
KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones.
The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate.
## Example script
We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py)
To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command:
```bash
accelerate launch trl/scripts/kto.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/kto-mix-14k \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-KTO
```
## Usage tips
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
### Batch size recommendations
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
### Learning rate recommendations
Each choice of `beta` has a maximum learning rate it can tolerate before learning performance degrades. For the default setting of `beta = 0.1`, the learning rate should typically not exceed `1e-6` for most models. As `beta` decreases, the learning rate should also be reduced accordingly. In general, we strongly recommend keeping the learning rate between `5e-7` and `5e-6`. Even with small datasets, we advise against using a learning rate outside this range. Instead, opt for more epochs to achieve better results.
### Imbalanced data
The `desirable_weight` and `undesirable_weight` of the [`KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
## Logged metrics
While training and evaluating we record the following reward metrics:
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
- `logps/chosen`: the mean log probabilities of the chosen completions
- `logps/rejected`: the mean log probabilities of the rejected completions
- `logits/chosen`: the mean logits of the chosen completions
- `logits/rejected`: the mean logits of the rejected completions
- `kl`: the KL divergence between the policy model and the reference model
## KTOTrainer
[[autodoc]] KTOTrainer
## KTOConfig
[[autodoc]] KTOConfig

View File

@ -0,0 +1,233 @@
# Learning Tools (Experimental 🧪)
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
| File | Description |
|---|---|
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
<Tip warning={true}>
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
</Tip>
## Learning to Use a Calculator
The rough idea is as follows:
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number:
```python
from transformers import AutoTokenizer, load_tool
tool = load_tool("ybelkada/simple-calculator")
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
```
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
1. Create a prompt on how to use the tools
```python
# system prompt
prompt = """\
What is 13.1-3?
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
Result=10.1<submit>
What is 4*3?
<request><SimpleCalculatorTool>4*3<call>12<response>
Result=12<submit>
What is 12.1+1?
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
Result=13.1<submit>
What is 12.1-20?
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
Result=-7.9<submit>"""
```
3. Create a `trl.TextEnvironment` with the model
```python
env = TextEnvironment(
model,
tokenizer,
{"SimpleCalculatorTool": tool_fn},
reward_fn,
prompt,
generation_kwargs=generation_kwargs,
)
```
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools.png)
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
## Experiment results
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
```
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
--command "python examples/research_projects/tools/calculator.py" \
--num-seeds 10 \
--start-seed 1 \
--workers 10 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 8 \
--slurm-template-path benchmark/trl.slurm_template
```
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
```
# pip install openrlbenchmark==0.2.1a5
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
'wandb?tag=calculator_final&cl=calculator_mask' \
--env-ids trl \
--check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename static/0compare \
--scan-history
```
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools_chart.png)
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
## (Early Experiments 🧪): learning to use a wiki tool for question answering
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
<Tip warning={true}>
**Note that many settings are different so the results are not directly comparable.**
</Tip>
### Building a search index
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
```python
from pyserini.search.lucene import LuceneSearcher
import json
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
def search(query):
hits = searcher.search(query, k=1)
hit = hits[0]
contents = json.loads(hit.raw)['contents']
return contents
print(search("tennis racket"))
```
```
Racket (sports equipment)
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
...
```
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pyserini.png)
### Experiment settings
We use the following settings:
* use the `bigcode/starcoderbase` model as the base model
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
* used the following prompt that demonstrates the usage of the wiki tool.
```python
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 4610, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
```
### Result and Discussion
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/triviaqa_learning_curves.png)
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (19851989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (19882013) and other roles.[1][2]"
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/real_first_name.png)
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/brown_act.png)
## (Early Experiments 🧪): solving math puzzles with python interpreter
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
```python
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>72<response>
Result = 72 <submit>
Q: """
```
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gms8k_learning_curve.png)

74
docs/source/logging.mdx Normal file
View File

@ -0,0 +1,74 @@
# Logging
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard.
Upon initialization, pass one of these two options to the [`PPOConfig`]:
```
training_args = PPOConfig(..., report_to="wandb") # or "tensorboard"
```
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
## PPO Logging
Here's a brief explanation for the logged metrics provided in the data:
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model.
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model.
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
Training stats:
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
1. `ppo/val/vpred`: The predicted values from the value function.
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
Stats on queries, responses, and logprobs:
1. `tokens/queries_len_mean`: The average length of the queries tokens.
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
1. `tokens/responses_len_mean`: The average length of the responses tokens.
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
### Crucial values
During training, many values are logged, here are the most important ones:
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
1. `ppo/loss/value`: it will spike / NaN when not going well.
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.

View File

@ -0,0 +1,144 @@
# Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)
The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685).
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
| File | Task | Description | Colab link |
|---|---| --- |
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
## Installation
Note: peft is in active development, so we install directly from their Github page.
Peft also relies on the latest version of transformers.
```bash
pip install trl[peft]
pip install bitsandbytes loralib
pip install git+https://github.com/huggingface/transformers.git@main
#optional: wandb
pip install wandb
```
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
## How to use it?
Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
```python
from peft import LoraConfig
from trl import AutoModelForCausalLMWithValueHead
model_id = "edbeeching/gpt-neo-125M-imdb"
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_id,
peft_config=lora_config,
)
```
And if you want to load your model in 8bit precision:
```python
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
load_in_8bit=True,
peft_config=lora_config,
)
```
... or in 4bit precision:
```python
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
peft_config=lora_config,
load_in_4bit=True,
)
```
## Launch scripts
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
```bash
accelerate config # will prompt you to define the training configuration
accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
```
## Using `trl` + `peft` and Data Parallelism
You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows:
```python
from peft import LoraConfig
...
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
peft_config=lora_config,
)
```
And if you want to load your model in 8bit precision:
```python
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
peft_config=lora_config,
load_in_8bit=True,
)
```
... or in 4bit precision:
```python
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
peft_config=lora_config,
load_in_4bit=True,
)
```
Finally, make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`.
## Naive pipeline parallelism (NPP) for large models (>60B models)
The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs.
This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-npp.png">
</div>
### How to use NPP?
Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model.
Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`.
### Launch scripts
Although `trl` library is powered by `accelerate`, you should run your training script in a single process. Note that we do not support Data Parallelism together with NPP yet.
```bash
python PATH_TO_SCRIPT
```
## Fine-tuning Llama-2 model
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
```bash
python trl/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
```

View File

@ -0,0 +1,100 @@
# Multi Adapter RL (MARL) - a single base model for everything
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues.
## Requirements
You just need to install `peft` and optionally install `bitsandbytes` as well if you want to go for 8bit base models, for more memory efficient finetuning.
## Summary
You need to address this approach in three stages that we summarize as follows:
1- Train a base model on the target domain (e.g. [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb)) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL.
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py)
3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
## Quickstart
Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`.
When doing PPO, before passing the model to `PPOTrainer` create your model as follows:
```python
model_name = "huggyllama/llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
# PPO adapter
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name,
peft_config=lora_config,
reward_adapter=rm_adapter_id,
)
...
trainer = PPOTrainer(
model=model,
...
)
...
```
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
```python
rewards = trainer.model.compute_reward_score(**inputs)
```
## Advanced usage
### Control on the adapter name
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
```python
adapter_name_policy_1 = "policy_1"
rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1)
...
```
### Using 4-bit and 8-bit base models
For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32).
Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`):
```python
model_name = "llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
# PPO adapter
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name,
peft_config=lora_config,
reward_adapter=rm_adapter_id,
load_in_8bit=True,
)
...
trainer = PPOTrainer(
model=model,
...
)
...
```

View File

@ -0,0 +1,159 @@
# Nash-MD Trainer
[![](https://img.shields.io/badge/All_models-Nash--MD-blue)](https://huggingface.co/models?other=nash-md,trl)
## Overview
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.
The abstract from the paper is the following:
> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences.
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).
## Quick start
This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_nash_md.py
from datasets import load_dataset
from trl import NashMDConfig, NashMDTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD", logging_steps=10)
trainer = NashMDTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_nash_md.py
```
Distributed across 8 GPUs, the training takes approximately 3 hours.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [TRL Chat CLI](clis#chat-interface).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-NashMD&gt;:</span></strong>
The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.
</code></pre>
## Expected dataset type
Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Usage tips
### Use a reward model
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
```diff
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
trainer = NashMDTrainer(
...
- judge=judge,
+ reward_model=reward_model,
)
```
<Tip warning={true}>
Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
</Tip>
### Encourage EOS token generation
We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`NashMDConfig`]:
```python
training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
```
### Logging Completions
To better understand your models behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`].
```python
trainer = NashMDTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)
```
This callback logs the model's generated completions directly to Weights & Biases.
![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)
## Example script
We provide an example script to train a model using the Nash-MD method. The script is available in [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py)
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
```bash
python examples/scripts/nash_md.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--logging_steps 25 \
--output_dir Qwen2.5-0.5B-NashMD-PairRM \
--warmup_ratio 0.1 \
--push_to_hub
```
## Logged metrics
The logged metrics are as follows:
* `loss/kl`: The mean KL divergence between the model and reference data.
* `objective/entropy`: The mean entropy of the model and reference data.
* `loss/score`: The mean reinforce score loss.
* `rewards/chosen`: The mean scores (according to the reward model) of the model completions.
* `rewards/rejected`: The mean scores (according to the reward model) of the mixture completions.
* `rewards/probabilities`: The mean probability (according to the reward model or judge) of the model completions chosen vs the mixture completion.
* `rewards/accuracies`: The accuracies of the Nash-MD's implicit reward model.
* `rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions.
* `logps/chosen`: The mean log probabilities of the chosen completions.
* `logps/rejected`: The mean log probabilities of the reference completions.
* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token.
* `val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token.
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].
* `mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].
## NashMDTrainer
[[autodoc]] NashMDTrainer
## NashMDConfig
[[autodoc]] NashMDConfig

View File

@ -0,0 +1,278 @@
# Online DPO Trainer
[![](https://img.shields.io/badge/All_models-Online_DPO-blue)](https://huggingface.co/models?other=online-dpo,trl)
## Overview
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
The abstract from the paper is the following:
> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator.
This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching).
## Quick start
This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_online_dpo.py
from datasets import load_dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_online_dpo.py
```
Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online-dpo-qwen2.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-OnlineDPO&gt;:</span></strong>
The best programming language depends on your specific needs and priorities. Some people prefer imperative programming languages (like Haskell or Lisp), while others prefer functional programming languages (like Scala or Python). It's important to consider your work style, programming environment, and project requirements when choosing a programming language.
</code></pre>
## Expected dataset type
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Usage tips
### Use a reward model
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
```diff
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
+ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward")
trainer = OnlineDPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
+ reward_processing_class=reward_tokenizer,
...
)
```
### Encourage EOS token generation
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
```python
training_args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
```
### Logging Completions
To better understand your models behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`].
```python
trainer = OnlineDPOTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)
```
This callback logs the model's generated completions directly to Weights & Biases.
![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)
## Example script
We provide an example script to train a model using the online DPO method. The script is available in [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py)
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
```bash
python examples/scripts/dpo_online.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--logging_steps 25 \
--output_dir Qwen2.5-0.5B-Online-DPO-PairRM \
--warmup_ratio 0.1 \
--push_to_hub
```
## Logged metrics
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9)
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model.
* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `scores - non_score_reward`. The `rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up.
* `objective/scores`: The mean scores returned by the reward model.
* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions.
* `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions.
* `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions.
* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model.
* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions.
* `logps/chosen`: The mean log probabilities of the chosen completions.
* `logps/rejected`: The mean log probabilities of the rejected completions.
* `val/contain_eos_token`: The fraction of completions which contain an EOS token.
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`OnlineDPOConfig`].
## Benchmark experiments
To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
```
# 1B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
# 2.8B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-2.8b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-2.8b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-2.8b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
# 6.9B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-6.9b-deduped-tldr-online-dpo \
--beta 0.1 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--gradient_checkpointing \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
```
Checkpoints and experiment tracking are available at:
- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
```bash
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 33.00%
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 41.50%
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 62.60%
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 74.20%
```
We can then plot the RLHF scaling chart.
```python
import matplotlib.pyplot as plt
results = {
"SFT": {1.0e9: 0.21, 2.8e9: 0.27, 6.9e9: 0.316},
"online-dpo": {1.0e9: 0.542, 2.8e9: 0.746, 6.9e9: 0.796},
"offline-dpo": {1.0e9: 0.422, 2.8e9: 0.517, 6.9e9: 0.701},
}
plt.plot(results["SFT"].keys(), results["SFT"].values(), label="SFT", marker="o")
plt.plot(results["online-dpo"].keys(), results["online-dpo"].values(), label="Online-dpo with RM judge", marker="o")
plt.plot(results["offline-dpo"].keys(), results["offline-dpo"].values(), label="Offline-dpo", marker="o")
plt.axhline(y=0.5, color="black", linestyle="-.", label="Human reference summary")
plt.xscale("log")
plt.xlabel("Model size")
plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)")
plt.title("DPO scaling by model size")
plt.legend()
plt.xlim(5e8, 1.2e10)
plt.xticks([1e9, 3e9, 1e10], ["1B", "3B", "10B"])
plt.grid(True, which="both", ls="--", c="0.7")
plt.tight_layout()
plt.show()
```
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online_dpo_scaling.png)
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
## OnlineDPOTrainer
[[autodoc]] OnlineDPOTrainer
## OnlineDPOConfig
[[autodoc]] OnlineDPOConfig

129
docs/source/orpo_trainer.md Normal file
View File

@ -0,0 +1,129 @@
# ORPO Trainer
[![](https://img.shields.io/badge/All_models-ORPO-blue)](https://huggingface.co/models?other=orpo,trl) [![](https://img.shields.io/badge/smol_course-Chapter_2-yellow)](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
## Overview
Odds Ratio Preference Optimization (ORPO) was introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes).
The abstract from the paper is the following:
> While recent preference alignment algorithms for language models have demonstrated promising results, supervised fine-tuning (SFT) remains imperative for achieving successful convergence. In this paper, we study the crucial role of SFT within the context of preference alignment, emphasizing that a minor penalty for the disfavored generation style is sufficient for preference-aligned SFT. Building on this foundation, we introduce a straightforward and innovative reference model-free monolithic odds ratio preference optimization algorithm, ORPO, eliminating the necessity for an additional preference alignment phase. We demonstrate, both empirically and theoretically, that the odds ratio is a sensible choice for contrasting favored and disfavored styles during SFT across the diverse sizes from 125M to 7B. Specifically, fine-tuning Phi-2 (2.7B), Llama-2 (7B), and Mistral (7B) with ORPO on the UltraFeedback alone surpasses the performance of state-of-the-art language models with more than 7B and 13B parameters: achieving up to 12.20% on AlpacaEval_{2.0} (Figure 1), 66.19% on IFEval (instruction-level loose, Table 6), and 7.32 in MT-Bench (Figure 12). We release code and model checkpoints for Mistral-ORPO-alpha (7B) and Mistral-ORPO-beta (7B).
It studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.
The official code can be found in [xfactlab/orpo](https://github.com/xfactlab/orpo).
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Lewis Tunstall](https://huggingface.co/lewtun) and [Alvaro Bartolome](https://huggingface.co/alvarobartt).
## Quick start
This example demonstrates how to train a model using the ORPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_orpo.py
from datasets import load_dataset
from trl import ORPOConfig, ORPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = ORPOConfig(output_dir="Qwen2-0.5B-ORPO", logging_steps=10)
trainer = ORPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_orpo.py
```
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/orpo-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-ORPO&gt;:</span></strong>
It's challenging to determine the best programming language as no one language is perfect, as the complexity of a task and the type of project are significant factors. Some popular languages include Java, Python, JavaScript, and
C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.
Here are some other factors to consider when choosing a programming language for a project:
<strong><span style="color: green;">• Language proficiency:</span></strong> A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
<strong><span style="color: green;">• Ease of use:</span></strong> There are tools and libraries available to make programming more accessible, so developers should choose a language that can help them get started easier.
<strong><span style="color: green;">• Code readability:</span></strong> A clear and concise codebase should be easy to read and understand, especially when working with large projects.
<strong><span style="color: green;">• Tool and framework support:</span></strong> There are numerous libraries available for Python, Java, and JavaScript, along with tools like IDEs and static code analysis tools.
<strong><span style="color: green;">• Accessibility:</span></strong> Some languages and tools have features that make them more accessible to developers with disabilities, such as support for screen readers.
<strong><span style="color: green;">• Version control:</span></strong> As your projects grow and complexity increases, version control tools can be beneficial for tracking changes.
</code></pre>
## Expected dataset type
ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
## Example script
We provide an example script to train a model using the ORPO method. The script is available in [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py)
To test the ORPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch examples/scripts/orpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-ORPO
```
## Usage tips
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## Logged metrics
While training and evaluating we record the following reward metrics:
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
- `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
- `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
- `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
## ORPOTrainer
[[autodoc]] ORPOTrainer
## ORPOConfig
[[autodoc]] ORPOConfig

237
docs/source/ppo_trainer.md Normal file
View File

@ -0,0 +1,237 @@
# PPO Trainer
[![](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl)
TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347).
References:
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031)
## Get started
To just run a PPO script to make sure the trainer can run, you can run the following command to train a PPO model with a dummy reward model.
```bash
python examples/scripts/ppo/ppo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--missing_eos_penalty 1.0
```
## Explanation of the logged metrics
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
* `eps`: Tracks the number of episodes per second.
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
* `objective/scores`: The mean scores returned by the reward model / environment.
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: lr: The current learning rate used by the optimizer.
* `episode`: episode: The current global step or episode count in the training process.
## Cookbook
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
## What is my model doing exactly?
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif?download=true)
In the logs the sampled generations look like
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ query ┃ model response ┃ score ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │
│ │ I don't know how to get rid of │ │
│ TITLE: How do you get someone │ those feelings. I'm │ │
│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │
│ │ │ │
│ POST: Hi, │ │ │
│ I'm 22, and I have been with my │ │ │
│ girlfriend for 5 years now. We │ │ │
│ recently moved together. We've │ │ │
│ always loved each other │ │ │
│ intensely. │ │ │
│ │ │ │
│ Problem, I recently started to │ │ │
│ have feelings for an other │ │ │
│ person (a friend). This person │ │ │
│ has had a boyfriend for now 3 │ │ │
│ years, and has absolutely no │ │ │
│ ideas. Those feelings were so │ │ │
│ strong, it was hard to hide │ │ │
│ them. After 2 months of me │ │ │
│ being distant and really sad, │ │ │
│ my girlfriend forced me to say │ │ │
│ what was bothering me. I'm not │ │ │
│ a good liar, and now she knows. │ │ │
│ │ │ │
│ We decided to give us a week │ │ │
│ alone, I went to my parents. │ │ │
│ │ │ │
│ Now, I'm completely lost. I │ │ │
│ keep on thinking about this │ │ │
│ person, and I hate that. I │ │ │
│ would like for those feelings │ │ │
│ to go away, to leave me alone. │ │ │
│ But I can't. │ │ │
│ │ │ │
│ What do I do? It's been 3 │ │ │
│ months now, and I'm just │ │ │
│ desperate. │ │ │
│ │ │ │
│ TL;DR: │ │ │
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │
│ │ TV. I blasted Gangnam Style on │ │
│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │
│ with a loud TV. │ up as high as it could │ │
│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │
│ POST: She was in her living │ │ │
│ room, watching TV. This was at │ │ │
│ about 8:30 in the morning, and │ │ │
│ she was exercising. She turned │ │ │
│ the TV up extra loud to hear it │ │ │
│ over her excercycle, and woke │ │ │
│ me up. I went in there asking │ │ │
│ for her to turn it down. She │ │ │
│ said she didn't have to; I │ │ │
│ explained that I always used │ │ │
│ headphones so she didn't have │ │ │
│ to deal with my noise and that │ │ │
│ she should give me a little │ │ │
│ more respect, given that I paid │ │ │
│ rent at the time. │ │ │
│ │ │ │
│ She disagreed. I went back to │ │ │
│ my room, rather pissed off at │ │ │
│ the lack of equality. I had no │ │ │
│ lock on my door; but I had a │ │ │
│ dresser right next to it, so I │ │ │
│ pulled one of the drawers out │ │ │
│ enough so that it caused the │ │ │
│ door to not be openable. Then, │ │ │
│ I turned my speakers up really │ │ │
│ loud and blasted Gangnam Style │ │ │
│ on repeat, with the bass │ │ │
│ cranked up as high as it could │ │ │
│ go. │ │ │
│ │ │ │
│ If you hate Gangnam Style for │ │ │
│ being overplayed, you will see │ │ │
│ why I chose that particular │ │ │
│ song. I personally don't mind │ │ │
│ it. But here's the thing about │ │ │
│ my bass; it vibrates the walls, │ │ │
│ making one hell of a lot of │ │ │
│ noise. Needless to say, my mom │ │ │
│ was not pleased and shut off │ │ │
│ the internet. But it was oh so │ │ │
│ worth it. │ │ │
│ │ │ │
│ TL;DR: │ │ │
└─────────────────────────────────┴─────────────────────────────────┴──────────┘
```
## Implementation details
This PPO implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
## Benchmark experiments
To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
```
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/ppo/ppo_tldr.py \
--output_dir models/minimal/ppo_tldr \
--learning_rate 3e-6 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 4 \
--total_episodes 1000000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--missing_eos_penalty 1.0 \
--stop_token eos
```
Checkpoints and experiment tracking are available at:
- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/ppo_tldr)
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
```bash
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 33.00%
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 64.70%
```
The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the PPO training is working as intended.
Metrics:
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/pr-1540/ppov2.png)
```bash
# pip install openrlbenchmark==0.2.1a5
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/loss/value_avg&metrics=train/val/clipfrac_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
"cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
--env-ids models/minimal/ppo_tldr \
--pc.ncols 4 \
--pc.ncols-legend 1 \
--pc.xlabel "Episode" \
--output-filename benchmark/trl/pr-1540/ppo \
--scan-history
```
## PPOTrainer
[[autodoc]] PPOTrainer
## PPOConfig
[[autodoc]] PPOConfig

123
docs/source/prm_trainer.mdx Normal file
View File

@ -0,0 +1,123 @@
# PRM Trainer
<Tip warning={true}>
PRM Trainer is an experimental API which is subject to change at any time.
</Tip>
## Overview
Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins.
The abstract from the paper is the following:
> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions.
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).
## Quick start
This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/math_shepherd/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_prm.py
```
Distributed across 8 GPUs, the training takes approximately 1 hour.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script.
```python
from datasets import load_dataset
from transformers import pipeline
pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd")
dataset = load_dataset("trl-lib/math_shepherd")
example = {
"prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
"completions": [
"Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
"Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
"Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
],
"labels": [True, False, False],
}
separator = "\n" # It's important to use the same separator as the one used during training
for idx in range(1, len(example["completions"]) + 1):
steps = example["completions"][0:idx]
text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps
pred_entity = pipe(text)[-1]["entity"]
pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity]
label = example["labels"][idx - 1]
print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}")
```
```text
Step 1 Predicted: True Label: True
Step 2 Predicted: False Label: False
Step 3 Predicted: False Label: False
```
It's a win!
## Expected dataset type
PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision).
The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step.
The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format.
## Example script
We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py)
To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command:
```bash
accelerate launch examples/scripts/prm.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/math_shepherd \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-Reward-Math-Sheperd
```
## PRMTrainer
[[autodoc]] PRMTrainer
## PRMConfig
[[autodoc]] PRMConfig

View File

@ -19,30 +19,40 @@ The following code illustrates the steps above.
# 0. imports
import torch
from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# 2. initialize trainer
ppo_config = {'batch_size': 1}
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
# 3. encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)
# 4. generate model response
response_tensor = respond_to_batch(model, query_tensor)
response_txt = tokenizer.decode(response_tensor[0,:])
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])
# 5. define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]
# 6. train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

View File

@ -0,0 +1,90 @@
# Reward Modeling
[![](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl)
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py).
## Expected dataset type
The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`).
The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`.
## Using the `RewardTrainer`
After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
### Leveraging 🤗 PEFT to train a reward model
Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
```python
from peft import LoraConfig, TaskType
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
...
trainer = RewardTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dataset,
peft_config=peft_config,
)
trainer.train()
```
### Adding a margin to the loss
As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
```python
def add_margin(row):
# Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
return {'margin': row['score_chosen'] - row['score_rejected']}
dataset = dataset.map(add_margin)
```
### Centering rewards
In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it.
[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs:
$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$
This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`).
```python
training_args = RewardConfig(
center_rewards_coefficient=0.01,
...
)
```
For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932).
## RewardTrainer
[[autodoc]] RewardTrainer
## RewardConfig
[[autodoc]] RewardConfig

279
docs/source/rloo_trainer.md Normal file
View File

@ -0,0 +1,279 @@
# RLOO Trainer
[![](https://img.shields.io/badge/All_models-RLOO-blue)](https://huggingface.co/models?other=rloo,trl)
TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.
References:
- [Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740)
- [A2C is a special case of PPO](https://huggingface.co/papers/2205.09123)
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031)
## Get started
To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.
```bash
python examples/scripts/rloo/rloo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-14m \
--reward_model_path EleutherAI/pythia-14m \
--missing_eos_penalty 1.0
```
## Explanation of the logged metrics
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34)
<!-- * `rlhf_reward_var_per_prompt`: calculated by `rlhf_reward.var(0).mean()`. This is the variance of the rewards estimated across the `args.rloo_k` samples. Usually we expect it to go down (cause policy entropy goes down). -->
* `eps`: Tracks the number of episodes per second.
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
* `objective/scores`: The mean scores returned by the reward model / environment.
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: lr: The current learning rate used by the optimizer.
* `episode`: episode: The current global step or episode count in the training process.
## Cookbook
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
## What is my model doing exactly?
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif)
In the logs the sampled generations look like
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ query ┃ model response ┃ score ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │
│ │ I don't know how to get rid of │ │
│ TITLE: How do you get someone │ those feelings. I'm │ │
│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │
│ │ │ │
│ POST: Hi, │ │ │
│ I'm 22, and I have been with my │ │ │
│ girlfriend for 5 years now. We │ │ │
│ recently moved together. We've │ │ │
│ always loved each other │ │ │
│ intensely. │ │ │
│ │ │ │
│ Problem, I recently started to │ │ │
│ have feelings for an other │ │ │
│ person (a friend). This person │ │ │
│ has had a boyfriend for now 3 │ │ │
│ years, and has absolutely no │ │ │
│ ideas. Those feelings were so │ │ │
│ strong, it was hard to hide │ │ │
│ them. After 2 months of me │ │ │
│ being distant and really sad, │ │ │
│ my girlfriend forced me to say │ │ │
│ what was bothering me. I'm not │ │ │
│ a good liar, and now she knows. │ │ │
│ │ │ │
│ We decided to give us a week │ │ │
│ alone, I went to my parents. │ │ │
│ │ │ │
│ Now, I'm completely lost. I │ │ │
│ keep on thinking about this │ │ │
│ person, and I hate that. I │ │ │
│ would like for those feelings │ │ │
│ to go away, to leave me alone. │ │ │
│ But I can't. │ │ │
│ │ │ │
│ What do I do? It's been 3 │ │ │
│ months now, and I'm just │ │ │
│ desperate. │ │ │
│ │ │ │
│ TL;DR: │ │ │
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │
│ │ TV. I blasted Gangnam Style on │ │
│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │
│ with a loud TV. │ up as high as it could │ │
│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │
│ POST: She was in her living │ │ │
│ room, watching TV. This was at │ │ │
│ about 8:30 in the morning, and │ │ │
│ she was exercising. She turned │ │ │
│ the TV up extra loud to hear it │ │ │
│ over her excercycle, and woke │ │ │
│ me up. I went in there asking │ │ │
│ for her to turn it down. She │ │ │
│ said she didn't have to; I │ │ │
│ explained that I always used │ │ │
│ headphones so she didn't have │ │ │
│ to deal with my noise and that │ │ │
│ she should give me a little │ │ │
│ more respect, given that I paid │ │ │
│ rent at the time. │ │ │
│ │ │ │
│ She disagreed. I went back to │ │ │
│ my room, rather pissed off at │ │ │
│ the lack of equality. I had no │ │ │
│ lock on my door; but I had a │ │ │
│ dresser right next to it, so I │ │ │
│ pulled one of the drawers out │ │ │
│ enough so that it caused the │ │ │
│ door to not be openable. Then, │ │ │
│ I turned my speakers up really │ │ │
│ loud and blasted Gangnam Style │ │ │
│ on repeat, with the bass │ │ │
│ cranked up as high as it could │ │ │
│ go. │ │ │
│ │ │ │
│ If you hate Gangnam Style for │ │ │
│ being overplayed, you will see │ │ │
│ why I chose that particular │ │ │
│ song. I personally don't mind │ │ │
│ it. But here's the thing about │ │ │
│ my bass; it vibrates the walls, │ │ │
│ making one hell of a lot of │ │ │
│ noise. Needless to say, my mom │ │ │
│ was not pleased and shut off │ │ │
│ the internet. But it was oh so │ │ │
│ worth it. │ │ │
│ │ │ │
│ TL;DR: │ │ │
└─────────────────────────────────┴─────────────────────────────────┴──────────┘
```
## Implementation details
The bulk of RLOOTrainer is based on the PPO implementation, which is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
Below is a vectorized advantage calculation for RLOO:
```python
def test_rloo_reward():
local_batch_size = 3
rloo_k = 4
rlhf_reward = torch.tensor([
1, 2, 3, # first rlhf reward for three prompts
2, 3, 4, # second rlhf reward for three prompts
5, 6, 7, # third rlhf reward for three prompts
8, 9, 10, # fourth rlhf reward for three prompts
]).float() # here we have 3 prompts which have 4 completions each
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
advantages = torch.zeros_like(rlhf_reward)
for i in range(0, len(advantages), local_batch_size):
other_response_rlhf_rewards = []
for j in range(0, len(advantages), local_batch_size):
if i != j:
other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(other_response_rlhf_rewards).mean(0)
assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6 # First rlhf reward for the first prompt
assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6 # Third rlhf reward for the second prompt
# Vectorized implementation
rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)
```
## Benchmark experiments
To validate the RLOO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
```
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
--output_dir models/minimal/rloo_tldr \
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
--dataset_test_split validation \
--num_ppo_epochs 2 \
--num_mini_batches 2 \
--learning_rate 3e-6 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 16 \
--total_episodes 1000000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--missing_eos_penalty 1.0 \
--stop_token eos \
--kl_coef 0.03
```
Checkpoints and experiment tracking are available at:
- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/rloo_tldr)
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/u2sqci34)
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
```bash
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 33.00%
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 51.20%
```
The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the RLOO training is working as intended.
Metrics:
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/pr-1540/rloo.png)
```bash
# pip install openrlbenchmark==0.2.1a5
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
"cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
--env-ids models/minimal/rloo_tldr \
--pc.ncols 4 \
--pc.ncols-legend 1 \
--pc.xlabel "Episode" \
--output-filename benchmark/trl/pr-1540/rloo \
--scan-history
```
## RLOOTrainer
[[autodoc]] RLOOTrainer
## RLOOConfig
[[autodoc]] RLOOConfig

View File

@ -0,0 +1,12 @@
# Scripts Utilities
## ScriptArguments
[[autodoc]] ScriptArguments
## TrlParser
[[autodoc]] TrlParser
- parse_args_and_config
- parse_args_into_dataclasses
- set_defaults_with_config

View File

@ -1,35 +1,36 @@
# Sentiment Examples
# Sentiment Tuning Examples
The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
| File | Description | Colab link |
|---|---| --- |
| [`gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | Fine-tune GPT2 to generate positive movie reviews. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb)
|
| [`gpt2-sentiment-control.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | Fine-tune GPT2 to generate movie reviews with controlled sentiment. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb)
|
| [`gpt2-sentiment.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt2-sentiment.py) | Same as the notebook, but easier to use to use in multi-GPU setup. | x |
| [`t5-sentiment.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/t5-sentiment.py) | Same as GPT2 script, but for a Seq2Seq model (T5). | x |
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
## Installation
| File | Description |
|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
## Usage
```bash
pip install trl
#optional: wandb
pip install wandb
# 1. run directly
python examples/scripts/ppo.py
# 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed)
accelerate config # will prompt you to define the training configuration
accelerate launch examples/scripts/ppo.py # launches training
# 3. get help text and documentation
python examples/scripts/ppo.py --help
# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16
python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16
```
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
## Launch scripts
## Few notes on multi-GPU
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
```bash
accelerate config # will prompt you to define the training configuration
accelerate launch scripts/gpt2-sentiment.py # launches training
```
To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.

View File

@ -1,82 +0,0 @@
# Examples of using peft and trl to finetune 8-bit models with Low Rank Adaption
The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
| File | Description | Colab link |
|---|---| --- |
| [`gpt2-sentiment_peft.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt2-sentiment_peft.py) | Same as the sentiment analysis example, but learning a low rank adapter on a 8-bit base model | |
| [`cm_finetune_peft_imdb.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/cm_finetune_peft_imdb.py) | Fine tuning a Low Rank Adapter on a frozen 8-bit model for text generation on the imdb dataset. | |
| [`merge_peft_adapter.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/merge_peft_adapter.py) | Merging of the adapter layers into the base models weights and storing these on the hub. | |
| [`gpt-neo-20b_sentiment_peft.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py) | Sentiment fine-tuning of a Low Rank Adapter to create positive reviews. | |
| [`gpt-neo-1b_peft.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neo-1b-multi-gpu/gpt-neo-1b_peft.py) | Sentiment fine-tuning of a Low Rank Adapter to create positive reviews using 2 GPUs. | |
## Installation
Note: peft is in active development, so we install directly from their github page.
Peft also relies on the latest version of transformers.
```bash
pip install trl[peft]
pip install bitsandbytes loralib
pip install git+https://github.com/huggingface/transformers.git@main
#optional: wandb
pip install wandb
```
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
## Launch scripts
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
```bash
accelerate config # will prompt you to define the training configuration
accelerate launch scripts/gpt2-sentiment_peft.py # launches training
```
## Using `trl` + `peft` and Data Parallelism
You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows:
```python
from accelerate import Accelerator
...
current_device = Accelerator().process_index
pretrained_model = AutoModelForCausalLM.from_pretrained(
config.model_name, load_in_8bit=True, device_map={"": current_device}
)
```
The reason behind `device_map={"": current_device}` is that when you set `"":device_number`, `accelerate` will set the entire model on the `device_number` device. Therefore this trick enables to set the model on the correct device for each process.
As the `Accelerator` object from `accelerate` will take care of initializing the distributed setup correctly.
Make sure to initialize your accelerate config by specifying that you are training in a multi-gpu setup, by running `accelerate config` and make sure to run the training script with `accelerator launch your_script.py`.
Finally make sure that the rewards are computed on `current_device` as well.
## Naive pipeline parallelism (NPP) for large models (>60B models)
The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs.
This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-npp.png">
</div>
### How to use NPP?
Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model.
Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`.
That all you need to do to use NPP. Check out the [gpt-neo-1b_peft.py](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neo-1b-multi-gpu/gpt-neo-1b_peft.py) example for a more details usage of NPP.
### Launch scripts
Although `trl` library is powered by `accelerate`, you should run your training script in a single process. Note that we do not support Data Parallelism together with NPP yet.
```bash
python PATH_TO_SCRIPT
```

777
docs/source/sft_trainer.mdx Normal file
View File

@ -0,0 +1,777 @@
# Supervised Fine-tuning Trainer
[![](https://img.shields.io/badge/All_models-SFT-blue)](https://huggingface.co/models?other=sft,trl) [![](https://img.shields.io/badge/smol_course-Chapter_1-yellow)](https://github.com/huggingface/smol-course/tree/main/1_instruction_tuning)
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
Check out a complete flexible example at [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py).
Experimental support for Vision Language Models is also included in the example [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py).
## Quickstart
If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model.
The following code-snippet takes care of all the data pre-processing and training for you:
```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
training_args = SFTConfig(
max_seq_length=512,
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
)
trainer.train()
```
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
You can also construct a model outside of the trainer and pass it as follows:
```python
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
training_args = SFTConfig(output_dir="/tmp")
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=training_args,
)
trainer.train()
```
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
## Advanced usage
### Train on completions only
You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`.
To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['instruction'])):
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
output_texts.append(text)
return output_texts
response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(output_dir="/tmp"),
formatting_func=formatting_prompts_func,
data_collator=collator,
)
trainer.train()
```
To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
instruction_template = "### Human:"
response_template = "### Assistant:"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)
trainer = SFTTrainer(
model,
args=SFTConfig(output_dir="/tmp"),
train_dataset=dataset,
data_collator=collator,
)
trainer.train()
```
Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.
#### Using token_ids directly for `response_template`
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:
```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
def print_tokens_with_ids(txt):
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
token_ids = tokenizer.encode(txt, add_special_tokens=False)
print(list(zip(tokens, token_ids)))
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
response_template = "### Assistant:"
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
```
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
- `response_template` (without context): `[835, 4007, 22137, 29901]`
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
```
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
```
To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
```python
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```
### Add Special Tokens for Chat Format
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the models embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# Set up the chat format with default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)
```
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
### Dataset format support
The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
* conversational format
```json
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
```
* instruction format
```json
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```
If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method.
```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
...
# load jsonl dataset
dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train")
# load dataset from the HuggingFace Hub
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
...
training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
args=training_args,
train_dataset=dataset,
)
```
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
### Format your input prompts
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows:
```bash
Below is an instruction ...
### Instruction
{prompt}
### Response:
{completion}
```
Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run:
```python
...
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['question'])):
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts
trainer = SFTTrainer(
model,
args=training_args,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
)
trainer.train()
```
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
### Packing dataset ([`ConstantLengthDataset`])
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
```python
...
training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args
)
trainer.train()
```
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.
#### Customize your prompts using packed dataset
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
```python
def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
formatting_func=formatting_func
)
trainer.train()
```
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTConfig`] constructor. Please refer to that class' signature for more information.
### Control over the pretrained model
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to
```python
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
...
training_args = SFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
},
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
)
trainer.train()
```
Note that all keyword arguments of `from_pretrained()` are supported.
### Training adapters
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model.
```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
dataset = load_dataset("trl-lib/Capybara", split="train")
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules="all-linear",
modules_to_save=["lm_head", "embed_token"],
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
"Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
args=SFTConfig(output_dir="Qwen2.5-0.5B-SFT"),
peft_config=peft_config
)
trainer.train()
```
> [!WARNING]
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsense generations. If the chat template doesn't contain special tokens (e.g. Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
### Training adapters with base 8 bit models
For that, you need to first load your 8 bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
```python
...
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/gpt-neo-125m",
load_in_8bit=True,
device_map="auto",
)
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(),
peft_config=peft_config,
)
trainer.train()
```
## Using Flash Attention and Flash Attention 2
You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code.
First, to make sure you have all the latest features from transformers, install transformers from source
```bash
pip install -U git+https://github.com/huggingface/transformers.git
```
Note that Flash Attention only works on GPU now and under half-precision regime (when using adapters, base model loaded in half-precision)
Note also both features are perfectly compatible with other tools such as quantization.
### Using Flash-Attention 1
For Flash Attention 1 you can use the `BetterTransformer` API and force-dispatch the API to use Flash Attention kernel. First, install the latest optimum package:
```bash
pip install -U optimum
```
Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager:
```diff
...
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
trainer.train()
```
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
### Using Flash Attention-2
To use Flash Attention 2, first install the latest `flash-attn` package:
```bash
pip install -U flash-attn
```
And add `attn_implementation="flash_attention_2"` when calling `from_pretrained`:
```python
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
attn_implementation="flash_attention_2"
)
```
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
### Using model creation utility
We included a utility function to create your model.
[[autodoc]] ModelConfig
```python
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
model_args = ModelConfig(
model_name_or_path="facebook/opt-350m"
attn_implementation=None, # or "flash_attention_2"
)
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
trainer = SFTTrainer(
...,
model=model_args.model_name_or_path,
peft_config=get_peft_config(model_args),
)
```
### Enhance the model's performances using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/neft-screenshot.png">
</div>
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
training_args = SFTConfig(
neftune_noise_alpha=5,
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
)
trainer.train()
```
We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-neftune-mistral-7b.png">
</div>
Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains.
### Accelerate fine-tuning 2x using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```python
import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_seq_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing=True,
random_state=3407,
)
training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
```
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
## Liger-Kernel: Increase 20% throughput and reduces 60% memory for multi-GPU training
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
With great memory reduction, you can potentially turn off cpu_offloading or gradient checkpointing to further boost the performance.
| Speed Up | Memory Reduction |
|--------------------------|-------------------------|
| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
1. To use Liger-Kernel in `SFTTrainer`, first install by
```bash
pip install liger-kernel
```
2. Once installed, set `use_liger` in [`SFTConfig`]. No other changes are needed!
```python
training_args = SFTConfig(
use_liger=True
)
```
To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/).
## Best practices
Pay attention to the following best practices when training a model with that trainer:
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
## Multi-GPU Training
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
- Ensure that the model is placed on the correct device:
```python
from accelerate import PartialState
device_string = PartialState().process_index
model = AutoModelForCausalLM.from_pretrained(
...
device_map={'':device_string}
)
```
## GPTQ Conversion
You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.
## Extending `SFTTrainer` for Vision Language Models
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
### Preparing the Data
The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:
```python
images = ["obama.png"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Who is this?"},
{"type": "image"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Barack Obama"}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is he famous for?"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "He is the 44th President of the United States."}
]
}
]
```
To illustrate how this data format will be processed using the LLaVA model, you can use the following code:
```python
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))
```
The output will be formatted as follows:
```txt
Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States.
```
<iframe src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train" frameborder="0" width="100%" height="560px"></iframe>
### A custom collator for processing multi-modal data
Unlike the default behavior of `SFTTrainer`, processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:
```python
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]
# Tokenize the texts and process the images
batch = processor(texts, images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
```
We can verify that the collator works as expected by running the following code:
```python
from datasets import load_dataset
dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])
```
### Training the vision-language model
Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the `SFTConfig`, specifically `remove_unused_columns` and `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`.
```python
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_dataset,
processing_class=processor.tokenizer,
)
```
A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py).
- [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s)
- [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf)
## SFTTrainer
[[autodoc]] SFTTrainer
## SFTConfig
[[autodoc]] SFTConfig
## Datasets
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
### ConstantLengthDataset
[[autodoc]] trainer.ConstantLengthDataset

View File

@ -1,30 +0,0 @@
# Summarization Example
The script in this example show how to train a reward model for summarization, following the OpenAI Learning to Summarize from Human Feedback [paper](https://arxiv.org/abs/2009.01325). We've validated that the script can be used to train a small GPT2 to get slightly over 60% validation accuracy, which is aligned with results from the paper. The model is [here](https://huggingface.co/Tristan/gpt2_reward_summarization).
Here's an overview of the relevant files in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
| File | Description |
|---|---|
| `scripts/reward_summarization.py` | For tuning the reward model. |
| `scripts/ds3_reward_summarization_example_config.json` | Can be used with the reward model script to scale it up to arbitrarily big models that don't fit on a single GPU. |
## Installation
```bash
pip install trl
pip install evaluate
# optional: deepspeed
pip install deepspeed
```
```bash
# If you want your reward model to follow the Learning to Summarize from Human Feedback paper closely, then tune a GPT model on summarization and then instantiate the reward model
# with it. In other words, pass in the name of your summarization-finetuned gpt on the hub, instead of the name of the pretrained gpt2 like we do in the following examples of how
# to run this script.
# Example of running this script with the small size gpt2 on a 40GB A100 (A100's support bf16). Here, the global batch size will be 64:
python -m torch.distributed.launch --nproc_per_node=1 reward_summarization.py --bf16
# Example of running this script with the xl size gpt2 on 16 40GB A100's. Here the global batch size will still be 64:
python -m torch.distributed.launch --nproc_per_node=16 reward_summarization.py --per_device_train_batch_size=1 --per_device_eval_batch_size=1 --gradient_accumulation_steps=4 --gpt_model_name=gpt2-xl --bf16 --deepspeed=ds3_reward_summarization_example_config.json
```

View File

@ -0,0 +1,197 @@
# Text Environments
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv.png">
</div>
Let's dive into how text environments work and start with tools!
## Tools
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
### `transformers.Tool`
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
```Python
from transformers import load_tool
# simple calculator tool that runs +-/* operations
calc_tool = load_tool("ybelkada/simple-calculator")
# python interpreter that executes program and returns outputs
py_tool = load_tool("lvwerra/python-interpreter")
# wikipedia search index that returns best search match
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
```
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
```Python
calc_tool("1/2")
>>> "0.5"
```
Note that both input and return values are strings to enable easy usage with a language model.
### Custom Tools
The following is an example of a tool that adds two integers:
```Python
def add(text):
int_1, int_2 = text.split("+")
result = int(int_1) + int(int_2)
return str(result)
print(add("1+1"))
>>> "2"
```
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
### Call syntax
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
```python
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
```
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
```python
"<request><Calculator>1/2<call>0.5<response>"
```
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
Now let's have a look how we can create a new text environment!
## Create a `TextEnvironment`
```python
prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
"""
def reward_fn(result, answer):
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
result_parsed = result.split("=")[1].split("<")[0]
return int(result_parsed==answer)
text_env = TextEnvironemnt(
model=model,
tokenizer=tokenizer,
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
reward_fn=exact_match_reward,
prompt=prompt,
max_turns=1
max_tool_response=100
generation_kwargs={"do_sample": "true"}
)
```
Let's decompose the settings:
| Argument | Description |
|:-------------------|:----------------|
| `model` | Language model to interact with the environment and generate requests. |
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
| `max_length` | The maximum number of tokens to allow in an episode. |
| `generation_kwargs`| Generation settings used by the language model. |
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
## Run an Episode
To run a set of queries through the text environment one can simply use the `run` method.
```python
queries = ["What is 1/2?"]
answers = ["0.5"]
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
```
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
There are five objects that are returned by `run`:
- `queries`: a list of the tokenized queries
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
- `rewards`: a list of reward for each query/response
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
Next, we'll train a PPO step with the generated responses!
### Train
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
```python
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
```
## `TextHistory`
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
### Attributes
The following table summarises the available attributes of the `TextEnvironment` class:
| Attribute | Description |
|:-------------------|:----------------|
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
| `completed` | Indicates if the interaction with the environment has completed. |
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
### Visualization
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods).
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_text.png" width=600>
</div>
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_tokens.png" width=800>
</div>
Note that you can turn on the colour legend by passing `show_legend=True`.
## API Documentation
[[autodoc]] TextEnvironment
[[autodoc]] TextHistory

View File

@ -1,16 +0,0 @@
# Trainer
At TRL we support PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
The Trainer and model classes are largely inspired from `transformers.Trainer` and `transformers.AutoModel` classes and adapted for RL.
## PPOConfig
[[autodoc]] PPOConfig
## PPOTrainer
[[autodoc]] PPOTrainer
## set_seed
[[autodoc]] set_seed

58
docs/source/use_model.md Normal file
View File

@ -0,0 +1,58 @@
# Use model after training
Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference).
## Load and Generate
If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored:
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
device = "cpu" # or "cuda" if you have a GPU
model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device)
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0]))
```
Alternatively you can also use the pipeline:
```python
from transformers import pipeline
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
pipe = pipeline("text-generation", model=model_name_or_path)
print(pipe("This movie was really")[0]["generated_text"])
```
## Use Adapters PEFT
```python
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub"
adapter_model_name = "path/to/my/adapter"
model = AutoModelForCausalLM.from_pretrained(base_model_name)
model = PeftModel.from_pretrained(model, adapter_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
```
You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger:
```python
model = AutoModelForCausalLM.from_pretrained(base_model_name)
model = PeftModel.from_pretrained(model, adapter_model_name)
model = model.merge_and_unload()
model.save_pretrained("merged_adapters")
```
Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above.

View File

@ -0,0 +1,160 @@
# Using LLaMA models with TRL
We've begun rolling out examples to use Meta's LLaMA models in `trl` (see [Meta's LLaMA release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) for the original LLaMA model).
## Efficient training strategies
Even training the smallest LLaMA model requires an enormous amount of memory. Some quick math: in bf16, every parameter uses 2 bytes (in fp32 4 bytes) in addition to 8 bytes used, e.g., in the Adam optimizer (see the [performance docs](https://huggingface.co/docs/transformers/perf_train_gpu_one#optimizer) in Transformers for more info). So a 7B parameter model would use `(2+8)*7B=70GB` just to fit in memory and would likely need more when you compute intermediate values such as attention scores. So you couldnt train the model even on a single 80GB A100 like that. You can use some tricks, like more efficient optimizers of half-precision training, to squeeze a bit more into memory, but youll run out sooner or later.
Another option is to use Parameter-Efficient Fine-Tuning (PEFT) techniques, such as the [`peft`](https://github.com/huggingface/peft) library, which can perform low-rank adaptation (LoRA) on a model loaded in 8-bit.
For more on `peft` + `trl`, see the [docs](https://huggingface.co/docs/trl/sentiment_tuning_peft).
Loading the model in 8bit reduces the memory footprint drastically since you only need one byte per parameter for the weights (e.g. 7B LlaMa is 7GB in memory).
Instead of training the original weights directly, LoRA adds small adapter layers on top of some specific layers (usually the attention layers); thus, the number of trainable parameters is drastically reduced.
In this scenario, a rule of thumb is to allocate ~1.2-1.4GB per billion parameters (depending on the batch size and sequence length) to fit the entire fine-tuning setup.
This enables fine-tuning larger models (up to 50-60B scale models on a NVIDIA A100 80GB) at low cost.
Now we can fit very large models into a single GPU, but the training might still be very slow.
The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU.
With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs.
![chapter10_ddp.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_ddp.png)
We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively.
```bash
accelerate launch --multi_gpu --num_machines 1 --num_processes 8 my_accelerate_script.py
torchrun --nnodes 1 --nproc_per_node 8 my_torch_script.py
```
## Supervised fine-tuning
Before we start training reward models and tuning our model with RL, it helps if the model is already good in the domain we are interested in.
In our case, we want it to answer questions, while for other use cases, we might want it to follow instructions, in which case instruction tuning is a great idea.
The easiest way to achieve this is by continuing to train the language model with the language modeling objective on texts from the domain or task.
The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences) is enormous (over 10 million instructions), so we can easily train the language model on a subset of it.
There is nothing special about fine-tuning the model before doing RLHF - its just the causal language modeling objective from pretraining that we apply here.
To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with a EOS token in between and cut chunks of the context size to fill the batch without any padding.
![chapter10_preprocessing-clm.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_preprocessing-clm.png)
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
The packing is handled by the `ConstantLengthDataset` and we can then use the `Trainer` after loading the model with `peft`. First, we load the model in int8, prepare it for training, and then add the LoRA adapters.
```python
# load model in 8bit
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
load_in_8bit=True,
device_map={"": Accelerator().local_process_index}
)
model = prepare_model_for_kbit_training(model)
# add LoRA to model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
```
We train the model for a few thousand steps with the causal language modeling objective and save the model.
Since we will tune the model again with different objectives, we merge the adapter weights with the original model weights.
**Disclaimer:** due to LLaMA's license, we release only the adapter weights for this and the model checkpoints in the following sections.
You can apply for access to the base model's weights by filling out Meta AI's [form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) and then converting them to the 🤗 Transformers format by running this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py).
Note that you'll also need to install 🤗 Transformers from source until the `v4.28` is released.
Now that we have fine-tuned the model for the task, we are ready to train a reward model.
## Reward modeling and human preferences
In principle, we could fine-tune the model using RLHF directly with the human annotations.
However, this would require us to send some samples to humans for rating after each optimization iteration.
This is expensive and slow due to the number of training samples needed for convergence and the inherent latency of human reading and annotator speed.
A trick that works well instead of direct feedback is training a reward model on human annotations collected before the RL loop.
The goal of the reward model is to imitate how a human would rate a text. There are several possible strategies to build a reward model: the most straightforward way would be to predict the annotation (e.g. a rating score or a binary value for “good”/”bad”).
In practice, what works better is to predict the ranking of two examples, where the reward model is presented with two candidates `(y_k, y_j)` for a given prompt `x` and has to predict which one would be rated higher by a human annotator.
With the StackExchange dataset, we can infer which of the two answers was preferred by the users based on the score.
With that information and the loss defined above, we can then modify the `transformers.Trainer` by adding a custom loss function.
```python
class RewardTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
if return_outputs:
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
return loss
```
We utilize a subset of a 100,000 pair of candidates and evaluate on a held-out set of 50,000. With a modest training batch size of 4, we train the Llama model using the LoRA `peft` adapter for a single epoch using the Adam optimizer with BF16 precision. Our LoRA configuration is:
```python
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
```
As detailed in the next section, the resulting adapter can be merged into the frozen model and saved for further downstream use.
## Reinforcement Learning from Human Feedback
With the fine-tuned language model and the reward model at hand, we are now ready to run the RL loop. It follows roughly three steps:
1. Generate responses from prompts,
2. Rate the responses with the reward model,
3. Run a reinforcement learning policy-optimization step with the ratings.
The Query and Response prompts are templated as follows before being tokenized and passed to the model:
```bash
Question: <Query>
Answer: <Response>
```
The same template was used for SFT, RM and RLHF stages.
Once more, we utilize `peft` for memory-efficient training, which offers an extra advantage in the RLHF context.
Here, the reference model and policy share the same base, the SFT model, which we load in 8-bit and freeze during training.
We exclusively optimize the policy's LoRA weights using PPO while sharing the base model's weights.
```python
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
question_tensors = batch["input_ids"]
# sample from the policy and to generate responses
response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=False,
length_sampler=output_length_sampler,
**generation_kwargs,
)
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
# Run PPO step
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
# Log stats to Wandb
ppo_trainer.log_stats(stats, batch, rewards)
```
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).

162
docs/source/xpo_trainer.mdx Normal file
View File

@ -0,0 +1,162 @@
# XPO Trainer
[![](https://img.shields.io/badge/All_models-XPO-blue)](https://huggingface.co/models?other=xpo,trl)
## Overview
Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the intitial model and human feedback data.
The abstract from the paper is the following:
> Reinforcement learning from human feedback (RLHF) has emerged as a central tool for language model alignment. We consider online exploration in RLHF, which exploits interactive access to human or AI feedback by deliberately encouraging the model to produce diverse, maximally informative responses. By allowing RLHF to confidently stray from the pre-trained model, online exploration offers the possibility of novel, potentially super-human capabilities, but its full potential as a paradigm for language model training has yet to be realized, owing to computational and statistical bottlenecks in directly adapting existing reinforcement learning techniques. We propose a new algorithm for online exploration in RLHF, Exploratory Preference Optimization (XPO), which is simple and practical -- a one-line change to (online) Direct Preference Optimization (DPO; Rafailov et al., 2023) -- yet enjoys the strongest known provable guarantees and promising empirical performance. XPO augments the DPO objective with a novel and principled exploration bonus, empowering the algorithm to explore outside the support of the initial model and human feedback data. In theory, we show that XPO is provably sample-efficient and converges to a near-optimal language model policy under natural exploration conditions, irrespective of whether the initial model has good coverage. Our analysis, which builds on the observation that DPO implicitly performs a form of Q*-approximation (or, Bellman error minimization), combines previously disparate techniques from language modeling and theoretical reinforcement learning in a serendipitous fashion through the perspective of KL-regularized Markov decision processes. Empirically, we find that XPO is more sample-efficient than non-exploratory DPO variants in a preliminary evaluation.
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Lewis Tunstall](https://huggingface.co/lewtun).
## Quick start
This example demonstrates how to train a model using the XPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_xpo.py
from datasets import load_dataset
from trl import PairRMJudge, XPOConfig, XPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = XPOConfig(output_dir="Qwen2-0.5B-XPO", logging_steps=10)
trainer = XPOTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_xpo.py
```
Distributed across 8 GPUs, the training takes approximately 1 hour.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-XPO&gt;:</span></strong>
The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript.
</code></pre>
## Expected dataset type
XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`XPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Usage tips
### Use a reward model
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
```diff
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
trainer = XPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
)
```
<Tip warning={true}>
Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
</Tip>
### Encourage EOS token generation
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`XPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`XPOConfig`]:
```python
training_args = XPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
```
### Logging Completions
To better understand your models behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`].
```python
trainer = XPOTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)
```
This callback logs the model's generated completions directly to Weights & Biases.
![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)
## Example script
We provide an example script to train a model using the XPO method. The script is available in [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py)
To test the XPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
```bash
python examples/scripts/xpo.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--logging_steps 25 \
--output_dir Qwen2.5-0.5B-XPO-PairRM \
--warmup_ratio 0.1 \
--push_to_hub
```
## Logged metrics
The logged metrics are as follows:
* `loss/xpo`: The mean xpo part of the full loss.
* `loss/dpo`: The mean dpo part of the full loss.
* `objective/kl`: The mean KL divergence between the model and reference data.
* `objective/entropy`: The mean entropy of the model and reference data.
* `objective/model_scores`: The mean scores (according to the reward model) of the model completions.
* `objective/ref_scores`: The mean scores (according to the reward model) of the reference completions.
* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions.
* `rewards/chosen`: The mean reward (according to XPO's DPO implicit reward model) of the chosen completions.
* `rewards/rejected`: The mean reward (according to XPO's DPO implicit reward model) of the rejected completions.
* `rewards/accuracies`: The accuracies of the XPO's implicit reward model.
* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions.
* `logps/chosen`: The mean log probabilities of the chosen completions.
* `logps/rejected`: The mean log probabilities of the rejected completions.
* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token.
* `val/ref_contain_eos_token`: The amount of times the reference's output contains the eos token.
* `alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
## XPOTrainer
[[autodoc]] XPOTrainer
## XPOConfig
[[autodoc]] XPOConfig

View File

@ -1,66 +1,3 @@
# Sentiment Examples
# Examples
The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
Here's an overview of the notebooks and scripts:
| File | Description |
|---|---|
| `notebooks/gpt2-sentiment.ipynb` | Fine-tune GPT2 to generate positive movie reviews. |
| `notebooks/gpt2-sentiment-control.ipynb` | Fine-tune GPT2 to generate movie reviews with controlled sentiment. |
| `scripts/gpt2-sentiment.py` | Same as the notebook, but easier to use to use in mutli-GPU setup. |
| `scripts/t5-sentiment.py` | Same as GPT2 script, but for a Seq2Seq model (T5). |
## Installation
```bash
pip install trl
#optional: wandb
pip install wandb
```
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
## Launch scripts
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
```bash
accelerate config # will prompt you to define the training configuration
accelerate launch scripts/gpt2-sentiment.py # launches training
```
# Summarization Example
The script in this example show how to train a reward model for summarization, following the OpenAI Learning to Summarize from Human Feedback [paper](https://arxiv.org/abs/2009.01325). We've validated that the script can be used to train a small GPT2 to get slightly over 60% validation accuracy, which is aligned with results from the paper. The model is [here](https://huggingface.co/Tristan/gpt2_reward_summarization).
Here's an overview of the files:
| File | Description |
|---|---|
| `scripts/reward_summarization.py` | For tuning the reward model. |
| `scripts/ds3_reward_summarization_example_config.json` | Can be used with the reward model script to scale it up to arbitrarily big models that don't fit on a single GPU. |
## Installation
```bash
pip install trl
pip install evaluate
# optional: deepspeed
pip install deepspeed
```
```bash
# If you want your reward model to follow the Learning to Summarize from Human Feedback paper closely, then tune a GPT model on summarization and then instantiate the reward model
# with it. In other words, pass in the name of your summarization-finetuned gpt on the hub, instead of the name of the pretrained gpt2 like we do in the following examples of how
# to run this script.
# Example of running this script with the small size gpt2 on a 40GB A100 (A100's support bf16). Here, the global batch size will be 64:
python -m torch.distributed.launch --nproc_per_node=1 reward_summarization.py --bf16
# Example of running this script with the xl size gpt2 on 16 40GB A100's. Here the global batch size will still be 64:
python -m torch.distributed.launch --nproc_per_node=16 reward_summarization.py --per_device_train_batch_size=1 --per_device_eval_batch_size=1 --gradient_accumulation_steps=4 --gpt_model_name=gpt2-xl --bf16 --deepspeed=ds3_reward_summarization_example_config.json
```
Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.

View File

@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
zero3_init_flag: false
zero_stage: 1
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,21 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,25 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: "NO"
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,18 @@
# This is an example configuration file of TRL CLI, you can use it for
# SFT like that: `trl sft --config config.yaml --output_dir test-sft`
# The YAML file supports environment variables by adding an `env` field
# as below
# env:
# CUDA_VISIBLE_DEVICES: 0
model_name_or_path:
Qwen/Qwen2.5-0.5B
dataset_name:
stanfordnlp/imdb
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine

View File

@ -0,0 +1,96 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/hh-rlhf-helpful-base"
dataset_num_proc: Optional[int] = None
def common_start(str1: str, str2: str) -> str:
# Zip the two strings and iterate over them together
common_chars = []
for c1, c2 in zip(str1, str2):
if c1 == c2:
common_chars.append(c1)
else:
break
# Join the common characters and return as a string
return "".join(common_chars)
def extract_dialogue(example: str) -> list[dict[str, str]]:
# Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
prompt_text = common_start(example["chosen"], example["rejected"])
# The chosen and rejected may share a common start, so we need to remove the common part
if not prompt_text.endswith("\n\nAssistant: "):
prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: "
# Extract the chosen and rejected lines
chosen_line = example["chosen"][len(prompt_text) :]
rejected_line = example["rejected"][len(prompt_text) :]
# Remove the generation prompt ("\n\nAssistant: ") from the prompt
prompt_text = prompt_text[: -len("\n\nAssistant: ")]
# Split the string at every occurrence of "Human: " or "Assistant: "
prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text)
# Remove the first element as it's empty
prompt_lines = prompt_lines[1:]
prompt = []
for idx in range(0, len(prompt_lines), 2):
role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant"
content = prompt_lines[idx + 1]
prompt.append({"role": role, "content": content})
# Remove the prompt from the chosen and rejected dialogues
chosen = [{"role": "assistant", "content": chosen_line}]
rejected = [{"role": "assistant", "content": rejected_line}]
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
dataset = dataset.map(extract_dialogue, num_proc=script_args.dataset_num_proc)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

View File

@ -0,0 +1,81 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-descriptiveness"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/lm-human-preferences-descriptiveness"
dataset_num_proc: Optional[int] = None
# Edge cases handling: remove the cases where all samples are the same
def samples_not_all_same(example):
return not all(example["sample0"] == example[f"sample{j}"] for j in range(1, 4))
def to_prompt_completion(example, tokenizer):
prompt = tokenizer.decode(example["query"]).strip()
best_idx = example["best"]
chosen = tokenizer.decode(example[f"sample{best_idx}"])
for rejected_idx in range(4): # take the first rejected sample that is different from the chosen one
rejected = tokenizer.decode(example[f"sample{rejected_idx}"])
if chosen != rejected:
break
assert chosen != rejected
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset(
"json",
data_files="https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/descriptiveness/offline_5k.json",
split="train",
)
dataset = dataset.filter(samples_not_all_same, num_proc=script_args.dataset_num_proc)
dataset = dataset.map(
to_prompt_completion,
num_proc=script_args.dataset_num_proc,
remove_columns=["query", "sample0", "sample1", "sample2", "sample3", "best"],
fn_kwargs={"tokenizer": AutoTokenizer.from_pretrained("gpt2")},
)
# train_size taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L79)
dataset = dataset.train_test_split(train_size=4992)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

View File

@ -0,0 +1,74 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-sentiment"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/lm-human-preferences-sentiment"
dataset_num_proc: Optional[int] = None
def to_prompt_completion(example, tokenizer):
prompt = tokenizer.decode(example["query"]).strip()
best_idx = example["best"]
chosen = tokenizer.decode(example[f"sample{best_idx}"])
for rejected_idx in range(4): # take the first rejected sample that is different from the chosen one
rejected = tokenizer.decode(example[f"sample{rejected_idx}"])
if chosen != rejected:
break
assert chosen != rejected
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset(
"json",
data_files="https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/sentiment/offline_5k.json",
split="train",
)
dataset = dataset.map(
to_prompt_completion,
num_proc=script_args.dataset_num_proc,
remove_columns=["query", "sample0", "sample1", "sample2", "sample3", "best"],
fn_kwargs={"tokenizer": AutoTokenizer.from_pretrained("gpt2")},
)
# train_size taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L70)
dataset = dataset.train_test_split(train_size=4992)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

View File

@ -0,0 +1,131 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from itertools import chain
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/math_shepherd"
dataset_num_proc: Optional[int] = None
def process_example(example):
# Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label"
inputs = example["input"].replace("ки", "")
# Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label)
indexes = [m.start() for m in re.finditer("", inputs)]
# Sanity that all indexes are either "+" or "-"
assert all(example["label"][idx] in ["+", "-"] for idx in indexes)
# Get the labels
labels = [example["label"][idx] == "+" for idx in indexes]
# Split the inputs into steps (caution, the first step is missing here, it is the prompt)
steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]
# Remove the last step (single ⶻ)
steps = steps[:-1]
# Get the prompt (first part) and completions (rest)
prompt = steps[0]
completions = steps[1:]
# Remove the heading "ⶻ" and the final whitespace from the completions
assert all(completion.startswith("") for completion in completions)
completions = [completion[1:].strip() for completion in completions]
# At this point, we need to retrieve the first step from the prompt.
# First, we handle particular cases (annotation error) where we have a first label before the end of the prompt.
if prompt.startswith(
(
"Mr. Rocky",
"Parker",
"What is the smallest positive",
" The Myth",
"Let $\\mathbf{a}$",
"Find the arithmetic",
"Determine an ordered pair",
"Determine the ordered pair",
"At the Quill and Scroll stationery",
"Round to the nearest",
r"Calculate $\sqrt{10p}",
r"Simplify $\sqrt{28x}",
)
):
# Some spotted datasets errors where there is an annotation in the prompt: we remove it
labels = labels[1:]
# Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or
# (less common) "?".
elif "Step 1:" in prompt:
prompt, first_step = prompt.split("Step 1:")
first_step = "Step 1:" + first_step
completions = [first_step.strip()] + completions
elif "step 1:" in prompt:
prompt, first_step = prompt.split("step 1:")
first_step = "step 1:" + first_step
completions = [first_step.strip()] + completions
elif "?" in prompt:
prompt, first_step = prompt.split("?")
prompt = prompt + "?"
completions = [first_step.strip()] + completions
else:
raise ValueError(f"Prompt can't be processed: {prompt}")
# Strip the prompt
prompt = prompt.strip()
# Sanity check that the length of the completions is the same as the length of the labels
assert len(completions) == len(labels)
return {"prompt": prompt, "completions": completions, "labels": labels}
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset("peiyi9979/Math-Shepherd", split="train")
dataset = dataset.map(
process_example,
remove_columns=["input", "label", "task"],
num_proc=script_args.dataset_num_proc,
)
dataset = dataset.train_test_split(test_size=0.05, seed=42)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

View File

@ -0,0 +1,118 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/prm800k"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/prm800k"
dataset_num_proc: Optional[int] = None
def process_example(example):
outputs = []
prompt = example["question"]["problem"]
# Iterate through each step
previous_completions = []
previous_labels = []
for step in example["label"]["steps"]:
if step["completions"] is None and step["human_completion"] is None and step["chosen_completion"] is None:
# happens sometimes
break
# Loop through completions
for completion_idx, completion in enumerate(step["completions"]):
# For every completion that are not chosen, we are in a terminal state, so we can add it to the list of outputs.
if completion_idx != step["chosen_completion"]:
content = completion["text"]
completions = previous_completions[:] + [content]
label = completion["rating"] == 1
labels = previous_labels[:] + [label]
outputs.append({"prompt": prompt, "completions": completions, "labels": labels})
# Now, exapand the previous completions and labels
if step["chosen_completion"] is not None:
chosen_completion = step["completions"][step["chosen_completion"]]
label = chosen_completion["rating"] == 1
elif step["human_completion"] is not None:
chosen_completion = step["human_completion"]
label = True
else:
break
content = chosen_completion["text"]
previous_completions.append(content)
previous_labels.append(label)
# Last step: we are in a terminal state, so we can add it to the list of outputs
outputs.append({"prompt": prompt, "completions": previous_completions, "labels": previous_labels})
return outputs
def process_batch(examples):
outputs = []
batch_size = len(examples["label"])
for idx in range(batch_size):
example = {k: v[idx] for k, v in examples.items()}
outputs.extend(process_example(example))
# list of dict to dict of list
outputs = {k: [v[k] for v in outputs] for k in outputs[0]}
return outputs
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
data_files = {
"train": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_train.jsonl",
"test": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_test.jsonl",
}
dataset = load_dataset("json", data_files=data_files)
dataset = dataset.map(
process_batch,
batched=True,
batch_size=10,
remove_columns=[
"labeler",
"timestamp",
"generation",
"is_quality_control_question",
"is_initial_screening_question",
"question",
"label",
],
num_proc=script_args.dataset_num_proc,
)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

View File

@ -0,0 +1,73 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from datasets import features, load_dataset
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/rlaif-v"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/rlaif-v"
dataset_num_proc: Optional[int] = None
def to_conversational(example):
"""
Convert prompt from "xxx" to [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "xxx"}]}]
and chosen and rejected from "xxx" to [{"role": "assistant", "content": [{"type": "text", "text": "xxx"}]}].
Images are wrapped into a list.
"""
prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}]}]
chosen = [{"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}]
rejected = [{"role": "assistant", "content": [{"type": "text", "text": example["rejected"]}]}]
return {"prompt": prompt, "images": [example["image"]], "chosen": chosen, "rejected": rejected}
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train")
dataset = dataset.map(
to_conversational,
num_proc=script_args.dataset_num_proc,
remove_columns=dataset.column_names,
writer_batch_size=128,
)
# Cast the images to Sequence[Image] to avoid bytes format
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True))
dataset = dataset.cast(f)
dataset = dataset.train_test_split(test_size=0.01, writer_batch_size=128)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

67
examples/datasets/tldr.py Normal file
View File

@ -0,0 +1,67 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/tldr"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/tldr"
dataset_num_proc: Optional[int] = None
def to_prompt_completion(example):
tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
prompt = tldr_format_str.format(subreddit=example["subreddit"], title=example["title"], post=example["post"])
completion = " " + example["summary"] # Add a space to separate the prompt from the completion
return {"prompt": prompt, "completion": completion}
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
# Filtered reddit TL;DR dataset from https://github.com/openai/summarize-from-feedback?tab=readme-ov-file#reddit-tldr-dataset
data_files = {
"train": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/train.jsonl",
"validation": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/valid.jsonl",
"test": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/test.jsonl",
}
dataset = load_dataset("json", data_files=data_files)
dataset = dataset.map(
to_prompt_completion,
num_proc=script_args.dataset_num_proc,
remove_columns=["id", "subreddit", "title", "post", "summary"],
)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

View File

@ -0,0 +1,72 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/tldr-preference"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/tldr-preference"
dataset_num_proc: Optional[int] = None
def to_preference(example):
info = example["info"]
if example["batch"] in ["batch0_cnndm", "cnndm0", "cnndm2"]: # CNN Daily Mail batches
article = info["article"].replace("\n\n", "\n")
prompt = f"TITLE: {info['title']}\n\n{article}\n\nTL;DR:"
elif example["batch"] in [f"batch{i}" for i in range(3, 23)] + ["edit_b2_eval_test"]: # Reddit batches
post = info["post"].replace("\n\n", "\n")
prompt = f"SUBREDDIT: r/{info['subreddit']}\n\nTITLE: {info['title']}\n\nPOST: {post}\n\nTL;DR:"
else:
raise ValueError(f"Unknown batch: {example['batch']}")
chosen_idx = example["choice"]
rejected_idx = 1 - chosen_idx
chosen = example["summaries"][chosen_idx]["text"]
rejected = example["summaries"][rejected_idx]["text"]
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset("openai/summarize_from_feedback", "comparisons")
dataset = dataset.map(
to_preference,
num_proc=script_args.dataset_num_proc,
remove_columns=["info", "summaries", "choice", "worker", "batch", "split", "extra"],
)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

View File

@ -0,0 +1,54 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
"""
python -i examples/datasets/tokenize_ds.py --model HuggingFaceH4/zephyr-7b-beta
python -i examples/datasets/tokenize_ds.py --model gpt2
"""
@dataclass
class ScriptArguments:
dataset_name: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"}
)
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
)
if __name__ == "__main__":
script_args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
dataset = load_dataset(script_args.dataset_name)
tokenizer = AutoTokenizer.from_pretrained(script_args.model)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
dataset = dataset.map(process, num_proc=script_args.dataset_num_proc)
print(dataset["train"][0]["chosen"])

View File

@ -0,0 +1,68 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-prompt"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/ultrafeedback-prompt"
dataset_num_proc: Optional[int] = None
def to_unpaired_preference(example):
prompt = [{"role": "user", "content": example["instruction"]}]
return {"prompt": prompt}
def drop_long_prompt(example):
if len(example["prompt"][0]["content"]) > 512:
return False
else:
return True
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset("openbmb/UltraFeedback", split="train")
dataset = dataset.map(
to_unpaired_preference,
remove_columns=["source", "instruction", "models", "completions", "correct_answers", "incorrect_answers"],
num_proc=script_args.dataset_num_proc,
)
dataset = dataset.filter(drop_long_prompt)
dataset = dataset.train_test_split(test_size=0.05, seed=42)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

View File

@ -0,0 +1,102 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
model_name (`str`, *optional*, defaults to `"gpt-3.5-turbo"`):
Language model to target. Possible values are:
- `"alpaca-7b"`
- `"bard"`
- `"falcon-40b-instruct"`
- `"gpt-3.5-turbo"` (default)
- `"gpt-4"`
- `"llama-2-13b-chat"`
- `"llama-2-70b-chat"`
- `"llama-2-7b-chat"`
- `"mpt-30b-chat"`
- `"pythia-12b"`
- `"starchat"`
- `"ultralm-13b"`
- `"ultralm-65b"`
- `"vicuna-33b"`
- `"wizardlm-13b"`
- `"wizardlm-70b"`
- `"wizardlm-7b"`
aspect (`str`, *optional*, defaults to `"helpfulness"`):
Aspect to target. Possible values are:
- `"helpfulness"` (default)
- `"honesty"`
- `"instruction-following"`
- `"truthfulness"`
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
model_name: str = "gpt-3.5-turbo"
aspect: str = "helpfulness"
push_to_hub: bool = False
repo_id: str = "trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness"
dataset_num_proc: Optional[int] = None
def to_unpaired_preference(example, model_name, aspect):
prompt = [{"role": "user", "content": example["instruction"]}]
model_index = example["models"].index(model_name)
response_content = example["completions"][model_index]["response"]
completion = [{"role": "assistant", "content": response_content}]
score = int(example["completions"][model_index]["annotations"][aspect]["Rating"])
label = score >= 5
return {"prompt": prompt, "completion": completion, "label": label}
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset("openbmb/UltraFeedback", split="train")
dataset = dataset.filter(
lambda example: script_args.model_name in example["models"],
batched=False,
num_proc=script_args.dataset_num_proc,
)
dataset = dataset.map(
to_unpaired_preference,
remove_columns=["source", "instruction", "models", "completions", "correct_answers", "incorrect_answers"],
fn_kwargs={"model_name": script_args.model_name, "aspect": script_args.aspect},
num_proc=script_args.dataset_num_proc,
)
dataset = dataset.train_test_split(test_size=0.05, seed=42)
if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)

View File

@ -0,0 +1,7 @@
# Notebooks
This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications.
- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.

View File

@ -0,0 +1,662 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "WQpNapZNWuXP"
},
"source": [
"\n",
"**Best-of-n sampling as an alternative to RLHF**\n",
"\n",
"This notebook compares reward-model scores of prompt based responses from \n",
"1. a base model (`gpt2-imdb`)\n",
"2. `RLHF` tuned model based on this base-model \n",
"3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n",
"\n",
"Import dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vDA6qayz692w"
},
"outputs": [],
"source": [
"%pip install transformers trl"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "M1s_iNm773hM"
},
"outputs": [],
"source": [
"import torch\n",
"import pandas as pd\n",
"\n",
"from transformers import pipeline, AutoTokenizer\n",
"from datasets import load_dataset\n",
"\n",
"from trl import AutoModelForCausalLMWithValueHead\n",
"from trl.core import LengthSampler\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y7hyrIrO8tcY"
},
"source": [
"Various constants"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "MqS3OM6Q8x6g"
},
"outputs": [],
"source": [
"ref_model_name = \"lvwerra/gpt2-imdb\"\n",
"model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n",
"reward_model = \"lvwerra/distilbert-imdb\"\n",
"\n",
"N_BEST_OF = 4"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c1YcXeElg6or"
},
"source": [
"Models and tokenizers"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "b855NrL181Hh"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/kashif/Github/transformers/src/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"AutoModelForCausalLMWithValueHead(\n",
" (pretrained_model): GPT2LMHeadModel(\n",
" (transformer): GPT2Model(\n",
" (wte): Embedding(50257, 768)\n",
" (wpe): Embedding(1024, 768)\n",
" (drop): Dropout(p=0.1, inplace=False)\n",
" (h): ModuleList(\n",
" (0-11): 12 x GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2SdpaAttention(\n",
" (c_attn): Conv1D(nf=2304, nx=768)\n",
" (c_proj): Conv1D(nf=768, nx=768)\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D(nf=3072, nx=768)\n",
" (c_proj): Conv1D(nf=768, nx=3072)\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
" )\n",
" (v_head): ValueHead(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (summary): Linear(in_features=768, out_features=1, bias=True)\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" )\n",
")"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n",
"\n",
"ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n",
"\n",
"reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n",
"\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
"# cuda-ize models\n",
"model.to(device)\n",
"ref_model.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z1Cz0gCFhZYJ"
},
"source": [
"Dataset building"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "LqLVEp5p_8XM"
},
"outputs": [],
"source": [
"def build_dataset(\n",
" tokenizer,\n",
" dataset_name=\"stanfordnlp/imdb\",\n",
" input_min_text_length=2,\n",
" input_max_text_length=8,\n",
"):\n",
" # load imdb with datasets\n",
" ds = load_dataset(dataset_name, split=\"train\")\n",
" ds = ds.rename_columns({\"text\": \"review\"})\n",
" ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n",
"\n",
" input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
"\n",
" def tokenize(sample):\n",
" sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
" sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
" return sample\n",
"\n",
" ds = ds.map(tokenize, batched=False)\n",
" ds.set_format(type=\"torch\")\n",
" return ds\n",
"\n",
"\n",
"dataset = build_dataset(tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "AqA2McjMAxNw"
},
"outputs": [],
"source": [
"gen_kwargs = {\n",
" \"min_length\": -1,\n",
" \"top_k\": 0.0,\n",
" \"top_p\": 1.0,\n",
" \"do_sample\": True,\n",
" \"pad_token_id\": tokenizer.eos_token_id,\n",
"}\n",
"sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "L_q4qs35AxcR"
},
"outputs": [],
"source": [
"output_min_length = 4\n",
"output_max_length = 16\n",
"output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
"\n",
"#### get a batch from the dataset\n",
"bs = 16\n",
"output_data = dict()\n",
"dataset.set_format(\"pandas\")\n",
"df_batch = dataset[:].sample(bs)\n",
"output_data[\"query\"] = df_batch[\"query\"].tolist()\n",
"query_tensors = df_batch[\"input_ids\"].tolist()\n",
"\n",
"# :: [Resp]\n",
"response_tensors_ref, response_tensors = [], []\n",
"# :: [[Resp]]\n",
"response_tensors_best_of = []"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QVfpyHnZBLKY"
},
"source": [
"\n",
"Generation using various models"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "-imZ7uEFBNbw"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
]
}
],
"source": [
"for i in range(bs):\n",
" gen_len = output_length_sampler()\n",
"\n",
" query = torch.tensor(query_tensors[i])\n",
"\n",
" output = ref_model.generate(\n",
" query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
" ).squeeze()\n",
" response_tensors_ref.append(tokenizer.decode(output))\n",
"\n",
" output = model.generate(\n",
" query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
" ).squeeze()\n",
" response_tensors.append(tokenizer.decode(output))\n",
"\n",
" # generating copies of the same query for the Best-of-n sampling\n",
" queries = query.repeat((N_BEST_OF, 1))\n",
" output = ref_model.generate(\n",
" queries.to(device), max_new_tokens=gen_len, **gen_kwargs\n",
" ).squeeze()\n",
" response_tensors_best_of.append(tokenizer.batch_decode(output))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Jp5FC0Y5h_Sf"
},
"source": [
"Scoring"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "PyDbbAQ0F_h7"
},
"outputs": [],
"source": [
"scores_ref = [\n",
" output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)\n",
"]\n",
"scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n",
"scores_best_of = []\n",
"for i, response in enumerate(response_tensors_best_of):\n",
" # base_score = scores_ref[i]\n",
" scores_best_of.append(\n",
" torch.tensor(\n",
" [output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]\n",
" )\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 682
},
"id": "nA1GDNJEiGm-",
"outputId": "1389c686-0751-4304-dea2-b71fd68748e1"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>query</th>\n",
" <th>response (ref)</th>\n",
" <th>scores (ref)</th>\n",
" <th>response (RLHF)</th>\n",
" <th>scores (RLHF)</th>\n",
" <th>response (best_of)</th>\n",
" <th>scores (best_of)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>This movie</td>\n",
" <td>This movie should have read some books, and</td>\n",
" <td>1.411889</td>\n",
" <td>This movie has plenty of extraordinary feature...</td>\n",
" <td>2.735337</td>\n",
" <td>This movie was unexpectedly funny and funny, you</td>\n",
" <td>2.405301</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>OK where do i begin?</td>\n",
" <td>OK where do i begin? *** Acting is decent (not...</td>\n",
" <td>1.555380</td>\n",
" <td>OK where do i begin? For all of you who are no...</td>\n",
" <td>0.019694</td>\n",
" <td>OK where do i begin? i just wanted to add some...</td>\n",
" <td>0.622912</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>I watched</td>\n",
" <td>I watched one can compare themselves upon view...</td>\n",
" <td>1.380120</td>\n",
" <td>I watched it because of its excellent cast. Th...</td>\n",
" <td>2.498309</td>\n",
" <td>I watched the trial trial for teaches us a goo...</td>\n",
" <td>2.057187</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>It's been 19 years since Gordon</td>\n",
" <td>It's been 19 years since Gordon finally left c...</td>\n",
" <td>1.554914</td>\n",
" <td>It's been 19 years since Gordon Tree has becom...</td>\n",
" <td>1.632266</td>\n",
" <td>It's been 19 years since Gordon Clarke put me ...</td>\n",
" <td>2.783458</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Just kidding</td>\n",
" <td>Just kidding; I know a lot</td>\n",
" <td>-0.069533</td>\n",
" <td>Just kidding \"Third World Snopes</td>\n",
" <td>0.944632</td>\n",
" <td>Just kidding, I didn't even</td>\n",
" <td>1.945202</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>shakespeare's plays have a way</td>\n",
" <td>shakespeare's plays have a way of weaving into...</td>\n",
" <td>1.656927</td>\n",
" <td>shakespeare's plays have a way. It's the look ...</td>\n",
" <td>1.444803</td>\n",
" <td>shakespeare's plays have a way of getting back...</td>\n",
" <td>1.834373</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>This movie is wonderful. What</td>\n",
" <td>This movie is wonderful. What could have been ...</td>\n",
" <td>2.749068</td>\n",
" <td>This movie is wonderful. What someone likes ab...</td>\n",
" <td>2.759510</td>\n",
" <td>This movie is wonderful. What a different look,</td>\n",
" <td>2.695312</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>I loved</td>\n",
" <td>I loved this film. &lt;br /&gt;&lt;</td>\n",
" <td>2.576181</td>\n",
" <td>I loved it, and I really loved Audrey</td>\n",
" <td>2.578412</td>\n",
" <td>I loved this film. Reading reviews of it</td>\n",
" <td>2.751773</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>A superb and</td>\n",
" <td>A superb and very cool drama. The novel is</td>\n",
" <td>2.910374</td>\n",
" <td>A superb and super fun movie that removes all the</td>\n",
" <td>2.783201</td>\n",
" <td>A superb and most finely acted role that I will</td>\n",
" <td>2.894923</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>I remember</td>\n",
" <td>I remember.Very poor execution but good movies</td>\n",
" <td>0.923775</td>\n",
" <td>I remember when Shelter saw some girls on TV</td>\n",
" <td>0.825408</td>\n",
" <td>I remember thinking to myself how SOMEONE who</td>\n",
" <td>1.634163</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>This su*k</td>\n",
" <td>This su*k camel down your kidd</td>\n",
" <td>1.605957</td>\n",
" <td>This su*k Dress! I loved it</td>\n",
" <td>2.345865</td>\n",
" <td>This su*k like a roll of crap</td>\n",
" <td>2.422874</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>One Stink</td>\n",
" <td>One Stink Act...&lt;br /&gt;&lt;br</td>\n",
" <td>1.456476</td>\n",
" <td>One Stinkl was a great actor, particularly</td>\n",
" <td>1.782818</td>\n",
" <td>One Stink?: Invisible of Saint Barbara, poor</td>\n",
" <td>1.667756</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>I pulled down a VHS</td>\n",
" <td>I pulled down a VHS copy and watched it with m...</td>\n",
" <td>0.756151</td>\n",
" <td>I pulled down a VHS looking a good looking, and a</td>\n",
" <td>-0.008258</td>\n",
" <td>I pulled down a VHS copy the other day and all I</td>\n",
" <td>0.992919</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>For some</td>\n",
" <td>For some alone no more Buddy Trumbull would ha...</td>\n",
" <td>0.790762</td>\n",
" <td>For some enthraled time, the film will impress...</td>\n",
" <td>2.455694</td>\n",
" <td>For some reason, a bomb crashed on the rear of...</td>\n",
" <td>0.857423</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>This one features all</td>\n",
" <td>This one features all the good elements of spi...</td>\n",
" <td>1.452079</td>\n",
" <td>This one features all kinds of wit and humor r...</td>\n",
" <td>2.743043</td>\n",
" <td>This one features all the best Birdprogram sup...</td>\n",
" <td>2.343950</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>Somehow a woman working with</td>\n",
" <td>Somehow a woman working with Jim Wynorski prof...</td>\n",
" <td>0.242172</td>\n",
" <td>Somehow a woman working with her daughter play...</td>\n",
" <td>0.092226</td>\n",
" <td>Somehow a woman working with an overweight ins...</td>\n",
" <td>1.415525</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" query \\\n",
"0 This movie \n",
"1 OK where do i begin? \n",
"2 I watched \n",
"3 It's been 19 years since Gordon \n",
"4 Just kidding \n",
"5 shakespeare's plays have a way \n",
"6 This movie is wonderful. What \n",
"7 I loved \n",
"8 A superb and \n",
"9 I remember \n",
"10 This su*k \n",
"11 One Stink \n",
"12 I pulled down a VHS \n",
"13 For some \n",
"14 This one features all \n",
"15 Somehow a woman working with \n",
"\n",
" response (ref) scores (ref) \\\n",
"0 This movie should have read some books, and 1.411889 \n",
"1 OK where do i begin? *** Acting is decent (not... 1.555380 \n",
"2 I watched one can compare themselves upon view... 1.380120 \n",
"3 It's been 19 years since Gordon finally left c... 1.554914 \n",
"4 Just kidding; I know a lot -0.069533 \n",
"5 shakespeare's plays have a way of weaving into... 1.656927 \n",
"6 This movie is wonderful. What could have been ... 2.749068 \n",
"7 I loved this film. <br />< 2.576181 \n",
"8 A superb and very cool drama. The novel is 2.910374 \n",
"9 I remember.Very poor execution but good movies 0.923775 \n",
"10 This su*k camel down your kidd 1.605957 \n",
"11 One Stink Act...<br /><br 1.456476 \n",
"12 I pulled down a VHS copy and watched it with m... 0.756151 \n",
"13 For some alone no more Buddy Trumbull would ha... 0.790762 \n",
"14 This one features all the good elements of spi... 1.452079 \n",
"15 Somehow a woman working with Jim Wynorski prof... 0.242172 \n",
"\n",
" response (RLHF) scores (RLHF) \\\n",
"0 This movie has plenty of extraordinary feature... 2.735337 \n",
"1 OK where do i begin? For all of you who are no... 0.019694 \n",
"2 I watched it because of its excellent cast. Th... 2.498309 \n",
"3 It's been 19 years since Gordon Tree has becom... 1.632266 \n",
"4 Just kidding \"Third World Snopes 0.944632 \n",
"5 shakespeare's plays have a way. It's the look ... 1.444803 \n",
"6 This movie is wonderful. What someone likes ab... 2.759510 \n",
"7 I loved it, and I really loved Audrey 2.578412 \n",
"8 A superb and super fun movie that removes all the 2.783201 \n",
"9 I remember when Shelter saw some girls on TV 0.825408 \n",
"10 This su*k Dress! I loved it 2.345865 \n",
"11 One Stinkl was a great actor, particularly 1.782818 \n",
"12 I pulled down a VHS looking a good looking, and a -0.008258 \n",
"13 For some enthraled time, the film will impress... 2.455694 \n",
"14 This one features all kinds of wit and humor r... 2.743043 \n",
"15 Somehow a woman working with her daughter play... 0.092226 \n",
"\n",
" response (best_of) scores (best_of) \n",
"0 This movie was unexpectedly funny and funny, you 2.405301 \n",
"1 OK where do i begin? i just wanted to add some... 0.622912 \n",
"2 I watched the trial trial for teaches us a goo... 2.057187 \n",
"3 It's been 19 years since Gordon Clarke put me ... 2.783458 \n",
"4 Just kidding, I didn't even 1.945202 \n",
"5 shakespeare's plays have a way of getting back... 1.834373 \n",
"6 This movie is wonderful. What a different look, 2.695312 \n",
"7 I loved this film. Reading reviews of it 2.751773 \n",
"8 A superb and most finely acted role that I will 2.894923 \n",
"9 I remember thinking to myself how SOMEONE who 1.634163 \n",
"10 This su*k like a roll of crap 2.422874 \n",
"11 One Stink?: Invisible of Saint Barbara, poor 1.667756 \n",
"12 I pulled down a VHS copy the other day and all I 0.992919 \n",
"13 For some reason, a bomb crashed on the rear of... 0.857423 \n",
"14 This one features all the best Birdprogram sup... 2.343950 \n",
"15 Somehow a woman working with an overweight ins... 1.415525 "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_data[\"response (ref)\"] = response_tensors_ref\n",
"output_data[\"scores (ref)\"] = scores_ref\n",
"output_data[\"response (RLHF)\"] = response_tensors\n",
"output_data[\"scores (RLHF)\"] = scores\n",
"output_data[\"response (best_of)\"] = [\n",
" response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)\n",
"]\n",
"output_data[\"scores (best_of)\"] = [a.max().item() for a in scores_best_of]\n",
"\n",
"\n",
"# store results in a dataframe\n",
"df_results = pd.DataFrame(output_data)\n",
"df_results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

View File

@ -73,13 +73,19 @@
"import pandas as pd\n",
"from random import choices\n",
"import matplotlib.pyplot as plt\n",
"\n",
"tqdm.pandas()\n",
"\n",
"from datasets import load_dataset\n",
"\n",
"from transformers import AutoTokenizer, pipeline\n",
"\n",
"from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model"
"from trl import (\n",
" PPOTrainer,\n",
" PPOConfig,\n",
" AutoModelForCausalLMWithValueHead,\n",
" create_reference_model,\n",
")"
]
},
{
@ -95,22 +101,19 @@
"metadata": {},
"outputs": [],
"source": [
"sentiment_pipe_kwargs = {\n",
" \"top_k\": None, \n",
" \"function_to_apply\": \"none\"\n",
"}\n",
"sentiment_pipe_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\"}\n",
"\n",
"config = PPOConfig(\n",
" model_name=\"lvwerra/gpt2-imdb\",\n",
" steps=51200,\n",
" learning_rate=1.41e-5,\n",
" remove_unused_columns=False,\n",
" log_with=\"wandb\"\n",
" log_with=\"wandb\",\n",
")\n",
"\n",
"txt_in_len = 5\n",
"txt_out_len = 20\n",
"seed = 1\n"
"seed = 1"
]
},
{
@ -127,7 +130,7 @@
"metadata": {},
"source": [
"You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n",
"https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
"https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
]
},
{
@ -158,7 +161,7 @@
"outputs": [],
"source": [
"gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n",
"gpt2_model_ref = create_reference_model(gpt2_model)\n",
"gpt2_ref_model = create_reference_model(gpt2_model)\n",
"gpt2_tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n",
"\n",
"gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token"
@ -201,13 +204,13 @@
}
],
"source": [
"# create the dataset \n",
"# \n",
"dataset = load_dataset('imdb', split='train')\n",
"dataset = dataset.rename_columns({'text': 'review', 'label': 'sentiment'})\n",
"# create the dataset\n",
"#\n",
"dataset = load_dataset(\"stanfordnlp/imdb\", split=\"train\")\n",
"dataset = dataset.rename_columns({\"text\": \"review\", \"label\": \"sentiment\"})\n",
"# make sure the comments are are at least 500 and trim to 1000\n",
"dataset = dataset.filter(lambda x: len(x[\"review\"])>500, batched=False)\n",
"dataset = dataset.map(lambda x:{\"review\":x['review'][:1000]}, batched=False)\n",
"dataset = dataset.filter(lambda x: len(x[\"review\"]) > 500, batched=False)\n",
"dataset = dataset.map(lambda x: {\"review\": x[\"review\"][:1000]}, batched=False)\n",
"\n",
"dataset"
]
@ -241,11 +244,21 @@
}
],
"source": [
"dataset = dataset.map(lambda x:{\"input_ids\": gpt2_tokenizer.encode(' '+x['review'], return_tensors=\"pt\")[0, :txt_in_len]}, batched=False)\n",
"dataset = dataset.map(lambda x:{\"query\": gpt2_tokenizer.decode(x[\"input_ids\"])}, batched=False)\n",
"dataset = dataset.map(\n",
" lambda x: {\n",
" \"input_ids\": gpt2_tokenizer.encode(\" \" + x[\"review\"], return_tensors=\"pt\")[\n",
" 0, :txt_in_len\n",
" ]\n",
" },\n",
" batched=False,\n",
")\n",
"dataset = dataset.map(\n",
" lambda x: {\"query\": gpt2_tokenizer.decode(x[\"input_ids\"])}, batched=False\n",
")\n",
"dataset = dataset[:20480]\n",
"\n",
"from datasets import Dataset\n",
"\n",
"dataset = Dataset.from_dict(dataset)\n",
"dataset.set_format(\"pytorch\")"
]
@ -355,7 +368,9 @@
}
],
"source": [
"ppo_trainer = PPOTrainer(config, gpt2_model, gpt2_model_ref, gpt2_tokenizer, dataset, data_collator=collator)\n"
"ppo_trainer = PPOTrainer(\n",
" config, gpt2_model, gpt2_ref_model, gpt2_tokenizer, dataset, data_collator=collator\n",
")"
]
},
{
@ -373,10 +388,12 @@
"outputs": [],
"source": [
"if ppo_trainer.accelerator.num_processes == 1:\n",
" device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n",
" device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n",
"else:\n",
" device = ppo_trainer.accelerator.device\n",
"sentiment_pipe = pipeline(\"sentiment-analysis\", \"lvwerra/distilbert-imdb\", device=device)"
"sentiment_pipe = pipeline(\n",
" \"sentiment-analysis\", \"lvwerra/distilbert-imdb\", device=device\n",
")"
]
},
{
@ -404,7 +421,7 @@
}
],
"source": [
"text = 'this movie was really bad!!'\n",
"text = \"this movie was really bad!!\"\n",
"output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n",
"output"
]
@ -427,7 +444,7 @@
}
],
"source": [
"text = 'this movie was really good!!'\n",
"text = \"this movie was really good!!\"\n",
"output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n",
"output"
]
@ -450,7 +467,7 @@
}
],
"source": [
"text = 'this movie was a documentary'\n",
"text = \"this movie was a documentary\"\n",
"output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n",
"output"
]
@ -472,7 +489,7 @@
" positive_logits = []\n",
" for out in outputs:\n",
" for element in out:\n",
" if element[\"label\"]==\"POSITIVE\":\n",
" if element[\"label\"] == \"POSITIVE\":\n",
" positive_logits.append(torch.tensor(element[\"score\"]))\n",
" return positive_logits"
]
@ -511,9 +528,14 @@
"metadata": {},
"outputs": [],
"source": [
"ctrl_str = ['[negative]', '[neutral]', '[positive]']\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # this should be handled by accelerate\n",
"ctrl_tokens = dict((s, gpt2_tokenizer.encode(s, return_tensors=\"pt\").squeeze().to(device)) for s in ctrl_str)"
"ctrl_str = [\"[negative]\", \"[neutral]\", \"[positive]\"]\n",
"device = torch.device(\n",
" \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
") # this should be handled by accelerate\n",
"ctrl_tokens = dict(\n",
" (s, gpt2_tokenizer.encode(s, return_tensors=\"pt\").squeeze().to(device))\n",
" for s in ctrl_str\n",
")"
]
},
{
@ -559,14 +581,14 @@
" task [positive]: reward = logit\n",
" \"\"\"\n",
" for i in range(len(logit)):\n",
" if task[i]=='[negative]':\n",
" if task[i] == \"[negative]\":\n",
" logit[i] = -logit[i]\n",
" elif task[i]=='[neutral]':\n",
" logit[i] = -2*torch.abs(logit[i])+4\n",
" elif task[i]=='[positive]':\n",
" elif task[i] == \"[neutral]\":\n",
" logit[i] = -2 * torch.abs(logit[i]) + 4\n",
" elif task[i] == \"[positive]\":\n",
" pass\n",
" else:\n",
" raise ValueError('task has to be in [0, 1, 2]!')\n",
" raise ValueError(\"task has to be in [0, 1, 2]!\")\n",
" return logit"
]
},
@ -611,7 +633,7 @@
}
],
"source": [
"pos_logit_to_reward(torch.Tensor([4,4,4]), ctrl_str)"
"pos_logit_to_reward(torch.Tensor([4, 4, 4]), ctrl_str)"
]
},
{
@ -631,7 +653,7 @@
}
],
"source": [
"pos_logit_to_reward(torch.Tensor([-4,-4,-4]), ctrl_str)"
"pos_logit_to_reward(torch.Tensor([-4, -4, -4]), ctrl_str)"
]
},
{
@ -668,14 +690,14 @@
"outputs": [],
"source": [
"generation_kwargs = {\n",
" \"min_length\":-1,\n",
" \"min_length\": -1,\n",
" \"top_k\": 0.0,\n",
" \"top_p\": 1.0,\n",
" \"do_sample\": True,\n",
" \"pad_token_id\": gpt2_tokenizer.eos_token_id,\n",
" \"max_new_tokens\": txt_out_len,\n",
" \"eos_token_id\": -1\n",
"}\n"
" \"eos_token_id\": -1,\n",
"}"
]
},
{
@ -698,7 +720,6 @@
"4. Get sentiments for query/responses from BERT\n",
"5. Optimize policy with PPO using the (query, response, reward) triplet\n",
"6. Log all the training statistics\n",
"\n",
"**Training time**\n",
"\n",
@ -724,33 +745,46 @@
"source": [
"for epoch in range(2):\n",
" for batch in tqdm(ppo_trainer.dataloader):\n",
" logs, game_data, = dict(), dict()\n",
" \n",
" (\n",
" logs,\n",
" game_data,\n",
" ) = (\n",
" dict(),\n",
" dict(),\n",
" )\n",
"\n",
" #### prepend a random control token\n",
" task_list = choices(ctrl_str, k=config.batch_size)\n",
" game_data['query'] = [t+q for t,q in zip(task_list, batch['query'])]\n",
" query_tensors = [torch.cat((ctrl_tokens[t], input_ids)) for t, input_ids in zip(task_list, batch[\"input_ids\"])]\n",
" game_data[\"query\"] = [t + q for t, q in zip(task_list, batch[\"query\"])]\n",
" query_tensors = [\n",
" torch.cat((ctrl_tokens[t], input_ids))\n",
" for t, input_ids in zip(task_list, batch[\"input_ids\"])\n",
" ]\n",
"\n",
" #### get response from gpt2\n",
" response_tensors = []\n",
" for query in query_tensors:\n",
" response = ppo_trainer.generate(query, **generation_kwargs)\n",
" response_tensors.append(response.squeeze()[-txt_out_len:])\n",
" game_data['response'] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors]\n",
" game_data[\"response\"] = [\n",
" gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors\n",
" ]\n",
"\n",
" #### sentiment analysis\n",
" texts = [q + r for q,r in zip(batch['query'], game_data['response'])]\n",
" texts = [q + r for q, r in zip(batch[\"query\"], game_data[\"response\"])]\n",
" logits = extract_pipe_output(sentiment_pipe(texts, **sentiment_pipe_kwargs))\n",
" rewards = pos_logit_to_reward(logits, task_list)\n",
"\n",
" #### Run PPO training \n",
" #### Run PPO training\n",
" t = time.time()\n",
" stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n",
"\n",
" for cs in ctrl_str:\n",
" key = 'env/reward_'+cs.strip('[]')\n",
" stats[key] = np.mean([r.cpu().numpy() for r, t in zip(rewards, task_list) if t==cs])\n",
" ppo_trainer.log_stats(stats, game_data, rewards)\n"
" key = \"env/reward_\" + cs.strip(\"[]\")\n",
" stats[key] = np.mean(\n",
" [r.cpu().numpy() for r, t in zip(rewards, task_list) if t == cs]\n",
" )\n",
" ppo_trainer.log_stats(stats, game_data, rewards)"
]
},
{
@ -803,12 +837,14 @@
],
"source": [
"for ctrl_s in ctrl_str:\n",
" plt.hist([r for r, t in zip(logs['env/reward_dist'], task_list) if t==ctrl_s],\n",
" density=True,\n",
" alpha=0.5,\n",
" label=ctrl_s)\n",
"plt.legend(loc='best')\n",
"plt.title('reward distribution')\n",
" plt.hist(\n",
" [r for r, t in zip(logs[\"env/reward_dist\"], task_list) if t == ctrl_s],\n",
" density=True,\n",
" alpha=0.5,\n",
" label=ctrl_s,\n",
" )\n",
"plt.legend(loc=\"best\")\n",
"plt.title(\"reward distribution\")\n",
"plt.grid(True)\n",
"plt.show()"
]
@ -827,8 +863,8 @@
"metadata": {},
"outputs": [],
"source": [
"gpt2_model.save_pretrained('gpt2-imdb-ctrl')\n",
"gpt2_tokenizer.save_pretrained('gpt2-imdb-ctrl')"
"gpt2_model.save_pretrained(\"gpt2-imdb-ctrl\")\n",
"gpt2_tokenizer.save_pretrained(\"gpt2-imdb-ctrl\")"
]
}
],
@ -848,7 +884,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.12"
},
"vscode": {
"interpreter": {

View File

@ -63,6 +63,7 @@
"import torch\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"\n",
"tqdm.pandas()\n",
"\n",
"from transformers import pipeline, AutoTokenizer\n",
@ -91,11 +92,7 @@
" log_with=\"wandb\",\n",
")\n",
"\n",
"sent_kwargs = {\n",
" \"return_all_scores\": True,\n",
" \"function_to_apply\": \"none\",\n",
" \"batch_size\": 16\n",
"}"
"sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
]
},
{
@ -105,6 +102,7 @@
"outputs": [],
"source": [
"import wandb\n",
"\n",
"wandb.init()"
]
},
@ -112,8 +110,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n",
"https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
"You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/main/examples/legacy/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n",
"https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
]
},
{
@ -136,26 +134,22 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset imdb (/home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n",
"Loading cached processed dataset at /home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-ff455473e884c6a3.arrow\n"
]
}
],
"outputs": [],
"source": [
"def build_dataset(config, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n",
"def build_dataset(\n",
" config,\n",
" dataset_name=\"stanfordnlp/imdb\",\n",
" input_min_text_length=2,\n",
" input_max_text_length=8,\n",
"):\n",
" \"\"\"\n",
" Build dataset for training. This builds the dataset from `load_dataset`, one should \n",
" Build dataset for training. This builds the dataset from `load_dataset`, one should\n",
" customize this function to train the model on its own dataset.\n",
" \n",
"\n",
" Args:\n",
" dataset_name (`str`): \n",
" dataset_name (`str`):\n",
" The name of the dataset to be loaded.\n",
" \n",
"\n",
" Returns:\n",
" dataloader (`torch.utils.data.DataLoader`):\n",
" The dataloader for the dataset.\n",
@ -163,19 +157,19 @@
" tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
" # load imdb with datasets\n",
" ds = load_dataset(dataset_name, split='train')\n",
" ds = ds.rename_columns({'text': 'review'})\n",
" ds = ds.filter(lambda x: len(x[\"review\"])>200, batched=False)\n",
" ds = load_dataset(dataset_name, split=\"train\")\n",
" ds = ds.rename_columns({\"text\": \"review\"})\n",
" ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n",
"\n",
" input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
"\n",
" def tokenize(sample):\n",
" sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[:input_size()]\n",
" sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
" sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
" return sample\n",
"\n",
" ds = ds.map(tokenize, batched=False)\n",
" ds.set_format(type='torch')\n",
" ds.set_format(type=\"torch\")\n",
" return ds"
]
},
@ -187,6 +181,7 @@
"source": [
"dataset = build_dataset(config)\n",
"\n",
"\n",
"def collator(data):\n",
" return dict((key, [d[key] for d in data]) for key in data[0])"
]
@ -233,7 +228,9 @@
"metadata": {},
"outputs": [],
"source": [
"ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)"
"ppo_trainer = PPOTrainer(\n",
" config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator\n",
")"
]
},
{
@ -252,8 +249,10 @@
"source": [
"device = ppo_trainer.accelerator.device\n",
"if ppo_trainer.accelerator.num_processes == 1:\n",
" device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n",
"sentiment_pipe = pipeline(\"sentiment-analysis\", model=\"lvwerra/distilbert-imdb\", device=device)"
" device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n",
"sentiment_pipe = pipeline(\n",
" \"sentiment-analysis\", model=\"lvwerra/distilbert-imdb\", device=device\n",
")"
]
},
{
@ -271,8 +270,8 @@
{
"data": {
"text/plain": [
"[[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n",
" {'label': 'POSITIVE', 'score': -2.726576566696167}]]"
"[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n",
" {'label': 'POSITIVE', 'score': -2.726576328277588}]"
]
},
"execution_count": null,
@ -281,7 +280,7 @@
}
],
"source": [
"text = 'this movie was really bad!!'\n",
"text = \"this movie was really bad!!\"\n",
"sentiment_pipe(text, **sent_kwargs)"
]
},
@ -293,8 +292,8 @@
{
"data": {
"text/plain": [
"[[{'label': 'NEGATIVE', 'score': -2.2947897911071777},\n",
" {'label': 'POSITIVE', 'score': 2.557039737701416}]]"
"[{'label': 'POSITIVE', 'score': 2.557040214538574},\n",
" {'label': 'NEGATIVE', 'score': -2.294790267944336}]"
]
},
"execution_count": null,
@ -303,7 +302,7 @@
}
],
"source": [
"text = 'this movie was really good!!'\n",
"text = \"this movie was really good!!\"\n",
"sentiment_pipe(text, **sent_kwargs)"
]
},
@ -322,11 +321,11 @@
"outputs": [],
"source": [
"gen_kwargs = {\n",
" \"min_length\":-1,\n",
" \"min_length\": -1,\n",
" \"top_k\": 0.0,\n",
" \"top_p\": 1.0,\n",
" \"do_sample\": True,\n",
" \"pad_token_id\": tokenizer.eos_token_id\n",
" \"pad_token_id\": tokenizer.eos_token_id,\n",
"}"
]
},
@ -370,32 +369,39 @@
"\n",
"\n",
"generation_kwargs = {\n",
" \"min_length\":-1,\n",
" \"min_length\": -1,\n",
" \"top_k\": 0.0,\n",
" \"top_p\": 1.0,\n",
" \"do_sample\": True,\n",
" \"pad_token_id\": tokenizer.eos_token_id\n",
" \"pad_token_id\": tokenizer.eos_token_id,\n",
"}\n",
"\n",
"\n",
"for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n",
" query_tensors = batch['input_ids']\n",
"for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader)):\n",
" query_tensors = batch[\"input_ids\"]\n",
"\n",
" #### Get response from gpt2\n",
" response_tensors = []\n",
" for query in query_tensors:\n",
" gen_len = output_length_sampler()\n",
" generation_kwargs[\"max_new_tokens\"] = gen_len\n",
" response = ppo_trainer.generate(query, **generation_kwargs)\n",
" response_tensors.append(response.squeeze()[-gen_len:])\n",
" batch['response'] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n",
" query_response = ppo_trainer.generate(query, **generation_kwargs).squeeze()\n",
" response_len = len(query_response) - len(query)\n",
" response_tensors.append(query_response[-response_len:])\n",
" batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n",
"\n",
" #### Compute sentiment score\n",
" texts = [q + r for q,r in zip(batch['query'], batch['response'])]\n",
" texts = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n",
" pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
" rewards = [torch.tensor(output[1][\"score\"]) for output in pipe_outputs]\n",
" positive_scores = [\n",
" item[\"score\"]\n",
" for output in pipe_outputs\n",
" for item in output\n",
" if item[\"label\"] == \"POSITIVE\"\n",
" ]\n",
" rewards = [torch.tensor(score) for score in positive_scores]\n",
"\n",
" #### Run PPO step \n",
" #### Run PPO step\n",
" stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n",
" ppo_trainer.log_stats(stats, batch, rewards)"
]
@ -405,7 +411,7 @@
"metadata": {},
"source": [
"### Training progress\n",
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/lvwerra/trl-showcase/runs/1jtvxb1m/).\n",
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://wandb.ai/huggingface/trl/runs/w9l3110g).\n",
"\n",
"<div style=\"text-align: center\">\n",
"<img src='https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gpt2_tuning_progress.png' width='800'>\n",
@ -414,7 +420,7 @@
"\n",
"One can observe how the model starts to generate more positive outputs after a few optimisation steps.\n",
"\n",
"> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher inital coefficient."
"> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher initial coefficient."
]
},
{
@ -423,7 +429,7 @@
"metadata": {},
"source": [
"## Model inspection\n",
"Let's inspect some examples from the IMDB dataset. We can use `model_ref` to compare the tuned model `model` against the model before optimisation."
"Let's inspect some examples from the IMDB dataset. We can use `ref_model` to compare the tuned model `model` against the model before optimisation."
]
},
{
@ -431,14 +437,6 @@
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/transformers/pipelines/base.py:1075: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
@ -470,131 +468,131 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Oh dear,</td>\n",
" <td>what are I saying?! I fast-forwarded through</td>\n",
" <td>I must say that I are hanging my head on this</td>\n",
" <td>-0.858954</td>\n",
" <td>-1.007609</td>\n",
" <td>I rented Zero Day</td>\n",
" <td>4 for my sister. To my surprise, the Wii caug...</td>\n",
" <td>. It is a pleasure. It is a huge leap 68 years...</td>\n",
" <td>1.736068</td>\n",
" <td>2.423731</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>I've seen</td>\n",
" <td>it, as well.&lt;br</td>\n",
" <td>three million dialogue throughout, and</td>\n",
" <td>1.996807</td>\n",
" <td>2.240883</td>\n",
" <td>The only</td>\n",
" <td>distro of her</td>\n",
" <td>special compliments is the</td>\n",
" <td>0.150852</td>\n",
" <td>0.190159</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Hi:&lt;br /&gt;&lt;br</td>\n",
" <td>/&gt;This movie is a turkey though when it comes to</td>\n",
" <td>/&gt;I also like that movie. It's so funny</td>\n",
" <td>-0.438191</td>\n",
" <td>2.415630</td>\n",
" <td>I've read a few</td>\n",
" <td>news reports about Mr. Mueller's activities b...</td>\n",
" <td>novels and I never watch this. It has a reall...</td>\n",
" <td>-1.417962</td>\n",
" <td>2.831814</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>I'm a writer</td>\n",
" <td>and I'm not going to be asked to</td>\n",
" <td>, not a screenwriter. I've written</td>\n",
" <td>-0.655991</td>\n",
" <td>-0.724324</td>\n",
" <td>This is the second British Rank film</td>\n",
" <td>, and I wouldn't be surprised anymore if it</td>\n",
" <td>that I have enjoyed, achieving it in both the</td>\n",
" <td>0.835876</td>\n",
" <td>2.205628</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>If you</td>\n",
" <td>absolutely love sensitive romance, the plot a...</td>\n",
" <td>are looking at the cinematography, the acting,</td>\n",
" <td>2.221309</td>\n",
" <td>0.148751</td>\n",
" <td>A classic</td>\n",
" <td>classic.&lt;br /&gt;&lt;br /&gt;And only this one will ha...</td>\n",
" <td>. It's a movie with a fine cast. As the beginn...</td>\n",
" <td>2.113075</td>\n",
" <td>2.739168</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>OMG this</td>\n",
" <td>casting cast. Obi cult breezy, this is</td>\n",
" <td>movie was totally wonderful, I it was the ide...</td>\n",
" <td>-1.533139</td>\n",
" <td>2.590190</td>\n",
" <td>This has to be one of the</td>\n",
" <td>worst with the differences being that for the</td>\n",
" <td>best thriller films I've seen in recent</td>\n",
" <td>-2.705339</td>\n",
" <td>2.730615</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>It's</td>\n",
" <td>unrealistic; the guy who was supposed to be E...</td>\n",
" <td>a very good film. It reminds us about over</td>\n",
" <td>-2.097017</td>\n",
" <td>2.835831</td>\n",
" <td>Happy Go Lovely is a waste</td>\n",
" <td>. Not only are extremely</td>\n",
" <td>of time, giving a</td>\n",
" <td>-2.429504</td>\n",
" <td>-2.934672</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>There is a really</td>\n",
" <td>awful laptop game!&lt;br /&gt;&lt;br /&gt;I used to</td>\n",
" <td>interesting story that set us the journey. Th...</td>\n",
" <td>-2.341743</td>\n",
" <td>2.282939</td>\n",
" <td>Wow, I just</td>\n",
" <td>can't make fun of it</td>\n",
" <td>feek it! This show</td>\n",
" <td>-2.201666</td>\n",
" <td>-0.106085</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>This is</td>\n",
" <td>my favorite part about</td>\n",
" <td>a well thought well</td>\n",
" <td>2.554794</td>\n",
" <td>2.734139</td>\n",
" <td>This movie makes several mistakes.</td>\n",
" <td>Despite being a great comedic diversion it es...</td>\n",
" <td>It's cool, wonderful - it held me into a very ...</td>\n",
" <td>-1.232380</td>\n",
" <td>2.707638</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>Wasn't</td>\n",
" <td>Wasn't it clichéd?&lt;|endoftext|&gt;</td>\n",
" <td>anyone else interested in this movie? It's a ...</td>\n",
" <td>-1.790802</td>\n",
" <td>2.631960</td>\n",
" <td>Branagh and Fish</td>\n",
" <td>burne, Drake is played</td>\n",
" <td>is a great show. Beautiful</td>\n",
" <td>0.776819</td>\n",
" <td>2.808996</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>This film is another of director Tim</td>\n",
" <td>Burton's masterpieces</td>\n",
" <td>Curry's best bombs</td>\n",
" <td>2.622917</td>\n",
" <td>2.544106</td>\n",
" <td>I might have given this movie a</td>\n",
" <td>rating of *11 when I heard that!), but it was...</td>\n",
" <td>great performance. It was truly a great movie...</td>\n",
" <td>0.276380</td>\n",
" <td>2.743328</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>I thought this movie</td>\n",
" <td>was excellent. I actually laughed 6 times and...</td>\n",
" <td>was perfect, and I believe it's almost overlo...</td>\n",
" <td>2.548022</td>\n",
" <td>2.601913</td>\n",
" <td>Really, really bad</td>\n",
" <td>with feel like there is no end to the</td>\n",
" <td>. This movie is incredibly good, with the</td>\n",
" <td>-2.639503</td>\n",
" <td>-1.568827</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>This early John Wayne</td>\n",
" <td>films looked like an abandoned police beating</td>\n",
" <td>film is a realistic portrayal of what</td>\n",
" <td>-1.742279</td>\n",
" <td>2.609762</td>\n",
" <td>What another reviewer called lack of</td>\n",
" <td>judgment, connecting into her own harsh obser...</td>\n",
" <td>suspense. Rogers and Rooney rate this as exce...</td>\n",
" <td>-1.079707</td>\n",
" <td>2.696888</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>I was</td>\n",
" <td>given an experience-a big one, almost 25</td>\n",
" <td>very happy with all the reflections and this ...</td>\n",
" <td>2.250709</td>\n",
" <td>2.558540</td>\n",
" <td>This is simply one</td>\n",
" <td>more problem of Steve</td>\n",
" <td>of the best choice</td>\n",
" <td>-1.445436</td>\n",
" <td>2.662699</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>Embarrassingly, I</td>\n",
" <td>am more at a strict conformity after getting ...</td>\n",
" <td>had never seen a movie before. There was one ...</td>\n",
" <td>-2.021666</td>\n",
" <td>-1.803383</td>\n",
" <td>\"Perhaps we can arrange a meet</td>\n",
" <td>-and-greet.&lt;br /&gt;&lt;br /&gt;Teleg</td>\n",
" <td>with spent, classic music and dance, and come...</td>\n",
" <td>0.258479</td>\n",
" <td>1.876662</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>I am a fan</td>\n",
" <td>of living on simple islands, and we have visi...</td>\n",
" <td>of many things and learned how to appreciate ...</td>\n",
" <td>1.791297</td>\n",
" <td>2.324461</td>\n",
" <td>Richard Willaims is</td>\n",
" <td>nice enough; the little black guy plays quite</td>\n",
" <td>beautifully hands on in his own spin, and</td>\n",
" <td>0.796508</td>\n",
" <td>2.820259</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
@ -602,76 +600,76 @@
],
"text/plain": [
" query \\\n",
"0 Oh dear, \n",
"1 I've seen \n",
"2 Hi:<br /><br \n",
"3 I'm a writer \n",
"4 If you \n",
"5 OMG this \n",
"6 It's \n",
"7 There is a really \n",
"8 This is \n",
"9 Wasn't \n",
"10 This film is another of director Tim \n",
"11 I thought this movie \n",
"12 This early John Wayne \n",
"13 I was \n",
"14 Embarrassingly, I \n",
"15 I am a fan \n",
"0 I rented Zero Day \n",
"1 The only \n",
"2 I've read a few \n",
"3 This is the second British Rank film \n",
"4 A classic \n",
"5 This has to be one of the \n",
"6 Happy Go Lovely is a waste \n",
"7 Wow, I just \n",
"8 This movie makes several mistakes. \n",
"9 Branagh and Fish \n",
"10 I might have given this movie a \n",
"11 Really, really bad \n",
"12 What another reviewer called lack of \n",
"13 This is simply one \n",
"14 \"Perhaps we can arrange a meet \n",
"15 Richard Willaims is \n",
"\n",
" response (before) \\\n",
"0 what are I saying?! I fast-forwarded through \n",
"1 it, as well.<br \n",
"2 />This movie is a turkey though when it comes to \n",
"3 and I'm not going to be asked to \n",
"4 absolutely love sensitive romance, the plot a... \n",
"5 casting cast. Obi cult breezy, this is \n",
"6 unrealistic; the guy who was supposed to be E... \n",
"7 awful laptop game!<br /><br />I used to \n",
"8 my favorite part about \n",
"9 Wasn't it clichéd?<|endoftext|> \n",
"10 Burton's masterpieces \n",
"11 was excellent. I actually laughed 6 times and... \n",
"12 films looked like an abandoned police beating \n",
"13 given an experience-a big one, almost 25 \n",
"14 am more at a strict conformity after getting ... \n",
"15 of living on simple islands, and we have visi... \n",
"0 4 for my sister. To my surprise, the Wii caug... \n",
"1 distro of her \n",
"2 news reports about Mr. Mueller's activities b... \n",
"3 , and I wouldn't be surprised anymore if it \n",
"4 classic.<br /><br />And only this one will ha... \n",
"5 worst with the differences being that for the \n",
"6 . Not only are extremely \n",
"7 can't make fun of it \n",
"8 Despite being a great comedic diversion it es... \n",
"9 burne, Drake is played \n",
"10 rating of *11 when I heard that!), but it was... \n",
"11 with feel like there is no end to the \n",
"12 judgment, connecting into her own harsh obser... \n",
"13 more problem of Steve \n",
"14 -and-greet.<br /><br />Teleg \n",
"15 nice enough; the little black guy plays quite \n",
"\n",
" response (after) rewards (before) \\\n",
"0 I must say that I are hanging my head on this -0.858954 \n",
"1 three million dialogue throughout, and 1.996807 \n",
"2 />I also like that movie. It's so funny -0.438191 \n",
"3 , not a screenwriter. I've written -0.655991 \n",
"4 are looking at the cinematography, the acting, 2.221309 \n",
"5 movie was totally wonderful, I it was the ide... -1.533139 \n",
"6 a very good film. It reminds us about over -2.097017 \n",
"7 interesting story that set us the journey. Th... -2.341743 \n",
"8 a well thought well 2.554794 \n",
"9 anyone else interested in this movie? It's a ... -1.790802 \n",
"10 Curry's best bombs 2.622917 \n",
"11 was perfect, and I believe it's almost overlo... 2.548022 \n",
"12 film is a realistic portrayal of what -1.742279 \n",
"13 very happy with all the reflections and this ... 2.250709 \n",
"14 had never seen a movie before. There was one ... -2.021666 \n",
"15 of many things and learned how to appreciate ... 1.791297 \n",
"0 . It is a pleasure. It is a huge leap 68 years... 1.736068 \n",
"1 special compliments is the 0.150852 \n",
"2 novels and I never watch this. It has a reall... -1.417962 \n",
"3 that I have enjoyed, achieving it in both the 0.835876 \n",
"4 . It's a movie with a fine cast. As the beginn... 2.113075 \n",
"5 best thriller films I've seen in recent -2.705339 \n",
"6 of time, giving a -2.429504 \n",
"7 feek it! This show -2.201666 \n",
"8 It's cool, wonderful - it held me into a very ... -1.232380 \n",
"9 is a great show. Beautiful 0.776819 \n",
"10 great performance. It was truly a great movie... 0.276380 \n",
"11 . This movie is incredibly good, with the -2.639503 \n",
"12 suspense. Rogers and Rooney rate this as exce... -1.079707 \n",
"13 of the best choice -1.445436 \n",
"14 with spent, classic music and dance, and come... 0.258479 \n",
"15 beautifully hands on in his own spin, and 0.796508 \n",
"\n",
" rewards (after) \n",
"0 -1.007609 \n",
"1 2.240883 \n",
"2 2.415630 \n",
"3 -0.724324 \n",
"4 0.148751 \n",
"5 2.590190 \n",
"6 2.835831 \n",
"7 2.282939 \n",
"8 2.734139 \n",
"9 2.631960 \n",
"10 2.544106 \n",
"11 2.601913 \n",
"12 2.609762 \n",
"13 2.558540 \n",
"14 -1.803383 \n",
"15 2.324461 "
"0 2.423731 \n",
"1 0.190159 \n",
"2 2.831814 \n",
"3 2.205628 \n",
"4 2.739168 \n",
"5 2.730615 \n",
"6 -2.934672 \n",
"7 -0.106085 \n",
"8 2.707638 \n",
"9 2.808996 \n",
"10 2.743328 \n",
"11 -1.568827 \n",
"12 2.696888 \n",
"13 2.662699 \n",
"14 1.876662 \n",
"15 2.820259 "
]
},
"execution_count": null,
@ -685,31 +683,56 @@
"game_data = dict()\n",
"dataset.set_format(\"pandas\")\n",
"df_batch = dataset[:].sample(bs)\n",
"game_data['query'] = df_batch['query'].tolist()\n",
"query_tensors = df_batch['input_ids'].tolist()\n",
"game_data[\"query\"] = df_batch[\"query\"].tolist()\n",
"query_tensors = df_batch[\"input_ids\"].tolist()\n",
"\n",
"response_tensors_ref, response_tensors = [], []\n",
"\n",
"#### get response from gpt2 and gpt2_ref\n",
"for i in range(bs):\n",
" query = torch.tensor(query_tensors[i]).to(device)\n",
"\n",
" gen_len = output_length_sampler()\n",
" output = ref_model.generate(torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device),\n",
" max_new_tokens=gen_len, **gen_kwargs).squeeze()[-gen_len:]\n",
" response_tensors_ref.append(output)\n",
" output = model.generate(torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device),\n",
" max_new_tokens=gen_len, **gen_kwargs).squeeze()[-gen_len:]\n",
" response_tensors.append(output)\n",
" query_response = ref_model.generate(\n",
" query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs\n",
" ).squeeze()\n",
" response_len = len(query_response) - len(query)\n",
" response_tensors_ref.append(query_response[-response_len:])\n",
"\n",
" query_response = model.generate(\n",
" query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs\n",
" ).squeeze()\n",
" response_len = len(query_response) - len(query)\n",
" response_tensors.append(query_response[-response_len:])\n",
"\n",
"#### decode responses\n",
"game_data['response (before)'] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]\n",
"game_data['response (after)'] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]\n",
"game_data[\"response (before)\"] = [\n",
" tokenizer.decode(response_tensors_ref[i]) for i in range(bs)\n",
"]\n",
"game_data[\"response (after)\"] = [\n",
" tokenizer.decode(response_tensors[i]) for i in range(bs)\n",
"]\n",
"\n",
"#### sentiment analysis of query/response pairs before/after\n",
"texts = [q + r for q,r in zip(game_data['query'], game_data['response (before)'])]\n",
"game_data['rewards (before)'] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n",
"texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (before)\"])]\n",
"pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
"positive_scores = [\n",
" item[\"score\"]\n",
" for output in pipe_outputs\n",
" for item in output\n",
" if item[\"label\"] == \"POSITIVE\"\n",
"]\n",
"game_data[\"rewards (before)\"] = positive_scores\n",
"\n",
"texts = [q + r for q,r in zip(game_data['query'], game_data['response (after)'])]\n",
"game_data['rewards (after)'] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n",
"texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (after)\"])]\n",
"pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
"positive_scores = [\n",
" item[\"score\"]\n",
" for output in pipe_outputs\n",
" for item in output\n",
" if item[\"label\"] == \"POSITIVE\"\n",
"]\n",
"game_data[\"rewards (after)\"] = positive_scores\n",
"\n",
"# store results in a dataframe\n",
"df_results = pd.DataFrame(game_data)\n",
@ -738,8 +761,8 @@
{
"data": {
"text/plain": [
"rewards (before) 0.156629\n",
"rewards (after) 1.686487\n",
"rewards (before) -0.512965\n",
"rewards (after) 1.676750\n",
"dtype: float64"
]
},
@ -757,8 +780,8 @@
{
"data": {
"text/plain": [
"rewards (before) -0.547091\n",
"rewards (after) 2.479868\n",
"rewards (before) -0.464427\n",
"rewards (after) 2.679794\n",
"dtype: float64"
]
},
@ -767,10 +790,10 @@
}
],
"source": [
"print('mean:')\n",
"print(\"mean:\")\n",
"display(df_results[[\"rewards (before)\", \"rewards (after)\"]].mean())\n",
"print()\n",
"print('median:')\n",
"print(\"median:\")\n",
"display(df_results[[\"rewards (before)\", \"rewards (after)\"]].median())"
]
},
@ -787,45 +810,6 @@
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/huggingface_hub/hf_api.py:1001: FutureWarning: `create_repo` now takes `token` as an optional positional argument. Be sure to adapt your code!\n",
" warnings.warn(\n",
"Cloning https://huggingface.co/lvwerra/gpt2-imdb-pos-v2 into local empty directory.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a953a6d0c465432bbc39aca826d37aaf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Upload file pytorch_model.bin: 0%| | 32.0k/487M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"remote: Enforcing permissions... \n",
"remote: Allowed refs: all \n",
"To https://huggingface.co/lvwerra/gpt2-imdb-pos-v2\n",
" 369b075..28b9865 main -> main\n",
"\n",
"remote: Enforcing permissions... \n",
"remote: Allowed refs: all \n",
"To https://huggingface.co/lvwerra/gpt2-imdb-pos-v2\n",
" 28b9865..42792ea main -> main\n",
"\n"
]
},
{
"data": {
"text/plain": [
@ -843,16 +827,9 @@
}
],
"source": [
"model.save_pretrained('gpt2-imdb-pos-v2', push_to_hub=True)\n",
"tokenizer.save_pretrained('gpt2-imdb-pos-v2', push_to_hub=True)"
"model.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)\n",
"tokenizer.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -871,7 +848,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12 (main, Mar 26 2022, 15:51:15) \n[Clang 13.1.6 (clang-1316.0.21.2)]"
"version": "3.11.9"
},
"vscode": {
"interpreter": {

View File

@ -0,0 +1,7 @@
# Research projects that use TRL
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity)
- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama)
- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2)

View File

@ -0,0 +1,18 @@
# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model.
There were three main steps to the training process:
1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se:
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm:
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=<LLAMA_SE_MODEL>`
3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model:
- `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/research_projects/stack_llama/scripts/rl_training.py --log_with=wandb --model_name=<LLAMA_SE_MODEL> --reward_model_name=<LLAMA_SE_RM_MODEL> --adafactor=False --tokenizer_name=<LLAMA_TOKENIZER> --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam`
LoRA layers were using at all stages to reduce memory requirements.
At each stage the peft adapter layers were merged with the base model, using:
```shell
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ
```
Note that this script requires `peft>=0.3.0`.
For access to the base llama-7b model, please see Meta's [release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) and [request form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform).

View File

@ -0,0 +1,62 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
@dataclass
class ScriptArguments:
"""
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the
merged model.
"""
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
assert script_args.base_model_name is not None, "please provide the name of the Base model"
assert script_args.output_name is not None, "please provide the output name of the merged model"
peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
if peft_config.task_type == "SEQ_CLS":
# The sequence classification task is used for the reward model in PPO
model = AutoModelForSequenceClassification.from_pretrained(
script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16
)
else:
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)
# Load the PEFT model
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
model.eval()
model = model.merge_and_unload()
model.save_pretrained(f"{script_args.output_name}")
tokenizer.save_pretrained(f"{script_args.output_name}")
model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False)

View File

@ -0,0 +1,324 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional, Union
import evaluate
import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
TrainingArguments,
set_seed,
)
from transformers.utils import PaddingStrategy
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""
local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})
resume_from_checkpoint: Optional[bool] = field(
default=False,
metadata={"help": "If you want to resume training where it left off."},
)
deepspeed: Optional[str] = field(
default=None,
metadata={
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
},
)
per_device_train_batch_size: Optional[int] = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=1)
learning_rate: Optional[float] = field(default=2e-5)
weight_decay: Optional[float] = field(default=0.001)
model_name: Optional[str] = field(
default="gpt2",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "The tokenizer for your model, if left empty will use the default for your model",
},
)
bf16: Optional[bool] = field(
default=True,
metadata={
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
},
)
num_train_epochs: Optional[int] = field(
default=1,
metadata={"help": "The number of training epochs for the reward model."},
)
train_subset: Optional[int] = field(
default=100000,
metadata={"help": "The size of the subset of the training data to use"},
)
eval_subset: Optional[int] = field(
default=50000,
metadata={"help": "The size of the subset of the eval data to use"},
)
gradient_checkpointing: Optional[bool] = field(
default=False,
metadata={"help": "Enables gradient checkpointing."},
)
optim: Optional[str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: Optional[str] = field(
default="linear",
metadata={"help": "The lr scheduler"},
)
max_length: Optional[int] = field(default=512)
eval_first_step: Optional[bool] = field(
default=False,
metadata={"help": "Whether to run eval after the first step"},
)
seed: Optional[int] = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
set_seed(script_args.seed)
# Load the human stack-exchange-paired dataset for tuning the reward model.
train_dataset = load_dataset(
"lvwerra/stack-exchange-paired", data_dir="data/reward", split="train", verification_mode="no_checks"
)
if script_args.train_subset > 0:
train_dataset = train_dataset.select(range(script_args.train_subset))
eval_dataset = load_dataset(
"lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train", verification_mode="no_checks"
)
if script_args.eval_subset > 0:
eval_dataset = eval_dataset.select(range(script_args.eval_subset))
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
model_name_split = script_args.model_name.split("/")[-1]
output_name = (
f"{model_name_split}_peft_stack-exchange-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}"
)
training_args = TrainingArguments(
output_dir=output_name,
learning_rate=script_args.learning_rate,
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
num_train_epochs=script_args.num_train_epochs,
weight_decay=script_args.weight_decay,
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
deepspeed=script_args.deepspeed,
local_rank=script_args.local_rank,
remove_unused_columns=False,
label_names=[],
bf16=script_args.bf16,
logging_strategy="steps",
logging_steps=10,
optim=script_args.optim,
lr_scheduler_type=script_args.lr_scheduler_type,
seed=script_args.seed,
)
# Load the value-head model and tokenizer.
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
model = AutoModelForSequenceClassification.from_pretrained(
script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Need to do this for gpt2, because it doesn't have an official pad token.
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
model.config.use_cache = not script_args.gradient_checkpointing
num_proc = 24 # Can adjust to be higher if you have more processors.
original_columns = train_dataset.column_names
# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other.
# Then tokenize the dataset.
def preprocess_function(examples):
new_examples = {
"input_ids_j": [],
"attention_mask_j": [],
"input_ids_k": [],
"attention_mask_k": [],
}
for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]):
tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True)
tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True)
new_examples["input_ids_j"].append(tokenized_j["input_ids"])
new_examples["attention_mask_j"].append(tokenized_j["attention_mask"])
new_examples["input_ids_k"].append(tokenized_k["input_ids"])
new_examples["attention_mask_k"].append(tokenized_k["attention_mask"])
return new_examples
# preprocess the dataset and filter out QAs that are longer than script_args.max_length
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
num_proc=num_proc,
)
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
num_proc=num_proc,
)
# We need to define a special data collator that batches the data in our j vs k format.
@dataclass
class RewardDataCollatorWithPadding:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
features_j = []
features_k = []
for feature in features:
features_j.append(
{
"input_ids": feature["input_ids_j"],
"attention_mask": feature["attention_mask_j"],
}
)
features_k.append(
{
"input_ids": feature["input_ids_k"],
"attention_mask": feature["attention_mask_k"],
}
)
batch_j = self.tokenizer.pad(
features_j,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch_k = self.tokenizer.pad(
features_k,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch = {
"input_ids_j": batch_j["input_ids"],
"attention_mask_j": batch_j["attention_mask"],
"input_ids_k": batch_k["input_ids"],
"attention_mask_k": batch_k["attention_mask"],
"return_loss": True,
}
return batch
# Define the metric that we'll use for validation.
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, _ = eval_pred
# Here, predictions is rewards_j and rewards_k.
# We want to see how much of the time rewards_j > rewards_k.
predictions = np.argmax(predictions, axis=0)
labels = np.zeros(predictions.shape)
return accuracy.compute(predictions=predictions, references=labels)
class RewardTrainer(Trainer):
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://huggingface.co/papers/2203.02155
def compute_loss(self, model, inputs, return_outputs=False):
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
if return_outputs:
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
return loss
# Train the model, woohoo.
trainer = RewardTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer),
)
if script_args.eval_first_step:
class EvaluateFirstStepCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
if state.global_step == 1:
control.should_evaluate = True
trainer.add_callback(EvaluateFirstStepCallback())
trainer.train(script_args.resume_from_checkpoint)
print("Saving last checkpoint of the model")
model.save_pretrained(output_name + "_peft_last_checkpoint")

View File

@ -0,0 +1,268 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
tqdm.pandas()
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine-tune with PPO
"""
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable.
model_name: Optional[str] = field(default="", metadata={"help": "the model name"})
tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"})
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
output_max_length: Optional[int] = field(default=128, metadata={"help": "maximum length for generation"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"})
gradient_accumulation_steps: Optional[int] = field(
default=4, metadata={"help": "the number of gradient accumulation steps"}
)
adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"})
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
reward_baseline: Optional[float] = field(
default=0.0,
metadata={"help": "a baseline value that is subtracted from the reward"},
)
batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"})
save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"})
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"})
seed: Optional[int] = field(default=0, metadata={"help": "the seed"})
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"})
init_kl_coef: Optional[float] = field(
default=0.2,
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
)
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
parser = HfArgumentParser(ScriptArguments)
script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
reward_model_name = script_args.reward_model_name
dataset_name = "lvwerra/stack-exchange-paired"
config = PPOConfig(
steps=script_args.steps,
model_name=script_args.model_name,
learning_rate=script_args.learning_rate,
log_with=script_args.log_with,
batch_size=script_args.batch_size,
mini_batch_size=script_args.mini_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
optimize_cuda_cache=True,
early_stopping=script_args.early_stopping,
target_kl=script_args.target_kl,
ppo_epochs=script_args.ppo_epochs,
seed=script_args.seed,
init_kl_coef=script_args.init_kl_coef,
adap_kl_ctrl=script_args.adap_kl_ctrl,
)
train_dataset = load_dataset(
"lvwerra/stack-exchange-paired", data_dir="data/rl", split="train", verification_mode="no_checks"
)
train_dataset = train_dataset.select(range(100000))
original_columns = train_dataset.column_names
# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": 16,
"truncation": True,
}
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name)
# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
# only for this model.
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
# Below is an example function to build the dataset. In our case, we use the IMDB dataset
# from the `datasets` library. One should customize this function to train the model on
# its own dataset.
def build_dataset(
tokenizer,
dataset_name="lvwerra/stack-exchange-paired",
):
"""
Build dataset for training. This builds the dataset from `load_dataset`, one should
customize this function to train the model on its own dataset.
Args:
dataset_name (`str`):
The name of the dataset to be loaded.
Returns:
dataloader (`torch.utils.data.DataLoader`):
The dataloader for the dataset.
"""
num_proc = 24
def preprocess_function(examples):
new_examples = {
"query": [],
"input_ids": [],
}
for question in examples["question"]:
query = "Question: " + question + "\n\nAnswer: "
tokenized_question = tokenizer(query, truncation=True)
new_examples["query"].append(query)
new_examples["input_ids"].append(tokenized_question["input_ids"])
return new_examples
ds = train_dataset.map(
preprocess_function,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False, num_proc=num_proc)
ds.set_format(type="torch")
return ds
# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(tokenizer)
def collator(data):
return {key: [d[key] for d in data] for key in data[0]}
# set seed before initializing value head for deterministic eval
set_seed(config.seed)
# Now let's build the model, the reference model, and the tokenizer.
current_device = Accelerator().local_process_index
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
load_in_8bit=script_args.load_in_8bit,
device_map={"": current_device},
peft_config=lora_config,
)
optimizer = None
if script_args.adafactor:
optimizer = Adafactor(
filter(lambda p: p.requires_grad, model.parameters()),
scale_parameter=False,
relative_step=False,
warmup_init=False,
lr=config.learning_rate,
)
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
config,
model,
ref_model=None,
tokenizer=tokenizer,
dataset=dataset,
data_collator=collator,
optimizer=optimizer,
)
# We then build the sentiment analysis pipeline using our reward model, passing the
# model name and the sentiment analysis pipeline arguments. Let's also make sure to
# set the device to the same device as the PPOTrainer.
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a ` pipeline` bug
sentiment_pipe = pipeline(
"sentiment-analysis",
model=reward_model_name,
device_map={"": current_device},
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
tokenizer=tokenizer,
return_token_type_ids=False,
)
if sentiment_pipe.model.config.pad_token_id is None:
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
generation_kwargs = {
# "min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": 100_000,
}
output_min_length = 32
output_max_length = script_args.output_max_length
output_length_sampler = LengthSampler(output_min_length, output_max_length)
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
if epoch >= config.total_ppo_epochs:
break
question_tensors = batch["input_ids"]
response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=False,
length_sampler=output_length_sampler,
**generation_kwargs,
)
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
# Compute reward score (using the sentiment analysis pipeline)
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
# Run PPO step
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")

View File

@ -0,0 +1,222 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed
from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset
"""
Fine-Tune Llama-7b on SE paired dataset
"""
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired")
parser.add_argument("--subset", type=str, default="data/finetune")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--size_valid_set", type=int, default=4000)
parser.add_argument("--streaming", action="store_true")
parser.add_argument("--shuffle_buffer", type=int, default=5000)
parser.add_argument("--seq_length", type=int, default=1024)
parser.add_argument("--max_steps", type=int, default=10000)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--eos_token_id", type=int, default=49152)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
parser.add_argument("--num_warmup_steps", type=int, default=100)
parser.add_argument("--weight_decay", type=float, default=0.05)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--bf16", action="store_true", default=False)
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num_workers", type=int, default=None)
parser.add_argument("--output_dir", type=str, default="./checkpoints")
parser.add_argument("--log_freq", default=1, type=int)
parser.add_argument("--eval_freq", default=1000, type=int)
parser.add_argument("--save_freq", default=1000, type=int)
return parser.parse_args()
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
"""
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
text = prepare_sample_text(example)
total_characters += len(text)
if tokenizer.is_fast:
total_tokens += len(tokenizer(text).tokens())
else:
total_tokens += len(tokenizer.tokenize(text))
return total_characters / total_tokens
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def prepare_sample_text(example):
"""Prepare the text from a sample of the dataset."""
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
return text
def create_datasets(tokenizer, args):
dataset = load_dataset(
args.dataset_name,
data_dir=args.subset,
split=args.split,
use_auth_token=True,
num_proc=args.num_workers if not args.streaming else None,
streaming=args.streaming,
)
if args.streaming:
print("Loading the dataset in streaming mode")
valid_data = dataset.take(args.size_valid_set)
train_data = dataset.skip(args.size_valid_set)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
else:
dataset = dataset.train_test_split(test_size=0.005, seed=args.seed)
train_data = dataset["train"]
valid_data = dataset["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
chars_per_token = chars_token_ratio(train_data, tokenizer)
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
train_dataset = ConstantLengthDataset(
tokenizer,
train_data,
formatting_func=prepare_sample_text,
infinite=True,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
)
valid_dataset = ConstantLengthDataset(
tokenizer,
valid_data,
formatting_func=prepare_sample_text,
infinite=False,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
)
return train_dataset, valid_dataset
def run_training(args, train_data, val_data):
print("Loading the model")
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
train_data.start_iteration = 0
print("Starting main loop")
training_args = TrainingArguments(
output_dir=args.output_dir,
dataloader_drop_last=True,
eval_strategy="steps",
max_steps=args.max_steps,
eval_steps=args.eval_freq,
save_steps=args.save_freq,
logging_steps=args.log_freq,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.num_warmup_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.gradient_checkpointing,
fp16=args.fp16,
bf16=args.bf16,
weight_decay=args.weight_decay,
run_name="llama-7b-finetuned",
report_to="wandb",
ddp_find_unused_parameters=False,
)
model = AutoModelForCausalLM.from_pretrained(
args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index}
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
peft_config=lora_config,
packing=True,
)
print_trainable_parameters(trainer.model)
print("Training...")
trainer.train()
print("Saving last checkpoint of the model")
trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
def main(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
train_dataset, eval_dataset = create_datasets(tokenizer, args)
run_training(args, train_dataset, eval_dataset)
if __name__ == "__main__":
args = get_args()
assert args.model_path != "", "Please provide the llama model path"
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
logging.set_verbosity_error()
main(args)

View File

@ -0,0 +1,76 @@
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model
## Prerequisites
Install all the dependencies in the `requirements.txt`:
```
$ pip install -U -r requirements.txt
```
Since we will use `accelerate` for training, make sure to run:
```
$ accelerate config
```
## Training
There were two main steps to the DPO training process:
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
```
accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \
--output_dir="./sft" \
--max_steps=500 \
--logging_steps=10 \
--save_steps=10 \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing=False \
--group_by_length=False \
--learning_rate=1e-4 \
--lr_scheduler_type="cosine" \
--warmup_steps=100 \
--weight_decay=0.05 \
--optim="paged_adamw_32bit" \
--bf16=True \
--remove_unused_columns=False \
--run_name="sft_llama2" \
--report_to="wandb"
```
1. Run the DPO trainer using the model saved by the previous step:
```
accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \
--model_name_or_path="sft/final_checkpoint" \
--output_dir="dpo"
```
## Merging the adaptors
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
```
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2"
```
which will also push the model to your HuggingFace hub account.
## Running the model
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via:
```py
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(
"dpo/final_checkpoint",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
model.generate(...)
```

Some files were not shown because too many files have changed in this diff Show More