* Compatibilitywith padding free and iterable dataset
* Fix collator test
* add a test for streaming
* some cleaning
* improve and fix tests
* tiny revert
* bump datasets to 3.0.0
* Added else clause to avoid NameError on optimizer_offload
* Accounted for deepspeed's renaming in 0.16.4
* Switched to packaging.version.parse over the (broken) tuple split
* Moved from NotImplementedError to RuntimeError in else clause
* Add support for additional generation kwargs in GRPO Trainer
- Extend GRPOConfig to support additional generation kwargs
- Update GRPOTrainer to incorporate additional generation parameters
- Add tests for training with additional generation kwargs for both standard and vLLM modes
* Add missing vllm_gpu_memory_utilization=0.5
* 🔧 Refactor GRPO generation parameters and configuration
- Restructure GRPOConfig to separate generation parameters
- Add support for top_p, top_k, min_p, repetition_penalty, and length_penalty
- Remove additional_generation_kwargs in favor of explicit parameters
- Update GRPOTrainer to use new generation parameter configuration
* Update tests
* Remove length_penalty and fix tests
* Update defaults and docs
- Change temperature type from Optional[float] to float
- Set default top_p to 1.0 instead of None
- Simplify parameter descriptions by removing redundant "if set to None" text
- Maintain consistent type hints and default values for generation parameters
* GRPO remove optional type hint for temperature parameter
* Remove length_penalty from sampling_kwargs dict in GRPOTrainer
* some refactoring
* top k None support
* change value of in test to amke them work
---------
Co-authored-by: Robert Veres <robert.veres@languagetool.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
* Deprecate liger
* remove import
* oops, shouldn't be here
* Fix other deprecations
* remove liger from gkd for now
* remove liger for teacher
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* Fix logits computation in DPO trainer prediction step
* fix compute_metrics for bco and test
* same for cpo
* same from dpo
* for kto
* anf finally orpo
* Apply style fixes
---------
Co-authored-by: kyungdae-jo <kyungdae.jo@navercorp.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
* ✨ Enhance GRPO logging with configurable completions sampling
- Update `GRPOConfig` to replace `log_completions` with `log_completions_steps`
- Add `print_prompt_completions_sample()` utility function for rich console logging
- Modify `GRPOTrainer` to additionally print 5 random prompt-completion pairs every log_completions_steps steps
* GRPO trainer completions logging, move wandb checks together
* Add rich availability check and use fallback in print_prompt_completions_sample when rich is not available
* Update docstrings on print_prompt_completions_sample
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Revert back to simple log_completions bool
* GRPO log completions fully
* Remove print fallback from print_prompt_completions_sample
* Move accelerator main process check up for grpo log completions
* Explicit variable names in print_prompt_completions_sample
* Make GRPOConfig docstring match field description
* Update log_completions docs again
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update GRPOConfig docs to match field
* improve readibility when prompt or completions are multilines
* log reward
* prevent hanging, don't print without rich, print reward
* style
---------
Co-authored-by: Robert Veres <robert.veres@languagetool.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
* Add num_updates and epsilon parameters to GRPOConfig and GRPOTrainer
* test sampler
* update the loss computation
* fix eval sampler
* should work now
* buffer inputs with grad accum
* optimize when num_iterations == 1
* test
* minor comment removal and fix log metric
* beta position
* clarify comment [ci skip]
* clarify sampler doc [ci skip]
* fix collision with eval logging
* clarify
* 🔧 Optimize GRPO training by conditionally loading reference model based on beta value
* ✅ Add test for GRPOTrainer with beta=0 to ensure no reference model and KL divergence
* 🔧 Refactor GRPOTrainer code for improved readability and maintainability
* 🔧 Simplify per_token_loss calculation in GRPOTrainer for clarity
* fix test, style, and some struct for clarity
---------
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* return dataset if it's preprocessed
* add is_processed flag variable
* add test
* move test_sft_trainer_directly_with_pretokenized_data to Tester2
* Update sft_trainer.py
* no need for padding and truncation
* minor reorganization
* Update trl/trainer/sft_trainer.py
* let the collator pad
* style
* fix tests
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* fix check for AutoLigerKernelForCausalLM
* fix case where AutoLigerKernelForCausalLM is not defined
* update min liger version
* formatting
* fix win CI
- Fixed a bug where an extra `len` call inside the error message caused a `TypeError` instead of the expected `ValueError`.
- Replaced `len(len(args.reward_weights))` with the correct `len(args.reward_weights)` to properly calculate the number of reward weights.
- Ensured that a `ValueError` is now raised with an accurate and clear message when the number of reward weights does not match the number of reward functions.
This fix prevents confusion during debugging and ensures proper error handling during validation.
Tested with cases where:
- `args.reward_weights` is None (default case).
- `args.reward_weights` has mismatched lengths with `reward_funcs`.
* added reward weights for multi-reward runs in GRPO
* reward_weights are float, moved from GRPOTrainer to GRPOConfig
* minor comment fix
* minor
* fix test
* missing link
---------
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* add token accuracy metric
* fix return type
* shift tokens
* use compute_loss so that the model is called only once
* add to logs
* log from main process
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Distribute
* fix some logic errors
* fix and document RepeatRandomSampler
* comment
* doc clarification
* fix type hint
* more readable
* fix eval
* fix tests
* roll back to distribute generation
* improve comment [ci skip]
* fix slice
* catch for eval batch size as well; fix completion_ids in vllm
* log completions
* Revert "log completions"
This reverts commit 1e4af8ffb8dda15d7596e707ac784208db88135a.
* Before the first training step, the model has no optimizer: fix ds3
* properly unwrap torch.compile-ed models with GRPO
* add test and compat with reward models
* ignore test windows
* properly unwrap torch.compile-ed models with GRPO
* add test and compat with reward models
* ignore test windows
* chore: lint
* style
---------
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* [DOCS] add GRPOTrainer to README.md
I replaced RLOOTrainer with GRPOTrainer because you thought you might want to keep it limited, but let me know if you want both.
* Update README.md
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* add eval loss logging during predition
* make sure the train and eval logs aren't mixed
* test grpo in eval
* fix tests
---------
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* rloo custom reward function and test
* idont even know why i did that
* removing get_reward_custom
* remove get_reward_custom test
* fix code quality check
* adding test
* end this mysery already
* fix test
* initial commit
* doc on custom reward function
* test
* doc doc doc
* fix collator
* style
* links?
* I need a docdoc 🎵
* fix link
* I do like writing doc tbh
* it takes time, but it's worth it
* no return!
* type hint
* it's probably the best of both worlds [ci skip]
* new doc before implementation
* tests
* more doc
* style
* multiple pretrained funcs
* fix arg name
* main?
* example for R1
* fix script
* clearer
* import [ci skip]
* Update docs/source/grpo_trainer.md
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
---------
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* init grpo [ci skip]
* initial version
* refine args defs
* model card
* initial doc
* fix badges
* fix spaces
* try link to super in doc
* temperature, fix indexing, and std=0.0
* grpo script for cli
* peft support
* move data preparation in `compute_loss`
* weird doc trial
* fix device and some logging
* unwrap_model_for_generation for distributed setting
* Compat with distrib training
* revert grpo config doc trial (didn't work)
* test
* allow model to be str and processing_class to be none; fix loss computation
* advantage is always 0.0: don't log
* fix peft not installed
* proper reward model for testing
* fix script for cli
* add trl grpo to cli doc
* test peft
* flush left
* fix reward calculation
* new reward model
* support any reward model
* fix reward processing class def
* log reward std
* fix reward logging
* fix grad computation
* skip embed layer in test
* remove optimizer_cls_and_kwargs
* improve GRPO default args
* reduce mem usage for grpo test
* reduce mem usage in test grpo
* reduce memory usage for test
* Fix the test
* remove redondant
* fix min version
* Update test_grpo_trainer.py
* Update test_grpo_trainer.py
* Fix test, finally found the solution!
* some doc
* Update doc-builder workflow to use specific commit sha
* more doc
* advantages
* drop cancel fo no grad
* logged metrics [ci skip]
* completion col is ignored [ci skip]
* fix latex
* double space? ~?
* try a latex fix
* with branch
* Empty commit
* Empty commit
* double space seems to be the solution
* set default for max_length and max prompt lenngth and add guidelines for defaults
* remove dep kwargs
* truncate prompt in prm
* Update CONTRIBUTING.md [ci skip]
* vllm online dpo
* new arg and add back generation config [skip ci]
* import utils
* optional import and comment
* is_vllm_available
* support conv and not conv [ci skip]
* add old code back
* use func [skip ci]
* fix _generate call
* fix and dedicated func
* top k 50
* style
* add import error
* new testing model
* Update OnlineDPOTrainer class with new features
* test vllm
* fix generate tiny script
* max len arg
* fix comment [ci skip]
* revert num_return_sequences
* vllm dep
* Add require_torch_accelerator import and skip test if vllm is not available
* proper require_torch_accelerator
* add vllm section
* Add hfoption sections to speeding_up_training.md
* no, an id
* Update vllm dependency to exclude Windows platform
* Note on future release
* style
* adding readme for ultrafeedback dataset
* using ModelCard as DatasetsCard like hf datasets is understaffed
* more info in readme.md of the dataset
* generated readme for all dataset scripts
* precommit
* fixing test
* md format; corrections; generation script link
* some collections
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* padding free
* specify dtype
* test
* warnings when not flash attention
* fix test
* remove
* docstring padding-free
* flash-attn dep
* Stronger warning
* require_flash_attn in test
* flash-attn in CI
* rm flash-attn from dep
* Remove flash-attn dependency from test workflows
* refactor
* Update .github/workflows/tests.yml
* Update trl/trainer/dpo_trainer.py
* drop require flash-attn
* fix dtype
* refine warning
* Update trl/trainer/dpo_config.py
* Add logic to compute mean logits for chosen and rejected tokens with padding-free
* format
* Update trl/trainer/dpo_trainer.py
* Update trl/trainer/dpo_trainer.py
* fix comment [ci skip]
* fix num logits to keep
* Implemented integration with Comet in `LogCompletionsCallback`. Implemented related integration test.
* Implemented integration with Comet in `CPOTrainer.evaluation_loop()` during logging of `game_log` table.
* Implemented integration with Comet in `CPOTrainer.evaluation_loop()` during logging of `game_log` table.
* Implemented integration with Comet in `DPOTrainer.evaluation_loop()` during logging of `game_log` table.
* Implemented integration with Comet in `BCOTrainer.evaluation_loop()` during logging of `game_log` table.
* Implemented integration with Comet in `KTOTrainer.evaluation_loop()` during logging of `game_log` table.
* Implemented integration with Comet in `ORPOTrainer.evaluation_loop()` during logging of `game_log` table.
* Added support for Comet URL integration into model cards created by trainers.
* Moved `get_comet_experiment_url()` into utils.py
* Updated Comet badge in the model card to use PNG image instead of text.
* Fixed bug related to running PPO example during model saving. The error as following: 'GPTNeoXForCausalLM' object has no attribute 'policy'. Introduced guard check that attribute `policy` exists.
* Implemented utility method to handle logging of tabular data to the Comet experiment.
* Implemented logging of the completions table to Comet by `PPOTrainer`.
* Implemented logging of the completions table to Comet by `WinRateCallback`.
* Implemented logging of the completions table to Comet by `RLOOTrainer` and `RewardTrainer`.
* Restored line to the main branch version.
* Moved Comet related utility methods into `trainer/utils.py` to resolve merge conflict with master branch,
* Update trl/trainer/utils.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Implemented raising of `ModuleNotFoundError` error when logging table to Comet if `comet-ml` is not installed.
* import comet with other imports
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* initial skeleton
* tokenize fn
* adding bos and eos to tokenization fn
* prmtrainer
* fixing small typo in tokenize
* typo in input_ids and labels construction
* numpy dimension
* introduce the stepwise reward trainer
* update markdown files
* let user decide post step separator in config
* doc post_step_separator
* do not add post step_tokens to last step of the reasoning process
* renaming prm to stepwisereward
* formatting
* fix tokenize kwargs
* adapt test to the new post_token args
* adding example script
* fix small typo
* add create_model_card and renaming
* fixing booleans
* Adding the new stepwise_preference instead of placeholders for datasets
* formatting
* Update docs/source/_toctree.yml
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update examples/scripts/stepwise_reward_modeling.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update trl/trainer/stepwise_reward_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update trl/trainer/stepwise_reward_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* update push to hub
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* step_separator can't be None
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* fix suggested typos
* add citation
* reformat doc
* reordering init
* push to hub prm800k
* changing dataset in example
* change dataset format to align with the sky is blue example
* fix tokenization column names
* fix num labels in openai example
* add support for conversational dataset
* remove training whitespace
* replace tokenizer with processing class
* Update docs/source/dataset_formats.mdx
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* remove openai_prm800k
* Update trl/trainer/stepwise_reward_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update trl/trainer/stepwise_reward_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update docs/source/stepwise_reward_trainer.mdx
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update docs/source/stepwise_reward_trainer.mdx
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* renaming
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* renaming
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* minor renamings in docs
* using prm800k instead of openai_prm800k
* update num labels to 2 following the new format
* changing doc examples to math examples
* change reference to dataset_formats.mdx
* changing dataset config in test
* remove conversational dataset support
* remove conv dataset support
* fix bos token
* fix scriptarguments in example
* completion to completions
* remove valuerror for step_separator inside steps
* run precommit
* remove conv dataset support
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* renaming zen dataset
* remove unused printing
* unknown label column
* introduce the train on last step arg
* _tokenize support train_on_last_step
* incorporate train_on_last_step to tests
* formatting
* remove comments in trainer
* Refactor `tokenize_row`
* Update max_completion_length parameter in StepwiseRewardConfig
* Collator
* Update comment
* Update type hint
* fix table
* Remove collator
* don't need pad token id
* add error back
* max length args
* use tokenizer arg
* Update doc
* label -> labels
* fixing tokenization issues in tokenize row
* correct labels for token classification
* adding max_length to tokenize_row
* reformat tests
* adding tests for tokenize row
* fixing typos in comments
* update doc
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* Add math_shepherd.py script for dataset processing
* split the dataset
* formatting
* same evaluation method for the two training methods
* adding filtering to example script
* formatting
* Add features to avoid casting labels to bool in dataset tokenization
* Update docs/source/stepwise_reward_trainer.mdx [ci skip]
* Add learning_rate parameter to StepwiseRewardConfig class
* update doc
* Remove unused setup_chat_format function
* Fix warning message in stepwise_reward_modeling.py
* Update logging steps in stepwise_reward_trainer.mdx
* little doc change [ci skip]
* Fix copyrights
* fix space after copyrights
* Update dataset loading in stepwise_reward_modeling.py
* refine compute_accuracy and proper test
* fix tests
* style
* renamings
* renaming in init
* doc renaming
* fix sorting and tag
* experiemental [ci skip]
* trigger CI
* other doc fix
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* function calling training support for SFTTraining
* adding tool support to data_utils
* adding test for function calling tokenizer
* reverting changes to sfttrainer and config,added maybe_apply_chat_template
* arg for maybe_apply_chat_templates docstring
* Doc sectioning
* minor test modification
* minor doc modification
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* refactor parser
* Only document some methods
* Update imports in cli_utils.py and remove config option in utils.py
* add `test_parse_args_and_arg_override_config` and remove unnecessary mocks [ci skip]
* fix comment [ci skip]
* fix comment [ci skip]
* Extra arg in config also returned
* fix docstring [ci skip]
* add mock back
* use `deprecate_kwarg`
* Add guidelines for working with warnings in the codebase
* Remove unnecessary warnings and improve code initialization
* Fix warnings and improve accuracy calculation
* Add rich library dependency for text formatting
* Update LoRA weight loading warning message
* Fix logging and import issues in AlignPropConfig
* Fix warnings and improve code readability
* Remove unused import statements
* Refactor CPOTrainer class in cpo_trainer.py
* Remove unnecessary warnings and raise ValueError for missing model
* Fix warnings and improve code consistency
* Update CONTRIBUTING.md to clarify the purpose of warnings
* Fix string formatting in DataCollatorForCompletionOnlyLM class
* Update SimPO loss parameters in CPOTrainer
* Fix warnings and remove unnecessary code in ConstantLengthDataset class
* Clarify warning guidelines
* Rewrite the entire section
* Fix capitalization in CONTRIBUTING.md
* Fix formatting in CONTRIBUTING.md
* Add script_utils.md to the documentation
* Refactor ScriptArguments class documentation
* Refactor TrlParser class to improve code organization and readability
* first commit
* uncomment
* other tests adaptations
* Remove unused variable in test_setup_chat_format
* Remove unused import statement
* style
* Add Bart model
* Update BCOTrainerTester class in test_bco_trainer.py
* Update model IDs and tokenizers in test files
* Add new models and processors
* Update model IDs in test files
* Fix formatting issue in test_dataset_formatting.py
* Refactor dataset formatting in test_dataset_formatting.py
* Fix dataset sequence length in SFTTrainerTester
* Remove tokenizer
* Remove print statement
* Add reward_model_path and sft_model_path to PPO trainer
* Fix tokenizer padding issue
* Add chat template for testing purposes in PaliGemma model
* Update PaliGemma model and chat template
* Increase learning rate to speed up test
* Update model names in run_dpo.sh and run_sft.sh scripts
* Update model and dataset names
* Fix formatting issue in test_dataset_formatting.py
* Fix formatting issue in test_dataset_formatting.py
* Remove unused chat template
* Update model generation script
* additional models
* Update model references in test files
* Remove unused imports in test_online_dpo_trainer.py
* Add is_llm_blender_available import and update reward_tokenizer
* Refactor test_online_dpo_trainer.py: Move skipped test case decorator
* remove models without chat templates
* Update model names in scripts and tests
* Update model_id in test_modeling_value_head.py
* Update model versions in test files
* Fix formatting issue in test_dataset_formatting.py
* Update embedding model ID in BCOTrainerTester
* Update test_online_dpo_trainer.py with reward model changes
* Update expected formatted text in test_dataset_formatting.py
* Add reward_tokenizer to TestOnlineDPOTrainer
* fix tests
* Add SIMPLE_CHAT_TEMPLATE to T5 tokenizer
* Fix dummy_text format in test_rloo_trainer.py
* Skip outdated test for chatML data collator
* Add new vision language models
* Commented out unused model IDs in test_vdpo_trainer
* Update model and vision configurations in generate_tiny_models.py and test_dpo_trainer.py
* Update model and tokenizer references
* Don't push if it already exists
* Add comment explaining test skip
* Fix model_exists function call and add new models
* Update LlavaForConditionalGeneration model and processor
* `qgallouedec` -> `trl-internal-testing`
* Create mergekit_utils.py
* adding mergekit as an optional dependancy
* adding MergeModel to callbacks
* adding mergekit_utils dependencies to callbacks
* setting lower bound for mergekit
* setting mergekit lower band to 0.0.5.1
* adding support for MergeModelCallBack __init__.py
* adding support for mergemodelcallback
* mergemodelcallback tests
* Update callbacks.py
* Update __init__.py
* Update __init__.py
* Update test_callbacks.py
* Update trl/trainer/callbacks.py
removing ## from docs
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update trl/trainer/callbacks.py
removing ## from docs
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update trl/trainer/callbacks.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* using different dataset for tests
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update trl/mergekit_utils.py
adding types
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update trl/mergekit_utils.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Apply suggestions from code review
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* replacing get_last_checkpoint
* renaming Merge to merge_models
* setting mergers default value to linear
* removing unnecessary docs and comments
* adding docstring to Mergeconfig
* adding mergekits link to docstring
* precommit
* removing duplicated import
* typos in mergekit_utils docstring
* fixing tests
* making mergemodelcallback tests optional
* Make import optional
* minor
* use tmp dir in test
* sort
* Add import error checks for mergekit extra
* use a common _merge_and_maybe_push method and compat with windows path
* debug windows
* Update dependencies for mergekit and add test dependencies
* Add assertion to check if merged folder exists in the last checkpoint
* Fix temporary directory cleanup in test_callbacks.py
* Add sys import and skip test for Python versions below 3.10 due to cleanup errors with temp dir
* revert change for debug
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* Fix "Use specified data_collator instead of hard-coding the option"
* Remove query_responses = [] since it's immediately overwritten afterwards.
* Use self.data_collator
* Use specified data_collator instead of hard-coded one in PPOTrainer
* Move the data_collator creation
* Run make precommit
* Support num_logits_to_keep, which computes necessary logits in the forward pass.
* update doc
* bug fix
* update
* check is model supports num_logits_to_keep
* ruff format
* update test file
* peft model support
* test passed
* update
* apply use_num_logits_to_keep
* fix num_logits_to_keep compute bug
* compare all outputs
* pytest
* pass test
* use check_min_version
* format
* test_dpo_trainer_use_num_logits_to_keep passed
* add some comments
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* Bump dev version to `0.13.0.dev0`
* Update version number to 0.12 in CITATION.cff
* 🧽 Fix judge documentation (#2318)
* Update judge examples and documentation
* without ':'
* Clean doc
* Fix typo in example code
* Add space after Attributes
* Update attribute name in judges.py
* Add installation instructions for llm-blender library
* Update PairRMJudge attributes documentation
* Fix return type in PairRMJudge
* Revert "🧽 Fix judge documentation (#2318)"
This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.
* Revert "🧽 Fix judge documentation (#2318)"
This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.
* 🧽 Fix judge documentation (#2318)
* Update judge examples and documentation
* without ':'
* Clean doc
* Fix typo in example code
* Add space after Attributes
* Update attribute name in judges.py
* Add installation instructions for llm-blender library
* Update PairRMJudge attributes documentation
* Fix return type in PairRMJudge
* Bump dev version to `0.13.0.dev0`
* Update version number to 0.12 in CITATION.cff
* Add publication date to blog post
* 🧽 Fix judge documentation (#2318)
* Update judge examples and documentation
* without ':'
* Clean doc
* Fix typo in example code
* Add space after Attributes
* Update attribute name in judges.py
* Add installation instructions for llm-blender library
* Update PairRMJudge attributes documentation
* Fix return type in PairRMJudge
* Revert "🧽 Fix judge documentation (#2318)"
This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.
* Update blog post publication dates
* revert to p5
* Update image URLs in index.mdx
* Sort and uniform thumbnail
* Update image alignment in index.mdx
* Bump dev version to `0.13.0.dev0`
* Update version number to 0.12 in CITATION.cff
* 🧽 Fix judge documentation (#2318)
* Update judge examples and documentation
* without ':'
* Clean doc
* Fix typo in example code
* Add space after Attributes
* Update attribute name in judges.py
* Add installation instructions for llm-blender library
* Update PairRMJudge attributes documentation
* Fix return type in PairRMJudge
* Revert "🧽 Fix judge documentation (#2318)"
This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.
* Add conditional check for LLMBlender availability in test_judges.py
* Fix import issues and update test requirements
* Remove unused imports
* Add require_peft decorator to test cases
* Fix import_utils module to use correct package name for llm_blender
* Found min version and test
* Update Slack notification titles
* Update dependencies versions
* Update GitHub Actions workflow to include setup.py and reorder file paths
* Revert "Update Slack notification titles"
This reverts commit be02a7f2de87905e86a847540770968d0416934a.
* Update Slack notification titles
* Remove pull_request branch restriction in tests.yml
* add check code quality back
* Fix PairRMJudge model loading issue
* Add conditional check for LLMBlender availability in test_judges.py
* Fix import issues and update test requirements
* Remove unused imports
* Add require_peft decorator to test cases
* Fix import_utils module to use correct package name for llm_blender
* Update trainer_utils import and save strategy in online_dpo_trainer.py
* fix back-compat for online-dpo
* better comment
* Update transformers dependency to commit f33904
* clean deps
* new tests
* tests
* Add tests without optional dependencies workflow
* Update dependencies in tests.yml
* cpu version of torch
* Update dependencies and installation commands
* Disable fail-fast in test workflow
* Update test matrix in workflows file
* try fix windows
* Remove "rich" from required packages in setup.py
* Update dependency installation in tests.yml
* Add torch and deepspeed installation for windows-latest
* Fix conditional statement in workflow file
* Add torch and deepspeed installation for Windows
* Fix if statement
* Update torch and deepspeed dependencies
* Update liger package requirement for non-Windows platforms
* remove scipy dep
* Add torch GPU requirement for testing_utils
* Update trl/trainer/judges.py
* Refactor reward processing in OnlineDPOTrainer
* Refactor completion decoding and reward processing
* remove strip
* remove warning
* Add reward_tokenizer to training script
* Add reward_tokenizer and reward_processing_class to OnlineDPOTrainer test
* propagate to xpo and nash
* style
* reduce memory requirement with inference_mode
* fix tests
* pairrm judge llmblender
* setUpClass(cls)
* Add setUpClass method to TestJudges class
* truncation left for reward tokenizer
* don't logcompletion without eval dataset
* only eval when possible
* use the pair-judges
* add test
* Update trl/trainer/online_dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update trl/trainer/online_dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* decode and skip special characters
* initial nash
* return tensors
* Update trl/trainer/online_dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update trl/trainer/online_dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update trl/trainer/online_dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* add back the logging
* use batch_decode
* add judges api to XPO trainer
* Update tests/test_online_dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* judge in examples
* judge in config
* add back logs when using reward model
* typo
* add back model_scores logging when using reward model
* log scores for reward model only
* better cond on what to log
* same for rlhf reward
* Update trl/trainer/online_dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* use decode_and_strip_padding
* error if both reward and judge or none are set
* remove unused check
* Uniform way to pass conversation into judge
* heading -> leading
* LogCompletionsCallback compat with online method
* Update Online DPO doc
* check if data is conversational for judges
* update example
* remove comment
* use zip
* fix stats xpo
* Replace judge with PairRMJudge and import AutoModelForSequenceClassification
* update xpo documentation
* Remove doc duplication
* update nash doc
* XPO trl chat
* nash md doc
* HfPairwiseJudge
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* `get_batch_sample` -> `generate_from_model[_and_ref]`
* add `num_items_in_batch=None`
* `num_items_in_batch` in `training_step`
* Fix return type hint
* desc for unpair dataset util
* update example
* process in KTO
* Update doc
* KTO doc rewrite
* fix orpo doc
* add other dataset config names in test
* update doc image
* fix links in doc
* Update reward and log probability metrics in KTOTrainer doc
* skip enc-dec test
* Update docs/source/kto_trainer.mdx
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
---------
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* setup_chat_format: throw error if there was already a template
* fix lint
* clarify in docs
* fix test?
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* in progress
* refactor concatenated_inputs and concatenated_forward
* progress
* further modif
* padding side
* eos prompt enc dec
* prompt_padding_side
* drop prompt apdding side collator
* working on decoder only
* dpo trainer
* Fix loss_mask type conversion bug
* bad attention mask
* try to get the same tokens as main
* fix loss mask
* fix unused col
* added comment
* raise error when paddind token not set
* remove private method tests
* initial vlm support
* make it work for paligemma
* minor test updates
* style
* improve readibility
* improve doc
* style
* flush left and truncate
* flush left in the code
* fix empty_cols and make max_length optional
* always add eos token
* minor changes and doc
* style
* fix docstring
* preference collator in doc
* fix doc
* optional max_completion_length
* Investigating CI failing
* style
* just dpo trainer test
* just idefics
* paligemma
* llava
* test cli
* dataset in test
* all tests
* Update trl/trainer/dpo_trainer.py
* Update trl/trainer/dpo_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update trl/trainer/dpo_trainer.py
* Update trl/trainer/dpo_trainer.py
* reference to ref
* rich descriptions
* fix logits reporting
* fix truncation
* remove chat template from dpo_vlm
* `get_batch_sample` -> `generate_from_model[_and_ref]`
* add `num_items_in_batch=None`
* `num_items_in_batch` in `training_step`
* Fix return type hint
* test tokenize row
* fix test
---------
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update log_example_reports.py
1. Added logging: Imported the logging module and set up a logger in the main function. This allows for better error tracking and debugging.
2. Improved file reading: Used a with statement to ensure the file is properly closed after reading. Also added error handling to catch and log any issues when reading the file.
3. Error handling for Slack SDK import: Added a try-except block to handle cases where the slack_sdk might not be installed.
4. Enhanced Slack message sending: Added error handling and logging for the Slack message sending process. This will help identify any issues with the Slack integration.
* style
* Update log_reports.py
1. Logging: Added logging to track errors and important events.
2. Error Handling: Wrapped the log file processing in a try-except block to handle potential errors gracefully.
3. Logging Total Failed Tests: Added a log statement to report the total number of failed tests
* style
* further improve
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* `DPOScriptArguments` to `ScriptArguments`
* use dataset_train_split
* Use scriptarguments
* dataset names in command lines
* use `ScriptArguments` everywhere
* ignore biais buffer to end
* remove in v0.13
* rm comment
* update test commands
* Update docs/source/rloo_trainer.md
* Update tests/test_rloo_trainer.py
* Added dataset_train_split argument to ppo.py and rloo.py
* update scripts with dataset_train_split
* Updated README.md with CLI examples and additional usage instructions
Added Command Line Interface (CLI) examples for SFT, DPO, and Chat features.
Improved the "How to Use" section by providing code examples for SFTTrainer and RewardTrainer.
Included installation instructions for both Python Package and source-based installation.
Refined highlights to better showcase efficiency and scalability features.
Updated the repository clone instructions for working with examples.
Added new links to CLI documentation and contribution guide for better navigation.
* Update README.md
* Update README.md
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update README.md
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update README.md
* update badges
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* Update log_example_reports.py
1. Added logging: Imported the logging module and set up a logger in the main function. This allows for better error tracking and debugging.
2. Improved file reading: Used a with statement to ensure the file is properly closed after reading. Also added error handling to catch and log any issues when reading the file.
3. Error handling for Slack SDK import: Added a try-except block to handle cases where the slack_sdk might not be installed.
4. Enhanced Slack message sending: Added error handling and logging for the Slack message sending process. This will help identify any issues with the Slack integration.
* style
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* Update incorrect data processing in DataCollatorForChatML
Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels.
* Update trl/trainer/utils.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* style
* move comment
* add test for DataCollatorForChatML
* update comment with more details
* update assert reports and comments, and adds verification that the last token of input_ids should be EOS token
* new line at the end of file for code quality
* Update tests/test_utils.py
* Update tests/test_utils.py
* Update tests/test_utils.py
* update tests
* fix test
* Update tests/test_utils.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update tests/test_utils.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* formatting
* fix typo
* simplify
* Revert "simplify"
This reverts commit 7e4006c87265665183032932ca05dffef567e38b.
* tokenize full messages
* dont add eos
* eos is in the last token
* simplify DataCollatorForChatML
* Update tests/test_utils.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* Update README.md
Fix grammatical errors in README.md
fixes issue #2185
Description:
I found a grammatical error in the README.md of the project. This PR fixes the error to improve the overall readability and clarity of the documentation.
Changes:
Corrected grammatical errors
Updated lines to reflect the correct grammar
Reasoning: The original text contained a grammatical error that could confuse readers. This fix ensures that the documentation is accurate and easy to understand.
Closes#2185
* Update README.md
Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
* add argument for dropout
* increase default lr
* change default lr in examples
* fix bug in calculation of KL batch size
* KL batch size should be args.per_device_train_batch_size
* Update kto_trainer.mdx with hparam recs
* typo
* allow dropout to be disabled
* update lr in sample scrippt
* Update kto_config.py
* Update trl/trainer/kto_trainer.py
* Update docs/source/kto_trainer.mdx
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* clarify ConstantLengthDataset usage
* dont provide dataset text field when formatting func is provided
* kto maybe_apply_chat_template
* default text field
* doc
* remove maybe_apply_chat_template from kto example
* dataset text field always a str
* remove `dataset_text_field="text"`
* update doc
* conversational dataset support for dpo
* support standard dataset for extract prompt
* test standard dataset for extract prompt
* fix maybe
* fix maybe apply prompt
* style
* overwrite default learning rate of DPO
* style
* rlaif script
* `writer_batch_size` in `train_test_split`
* initial dpo doc refactoring
* vision data section in doc
* lil format modif
* refine Vision datasets
* refine doc
* test new loss type format
* restrcture loss function
* table loss type
* simplify `unsloth`
* improve doc
* looged metrics up
* refine loss section
* Fix label_smoothing parameter in DPOConfig
* dataset for test
* update readme
* Update docs/source/dpo_trainer.mdx
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* try colorized code block
* refine doc style
* further refine doc
* Update docs/source/dpo_trainer.mdx
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* re add pali gemma test
* Add missing period
---------
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* tokenize while training
* same for nashmd and xpo
* Update trl/trainer/online_dpo_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
---------
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
The function _process_tokens in trl/trainers/kto_trainer.py crashes if the prompt_input_ids are an empty list.
- added a check for nonzero length
- added a check for nonzero length of answer_input_ids for consistency
The checks happen when determining when subtracting 1 from max_length (happens when BOS or EOS is already present).
* fix neftune_noise_alpha
* del neftune_noise_alpha first
* check len after removing handle
* make sure we do not load twice
* Update trl/trainer/sft_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* remove neftune from SFTTrainer as the superclass has it now
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* learning rate recomentations for kto
* update from suggestion
* override default lr
* add tip tag
* Update trl/trainer/kto_config.py
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* make Orpotrainer run faster on tpu
* less data transfer
* train-trl.py
* fix
* set device_map=auto
* add is_torch_xla_available guards
* delete file
* address comments
* make presubmit
* Update transformer version in setup.py
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* initial xpo trainer
* compute rewards and ref log probs in smaller batches
* add logging
* initial log docs
* fix global_step increment
* fix metric descriptions
* use messages API
* use training_step API
* fix logs
* add test
* add back max_new_tokens
* use max_new_tokens
* refactor
* top_k is an int
* fix formatting
* fix the loss
* fix logging
* fix logging
* fix logging
* fix loss
* calcuate pi_log_ratio once
* fix stats
* fix loss
* do not log loss again
* fix docs
* add disable_dropout_in_model via flag
* comments
* revert doc change
* rm empty cache in online dpo
* improve doc xpo config
* some comment
* fix loggings stats
* fix docs
* save the model
* fix model and reward model
* Update trl/trainer/xpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
---------
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Subtract a penalty from OnPolicy Trainers if output does not contain an EOS token
* Caught a few other problems
* Updated the documentation for RLOO trainer and PPOv2Trainer
* Corrected the default type and value for missing_eos_penalty
* Made RLOO Trainer consistent with Online DPO and PPOv2
* Removed --non_eos_penalty from all documentation
* Made missing_eos_penalty examples positive (because we subtract).
* Caught two more incorrect examples
* Removed unnecessary whitespace to make ruff happy
* Update trl/trainer/utils.py
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* drop canonical
* Delete ultrafeedback_prompt_only.py dataset script
* reduce dif in best_of_n
* try to revert best_of_n to make github happy
* anyway...
* fix: prevent unpackaging error due to additional **aux_loss** returned by **concatenated_forward** function when **aux_loss_enabled** is set to True.
* Refactor: Simplify tuple unpacking in `concatenated_forward` call in `get_batch_loss_metrics` function
* Refactor: improve code quality
* fix dataset and value error in sft
* Update trl/trainer/sft_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* move the test to the right place
---------
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* correct formatting of star sign in kto_trainer.mdx
The "*" symbol in markdown doesn't show. I changed it to $\times$ so the mathematical formula is clearer
* fix markdown
* one more try
* feat : add kto command
* feat : add support for apo loss in KTO Trainer
* feat : make kto script compatible with dpo-formatted datasets
* fix: lint data utils
* add loss_type in kto test
* fix: data utils docstrings
* fix: add dataset reformat test
* fix: lint tests
* fix: only reference kl_logps if needed
---------
Co-authored-by: Karel D'Oosterlinck <karel@contextual.ai>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* tokenize and process DPO data via batches
* use helpers
* updated _process_tokens
* fixed
* incorporate build_tokenized_answer in the _tokenizer
* Update trl/trainer/dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* Update trl/trainer/dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* fix tokenizer for is_vision_model
* Update trl/trainer/dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* give the _tokenize the tokenizer as well as optional processor
* fix tests
* add bos and eos tokens
* add prompt_pixel_attention_mask
* Update trl/trainer/dpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* truncate by max_length
* formatting
* fix for enc-dec
* For encoder-decoder models, we need to use the prepared decoder_input_ids
* add tests for _build_tokenized_answer and _tokenize_feature
* check for EOS and BOS tokens
* formatting
* do not include pixel mask if they are not provided
* undo refactor
* undo add_bos_token_if_needed change
* refactor tokenizer into smaller helpers
* add back comments
* fix type hints
* format
* fix t5 tests
* args are never optional
* move cat to appropriate helper
* fix _truncate_tokens
* add tests for _truncate_tokens
* remove dead code
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Test for #1970
* style
* drop last element in the batch for test
* check prompt_input_ids not modified
---------
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* Fix issue with precompute_ref_log_probs not working when rpo_alpha is None
* Test: Add test for precompute_ref_log_probs with rpo_alpha=None
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* fix model to save in ppov2
currently saving self.backup_model but this should be self.model
self.backup_model is only a temp model used to store the policy and
value function whereas self.model should have just the policy to save
* simplified logic
* remove unused ordereddict
* format
* fix the fix
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* feat: support RS-LoRA in the ModelConfig
* build: bump minimum peft version to support rslora
* test: add test for get_peft_config
* test: make test python 3.8 friendly
* rm unused marker
* minor changes
* simplify, clarify doc
* update deps (peft in test)
* re-ordering
* fix setup
---------
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* skip bigbird in ci
* readd big bird test
* pytest parametrize
* dont check the version
* rm model name
* re add big bird
* Merge branch 'main' into readd-bigbird-save-load-test
---------
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
* Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM
Added ```dataset_text_field``` in the SFTConfig while training
* Update docs/source/sft_trainer.mdx
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* online dpo cleanups
* remove unused self.policy
* add OnlineDPOTrainer and config to __init__.py
* import from trainer
* online dpo test
* rename policy to model and ref_policy to ref_model
* renamed internally
* formatting
* Fix `torch_dtype` handling through CLI
The `torch_dtype` is not properly handled when provided via the TRL CLI
since it's provided initially as a string, but is then casted to
`torch.dtype` before providing it to the `{DPO,SFT}Trainer`, which means
that those trainers should handle the scenario where `torch_dtype` is a
`torch.dtype` too.
* Add `torch_dtype` tests in `test_{dpo,sft}_trainer.py`
* Forward contribution credits
* Run `make precommit`
---------
Co-authored-by: Tash Srivastava <yash-srivastava19@users.noreply.github.com>
* Preserve token fields when converting TrainingArguments to SFTConfig
TrainingArguments.to_dict() redacts token fields, so we have to
individually copy them over when converting to SFTConfig to avoid
breaking push_to_hub functionality.
Also adds a test.
* run precommit
* one-line args_as_dict definition per suggestion from kashif
* generalize token copying to match TrainingArguments behavior
* unwrap |= on dict, to support python 3.8
* use .update instead of |= or for-loop
* Remove extra whitespaces
* idefics
* vdpo
* sft idefics
* pad with test
* use prompt instead of tokenizer
* rm name main
* support vlm in tokenize row
* temp fix for regex in lora_target_module
* format
* vdpo
* tmp float16 hard code
* concatenated_forward support for vision
* style and new command line
* all-linear
* format
* delete old examples
* get image
* upcast
* new test
* modified test
* new strat for tokenizer
* rm token transfer
* integrate vision in dpo example
* format
* add FDivergenceType back
* precommit
* pillow test dep
* optional prompt
* `evaluation_strategy` to `eval_strategy`
* revert vsft change (oos)
* update test
* test
* comment and support more in process
* update process
* update doc for vdpo
* caution about limited support
* Update docs/source/dpo_trainer.mdx
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* revert DPO example changes
* cleaner way to check if a model is vision
* comment
* update vdpo example
* rename
---------
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* add a test case for num_train_epochs
* fix ci
* quick change
* disable push to hub
* debug windows ci
* try another fix
* skip subprocess tests on windows
Current handling of `response_masks` inside `batch_forward_pass`
function does not take padding into consideration which results with
shape unmatch during masking. Since response mask is a mask tensor of
response tokens, response tokens should not be concatenated with a
`torch.zeros(query_length)` and masking operation should be done without
slicing.
Remove the concatenation of the response mask, remove the slicing from
the response mask since response mask already has the length of `end -
start + 1`, which is equal to length of `masks[j, start:end]`.
* Step 1: update ppo_trainer and hello_world example
* Step 2: Refine comments and add parameter type
* Step 2: Add missing parameter comments
* Step 1: Organize ptx loss into a function and add ptx_loss to train_stats
* Step 1 updates: add comment to ptx_loss function, fix a bug and add warning message
* Step 2: 1) Add ppo_ptx trainig example as ppo; 2) separate pretrain data fetch and iterate
* Step 2: Remove loss from columns_to_log in ppo_ptx example
* Remove data set revision in load imbd dataset
* Run pre-commit and fix format issues
* Initial draft of f-divergence fn
* Update f-divergence to avoid overflow
* fix test errors and comments
* Add Unit tests for dpo loss with alpha and js div f
* Adjust format
* Fix test error
* Reverse this update
* Add test cases
* Reverse un-needed updates
* Update code style
* Try to fix code fmt error
* remove extra end line
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* working trl parser with config
correctly overrides yaml config with command line arguments
adds return_remaining_strings
when return_remaining_strings is False, raises error if yaml contains
extra args that are not in the dataclasses
simpler and cleaner than previous yaml parsing and merging
addresses #1733
* lowercase trlparser
* Add test for skipping preproc if packing=True
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
* Allow skipping of validation for packing=True
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
* Use dummy dataset in no packing preproc test
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
---------
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
* Don't override optimize_device_cache when optimize_cuda_cache is not provided
Raise an exception when both optimize_cuda_cache and optimize_device_cache are set
* Minor fix
* [ORPO] Correct label mask for pad tokens
Recent [fix](57aebe9c36) for calculating NLL loss for a whole sequence introduced a bug. When input_ids are copied to labels, pad tokens are not masked.
This PR aims to path this by masking labels based on the attention mask.
* -100 -> label_pad_token_id
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* fixed adding bos and eos token unconditionally
* fixed typo of tokenizer -> self.tokenizer. Also added update to ORPO
* fixed code quality, and added BOS/EOS fix to KTO
* code reformatting with pre-commit run --all-files
* bug fix: check input id length before checking for EOS/BOS
* add `Loss Functions` section in the doc.
* add bce loss with reward shift in KTOTrainer
* add underlying distribution matching
* update example to use underlying distribution matching
* add config description
* fix 'referenced before assignment' error
* add 'bco' and 'udm' test cases
* run pre-commit
* add `scikit-learn` dependency
* raise error is sklearn is not available
* call TrainingArguments().__post_init__() for proper init
* initial DPOConfig
* fix doc string
* use DPOConfig
* fix missing import
* fix DpoScriptArguments
* override args config when given in init
* use DPOConfig
* fix output dir name
* over-ride with depreicated arguments if given
* use DPOConfig in tests
* fix comment
* add custom_message
* use dataset_train_name and dataset_test_name
* beta is also in the training_args
* fix loss_type docs
* Update trl/commands/cli_utils.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update trl/commands/cli_utils.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update trl/commands/cli_utils.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* use DPOScriptArguments
---------
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* adds option to skip dataset preparation in SFTTrainer
* before changing the template
* adds support for new schema
* a few fixes to data collator to support new schema
* updates args
* precommit
* adds sys prompt to chat template and other fixes
* updates template, fixes collator for multiple images
* precommit
* rename vsft to vstf_llava
* adding integration tests
* adds integration test for vsft
* precommit
* adds back chat template
* docs
* typo
* adds eval, precommit
* adds peft launch args
* formatting
* fixes no deps tests by checking if PIL lib exists
* Update __init__.py
---------
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* Correct ppo_epochs usage
The usage of ppo_epochs is incorrect here.
In 8534f0edf8/trl/trainer/ppo_config.py (L104C8-L104C58)
the ppo_epochs was described as "Number of optimisation epochs per batch of samples".
However, here it is used as the usual epoch number, in which you do one iteration over the training dataset.
* Update ppo_trainer.mdx
* Update docs/source/ppo_trainer.mdx
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
the type hint forces a list which raises a "all-linear" layer not found. forcing a string makes it work. updating the type hint to `Union[str, list[str]]` also raise a parsing error
* Refactor test
* Make batched tokenizer
* Make is FAST 🔥!
* Hack to the max
* Run on main process
* Refactor
* Add unit test
* f
* r
* Refactor
* Remove bs
* Refactor to tokenize once
* Add typing
* Add test for KL getter
* Add `use_cache=False` in `concatenated_forward`
Prevents `ORPOTrainer` from using the cache, as it's not required for computing the logits and runs into conflicts with Flash Attention 2
* Add `use_cache=False` to `concatenated_forward`
Co-authored-by: Kashif Rasul <kashif@users.noreply.github.com>
---------
Co-authored-by: Kashif Rasul <kashif@users.noreply.github.com>
* add CPOTrainer
* add docs
* fix formatting
* removed precompute_ref_log_probs arg
* remove precompute_ref_log_probs
* typos
* finish cpo trainer doc
* remove redundant lines
* typo
* formatting
* compute chosen nll loss also for enc-dec models
* fix gradient error of inplace operation for enc-dec models
* formatting
* use CPOConfig
* formatting
* use model_init_kwargs from CPOConfig
* comments in example
* fix doc string
* fix typo in docstring
* update year
* fixed typo
* use preference dataset
* fix learning rate
* move dataset_num_proc to configs
* Update cpo paper link from HF: cpo_trainer.mdx
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* update description for CPO: cpo_trainer.mdx
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* remove _prepare_deepspeed for cpo
Because CPO does not need init for reference model
* Add explanation to CPO loss
* format
* fix bug when lengths are given
* add CPOTrainer to README
* fix grammer
---------
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* CLI V1
* v1 CLI
* add rich enhancmeents
* revert unindented change
* some comments
* cleaner CLI
* fix
* fix
* remove print callback
* move to cli instead of trl_cli
* revert unneeded changes
* fix test
* Update trl/commands/sft.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* remove redundant strings
* fix import issue
* fix other issues
* add packing
* add config parser
* some refactor
* cleaner
* add example config yaml file
* small refactor
* change a bit the logic
* fix issues here and there
* add CLI in docs
* move to examples/sft
* remove redundant licenses
* make it work on dpo
* set to None
* switch to accelerate and fix many things
* add docs
* more docs
* added tests
* doc clarification
* more docs
* fix CI for windows and python 3.8
* fix
* attempt to fix CI
* fix?
* test
* fix
* tweak?
* fix
* test
* another test
* fix
* test
* fix
* fix
* fix
* skip tests for windows
* test @lvwerra approach
* make dev
* revert unneeded changes
* fix sft dpo
* optimize a bit
* address final comments
* update docs
* final comment
---------
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Add use_bnb and load_in_4bit arguments.
Make it optional and not supported on all platforms
Signed-off-by: yuanwu <yuan.wu@intel.com>
* Change the use_reentrant default value to False
If the default value of gradient_checkpointing is True, set the
use_reentrant default value as False. Because the following error
happens.
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 191 with name base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine hooks have fired for this particular parameter during this iteration.
Signed-off-by: yuanwu <yuan.wu@intel.com>
* Add model_dtype for loading the model in model_dtype
Signed-off-by: yuanwu <yuan.wu@intel.com>
* Reformate the patch
Signed-off-by: yuanwu <yuan.wu@intel.com>
---------
Signed-off-by: yuanwu <yuan.wu@intel.com>
* fix 8-bit multi-gpu training bug see https://github.com/huggingface/trl/issues/1348
* Update dpo_llama2.py
make gradient_checkpointing_kwargs configurable.
* Update dpo_llama2.py
remote unnecessary config of device_map
* format with make precommit
---------
Co-authored-by: ubuntu <lili@liveremier.ai>
Clarify that language models must be transformers models for text. This is a bit redundant with intro description, but attempts to better address a question that that comes up (issue 1257).
Closes: #1257
In the document as it is now the best practice recommendations don't seem neither consistent nor correct.
For example, the documentation links a tweet with a recommendation to merge adaptors into a quantized model, and a script that supposedly illustrates how to apply that recommendation. But the script actually does the opposite of what the tweet recommends, first dequantizing the model.
There are similar inconsistencies/ambiguities further in that paragraph. For example, saying that using an unquantized model would lead to lower performance (I changed it to "higher memory demand").
Overall, I updated the paragraph to improve consistency and provided links to slightly more evidence-based merging recommendations.
Both the argument's name as well as the value need to be renamed.
Otherwise we get both
NameError: name 'train_dataset' is not defined
and
TypeError: PPOTrainer.__init__() got an unexpected keyword argument 'train_dataset'
* Update dpo_trainer.py
update reference_free parameter for dpo_loss
* Update dpo_trainer for reference_free case
updated the docstring typo and set device parameter to ref_logratios tensor
* Remove stray commas from test data
* Codemod Unittest assertions to bare asserts
* Make `assertAlmostEqual` tests more idiomatic
* DRY some test strings
* Update dpo_trainer.py
Added support for num_proc to tokenize the training dataset.
* Update dpo_trainer.py
added type in the new num_proc variable
* added test case
* add test case
* fix type
---------
Co-authored-by: imraviagrawal <ravi.agrawal@umass.edu>
Co-authored-by: Ravi Agrawal <raviagrawal@Ravis-MacBook-Pro.local>
* fix: only load data on main process
* define is_main_process once
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* avoid re-initializing PartialState on train dataset check
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* avoid re-initializing PartialState on eval dataset check
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* process dataset on main first to take advantage of caching
* fix typo in docs
* use decorator to manage state
* Revert "fix typo in docs"
This reverts commit 0880a188812a698f7106853245ce1ba96a036831.
* Revert "Revert "fix typo in docs""
This reverts commit ff7ee33fbeedcd0032b728d86a17cfcb10e43f9b.
* Revert "use decorator to manage state"
This reverts commit 7ac7a45949f621941fedc522f0d2ca7b29367c3a.
* use is_local_main_process instead of is_main_process
* fix: use context manager instead of attribute
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* Update trl/trainer/sft_trainer.py
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
---------
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* fix: improve error message when `pad_token_id` is not configured
* Add test for error raised when pad_token is None
* Fix pre-commit errors
* Fix error in the test environment
* Fix FSDP error
Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP.
* Apply suggestions from code review
force return_dict
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
---------
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* Fix reported KL in PPO trainer
previously this was always reporting the estimated KL, even when using `kl_penalty = 'full'` (or `abs`, etc).
Now we return the actual KL calculated in `compute_rewards()`, and report that.
* fix test
* Fix instruction token masking
Fix instruction token masking if the first instruction is tokenized differently than the others, or in general if no instruction is detected before the first response.
* Bugfix for edge case
(in case either of the templates isn't found at all, ...idxs[0] might not exist)
* Add test for instruction masking fix
* Allow separate devices for target/ref models.
* Remove original/duplicate.
* Cleanup original, black formatting.
---------
Co-authored-by: Jon Durbin <jonathan@convai.com>
* Address issue #1122
Issue [#1122](https://github.com/huggingface/trl/issues/1122)
takes care of an inconsistency between `_prepare_packed_dataloader`
and `_prepare_non_packed_dataloader`
* made attention_mask field in ConstantLengthDataset a tensor
* add: support for peft in ddpo.
* revert to the original modeling_base.
* style
* specify weight_name
* explicitly specify weight_name
* fix: parameter parsing
* fix: trainable_layers.
* parameterize use_lora.
* fix one more trainable_layers
* debug
* debug
* more fixes.
* manually set unet of sd_pipeline
* make trainable_layers cleaner.
* more fixes
* remove prints.
* tester class for LoRA too.
* add peft_module_casting_to_bf16 in DPOTrainer
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
* Update trl/trainer/dpo_trainer.py
---------
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
* SFT Trainer enhancements
* remove the callback `PeftSavingCallback`
* bump the version of transformers to `4.31.0`
* remove `PeftSavingCallback` from all places.
* use logprobs if it exists in the batch
* add features to tokenized batch if in data
* make get_batch_logps a static method
* add tokenize_batch_element dataset mapper
* Remove tokenize_batch method from DPODataCollator
* Initial sketch to precompute reference_logps
* run ref model via pytorch dataloader
* add a padding helper
* clean up the helper
* use logprob item()
* default behaviour
* clean up collator
* add docstring
* copy data back to cpu if needed
* use get_train_dataloader methods
* fix tests
* rename: more explicit variable name precompute_ref_log_probs
* improve comment
* update comment
* Update trl/trainer/dpo_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* refactor models into setup parameters
* parametrize precompute_ref_log_probs flag
* remove useless test
* Update trl/trainer/dpo_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update tests/test_dpo_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update tests/test_dpo_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update trl/trainer/dpo_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Update trl/trainer/dpo_trainer.py
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* update function arg name
* distinguish between pad token_id and mask values
* fix tokenization #932 by @nrailg
* fix test
* undo test refactor
* new line
* undo breaking change
* Update token counter condition to allow Llama tokenizer
* Acount for merged tokens on certain tokenizers such Llama-2 tokenizer
* Update variable name to match list value when truncating response
* map function on multi-gpu and gather
* Add test cases for DPOTrainer tokenization step
* revert since we need the prepeared model
* Use gather_with_metrics on ref_logps precomputation to keep original dataset size
* Add flag to keep track of when ref_logps are precomputed
* make variable names private
* formatting
* if precompute_ref_log_probs is true one can use non-peft to populate log-probs
* Use tokenizer padding token unless padding_value is set
* Move dataset.map(tokenize_batch) outside dataloader to avoid serialization errors
* eval can be none
* move to cpu to avoid gpu oom
* remove unneeded cast to float32
* remove unneeded
* fix merge
* fix merge
* fix merge
* add precompute log-prob status via tqdm
* Truncate answer if too longer once prompt has been truncated
* Add prompt_input_ids to batch to enable generation
* formatting and add lora example
* fix formatting
* Tokenize row now expects sample to have space on chosen/rejected for llama
* Revert "Tokenize row now expects sample to have space on chosen/rejected for llama"
This reverts commit dd07a10fe8c19b6ac6bbcc7b8144189756710d52.
* raise error when using zero-3 with precompute_ref_log_probs
---------
Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
Co-authored-by: Shoaib Burq <saburq@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* first attempts at refactor of dpo trainer
* removed extra stuff in prediction step
* import fixes
* label names
* all working
---------
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Update utils.py
update compute_accuracy to deal with the cases where str_chosen and str_rej got the same scores, which is probably what the developers don't want
* Update utils.py
updated so only warning is reserved
* Update trl/trainer/utils.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
---------
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* reward adapter loaded as part of init
more flexible, clearer args
* fixed script for multi gpu
unwrap model since it is DDP
downside, with reward adapter it seems we need to use
find_unused_parameters=True
* remove gradient from reward score calculation
* change supported_args back to None
func.__code__.co_varnames was used to count the function arguments for formatting_func. This code actually counted the function variables rather than function parameters.
* fix: dpo trainer ds config
ref_model and model shouldn share the same ds config, so we shouldn modify the ds config directly. or else, it will cause sth wrong when init deepspeed engine
* fix: import sort
import sort by isort
* adds model kwargs to SFT and DPO trainers
* adds checks for model_kwarg passing when model is not str
* changed warning to ValueError
* renames model_kwargs to model_init_kwargs
* corrects argument names in
* First unwrap the model and then process the input embeddings
* Changed base_model to base_model.model to stay consistent with peft model abstractions
* initial skeleton
* iterative trainer for decoder only
* iterative trainer unittest
* encoder_decoder support
* fix typo in unittest
* init
* fix typo
* fix init typo
* adding loggings and safety checker
* fixed minor issues
* doc
* table of contents update
* add test for seq2seq2 models
* change year
* adding text as step input
* precommit
* fixing typo
* run precommit
* fixing typo in safety checker
* fix text tokenization issue
* add truncate and inherit from trainer
* remove iterative config from tests
* remove iterative config from init
* fix peft model
* change truncation side based on truncation_mode
* removed iterativeconfig autodoc
* fixed typo in trainer.mdx
* remove mention of iterative config in docs
* make sure optimizer and scheduler are created
* adding max_steps to test
* remove log_stats fn
* remove compute loss
* fixing encoder decoder detection
* fix PPODecorator
* run precommit
* fix testing
* fix small typos in iterative trainer
* adapted function log and eval
* make use of forward hooks
* correctly delete attributes
* fix RM DPP issues
* revert unneeded changes
* more fixes
* fix diff
* fix
* propagate to SFT
* Update examples/scripts/reward_modeling.py
* propagate the fix on DPO trainer
* add to example scripts
* trigger CI
* adding specific dict structure to tracker_kwargs doc string to enable changing tracker params like wandb experiment name for ease, avoids needing to go deep into accelerate source
* push changes
* set default dict
* refactor
* use typing extension
---------
Co-authored-by: Laura O'Mahony <lauraomahony@L-MacBook-Pro.fritz.box>
Co-authored-by: Costa Huang <costa.huang@outlook.com>
* Add whiten ops before compute advatanges
1. From LLaMA 2 paper, it says:
```
We also find it important to whiten the final linear scores (shown here by reversing the sigmoid with the logit function) in order to increase stability and balance properly with the KL penalty term (β) above.
```
2. This function is taken from [alpaca_farm](64e489c67e/src/alpaca_farm/rl/ppo_trainer.py (L86))
* Fix type def of self
---------
Co-authored-by: Lin Junpeng <linjunpeng@sensetime.com>
* add SLiC hinge loss
* fix links
* beta when loss is hinge is reciprocal of margin
* fix tests
* fix docs
* doc strings
* fix method name
* raise error if loss_type is not correct
* Update trl/trainer/dpo_trainer.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* fix formatting
---------
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* init
* run
* Update custom eval loop to aid DPO debugging (#770)
* sample_during_eval -> generate_during_eval
* Remove unused return_tokens
* Add import utils for W&B, prevent test fails
* Optimize dataloader random batch selection
* Separate prompt and response in logs
Makes it much easier to quickly read the starts of the generations
* Simplify logging
* reset eval steps
* manual merge fixes
* revert merge
* remove self.max_length
* style
* fix max_length
---------
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
* Update utils.py
* correctly assign instruction_template in DataCollatorForCompletionOnlyLM
* correctly use instruction_token_ids in DataCollatorForCompletionOnlyLM
* DataCollatorForCompletionOnlyLM: fix instruction_template / response_template type check: handle cases where instruction_template is None
* make precommit
* Test DataCollatorForCompletionOnlyLM with pre-tokenized instruction_template
* Start adding margin to RM training
* Fix typo and cleanup
* Fix incompatibilities when not using margin
* Format using 'make precommit'
* Add documentation and test for reward trainer
* Run 'make precommit'
* Update docs/source/reward_trainer.mdx
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* Fix missed merge conflict in reward trainer docs
---------
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This change avoids setting report_to="all" (the default behavior in
transformers v4), which could lead to unexpected error messages for
inexperienced users. Note that the default value of report_to will
change anyway to "none" in transformers v5.
* docs: add initial version of docs for `PPOTrainer`
* Apply suggestions from code review Leandro
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Apply suggestions from code review
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* updated docs based on feedback leandro
- specified reference to reward model
- added batched generator
- added line of saving model
- remove reference model
* Apply suggestions from code review
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
---------
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
It took a while to understand why zero-masked tokens are one less than the length of query tokens.
If I got it correctly, it is because the first logit (and state-value) from the outputs refers to the second token in the query.
Hope this comment can be helpful to others who may encounter a similar question in the first-pass reading of the code :)
* update to `prepare_model_for_kbit_training`
from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice
is also the workaround for #694
* workaround for gradient checkpointing issue
calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop
also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
* fix PeftConfig loading from a remote repo.
* failed to catch hf_hub_download() EntryNotFoundError.
At least in huggingface-hub 0.10.1, the error for "not found" is:
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error
* pass precommit checks.
* replace some bare excepts with specific codes
* catch LocalEntryNotFoundError additionally.
* Broken first pre-draft
* Change structure to leverage user-definition of pipeline
- reward function, pipeline and scheduler will be left to the user to define
- pipeline and scheduler contract interfaces is what the framework will define
- none of this actually works
* Incremental progress: trying to get the set-up running e2e
* Incemental progress: successfully running code
* Incremental progress: running setup
Next steps: fix accelerate gardient acc assertion error when we set value > 1
* Formatting and code standards
* Incremental prog: break down code a bit
- new config flag to notify code of async reward fetching
- break off image handling code and throw it on to user to define how to handle it
- more code restructuring
* Incremental progress:
1. More code sectioning off into own methods (more for readibility than anything else)
* Incremental progress:
1. clear up contracts
2. type the reward function and prompt function
* Code shuffling and expansion of tracker, accelerator config args to beyond wandb
* More small additions
Add tensorboard logging function
Remove wandb logging function for now
Consolidate the data that get's thrown to the logging function
Add README
* Formatting
* Formatting
* Remove print statement
Make tensorboard tracking the sole tracking for the training example
* 1. start of testing
2. more refactoring
3. start of docstrings
4. parameter rename
* Basic Tests
Formatting
* Docs according to the norm
* Doocs, credits and rename file
* docs and corrections
* Put example config to respectable state
* Add recent run params
* Correct the name of the library
* Move requirements to EXTRAS
* - Add license banners
- Guard import of DDPO functions with if_diffusers_available
- doc strings for output types
* Add snippet to pull weights from huggingface + banner
* Test if passes on CI/CD
* Minor refactor
* Test dummy unet
* Possible fix for randomly disappearing attribute
* Shuffling arrangement in hopes of meeting memory requirements
* Proper Names
* Appease windows memory allocator issues for the cpu device
* Remove print statements
* Update docs/source/ddpo_trainer.mdx
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Update docs/source/ddpo_trainer.mdx
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* Add docstrings and correct url
* Spelling and grammar
* Add more documentation and commandline parsing for example script
* Markdown synatx correction
* Revert accidentally committed file and put the correct one
* More docs
* Remove subclassing and add docs for leftoover subclassing
* Put back subclassing
* Reward metadata and more docs
* Remove save_load_save flag
* Grammar
* Update trl/trainer/ddpo_trainer.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Update tests/test_ddpo_trainer.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Update setup.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Update examples/scripts/stable_diffusion_tuning.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Edits to the readme for DDPO
* Renamed modelling_sd_base to modeling_sd_base
* Insert try and catch for bitsandbytes import
* Change to smaller model
* Correct tolerance for floating point comparison
* Remove dummy unet and move to check is isfinite
* 1. Expand interface to ensure other Stable Diffusion pipelines could be covered
2. remove extra identification
* 1. Remove most of the asserts except for one and add value error
2. Remove default run name
* Remove progress bar
* Docs
* Put back progress bar
* 1. Revert progress bar deletion completely
2. grammar
3. relocate line
* Experiment
* Remove experiment parts and format properly
* Change formatting and edit info in docs
* Grammar
* Refactor out most of nitty gritty of loading/saving from trainer to example model
Readme addition
* Docs additions
* 1. Proper formatting fr the test file
2. incorporatioon of pull frm hub if fails try local
3. doc strings for interface
4. highlight in the trainer, that this is only ready fr sd pipelines
* Resources for before and after
* Attempt at embedding images
* Post testing example script
* Consistent naming and document edits in light of new args
* Remove resources and add CDN links in html in doc file
---------
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Update dpo_trainer.py
Make ref_model optional.
* add tests for ref_model=None
* better handling for ref_model=None
* Update dpo_trainer.py
Correct docstring
* move instantiation of self.ref_model closer to model
* use .disable_adapters instead of .get_base_model
* handle ref_model=None in get_batch_samples
* fix failing test in dpo_trainer due to disable_dropout_in_model
* Update trl/trainer/dpo_trainer.py
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
---------
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* Add reward/score scaling/normalization/clipping
* Run pre-commit to fix styles and remove some dupe code
* Make sure score module and pretrained_model have the same dtype
* Add multi_adapter_rl_v2.py
* Add log_with
* Add more verbose help message for use_score_norm
* Fix score clipping for float16
* Minor fix
* WIP
* improve inference docs
* improve training faq
* update toctree
* fix toctree
* fix improve blog
* improve blog
* fix customization
* reword faq a bit
* reword inference a bit
* add references back
* integrate feedback from code review
* fix link in html
* Allow tokenized ids in DataCollatorForCompletionOnlyLM. Add test and docs
* Formatting
* Documentation
* Remove unused code from test
---------
Co-authored-by: Ivan Sanchez <ivan.sanchez@zyte.com>
* initial stack-llama-2 scripts
* removed unused function
* add accelerate
* link to stack-llama-2 code
* running the model
* pre-commit fixes
* use the merge_peft script
* Add section on logged metrics
* refactor grad accum
* quick fix
* use correct place to step optim
* push changes
* cleanup and fix division by zero in `masked_var`
* revert back changes
* use unbiased var
* deal with division by zero
* add test case
* calculate advantage only once
* format
* add warning
* add more warnings
* quick fix
* remove unhelpful warning
* fix test cases
* fix test cases
* bump version given the breaking change
* black
* refactor
* update test cases
* error out
* push changes
* remove exact div
* add comments
* Add unit test to DataCollatorForCompletionOnlyLM to reproduce the bug.
* Change comparison target from examples[i][input_ids] to batch[labels][i] in DataCollatorForCompletionOnlyLM
* added DataCollatorForChatCompletionOnlyLM
* added simple test
* merged the two collators and fixed ### in completion
* fix response template
* fixing ordering in test
* quality
* fixed minor comments & make doc
* chat test back
* Update tests/test_sft_trainer.py
---------
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* remove response/pairs from the DPO side
* Simplify get_hh helper function
* removed unused import
* update tests and docs for dpo_trainer
---------
Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
Co-authored-by: Shoaib Burq <saburq@gmail.com>
* added sfttrainer and rmtrainer example scripts.
* added few lines in the documentation.
* moved notebooks.
* delete `examples/summarization`
* remove from docs as well
* refactor sentiment tuning
* more refactoring.
* updated docs for multi-adapter RL.
* add research projects folder
* more refactor
* refactor docs.
* refactor structure
* add correct scripts all over the place
* final touches
* final touches
* updated documentation from feedback.
This PR fixes a bug in `_get_current_device()` where the global process index was being returned instead of the local one.
With this fix, it is possible to run training in **multi-node** environments and avoid the dreaded `RuntimeError: CUDA error: invalid device ordinal` :)
* Debug the tortuous logic in `_prepare_dataset` function
There are two issues with the previous `_prepare_dataset` function.
1. Tortuous and burdensome logic: the `is_already_dataset` variable is confusing and not helpful. So, remove it.
2. The comments and the logics do not match.
For instance, in the previous version, the comments said "check if torch dataset ... and do nothing". However, when "dataset" is a torch.utils.data.Dataset and `packing = True`? It will still move into the _prepare_non_packed_dataloader(...) function call.
The corrected version will do nothing if the dataset is already a torch dataloader/dataset/ConstantLengthDataset.
* Lint: sft_trainer.py
* Lint empty line
* v1 of alpaca datacollator
* make sure to match the response tokens
* add test
* add it in main init
* add check
* adapt test
---------
Co-authored-by: Costa Huang <costa.huang@outlook.com>
* First draft of best-of-n sampler class
* Formatting
* Add best-of-n class to init
* Rearrange files
* Correction
* Make sure input query is in shape
* check for numpy.ndarray type
* Fix for shapes and types AND linter fixes
* Make reward pipeline a callback for more broader application
* Documentation for best-of-n sampler class usage
* Docs update for best-of-n class
* Doc fixes for best-of-n sampler class
* Remove colon from new addition
* Change user callback output type and associated side-effects of said change
* Relocate param because of collision
* Documentation update
* Make input param keyword easier to grasp
* Remove comments and add docstrings
* Tests and fixes for best_of_n sampler class
* Change input arg name
* Formatting
* Removed unnecessary cloning
description:Submit a bug report to help us improve TRL
labels:["bug"]
body:
- type:markdown
attributes:
value:|
Thanks for taking the time to fill out this bug report! 🤗
🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
- type:textarea
id:reproduction
validations:
required:true
attributes:
label:Reproduction
description:|
Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
If you have code snippets, error messages, stack traces please provide them here as well.
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
value:|
```python
from trl import ...
```
outputs:
```
Traceback (most recent call last):
File "example.py", line 42, in <module>
...
```
- type:textarea
id:system-info
attributes:
label:System Info
description:|
Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ...
You can get this information by running `trl env` in your terminal.
placeholder:Copy-paste the output of `trl env`
validations:
required:true
- type:checkboxes
id:terms
attributes:
label:Checklist
description:|
Before submitting, please confirm that you've completed each of the following.
If an item doesn't apply to your issue, check it anyway to show you've reviewed it.
options:
- label:"I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))"
required:true
- label:"I have included my system information"
required:true
- label:"Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
required:true
- label:"Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
description:Submit a proposal/request for a new TRL feature
labels:["Feature request"]
body:
- type:textarea
id:feature-request
validations:
required:true
attributes:
label:Feature request
description:|
A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
- type:textarea
id:motivation
validations:
required:true
attributes:
label:Motivation
description:|
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
- type:textarea
id:contribution
validations:
required:true
attributes:
label:Your contribution
description:|
Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md)
description:Submit a proposal/request to implement a new trainer for a post-training method
labels:["New trainer"]
body:
- type:textarea
id:description-request
validations:
required:true
attributes:
label:Method description
description:|
Put any and all important information relative to the method
- type:checkboxes
id:information-tasks
attributes:
label:Open source status
description:|
Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`.
options:
- label:"The method implementation is available"
- label:"The model weights are available"
- label:"The training datasets are available"
- type:textarea
id:additional-info
attributes:
label:Provide useful links for the implementation
description:|
Please provide information regarding the implementation, the weights, and the authors.
Please mention the authors by @gh-username if you're aware of their usernames.
Congratulations! You've made it this far! You're not quite done yet though.
Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
Once you're done, someone will review your PR shortly. They may suggest changes to make the code even better.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request),
Pull Request section?
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
Everyone is welcome to contribute, and we value everybody's contribution. Code
contributions are not the only way to help the community. Answering questions, helping
others, and improving the documentation are also immensely valuable.
Before you start contributing make sure you installed all the dev tools:
It also helps us if you spread the word! Reference the library in blog posts
about the awesome projects it made possible, shout out on Twitter every time it has
helped you, or simply ⭐️ the repository to say thank you.
However you choose to contribute, please be mindful and respect our
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
## Ways to contribute
There are several ways you can contribute to TRL:
* Fix outstanding issues with the existing code.
* Submit issues related to bugs or desired new features.
* Implement trainers for new post-training algorithms.
* Contribute to the examples or the documentation.
If you don't know where to start, there is a special [Good First
Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
> All contributions are equally valuable to the community. 🥰
Before you start contributing make sure you have installed all the dev tools:
```bash
pip install -e ".[dev]"
pip install -e .[dev]
```
## Did you find a bug?
## Fixing outstanding issues
* Ensure the bug was not already reported by searching on GitHub under Issues.
* If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
* Be sure to add the complete error messages.
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
#### Did you write a patch that fixes a bug?
## Submitting a bug-related issue or feature request
* Open a new GitHub pull request with the patch.
* Ensure that your PR includes a test that fails without your patch, and pass with it.
* Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback.
## PR submission guidelines
### Did you find a bug?
* Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
* Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
* Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
* Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
* If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
The TRL library is robust and reliable thanks to users who report the problems they encounter.
### Before you submit a PR
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
First you want to make sure that all the tests pass:
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
* A short, self-contained, code snippet that allows us to reproduce the bug in
less than 30s.
* The *full* traceback if an exception is raised.
* Attach any other additional information, like screenshots, you think may help.
To get the OS and software versions automatically, run the following command:
```bash
make test
trl env
```
Then before submitting your PR make sure the code quality follows the standards. You can run the following command to format and test:
### Do you want a new feature?
If there is a new feature you'd like to see in TRL, please open an issue and describe:
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?
Whatever it is, we'd love to hear about it!
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
3. Provide a *code snippet* that demonstrates the feature's usage.
4. If the feature is related to a paper, please include a link.
If your issue is well written we're already 80% of the way there by the time you create it.
## Do you want to implement a new trainer?
New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL:
* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods.
* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM.
Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
* A short description of the method and a link to the paper.
* Link to the implementation if it is open-sourced.
* Link to model weights trained with the method if they are available.
Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration:
* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py)
* RL-based optimisation: [`rloo_trainer.py](./trl/trainer/rloo_trainer.py) and [`rloo_config.py](./trl/trainer/rloo_config.py)
* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py)
## Do you want to add documentation?
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested!
## Submitting a pull request (PR)
Before writing code, we strongly advise you to search through the existing PRs or
issues to make sure that nobody is already working on the same thing. If you are
unsure, it is always a good idea to open an issue to get some feedback.
You will need basic `git` proficiency to be able to contribute to
TRL. `git` is not the easiest tool to use but it has the greatest
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing:
1. Fork the [repository](https://github.com/huggingface/trl) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
```bash
$ git checkout main
$ git fetch upstream
$ git merge upstream/main
```
Once your `main` branch is synchronized, create a new branch from it:
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
```bash
$ pip install -e .[dev]
```
(If TRL was already installed in the virtual environment, remove
it with `pip uninstall trl` before reinstalling it.)
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite
passes. You should run the tests impacted by your changes like this (see
below an explanation regarding the environment variable):
```bash
$ pytest tests/<TEST_TO_RUN>.py
```
> For the following commands leveraging the `make` utility.
You can also run the full suite with the following command.
```bash
$ make test
```
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR.
To apply these checks and corrections in one step, use:
```bash
$ make precommit
```
This command runs the following:
- Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
- Runs additional scripts such as adding copyright information.
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
Once you're happy with your changes, add changed files using `git add` and
make a commit with `git commit` to record your changes locally:
6. Once you are satisfied (**and the checklist below is happy too**), go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
### Checklist
1. The title of your pull request should be a summary of its contribution;
2. If your pull request addresses an issue, please mention the issue number in
the pull request description to make sure they are linked (and people
consulting the issue know you are working on it);
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
it from PRs ready to be merged;
4. Make sure existing tests pass;
5. Add high-coverage tests. No quality testing = no merge.
### Tests
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
We use `pytest` to run the tests. From the root of the
repository here's how to run tests with `pytest` for the library:
```bash
make style && make quality
$ python -m pytest -sv ./tests
```
## Do you want to contribute to the documentation?
That's how `make test` is implemented (without the `pip install` line)!
* Docs are in the `docs/` folder and can be updated there.
You can specify a smaller set of tests to test only the feature
you're working on.
### Default values guidelines
1. **Use defaults when appropriate**:
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
2. **Prioritize proven defaults**:
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
3. **Ensure safety and predictability**:
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
4. **Balance consistency and flexibility**:
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
5. **Opt-in for new features**:
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
### Writing documentation
High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
To illustrate what good documentation looks like, here’s an example of a well-documented function:
* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
* **Type Annotations:**
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
E.g., for arguments that can't be `None` and aren't required:
```python
foo (`int`, *optional*, defaults to `4`):
```
For arguments that can be `None` and are required:
```python
foo (`Optional[int]`):
```
for arguments that can be `None` and aren't required:
```python
foo (`Optional[int]`, *optional*, defaults to `None`):
```
* **String Defaults:**
* Ensured that default string values are wrapped in double quotes:
```python
defaults to `"foo"`
```
* **Dictionary Typing:**
* Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
* **Default Value Formatting:**
* Consistently surrounded default values with backticks for improved formatting:
```python
defaults to `4`
```
* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
include_variance (`bool`, *optional*, defaults to `False`):
Whether to include the variance of the dataset in the results.
Returns:
`dict[str, float]`:
A dictionary containing calculated statistics such as mean, median, and optionally variance.
"""
...
```
### Deprecation and backward compatibility
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
Example:
```python
warnings.warn(
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
"Please use the `Trainer.bar` class instead.",
FutureWarning,
)
```
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
### Working with warnings
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
#### Definitions
- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
#### Choosing the right message
- **Correct → No warning**:
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
- **Correct but deserves attention → No warning, possibly a log message**:
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
```python
logger.info("This is an informational message about a rare but correct operation.")
```
- **Correct but very likely a mistake → Warning with option to disable**:
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
```python
def my_function(foo, bar, _warn=True):
if foo == bar:
if _warn:
warnings.warn("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
# Do something
```
- **Supported but not correct → Warning**:
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
```python
def my_function(foo, bar):
if foo and bar:
warnings.warn("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
# Do something
```
- **Not supported → Exception**:
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
```python
def my_function(foo, bar):
if foo and bar:
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
```
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
python -m pytest -n auto --dist=loadfile -s -v ./tests/
pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)'tests/
quality:
black --check --line-length 119 --target-version py38 $(check_dirs)
isort --check-only $(check_dirs)
flake8 $(check_dirs)
precommit:
python scripts/add_copyrights.py
pre-commit run --all-files
style:
black --line-length 119 --target-version py38 $(check_dirs)
> Train transformer language models with reinforcement learning.
## What is it?
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point most of decoder architectures and encoder-decoder architectures are supported.
**Highlights:**
-`PPOTrainer`: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
-`AutoModelForCausalLMWithValueHead`&`AutoModelForSeq2SeqLMWithValueHead`: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier.
## How it works
Fine-tuning a language model via PPO consists of roughly three steps:
1.**Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2.**Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3.**Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate to far from the reference language model. The active language model is then trained with PPO.
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
## Highlights
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
- **Efficient and scalable**:
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
- Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
- Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
## Installation
### Python package
Install the library with pip:
### Python Package
Install the library using `pip`:
```bash
pip install trl
```
### From source
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
If you want to use the latest features before an official release, you can install TRL from source:
If you wish to develop TRL, you should install in editable mode:
### Repository
If you want to use the examples you can clone the repository with the following command:
```bash
pip install -e .
git clone https://github.com/huggingface/trl.git
```
## How to use
## Quick Start
### Example
This is a basic example on how to use the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
### `SFTTrainer`
Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
For a detailed example check out the example python script `examples/sentiment/scripts/gpt2-sentiment.py`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
<pstyle="text-align: center;"><b>Figure:</b> A few review continuations before and after optimisation. </p>
</div>
[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
## References
```python
fromdatasetsimportload_dataset
fromtrlimportGRPOTrainer
### Proximal Policy Optimisation
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
The language models utilize the `transformers` library by 🤗 Hugging Face.
# Dummy reward function: count the number of unique characters in the completions
defreward_num_unique_chars(completions,**kwargs):
return[len(set(c))forcincompletions]
trainer=GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
### `DPOTrainer`
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
## Development
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]
```
## Citation
```bibtex
@misc{vonwerra2022trl,
author={Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert},
author={Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
## Getting started with `examples/scripts/alignprop.py`
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
## Setting up the image logging hook function
Expect the function to be given a dictionary with keys
```python
['image','prompt','prompt_metadata','rewards']
```
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
### Key terms
-`rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
-`prompt` : The prompt is the text that is used to generate the image
-`prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
-`image` : The image generated by the Stable Diffusion model
Example code for logging sampled images with `wandb` is given below.
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).
TRL supports the Binary Classifier Optimization (BCO).
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
For a full example have a look at [`examples/scripts/bco.py`].
## Expected dataset type
The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Expected model format
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `BCOTrainer`
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
```py
training_args=BCOConfig(
beta=0.1,
)
bco_trainer=BCOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
)
```
After this one can then call:
```py
bco_trainer.train()
```
## Underlying Distribution matching (UDM)
In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts.
Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts.
If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM.
Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
```py
training_args=BCOConfig(
beta=0.1,
prompt_sample_size=512,
)
bco_trainer=BCOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
embedding_func=embedding_func,
embedding_tokenizer=self.embedding_tokenizer,
)
bco_trainer.train()
```
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
## Usage
To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query
You can use TRL to fine-tune your language model with methods like Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) using the command line interface (CLI).
Currently supported CLIs are:
#### Training commands
-`trl dpo`: fine-tune a LLM with DPO
-`trl grpo`: fine-tune a LLM with GRPO
-`trl kto`: fine-tune a LLM with KTO
-`trl sft`: fine-tune a LLM with SFT
#### Other commands
-`trl env`: get the system information
-`trl vllm-serve`: serve a model with vLLM
## Fine-tuning with the CLI
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
Before using the `sft` or `dpo` commands make sure to run:
```bash
accelerate config
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
```yaml
model_name_or_path:
Qwen/Qwen2.5-0.5B
dataset_name:
stanfordnlp/imdb
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
```
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
The chat interface is deprecated and will be removed in TRL 0.19. Use `transformers-cli chat` instead. For more information, see the [Transformers documentation, chat with text generation models](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
</Tip>
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,
and scalability. Ultimately, it depends on personal preference, needs, and goals.
</code></pre>
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Besides talking to the model there are a few commands you can use:
-`clear`: clears the current conversation and start a new one
-`example {NAME}`: load example named `{NAME}` from the config and use it as the user input
-`set {SETTING_NAME}={SETTING_VALUE};`: change the system prompt or generation settings (multiple settings are separated by a `;`).
-`reset`: same as clear but also resets the generation configs to defaults if they have been changed by `set`
-`save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
-`exit`: closes the interface
## Getting the system information
You can get the system information by running the following command:
```bash
trl env
```
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
```txt
Copy-paste the following information when reporting an issue:
Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) |
| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) |
## Contributing
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat.
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
## Quick start
This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Example script
We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch examples/scripts/cpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1\
--logging_steps 25\
--output_dir Qwen2-0.5B-CPO
```
## Logged metrics
While training and evaluating we record the following reward metrics:
*`rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
*`rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
*`rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
*`rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
*`nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
## CPO variants
### Simple Preference Optimization (SimPO)
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`].
### CPO-SimPO
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
## Loss functions
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
## Use different optimizers and schedulers
By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
When training large models, you should better handle the CUDA cache by iteratively clearing it. To do so, simply pass `optimize_cuda_cache=True` to `DPOConfig`:
At `trl` we provide the possibility to give enough modularity to users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques.
## Run on multiple GPUs / nodes
We leverage `accelerate` to enable users to run their training on multiple GPUs or nodes. You should first create your accelerate config by simply running:
```bash
accelerate config
```
Then make sure you have selected multi-gpu / multi-node setup. You can then run your training by simply running:
```bash
accelerate launch your_script.py
```
Refer to the [examples page](https://github.com/lvwerra/trl/tree/main/examples) for more details
## Use different optimizers
By default, the `PPOTrainer` creates a `torch.optim.Adam` optimizer. You can create and define a different optimizer and pass it to `PPOTrainer`:
```python
import torch
from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
You can use the new [LION optimizer from Google](https://arxiv.org/abs/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)):
Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition).
</div>
```python
# 0. imports
# pip install bitsandbytes
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass `optimize_cuda_cache=True` to `PPOConfig`:
```python
config = PPOConfig(..., optimize_cuda_cache=True)
```
## Use correctly DeepSpeed stage 3:
A small tweak need to be added to your training script to use DeepSpeed stage 3 correctly. You need to properly initialize your reward model on the correct device using the `zero3_init_context_manager` context manager. Here is an example adapted for the `gpt2-sentiment` script:
torch.distributed package provides PyTorch natives method to distribute a network over several machines (mostly useful if there are several GPU nodes). It copies the model on each GPU, runs the forward and backward on each and then applies the mean of gradient of all GPUs for each one. If running torch 1.XX, you can call `torch.distributed.launch`, like
This guide provides an overview of the dataset formats and types supported by each trainer in TRL.
## Overview of the dataset formats and types
- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*.
- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.
<table>
<tr>
<th>Type \ Format</th>
<th>Standard</th>
<th>Conversational</th>
</tr>
<tr>
<td>Language modeling</td>
<td>
<pre><code>{"text": "The sky is blue."}</code></pre>
</td>
<td>
<pre><code>{"messages": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}]}</code></pre>
</td>
</tr>
<tr>
<td>Prompt-only</td>
<td>
<pre><code>{"prompt": "The sky is"}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}]}</code></pre>
</td>
</tr>
<tr>
<td>Prompt-completion</td>
<td>
<pre><code>{"prompt": "The sky is",
"completion": " blue."}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}</code></pre>
</td>
</tr>
</tr>
<tr>
<td>Preference</td>
<td>
<pre><code>{"prompt": "The sky is",
"chosen": " blue.",
"rejected": " green."}</code></pre>
or, with implicit prompt:
<pre><code>{"chosen": "The sky is blue.",
"rejected": "The sky is green."}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}]}</code></pre>
or, with implicit prompt:
<pre><code>{"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}</code></pre>
</td>
</tr>
<td>Unpaired preference</td>
<td>
<pre><code>{"prompt": "The sky is",
"completion": " blue.",
"label": True}</code></pre>
</td>
<td>
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is green."}],
"label": False}</code></pre>
</td>
</tr>
</tr>
<td>Stepwise supervision</td>
<td>
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": ["The fractional part of 9.8 is 0.8.",
The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks:
```python
# Language modeling
language_modeling_example={"text":"The sky is blue."}
Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text).
```python
messages=[
{"role":"user","content":"Hello, how are you?"},
{"role":"assistant","content":"I'm doing great. How can I help you today?"},
{"role":"user","content":"I'd like to show off how chat templating works!"},
]
```
Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks:
```python
# Prompt-completion
prompt_completion_example={"prompt":[{"role":"user","content":"What color is the sky?"}],
"completion":[{"role":"assistant","content":"It is blue."}]}
# Preference
preference_example={
"prompt":[{"role":"user","content":"What color is the sky?"}],
"chosen":[{"role":"assistant","content":"It is blue."}],
"rejected":[{"role":"assistant","content":"It is green."}],
}
```
Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
### Types
#### Language modeling
A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
```python
# Standard format
language_modeling_example={"text":"The sky is blue."}
# Conversational format
language_modeling_example={"messages":[
{"role":"user","content":"What color is the sky?"},
{"role":"assistant","content":"It is blue."}
]}
```
#### Prompt-only
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input.
```python
# Standard format
prompt_only_example={"prompt":"The sky is"}
# Conversational format
prompt_only_example={"prompt":[{"role":"user","content":"What color is the sky?"}]}
```
For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
<Tip>
While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
# Example for language modeling type
lm_example={"messages":[{"role":"user","content":"What color is the sky?"}]}
apply_chat_template(lm_example,tokenizer)
# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
```
- The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
- In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
</Tip>
#### Prompt-completion
A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
prompt_completion_example={"prompt":[{"role":"user","content":"What color is the sky?"}],
"completion":[{"role":"assistant","content":"It is blue."}]}
```
For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216).
#### Preference
A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
preference_example={"chosen":"The sky is blue.","rejected":"The sky is green."}
# Conversational format
## Explicit prompt (recommended)
preference_example={"prompt":[{"role":"user","content":"What color is the sky?"}],
"chosen":[{"role":"assistant","content":"It is blue."}],
"rejected":[{"role":"assistant","content":"It is green."}]}
## Implicit prompt
preference_example={"chosen":[{"role":"user","content":"What color is the sky?"},
{"role":"assistant","content":"It is blue."}],
"rejected":[{"role":"user","content":"What color is the sky?"},
{"role":"assistant","content":"It is green."}]}
```
For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c).
Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
#### Unpaired preference
An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
unpaired_preference_example={"prompt":[{"role":"user","content":"What color is the sky?"}],
"completion":[{"role":"assistant","content":"It is blue."}],
"label":True}
```
For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf).
#### Stepwise supervision
A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process.
```python
stepwise_example={
"prompt":"Which number is larger, 9.8 or 9.11?",
"completions":["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.","Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."],
"labels":[True,False]
}
```
For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e).
## Which dataset type to use?
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
<Tip>
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
</Tip>
## Working with conversational datasets in TRL
Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format.
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
### Converting a conversational dataset into a standard dataset
To convert a conversational dataset into a standard dataset, you need to _apply a chat template_ to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
In TRL, the method you apply to convert the dataset will vary depending on the task. Fortunately, TRL provides a helper function called [`apply_chat_template`] to simplify this process. Here's an example of how to use it:
# {'prompt': ['<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n',
# '<|user|>\nWhere is the sun?<|end|>\n<|assistant|>\n'],
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
```
<Tipwarning={true}>
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
</Tip>
<Tipwarning={true}>
It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
# {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
# 'completion': 'It is blue.<|im_end|>\n'}
```
Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
</Tip>
## Using any dataset with TRL: preprocessing and conversion
Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions.
### Example: UltraFeedback dataset
Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset:
As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty").
By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub:
By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL.
## Utilities for converting dataset types
This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently.
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
To convert a prompt-completion dataset into a prompt-only dataset, remove the completion.
```python
fromdatasetsimportDataset
dataset=Dataset.from_dict({
"prompt":["The sky is","The sun is"],
"completion":[" blue."," in the sky."],
})
dataset=dataset.remove_columns("completion")
```
```python
>>>dataset[0]
{'prompt':'The sky is'}
```
### From preference with implicit prompt to language modeling dataset
To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`.
```python
fromdatasetsimportDataset
dataset=Dataset.from_dict({
"chosen":["The sky is blue.","The sun is in the sky."],
"rejected":["The sky is green.","The sun is in the sea."],
### From preference with implicit prompt to prompt-completion dataset
To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`.
```python
fromdatasetsimportDataset
fromtrlimportextract_prompt
dataset=Dataset.from_dict({
"chosen":[
[{"role":"user","content":"What color is the sky?"},{"role":"assistant","content":"It is blue."}],
[{"role":"user","content":"Where is the sun?"},{"role":"assistant","content":"In the sky."}],
],
"rejected":[
[{"role":"user","content":"What color is the sky?"},{"role":"assistant","content":"It is green."}],
[{"role":"user","content":"Where is the sun?"},{"role":"assistant","content":"In the sea."}],
{'prompt':[{'role':'user','content':'What color is the sky?'}],'completion':[{'role':'assistant','content':'It is blue.'}]}
```
### From preference with implicit prompt to prompt-only dataset
To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen.
```python
fromdatasetsimportDataset
fromtrlimportextract_prompt
dataset=Dataset.from_dict({
"chosen":[
[{"role":"user","content":"What color is the sky?"},{"role":"assistant","content":"It is blue."}],
[{"role":"user","content":"Where is the sun?"},{"role":"assistant","content":"In the sky."}],
],
"rejected":[
[{"role":"user","content":"What color is the sky?"},{"role":"assistant","content":"It is green."}],
[{"role":"user","content":"Where is the sun?"},{"role":"assistant","content":"In the sea."}],
{'prompt':[{'role':'user','content':'What color is the sky?'}]}
```
### From implicit to explicit prompt preference dataset
To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`].
```python
fromdatasetsimportDataset
fromtrlimportextract_prompt
dataset=Dataset.from_dict({
"chosen":[
[{"role":"user","content":"What color is the sky?"},{"role":"assistant","content":"It is blue."}],
[{"role":"user","content":"Where is the sun?"},{"role":"assistant","content":"In the sky."}],
],
"rejected":[
[{"role":"user","content":"What color is the sky?"},{"role":"assistant","content":"It is green."}],
[{"role":"user","content":"Where is the sun?"},{"role":"assistant","content":"In the sea."}],
],
})
dataset=dataset.map(extract_prompt)
```
```python
>>>dataset[0]
{'prompt':[{'role':'user','content':'What color is the sky?'}],
'chosen':[{'role':'assistant','content':'It is blue.'}],
'rejected':[{'role':'assistant','content':'It is green.'}]}
```
### From preference with implicit prompt to unpaired preference dataset
To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`].
[{"role":"user","content":"What color is the sky?"},{"role":"assistant","content":"It is blue."}],
[{"role":"user","content":"Where is the sun?"},{"role":"assistant","content":"In the sky."}],
],
"rejected":[
[{"role":"user","content":"What color is the sky?"},{"role":"assistant","content":"It is green."}],
[{"role":"user","content":"Where is the sun?"},{"role":"assistant","content":"In the sea."}],
],
})
dataset=dataset.map(extract_prompt)
dataset=unpair_preference_dataset(dataset)
```
```python
>>>dataset[0]
{'prompt':[{'role':'user','content':'What color is the sky?'}],
'completion':[{'role':'assistant','content':'It is blue.'}],
'label':True}
```
<Tipwarning={true}>
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
</Tip>
### From preference to language modeling dataset
To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
### From explicit to implicit prompt preference dataset
To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.
```python
fromdatasetsimportDataset
dataset=Dataset.from_dict({
"prompt":[
[{"role":"user","content":"What color is the sky?"}],
{'chosen':[{'role':'user','content':'What color is the sky?'},{'role':'assistant','content':'It is blue.'}],
'rejected':[{'role':'user','content':'What color is the sky?'},{'role':'assistant','content':'It is green.'}]}
```
### From preference to unpaired preference dataset
To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`].
```python
fromdatasetsimportDataset
fromtrlimportunpair_preference_dataset
dataset=Dataset.from_dict({
"prompt":[
[{"role":"user","content":"What color is the sky?"}],
[{"role":"user","content":"Where is the sun?"}],
],
"chosen":[
[{"role":"assistant","content":"It is blue."}],
[{"role":"assistant","content":"In the sky."}],
],
"rejected":[
[{"role":"assistant","content":"It is green."}],
[{"role":"assistant","content":"In the sea."}],
],
})
dataset=unpair_preference_dataset(dataset)
```
```python
>>>dataset[0]
{'prompt':[{'role':'user','content':'What color is the sky?'}],
'completion':[{'role':'assistant','content':'It is blue.'}],
'label':True}
```
<Tipwarning={true}>
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
</Tip>
### From unpaired preference to language modeling dataset
To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns.
```python
fromdatasetsimportDataset
dataset=Dataset.from_dict({
"prompt":["The sky is","The sun is","The sky is","The sun is"],
"completion":[" blue."," in the sky."," green."," in the sea."],
### From stepwise supervision to unpaired preference dataset
To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels.
The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence.
```python
fromdatasetsimportDataset
dataset=Dataset.from_dict({
"prompt":["Blue light","Water"],
"completions":[[" scatters more in the atmosphere,"," so the sky is green."],
[" forms a less dense structure in ice,"," which causes it to expand when it freezes."]],
{'prompt':'Blue light','completion':' scatters more in the atmosphere, so the sky is green.','label':False}
```
## Vision datasets
Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
A conversational vision dataset differs from a standard conversational dataset in two key ways:
1. The dataset must contain the key `images` with the image data.
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.
Example:
```python
# Textual dataset:
"content":"What color is the sky?"
# Vision dataset:
"content":[
{"type":"image"},
{"type":"text","text":"What color is the sky in the image?"}
]
```
An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly:
## Getting started with Stable Diffusion finetuning with reinforcement learning
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to be made.
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
## Getting started with `examples/scripts/ddpo.py`
The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
```batch
python ddpo.py --hf_user_access_token <token>
```
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
## Setting up the image logging hook function
Expect the function to be given a list of lists of the form
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
### Key terms
-`rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
-`reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
-`prompt` : The prompt is the text that is used to generate the image
-`prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
-`image` : The image generated by the Stable Diffusion model
Example code for logging sampled images with `wandb` is given below.
Section under construction. Feel free to contribute!
</Tip>
TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency.
We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command:
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
@ -4,12 +4,12 @@ Language models (LMs) are known to sometimes generate toxic outputs. In this exa
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/lvwerra/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
| File | Description | Colab link |
|---|---| --- |
| [`gpt-j-6b-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
| [`evaluate-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
## Context
@ -30,7 +30,7 @@ We selected the following models for our experiments to show that TRL can be eas
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have ran toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
| Model | Mean toxicity score |
|---|---|
@ -45,7 +45,7 @@ When doing PPO, it is very important to design the problem efficiently so that t
### Pre-processing the dataset
The dataset consist of prompts and their continuations, and each of them has an associated `toxicity` score.
The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score.
A `prompt` example:
```
@ -58,13 +58,13 @@ And its `continuation` value:
We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
@ -98,22 +98,18 @@ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by just speifying `num_shared_layers` argument when creating a `PPOTrainer`:
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
@ -128,13 +124,13 @@ We have decided to keep 3 models in total that correspond to our best models:
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
@ -142,7 +138,7 @@ As you can see the model converges nicely, but obviously we don't observe a very
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
The evaluation script can be found [here](https://github.com/lvwerra/trl/blob/main/examples/toxicity/scripts/evaluate-toxicity.py).
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
### Discussions
The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure there outputs are less toxic as well as useful.
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful.
Section under construction. Feel free to contribute!
</Tip>
## Multi-GPU Training with TRL
The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running
```bash
accelerate config
```
and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running:
```bash
accelerate launch train.py
```
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly.
Example, these configurations are equivalent, and should yield the same results:
| Number of GPUs | Per device batch size | Gradient accumulation steps | Comments |
| --- | --- | --- | --- |
| 1 | 32 | 1 | Possibly high memory usage, but faster training |
| 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
<Tip>
Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration.md) guide for more details.
</Tip>
## Multi-Nodes Training
We're working on a guide for multi-node training. Stay tuned! 🚀
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn).
The abstract from the paper is the following:
> While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):
1.**Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
2.**Optimization**: Maximize the log-likelihood of the DPO loss directly.
This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
## Quick start
This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
The best programming language for specific applications can vary depending on the use case and knowledge level of the programmer. Here are some general factors that can be used as input to choose the best programming language:
<strong><spanstyle="color: green;">1</span></strong> Ease of use: Some programming languages are more user-friendly than others, such as Python, Java, or Ruby. Python is popular due to its simplicity and great scalability.
<strong><spanstyle="color: green;">2</span></strong> Versatility: The ability to work with a wide range of data structures and frameworks can define the language as versatile.
<strong><spanstyle="color: green;">3</span></strong> Ease of learning: Different programming languages have different learning curves, so users must be willing to take some time to master one.
<strong><spanstyle="color: green;">4</span></strong> Community support: The broader community of developers and enthusiasts in the selected programming language can provide great support and resources.
<strong><spanstyle="color: green;">5</span></strong> Reusability: Languages that emphasize code reuse and can be easily modifiable can be more suitable for software development.
The best programming language based on these factors is subjective and depends on what the programmer intends to accomplish.
</code></pre>
## Expected dataset type
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
### Special considerations for vision-language models
The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section.
Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`.
```diff
- model = AutoModelForCausalLM.from_pretrained(model_id)
+ model = AutoModelForVision2Seq.from_pretrained(model_id)
For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
## Example script
We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch trl/scripts/dpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1\
--logging_steps 25\
--output_dir Qwen2-0.5B-DPO
```
## Logged metrics
While training and evaluating we record the following reward metrics:
-`rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
-`rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
-`rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
-`rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
## Loss functions
The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported:
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
| `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. |
| `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. |
| `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) |
| `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`BCOTrainer`]. |
| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
| `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). |
### Label smoothing
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
### Syncing the reference model
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
### RPO loss
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`.
### WPO loss
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## Accelerate DPO fine-tuning using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```diff
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
- from transformers import AutoModelForCausalLM, AutoTokenizer
+ from unsloth import FastLanguageModel
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
## Reference model considerations with PEFT
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
### Downsides to merging QLoRA before DPO (approach 2)
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
### Using option 3 - load the adapter twice
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
For example:
```python
# Load the base model.
bnb_config=BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model=AutoModelForCausalLM.from_pretrained(
"mistralai/mixtral-8x7b-v0.1",
load_in_4bit=True,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.config.use_cache=False
# Load the adapter.
model=PeftModel.from_pretrained(
model,
"/path/to/peft",
is_trainable=True,
adapter_name="train",
)
# Load the adapter a second time, with a different name, which will be our reference model.
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
To run it in each of these various modes, first initialize the accelerate
configuration with `accelerate config`
**NOTE to train with a 4-bit or 8-bit model**, please run
```bash
pip install --upgrade trl[quantization]
```
## Accelerate Config
For all the examples, you'll need to generate a 🤗 Accelerate config file with:
```shell
accelerate config # will prompt you to define the training configuration
```
Then, it is encouraged to launch jobs with `accelerate launch`!
# Maintained Examples
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
| File | Description |
| --- | --- |
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
| [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a Outcome Reward Model (ORM) on your own dataset. |
| [`examples/scripts/rloo/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
| [`examples/scripts/rloo/rloo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo_tldr.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. |
| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. |
| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
| File | Description |
| --- | --- |
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
We also have some other examples that are less maintained but can be used as a reference:
1.**[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
## Distributed training
All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.)
You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision).
### Distributed training with DeepSpeed
Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
The abstract from the paper is the following:
> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.
The key aspects of GKD are:
1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences.
2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher.
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Lewis Tunstall](https://huggingface.co/lewtun).
## Usage tips
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
*`lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
*`seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
*`beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
> [!WARNING]
> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture.
TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday).
The abstract from the paper is the following:
> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.
This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec).
## Quick start
This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ignored!). You can view the data in the dataset here:
GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**.
At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)).
### Computing the advantage
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
<Tip>
It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
</Tip>
### Estimating the KL divergence
KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows:
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types).
</Tip>
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
-`clip_ratio/high_max`: The maximum ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
## Customization
### Speed up training with vLLM-powered generation
Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, first install the package with
```shell
pip install trl[vllm]
```
Then, start the vLLM server with the desired model:
```bash
trl vllm-serve --model <model_name>
```
Then, pass `use_vllm=True` in the training arguments and run the training script:
```python
fromtrlimportGRPOConfig
training_args=GRPOConfig(...,use_vllm=True)
```
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
### GRPO at scale: train a 70B+ Model on multiple nodes
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
```sh
#!/bin/bash
#SBATCH --nodes=5
#SBATCH --gres=gpu:8
# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}"# Nodes 0, 1, 2, 3 for training
The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
1.**Input arguments**:
- The function must accept the following as keyword arguments:
-`prompts` (contains the prompts),
-`completions` (contains the generated completions),
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
- Depending on the dataset format, the input will vary:
- For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
- For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.
2.**Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion.
#### Example 1: Reward longer completions
Below is an example of a reward function for a standard format that rewards longer completions:
```python
defreward_func(completions,**kwargs):
"""Reward function that gives higher scores to longer completions."""
#### Example 2: Reward completions with specific format
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
It is designed for conversational format, where prompts and completions consist of structured messages.
```python
importre
defformat_reward_func(completions,**kwargs):
"""Reward function that checks if the completion has a specific format."""
#### Example 3: Reward completions based on a reference
Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`.
Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
```python
fromdatasetsimportDataset
fromtrlimportGRPOTrainer
# Define a dataset that contains both math and coding problems
dataset=Dataset.from_list(
[
{"prompt":"What is 2+2?","task":"math"},
{"prompt":"Write a function that returns the sum of two numbers.","task":"code"},
{"prompt":"What is 3*4?","task":"math"},
{"prompt":"Write a function that returns the product of two numbers.","task":"code"},
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None` and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability.
Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
#### Passing the reward function to the trainer
To use your custom reward function, pass it to the [`GRPOTrainer`] as follows:
```python
fromtrlimportGRPOTrainer
trainer=GRPOTrainer(
reward_funcs=reward_func,
...,
)
```
If you have multiple reward functions, you can pass them as a list:
```python
fromtrlimportGRPOTrainer
trainer=GRPOTrainer(
reward_funcs=[reward_func1,reward_func2],
...,
)
```
and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config.
Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details.
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
To address this, we recommend focusing on two key metrics first:
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
However, there are more metrics that can be useful for debugging, checkout the [logging section](logging).
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
<pstyle="text-align: center;"><b>Figure:</b> Samples without a KL penalty from <ahref="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
</div>
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
## What Is the Concern with Negative KL Divergence?
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases:
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
So how should you generate text for PPO training? Let's have a look!
## How to generate text for training?
In order to avoid the KL issues described above we recommend to use the following settings:
```python
generation_kwargs={
"min_length":-1,# don't ignore the EOS token (see above)
"top_k":0.0,# no top-k sampling
"top_p":1.0,# no nucleus sampling
"do_sample":True,# yes, we want to sample
"pad_token_id":tokenizer.eos_token_id,# most decoder models don't have a padding token - use EOS token instead
"max_new_tokens":32,# specify how many tokens you want to generate at most
}
```
With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
## How can debug your own use-case?
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
You can also explore TRL-related models, datasets, and demos in the [TRL Hugging Face organization](https://huggingface.co/trl-lib).
## Learn
Learn post-training with TRL and other libraries in 🤗 [smol course](https://github.com/huggingface/smol-course).
## Contents
The documentation is organized into the following sections:
- **Getting Started**: installation and quickstart guide.
- **Conceptual Guides**: dataset formats, training FAQ, and understanding logs.
- **How-to Guides**: reducing memory usage, speeding up training, distributing training, etc.
- **Integrations**: DeepSpeed, Liger Kernel, PEFT, etc.
- **Examples**: example overview, community tutorials, etc.
With the TRL (Transformer Reinforcement Learning) library you can train transformer language models with reinforcement learning. The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
TRL supports decoder models such as GPT-2, BLOOM, GPT-Neo which can all be optimized using Proximal Policy Optimization (PPO). You can find installation instructions in the [installation guide](installation) and an introduction to the library in the [Quickstart section](quickstart). There is also a more [in-depth example](sentiment_tuning) to tune GPT-2 to produce positive movie reviews.
You can install TRL either from PyPI or from source:
## PyPI
Install the library with pip or [uv](https://docs.astral.sh/uv/):
<hfoptionsid="install">
<hfoptionid="uv">
uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), .
```bash
uv pip install trl
```
</hfoption>
<hfoptionid="pip">
```bash
pip install trl
```
</hfoption>
</hfoptions>
## Source
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .
```
If you want the development install you can replace the pip install with the following:
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
## Usage
To get started quickly, instantiate an instance a model, and a tokenizer.
You have the choice to either provide a list of strings or a list of tensors to the step function.
#### Using a list of tensors as input:
```python
inputs={
"input_ids":input_ids,
"attention_mask":attention_mask
}
trainer.step(**inputs)
```
#### Using a list of strings as input:
```python
inputs={
"texts":texts
}
trainer.step(**inputs)
```
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
TRL Judges is an experimental API which is subject to change at any time.
</Tip>
TRL provides judges to easily compare two completions.
Make sure to have installed the required dependencies by running:
```bash
pip install trl[judges]
```
## Using the provided judges
TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
```python
fromtrlimportHfPairwiseJudge
judge=HfPairwiseJudge()
judge.judge(
prompts=["What is the capital of France?","What is the biggest planet in the solar system?"],
To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`BaseRankJudge`] and implement the [`BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`BasePairJudge`] and implement the [`BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`BaseJudge`] and implement the [`BaseJudge.judge`] method.
As an example, let's define a pairwise judge that prefers shorter completions:
Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela).
The abstract from the paper is the following:
> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive.
The official code can be found in [ContextualAI/HALOs](https://github.com/ContextualAI/HALOs).
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Younes Belkada](https://huggingface.co/ybelkada), [Lewis Tunstall](https://huggingface.co/lewtun) and Pablo Vicente.
## Quick start
This example demonstrates how to train a model using the KTO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [KTO Mix 14k](https://huggingface.co/datasets/trl-lib/kto-mix-14k). You can view the data in the dataset here:
Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:
Here are some other factors to consider when choosing a programming language for a project:
<strong><spanstyle="color: green;">1</span> JavaScript</strong>: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.
<strong><spanstyle="color: green;">2</span> Java</strong>: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.
<strong><spanstyle="color: green;">3</span> C++</strong>: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.
<strong><spanstyle="color: green;">4</span> Python</strong>: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.
</code></pre>
## Expected dataset format
KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones.
The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate.
## Example script
We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py)
To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command:
```bash
accelerate launch trl/scripts/kto.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/kto-mix-14k \
--num_train_epochs 1\
--logging_steps 25\
--output_dir Qwen2-0.5B-KTO
```
## Usage tips
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
### Batch size recommendations
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
### Learning rate recommendations
Each choice of `beta` has a maximum learning rate it can tolerate before learning performance degrades. For the default setting of `beta = 0.1`, the learning rate should typically not exceed `1e-6` for most models. As `beta` decreases, the learning rate should also be reduced accordingly. In general, we strongly recommend keeping the learning rate between `5e-7` and `5e-6`. Even with small datasets, we advise against using a learning rate outside this range. Instead, opt for more epochs to achieve better results.
### Imbalanced data
The `desirable_weight` and `undesirable_weight` of the [`KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
## Logged metrics
While training and evaluating we record the following reward metrics:
-`rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta
-`rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta
-`logps/chosen_sum`: the sum of log probabilities of the chosen completions
-`logps/rejected_sum`: the sum of log probabilities of the rejected completions
-`logits/chosen_sum`: the sum of logits of the chosen completions
-`logits/rejected_sum`: the sum of logits of the rejected completions
-`count/chosen`: the count of chosen samples in a batch
-`count/rejected`: the count of rejected samples in a batch
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
| File | Description |
|---|---|
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
<Tipwarning={true}>
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
</Tip>
## Learning to Use a Calculator
The rough idea is as follows:
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calculated number:
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
## Experiment results
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
## (Early Experiments 🧪): learning to use a wiki tool for question answering
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
<Tip warning={true}>
**Note that many settings are different so the results are not directly comparable.**
</Tip>
### Building a search index
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
...
```
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
* use the `bigcode/starcoderbase` model as the base model
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragraphs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
* used the following prompt that demonstrates the usage of the wiki tool.
```python
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
```
### Result and Discussion
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
## (Early Experiments 🧪): solving math puzzles with python interpreter
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
```python
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>8<response>
Result = 8 <submit>
Q: """
```
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard.
Upon initialization, pass one of these two options to the [`PPOConfig`]:
```
training_args = PPOConfig(..., report_to="wandb") # or "tensorboard"
```
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
## PPO Logging
Here's a brief explanation for the logged metrics provided in the data:
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
1.`env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is used to specifically monitor the reward model.
1.`env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is used to specifically monitor the reward model.
1.`env/reward_dist`: The histogram distribution of the reward obtained from the environment.
1.`objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
1.`objective/kl_dist`: The histogram distribution of the `objective/kl`.
1.`objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
1.`ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
1.`objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
Training stats:
1.`ppo/learning_rate`: The learning rate for the PPO algorithm.
1.`ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
1.`ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
1.`ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
1.`ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
1.`ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
1.`ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
1.`ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
1.`ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
1.`ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
1.`ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
1.`ppo/val/var` : The variance of the values, used to monitor the value function's performance.
1.`ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
1.`ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
1.`ppo/val/vpred`: The predicted values from the value function.
1.`ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
1.`ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
1.`ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
1.`ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
Stats on queries, responses, and logprobs:
1.`tokens/queries_len_mean`: The average length of the queries tokens.
1.`tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
1.`tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
1.`tokens/responses_len_mean`: The average length of the responses tokens.
1.`tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
1.`tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
1.`objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
1.`objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
### Crucial values
During training, many values are logged, here are the most important ones:
1.`env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
1.`ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
1.`ppo/loss/value`: it will spike / NaN when not going well.
1.`ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
1.`ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
1.`objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
1.`objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to `wandb` or `tensorboard`.
Upon initialization, pass one of these two options to the [`PPOConfig`]:
```
config = PPOConfig(
model_name=args.model_name,
log_with=`wandb`, # or `tensorboard`
)
```
If you want to log with tensorboard, add the kwarg `accelerator_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
## PPO Logging
### Crucial values
During training, many values are logged, here are the most important ones:
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment".
2. `ppo/mean_scores`: The mean scores directly out of the reward model.
3. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
### Training stability parameters:
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
1. `ppo/loss/value`: The value function loss -- will spike / NaN when not going well.
2. `ppo/val/clipfrac`: The fraction of clipped values in the value function loss. This is often from 0.3 to 0.6.
3. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
# Multi Adapter RL (MARL) - a single base model for everything
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues.
## Requirements
You just need to install `peft` and optionally install `bitsandbytes` as well if you want to go for 8bit base models, for more memory efficient finetuning.
## Summary
You need to address this approach in three stages that we summarize as follows:
1- Train a base model on the target domain (e.g. [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb)) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL.
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py)
3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
## Quickstart
Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`.
When doing PPO, before passing the model to `PPOTrainer` create your model as follows:
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32).
Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`):
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.
The abstract from the paper is the following:
> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences.
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).
## Quick start
This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
Distributed across 8 GPUs, the training takes approximately 3 hours.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.
</code></pre>
## Expected dataset type
Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Usage tips
### Use a reward model
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
```diff
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
</Tip>
### Encourage EOS token generation
We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`NashMDConfig`]:
We provide an example script to train a model using the Nash-MD method. The script is available in [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py)
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
```bash
python examples/scripts/nash_md.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--logging_steps 25\
--output_dir Qwen2.5-0.5B-NashMD-PairRM \
--warmup_ratio 0.1 \
--push_to_hub
```
## Logged metrics
The logged metrics are as follows:
*`loss/kl`: The mean KL divergence between the model and reference data.
*`objective/entropy`: The mean entropy of the model and reference data.
*`loss/score`: The mean reinforce score loss.
*`rewards/chosen`: The mean scores (according to the reward model) of the model completions.
*`rewards/rejected`: The mean scores (according to the reward model) of the mixture completions.
*`rewards/probabilities`: The mean probability (according to the reward model or judge) of the model completions chosen vs the mixture completion.
*`rewards/accuracies`: The accuracies of the Nash-MD's implicit reward model.
*`rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions.
*`logps/chosen`: The mean log probabilities of the chosen completions.
*`logps/rejected`: The mean log probabilities of the reference completions.
*`val/model_contain_eos_token`: The amount of times the model's output contains the eos token.
*`val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token.
*`beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].
*`mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
The abstract from the paper is the following:
> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator.
This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching).
## Quick start
This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
The best programming language depends on your specific needs and priorities. Some people prefer imperative programming languages (like Haskell or Lisp), while others prefer functional programming languages (like Scala or Python). It's important to consider your work style, programming environment, and project requirements when choosing a programming language.
</code></pre>
## Expected dataset type
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Usage tips
### Use a reward model
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
```diff
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
We provide an example script to train a model using the online DPO method. The script is available in [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py)
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
```bash
python examples/scripts/dpo_online.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--logging_steps 25\
--output_dir Qwen2.5-0.5B-Online-DPO-PairRM \
--warmup_ratio 0.1 \
--push_to_hub
```
## Logged metrics
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9)
*`objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model.
*`objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model.
*`objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
*`objective/rlhf_reward`: The mean RLHF reward, which is `scores - non_score_reward`. The `rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up.
*`objective/scores`: The mean scores returned by the reward model.
*`objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions.
*`rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions.
*`rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions.
*`rewards/accuracies`: The accuracies of the online DPO's implicit reward model.
*`rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions.
*`logps/chosen`: The mean log probabilities of the chosen completions.
*`logps/rejected`: The mean log probabilities of the rejected completions.
*`val/contain_eos_token`: The fraction of completions which contain an EOS token.
*`beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`OnlineDPOConfig`].
## Benchmark experiments
To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
Odds Ratio Preference Optimization (ORPO) was introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes).
The abstract from the paper is the following:
> While recent preference alignment algorithms for language models have demonstrated promising results, supervised fine-tuning (SFT) remains imperative for achieving successful convergence. In this paper, we study the crucial role of SFT within the context of preference alignment, emphasizing that a minor penalty for the disfavored generation style is sufficient for preference-aligned SFT. Building on this foundation, we introduce a straightforward and innovative reference model-free monolithic odds ratio preference optimization algorithm, ORPO, eliminating the necessity for an additional preference alignment phase. We demonstrate, both empirically and theoretically, that the odds ratio is a sensible choice for contrasting favored and disfavored styles during SFT across the diverse sizes from 125M to 7B. Specifically, fine-tuning Phi-2 (2.7B), Llama-2 (7B), and Mistral (7B) with ORPO on the UltraFeedback alone surpasses the performance of state-of-the-art language models with more than 7B and 13B parameters: achieving up to 12.20% on AlpacaEval_{2.0} (Figure 1), 66.19% on IFEval (instruction-level loose, Table 6), and 7.32 in MT-Bench (Figure 12). We release code and model checkpoints for Mistral-ORPO-alpha (7B) and Mistral-ORPO-beta (7B).
It studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.
The official code can be found in [xfactlab/orpo](https://github.com/xfactlab/orpo).
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Lewis Tunstall](https://huggingface.co/lewtun) and [Alvaro Bartolome](https://huggingface.co/alvarobartt).
## Quick start
This example demonstrates how to train a model using the ORPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
It's challenging to determine the best programming language as no one language is perfect, as the complexity of a task and the type of project are significant factors. Some popular languages include Java, Python, JavaScript, and
C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.
Here are some other factors to consider when choosing a programming language for a project:
<strong><spanstyle="color: green;">• Language proficiency:</span></strong> A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
<strong><spanstyle="color: green;">• Ease of use:</span></strong> There are tools and libraries available to make programming more accessible, so developers should choose a language that can help them get started easier.
<strong><spanstyle="color: green;">• Code readability:</span></strong> A clear and concise codebase should be easy to read and understand, especially when working with large projects.
<strong><spanstyle="color: green;">• Tool and framework support:</span></strong> There are numerous libraries available for Python, Java, and JavaScript, along with tools like IDEs and static code analysis tools.
<strong><spanstyle="color: green;">• Accessibility:</span></strong> Some languages and tools have features that make them more accessible to developers with disabilities, such as support for screen readers.
<strong><spanstyle="color: green;">• Version control:</span></strong> As your projects grow and complexity increases, version control tools can be beneficial for tracking changes.
</code></pre>
## Expected dataset type
ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
## Example script
We provide an example script to train a model using the ORPO method. The script is available in [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py)
To test the ORPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch examples/scripts/orpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1\
--logging_steps 25\
--output_dir Qwen2-0.5B-ORPO
```
## Usage tips
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## Logged metrics
While training and evaluating we record the following reward metrics:
-`rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
-`rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
-`rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
-`rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
-`log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
-`log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
-`nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
# Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)
The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner.
For more information on LoRA, see the [original paper](https://arxiv.org/abs/2106.09685).
The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685).
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
| File | Task | Description | Colab link |
|---|---| --- |
| [`gpt2-sentiment_peft.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt2-sentiment_peft.py) | Sentiment | Same as the sentiment analysis example, but learning a low rank adapter on a 8-bit base model | |
| [`cm_finetune_peft_imdb.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/cm_finetune_peft_imdb.py) | Sentiment | Fine tuning a low rank adapter on a frozen 8-bit model for text generation on the imdb dataset. | |
| [`merge_peft_adapter.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/merge_peft_adapter.py) | 🤗 Hub | Merging of the adapter layers into the base model’s weights and storing these on the hub. | |
| [`gpt-neo-20b_sentiment_peft.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py) | Sentiment | Sentiment fine-tuning of a low rank adapter to create positive reviews. | |
| [`gpt-neo-1b_peft.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neo-1b-multi-gpu/gpt-neo-1b_peft.py) | Sentiment | Sentiment fine-tuning of a low rank adapter to create positive reviews using 2 GPUs. | |
| [`stack_llama/rl_training.py`](https://github.com/lvwerra/trl/blob/main/examples/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
| [`stack_llama/reward_modeling.py`](https://github.com/lvwerra/trl/blob/main/examples/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
| [`stack_llama/supervised_finetuning.py`](https://github.com/lvwerra/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
## Installation
Note: peft is in active development, so we install directly from their Github page.
@ -76,7 +71,7 @@ The `trl` library is powered by `accelerate`. As such it is best to configure an
```bash
accelerate config # will prompt you to define the training configuration
accelerate launch scripts/gpt2-sentiment_peft.py # launches training
accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
```
## Using `trl` + `peft` and Data Parallelism
@ -123,7 +118,7 @@ The `trl` library also supports naive pipeline parallelism (NPP) for large model
This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
@ -132,12 +127,18 @@ Simply load your model with a custom `device_map` argument on the `from_pretrain
Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`.
That all you need to do to use NPP. Check out the [gpt-neo-1b_peft.py](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neo-1b-multi-gpu/gpt-neo-1b_peft.py) example for a more details usage of NPP.
### Launch scripts
Although `trl` library is powered by `accelerate`, you should run your training script in a single process. Note that we do not support Data Parallelism together with NPP yet.
```bash
python PATH_TO_SCRIPT
```
```
## Fine-tuning Llama-2 model
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
*`eps`: Tracks the number of episodes per second.
*`objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
*`objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
*`objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
*`objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
*`objective/scores`: The mean scores returned by the reward model / environment.
*`policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
*`policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
*`loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
*`loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
*`val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
*`policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
*`val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
*`val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
*`val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
*`lr`: lr: The current learning rate used by the optimizer.
*`episode`: episode: The current episode count in the training process.
## Cookbook
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
## What is my model doing exactly?
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
This PPO implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
## Benchmark experiments
To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the PPO training is working as intended.
PRM Trainer is an experimental API which is subject to change at any time.
</Tip>
## Overview
Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins.
The abstract from the paper is the following:
> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions.
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).
## Quick start
This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:
"prompt":"Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
"completions":[
"Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
"Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
"Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
],
"labels":[True,False,False],
}
separator="\n"# It's important to use the same separator as the one used during training
foridxinrange(1,len(example["completions"])+1):
steps=example["completions"][0:idx]
text=separator.join((example["prompt"],*steps))+separator# Add a separator between the prompt and each steps
PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision).
The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step.
The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format.
## Example script
We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py)
To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command:
@ -9,7 +9,7 @@ Fine-tuning a language model via PPO consists of roughly three steps:
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
The full process is illustrated in the following figure:
Section under construction. Feel free to contribute!
</Tip>
## Truncation
Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.
To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
<hfoptionsid="dpo">
<hfoptionid="DPO">
DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
You can also use the `max_completion_length` parameter to truncate the completion, though this is less common since the goal is typically to preserve the completion's full length whenever possible.
To set the truncation parameter, use the following code snippet:
```python
fromtrlimportSFTConfig
training_args=SFTConfig(...,max_length=...)
```
</hfoption>
</hfoptions>
## Packing
<Tip>
This technique applies only to SFT.
</Tip>
[Truncation](#truncation) has several drawbacks:
1.**Loss of information**: Key data at the end of a sequence may be discarded.
2.**Choosing truncation length**: Too short loses data; too long undermines efficiency.
Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.
Packing eliminates padding, preserves all sequence information, and allows for flexible sequence lengths, making it a more efficient alternative to truncation. To enable packing, use `packing=True` in the [`SFTConfig`]:
Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230).
</Tip>
## Padding-free
Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
## Disabling model gathering for generation in online methods
When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py).
## Expected dataset type
The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`).
The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`.
## Using the `RewardTrainer`
After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
### Leveraging 🤗 PEFT to train a reward model
Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
```python
defadd_margin(row):
# Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it.
[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs:
$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$
This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`).
```python
training_args=RewardConfig(
center_rewards_coefficient=0.01,
...
)
```
For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932).
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
## Expected dataset format
The reward trainer expects a very specific format for the dataset. Since the model will be trained to predict which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
Therefore the final dataset object should contain two 4 entries at least if you use the default `RewardDataCollatorWithPadding` data collator. The entries should be named:
- `input_ids_chosen`
- `attention_mask_chosen`
- `input_ids_rejected`
- `attention_mask_rejected`
The `j` and `k` suffixes are used to denote the two sentences in the paired dataset.
## Using the `RewardTrainer`
After standardizing your dataset, you can use the `RewardTrainer` as a classic HugingFace Trainer.
You should pass an `AutoModelForSequenceClassification` model to the `RewardTrainer`.
### Leveraging the `peft` library to train a reward model
Just pass a `peft_config` in the key word arguments of `RewardTrainer`, and the trainer should automatically take care of converting the model into a PEFT model!
```python
from peft import LoraConfig, task_type
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments
from trl import RewardTrainer
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, whereas PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.
References:
- [Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740)
- [A2C is a special case of PPO](https://huggingface.co/papers/2205.09123)
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031)
## Get started
To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34)
<!-- * `rlhf_reward_var_per_prompt`: calculated by `rlhf_reward.var(0).mean()`. This is the variance of the rewards estimated across the `args.rloo_k` samples. Usually we expect it to go down (cause policy entropy goes down). -->
*`eps`: Tracks the number of episodes per second.
*`objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
*`objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
*`objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
*`objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
*`objective/scores`: The mean scores returned by the reward model / environment.
*`policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
*`policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
*`loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
*`val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
*`policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
*`val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
*`val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
*`val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
*`lr`: lr: The current learning rate used by the optimizer.
*`episode`: episode: The current global step or episode count in the training process.
## Cookbook
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
## What is my model doing exactly?
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
The bulk of RLOOTrainer is based on the PPO implementation, which is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
Below is a vectorized advantage calculation for RLOO:
```python
deftest_rloo_reward():
local_batch_size=3
rloo_k=4
rlhf_reward=torch.tensor([
1,2,3,# first rlhf reward for three prompts
2,3,4,# second rlhf reward for three prompts
5,6,7,# third rlhf reward for three prompts
8,9,10,# fourth rlhf reward for three prompts
]).float()# here we have 3 prompts which have 4 completions each
To validate the RLOO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the RLOO training is working as intended.
The [Reinforce++](https://hijkzzz.notion.site/reinforce-plus-plus) report by Jian Hu suggests several optimization tricks to enhance performance and stability of RLHF. They include:
- Clipping rewards: limiting reward values within a specific range to mitigate the impact of extreme rewards on model updates, thus preventing gradient explosion
- Normalizing rewards: scaling rewards to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
- Normalizing advantages: scaling advantages to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
- Using token-level KL penalty that is defined as equation (1) of the report vs. sequence-level KL penalty (default)
These options are available via the appropriate arguments in the [`RLOOConfig`] class.
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
## Usage
```bash
# 1. run directly
python examples/scripts/ppo.py
# 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed)
accelerate config # will prompt you to define the training configuration
accelerate launch examples/scripts/ppo.py # launches training
# 3. get help text and documentation
python examples/scripts/ppo.py --help
# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
## Few notes on multi-GPU
To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.
The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
| File | Description | Colab link |
|---|---| --- |
| [`gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | Fine-tune GPT2 to generate positive movie reviews. | [](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb)
|
| [`gpt2-sentiment-control.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | Fine-tune GPT2 to generate movie reviews with controlled sentiment. | [](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb)
|
| [`gpt2-sentiment.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt2-sentiment.py) | Same as the notebook, but easier to use to use in multi-GPU setup. | x |
| [`t5-sentiment.py`](https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/t5-sentiment.py) | Same as GPT2 script, but for a Seq2Seq model (T5). | x |
## Installation
```bash
pip install trl
#optional: wandb
pip install wandb
```
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
## Launch scripts
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
```bash
accelerate config # will prompt you to define the training configuration
accelerate launch scripts/gpt2-sentiment.py # launches training
Supervised fine-tuning (SFT) is the most common step in post-training foundation models, and also one of the most effective. In TRL, we provide a simple API to train models with SFT in a few lines of code; for a complete training script, check out [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py). Experimental support for Vision Language Models is also included in [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py).
## Quickstart
If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model.
The following code-snippet takes care of all the data pre-processing and training for you:
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
## Advanced usage
### Train on completions only
To train on completions only, simply use a [prompt-completion](#prompt-completion) dataset. In this mode, loss is computed solely on the completion part.
If you’d like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](#from-prompt-completion-to-language-modeling-dataset) format.
### Add Special Tokens for Chat Format
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the model’s embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B` one should set `eos_token="<|im_end|>"`.
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
### Dataset format support
The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
* conversational format
```json
{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"What's the capital of France?"},{"role":"assistant","content":"..."}]}
{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"Who wrote 'Romeo and Juliet'?"},{"role":"assistant","content":"..."}]}
{"messages":[{"role":"system","content":"You are helpful"},{"role":"user","content":"How far is the Moon from Earth?"},{"role":"assistant","content":"..."}]}
If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method.
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
### Format your input prompts
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows:
```bash
Below is an instruction ...
### Instruction
{prompt}
### Response:
{completion}
```
Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run:
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
### Packing dataset
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
```python
...
training_args=SFTConfig(packing=True)
trainer=SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args
)
trainer.train()
```
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.
#### Customize your prompts using packed dataset
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to
Note that all keyword arguments of `from_pretrained()` are supported.
### Training adapters
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model.
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsense generations. If the chat template doesn't contain special tokens (e.g. Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
### Training adapters with base 8 bit models
For that, you need to first load your 8 bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
```python
...
peft_config=LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model=AutoModelForCausalLM.from_pretrained(
"EleutherAI/gpt-neo-125m",
load_in_8bit=True,
device_map="auto",
)
trainer=SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(),
peft_config=peft_config,
)
trainer.train()
```
## Using Flash Attention and Flash Attention 2
You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code.
First, to make sure you have all the latest features from transformers, install transformers from source
Note that Flash Attention only works on GPU now and under half-precision regime (when using adapters, base model loaded in half-precision)
Note also both features are perfectly compatible with other tools such as quantization.
### Using Flash-Attention 1
For Flash Attention 1 you can use the `BetterTransformer` API and force-dispatch the API to use Flash Attention kernel. First, install the latest optimum package:
```bash
pip install -U optimum
```
Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager:
```diff
...
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
trainer.train()
```
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
### Enhance the model's performances using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench.
Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains.
### Accelerate fine-tuning 2x using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```python
importtorch
fromtrlimportSFTConfig,SFTTrainer
fromunslothimportFastLanguageModel
max_length=2048# Supports automatic RoPE Scaling, so choose any number
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
## Liger-Kernel: Increase 20% throughput and reduces 60% memory for multi-GPU training
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
With great memory reduction, you can potentially turn off cpu_offloading or gradient checkpointing to further boost the performance.
1. To use Liger-Kernel in [`SFTTrainer`], first install by
```bash
pip install liger-kernel
```
2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed!
```python
training_args=SFTConfig(
use_liger_kernel=True
)
```
To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/).
## Best practices
Pay attention to the following best practices when training a model with that trainer:
- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
## Multi-GPU Training
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
- Ensure that the model is placed on the correct device:
```python
fromaccelerateimportPartialState
device_string=PartialState().process_index
model=AutoModelForCausalLM.from_pretrained(
...
device_map={'':device_string}
)
```
## GPTQ Conversion
You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.
## Extending `SFTTrainer` for Vision Language Models
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
### Preparing the Data
The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:
```python
images=["obama.png"]
messages=[
{
"role":"user",
"content":[
{"type":"text","text":"Who is this?"},
{"type":"image"}
]
},
{
"role":"assistant",
"content":[
{"type":"text","text":"Barack Obama"}
]
},
{
"role":"user",
"content":[
{"type":"text","text":"What is he famous for?"}
]
},
{
"role":"assistant",
"content":[
{"type":"text","text":"He is the 44th President of the United States."}
]
}
]
```
To illustrate how this data format will be processed using the LLaVA model, you can use the following code:
### A custom collator for processing multi-modal data
Unlike the default behavior of `SFTTrainer`, processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:
```python
defcollate_fn(examples):
# Get the texts and images, and apply the chat template
Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the `SFTConfig`, specifically `remove_unused_columns` and `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`.
A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py).
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
## Quickstart
If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model.
The following code-snippet takes care of all the data pre-processing and training for you:
```python
from datasets import load_dataset
from trl import SFTTrainer
dataset = load_dataset("imdb", split="train")
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
)
trainer.train()
```
Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
You can also construct a model outside of the trainer and pass it as follows:
```python
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer
dataset = load_dataset("imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
trainer = SFTTrainer(
model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
)
trainer.train()
```
The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/lvwerra/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
## Advanced usage
### Format your input prompts
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows:
```bash
Below is an instruction ...
### Instruction
{prompt}
### Response:
{completion}
```
Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run:
```python
...
def formatting_prompts_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
trainer = SFTTrainer(
model,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
)
trainer.train()
```
### Packing dataset ([`ConstantLengthDataset`])
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTTrainer`] constructor.
```python
...
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
packing=True
)
trainer.train()
```
#### Customize your prompts using packed dataset
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
```python
def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
packing=True,
formatting_func=formatting_func
)
trainer.train()
```
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTTrainer`] constructor. Please refer to that class' signature for more information.
### Control over the pretrained model
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analoguous to
```python
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
```python
...
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
torch_dtype=torch.bfloat16,
)
trainer.train()
```
Note that all keyword arguments of `from_pretrained()` are supported.
### Training adapters
We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
```python
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
dataset = load_dataset("imdb", split="train")
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
"EleutherAI/gpt-neo-125m",
train_dataset=dataset,
dataset_text_field="text",
peft_config=peft_config
)
trainer.train()
```
Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only:
For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
```python
...
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/gpt-neo-125m",
load_in_8bit=True,
device_map="auto",
)
trainer = SFTTrainer(
model,
train_dataset=dataset,
dataset_text_field="text",
torch_dtype=torch.bfloat16,
peft_config=peft_config,
)
trainer.train()
```
## Best practices
Pay attention to the following best practices when training a model with that trainer:
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_int8_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
Section under construction. Feel free to contribute!
</Tip>
## vLLM for fast generation in online methods
Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time.
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed.
To use [vLLM](https://github.com/vllm-project/vllm), first install it using:
```bash
pip install vllm
```
or
```bash
pip install "trl[vllm]"
```
<hfoptionsid="vllm examples">
<hfoptionid="Online DPO">
Then, enable it by passing `use_vllm=True` in the training arguments.
```python
fromtrlimportOnlineDPOConfig
training_args=OnlineDPOConfig(...,use_vllm=True)
```
</hfoption>
<hfoptionid="GRPO">
First, start a vLLM server by running:
```bash
trl vllm-serve --model <model_name>
```
Then, run the training script and pass `use_vllm=True` in the training arguments.
```python
fromtrlimportGRPOConfig
training_args=GRPOConfig(...,use_vllm=True)
```
You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration).
<Tipwarning={true}>
When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
The script in this example show how to train a reward model for summarization, following the OpenAI Learning to Summarize from Human Feedback [paper](https://arxiv.org/abs/2009.01325). We've validated that the script can be used to train a small GPT2 to get slightly over 60% validation accuracy, which is aligned with results from the paper. The model is [here](https://huggingface.co/Tristan/gpt2_reward_summarization).
Here's an overview of the relevant files in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
| File | Description |
|---|---|
| `scripts/reward_summarization.py` | For tuning the reward model. |
| `scripts/ds3_reward_summarization_example_config.json` | Can be used with the reward model script to scale it up to arbitrarily big models that don't fit on a single GPU. |
## Installation
```bash
pip install trl
pip install evaluate
# optional: deepspeed
pip install deepspeed
```
```bash
# If you want your reward model to follow the Learning to Summarize from Human Feedback paper closely, then tune a GPT model on summarization and then instantiate the reward model
# with it. In other words, pass in the name of your summarization-finetuned gpt on the hub, instead of the name of the pretrained gpt2 like we do in the following examples of how
# to run this script.
# Example of running this script with the small size gpt2 on a 40GB A100 (A100's support bf16). Here, the global batch size will be 64:
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
Let's dive into how text environments work and start with tools!
## Tools
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
### `transformers.Tool`
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
```Python
fromtransformersimportload_tool
# simple calculator tool that runs +-/* operations
calc_tool=load_tool("ybelkada/simple-calculator")
# python interpreter that executes program and returns outputs
py_tool=load_tool("lvwerra/python-interpreter")
# wikipedia search index that returns best search match
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
```Python
calc_tool("1/2")
>>>"0.5"
```
Note that both input and return values are strings to enable easy usage with a language model.
### Custom Tools
The following is an example of a tool that adds two integers:
```Python
defadd(text):
int_1,int_2=text.split("+")
result=int(int_1)+int(int_2)
returnstr(result)
print(add("1+1"))
>>>"2"
```
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
### Call syntax
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
```python
"<request><Calculator>1/2<call>0.5<response>"
```
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
Now let's have a look how we can create a new text environment!
| `model` | Language model to interact with the environment and generate requests. |
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
| `max_length` | The maximum number of tokens to allow in an episode. |
| `generation_kwargs`| Generation settings used by the language model. |
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
## Run an Episode
To run a set of queries through the text environment one can simply use the `run` method.
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
There are five objects that are returned by `run`:
-`queries`: a list of the tokenized queries
-`responses`: all tokens that have been generated withing the environment including model and tool tokens
-`masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
-`rewards`: a list of reward for each query/response
-`histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
Next, we'll train a PPO step with the generated responses!
### Train
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
### Attributes
The following table summarises the available attributes of the `TextHistory` class:
| Attribute | Description |
|:-------------------|:----------------|
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
| `completed` | Indicates if the interaction with the environment has completed. |
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
### Visualization
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` library](https://github.com/Textualize/rich) (make sure to install it before using these methods).
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
At TRL we support PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
The Trainer and model classes are largely inspired from `transformers.Trainer` and `transformers.AutoModel` classes and adapted for RL.
We also support a `RewardTrainer` that can be used to train a reward model.
# Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)

## Overview
This guide walks you through the process of fine-tuning a multimodal language model (e.g., **Gemma 3**) using **Supervised Fine-Tuning (SFT)**. We cover two cases:
- **Single Image + Text**
- **Multi-Image + Text**
This guide serves as a **detailed walkthrough** and complements the existing [VLM SFT script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py). If you're already familiar with the concepts, you can use the script directly.
We demonstrate the fine-tuning process using two datasets, but these principles extend to other **Vision-Language Models (VLMs)** and datasets.
## Understanding the Datasets
To address both **Single Image + Text** and **Multi-Image + Text** scenarios, we use two datasets that are well-suited for this task.
This dataset is a reformatted version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). It consists of conversations where a user provides both **text** and a **single image** as input.
The model (referred to as the **"assistant"**) responds based on both the **visual and textual information** shared by the user. This dataset is particularly useful for training multimodal models to **understand and generate responses based on images and text**.
The **FanqingM/MMIU-Benchmark** dataset consists of:
- **Context:** Included in the system prompt.
- **Question:** Provided as part of the user's input.
- **Series of Images:** Multiple images related to the question.
- **Answer:** The model's expected response.
This dataset is designed for tasks where the model must reason over multiple images to generate an informed response based on both visual and textual inputs.
Once all dependencies are installed, we need to log in to the **Hugging Face Hub**. Since **Gemma 3** is a gated model, access permissions are required.
If you haven’t requested access yet, visit the [Model Card](https://huggingface.co/google/gemma-3-4b-it) and request it.
To log in, you’ll need to generate an [access token](https://huggingface.co/settings/tokens) from your Hugging Face account.
```bash
huggingface-cli login
```
### **Loading the Data**
As mentioned earlier, we will cover two possible use cases. While the specific procedure may vary based on the dataset, the core principles remain consistent.
This guide supports both use cases, so refer to the **Single Image + Text** or **Multi-Image + Text** sections depending on your specific scenario.
In this case, each sample in a batch consists of a **single image paired with text**. Since the dataset is already formatted for supervised fine-tuning (SFT), we can directly load it using `load_dataset`.
Gemma 3 also supports **Multi-Image + Text** scenarios, where:
- The model receives a **list of images** alongside a user message.
- The model processes **interleaved images and text** within a conversation.
For this dataset, some preprocessing is required before training.
```python
fromdatasetsimportload_dataset
dataset_name="FanqingM/MMIU-Benchmark"
# Load Dataset
dataset=load_dataset(dataset_name)
```
After loading the dataset, we need to preprocess and format it into a conversational structure. Here’s an example of how the data might look:
```python
{"role":"system","content":[{"type":"text","text":"You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
{"role":"user","content":images_list+[{"type":"text","text":"Which image is most likely to be a real photograph?"}]},
{"role":"assistant","content":[{"type":"text","text":"A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
With this, your **Multi-Image + Text** dataset is now prepared for training.
### **Preparing for Training**
We start by loading the model and processor. In this example, we use `google/gemma-3-4b-it`, but the same process applies to its other variants and similar models.
To optimize memory usage, we configure `BitsAndBytes` to load the quantized version of the model.
attn_implementation="eager",# Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
quantization_config=bnb_config
)
processor=AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side="right"
```
Next, we set up [Quantized Low-Rank Adaptation (QLoRA)](https://huggingface.co/papers/2305.14314), an efficient fine-tuning technique for Large Language Models (LLMs) and Vision-Language Models (VLMs).
```python
frompeftimportLoraConfig,get_peft_model
# Configure QLoRA
peft_config=LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=[
"lm_head",
"embed_tokens",
],
)
```
With QLoRA now set up, we need to define the training arguments for SFT. The [`SFTConfig`] class simplifies this process, providing an easy way to adjust parameters based on our specific needs.
```python
fromtrlimportSFTConfig
training_args=SFTConfig(
output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft",# Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
num_train_epochs=1,# Set the number of epochs to train the model.
per_device_train_batch_size=8,# Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
gradient_accumulation_steps=4,# Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
gradient_checkpointing=True,# Enable gradient checkpointing to reduce memory usage during training.
optim="adamw_torch_fused",# Use the fused AdamW optimizer for better performance.
logging_steps=10,# Frequency of logging training progress (log every 10 steps).
save_strategy="epoch",# Save checkpoints at the end of each epoch.
learning_rate=2e-05,# Learning rate for training.
bf16=True,# Enable bfloat16 precision for training to save memory and speed up computations.
push_to_hub=True,# Automatically push the fine-tuned model to Hugging Face Hub after training.
report_to="tensorboard",# Automatically report metrics to tensorboard.
gradient_checkpointing_kwargs={"use_reentrant":False},# Set gradient checkpointing to non-reentrant to avoid issues.
dataset_kwargs={"skip_prepare_dataset":True},# Skip dataset preparation to handle preprocessing manually.
remove_unused_columns=False,# Ensure unused columns are not removed in the collator (important for batch processing).
)
```
The `collate_fn` is responsible for processing and preparing individual examples to form a batch.
Each example in the batch undergoes the following steps:
1. The **chat template** is applied to the text.
2. The **processor tokenizes** both `texts` and `images`, encoding them into tensors.
3. The **labels** for training are set as the `input_ids` of the example.
4. Certain **special tokens** are **masked (ignored)** during loss computation:
-`pad_token_id`
-`<image_token_id>`
-`<image_soft_token>` (corresponding to ID `262144`)
This process is similar across different dataset types, with a minor variation in how images are handled:
- **Single Image + Text** → A **list of images** is directly processed.
- **Multi-Image + Text** → A **list of lists of images** is used, where each batch element contains multiple images.
We save the fine-tuned model to the Hub, making it easily accessible for future use. Additionally, TRL automatically logs the training results to **Weights & Biases (Wandb)** or **TensorBoard**, depending on the chosen configuration.
<!-- Add Wandb training results -->
### Results
During and after trainig, we can inspect the results using **Weights & Biases (Wandb)** or **TensorBoard**. For example:
* [**gemma-3-4b-it-trl-sft-MMIU-Benchmark (Multi-Images+Text or Interleaving)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-MMIU-Benchmark)
## Limitations
Currently, fine-tuning Gemma has some [known limitations](https://github.com/huggingface/trl/issues/3121). We recommend following the procedure outlined in this guide to ensure the best results.
## References
For further reading and complementary resources, check out the following:
- [Fine-Tuning Vision-Language Models with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora)
- [Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl)
Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference).
## Load and Generate
If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored:
You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger:
Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above.
@ -19,7 +19,7 @@ Now we can fit very large models into a single GPU, but the training might still
The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU.
With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs.
We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively.
@ -38,12 +38,11 @@ The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-
There is nothing special about fine-tuning the model before doing RLHF - it’s just the causal language modeling objective from pretraining that we apply here.
To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with a EOS token in between and cut chunks of the context size to fill the batch without any padding.
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
The packing is handled by the `ConstantLengthDataset` and we can then use the `Trainer` after loading the model with `peft`. First, we load the model in int8, prepare it for training, and then add the LoRA adapters.
```python
# load model in 8bit
@ -52,7 +51,7 @@ model = AutoModelForCausalLM.from_pretrained(
Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the initial model and human feedback data.
The abstract from the paper is the following:
> Reinforcement learning from human feedback (RLHF) has emerged as a central tool for language model alignment. We consider online exploration in RLHF, which exploits interactive access to human or AI feedback by deliberately encouraging the model to produce diverse, maximally informative responses. By allowing RLHF to confidently stray from the pre-trained model, online exploration offers the possibility of novel, potentially super-human capabilities, but its full potential as a paradigm for language model training has yet to be realized, owing to computational and statistical bottlenecks in directly adapting existing reinforcement learning techniques. We propose a new algorithm for online exploration in RLHF, Exploratory Preference Optimization (XPO), which is simple and practical -- a one-line change to (online) Direct Preference Optimization (DPO; Rafailov et al., 2023) -- yet enjoys the strongest known provable guarantees and promising empirical performance. XPO augments the DPO objective with a novel and principled exploration bonus, empowering the algorithm to explore outside the support of the initial model and human feedback data. In theory, we show that XPO is provably sample-efficient and converges to a near-optimal language model policy under natural exploration conditions, irrespective of whether the initial model has good coverage. Our analysis, which builds on the observation that DPO implicitly performs a form of Q*-approximation (or, Bellman error minimization), combines previously disparate techniques from language modeling and theoretical reinforcement learning in a serendipitous fashion through the perspective of KL-regularized Markov decision processes. Empirically, we find that XPO is more sample-efficient than non-exploratory DPO variants in a preliminary evaluation.
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Lewis Tunstall](https://huggingface.co/lewtun).
## Quick start
This example demonstrates how to train a model using the XPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
Distributed across 8 GPUs, the training takes approximately 1 hour.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript.
</code></pre>
## Expected dataset type
XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`XPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Usage tips
### Use a reward model
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
```diff
- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification
Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
</Tip>
### Encourage EOS token generation
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`XPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`XPOConfig`]:
We provide an example script to train a model using the XPO method. The script is available in [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py)
To test the XPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
```bash
python examples/scripts/xpo.py \
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--logging_steps 25\
--output_dir Qwen2.5-0.5B-XPO-PairRM \
--warmup_ratio 0.1 \
--push_to_hub
```
## Logged metrics
The logged metrics are as follows:
*`loss/xpo`: The mean xpo part of the full loss.
*`loss/dpo`: The mean dpo part of the full loss.
*`objective/kl`: The mean KL divergence between the model and reference data.
*`objective/entropy`: The mean entropy of the model and reference data.
*`objective/model_scores`: The mean scores (according to the reward model) of the model completions.
*`objective/ref_scores`: The mean scores (according to the reward model) of the reference completions.
*`objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions.
*`rewards/chosen`: The mean reward (according to XPO's DPO implicit reward model) of the chosen completions.
*`rewards/rejected`: The mean reward (according to XPO's DPO implicit reward model) of the rejected completions.
*`rewards/accuracies`: The accuracies of the XPO's implicit reward model.
*`rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions.
*`logps/chosen`: The mean log probabilities of the chosen completions.
*`logps/rejected`: The mean log probabilities of the rejected completions.
*`val/model_contain_eos_token`: The amount of times the model's output contains the eos token.
*`val/ref_contain_eos_token`: The amount of times the reference's output contains the eos token.
*`alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
*`beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
_The best place to learn about examples in TRL is our [docs page](https://huggingface.co/docs/trl/index)!_
## Installation
```bash
pip install trl
#optional: wandb
pip install wandb
```
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks.
You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
## Accelerate Config
For all the examples, you'll need to generate an `Accelerate` config with:
```shell
accelerate config # will prompt you to define the training configuration
```
Then, it is encouraged to launch jobs with `accelerate launch`!
## Categories
The examples are currently split over the following categories:
**1: [Sentiment](https://github.com/lvwerra/trl/tree/main/examples/sentiment)**: Fine-tune a model with a sentiment classification model.
**2: [StackOverflow](https://github.com/lvwerra/trl/tree/main/examples/stack_llama)**: Perform the full RLHF process (fine-tuning, reward model training, and RLHF) on StackOverflow data.
**3: [summarization](https://github.com/lvwerra/trl/tree/main/examples/summarization)**: Recreate OpenAI's [Learning to Summarize paper](https://proceedings.neurips.cc/paper/2020/file/1f89885d556929e98d3ef9b86448f951-Paper.pdf).
**4: [toxicity](https://github.com/lvwerra/trl/tree/main/examples/toxicity)**: Fine-tune a model to reduce the toxicity of its generations.
write about best-of-n as an alternative rlhf
**5: [best-of-n sampling](https://github.com/lvwerra/trl/tree/main/examples/best_of_n_sampling)**: Comparative demonstration of best-of-n sampling as a simpler (but relatively expensive) alternative to RLHF
Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.
"**Best-of-n sampling as an alternative to RLHF**\n",
"\n",
"This notebook compares reward-model scores of prompt based responses from \n",
"1. a base model (`gpt2-imdb`)\n",
"2. `RLHF` tuned model based on this base-model \n",
"3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n",
Some files were not shown because too many files have changed in this diff
Show More
Reference in New Issue
Block a user
Blocking a user prevents them from interacting with repositories, such as opening or commenting on pull requests or issues. Learn more about blocking a user.