Compare commits

...

84 Commits

Author SHA1 Message Date
a5788ac99b Release: v0.8.4 (#1547) 2024-04-17 17:19:28 +02:00
3bbe7e0407 Fixed ref model not used in PPO generation (#1534) 2024-04-17 07:22:56 -07:00
edf60e826b Update run_sft.sh (#1546) 2024-04-17 16:17:05 +02:00
5d1deb1445 CLI: Set dataset_text_field to None to allow ChatML automatic template (#1545)
* Update cli_utils.py

* Update test_cli.py
2024-04-17 14:45:14 +02:00
476c4b8dc0 [KTO] support to load the adapter twice (#1542)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-04-16 17:43:40 +02:00
e823458a6a save_model -> save_pretrained in ppo_trainer.mdx (#1537) 2024-04-15 09:35:03 +02:00
1c0d8bca15 VSFT hotfix - adds gen prompt to template and processor to hub (#1532)
* adds gen prompt to template and processor to hub

* fixes hub model id, removes Path
2024-04-12 20:14:12 +02:00
363369a717 [CPO] fix memory leak due to retained value (#1531) 2024-04-12 15:32:01 +02:00
aba4df02c1 set dev version (#1529) 2024-04-12 12:37:34 +02:00
98226473e4 Release: v0.8.3 (#1528) 2024-04-12 12:22:05 +02:00
87f4c70e60 [CLI] fix imports (#1527) 2024-04-12 12:17:05 +02:00
995f1174da set dev version (#1523) 2024-04-11 15:51:57 +02:00
143e11123d Release: v0.8.2 (#1522) 2024-04-11 15:42:47 +02:00
346c99d222 Adds VLM Training support to SFTTrainer + VSFT script (#1518)
* adds option to skip dataset preparation in SFTTrainer

* before changing the template

* adds support for new schema

* a few fixes to data collator to support new schema

* updates args

* precommit

* adds sys prompt to chat template and other fixes

* updates template, fixes collator for multiple images

* precommit

* rename vsft to vstf_llava

* adding integration tests

* adds integration test for vsft

* precommit

* adds back chat template

* docs

* typo

* adds eval, precommit

* adds peft launch args

* formatting

* fixes no deps tests by checking if PIL lib exists

* Update __init__.py

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-04-11 15:35:59 +02:00
087fe544b0 add data for sfttrainer doc (#1521) 2024-04-11 15:08:43 +02:00
ebbd37ba99 allow pre-tokenized datasets (#1520) 2024-04-11 14:50:39 +02:00
e667550a5a Allow streaming (datasets.IterableDataset) (#1468)
* safe-guard iterabledatasets

* import datasets

* reference the correct IterableDataset

* make pre-commit
2024-04-11 11:11:07 +02:00
57aebe9c36 [ORPO] Update NLL loss to use input_ids instead (#1516)
* Calculate loss on `input_ids` instead of only on response

* Use `concatenated_labels` if `is_encoder_decoder`
2024-04-09 14:10:09 +02:00
85f5fd220d correct metrics (#1514)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-04-08 17:09:04 +02:00
4dca169404 use kwarfs for RM (#1515) 2024-04-08 17:05:37 +02:00
f35b68a301 Speed up PPO with ZeRO-3 by 10x 🔥 (#1483)
* Speed up PPO by 10x 🔥

* Revert

* Clean up

* Use relative import

* Clean

* Fix typing for docs
2024-04-08 14:30:44 +02:00
5cf863576a Change the device index to device:index (#1490)
Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-04-08 14:20:42 +02:00
9a28b3fd05 Fix RichProgressCallback (#1496)
* fix RichProgressCallback

* Refine code styling in RichProgressCallback tests
2024-04-04 21:13:54 +02:00
4f8057ad23 [KTO] fix interleaving, reporting, hanging bugs (#1499)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* fix tokenization error: lack of bos

* change user warning for weight hyperparams

* minor update to docs

* reshape attention mask

* reformat

* add test for bos/eos tokens

* move dependency location

* Update tests/test_kto_trainer.py

* don't report nan metrics

* don't report nan metrics and remove data interleaving

* fix bugs in calculating metrics

* no need to gather KL term

* minor changes

* use nanmean for losses

* remove disabling of wandb

* revert changes

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-04-03 23:41:12 +02:00
ab0d11d815 Correct ppo_epochs usage (#1480)
* Correct ppo_epochs usage

The usage of ppo_epochs is incorrect here. 

In 8534f0edf8/trl/trainer/ppo_config.py (L104C8-L104C58)

the ppo_epochs was described as "Number of optimisation epochs per batch of samples". 

However, here it is used as the usual epoch number, in which you do one iteration over the training dataset.

* Update ppo_trainer.mdx

* Update docs/source/ppo_trainer.mdx

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-04-02 12:22:16 +02:00
c674c66a45 Fix DPO Unsloth example (#1494) 2024-04-02 12:16:56 +02:00
45da5df53e use log1p for loss (#1491) 2024-04-02 12:06:54 +02:00
04fd8d9400 Fix typo in how_to_train.md (#1503)
Said "big" where it should say "bug".
2024-04-02 12:05:07 +02:00
bf2aed3876 add dpo link (#1502) 2024-04-02 12:04:34 +02:00
0ee349dcd4 Update KTO example to use better model and ChatML support (#1485)
* Update KTO example

* Tweak params

* Fix values

* Fix LoRA params
2024-03-27 10:47:42 +01:00
7ff6206510 Ignore chat files (#1486)
* Ignore chat files

* Update .gitignore

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update .gitignore

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2024-03-27 10:44:23 +01:00
e4b20ecbc4 hackey update to ModelConfig to allow lora_target_modules="all-linear" (#1488)
the type hint forces a list which raises a "all-linear" layer not found. forcing a string makes it work. updating the type hint to `Union[str, list[str]]` also raise a parsing error
2024-03-27 09:04:41 +01:00
6c2f829bb7 [KTO] Use batching to speed up data processing (#1470)
* Refactor test

* Make batched tokenizer

* Make is FAST 🔥!

* Hack to the max

* Run on main process

* Refactor

* Add unit test

* f

* r

* Refactor

* Remove bs

* Refactor to tokenize once

* Add typing

* Add test for KL getter
2024-03-26 19:46:23 +01:00
c4f0f41935 Update KTO example with good dataset & chat format (#1481)
* Update KTO example with good dataset & chat format

* Add error for chat template
2024-03-25 16:56:43 +01:00
dc6a934269 add missing classes (#1479) 2024-03-24 22:08:28 +01:00
9ce7ac6925 Fix hyperparameters in KTO example (#1474)
* Fix hparams in KTO example

* Clean

* Fix
2024-03-24 14:29:22 +01:00
99553c19ae Add use_cache=False in {ORPO,CPO}Trainer.concatenated_forward (#1478)
* Add `use_cache=False` in `concatenated_forward`

Prevents `ORPOTrainer` from using the cache, as it's not required for computing the logits and runs into conflicts with Flash Attention 2

* Add `use_cache=False` to `concatenated_forward`

Co-authored-by: Kashif Rasul <kashif@users.noreply.github.com>

---------

Co-authored-by: Kashif Rasul <kashif@users.noreply.github.com>
2024-03-24 11:33:20 +01:00
2ce8e45bb2 ORPO trainer (#1435)
* initial orpo skeleton

* typos

* calculate orpo loss

* fix class name

* fix tests

* fix typo

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/orpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* rename max_target_length

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* more docs

* log log_odds_ratio and log_odds

* average_log_prob as per paper

* added logging section

* add nll_loss

* fix typo

* more verbose

* rename log_odds to log_odds_chosen

* allow datasets to be loaded

* remove dup debug arg

* tokenizer exists

* fix typo

* use trl-internal-testing/hh-rlhf-trl-style dataset

* formatting

* add missing imports

* fix output dir name

* Update examples/scripts/orpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* move dataset_num_proc to configs

* Update trl/trainer/orpo_config.py

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>

* Update trl/trainer/orpo_trainer.py

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>

* add ORPOTrainer to readme

* fix typo

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
2024-03-22 22:07:11 +01:00
d1df79f83c Add CPOTrainer (#1382)
* add CPOTrainer

* add docs

* fix formatting

* removed precompute_ref_log_probs arg

* remove precompute_ref_log_probs

* typos

* finish cpo trainer doc

* remove redundant lines

* typo

* formatting

* compute chosen nll loss also for enc-dec models

* fix gradient error of inplace operation for enc-dec models

* formatting

* use CPOConfig

* formatting

* use model_init_kwargs from CPOConfig

* comments in example

* fix doc string

* fix typo in docstring

* update year

* fixed typo

* use preference dataset

* fix learning rate

* move dataset_num_proc to configs

* Update cpo paper link from HF: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update description for CPO: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* remove _prepare_deepspeed for cpo

Because CPO does not need init for reference model

* Add explanation to CPO loss

* format

* fix bug when lengths are given

* add CPOTrainer to README

* fix grammer

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-22 21:32:45 +01:00
d10f7663b0 [peft] Update test_reward_trainer.py to fix tests (#1471)
* [peft] Update test_reward_trainer.py

Since we are requiring peft >= 0.4.0

* formatting
2024-03-22 19:12:54 +01:00
423991c204 Use the standard dataset for DPO CLI (#1456)
* Use the standard dataset

* update docs

* update dpo examples

* fix cli error

* fix CI

* use trl-internal-testing/hh-rlhf-trl-style
2024-03-20 13:14:08 -04:00
988d4c4e1a set dev version (#1463) 2024-03-20 12:30:48 +01:00
8534f0edf8 Release: v0.8.1 (#1462) 2024-03-20 11:32:06 +01:00
5095e7f948 add eos token to generate (#1459) 2024-03-20 10:30:27 +01:00
9fcf61d706 Fix chat CLI for model revisions (#1458)
* Fix chat CLI for model revisions

* Clean
2024-03-20 09:35:34 +01:00
66b043a910 set dev version (#1454) 2024-03-19 17:30:48 +01:00
f2c71771cc Release: v0.8.0 (#1453)
* Release: v0.7.12

* 0.8.0 instead
2024-03-19 17:19:38 +01:00
631c33cbb3 FEAT: Update README to add DPO + CLIs (#1448)
* Update README.md

* Update README.md

* move dpo/ppo description to docs

* rework readme

* Update README.md

---------

Co-authored-by: leandro <leandro.vonwerra@spoud.io>
2024-03-19 16:55:56 +01:00
3f7ff60528 model --> model_name_or_path (#1452)
* `model` --> `model_name_or_path`

* fix style
2024-03-19 16:52:42 +01:00
1705aebeba Fix yaml parsing issue (#1450) 2024-03-19 16:07:50 +01:00
4e622a9033 chat cli (#1431)
* first draft

* move chat to cli

* fix makefile

* make script less verbose

* fix parsing

* fix style

* add more examples

* fix setup.py

* add copyright

* fix verbose init

* attribute FastChat

* add docs
2024-03-19 12:37:06 +01:00
eb2d5b2972 CI / CLI: Properly raise error when CLI tests failed (#1446)
* properly raise error

* another fix

* Update tests.yml

* Update tests-main.yml
2024-03-19 11:39:07 +01:00
f976c6d234 Before update the tr_loss, make sure tr_loss_step is in the same device. (#1439)
* before update the loss from dpo, make sure it's in the same device of tr_loss

* Update trl/trainer/dpo_trainer.py

Co-authored-by: guy1992l <83535508+guy1992l@users.noreply.github.com>

---------

Co-authored-by: guy1992l <83535508+guy1992l@users.noreply.github.com>
2024-03-19 10:28:44 +01:00
abc7301bab Fix PPOTrainer README example (#1441)
* Fix example

* Delete newline
2024-03-19 10:18:49 +01:00
6cfa5cfc81 fix doc build on main (#1437) 2024-03-18 14:24:02 +01:00
a2aa0f0b09 FEAT: Add CLIs in TRL ! (#1419)
* CLI V1

* v1 CLI

* add rich enhancmeents

* revert unindented change

* some comments

* cleaner CLI

* fix

* fix

* remove print callback

* move to cli instead of trl_cli

* revert unneeded changes

* fix test

* Update trl/commands/sft.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* remove redundant strings

* fix import issue

* fix other issues

* add packing

* add config parser

* some refactor

* cleaner

* add example config yaml file

* small refactor

* change a bit the logic

* fix issues here and there

* add CLI in docs

* move to examples/sft

* remove redundant licenses

* make it work on dpo

* set to None

* switch to accelerate and fix many things

* add docs

* more docs

* added tests

* doc clarification

* more docs

* fix CI for windows and python 3.8

* fix

* attempt to fix CI

* fix?

* test

* fix

* tweak?

* fix

* test

* another test

* fix

* test

* fix

* fix

* fix

* skip tests for windows

* test @lvwerra approach

* make dev

* revert unneeded changes

* fix sft dpo

* optimize a bit

* address final comments

* update docs

* final comment

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2024-03-18 12:20:54 +01:00
304e208f77 Create standard dataset for TRL (#1424)
* add scripts to create standard dataset

* precommit

* push changes

* add script to play with
2024-03-14 10:57:48 -04:00
4fe8b027f6 [Kto] torch_dtype kwargs fix (#1429)
* set torch_dtype from string type

* fix test
2024-03-14 13:49:44 +01:00
fb6ebb1e11 [KTO] fix tokenization bugs (#1418)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* fix tokenization error: lack of bos

* change user warning for weight hyperparams

* minor update to docs

* reshape attention mask

* reformat

* add test for bos/eos tokens

* move dependency location

* Update tests/test_kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-14 08:22:50 +01:00
66078c7c01 CI: Fix CI on main (#1422)
* fix CI on main

* final fix
2024-03-13 13:54:22 +01:00
58c0888996 Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA (#1416)
* don't do mp casting

* don't use `prepare_for_kbit` when using fsdp+qlora or dsz3+qlora

* changes to enable fsdp+qlora and dsz3+qlora

* revert

* Update sft_trainer.py

* quality

* fix deprecation using changes from PR https://github.com/huggingface/trl/pull/1415

* fixes

* quality

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* quality

* relaunch tests

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-03-13 10:43:45 +01:00
486e7a4071 model init when args are given (#1413)
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2024-03-11 13:47:37 +01:00
7630f877f9 Fix import error from deprecation in transformers (#1415)
* Fix import error from  deprecation in transformers

* Fix import path
2024-03-11 13:23:56 +01:00
4d862da181 [KTO] fix various bugs (#1402)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-08 12:04:52 +01:00
22b4f548f4 fix RM script (#1393) 2024-03-07 08:49:52 +01:00
4219cbfedc Fix the pad_token_id error (#1394)
* Fix the pad_token_id error

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add the load_in_8bit argument in rl_training.py

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Reformate the patch

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Fix the check failed

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-03-05 02:18:42 +01:00
3bd02380c7 Log ddpo reward as float to fix numpy conversion during bf16 training (#1391) 2024-03-04 02:50:50 +01:00
067db7553a [KTO] prevent nans from appearing in metrics (#1386)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-01 12:19:55 +01:00
93e85ed808 [KTO] merge eval dataset only if it exists (#1383)
* merge eval dataset if it exists

* add eval dataset test
2024-03-01 12:15:14 +01:00
14e0d78807 fix bugs in KTO implementation (#1380)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-02-29 09:01:52 +01:00
b32656f726 FIX: Fix the CI again .. (#1374)
* Update tests-main.yml

* Update tests-main.yml

* Update tests-main.yml

* Update .github/workflows/tests-main.yml

* Update tests-main.yml

* Update tests-main.yml
2024-02-27 12:46:20 +01:00
9399bc113b Update tests-main.yml (#1373) 2024-02-27 12:07:50 +01:00
11f122ad49 Update tests-main.yml (#1372) 2024-02-27 11:45:02 +01:00
009c9a610b feature request add force_use_ref_model (#1367) 2024-02-27 11:19:16 +01:00
7712d42f8c add eval_packing (#1369) 2024-02-27 11:19:06 +01:00
7c2213b9e5 add ci message sending on TRL (#1370) 2024-02-27 11:18:55 +01:00
ddeebce176 Add some arguments for support XPU (#1366)
* Add use_bnb and load_in_4bit arguments.

Make it optional and not supported on all platforms

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Change the use_reentrant default value to False

If the default value of gradient_checkpointing is True, set the
use_reentrant default value as False. Because the following error
happens.

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 191 with name base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add model_dtype for loading the model in model_dtype

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Reformate the patch

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-02-27 02:49:16 +01:00
cf68d871cf Fix version for Python<3.8 (#1363) 2024-02-27 02:41:09 +01:00
2a2676e7ec set seed in sft/dpo/reward_modeling to make result reproducable (#1357)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-02-23 11:12:45 +01:00
ca90cba351 fix 8-bit multi-gpu training bug (#1353)
* fix 8-bit multi-gpu training bug see https://github.com/huggingface/trl/issues/1348

* Update dpo_llama2.py

make gradient_checkpointing_kwargs configurable.

* Update dpo_llama2.py

remote unnecessary config of device_map

* format with make precommit

---------

Co-authored-by: ubuntu <lili@liveremier.ai>
2024-02-23 03:58:43 +01:00
4f97fb4a74 more userfriendly (#1350) 2024-02-22 10:06:35 +01:00
a46cd84a64 Kto trainer (#1181)
* initial file

* initial tokenizer

* UnpairedPreferenceBatchSampler

* use batch_sampler

* use interleave_datasets

* add loss

* fix imports

* use SequentialSampler when training

* formatting

* add other helpers

* add prediction_step

* fix the kto pair docs

* tests

* compute_reference_log_probs

* add get_eval_dataloader

* fix typo

* kto with is_encoder_decoder true

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* fixed typo

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* renamed KTO dataset keys

* use DPOTrainer's get_batch_logps

* add get_batch_samples

* typo

* Handle last token in prompt

* Create KTOConfig class that subclasses transformers.TrainingArguments

* Update KTO tests to handle KTOConfig

* Update KTO script to use KTOConfig

* formatting

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/training_configs.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use max_completion_length

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add back get_batch_logps

* use max_completion_length

* move config to its own file

* Check tokenize params on Trainer init

* Clone labels for end-dec model to solve RuntimeError

* formatting

* fix enc-dec later

* completion_decoder_input_ids is optional for enc-dec

* fix breaking test

* add a kl key for KL estimation with shuffled completion

* add loss ad weights

* fix bug in chosen_idx

* add back metrics

* fix typos

* fix kto_loss docs

* typo

* set loss to None when there is no target completions in batch

* use nan tensor instead of none

* fix reference_logps test

* fix logits

* a bit more robust options

* log only the correct prompt-completion during eval

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add docs for desirable_weight and undesirable_weight args

* dropout is always disabled

* remove DDP hack

* formatting

* move more arguments of trainer to config

* comment out T5 test for now

* Add docstring to KTOTrainer

* moved Config docstrings to the appropriate class

* add autodoc to markdown

* formatting

* updated copyright year

* add model tags

* do not add BOS to start of completion

* Move data_collator to KTOTrainer

* formatting

* data_collator is not in args

* shuffle_completion with specific input_columns

* remove all but the needed columns

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* moved more args to kto_config

* fjx test

* use all_exhausted strategy and shuffle after

* use KTOConfig in HfArgumentParser

* use ModelConfig

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
2024-02-19 14:43:17 +01:00
1f56bffdf8 Update Example to reflect #aa35fec (#1333) 2024-02-18 14:10:04 +01:00
1bfe0b8fcb set dev version (#1332) 2024-02-16 09:49:05 +01:00
77 changed files with 7907 additions and 401 deletions

View File

@ -10,6 +10,9 @@ concurrency:
group: docker-image-builds
cancel-in-progress: false
env:
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
jobs:
trl-latest:
name: "Latest TRL GPU"
@ -42,6 +45,31 @@ jobs:
push: true
tags: huggingface/trl-latest-gpu
- name: Post to a Slack channel
id: slack
#uses: slackapi/slack-github-action@v1.25.0
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
with:
# Slack channel id, channel name, or user id to post message.
# See also: https://api.slack.com/methods/chat.postMessage#channels
channel-id: ${{ env.CI_SLACK_CHANNEL }}
# For posting a rich message using Block Kit
payload: |
{
"text": "trl-latest-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
"blocks": [
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": "trl-latest-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
}
}
]
}
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
trl-source:
name: "Latest TRL + HF ecosystem from source"
runs-on: ubuntu-latest
@ -71,4 +99,29 @@ jobs:
with:
context: ./docker/trl-source-gpu
push: true
tags: huggingface/trl-source-gpu
tags: huggingface/trl-source-gpu
- name: Post to a Slack channel
id: slack
#uses: slackapi/slack-github-action@v1.25.0
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
with:
# Slack channel id, channel name, or user id to post message.
# See also: https://api.slack.com/methods/chat.postMessage#channels
channel-id: ${{ env.CI_SLACK_CHANNEL }}
# For posting a rich message using Block Kit
payload: |
{
"text": "trl-source-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
"blocks": [
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": "trl-source-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
}
}
]
}
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

View File

@ -4,12 +4,16 @@ on:
push:
branches: [ main ]
env:
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
jobs:
tests:
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
os: ['ubuntu-latest', 'windows-latest']
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
@ -28,7 +32,32 @@ jobs:
pip install -U git+https://github.com/huggingface/peft.git
pip install -U git+https://github.com/huggingface/transformers.git
# cpu version of pytorch
pip install -e ".[test, diffusers]"
pip install ".[test, diffusers]"
- name: Test with pytest
run: |
make test
make test
- name: Post to a Slack channel
if: always()
id: slack
#uses: slackapi/slack-github-action@v1.25.0
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
with:
# Slack channel id, channel name, or user id to post message.
# See also: https://api.slack.com/methods/chat.postMessage#channels
channel-id: ${{ env.CI_SLACK_CHANNEL }}
# For posting a rich message using Block Kit
payload: |
{
"text": "TRL CI on transformers/PEFT main: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
"blocks": [
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": "TRL CI on transformers/PEFT main: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
}
}
]
}
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

View File

@ -54,7 +54,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install -e ".[test, peft, diffusers]"
pip install ".[test, peft, diffusers]"
- name: Test with pytest
run: |
make test

5
.gitignore vendored
View File

@ -143,4 +143,7 @@ checklink/cookies.txt
# wandb files
nbs/wandb/
examples/notebooks/wandb/
wandb/
wandb/
# cli scripts that are symlinked from `examples/scripts`
trl/commands/scripts/

View File

@ -5,7 +5,7 @@
Before you start contributing make sure you installed all the dev tools:
```bash
pip install -e ".[dev]"
make dev
```
## Did you find a bug?

View File

@ -2,4 +2,4 @@ include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
recursive-exclude * __pycache__

View File

@ -5,6 +5,12 @@ check_dirs := examples tests trl
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
COMMAND_FILES_PATH = `pwd`/commands
dev:
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
pip install -e ".[dev]"
ln -s `pwd`/examples/scripts/ `pwd`/trl/commands
test:
python -m pytest -n auto --dist=loadfile -s -v ./tests/

130
README.md
View File

@ -3,7 +3,7 @@
</div>
# TRL - Transformer Reinforcement Learning
> Full stack transformer language models with reinforcement learning.
> Full stack library to fine-tune and align large language models.
<p align="center">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
@ -20,61 +20,73 @@
## What is it?
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
</div>
The `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).
`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
**Highlights:**
- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
## How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
This process is illustrated in the sketch below:
The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>
## Highlights
- **`Efficient and scalable`**:
- [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer).
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).
## Installation
### Python package
Install the library with pip:
Install the library with `pip`:
```bash
pip install trl
```
### From source
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
If you want to use the latest features before an official release you can install from source:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .
pip install git+https://github.com/huggingface/trl.git
```
If you wish to develop TRL, you should install in editable mode:
### Repository
If you want to use the examples you can clone the repository with the following command:
```bash
pip install -e .
git clone https://github.com/huggingface/trl.git
```
## Command Line Interface (CLI)
You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:
**SFT:**
```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```
**DPO:**
```bash
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-trl-style --output_dir opt-sft-hh-rlhf
```
**Chat:**
```bash
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
## How to use
For more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.
### `SFTTrainer`
This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
This is a basic example of how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
```python
# imports
@ -98,7 +110,7 @@ trainer.train()
### `RewardTrainer`
This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
This is a basic example of how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
```python
# imports
@ -124,7 +136,7 @@ trainer.train()
### `PPOTrainer`
This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
This is a basic example of how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
```python
# imports
@ -138,11 +150,10 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# initialize trainer
ppo_config = PPOConfig(
batch_size=1,
)
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)
# encode a query
query_txt = "This morning I went to the "
@ -162,13 +173,50 @@ reward = [torch.tensor(1.0)]
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```
### `DPOTrainer`
`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://arxiv.org/abs/2305.18290). This is a basic example of how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
```python
# imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer
# load model and dataset - dataset needs to be in a specific format
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
# load trainer
trainer = DPOTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)
# train
trainer.train()
```
## Development
If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
```
## References
### Proximal Policy Optimisation
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
### Language models
The language models utilize the `transformers` library by 🤗 Hugging Face. Currently, `trl` only supports `transformers` models **for text**.
### Direct Preference Optimization
DPO is based on the original implementation of **"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"** by E. Mitchell et al. \[[paper](https://arxiv.org/pdf/2305.18290.pdf), [code](https://github.com/eric-mitchell/direct-preference-optimization)]
## Citation

View File

@ -3,6 +3,7 @@
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
DATASET_NAME="trl-internal-testing/hh-rlhf-trl-style"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
@ -36,6 +37,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
--mixed_precision 'fp16' \
`pwd`/examples/scripts/dpo.py \
--model_name_or_path $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \

View File

@ -41,6 +41,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--dataset_text_field 'text' \
--per_device_train_batch_size $BATCH_SIZE \
--max_seq_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
@ -56,4 +57,4 @@ echo "Starting program..."
echo "Operation Failed!"
exit 1
}
exit 0
exit 0

View File

@ -5,6 +5,8 @@
title: Quickstart
- local: installation
title: Installation
- local: clis
title: Get started with Command Line Interfaces (CLIs)
- local: how_to_train
title: PPO Training FAQ
- local: use_model
@ -29,8 +31,14 @@
title: Best of N Sampling
- local: dpo_trainer
title: DPO Trainer
- local: kto_trainer
title: KTO Trainer
- local: cpo_trainer
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: orpo_trainer
title: ORPO Trainer
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: text_environments

119
docs/source/clis.mdx Normal file
View File

@ -0,0 +1,119 @@
# Command Line Interfaces (CLIs)
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
Currently supported CLIs are:
- `trl sft`: fine-tune a LLM on a text/instruction dataset
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
## Fine-tuning with the CLI
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
Before using the `sft` or `dpo` commands make sure to run:
```bash
accelerate config
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
```yaml
model_name_or_path:
HuggingFaceM4/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
```
Save that config in a `.yaml` and get directly started ! Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g.:
```bash
trl sft --config example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
```
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
### Supported Arguments
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
[[autodoc]] ModelConfig
You can pass any of these arguments either to the CLI or the YAML file.
### Supervised Fine-tuning (SFT)
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```
The SFT CLI is based on the `examples/scripts/sft.py` script.
### Direct Policy Optimization (DPO)
To use the DPO CLI, you need to have a dataset in the TRL format such as
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-trl-style
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
These datasets always have at least three columns `prompt, chosen, rejected`:
* `prompt` is a list of strings.
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
To do a quick start, you can run the following command:
```bash
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-trl-style
```
The DPO CLI is based on the `examples/scripts/dpo.py` script.
#### Custom preference dataset
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
```
## Chat interface
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
```bash
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```
> [!TIP]
> To use the chat CLI with the developer installation, you must run `make dev`
>
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Besides talking to the model there are a few commands you can use:
- **clear**: clears the current conversation and start a new one
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **exit**: closes the interface
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters.

102
docs/source/cpo_trainer.mdx Normal file
View File

@ -0,0 +1,102 @@
# CPO Trainer
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level, CPO trains models to
avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFTs methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
## Expected dataset format
The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
- `prompt`
- `chosen`
- `rejected`
for example:
```py
cpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
## Expected model format
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `CPOTrainer`
For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above.
```py
cpo_config = CPOConfig(
beta=0.1,
)
cpo_trainer = CPOTrainer(
model,
args=cpo_config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:
```py
cpo_trainer.train()
```
## Loss functions
Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).
## Logging
While training and evaluating we record the following reward metrics:
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
## CPOTrainer
[[autodoc]] CPOTrainer
## CPOConfig
[[autodoc]] CPOConfig

View File

@ -2,9 +2,24 @@
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
## How DPO works
Fine-tuning a language model via DPO consists of two steps and is easier than PPO:
1. **Data collection**: Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
DPO-compatible datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo).
This process is illustrated in the sketch below (from [figure 1 of the original paper](https://arxiv.org/pdf/2305.18290.pdf)):
<img width="835" alt="Screenshot 2024-03-19 at 12 39 41" src="https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d">
Read more about DPO algorithm in the [original paper](https://arxiv.org/pdf/2305.18290.pdf).
## Expected dataset format
The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
@ -63,7 +78,7 @@ The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that
For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
```py
dpo_trainer = DPOTrainer(
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
@ -90,7 +105,7 @@ The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.
The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of preferences. Thus the dataset are not necessarily preferences but rather desirable vs undesirable completions. For paired preference data as required by the `DPOTrainer`, use the `loss_type="kto_pair"` argument to the trainer to utilize this loss, while for the more general case of desired and undesirable data, use the as of yet unimplemented `KTOTrainer`.
The [KTO](https://arxiv.org/abs/2402.01306) authors directly maximize the utility of LLM generations instead of the log-likelihood of preferences. To use preference data with KTO, we recommend breaking up the n preferences into 2n examples and using [`KTOTrainer`](kto_trainer) (i.e., treating the data like an unpaired feedback dataset). Although it is possible to pass in `loss_type="kto_pair"` into DPOTrainer, this is a highly simplified version of KTO that we *do not recommend* in most cases. Please use [`KTOTrainer`](kto_trainer) when possible.
## Logging
@ -146,7 +161,7 @@ training_args = TrainingArguments(output_dir="./output")
dpo_trainer = DPOTrainer(
model,
model_ref=None,
ref_model=None,
args=training_args,
beta=0.1,
train_dataset=train_dataset,

View File

@ -35,6 +35,7 @@ Then, it is encouraged to launch jobs with `accelerate launch`!
| File | Description |
|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the `SFTTrainer` to fine tune a model or adapters into a target dataset. |
| [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) | This script shows how to use the `SFTTrainer` to fine tune a Vision Language Model in a chat setting, the script has been tested on a llava1.5 model so users may see unexpected behaviour in other model architectures. |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the `RewardTrainer` to train a reward model on your own dataset. |
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the `PPOTrainer` to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. |

View File

@ -59,7 +59,7 @@ Debugging the RL pipeline can be challenging due to its complexity. Here are som
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a big in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!

View File

@ -0,0 +1,93 @@
# KTO Trainer
TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://arxiv.org/abs/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela.
For a full example have a look at [`examples/scripts/kto.py`].
Depending on how good your base model is, you may or may not need to do SFT before KTO.
This is different from standard RLHF and DPO, which always require SFT.
## Expected dataset format
The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
- `prompt`
- `completion`
- `label`
for example:
```
kto_dataset_dict = {
"prompt": [
"Hey, hello",
"How are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"completion": [
"hi nice to meet you",
"leave me alone",
"I don't have a name",
"My name is Mary",
"Python",
"C++",
"Java",
],
"label": [
True,
False,
False,
True,
True,
False,
False,
],
}
```
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
## Expected model format
The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `KTOTrainer`
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` * number of positives) to (`undesirable_weight` * number of negatives) is in the range 1:1 to 4:3.
```py
training_args = KTOConfig(
beta=0.1,
desirable_weight=1.0,
undesirable_weight=1.0,
)
kto_trainer = KTOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:
```py
kto_trainer.train()
```
## KTOTrainer
[[autodoc]] KTOTrainer
## KTOConfig
[[autodoc]] KTOConfig

View File

@ -0,0 +1,98 @@
# ORPO Trainer
[Odds Ratio Preference Optimization](https://arxiv.org/abs/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.
The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo).
## Expected dataset format
The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
- `prompt`
- `chosen`
- `rejected`
for example:
```py
orpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
## Expected model format
The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `ORPOTrainer`
For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT.
```py
orpo_config = ORPOConfig(
beta=0.1, # the lambda/alpha hyperparameter in the paper/code
)
orpo_trainer = ORPOTrainer(
model,
args=orpo_config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:
```py
orpo_trainer.train()
```
## Logging
While training and evaluating we record the following reward metrics:
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
## ORPOTrainer
[[autodoc]] ORPOTrainer
## ORPOConfig
[[autodoc]] ORPOConfig

View File

@ -4,6 +4,21 @@ TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training la
The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.
## How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
This process is illustrated in the sketch below:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>
## Expected dataset format
The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
@ -115,7 +130,10 @@ We can then loop over all examples in the dataset and generate a response for ea
```py
from tqdm import tqdm
for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "):
epochs = 10
for epoch in tqdm(range(epochs), "epoch: "):
for batch in tqdm(ppo_trainer.dataloader):
query_tensors = batch["input_ids"]
@ -133,7 +151,7 @@ for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "):
ppo_trainer.log_stats(stats, batch, rewards)
#### Save model
ppo_trainer.save_model("my_ppo_model")
ppo_trainer.save_pretrained("my_ppo_model")
```
## Logging

View File

@ -3,6 +3,7 @@
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py).
Experimental support for Vision Language Models is also included in the example [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/vsft_llava.py).
## Quickstart
@ -271,6 +272,7 @@ trainer.train()
```
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTTrainer` init method.
#### Customize your prompts using packed dataset
@ -604,6 +606,12 @@ You may experience some issues with GPTQ Quantization after completing training.
[[autodoc]] SFTTrainer
## ConstantLengthDataset
## Datasets
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
### ConstantLengthDataset
[[autodoc]] trainer.ConstantLengthDataset

View File

@ -4,6 +4,47 @@ At TRL we support PPO (Proximal Policy Optimisation) with an implementation that
The Trainer and model classes are largely inspired from `transformers.Trainer` and `transformers.AutoModel` classes and adapted for RL.
We also support a `RewardTrainer` that can be used to train a reward model.
## CPOConfig
[[autodoc]] CPOConfig
## CPOTrainer
[[autodoc]] CPOTrainer
## DDPOConfig
[[autodoc]] DDPOConfig
## DDPOTrainer
[[autodoc]] DDPOTrainer
## DPOTrainer
[[autodoc]] DPOTrainer
## IterativeSFTTrainer
[[autodoc]] IterativeSFTTrainer
## KTOConfig
[[autodoc]] KTOConfig
## KTOTrainer
[[autodoc]] KTOTrainer
## ORPOConfig
[[autodoc]] ORPOConfig
## ORPOTrainer
[[autodoc]] ORPOTrainer
## PPOConfig
[[autodoc]] PPOConfig
@ -24,22 +65,6 @@ We also support a `RewardTrainer` that can be used to train a reward model.
[[autodoc]] SFTTrainer
## DPOTrainer
[[autodoc]] DPOTrainer
## DDPOConfig
[[autodoc]] DDPOConfig
## DDPOTrainer
[[autodoc]] DDPOTrainer
## IterativeSFTTrainer
[[autodoc]] IterativeSFTTrainer
## set_seed
[[autodoc]] set_seed

20
example_config.yaml Normal file
View File

@ -0,0 +1,20 @@
# This is an example configuration file of TRL CLI, you can use it for
# SFT like that: `trl sft --config config.yaml --output_dir test-sft`
# The YAML file supports environment variables by adding an `env` field
# as below
# env:
# CUDA_VISIBLE_DEVICES: 0
model_name_or_path:
HuggingFaceM4/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine

View File

@ -0,0 +1,122 @@
import multiprocessing
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import HfApi
from huggingface_hub.repocard import RepoCard
from transformers import HfArgumentParser
"""
# debug
python -i examples/datasets/anthropic_hh.py --debug --push_to_hub
# actual push
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity trl-internal-testing
"""
api = HfApi()
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
hf_repo_id: Optional[str] = field(default="hh-rlhf-trl-style", metadata={"help": "The Hugging Face repository ID"})
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
update_main_revision: Optional[bool] = field(
default=True, metadata={"help": "Update the main revision of the repository"}
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
# GPT-4 generated 😄 Define a function to process the input and extract the dialogue into structured format
def extract_dialogue(input_text):
# Split the input by lines and initialize variables
lines = input_text.strip().split("\n\n")
dialogue_list = []
# Iterate through each line and extract the dialogue
for line in lines:
# Check if the line starts with "Human" or "Assistant" and split accordingly
if line.startswith("Human:"):
role = "user"
content = line.replace("Human: ", "").strip()
elif line.startswith("Assistant:"):
role = "assistant"
content = line.replace("Assistant: ", "").strip()
else:
# If the line doesn't start with "Human" or "Assistant", it's part of the previous message's content
# Append it to the last message's content
dialogue_list[-1]["content"] += "\n\n" + line.strip()
continue
# Append the extracted dialogue piece to the list
dialogue_list.append({"role": role, "content": content})
return dialogue_list
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
ds = load_dataset("Anthropic/hh-rlhf")
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
def process(row):
row["chosen"] = extract_dialogue(row["chosen"])
row["rejected"] = extract_dialogue(row["rejected"])
row["prompt"] = row["chosen"][0]["content"]
return row
ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
# get the commnad used to run the script
run_command = " ".join(["python"] + sys.argv)
for revision in revisions:
ds.push_to_hub(full_repo_id, revision=revision)
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
# get the name of the current file
file_name = __file__.split("/")[-1]
api.upload_file(
path_or_fileobj=__file__,
path_in_repo=file_name,
revision=revision,
repo_id=full_repo_id,
repo_type="dataset",
)
sft_card = RepoCard.load(
full_repo_id,
repo_type="dataset",
)
sft_card.text = f"""\
# TRL's Anthropic HH Dataset
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""
sft_card.push_to_hub(
full_repo_id,
repo_type="dataset",
)

View File

@ -0,0 +1,113 @@
import multiprocessing
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import HfApi
from huggingface_hub.repocard import RepoCard
from transformers import HfArgumentParser
"""
# debug
python -i examples/datasets/tldr_preference.py --debug --push_to_hub
# actual push
python examples/datasets/tldr_preference.py --push_to_hub --hf_entity trl-internal-testing
"""
api = HfApi()
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
hf_repo_id: Optional[str] = field(
default="tldr-preference-trl-style", metadata={"help": "The Hugging Face repository ID"}
)
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
update_main_revision: Optional[bool] = field(
default=True, metadata={"help": "Update the main revision of the repository"}
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
ds = load_dataset("openai/summarize_from_feedback", "comparisons")
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
cnndm_batches = ["batch0_cnndm", "cnndm0", "cnndm2"]
if not args.debug:
ds["validation_cnndm"] = ds["validation"].filter(lambda x: x["batch"] in cnndm_batches)
ds["validation"] = ds["validation"].filter(lambda x: x["batch"] not in cnndm_batches)
tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
cnndm_format_str = "Article:\n{article}\n\nTL;DR:"
def process(row):
format_str = cnndm_format_str if row["batch"] in cnndm_batches else tldr_format_str
row["prompt"] = format_str.format(**row["info"])
choice = row["choice"]
chosen = row["summaries"][choice]["text"]
rejected = row["summaries"][1 - choice]["text"]
row["chosen"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": chosen}]
row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": rejected}]
return row
ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(
["prompt", "chosen", "rejected", "info", "summaries", "choice", "worker", "batch", "split", "extra"]
)
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
# get the commnad used to run the script
run_command = " ".join(["python"] + sys.argv)
for revision in revisions:
ds.push_to_hub(full_repo_id, revision=revision)
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
# get the name of the current file
file_name = __file__.split("/")[-1]
api.upload_file(
path_or_fileobj=__file__,
path_in_repo=file_name,
revision=revision,
repo_id=full_repo_id,
repo_type="dataset",
)
sft_card = RepoCard.load(
full_repo_id,
repo_type="dataset",
)
sft_card.text = f"""\
# TRL's TL;DR Preference Dataset
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""
sft_card.push_to_hub(
full_repo_id,
repo_type="dataset",
)

View File

@ -0,0 +1,42 @@
import multiprocessing
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser
"""
python -i examples/datasets/tokenize_ds.py --debug --model HuggingFaceH4/zephyr-7b-beta
python -i examples/datasets/tokenize_ds.py --debug --model gpt2
"""
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
dataset: str = field(default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The dataset to load"})
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
ds = load_dataset(args.dataset)
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
print(ds["train"][0]["chosen"])

View File

@ -15,6 +15,7 @@ from transformers import (
Trainer,
TrainerCallback,
TrainingArguments,
set_seed,
)
from transformers.utils import PaddingStrategy
@ -89,11 +90,14 @@ class ScriptArguments:
default=False,
metadata={"help": "Whether to run eval after the first step"},
)
seed: Optional[int] = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
set_seed(script_args.seed)
# Load the human stack-exchange-paired dataset for tuning the reward model.
train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train")
if script_args.train_subset > 0:
@ -129,7 +133,10 @@ training_args = TrainingArguments(
logging_steps=10,
optim=script_args.optim,
lr_scheduler_type=script_args.lr_scheduler_type,
seed=script_args.seed,
)
# Load the value-head model and tokenizer.
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)

View File

@ -66,6 +66,7 @@ class ScriptArguments:
)
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
parser = HfArgumentParser(ScriptArguments)
@ -180,7 +181,7 @@ lora_config = LoraConfig(
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
load_in_8bit=True,
load_in_8bit=script_args.load_in_8bit,
device_map={"": current_device},
peft_config=lora_config,
)
@ -215,11 +216,13 @@ sentiment_pipe = pipeline(
"sentiment-analysis",
model=reward_model_name,
device_map={"": current_device},
model_kwargs={"load_in_8bit": True},
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
tokenizer=tokenizer,
return_token_type_ids=False,
)
if sentiment_pipe.model.config.pad_token_id is None:
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.

View File

@ -4,9 +4,10 @@ from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from accelerate import Accelerator
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
from trl import DPOTrainer
@ -41,6 +42,10 @@ class ScriptArguments:
default=True, metadata={"help": "whether to use gradient checkpointing"}
)
gradient_checkpointing_use_reentrant: Optional[bool] = field(
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
)
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
@ -54,6 +59,10 @@ class ScriptArguments:
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
model_dtype: Optional[str] = field(
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
)
# instrumentation
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
@ -73,6 +82,9 @@ class ScriptArguments:
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
seed: Optional[int] = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)
def get_stack_exchange_paired(
@ -123,12 +135,21 @@ if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
set_seed(script_args.seed)
# 1. load a pretrained model
torch_dtype = torch.float
if script_args.model_dtype == "float16":
torch_dtype = torch.float16
elif script_args.model_dtype == "bfloat16":
torch_dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
torch_dtype=torch_dtype,
load_in_4bit=script_args.load_in_4bit,
device_map={"": Accelerator().local_process_index},
)
model.config.use_cache = False
@ -138,12 +159,6 @@ if __name__ == "__main__":
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
model_ref = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
@ -181,6 +196,8 @@ if __name__ == "__main__":
bf16=True,
remove_unused_columns=False,
run_name="dpo_llama2",
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
seed=script_args.seed,
)
peft_config = LoraConfig(
@ -203,7 +220,7 @@ if __name__ == "__main__":
# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
ref_model=None,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,

View File

@ -8,7 +8,14 @@ from accelerate import Accelerator
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
TrainingArguments,
set_seed,
)
from trl import SFTTrainer
from trl.import_utils import is_npu_available, is_xpu_available
@ -27,6 +34,7 @@ class ScriptArguments:
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
# LoraConfig
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
@ -53,6 +61,8 @@ if training_args.group_by_length and script_args.packing:
if training_args.gradient_checkpointing:
raise ValueError("gradient_checkpointing not supported")
set_seed(training_args.seed)
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
"""
@ -91,7 +101,7 @@ def prepare_sample_text(example):
return text
def create_datasets(tokenizer, args):
def create_datasets(tokenizer, args, seed=None):
dataset = load_dataset(
args.dataset_name,
data_dir=args.subset,
@ -104,9 +114,9 @@ def create_datasets(tokenizer, args):
print("Loading the dataset in streaming mode")
valid_data = dataset.take(args.size_valid_set)
train_data = dataset.skip(args.size_valid_set)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
else:
dataset = dataset.train_test_split(test_size=0.005, seed=None)
dataset = dataset.train_test_split(test_size=0.005, seed=seed)
train_data = dataset["train"]
valid_data = dataset["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
@ -133,11 +143,13 @@ def create_datasets(tokenizer, args):
return train_dataset, valid_dataset
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
bnb_config = None
if script_args.use_bnb:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
script_args.model_name,
@ -153,7 +165,7 @@ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_c
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
train_dataset, eval_dataset = create_datasets(tokenizer, script_args)
train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
trainer = SFTTrainer(
model=base_model,

340
examples/scripts/chat.py Normal file
View File

@ -0,0 +1,340 @@
# flake8: noqa
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trl.commands.cli_utils import init_zero_verbose
init_zero_verbose()
import copy
import json
import os
import pwd
import re
import time
from threading import Thread
import torch
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from trl.commands.cli_utils import ChatArguments, TrlParser, init_zero_verbose
from trl.trainer.utils import get_kbit_device_map, get_quantization_config
HELP_STRING = """\
**TRL CHAT INTERFACE**
The chat interface is a simple tool to try out a chat model.
Besides talking to the model there are several commands:
- **clear**: clears the current conversation and start a new one
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **exit**: closes the interface
"""
SUPPORTED_GENERATION_KWARGS = [
"max_new_tokens",
"do_sample",
"num_beams",
"temperature",
"top_p",
"top_k",
"repetition_penalty",
]
SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$"
class RichInterface:
def __init__(self, model_name=None, user_name=None):
self._console = Console()
if model_name is None:
self.model_name = "assistant"
else:
self.model_name = model_name
if user_name is None:
self.user_name = "user"
else:
self.user_name = user_name
def stream_output(self, output_stream):
"""Stream output from a role."""
# This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
# Create a Live context for updating the console output
text = ""
self._console.print(f"[bold blue]<{self.model_name}>:")
with Live(console=self._console, refresh_per_second=4) as live:
# Read lines from the stream
for i, outputs in enumerate(output_stream):
if not outputs or i == 0:
continue
text += outputs
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
# Update the Live console output
live.update(markdown)
self._console.print()
return text
def input(self):
input = self._console.input(f"[bold red]<{self.user_name}>:\n")
self._console.print()
return input
def clear(self):
self._console.clear()
def print_user_message(self, text):
self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
self._console.print()
def print_green(self, text):
self._console.print(f"[bold green]{text}")
self._console.print()
def print_red(self, text):
self._console.print(f"[bold red]{text}")
self._console.print()
def print_help(self):
self._console.print(Markdown(HELP_STRING))
self._console.print()
def get_username():
return pwd.getpwuid(os.getuid())[0]
def create_default_filename(model_name):
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
return f"{model_name}/chat_{time_str}.json"
def save_chat(chat, args, filename):
output_dict = {}
output_dict["settings"] = vars(args)
output_dict["chat_history"] = chat
folder = args.save_folder
if filename is None:
filename = create_default_filename(args.model_name_or_path)
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
json.dump(output_dict, f, indent=4)
return os.path.abspath(filename)
def clear_chat_history(system_prompt):
if system_prompt is None:
chat = []
else:
chat = [{"role": "system", "content": system_prompt}]
return chat
def parse_settings(user_input, current_args, interface):
settings = user_input[4:].strip().split(";")
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
settings = dict(settings)
error = False
for name in settings:
if hasattr(current_args, name):
try:
if isinstance(getattr(current_args, name), bool):
if settings[name] == "True":
settings[name] = True
elif settings[name] == "False":
settings[name] = False
else:
raise ValueError
else:
settings[name] = type(getattr(current_args, name))(settings[name])
except ValueError:
interface.print_red(
f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
)
else:
interface.print_red(f"There is no '{name}' setting.")
if error:
interface.print_red("There was an issue parsing the settings. No settings have been changed.")
return current_args, False
else:
for name in settings:
setattr(current_args, name, settings[name])
interface.print_green(f"Set {name} to {settings[name]}.")
time.sleep(1.5) # so the user has time to read the changes
return current_args, True
def load_model_and_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, revision=args.model_revision)
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
quantization_config = get_quantization_config(args)
model_kwargs = dict(
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
attn_implementation=args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
return model, tokenizer
def chat_cli():
parser = TrlParser(ChatArguments)
args = parser.parse_args_into_dataclasses()[0]
if args.config == "default":
args.config = os.path.join(os.path.dirname(__file__), "config/default_chat_config.yaml")
if args.config.lower() == "none":
args.config = None
args = parser.update_dataclasses_with_config([args])[0]
if args.examples is None:
args.examples = {}
current_args = copy.deepcopy(args)
if args.user is None:
user = get_username()
else:
user = args.user
model, tokenizer = load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
interface.clear()
chat = clear_chat_history(current_args.system_prompt)
while True:
try:
user_input = interface.input()
if user_input == "clear":
chat = clear_chat_history(current_args.system_prompt)
interface.clear()
continue
if user_input == "help":
interface.print_help()
continue
if user_input == "exit":
break
if user_input == "reset":
interface.clear()
current_args = copy.deepcopy(args)
chat = clear_chat_history(current_args.system_prompt)
continue
if user_input.startswith("save") and len(user_input.split()) < 2:
split_input = user_input.split()
if len(split_input) == 2:
filename = split_input[1]
else:
filename = None
filename = save_chat(chat, current_args, filename)
interface.print_green(f"Chat saved in {filename}!")
continue
if re.match(SETTING_RE, user_input):
current_args, success = parse_settings(user_input, current_args, interface)
if success:
chat = []
interface.clear()
continue
if user_input.startswith("example") and len(user_input.split()) == 2:
example_name = user_input.split()[1]
if example_name in current_args.examples:
interface.clear()
chat = []
interface.print_user_message(current_args.examples[example_name]["text"])
user_input = current_args.examples[example_name]["text"]
else:
interface.print_red(
f"Example {example_name} not found in list of available examples: {list(current_args.examples.keys())}."
)
continue
chat.append({"role": "user", "content": user_input})
generation_kwargs = dict(
inputs=tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
),
streamer=generation_streamer,
max_new_tokens=current_args.max_new_tokens,
do_sample=current_args.do_sample,
num_beams=current_args.num_beams,
temperature=current_args.temperature,
top_k=current_args.top_k,
top_p=current_args.top_p,
repetition_penalty=current_args.repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
model_output = interface.stream_output(generation_streamer)
thread.join()
chat.append({"role": "assistant", "content": model_output})
except KeyboardInterrupt:
break
if __name__ == "__main__":
chat_cli()

View File

@ -0,0 +1,13 @@
examples:
llama:
text: There is a Llama in my lawn, how can I get rid of it?
code:
text: Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].
helicopter:
text: How many helicopters can a human eat in one sitting?
numbers:
text: Count to 10 but skip every number ending with an 'e'
birds:
text: Why aren't birds real?
socks:
text: Why is it important to eat socks after meditating?

121
examples/scripts/cpo.py Normal file
View File

@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the CPO training script with the following command with some example arguments.
In general, the optimal configuration for CPO will be similar to that of DPO:
# regular:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-aligned-cpo" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-lora-aligned-cpo" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""
import multiprocessing
from dataclasses import dataclass, field
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config
@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
args, cpo_args, model_config = parser.parse_args_into_dataclasses()
################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
peft_config = get_peft_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
ds = load_dataset(args.dataset)
if cpo_args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
ds = ds.map(
process,
num_proc=1 if cpo_args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]
################
# Training
################
trainer = CPOTrainer(
model,
args=cpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
# train and save the model
trainer.train()
trainer.save_model(cpo_args.output_dir)

View File

@ -175,7 +175,7 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
for i, image in enumerate(images):
prompt = prompts[i]
reward = rewards[i].item()
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0)
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float()
accelerate_logger.log_images(
result,

View File

@ -1,3 +1,4 @@
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -14,9 +15,9 @@
"""
# regular:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
@ -30,9 +31,9 @@ python examples/scripts/dpo.py \
# peft:
python examples/scripts/dpo.py \
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
@ -48,76 +49,48 @@ python examples/scripts/dpo.py \
--lora_r=16 \
--lora_alpha=16
"""
from dataclasses import dataclass, field
from typing import Dict, Optional
import logging
import multiprocessing
import os
from contextlib import nullcontext
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
from trl.commands.cli_utils import DpoScriptArguments, init_zero_verbose, TrlParser
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config
from trl import (
DPOTrainer,
ModelConfig,
RichProgressCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
@dataclass
class ScriptArguments:
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: int = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"})
ignore_bias_buffers: bool = field(
default=False,
metadata={
"help": "debug argument for distributed training;"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
def split_prompt_and_responses(sample) -> Dict[str, str]:
prompt = extract_anthropic_prompt(sample["chosen"])
return {
"prompt": prompt,
"chosen": sample["chosen"][len(prompt) :],
"rejected": sample["rejected"][len(prompt) :],
}
return dataset.map(split_prompt_and_responses)
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
parser = TrlParser((DpoScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
################
# Model & Tokenizer
@ -146,34 +119,66 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)
################
# Dataset
################
train_dataset = get_hh("train", sanity_check=args.sanity_check)
eval_dataset = get_hh("test", sanity_check=args.sanity_check)
ds = load_dataset(args.dataset_name)
if args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(50))
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
ds = ds.map(
process,
num_proc=multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]
################
# Training
################
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
)
with init_context:
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
trainer.train()
trainer.save_model(training_args.output_dir)
with save_context:
trainer.save_model(training_args.output_dir)

115
examples/scripts/kto.py Normal file
View File

@ -0,0 +1,115 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.
# Full training:
python examples/scripts/kto.py \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 16 \
--num_train_epochs 1 \
--learning_rate 1e-5 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir=kto-aligned-model \
--warmup_ratio 0.1 \
--report_to wandb \
--bf16 \
--logging_first_step
# QLoRA:
python examples/scripts/kto.py \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--learning_rate 1e-4 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir=kto-aligned-model-lora \
--warmup_ratio 0.1 \
--report_to wandb \
--bf16 \
--logging_first_step \
--use_peft \
--load_in_4bit \
--lora_target_modules=all-linear \
--lora_r=16 \
--lora_alpha=16
"""
from dataclasses import dataclass
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""
dataset_name: str = "trl-lib/kto-mix-14k"
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# If we are aligning a base model, we use ChatML as the default template
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
# Load the dataset
dataset = load_dataset(script_args.dataset_name)
# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
return example
formatted_dataset = dataset.map(format_dataset)
# Initialize the KTO trainer
kto_trainer = KTOTrainer(
model,
model_ref,
args=kto_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
# Train and push the model to the Hub
kto_trainer.train()
kto_trainer.save_model(kto_args.output_dir)
kto_trainer.push_to_hub()

121
examples/scripts/orpo.py Normal file
View File

@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the ORPO training script with the following command with some example arguments.
In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model:
# regular:
python examples/scripts/orpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-aligned-orpo" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/orpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-lora-aligned-orpo" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""
import multiprocessing
from dataclasses import dataclass, field
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config
@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig))
args, orpo_args, model_config = parser.parse_args_into_dataclasses()
################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
peft_config = get_peft_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
ds = load_dataset(args.dataset)
if orpo_args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
ds = ds.map(
process,
num_proc=1 if orpo_args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]
################
# Training
################
trainer = ORPOTrainer(
model,
args=orpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
# train and save the model
trainer.train()
trainer.save_model(orpo_args.output_dir)

View File

@ -27,6 +27,7 @@ python examples/scripts/reward_modeling.py \
--evaluation_strategy="steps" \
--max_length=512 \
"""
import warnings
import torch
from datasets import load_dataset
@ -64,6 +65,12 @@ if __name__ == "__main__":
model_config.model_name_or_path, num_labels=1, **model_kwargs
)
if model_config.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script."
)
################
# Dataset
################

View File

@ -1,3 +1,4 @@
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -43,30 +44,50 @@ python examples/scripts/sft.py \
--lora_r=64 \
--lora_alpha=16
"""
from dataclasses import dataclass, field
import logging
import os
from contextlib import nullcontext
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
from tqdm.rich import tqdm
from transformers import AutoTokenizer, TrainingArguments
from trl import (
ModelConfig,
RichProgressCallback,
SFTTrainer,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
tqdm.pandas()
@dataclass
class ScriptArguments:
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"})
max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
################
# Model & Tokenizer
@ -96,20 +117,35 @@ if __name__ == "__main__":
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)
################
# Training
################
trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
packing=True,
peft_config=get_peft_config(model_config),
)
with init_context:
trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field=args.dataset_text_field,
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
packing=args.packing,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
trainer.train()
trainer.save_model(training_args.output_dir)
with save_context:
trainer.save_model(training_args.output_dir)

View File

@ -0,0 +1,210 @@
# flake8: noqa
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
python examples/scripts/vsft.py \
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
--report_to="wandb" \
--learning_rate=1.4e-5 \
--per_device_train_batch_size=8 \
--gradient_accumulation_steps=1 \
--output_dir="data/vsft-llava-1.5-7b-hf" \
--logging_steps=5 \
--num_train_epochs=1 \
--push_to_hub \
--gradient_checkpointing \
--remove_unused_columns=False \
--torch_dtype=float16 \
--fp16=True \
--dataset_name=HuggingFaceH4/llava-instruct-mix-vsft \
# peft:
python examples/scripts/vsft.py \
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
--report_to="wandb" \
--learning_rate=1.4e-5 \
--per_device_train_batch_size=8 \
--gradient_accumulation_steps=1 \
--output_dir="data/vsft-llava-1.5-7b-hf" \
--logging_steps=5 \
--num_train_epochs=1 \
--push_to_hub \
--gradient_checkpointing \
--remove_unused_columns=False \
--torch_dtype=float16 \
--fp16=True \
--dataset_name=HuggingFaceH4/llava-instruct-mix-vsft \
--use_peft=True \
--lora_r=64 \
--lora_alpha=16 \
--lora_target_modules=all-linear"
# evaluation:
To evaluate, first install the lmms-eval framework: pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
then run:
accelerate launch --num_processes=8 -m lmms_eval \
--model llava_hf \
--model_args pretrained=llava-hf/llava-1.5-7b-hf \
--tasks mmbench \
--batch_size 1 \
--output_path ./logs/ \
--log_sample
"""
import logging
import os
from contextlib import nullcontext
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from accelerate import Accelerator
from datasets import load_dataset
from tqdm.rich import tqdm
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration
from trl import (
ModelConfig,
RichProgressCallback,
SFTTrainer,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
tqdm.pandas()
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
if __name__ == "__main__":
parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
################
# Model, Tokenizer & Processor
################
LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path)
processor.tokenizer = tokenizer
model = LlavaForConditionalGeneration.from_pretrained(model_config.model_name_or_path, **model_kwargs)
################
# Create a data collator to encode text and image pairs
################
class LLavaDataCollator:
def __init__(self, processor):
self.processor = processor
def __call__(self, examples):
texts = []
images = []
for example in examples:
if len(example["images"]) > 1:
raise ValueError("This collator only supports one image per example")
messages = example["messages"]
text = self.processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(example["images"][0])
batch = self.processor(texts, images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
data_collator = LLavaDataCollator(processor)
################
# Dataset
################
raw_datasets = load_dataset(args.dataset_name)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)
################
# Training
################
with init_context:
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text", # need a dummy field
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
data_collator=data_collator,
dataset_kwargs={"skip_prepare_dataset": True},
)
trainer.train()
with save_context:
trainer.save_model(training_args.output_dir)
trainer.push_to_hub()
if Accelerator().is_main_process:
processor.push_to_hub(training_args.hub_model_id)

View File

@ -1,2 +1,2 @@
[metadata]
license_file = LICENSE
license_file = LICENSE

View File

@ -53,11 +53,12 @@ To create the package for pypi.
8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0).
Then push the change with a message 'set dev version'
"""
import os
from setuptools import find_packages, setup
__version__ = "0.7.11" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.8.4" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"torch>=1.4.0",
@ -79,34 +80,44 @@ EXTRAS["dev"] = []
for reqs in EXTRAS.values():
EXTRAS["dev"].extend(reqs)
setup(
name="trl",
license="Apache 2.0",
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
url="https://github.com/huggingface/trl",
packages=find_packages(),
include_package_data=True,
install_requires=REQUIRED_PKGS,
extras_require=EXTRAS,
python_requires=">=3.7",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
zip_safe=False,
version=__version__,
description="Train transformer language models with reinforcement learning.",
keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf",
author="Leandro von Werra",
author_email="leandro.vonwerra@gmail.com",
)
try:
file_path = os.path.dirname(os.path.abspath(__file__))
os.symlink(os.path.join(file_path, "examples/scripts"), os.path.join(file_path, "trl/commands/scripts"))
setup(
name="trl",
license="Apache 2.0",
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
url="https://github.com/huggingface/trl",
entry_points={
"console_scripts": ["trl=trl.commands.cli:main"],
},
include_package_data=True,
package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*"]},
packages=find_packages(),
install_requires=REQUIRED_PKGS,
extras_require=EXTRAS,
python_requires=">=3.7",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
zip_safe=False,
version=__version__,
description="Train transformer language models with reinforcement learning.",
keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf",
author="Leandro von Werra",
author_email="leandro.vonwerra@gmail.com",
)
finally:
os.unlink(os.path.join(file_path, "trl/commands/scripts"))

40
tests/test_cli.py Normal file
View File

@ -0,0 +1,40 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
import unittest
@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_sft_cli():
try:
subprocess.run(
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine --dataset_text_field text",
shell=True,
check=True,
)
except BaseException as exc:
raise AssertionError("An error occured while running the CLI, please double check") from exc
@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_dpo_cli():
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check",
shell=True,
check=True,
)
except BaseException as exc:
raise AssertionError("An error occured while running the CLI, please double check") from exc

182
tests/test_cpo_trainer.py Normal file
View File

@ -0,0 +1,182 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import torch
from datasets import Dataset
from parameterized import parameterized
from pytest import mark
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from trl import CPOConfig, CPOTrainer
from .testing_utils import require_peft
class CPOTrainerTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
cls.tokenizer.pad_token = cls.tokenizer.eos_token
# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
def _init_dummy_dataset(self):
# fmt: off
dummy_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
"[INST] How is the stock price? [/INST]",
"[INST] How is the stock price? [/INST] ",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Python",
"$46 as of 10am EST",
"46 as of 10am EST",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"Java",
" $46 as of 10am EST",
" 46 as of 10am EST",
],
}
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)
@parameterized.expand(
[
["gpt2", "sigmoid"],
["t5", "hinge"],
["gpt2", "ipo"],
["t5", "ipo"],
]
)
def test_cpo_trainer(self, name, loss_type):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = CPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
loss_type=loss_type,
)
dummy_dataset = self._init_dummy_dataset()
if name == "gpt2":
model = self.model
tokenizer = self.tokenizer
elif name == "t5":
model = self.t5_model
tokenizer = self.t5_tokenizer
training_args.is_encoder_decoder = True
trainer = CPOTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)
@require_peft
@mark.peft_test
def test_cpo_trainer_with_lora(self):
from peft import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = CPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
trainer = CPOTrainer(
model=self.model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)

View File

@ -587,3 +587,64 @@ class DPOTrainerTester(unittest.TestCase):
)
assert trainer.model.model_tags == trainer._tag_names
@require_peft
@mark.peft_test
def test_dpo_lora_force_use_ref(self):
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# lora model
model = AutoModelForCausalLM.from_pretrained(self.model_id)
model_peft = get_peft_model(model, lora_config)
ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
)
dummy_dataset = self._init_dummy_dataset()
with self.assertRaises(ValueError):
# passing a peft_model as model and ref_model should error out,
# unless you pass `force_use_ref_model`
trainer = DPOTrainer(
model=model_peft,
ref_model=ref_model,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
trainer = DPOTrainer(
model=model_peft,
ref_model=ref_model,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
force_use_ref_model=True,
)
# train the model
trainer.train()

View File

@ -13,6 +13,7 @@
# limitations under the License.
import tempfile
import unittest
from functools import partial
import torch
from datasets import Dataset
@ -31,15 +32,27 @@ class IterativeTrainerTester(unittest.TestCase):
cls.tokenizer.pad_token = cls.tokenizer.eos_token
# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab-calibrated"
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
def _init_tensor_dummy_dataset(self):
dummy_dataset_dict = {
"input_ids": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
"attention_mask": [torch.tensor([1, 1]), torch.tensor([1, 1, 1]), torch.tensor([1, 1])],
"labels": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
"input_ids": [
torch.tensor([5303, 3621, 3666, 1438, 318]),
torch.tensor([3666, 1438, 318, 3666, 1438, 318]),
torch.tensor([5303, 3621, 3666, 1438, 318]),
],
"attention_mask": [
torch.tensor([1, 1, 1, 1, 1]),
torch.tensor([1, 1, 1, 1, 1, 1]),
torch.tensor([1, 1, 1, 1, 1]),
],
"labels": [
torch.tensor([5303, 3621, 3666, 1438, 318]),
torch.tensor([3666, 1438, 318, 3666, 1438, 318]),
torch.tensor([5303, 3621, 3666, 1438, 318]),
],
}
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
@ -94,11 +107,10 @@ class IterativeTrainerTester(unittest.TestCase):
tokenizer = self.t5_tokenizer
args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=2, learning_rate=1e-3
)
iterative_trainer = IterativeSFTTrainer(model=model, args=args, tokenizer=tokenizer)
iterative_trainer.optimizer.zero_grad = partial(iterative_trainer.optimizer.zero_grad, set_to_none=False)
iterative_trainer.step(**inputs)

387
tests/test_kto_trainer.py Normal file
View File

@ -0,0 +1,387 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import torch
from datasets import Dataset
from parameterized import parameterized
from pytest import mark
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from trl import KTOConfig, KTOTrainer
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
from .testing_utils import require_no_wandb, require_peft
class KTOTrainerTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model_id)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
cls.tokenizer.pad_token = cls.tokenizer.eos_token
# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
def _init_dummy_dataset(self):
# fmt: off
dummy_dataset_dict = {
"prompt": [
"Hey, hello",
"How are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"completion": [
"hi nice to meet you",
"leave me alone",
"I don't have a name",
"My name is Mary",
"Python",
"C++",
"Java",
],
"label": [
True,
False,
False,
True,
True,
False,
False,
],
}
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)
@parameterized.expand(
[
["gpt2", True, True],
["gpt2", True, False],
# ["t5", True],
["gpt2", False, True],
["gpt2", False, False],
# ["t5", False],
]
)
def test_kto_trainer(self, name, pre_compute, eval_dataset):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
precompute_ref_log_probs=pre_compute,
)
dummy_dataset = self._init_dummy_dataset()
if name == "gpt2":
model = self.model
ref_model = self.ref_model
tokenizer = self.tokenizer
elif name == "t5":
model = self.t5_model
ref_model = self.t5_ref_model
tokenizer = self.t5_tokenizer
trainer = KTOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset if eval_dataset else None,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
def test_tokenize_and_process_tokens(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
trainer = KTOTrainer(
model=self.model,
ref_model=self.ref_model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
with tempfile.TemporaryDirectory() as tmp_dir:
tokenized_dataset = dummy_dataset.map(
_tokenize,
fn_kwargs={"tokenizer": trainer.tokenizer},
batched=True,
batch_size=2,
)
self.assertListEqual(tokenized_dataset["prompt"], dummy_dataset["prompt"])
self.assertListEqual(tokenized_dataset["completion"], dummy_dataset["completion"])
self.assertListEqual(tokenized_dataset["label"], dummy_dataset["label"])
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [10814, 11])
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1])
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [5968, 1219, 72, 3621, 284, 1826, 345])
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
# Test reversal of (prompt, completion) pairs for KL dataset
tokenized_kl_dataset = tokenized_dataset.map(_get_kl_dataset, batched=True, batch_size=2)
self.assertListEqual(
tokenized_kl_dataset["prompt_input_ids"][0], tokenized_dataset["prompt_input_ids"][0]
)
self.assertListEqual(
tokenized_kl_dataset["prompt_attention_mask"][0], tokenized_dataset["prompt_attention_mask"][0]
)
self.assertListEqual(
tokenized_kl_dataset["answer_input_ids"][0], tokenized_dataset["answer_input_ids"][1]
)
self.assertListEqual(
tokenized_kl_dataset["answer_attention_mask"][0], tokenized_dataset["answer_attention_mask"][1]
)
fn_kwargs = {
"prefix": "",
"is_encoder_decoder": trainer.is_encoder_decoder,
"tokenizer": trainer.tokenizer,
"max_length": trainer.max_length,
"truncation_mode": trainer.truncation_mode,
"label_pad_token_id": trainer.label_pad_token_id,
"max_prompt_length": trainer.max_prompt_length,
}
processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2)
self.assertListEqual(processed_dataset["prompt"], dummy_dataset["prompt"])
self.assertListEqual(processed_dataset["completion"], dummy_dataset["completion"])
self.assertListEqual(processed_dataset["label"], dummy_dataset["label"])
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [50256, 10814, 11])
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1])
self.assertListEqual(
processed_dataset["completion_input_ids"][0],
[50256, 10814, 11, 5968, 1219, 72, 3621, 284, 1826, 345, 50256],
)
self.assertListEqual(
processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
)
self.assertListEqual(
processed_dataset["completion_labels"][0],
[-100, -100, -100, 5968, 1219, 72, 3621, 284, 1826, 345, 50256],
)
def test_kto_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
trainer = KTOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
@require_peft
@mark.peft_test
def test_kto_trainer_without_providing_ref_model_with_lora(self):
from peft import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
trainer = KTOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
@require_no_wandb
def test_kto_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
generate_during_eval=True,
)
dummy_dataset = self._init_dummy_dataset()
with self.assertRaisesRegex(
ValueError,
expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install with `pip install wandb` to resolve.",
):
KTOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
@require_peft
@mark.peft_test
def test_kto_lora_save(self):
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# lora model
model = AutoModelForCausalLM.from_pretrained(self.model_id)
model_peft = get_peft_model(model, lora_config)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
# kto train lora model with a lora config
trainer = KTOTrainer(
model=model_peft,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
# train the model
trainer.train()
# save peft adapter
trainer.save_model()
# assert that the model is loaded without giving OSError
try:
AutoModelForCausalLM.from_pretrained(tmp_dir)
except OSError:
self.fail("Loading the saved peft adapter failed")

View File

@ -13,6 +13,7 @@
# limitations under the License.
import sys
import unittest
from functools import partial
from unittest.mock import patch
import pytest
@ -134,6 +135,7 @@ class TestPeftDependancy(unittest.TestCase):
tokenizer=tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
for query_tensor, response_tensor in dummy_dataloader:

174
tests/test_orpo_trainer.py Normal file
View File

@ -0,0 +1,174 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import torch
from datasets import Dataset
from parameterized import parameterized
from pytest import mark
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from trl import ORPOConfig, ORPOTrainer
from .testing_utils import require_peft
class ORPOTrainerTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
cls.tokenizer.pad_token = cls.tokenizer.eos_token
# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
def _init_dummy_dataset(self):
# fmt: off
dummy_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
"[INST] How is the stock price? [/INST]",
"[INST] How is the stock price? [/INST] ",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Python",
"$46 as of 10am EST",
"46 as of 10am EST",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"Java",
" $46 as of 10am EST",
" 46 as of 10am EST",
],
}
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)
@parameterized.expand([["gpt2"], ["t5"]])
def test_orpo_trainer(self, name):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = ORPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
if name == "gpt2":
model = self.model
tokenizer = self.tokenizer
elif name == "t5":
model = self.t5_model
tokenizer = self.t5_tokenizer
training_args.is_encoder_decoder = True
trainer = ORPOTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)
@require_peft
@mark.peft_test
def test_orpo_trainer_with_lora(self):
from peft import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = ORPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
trainer = ORPOTrainer(
model=self.model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)

View File

@ -17,6 +17,7 @@ import gc
import re
import tempfile
import unittest
from functools import partial
import pytest
import torch
@ -193,6 +194,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -220,6 +222,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -252,6 +255,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)
@ -291,6 +295,7 @@ class PPOTrainerTester(unittest.TestCase):
dataset=dummy_dataset,
lr_scheduler=lr_scheduler,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)
@ -332,6 +337,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -382,6 +388,7 @@ class PPOTrainerTester(unittest.TestCase):
dataset=dummy_dataset,
num_shared_layers=num_shared_layers,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -449,6 +456,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -486,6 +494,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -526,6 +535,7 @@ class PPOTrainerTester(unittest.TestCase):
ref_model=self.gpt2_model_ref,
tokenizer=self.gpt2_tokenizer,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
# train model with ppo
reward = [torch.tensor([1.0])]
# train model - this should work fine
@ -692,7 +702,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
@ -879,7 +889,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
assert ppo_trainer.ref_model is None
dummy_dataloader = ppo_trainer.dataloader
@ -967,7 +977,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
assert ppo_trainer.ref_model is None
dummy_dataloader = ppo_trainer.dataloader
@ -1132,6 +1142,55 @@ class PPOTrainerTester(unittest.TestCase):
assert generations_single == generations_batched
def test_generation_with_ref_model(self):
dummy_dataset = self._init_dummy_dataset()
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Negate the weights in the last layer of the ref model so it never
# outputs the same things as the primary model
ref_model = copy.deepcopy(model)
lm_head_weight = ref_model.pretrained_model.lm_head.weight
lm_head_weight.data = -lm_head_weight.data
ppo_trainer = PPOTrainer(
config=self.ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dummy_dataset,
)
input_texts = ["this is a test", "this is another, longer test"]
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": tokenizer.eos_token_id}
tokenizer.pad_token = tokenizer.eos_token
model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
generations_batched, ref_generations_batched = ppo_trainer.generate(
model_inputs, batch_size=2, generate_ref_response=True, **generation_kwargs
)
generations_batched = tokenizer.batch_decode(generations_batched)
ref_generations_batched = tokenizer.batch_decode(ref_generations_batched)
generations_single = []
ref_generations_single = []
for inputs in model_inputs:
generation, ref_generation = ppo_trainer.generate(inputs, generate_ref_response=True, **generation_kwargs)
generations_single.append(generation.squeeze())
ref_generations_single.append(ref_generation.squeeze())
generations_single = tokenizer.batch_decode(generations_single)
ref_generations_single = tokenizer.batch_decode(ref_generations_single)
assert generations_single == generations_batched
assert ref_generations_single == ref_generations_batched
assert generations_batched != ref_generations_batched
assert generations_single != ref_generations_single
def test_grad_accumulation(self):
dummy_dataset = self._init_dummy_dataset()
@ -1213,6 +1272,7 @@ class PPOTrainerTester(unittest.TestCase):
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:

View File

@ -106,11 +106,8 @@ class RewardTrainerTester(unittest.TestCase):
@require_peft
def test_reward_trainer_peft(self):
import peft
from peft import LoraConfig, TaskType
peft_version = peft.__version__
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
@ -172,7 +169,7 @@ class RewardTrainerTester(unittest.TestCase):
previous_non_trainable_params = {}
# due to a change in the way the modules to save are dealt in PEFT.
trainable_params_name = ["lora", "score"] if peft_version < "0.3.0" else ["lora", "modules_to_save"]
trainable_params_name = ["lora", "modules_to_save"]
# check gradients are not None
for n, param in trainer.model.named_parameters():

View File

@ -0,0 +1,53 @@
import tempfile
import unittest
import torch
import torch.nn as nn
from datasets import Dataset
from transformers import Trainer, TrainingArguments
from trl.trainer.utils import RichProgressCallback
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
return self.a * x
class TestRichProgressCallback(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dummy_model = DummyModel()
cls.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5)
cls.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101)
def test_rich_progress_callback_logging(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_eval_batch_size=2,
per_device_train_batch_size=2,
num_train_epochs=4,
evaluation_strategy="steps",
eval_steps=1,
logging_strategy="steps",
logging_steps=1,
save_strategy="no",
report_to="none",
disable_tqdm=True,
)
callbacks = [RichProgressCallback()]
trainer = Trainer(
model=self.dummy_model,
train_dataset=self.dummy_train_dataset,
eval_dataset=self.dummy_val_dataset,
args=training_args,
callbacks=callbacks,
)
trainer.train()
trainer.train()

View File

@ -19,14 +19,20 @@ import unittest
import numpy as np
import pytest
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import Dataset, Image, Sequence
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
LlavaForConditionalGeneration,
TrainingArguments,
)
from trl import SFTTrainer
from trl.import_utils import is_peft_available
from trl.import_utils import is_peft_available, is_pil_available
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM
from .testing_utils import require_peft
from .testing_utils import require_peft, requires_pil
def formatting_prompts_func(example):
@ -45,6 +51,9 @@ def formatting_prompts_func_batched(example):
if is_peft_available():
from peft import LoraConfig, PeftModel
if is_pil_available():
from PIL import Image as PILImage
class SFTTrainerTester(unittest.TestCase):
r""" """
@ -123,6 +132,49 @@ class SFTTrainerTester(unittest.TestCase):
]
)
if is_pil_available():
cls.dummy_vsft_instruction_dataset = Dataset.from_dict(
{
"messages": [
[
{
"role": "user",
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is random noise."}],
},
{
"role": "user",
"content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "2"}],
},
],
[
{
"role": "user",
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is random noise."}],
},
],
],
"images": [
[PILImage.fromarray((np.random.rand(40, 50, 3) * 255).astype("uint8")).convert("RGBA")],
[PILImage.fromarray((np.random.rand(50, 60, 3) * 255).astype("uint8")).convert("RGBA")],
],
}
)
cls.dummy_vsft_instruction_dataset = cls.dummy_vsft_instruction_dataset.cast_column(
"images", Sequence(Image())
)
cls.train_dataset = ConstantLengthDataset(
cls.tokenizer,
cls.dummy_dataset,
@ -930,3 +982,145 @@ class SFTTrainerTester(unittest.TestCase):
)
assert trainer.model.model_tags == trainer._tag_names
def test_sft_trainer_eval_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
packing=True,
max_seq_length=32, # make sure there is at least 1 packed sequence
eval_packing=False,
)
assert len(trainer.train_dataset["input_ids"]) == 1
assert len(trainer.eval_dataset["input_ids"]) != 1
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
max_seq_length=32, # make sure there is at least 1 packed sequence
packing=True,
)
assert len(trainer.train_dataset["input_ids"]) == 1
assert len(trainer.eval_dataset["input_ids"]) == 1
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
max_seq_length=32, # make sure there is at least 1 packed sequence
packing=False,
)
assert len(trainer.train_dataset["input_ids"]) != 1
assert len(trainer.eval_dataset["input_ids"]) != 1
@requires_pil
def test_sft_trainer_skip_prepare_dataset(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
remove_unused_columns=False,
)
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_vsft_instruction_dataset,
eval_dataset=self.dummy_vsft_instruction_dataset,
dataset_text_field="text", # need a dummy field
dataset_kwargs={"skip_prepare_dataset": True},
)
assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features
assert trainer.eval_dataset.features == self.dummy_vsft_instruction_dataset.features
@requires_pil
def test_sft_trainer_llava(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
remove_unused_columns=False,
)
tiny_llava = LlavaForConditionalGeneration.from_pretrained(
"trl-internal-testing/tiny-random-LlavaForConditionalGeneration"
)
processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-random-LlavaForConditionalGeneration")
processor.tokenizer.chat_template = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
class LLavaDataCollator:
def __init__(self, processor):
self.processor = processor
def __call__(self, examples):
texts = []
images = []
for example in examples:
if len(example["images"]) > 1:
raise ValueError("This collator only supports one image per example")
messages = example["messages"]
text = self.processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(example["images"][0])
batch = self.processor(texts, images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
data_collator = LLavaDataCollator(processor)
trainer = SFTTrainer(
model=tiny_llava,
args=training_args,
train_dataset=self.dummy_vsft_instruction_dataset,
eval_dataset=self.dummy_vsft_instruction_dataset,
dataset_text_field="text", # need a dummy field
dataset_kwargs={"skip_prepare_dataset": True},
data_collator=data_collator,
)
trainer.train()
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")

View File

@ -19,6 +19,7 @@ from trl import (
is_bitsandbytes_available,
is_diffusers_available,
is_peft_available,
is_pil_available,
is_wandb_available,
is_xpu_available,
)
@ -51,6 +52,15 @@ def require_diffusers(test_case):
return test_case
def requires_pil(test_case):
"""
Decorator marking a test that requires PIL. Skips the test if pil is not available.
"""
if not is_pil_available():
test_case = unittest.skip("test requires PIL")(test_case)
return test_case
def require_wandb(test_case, required: bool = True):
"""
Decorator marking a test that requires wandb. Skips the test if wandb is not available.

View File

@ -1,44 +1,143 @@
# flake8: noqa
__version__ = "0.7.11"
__version__ = "0.8.4"
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import (
is_bitsandbytes_available,
is_diffusers_available,
is_npu_available,
is_peft_available,
is_wandb_available,
is_xpu_available,
)
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
create_reference_model,
setup_chat_format,
)
from .trainer import (
DataCollatorForCompletionOnlyLM,
DPOTrainer,
IterativeSFTTrainer,
ModelConfig,
PPOConfig,
PPOTrainer,
RewardConfig,
RewardTrainer,
SFTTrainer,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
from typing import TYPE_CHECKING
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
_import_structure = {
"core": [
"set_seed",
],
"environment": [
"TextEnvironment",
"TextHistory",
],
"extras": [
"BestOfNSampler",
],
"import_utils": [
"is_bitsandbytes_available",
"is_diffusers_available",
"is_npu_available",
"is_peft_available",
"is_pil_available",
"is_wandb_available",
"is_xpu_available",
],
"models": [
"AutoModelForCausalLMWithValueHead",
"AutoModelForSeq2SeqLMWithValueHead",
"PreTrainedModelWrapper",
"create_reference_model",
"setup_chat_format",
"SUPPORTED_ARCHITECTURES",
],
"trainer": [
"DataCollatorForCompletionOnlyLM",
"DPOTrainer",
"CPOConfig",
"CPOTrainer",
"IterativeSFTTrainer",
"KTOConfig",
"KTOTrainer",
"ModelConfig",
"ORPOConfig",
"ORPOTrainer",
"PPOConfig",
"PPOTrainer",
"RewardConfig",
"RewardTrainer",
"SFTTrainer",
],
"commands": [],
"commands.cli_utils": ["init_zero_verbose", "SftScriptArguments", "DpoScriptArguments", "TrlParser"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "RichProgressCallback"],
"multitask_prompt_tuning": [
"MultitaskPromptEmbedding",
"MultitaskPromptTuningConfig",
"MultitaskPromptTuningInit",
],
}
if is_diffusers_available():
from .models import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["models"].extend(
[
"DDPOPipelineOutput",
"DDPOSchedulerOutput",
"DDPOStableDiffusionPipeline",
"DefaultDDPOStableDiffusionPipeline",
]
)
_import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"])
if TYPE_CHECKING:
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import (
is_bitsandbytes_available,
is_diffusers_available,
is_npu_available,
is_peft_available,
is_pil_available,
is_wandb_available,
is_xpu_available,
)
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
create_reference_model,
setup_chat_format,
SUPPORTED_ARCHITECTURES,
)
from .trainer import (
DataCollatorForCompletionOnlyLM,
DPOTrainer,
CPOConfig,
CPOTrainer,
IterativeSFTTrainer,
KTOConfig,
KTOTrainer,
ModelConfig,
ORPOConfig,
ORPOTrainer,
PPOConfig,
PPOTrainer,
RewardConfig,
RewardTrainer,
SFTTrainer,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
from .commands.cli_utils import init_zero_verbose, SftScriptArguments, DpoScriptArguments, TrlParser
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .models import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
from .trainer import DDPOConfig, DDPOTrainer
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
extra_objects={"__version__": __version__},
)
from .trainer import DDPOConfig, DDPOTrainer

34
trl/commands/__init__.py Normal file
View File

@ -0,0 +1,34 @@
# flake8: noqa
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule, OptionalDependencyNotAvailable
_import_structure = {
"cli_utils": ["SftArgumentParser", "init_zero_verbose", "DpoScriptArguments", "TrlParser"],
"config_parser": ["YamlConfigParser"],
}
if TYPE_CHECKING:
from .cli_utils import SftScriptArguments, init_zero_verbose, DpoScriptArguments, TrlParser
from .config_parser import YamlConfigParser
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

71
trl/commands/cli.py Normal file
View File

@ -0,0 +1,71 @@
# This file is a copy of trl/examples/scripts/sft.py so that we could
# use it together with rich and the TRL CLI in a more customizable manner.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from subprocess import CalledProcessError
from rich.console import Console
SUPPORTED_COMMANDS = ["sft", "dpo", "chat"]
def main():
console = Console()
# Make sure to import things locally to avoid verbose from third party libs.
with console.status("[bold purple]Welcome! Initializing the TRL CLI..."):
from trl.commands.cli_utils import init_zero_verbose
init_zero_verbose()
command_name = sys.argv[1]
if command_name not in SUPPORTED_COMMANDS:
raise ValueError(
f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}"
)
trl_examples_dir = os.path.dirname(__file__)
# Force-use rich
os.environ["TRL_USE_RICH"] = "1"
if command_name == "chat":
command = f"""
python {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
"""
else:
command = f"""
accelerate launch {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
"""
try:
subprocess.run(
command.split(),
text=True,
check=True,
encoding="utf-8",
cwd=os.getcwd(),
env=os.environ.copy(),
)
except (CalledProcessError, ChildProcessError) as exc:
console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.")
raise ValueError("TRL CLI failed! Check the traceback above..") from exc
if __name__ == "__main__":
main()

288
trl/commands/cli_utils.py Normal file
View File

@ -0,0 +1,288 @@
# This file is a copy of trl/examples/scripts/sft.py so that we could
# use it together with rich and the TRL CLI in a more customizable manner.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from copy import deepcopy
from dataclasses import asdict, dataclass, field, fields
from typing import Any, List
import yaml
from transformers import HfArgumentParser
class YamlConfigParser:
def __init__(self, config_path: str = None, dataclasses: List[Any] = None):
self.config = None
if config_path is not None:
with open(config_path) as yaml_file:
self.config = yaml.safe_load(yaml_file)
else:
self.config = {}
if dataclasses is None:
dataclasses = []
# We create a dummy training args to compare the values before / after
# __post_init__
# Here we import `TrainingArguments` from the local level to not
# break TRL lazy imports.
from transformers import TrainingArguments
self._dummy_training_args = TrainingArguments(output_dir="dummy-training-args")
self.parse_and_set_env()
self.merge_dataclasses(dataclasses)
def parse_and_set_env(self):
if "env" in self.config:
env_vars = self.config["env"]
if isinstance(env_vars, dict):
for key, value in env_vars.items():
os.environ[key] = str(value)
else:
raise ValueError("`env` field should be a dict in the YAML file.")
def merge_dataclasses(self, dataclasses):
from transformers import TrainingArguments
dataclasses_copy = [deepcopy(dataclass) for dataclass in dataclasses]
if len(self.config) > 0:
for i, dataclass in enumerate(dataclasses):
is_hf_training_args = False
for data_class_field in fields(dataclass):
# Get the field here
field_name = data_class_field.name
field_value = getattr(dataclass, field_name)
if not isinstance(dataclass, TrainingArguments):
default_value = data_class_field.default
else:
default_value = (
getattr(self._dummy_training_args, field_name)
if field_name != "output_dir"
else field_name
)
is_hf_training_args = True
default_value_changed = field_value != default_value
if field_value is not None or field_name in self.config:
if field_name in self.config:
# In case the field value is not different from default, overwrite it
if not default_value_changed:
value_to_replace = self.config[field_name]
setattr(dataclasses_copy[i], field_name, value_to_replace)
# Otherwise do nothing
# Re-init `TrainingArguments` to handle all post-processing correctly
if is_hf_training_args:
init_signature = list(inspect.signature(TrainingArguments.__init__).parameters)
dict_dataclass = asdict(dataclasses_copy[i])
new_dict_dataclass = {k: v for k, v in dict_dataclass.items() if k in init_signature}
dataclasses_copy[i] = TrainingArguments(**new_dict_dataclass)
return dataclasses_copy
def to_string(self):
final_string = """"""
for key, value in self.config.items():
if isinstance(value, (dict, list)):
if len(value) != 0:
value = str(value)
value = value.replace("'", '"')
value = f"'{value}'"
else:
continue
final_string += f"--{key} {value} "
return final_string
def init_zero_verbose():
"""
Perform zero verbose init - use this method on top of the CLI modules to make
"""
import logging
import warnings
from rich.logging import RichHandler
FORMAT = "%(message)s"
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.ERROR)
# Custom warning handler to redirect warnings to the logging system
def warning_handler(message, category, filename, lineno, file=None, line=None):
logging.warning(f"{filename}:{lineno}: {category.__name__}: {message}")
# Add the custom warning handler - we need to do that before importing anything to make sure the loggers work well
warnings.showwarning = warning_handler
@dataclass
class SftScriptArguments:
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
dataset_text_field: str = field(default=None, metadata={"help": "the text field of the dataset"})
max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"})
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
gradient_checkpointing_use_reentrant: bool = field(
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
)
@dataclass
class DpoScriptArguments:
dataset_name: str = field(default=None, metadata={"help": "the dataset name"})
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: int = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
sanity_check: bool = field(default=False, metadata={"help": "only train on 1000 samples"})
ignore_bias_buffers: bool = field(
default=False,
metadata={
"help": "debug argument for distributed training;"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
gradient_checkpointing_use_reentrant: bool = field(
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
)
@dataclass
class ChatArguments:
# general settings
model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"})
user: str = field(default=None, metadata={"help": "Username to display in chat interface"})
system_prompt: str = field(default=None, metadata={"help": "System prompt"})
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"})
device: str = field(
default="cpu",
metadata={"help": "device to use for inference."},
)
config: str = field(
default="default",
metadata={
"help": "Config file used for setting the configs. If `default` uses examples/scripts/config/default_chat_config.yaml"
},
)
examples: str = field(default=None, metadata={"help": "Empty placeholder needs to be set via config."})
# generation settings
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate"})
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation"})
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search"})
temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation"})
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"})
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"})
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"})
# model loading
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
torch_dtype: str = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
attn_implementation: str = field(
default=None,
metadata={
"help": (
"Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`"
)
},
)
load_in_8bit: bool = field(
default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}
)
load_in_4bit: bool = field(
default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}
)
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
class TrlParser(HfArgumentParser):
def __init__(self, parsers):
"""
The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config
parsers for users that pass a valid `config` field and merge the values that are set in the config
with the processed parsers.
Args:
parsers (`List[argparse.ArgumentParser`]):
List of parsers.
"""
super().__init__(parsers)
def post_process_dataclasses(self, dataclasses):
# Apply additional post-processing in case some arguments needs a special
# care
training_args = trl_args = None
training_args_index = None
for i, dataclass_obj in enumerate(dataclasses):
if dataclass_obj.__class__.__name__ == "TrainingArguments":
training_args = dataclass_obj
training_args_index = i
elif dataclass_obj.__class__.__name__ in ("SftScriptArguments", "DpoScriptArguments"):
trl_args = dataclass_obj
else:
...
if trl_args is not None and training_args is not None:
training_args.gradient_checkpointing_kwargs = dict(
use_reentrant=trl_args.gradient_checkpointing_use_reentrant
)
dataclasses[training_args_index] = training_args
return dataclasses
def parse_args_and_config(self):
dataclasses = self.parse_args_into_dataclasses(return_remaining_strings=True)
# Pop the last element which should be the remaining strings
dataclasses = self.update_dataclasses_with_config(dataclasses[:-1])
return dataclasses
def update_dataclasses_with_config(self, dataclasses):
self.config_parser = None
for parser_dataclass in dataclasses:
if hasattr(parser_dataclass, "config"):
if self.config_parser is not None:
raise ValueError("You passed the `config` field twice! Make sure to pass `config` only once.")
self.config_parser = YamlConfigParser(parser_dataclass.config)
if self.config_parser is not None:
dataclasses = self.config_parser.merge_dataclasses(dataclasses)
dataclasses = self.post_process_dataclasses(dataclasses)
return dataclasses

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import top_k_top_p_filtering
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
from .import_utils import is_npu_available, is_xpu_available
@ -36,6 +36,42 @@ except ImportError:
WANDB_PADDING = -1
def top_k_top_p_filtering(
logits: torch.FloatTensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)
if 0 <= top_p <= 1.0:
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)
return logits
def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
"""Flatten dictionary and concatenate nested keys with separator."""

View File

@ -1,3 +1,14 @@
# flake8: noqa
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule
from .base_environment import TextEnvironment, TextHistory
_import_structure = {
"base_environment": ["TextEnvironment", "TextHistory"],
}
if TYPE_CHECKING:
from .base_environment import TextEnvironment, TextHistory
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -13,4 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .best_of_n_sampler import BestOfNSampler
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule
_import_structure = {
"best_of_n_sampler": ["BestOfNSampler"],
}
if TYPE_CHECKING:
from .best_of_n_sampler import BestOfNSampler
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import sys
from importlib.util import find_spec
from itertools import chain
from types import ModuleType
from typing import Any
if sys.version_info < (3, 8):
@ -22,11 +27,11 @@ else:
def is_peft_available() -> bool:
return importlib.util.find_spec("peft") is not None
return find_spec("peft") is not None
def is_unsloth_available() -> bool:
return importlib.util.find_spec("unsloth") is not None
return find_spec("unsloth") is not None
def is_accelerate_greater_20_0() -> bool:
@ -41,9 +46,16 @@ def is_accelerate_greater_20_0() -> bool:
return accelerate_version >= "0.20.0"
def is_transformers_greater_than(version: str) -> bool:
_transformers_version = importlib.metadata.version("transformers")
return _transformers_version > version
def is_transformers_greater_than(current_version: str) -> bool:
if _is_python_greater_3_8:
from importlib.metadata import version
_transformers_version = version("transformers")
else:
import pkg_resources
_transformers_version = pkg_resources.get_distribution("transformers").version
return _transformers_version > current_version
def is_torch_greater_2_0() -> bool:
@ -59,26 +71,30 @@ def is_torch_greater_2_0() -> bool:
def is_diffusers_available() -> bool:
return importlib.util.find_spec("diffusers") is not None
return find_spec("diffusers") is not None
def is_pil_available() -> bool:
return find_spec("PIL") is not None
def is_bitsandbytes_available() -> bool:
import torch
# bnb can be imported without GPU but is not usable.
return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available()
return find_spec("bitsandbytes") is not None and torch.cuda.is_available()
def is_torchvision_available() -> bool:
return importlib.util.find_spec("torchvision") is not None
return find_spec("torchvision") is not None
def is_rich_available() -> bool:
return importlib.util.find_spec("rich") is not None
return find_spec("rich") is not None
def is_wandb_available() -> bool:
return importlib.util.find_spec("wandb") is not None
return find_spec("wandb") is not None
def is_xpu_available() -> bool:
@ -87,7 +103,7 @@ def is_xpu_available() -> bool:
return accelerate.utils.is_xpu_available()
else:
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
if find_spec("intel_extension_for_pytorch") is None:
return False
try:
import torch
@ -99,10 +115,74 @@ def is_xpu_available() -> bool:
def is_npu_available() -> bool:
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
if find_spec("torch") is None or find_spec("torch_npu") is None:
return False
import torch
import torch_npu # noqa: F401
return hasattr(torch, "npu") and torch.npu.is_available()
class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = import_structure
# Needed for autocompletion in an IDE
def __dir__(self):
result = super().__dir__()
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
for attr in self.__all__:
if attr not in result:
result.append(attr)
return result
def __getattr__(self, name: str) -> Any:
if name in self._objects:
return self._objects[name]
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str):
try:
return importlib.import_module("." + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
f" traceback):\n{e}"
) from e
def __reduce__(self):
return (self.__class__, (self._name, self.__file__, self._import_structure))
class OptionalDependencyNotAvailable(BaseException):
"""Internally used error class for signalling an optional dependency was not found."""

View File

@ -13,23 +13,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_base import PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import setup_chat_format
# flake8: noqa
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
SUPPORTED_ARCHITECTURES = (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
)
_import_structure = {
"modeling_base": ["PreTrainedModelWrapper", "create_reference_model"],
"modeling_value_head": [
"AutoModelForCausalLMWithValueHead",
"AutoModelForSeq2SeqLMWithValueHead",
],
"utils": ["setup_chat_format", "SUPPORTED_ARCHITECTURES", "unwrap_model_for_generation"],
}
from ..import_utils import is_diffusers_available
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_sd_base"] = [
"DDPOPipelineOutput",
"DDPOSchedulerOutput",
"DDPOStableDiffusionPipeline",
"DefaultDDPOStableDiffusionPipeline",
]
if TYPE_CHECKING:
from .modeling_base import PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import setup_chat_format, SUPPORTED_ARCHITECTURES
if is_diffusers_available():
from .modeling_sd_base import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_sd_base import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -15,6 +15,7 @@ import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from ..import_utils import is_npu_available, is_xpu_available
from .modeling_base import PreTrainedModelWrapper
@ -245,7 +246,13 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
)
first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
if isinstance(first_device, int):
if is_npu_available():
first_device = f"npu:{first_device}"
elif is_xpu_available():
first_device = f"xpu:{first_device}"
else:
first_device = f"cuda:{first_device}"
self.v_head = self.v_head.to(first_device)
def set_device_hook(module, input, outputs):

View File

@ -1,8 +1,29 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Literal, Optional, Tuple
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
from accelerate.utils import is_deepspeed_available
from transformers import PreTrainedModel, PreTrainedTokenizer
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
SUPPORTED_ARCHITECTURES = (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
)
if is_deepspeed_available():
import deepspeed
if TYPE_CHECKING:
from accelerate import Accelerator
from deepspeed.runtime.engine import DeepSpeedEngine
from torch.nn.parallel.distributed import DistributedDataParallel
from .modeling_base import PreTrainedModelWrapper
# TODO: Add Abstract Base Class if more formats are added
@dataclass
@ -76,10 +97,60 @@ def setup_chat_format(
model.resize_token_embeddings(
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
)
# Make sure to update the generation config to use the new eos & bos token
# Update the model config to use the new eos & bos tokens
if getattr(model, "config", None) is not None:
model.config.pad_token_id = tokenizer.pad_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
# Update the generation config to use the new eos & bos token
if getattr(model, "generation_config", None) is not None:
model.generation_config.bos_token_id = tokenizer.bos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
def remove_hooks(model: "DeepSpeedEngine") -> None:
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
for hook in optimizer_offload.forward_hooks:
hook.remove()
for hook in optimizer_offload.backward_hooks:
hook.remove()
optimizer_offload.forward_hooks = []
optimizer_offload.backward_hooks = []
def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
@contextmanager
def unwrap_model_for_generation(
model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False
) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]:
"""Context manager to unwrap a model for generation.
For ZeRO-3 models, we gather the weights once to speed up generation.
"""
unwrapped_model = accelerator.unwrap_model(model)
if is_peft_model:
unwrapped_model.pretrained_model.disable_adapter()
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield model
add_hooks(model)
else:
yield unwrapped_model

View File

@ -15,32 +15,92 @@
# limitations under the License.
# There is a circular import in the PPOTrainer if we let isort sort these
# isort: off
from .utils import (
AdaptiveKLController,
FixedKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
disable_dropout_in_model,
peft_module_casting_to_bf16,
)
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
# isort: on
_import_structure = {
"utils": [
"AdaptiveKLController",
"FixedKLController",
"ConstantLengthDataset",
"DataCollatorForCompletionOnlyLM",
"RunningMoments",
"disable_dropout_in_model",
"peft_module_casting_to_bf16",
"RichProgressCallback",
],
"dpo_trainer": [
"DPOTrainer",
],
"cpo_config": ["CPOConfig"],
"cpo_trainer": ["CPOTrainer"],
"iterative_sft_trainer": [
"IterativeSFTTrainer",
],
"kto_config": ["KTOConfig"],
"kto_trainer": ["KTOTrainer"],
"model_config": ["ModelConfig"],
"orpo_config": ["ORPOConfig"],
"orpo_trainer": ["ORPOTrainer"],
"ppo_config": ["PPOConfig"],
"ppo_trainer": ["PPOTrainer"],
"reward_config": ["RewardConfig"],
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
"sft_trainer": ["SFTTrainer"],
"base": ["BaseTrainer"],
"ddpo_config": ["DDPOConfig"],
}
from ..import_utils import is_diffusers_available
from .base import BaseTrainer
from .ddpo_config import DDPOConfig
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["ddpo_trainer"] = ["DDPOTrainer"]
if is_diffusers_available():
from .ddpo_trainer import DDPOTrainer
if TYPE_CHECKING:
# isort: off
from .utils import (
AdaptiveKLController,
FixedKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
disable_dropout_in_model,
peft_module_casting_to_bf16,
RichProgressCallback,
)
from .dpo_trainer import DPOTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .model_config import ModelConfig
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer, compute_accuracy
from .sft_trainer import SFTTrainer
# isort: on
from .base import BaseTrainer
from .ddpo_config import DDPOConfig
from .dpo_trainer import DPOTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .cpo_config import CPOConfig
from .cpo_trainer import CPOTrainer
from .kto_config import KTOConfig
from .kto_trainer import KTOTrainer
from .model_config import ModelConfig
from .orpo_config import ORPOConfig
from .orpo_trainer import ORPOTrainer
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer, compute_accuracy
from .sft_trainer import SFTTrainer
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .ddpo_trainer import DDPOTrainer
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

78
trl/trainer/cpo_config.py Normal file
View File

@ -0,0 +1,78 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Literal, Optional
from transformers import TrainingArguments
@dataclass
class CPOConfig(TrainingArguments):
r"""
CPOConfig collects all training arguments related to the [`CPOTrainer`] class.
Using [`HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
max_length (`int`, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, defaults to `None`):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
max_target_length (`int`, defaults to `None`):
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
beta (`float`, defaults to 0.1):
The beta factor in CPO loss.
label_smoothing (`float`, defaults to 0):
The label smoothing factor. This argument is required if you want to use the default data collator.
loss_type (`str`, defaults to `sigmoid`):
The type of loss to use. This argument is required if you want to use the default data collator.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
padding_value (`int`, defaults to `None`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
generate_during_eval (`bool`, defaults to `False`):
Whether to sample and log generations during evaluation step.
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
If no model is provided, we need to know if the model_init returns an encoder-decoder.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model`.
model_init_kwargs (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
dataset_num_proc (`Optional[int]`, *optional*):
The number of workers to use to tokenize the data. Defaults to None.
"""
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
max_target_length: Optional[int] = None
beta: float = 0.1
label_smoothing: float = 0
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
disable_dropout: bool = True
label_pad_token_id: int = -100
padding_value: int = None
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
model_init_kwargs: Optional[Dict] = None
dataset_num_proc: Optional[int] = None

929
trl/trainer/cpo_trainer.py Normal file
View File

@ -0,0 +1,929 @@
# CPO Authors: Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, Young Jin Kim
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import random
import warnings
from collections import defaultdict
from contextlib import nullcontext
from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_torch_fx_proxy
from ..import_utils import is_peft_available, is_wandb_available
from .cpo_config import CPOConfig
from .utils import (
DPODataCollatorWithPadding,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
)
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
if is_wandb_available():
import wandb
class CPOTrainer(Trainer):
r"""
Initialize CPOTrainer.
Args:
model (`transformers.PreTrainedModel`):
The model to train, preferably an `AutoModelForSequenceClassification`.
args (`CPOConfig`):
The CPO config arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
tokenizer (`transformers.PreTrainedTokenizerBase`):
The tokenizer to use for training. This argument is required if you want to use the default data collator.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
callbacks (`List[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
"""
_tag_names = ["trl", "cpo"]
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[CPOConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[Dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
):
if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
model_init_kwargs["torch_dtype"] = (
model_init_kwargs["torch_dtype"]
if model_init_kwargs["torch_dtype"] in ["auto", None]
else getattr(torch, model_init_kwargs["torch_dtype"])
)
if isinstance(model, str):
warnings.warn(
"You passed a model_id to the CPOTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# if model is a peft model and we have a peft_config, we merge and unload it first
if isinstance(model, PeftModel):
model = model.merge_and_unload()
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if _support_gc_kwargs:
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# get peft model with the given config
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
self._peft_has_been_casted_to_bf16 = True
# For models that use gradient_checkpointing, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if args.generate_during_eval and not is_wandb_available():
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve."
)
if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif args.is_encoder_decoder is None:
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
else:
self.is_encoder_decoder = args.is_encoder_decoder
if self.is_encoder_decoder:
self.decoder_start_token_id = model.config.decoder_start_token_id
self.pad_token_id = model.config.pad_token_id
if tokenizer is None:
raise ValueError("tokenizer must be specified to tokenize a CPO dataset.")
if args.max_length is None:
warnings.warn(
"`max_length` is not set in the CPOConfig's init"
" it will default to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
else:
max_length = args.max_length
if args.max_prompt_length is None:
warnings.warn(
"`max_prompt_length` is not set in the CPOConfig's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
else:
max_prompt_length = args.max_prompt_length
if args.max_target_length is None and self.is_encoder_decoder:
warnings.warn(
"When using an encoder decoder architecture, you should set `max_target_length` in the CPOConfig's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_target_length = 128
else:
max_target_length = args.max_target_length
if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=tokenizer.pad_token_id,
label_pad_token_id=args.label_pad_token_id,
is_encoder_decoder=self.is_encoder_decoder,
)
if args.remove_unused_columns:
args.remove_unused_columns = False
# warn users
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_dpo_data_collator = True
else:
self.use_dpo_data_collator = False
if args.disable_dropout:
disable_dropout_in_model(model)
self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
self.max_prompt_length = max_prompt_length
self.truncation_mode = args.truncation_mode
self.max_target_length = max_target_length
self.tokenizer = tokenizer
if args.loss_type in ["hinge", "ipo", "kto_pair"] and args.label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
self.beta = args.beta
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# tokenize the dataset
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
if not hasattr(self, "accelerator"):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
Reference:
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
"""
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
# Prepare input tokens for token by token comparison
full_input_ids = np.array(full_tokenized["input_ids"])
if len(full_input_ids) != len(full_concat_input_ids):
raise ValueError("Prompt input ids and answer input ids should have the same length.")
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
# can be merged together when tokenizing prompt+answer. This could result
# on the last token from the prompt being different when tokenized on its own
# vs when done as prompt+answer.
response_token_ids_start_idx = len(prompt_input_ids)
# If tokenized prompt is different than both prompt+answer, then it means the
# last token has changed due to merging.
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
response_token_ids_start_idx -= 1
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
if len(prompt_input_ids) != len(prompt_attention_mask):
raise ValueError("Prompt input ids and attention mask should have the same length.")
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
return dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
"""Tokenize a single row from a CPO specific dataset.
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + chosen or prompt + rejected responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
We also create the labels for the chosen/rejected responses, which are of length equal to
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
"""
batch = {}
prompt = feature["prompt"]
chosen = feature["chosen"]
rejected = feature["rejected"]
if not self.is_encoder_decoder:
# Check issues below for more details
# 1. https://github.com/huggingface/trl/issues/907
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
# 3. https://github.com/LianjiaTech/BELLE/issues/337
if not isinstance(prompt, str):
raise ValueError(f"prompt should be an str but got {type(prompt)}")
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
if not isinstance(chosen, str):
raise ValueError(f"chosen should be an str but got {type(chosen)}")
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
if not isinstance(rejected, str):
raise ValueError(f"rejected should be an str but got {type(rejected)}")
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
# Last prompt token might get merged by tokenizer and
# it should not be included for generation if that happens
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
for k, v in prompt_tokens.items():
prompt_tokens[k] = v[:prompt_len_input_ids]
# Make sure prompts only have one different token at most an
# and length only differs by 1 at most
num_diff_tokens = sum(
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
)
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
if num_diff_tokens > 1 or num_diff_len > 1:
raise ValueError(
"Chosen and rejected prompt_input_ids might only differ on the "
"last token due to tokenizer merge ops."
)
# add BOS token to head of prompt
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
# add EOS token to end of answer
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
# if combined sequence is too long, truncate the prompt
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
if self.truncation_mode == "keep_start":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
elif self.truncation_mode == "keep_end":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
# if that's still too long, truncate the response
for answer_tokens in [chosen_tokens, rejected_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
for k in ["input_ids", "attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
# Create labels
chosen_sequence_tokens = {
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
}
rejected_sequence_tokens = {
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
}
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
self.label_pad_token_id
] * len(chosen_tokens["prompt_input_ids"])
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
self.label_pad_token_id
] * len(rejected_tokens["prompt_input_ids"])
for k, toks in {
"chosen_": chosen_sequence_tokens,
"rejected_": rejected_sequence_tokens,
"": prompt_tokens,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}{type_key}"] = tokens
else:
chosen_tokens = self.tokenizer(
chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
)
rejected_tokens = self.tokenizer(
rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
)
prompt_tokens = self.tokenizer(
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
)
batch["chosen_labels"] = chosen_tokens["input_ids"]
batch["rejected_labels"] = rejected_tokens["input_ids"]
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["rejected_labels"])
)
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["chosen_labels"])
)
return batch
@staticmethod
def concatenated_inputs(
batch: Dict[str, Union[List, torch.LongTensor]],
is_encoder_decoder: bool = False,
label_pad_token_id: int = -100,
padding_value: int = 0,
device: Optional[torch.device] = None,
) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
Args:
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
is_encoder_decoder: Whether the model is an encoder-decoder model.
label_pad_token_id: The label pad token id.
padding_value: The padding value to use for the concatenated inputs_ids.
device: The device for the concatenated inputs.
Returns:
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
concatenated_batch = {}
if is_encoder_decoder:
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
else:
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
for k in batch:
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("chosen", "concatenated")
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
for k in batch:
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("rejected", "concatenated")
concatenated_batch[concatenated_key] = torch.cat(
(
concatenated_batch[concatenated_key],
pad_to_length(batch[k], max_length, pad_value=pad_value),
),
dim=0,
).to(device=device)
if is_encoder_decoder:
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
concatenated_batch["concatenated_attention_mask"] = (
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
)
return concatenated_batch
def cpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the CPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative CPO loss.
if self.loss_type == "sigmoid":
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
)
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
return losses, chosen_rewards, rejected_rewards
@staticmethod
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
label_pad_token_id: The label pad token id.
is_encoder_decoder: Whether the model is an encoder-decoder model.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
if not is_encoder_decoder:
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != label_pad_token_id
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = self.concatenated_inputs(
batch,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
padding_value=self.padding_value,
device=self.accelerator.device,
)
len_chosen = batch["chosen_labels"].shape[0]
model_kwargs = (
{
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
}
if self.is_encoder_decoder
else {}
)
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits
def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss
labels = concatenated_batch["concatenated_labels"].clone()
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = self.concatenated_forward(model, batch)
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)
loss = losses.mean() + policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
return loss, metrics
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
# force log the metrics
self.store_metrics(metrics, train_eval="train")
if return_outputs:
return (loss, metrics)
return loss
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
with generate_context_manager():
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
return policy_output_decoded
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
):
if not self.use_dpo_data_collator:
warnings.warn(
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
with torch.no_grad(), prediction_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
# force log the metrics
self.store_metrics(metrics, train_eval="eval")
if prediction_loss_only:
return (loss.detach(), None, None)
# logits for the chosen and rejected samples from model
logits_dict = {
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
return (loss.detach(), logits, labels)
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded = self.get_batch_samples(self.model, random_batch)
self.log(
{
"game_log": wandb.Table(
columns=["Prompt", "Policy"],
rows=[
[prompt, pol[len(prompt) :]]
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
],
)
}
)
self.state.log_history.pop()
# Base evaluation
initial_output = super().evaluation_loop(
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
)
return initial_output
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)
def _shift_right(self, input_ids):
if self.decoder_start_token_id is None:
raise ValueError(
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.decoder_start_token_id
if self.pad_token_id is None:
raise ValueError("model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
return shifted_input_ids
@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "cpo" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)

View File

@ -121,7 +121,7 @@ class DPOTrainer(Trainer):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
precompute_ref_log_probs (`bool`, defaults to `False`):
Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train
Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful if you want to train
without the reference model and reduce the total GPU memory needed.
dataset_num_proc (`Optional[int]`, *optional*):
The number of workers to use to tokenize the data. Defaults to None.
@ -135,6 +135,8 @@ class DPOTrainer(Trainer):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
reference_free (`bool`):
If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
force_use_ref_model (`bool`, defaults to `False`):
In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`.
"""
_tag_names = ["trl", "dpo"]
@ -173,6 +175,7 @@ class DPOTrainer(Trainer):
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
reference_free: bool = False,
force_use_ref_model: bool = False,
):
if model_init_kwargs is None:
model_init_kwargs = {}
@ -213,10 +216,11 @@ class DPOTrainer(Trainer):
if isinstance(model, PeftModel):
model = model.merge_and_unload()
if ref_model is not None:
if ref_model is not None and not force_use_ref_model:
raise ValueError(
"You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
" model. Please pass `ref_model=None` in case you want to train PEFT adapters."
" model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init."
" if you want to use a different ref_model."
)
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
@ -226,12 +230,12 @@ class DPOTrainer(Trainer):
inspect.signature(prepare_model_for_kbit_training).parameters
)
preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
@ -250,7 +254,7 @@ class DPOTrainer(Trainer):
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
self._peft_has_been_casted_to_bf16 = True
# For models that use gradient_checkpoiting, we need to attach a hook that enables input
# For models that use gradient_checkpointing, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
elif getattr(args, "gradient_checkpointing", False):
@ -731,10 +735,10 @@ class DPOTrainer(Trainer):
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=batch["rejected_labels"]
labels=torch.tensor(batch["rejected_labels"])
)
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=batch["chosen_labels"]
labels=torch.tensor(batch["chosen_labels"])
)
return batch
@ -1076,6 +1080,8 @@ class DPOTrainer(Trainer):
with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
loss = loss.to(self.args.device)
# force log the metrics
self.store_metrics(metrics, train_eval="train")
@ -1086,7 +1092,7 @@ class DPOTrainer(Trainer):
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""
# If one uses `generate_during_eval` with peft + bf16, we need to explictly call generate with
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
@ -1242,7 +1248,7 @@ class DPOTrainer(Trainer):
@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
Overwrite the `push_to_hub` method in order to force-add the tag "dpo" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)

84
trl/trainer/kto_config.py Normal file
View File

@ -0,0 +1,84 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional
from transformers import TrainingArguments
@dataclass
class KTOConfig(TrainingArguments):
r"""
KTOConfig collects all training arguments related to the [`KTOTrainer`] class.
Using [`HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
max_length (`int`, *optional*, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, *optional*, defaults to `None`):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int`, *optional*, defaults to `None`):
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
beta (`float`, defaults to 0.1):
The beta factor in KTO loss. Higher beta means less divergence from the initial policy.
desirable_weight (`float`, *optional*, defaults to 1.0):
The desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
undesirable_weight (`float`, *optional*, defaults to 1.0):
The undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
padding_value (`int`, defaults to `0`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
generate_during_eval (`bool`, defaults to `False`):
Whether to sample and log generations during evaluation step.
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
If no model is provided, we need to know if the model_init returns an encoder-decoder.
precompute_ref_log_probs (`bool`, defaults to `False`):
Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful if you want to train
without the reference model and reduce the total GPU memory needed.
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string.
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string.
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the datasets.
"""
max_length: Optional[int] = None
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
max_prompt_length: Optional[int] = None
"""The maximum length of the prompt. This argument is required if you want to use the default data collator."""
max_completion_length: Optional[int] = None
"""The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder."""
beta: float = 0.1
"""The beta factor in KTO loss. Higher beta means less divergence from the initial policy."""
desirable_weight: Optional[float] = 1.0
"""The desirable losses are weighed by this factor."""
undesirable_weight: Optional[float] = 1.0
"""The undesirable losses are weighed by this factor."""
label_pad_token_id: int = -100
padding_value: int = None
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
precompute_ref_log_probs: bool = False
model_init_kwargs: Optional[Dict] = None
ref_model_init_kwargs: Optional[Dict] = None
dataset_num_proc: Optional[int] = None

1314
trl/trainer/kto_trainer.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -61,6 +61,9 @@ class ModelConfig:
default=None,
metadata={"help": ("Model layers to unfreeze & train")},
)
lora_task_type: str = field(
default="CAUSAL_LM", metadata={"help": "The task_type to pass for LoRA (use SEQ_CLS for reward modeling)"}
)
load_in_8bit: bool = field(
default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}
)
@ -82,3 +85,6 @@ class ModelConfig:
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
if self.lora_target_modules == ["all-linear"]:
self.lora_target_modules = "all-linear"

View File

@ -0,0 +1,71 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional
from transformers import TrainingArguments
@dataclass
class ORPOConfig(TrainingArguments):
r"""
ORPOConfig collects all training arguments related to the [`ORPOTrainer`] class.
Using [`HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
max_length (`int`, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, defaults to `None`):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
max_completion_length (`int`, defaults to `None`):
The maximum length of the completions. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
beta (`float`, defaults to 0.1):
The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
padding_value (`int`, defaults to `None`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
generate_during_eval (`bool`, defaults to `False`):
Whether to sample and log generations during evaluation step.
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
If no model is provided, we need to know if the model_init returns an encoder-decoder.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model`.
model_init_kwargs (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
dataset_num_proc (`Optional[int]`, *optional*):
The number of workers to use to tokenize the data. Defaults to None.
"""
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
beta: float = 0.1
disable_dropout: bool = True
label_pad_token_id: int = -100
padding_value: int = None
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
model_init_kwargs: Optional[Dict] = None
dataset_num_proc: Optional[int] = None

955
trl/trainer/orpo_trainer.py Normal file
View File

@ -0,0 +1,955 @@
# ORPO Authors: Jiwoo Hong, Noah Lee, and James Thorne
# Official code: https://github.com/xfactlab/orpo
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import random
import warnings
from collections import defaultdict
from contextlib import nullcontext
from copy import deepcopy
from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_torch_fx_proxy
from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper
from .orpo_config import ORPOConfig
from .utils import (
DPODataCollatorWithPadding,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
)
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
if is_wandb_available():
import wandb
if is_deepspeed_available():
import deepspeed
class ORPOTrainer(Trainer):
r"""
Initialize ORPOTrainer.
Args:
model (`transformers.PreTrainedModel`):
The model to train, preferably an `AutoModelForSequenceClassification`.
args (`ORPOConfig`):
The ORPO config arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
tokenizer (`transformers.PreTrainedTokenizerBase`):
The tokenizer to use for training. This argument is required if you want to use the default data collator.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
callbacks (`List[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
"""
_tag_names = ["trl", "orpo"]
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[ORPOConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[Dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
):
if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
model_init_kwargs["torch_dtype"] = (
model_init_kwargs["torch_dtype"]
if model_init_kwargs["torch_dtype"] in ["auto", None]
else getattr(torch, model_init_kwargs["torch_dtype"])
)
if isinstance(model, str):
warnings.warn(
"You passed a model_id to the ORPOTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# if model is a peft model and we have a peft_config, we merge and unload it first
if isinstance(model, PeftModel):
model = model.merge_and_unload()
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if _support_gc_kwargs:
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# get peft model with the given config
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
self._peft_has_been_casted_to_bf16 = True
# For models that use gradient_checkpointing, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if args.generate_during_eval and not is_wandb_available():
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve."
)
if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif args.is_encoder_decoder is None:
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
else:
self.is_encoder_decoder = args.is_encoder_decoder
if self.is_encoder_decoder:
self.decoder_start_token_id = model.config.decoder_start_token_id
self.pad_token_id = model.config.pad_token_id
if tokenizer is None:
raise ValueError("tokenizer must be specified to tokenize a ORPO dataset.")
if args.max_length is None:
warnings.warn(
"`max_length` is not set in the ORPOConfig's init"
" it will default to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
else:
max_length = args.max_length
if args.max_prompt_length is None:
warnings.warn(
"`max_prompt_length` is not set in the ORPOConfig's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
else:
max_prompt_length = args.max_prompt_length
if args.max_completion_length is None and self.is_encoder_decoder:
warnings.warn(
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
self.max_completion_length = 128
else:
self.max_completion_length = args.max_completion_length
if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=tokenizer.pad_token_id,
label_pad_token_id=args.label_pad_token_id,
is_encoder_decoder=self.is_encoder_decoder,
)
if args.remove_unused_columns:
args.remove_unused_columns = False
# warn users
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_dpo_data_collator = True
else:
self.use_dpo_data_collator = False
if args.disable_dropout:
disable_dropout_in_model(model)
self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
self.max_prompt_length = max_prompt_length
self.truncation_mode = args.truncation_mode
self.tokenizer = tokenizer
self.beta = args.beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# tokenize the dataset
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
if not hasattr(self, "accelerator"):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
Reference:
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
"""
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
# Prepare input tokens for token by token comparison
full_input_ids = np.array(full_tokenized["input_ids"])
if len(full_input_ids) != len(full_concat_input_ids):
raise ValueError("Prompt input ids and answer input ids should have the same length.")
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
# can be merged together when tokenizing prompt+answer. This could result
# on the last token from the prompt being different when tokenized on its own
# vs when done as prompt+answer.
response_token_ids_start_idx = len(prompt_input_ids)
# If tokenized prompt is different than both prompt+answer, then it means the
# last token has changed due to merging.
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
response_token_ids_start_idx -= 1
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
if len(prompt_input_ids) != len(prompt_attention_mask):
raise ValueError("Prompt input ids and attention mask should have the same length.")
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
return dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
"""Tokenize a single row from a ORPO specific dataset.
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + chosen or prompt + rejected responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
We also create the labels for the chosen/rejected responses, which are of length equal to
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
"""
batch = {}
prompt = feature["prompt"]
chosen = feature["chosen"]
rejected = feature["rejected"]
if not self.is_encoder_decoder:
# Check issues below for more details
# 1. https://github.com/huggingface/trl/issues/907
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
# 3. https://github.com/LianjiaTech/BELLE/issues/337
if not isinstance(prompt, str):
raise ValueError(f"prompt should be an str but got {type(prompt)}")
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
if not isinstance(chosen, str):
raise ValueError(f"chosen should be an str but got {type(chosen)}")
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
if not isinstance(rejected, str):
raise ValueError(f"rejected should be an str but got {type(rejected)}")
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
# Last prompt token might get merged by tokenizer and
# it should not be included for generation if that happens
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
for k, v in prompt_tokens.items():
prompt_tokens[k] = v[:prompt_len_input_ids]
# Make sure prompts only have one different token at most an
# and length only differs by 1 at most
num_diff_tokens = sum(
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
)
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
if num_diff_tokens > 1 or num_diff_len > 1:
raise ValueError(
"Chosen and rejected prompt_input_ids might only differ on the "
"last token due to tokenizer merge ops."
)
# add BOS token to head of prompt
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
# add EOS token to end of answer
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
# if combined sequence is too long, truncate the prompt
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
if self.truncation_mode == "keep_start":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
elif self.truncation_mode == "keep_end":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
# if that's still too long, truncate the response
for answer_tokens in [chosen_tokens, rejected_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
for k in ["input_ids", "attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
# Create labels
chosen_sequence_tokens = {
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
}
rejected_sequence_tokens = {
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
}
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
self.label_pad_token_id
] * len(chosen_tokens["prompt_input_ids"])
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
self.label_pad_token_id
] * len(rejected_tokens["prompt_input_ids"])
for k, toks in {
"chosen_": chosen_sequence_tokens,
"rejected_": rejected_sequence_tokens,
"": prompt_tokens,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}{type_key}"] = tokens
else:
chosen_tokens = self.tokenizer(
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
)
rejected_tokens = self.tokenizer(
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
)
prompt_tokens = self.tokenizer(
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
)
batch["chosen_labels"] = chosen_tokens["input_ids"]
batch["rejected_labels"] = rejected_tokens["input_ids"]
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["rejected_labels"])
)
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["chosen_labels"])
)
return batch
@staticmethod
def concatenated_inputs(
batch: Dict[str, Union[List, torch.LongTensor]],
is_encoder_decoder: bool = False,
label_pad_token_id: int = -100,
padding_value: int = 0,
device: Optional[torch.device] = None,
) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
Args:
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
is_encoder_decoder: Whether the model is an encoder-decoder model.
label_pad_token_id: The label pad token id.
padding_value: The padding value to use for the concatenated inputs_ids.
device: The device for the concatenated inputs.
Returns:
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
concatenated_batch = {}
if is_encoder_decoder:
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
else:
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
for k in batch:
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("chosen", "concatenated")
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
for k in batch:
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("rejected", "concatenated")
concatenated_batch[concatenated_key] = torch.cat(
(
concatenated_batch[concatenated_key],
pad_to_length(batch[k], max_length, pad_value=pad_value),
),
dim=0,
).to(device=device)
if is_encoder_decoder:
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
concatenated_batch["concatenated_attention_mask"] = (
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
)
return concatenated_batch
def odds_ratio_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the ORPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
The `log(sigmoid(log_odds_chosen))` for logging purposes.
"""
# Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
)
sig_ratio = F.sigmoid(log_odds)
ratio = torch.log(sig_ratio)
losses = self.beta * ratio
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio).item(), torch.mean(log_odds).item()
@staticmethod
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
label_pad_token_id: The label pad token id.
is_encoder_decoder: Whether the model is an encoder-decoder model.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
if not is_encoder_decoder:
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != label_pad_token_id
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = self.concatenated_inputs(
batch,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
padding_value=self.padding_value,
device=self.accelerator.device,
)
len_chosen = batch["chosen_labels"].shape[0]
model_kwargs = (
{
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
}
if self.is_encoder_decoder
else {}
)
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits
def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss
if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = self.concatenated_forward(model, batch)
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
return loss, metrics
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
# force log the metrics
self.store_metrics(metrics, train_eval="train")
if return_outputs:
return (loss, metrics)
return loss
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
with generate_context_manager():
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
return policy_output_decoded
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
):
if not self.use_dpo_data_collator:
warnings.warn(
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
with torch.no_grad(), prediction_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
# force log the metrics
self.store_metrics(metrics, train_eval="eval")
if prediction_loss_only:
return (loss.detach(), None, None)
# logits for the chosen and rejected samples from model
logits_dict = {
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
return (loss.detach(), logits, labels)
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded = self.get_batch_samples(self.model, random_batch)
self.log(
{
"game_log": wandb.Table(
columns=["Prompt", "Policy"],
rows=[
[prompt, pol[len(prompt) :]]
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
],
)
}
)
self.state.log_history.pop()
# Base evaluation
initial_output = super().evaluation_loop(
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
)
return initial_output
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)
def _shift_right(self, input_ids):
if self.decoder_start_token_id is None:
raise ValueError(
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.decoder_start_token_id
if self.pad_token_id is None:
raise ValueError("model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
return shifted_input_ids
@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "orpo" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)

View File

@ -53,7 +53,12 @@ from ..core import (
stats_to_np,
)
from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from ..models import (
SUPPORTED_ARCHITECTURES,
PreTrainedModelWrapper,
create_reference_model,
unwrap_model_for_generation,
)
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
@ -470,15 +475,14 @@ class PPOTrainer(BaseTrainer):
**generation_kwargs,
)
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = self._generate_batched(
ref_model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
ref_response = self._generate_batched(
ref_model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
else:
if len(query_tensor.shape) == 2:
@ -488,12 +492,17 @@ class PPOTrainer(BaseTrainer):
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
response = self.accelerator.unwrap_model(self.model).generate(
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
with unwrap_model_for_generation(
ref_model, self.accelerator, is_peft_model=self.is_peft_model
) as unwrapped_model:
ref_response = unwrapped_model.generate(
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
)
if not return_prompt and not self.is_encoder_decoder:
response = response[:, query_tensor.shape[0] :]
@ -543,7 +552,8 @@ class PPOTrainer(BaseTrainer):
return_tensors="pt",
).to(self.current_device)
generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs)
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)
for generation, mask in zip(generations, padded_inputs["attention_mask"]):
if not self.is_encoder_decoder:

View File

@ -137,7 +137,7 @@ class RewardTrainer(Trainer):
inspect.signature(prepare_model_for_kbit_training).parameters
)
preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
warnings.warn(
@ -145,9 +145,9 @@ class RewardTrainer(Trainer):
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
)
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
model = get_peft_model(model, peft_config)
@ -196,17 +196,17 @@ class RewardTrainer(Trainer):
else:
self.use_reward_data_collator = False
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
tokenizer,
model_init,
compute_metrics,
callbacks,
optimizers,
preprocess_logits_for_metrics,
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Add tags for models that have been loaded with the correct transformers version

View File

@ -17,6 +17,7 @@ import warnings
from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union
import datasets
import torch
import torch.nn as nn
from accelerate.state import PartialState
@ -42,6 +43,7 @@ from ..import_utils import is_peft_available
from .utils import (
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RichProgressCallback,
neftune_post_forward_hook,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
@ -116,6 +118,8 @@ class SFTTrainer(Trainer):
Dict of Optional kwargs to pass when instantiating the model from a string
dataset_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when creating packed or non-packed datasets
eval_packing: (`Optional[bool]`, *optional*):
Whether to pack the eval dataset as well. Defaults to `packing` if `None` is passed.
"""
_tag_names = ["trl", "sft"]
@ -124,7 +128,7 @@ class SFTTrainer(Trainer):
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
data_collator: Optional[DataCollator] = None, # type: ignore
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
@ -146,6 +150,7 @@ class SFTTrainer(Trainer):
neftune_noise_alpha: Optional[float] = None,
model_init_kwargs: Optional[Dict] = None,
dataset_kwargs: Optional[Dict] = None,
eval_packing: Optional[bool] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
@ -183,15 +188,26 @@ class SFTTrainer(Trainer):
inspect.signature(prepare_model_for_kbit_training).parameters
)
gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
preprare_model_kwargs = {
is_sharded_qlora = False
# Below is to support QLoRA + FSDP / DS-Zero3 - one should never call
# peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing
# QLoRA + FSDP / DS-Zero3
if getattr(model, "is_loaded_in_4bit", False):
for _, param in model.named_parameters():
if param.__class__.__name__ == "Params4bit":
is_sharded_qlora = param.data.device.type == "cpu"
break
if getattr(model, "is_loaded_in_8bit", False) or (
getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora
):
prepare_model_kwargs = {
"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)
}
if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs
prepare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
if args is not None:
args = dataclasses.replace(args, gradient_checkpointing=False)
@ -210,7 +226,12 @@ class SFTTrainer(Trainer):
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model = get_peft_model(model, peft_config)
if args is not None and args.bf16 and getattr(model, "is_loaded_in_4bit", False):
if (
args is not None
and args.bf16
and getattr(model, "is_loaded_in_4bit", False)
and not is_sharded_qlora
):
peft_module_casting_to_bf16(model)
if tokenizer is None:
@ -274,11 +295,14 @@ class SFTTrainer(Trainer):
if eval_dataset is not None:
_multiple = isinstance(eval_dataset, dict)
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
eval_packing = packing if eval_packing is None else eval_packing
for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
_eval_dataset,
tokenizer,
packing,
eval_packing,
dataset_text_field,
max_seq_length,
formatting_func,
@ -322,6 +346,12 @@ class SFTTrainer(Trainer):
elif self.args.max_steps == -1 and packing:
self.train_dataset.infinite = False
if any(isinstance(callback, RichProgressCallback) for callback in self.callback_handler.callbacks):
for callback in self.callback_handler.callbacks:
# Remove the PrinterCallback to avoid duplicated prints in case we passed a `RichProgressCallback`
if callback.__class__.__name__ == "PrinterCallback":
self.callback_handler.pop_callback(callback)
@wraps(Trainer.train)
def train(self, *args, **kwargs):
# Activate neftune right before training.
@ -367,12 +397,27 @@ class SFTTrainer(Trainer):
remove_unused_columns=True,
append_concat_token=True,
add_special_tokens=True,
skip_prepare_dataset=False,
):
if dataset is None:
raise ValueError("The dataset should not be None")
if skip_prepare_dataset:
return dataset
# If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is
# a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset
column_names = (
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None
)
if column_names and "input_ids" in column_names:
return dataset
# check if torch dataset / dataloader and do nothing
if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)):
# see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check
if isinstance(
dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)
) and not isinstance(dataset, datasets.IterableDataset):
return dataset
if not packing:
@ -484,6 +529,9 @@ class SFTTrainer(Trainer):
add_special_tokens=add_special_tokens,
)
if isinstance(dataset, datasets.IterableDataset):
return constant_length_iterator
def data_generator(constant_length_iterator):
yield from constant_length_iterator

View File

@ -20,9 +20,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from accelerate import PartialState
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase
from transformers.trainer import TrainerCallback
from transformers.trainer_utils import has_length
from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available
from ..trainer.model_config import ModelConfig
@ -319,7 +325,7 @@ class DPODataCollatorWithPadding:
padding_value = self.pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k):
elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k):
padding_value = self.label_pad_token_id
else:
raise ValueError(f"Unexpected key in batch '{k}'")
@ -715,14 +721,97 @@ def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]":
if model_config.use_peft is False:
return None
if not is_peft_available():
raise ValueError(
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
"Make sure to run `pip install -U peft`."
)
peft_config = LoraConfig(
r=model_config.lora_r,
lora_alpha=model_config.lora_alpha,
lora_dropout=model_config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
task_type=model_config.lora_task_type,
target_modules=model_config.lora_target_modules,
modules_to_save=model_config.lora_modules_to_save,
)
return peft_config
class RichProgressCallback(TrainerCallback):
"""
A [`TrainerCallback`] that displays the progress of training or evaluation using Rich.
"""
def __init__(self):
self.training_bar = None
self.prediction_bar = None
self.training_task_id = None
self.prediction_task_id = None
self.rich_group = None
self.rich_console = None
self.training_status = None
self.current_step = None
def on_train_begin(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.training_bar = Progress()
self.prediction_bar = Progress()
self.rich_console = Console()
self.training_status = self.rich_console.status("Nothing to log yet ...")
self.rich_group = Live(Panel(Group(self.training_bar, self.prediction_bar, self.training_status)))
self.rich_group.start()
self.training_task_id = self.training_bar.add_task("[blue]Training the model", total=state.max_steps)
self.current_step = 0
def on_step_end(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.training_bar.update(self.training_task_id, advance=state.global_step - self.current_step, update=True)
self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_world_process_zero and has_length(eval_dataloader):
if self.prediction_task_id is None:
self.prediction_task_id = self.prediction_bar.add_task(
"[blue]Predicting on the evaluation dataset", total=len(eval_dataloader)
)
self.prediction_bar.update(self.prediction_task_id, advance=1, update=True)
def on_evaluate(self, args, state, control, **kwargs):
if state.is_world_process_zero:
if self.prediction_task_id is not None:
self.prediction_bar.remove_task(self.prediction_task_id)
self.prediction_task_id = None
def on_predict(self, args, state, control, **kwargs):
if state.is_world_process_zero:
if self.prediction_task_id is not None:
self.prediction_bar.remove_task(self.prediction_task_id)
self.prediction_task_id = None
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero and self.training_bar is not None:
_ = logs.pop("total_flos", None)
self.training_status.update(f"[bold green]Status = {str(logs)}")
def on_train_end(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.rich_group.stop()
self.training_bar = None
self.prediction_bar = None
self.training_task_id = None
self.prediction_task_id = None
self.rich_group = None
self.rich_console = None
self.training_status = None
self.current_step = None