Compare commits

...

99 Commits

Author SHA1 Message Date
f2c71771cc Release: v0.8.0 (#1453)
* Release: v0.7.12

* 0.8.0 instead
2024-03-19 17:19:38 +01:00
631c33cbb3 FEAT: Update README to add DPO + CLIs (#1448)
* Update README.md

* Update README.md

* move dpo/ppo description to docs

* rework readme

* Update README.md

---------

Co-authored-by: leandro <leandro.vonwerra@spoud.io>
2024-03-19 16:55:56 +01:00
3f7ff60528 model --> model_name_or_path (#1452)
* `model` --> `model_name_or_path`

* fix style
2024-03-19 16:52:42 +01:00
1705aebeba Fix yaml parsing issue (#1450) 2024-03-19 16:07:50 +01:00
4e622a9033 chat cli (#1431)
* first draft

* move chat to cli

* fix makefile

* make script less verbose

* fix parsing

* fix style

* add more examples

* fix setup.py

* add copyright

* fix verbose init

* attribute FastChat

* add docs
2024-03-19 12:37:06 +01:00
eb2d5b2972 CI / CLI: Properly raise error when CLI tests failed (#1446)
* properly raise error

* another fix

* Update tests.yml

* Update tests-main.yml
2024-03-19 11:39:07 +01:00
f976c6d234 Before update the tr_loss, make sure tr_loss_step is in the same device. (#1439)
* before update the loss from dpo, make sure it's in the same device of tr_loss

* Update trl/trainer/dpo_trainer.py

Co-authored-by: guy1992l <83535508+guy1992l@users.noreply.github.com>

---------

Co-authored-by: guy1992l <83535508+guy1992l@users.noreply.github.com>
2024-03-19 10:28:44 +01:00
abc7301bab Fix PPOTrainer README example (#1441)
* Fix example

* Delete newline
2024-03-19 10:18:49 +01:00
6cfa5cfc81 fix doc build on main (#1437) 2024-03-18 14:24:02 +01:00
a2aa0f0b09 FEAT: Add CLIs in TRL ! (#1419)
* CLI V1

* v1 CLI

* add rich enhancmeents

* revert unindented change

* some comments

* cleaner CLI

* fix

* fix

* remove print callback

* move to cli instead of trl_cli

* revert unneeded changes

* fix test

* Update trl/commands/sft.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* remove redundant strings

* fix import issue

* fix other issues

* add packing

* add config parser

* some refactor

* cleaner

* add example config yaml file

* small refactor

* change a bit the logic

* fix issues here and there

* add CLI in docs

* move to examples/sft

* remove redundant licenses

* make it work on dpo

* set to None

* switch to accelerate and fix many things

* add docs

* more docs

* added tests

* doc clarification

* more docs

* fix CI for windows and python 3.8

* fix

* attempt to fix CI

* fix?

* test

* fix

* tweak?

* fix

* test

* another test

* fix

* test

* fix

* fix

* fix

* skip tests for windows

* test @lvwerra approach

* make dev

* revert unneeded changes

* fix sft dpo

* optimize a bit

* address final comments

* update docs

* final comment

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2024-03-18 12:20:54 +01:00
304e208f77 Create standard dataset for TRL (#1424)
* add scripts to create standard dataset

* precommit

* push changes

* add script to play with
2024-03-14 10:57:48 -04:00
4fe8b027f6 [Kto] torch_dtype kwargs fix (#1429)
* set torch_dtype from string type

* fix test
2024-03-14 13:49:44 +01:00
fb6ebb1e11 [KTO] fix tokenization bugs (#1418)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* fix tokenization error: lack of bos

* change user warning for weight hyperparams

* minor update to docs

* reshape attention mask

* reformat

* add test for bos/eos tokens

* move dependency location

* Update tests/test_kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-14 08:22:50 +01:00
66078c7c01 CI: Fix CI on main (#1422)
* fix CI on main

* final fix
2024-03-13 13:54:22 +01:00
58c0888996 Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA (#1416)
* don't do mp casting

* don't use `prepare_for_kbit` when using fsdp+qlora or dsz3+qlora

* changes to enable fsdp+qlora and dsz3+qlora

* revert

* Update sft_trainer.py

* quality

* fix deprecation using changes from PR https://github.com/huggingface/trl/pull/1415

* fixes

* quality

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* quality

* relaunch tests

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-03-13 10:43:45 +01:00
486e7a4071 model init when args are given (#1413)
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2024-03-11 13:47:37 +01:00
7630f877f9 Fix import error from deprecation in transformers (#1415)
* Fix import error from  deprecation in transformers

* Fix import path
2024-03-11 13:23:56 +01:00
4d862da181 [KTO] fix various bugs (#1402)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

* fix choice of mismatched examples for KL term

* describe weights

* fix hanging issue in distributed training

* linting

* move metrics to cpu

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-08 12:04:52 +01:00
22b4f548f4 fix RM script (#1393) 2024-03-07 08:49:52 +01:00
4219cbfedc Fix the pad_token_id error (#1394)
* Fix the pad_token_id error

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add the load_in_8bit argument in rl_training.py

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Reformate the patch

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Fix the check failed

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-03-05 02:18:42 +01:00
3bd02380c7 Log ddpo reward as float to fix numpy conversion during bf16 training (#1391) 2024-03-04 02:50:50 +01:00
067db7553a [KTO] prevent nans from appearing in metrics (#1386)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

* remove nans in metrics by gathering across machines

* fix formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-01 12:19:55 +01:00
93e85ed808 [KTO] merge eval dataset only if it exists (#1383)
* merge eval dataset if it exists

* add eval dataset test
2024-03-01 12:15:14 +01:00
14e0d78807 fix bugs in KTO implementation (#1380)
* add warning for imbalanced data

* update documentation

* update script commands to be same as in dpo

* use batch_size KL examples and batch_size target examples to calculate batch_size losses

* fix deepspeed issue

* speed up forward with no_grad for KL

* add some removed metrics

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

* Update trl/trainer/kto_trainer.py

add reference to paper

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* add more detailed comments

* convert assert to ValueError

* Update kto_trainer.py

* precommit formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-02-29 09:01:52 +01:00
b32656f726 FIX: Fix the CI again .. (#1374)
* Update tests-main.yml

* Update tests-main.yml

* Update tests-main.yml

* Update .github/workflows/tests-main.yml

* Update tests-main.yml

* Update tests-main.yml
2024-02-27 12:46:20 +01:00
9399bc113b Update tests-main.yml (#1373) 2024-02-27 12:07:50 +01:00
11f122ad49 Update tests-main.yml (#1372) 2024-02-27 11:45:02 +01:00
009c9a610b feature request add force_use_ref_model (#1367) 2024-02-27 11:19:16 +01:00
7712d42f8c add eval_packing (#1369) 2024-02-27 11:19:06 +01:00
7c2213b9e5 add ci message sending on TRL (#1370) 2024-02-27 11:18:55 +01:00
ddeebce176 Add some arguments for support XPU (#1366)
* Add use_bnb and load_in_4bit arguments.

Make it optional and not supported on all platforms

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Change the use_reentrant default value to False

If the default value of gradient_checkpointing is True, set the
use_reentrant default value as False. Because the following error
happens.

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 191 with name base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add model_dtype for loading the model in model_dtype

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Reformate the patch

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
2024-02-27 02:49:16 +01:00
cf68d871cf Fix version for Python<3.8 (#1363) 2024-02-27 02:41:09 +01:00
2a2676e7ec set seed in sft/dpo/reward_modeling to make result reproducable (#1357)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-02-23 11:12:45 +01:00
ca90cba351 fix 8-bit multi-gpu training bug (#1353)
* fix 8-bit multi-gpu training bug see https://github.com/huggingface/trl/issues/1348

* Update dpo_llama2.py

make gradient_checkpointing_kwargs configurable.

* Update dpo_llama2.py

remote unnecessary config of device_map

* format with make precommit

---------

Co-authored-by: ubuntu <lili@liveremier.ai>
2024-02-23 03:58:43 +01:00
4f97fb4a74 more userfriendly (#1350) 2024-02-22 10:06:35 +01:00
a46cd84a64 Kto trainer (#1181)
* initial file

* initial tokenizer

* UnpairedPreferenceBatchSampler

* use batch_sampler

* use interleave_datasets

* add loss

* fix imports

* use SequentialSampler when training

* formatting

* add other helpers

* add prediction_step

* fix the kto pair docs

* tests

* compute_reference_log_probs

* add get_eval_dataloader

* fix typo

* kto with is_encoder_decoder true

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* fixed typo

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* renamed KTO dataset keys

* use DPOTrainer's get_batch_logps

* add get_batch_samples

* typo

* Handle last token in prompt

* Create KTOConfig class that subclasses transformers.TrainingArguments

* Update KTO tests to handle KTOConfig

* Update KTO script to use KTOConfig

* formatting

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/training_configs.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use max_completion_length

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add back get_batch_logps

* use max_completion_length

* move config to its own file

* Check tokenize params on Trainer init

* Clone labels for end-dec model to solve RuntimeError

* formatting

* fix enc-dec later

* completion_decoder_input_ids is optional for enc-dec

* fix breaking test

* add a kl key for KL estimation with shuffled completion

* add loss ad weights

* fix bug in chosen_idx

* add back metrics

* fix typos

* fix kto_loss docs

* typo

* set loss to None when there is no target completions in batch

* use nan tensor instead of none

* fix reference_logps test

* fix logits

* a bit more robust options

* log only the correct prompt-completion during eval

* Update trl/trainer/kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/kto_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add docs for desirable_weight and undesirable_weight args

* dropout is always disabled

* remove DDP hack

* formatting

* move more arguments of trainer to config

* comment out T5 test for now

* Add docstring to KTOTrainer

* moved Config docstrings to the appropriate class

* add autodoc to markdown

* formatting

* updated copyright year

* add model tags

* do not add BOS to start of completion

* Move data_collator to KTOTrainer

* formatting

* data_collator is not in args

* shuffle_completion with specific input_columns

* remove all but the needed columns

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/kto.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update tests/test_kto_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* moved more args to kto_config

* fjx test

* use all_exhausted strategy and shuffle after

* use KTOConfig in HfArgumentParser

* use ModelConfig

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Pablo Vicente Juan <p.vicente.juan@gmail.com>
2024-02-19 14:43:17 +01:00
1f56bffdf8 Update Example to reflect #aa35fec (#1333) 2024-02-18 14:10:04 +01:00
1bfe0b8fcb set dev version (#1332) 2024-02-16 09:49:05 +01:00
0f13e51efa Release: v0.7.11 (#1331) 2024-02-16 09:05:04 +01:00
1e77d8aeb2 [core / xxxTrainer] Automatic tagging (#1329)
* automatic tagging

* add comments

* fix tests

* fix
2024-02-15 14:47:32 +01:00
3b1911c2a9 add tests on transformers peft main (#1328) 2024-02-15 05:19:31 +01:00
851e7fe556 [core / DDPO] Fix diffusers import issue (#1314)
* fix

* more clean up
2024-02-15 04:45:27 +01:00
31b02d0cd0 Update README.md to clarify model requirement (#1315)
Clarify that language models must be transformers models for text.  This is a bit redundant with intro description, but attempts to better address a question that that comes up (issue 1257).

Closes: #1257
2024-02-15 04:38:17 +01:00
9bc478ecbb pre-commit: replace linters + formatters with Ruff; fix some issues (#1300)
* pre-commit: replace linters + formatters with Ruff

* Don't use bare except

* Clean up `noqa`s

* Enable Ruff UP; apply auto-fixes

* Enable Ruff B; apply fixes

* Enable Ruff T with exceptions

* Enable Ruff C (complexity); autofix

* Upgrade Ruff to 0.2.0
2024-02-15 04:37:41 +01:00
29f162b86c Best practice recommendation update for dpo_trainer.mdx (#1325)
In the document as it is now the best practice recommendations don't seem neither consistent nor correct. 

For example, the documentation links a tweet with a recommendation to merge adaptors into a quantized model, and a script that supposedly illustrates how to apply that recommendation. But the script actually does the opposite of what the tweet recommends, first dequantizing the model. 

There are similar inconsistencies/ambiguities further in that paragraph. For example, saying that using an unquantized model would lead to lower performance (I changed it to "higher memory demand").

Overall, I updated the paragraph to improve consistency and provided links to slightly more evidence-based merging recommendations.
2024-02-14 11:43:48 +01:00
6852097169 Fix PPOTrainer argument train_dataset -> dataset (#1321)
Both the argument's name as well as the value need to be renamed.
Otherwise we get both

NameError: name 'train_dataset' is not defined

and

TypeError: PPOTrainer.__init__() got an unexpected keyword argument 'train_dataset'
2024-02-06 22:37:04 +01:00
f12a1da74b Fix AttributeError in dpo_trainer for reference_free case in dpo_loss function (#1313)
* Update dpo_trainer.py

update reference_free parameter for dpo_loss

* Update dpo_trainer for reference_free case

updated the docstring typo and set device parameter to ref_logratios tensor
2024-02-02 11:02:40 +01:00
ae87b3aefa Fix typos in docs for Multi Adapter RL (MARL). (#1312)
* Fix more typos

* Fix typos in docs.
2024-02-02 07:37:08 +01:00
3f7cee7643 ENH: Run CI only if relevant files are modified (#1309)
* Update tests.yml

* Update .github/workflows/tests.yml
2024-02-01 23:49:32 +01:00
ae8431bd50 Codemod Unittest assertions to bare asserts (#1301)
* Remove stray commas from test data

* Codemod Unittest assertions to bare asserts

* Make `assertAlmostEqual` tests more idiomatic

* DRY some test strings
2024-02-01 23:49:03 +01:00
66a976c6bd Update sft_trainer.mdx to add note on launching DDP training (#1308)
As requested here: https://github.com/huggingface/trl/issues/1303#issuecomment-1920437586
2024-02-01 23:42:14 +01:00
814930377c Add num_proc arg to the eval_dataset processing (#1307) 2024-02-01 17:58:00 +01:00
88685f2cd4 Types: Fix PEP 484 implicit-optional compliance (#1297)
This was done automatically with hauntsaninja/no_implicit_optional.
2024-01-31 14:51:58 +01:00
6f40f20233 Fix DPOTrainer docstrings (#1298)
Some issues were leading the auto-generation of the API reference to fail and the args were overlapped in the documentation page
2024-01-31 14:49:41 +01:00
036213bd85 Fix sft trainer when args is None (#1295)
* fix sft trainer when args is None

* add test

* fix
2024-01-31 03:31:53 +01:00
6042596705 Fix DPO slow tests (#1292)
* Update test_dpo_slow.py

* style
2024-01-30 10:15:46 +01:00
070c75ec54 load data only on main process + fix dpo example test (#1291) 2024-01-30 10:14:22 +01:00
b415224a4a fix DPO trainer + mistral + FA2 (#1290) 2024-01-30 08:25:29 +01:00
9186710671 fix padding in dpo trainer (#1284) 2024-01-30 08:24:48 +01:00
aa35fec099 raise value error if one passes a ref_model and a peft_config (#1289) 2024-01-30 08:06:03 +01:00
737d771941 Add multiprocessing in the DPO trainer. (#1286)
* Update dpo_trainer.py

Added support for num_proc to tokenize the training dataset.

* Update dpo_trainer.py

added type in the new num_proc variable

* added test case

* add test case

* fix type

---------

Co-authored-by: imraviagrawal <ravi.agrawal@umass.edu>
Co-authored-by: Ravi Agrawal <raviagrawal@Ravis-MacBook-Pro.local>
2024-01-30 02:55:07 +01:00
ef441ea028 Update dpo_trainer.mdx (#1280) 2024-01-27 10:29:10 +01:00
af623aeba6 Fix sft ci (#1279) 2024-01-26 19:18:23 +01:00
3843cfc32f Fix SFT tuner (#1278) 2024-01-26 17:49:50 +01:00
9a71e67be9 Remove tyro (#1176)
* refactor

* Remove tyro in `ppo.py`

* quick update

* update default args

* quick push

* precommit

* refactor

* quick change

* remove tyro

* quick change

* precommit

* quick change

* fix hello_world

* remove docstring diffences

* add `module load cuda/12.1`

* push changes

* precommit

* make dpo runnable

* fix circular import

* quick fix

* refactor

* quick update

* path change

* update plots

* fix docs

* quick change

* Update trl/trainer/model_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/model_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update examples/scripts/dpo.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* address comments. use attn_implementation

* precommit

* remove duplicate code

* update peft.py

* fix test no op dep

* Update trl/trainer/utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* precommit

* add docs

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-26 07:51:15 -08:00
09ca565b24 FIx SFTTrainer bugs on TRL main (#1276)
* Update sft_trainer.py

* Update trl/trainer/sft_trainer.py
2024-01-26 13:50:37 +01:00
4edc688311 Only load data on main process (#1255)
* fix: only load data on main process

* define is_main_process once

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* avoid re-initializing PartialState on train dataset check

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* avoid re-initializing PartialState on eval dataset check

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* process dataset on main first to take advantage of caching

* fix typo in docs

* use decorator to manage state

* Revert "fix typo in docs"

This reverts commit 0880a188812a698f7106853245ce1ba96a036831.

* Revert "Revert "fix typo in docs""

This reverts commit ff7ee33fbeedcd0032b728d86a17cfcb10e43f9b.

* Revert "use decorator to manage state"

This reverts commit 7ac7a45949f621941fedc522f0d2ca7b29367c3a.

* use is_local_main_process instead of is_main_process

* fix: use context manager instead of attribute

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update trl/trainer/sft_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-26 10:38:07 +01:00
29d439a204 [DPO] average_log_prob when loss is IPO (#1265)
* average_log_prob when loss is IPO

* updated docs with the fix
2024-01-24 12:18:04 +01:00
5760e5d3db Fix typo in extra_columns variable name (#1269)
Co-authored-by: Otto Laitila <otto.laitila@op.fi>
2024-01-23 14:46:13 +01:00
a3c5b7178a Update utils.py (#1256) 2024-01-22 15:32:29 +01:00
222d275b8a set dev version (#1254) 2024-01-19 11:58:47 +01:00
09ca7607d5 Release: v0.7.10 (#1253) 2024-01-19 11:52:51 +01:00
1e68753216 fix: fix loss_type and some args desc (#1247) 2024-01-18 17:20:52 +01:00
1f59eeb9bb Fix chatml template (#1248)
* first draft

* 64

* sourabs suggestion

* wip tests

* make style happy

* add check

* docstring

* fix docstring

* Update tests/test_model_utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move tests

* add todo for abstract class

* make style happy

* add slow tests and imports

* add documentation

* sft_trainer.mdx aktualisieren

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fix template & add test

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-18 16:47:25 +01:00
928d14445e Add setup_chat_format for adding new special tokens to model for training chat models (#1242)
* first draft

* 64

* sourabs suggestion

* wip tests

* make style happy

* add check

* docstring

* fix docstring

* Update tests/test_model_utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move tests

* add todo for abstract class

* make style happy

* add slow tests and imports

* add documentation

* sft_trainer.mdx aktualisieren

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-18 11:05:32 +01:00
3319993bd1 Fix weird doc bug (#1244)
* Update utils.py

* Update trl/trainer/utils.py

* Update trl/trainer/utils.py
2024-01-18 10:48:56 +01:00
4fb3d0c860 Update sft_trainer.py (#1241) 2024-01-17 15:16:07 +01:00
bcccdeb6f9 [core / SFTTrainer] Fix breaking change (#1229)
* fix breaking change

* revert

* fix

* final fix

* fix

* fix tests
2024-01-17 14:45:22 +01:00
ef209e311f [core / tests ] v1 slow tests (#1218)
* v1 slow tests

* nit

* add qlora tests for DPO

* add decorator

* release memory + log reports

* report to none to avoid seg fault issues

* update setup

* fix

* add exampel testing

* fix nit

* change temp filename

* add workflow file

* fix comment

* add slack push script

* more tests for DPO

* add dpo example tests

* another makefile command

* fix

* add paths + clean up

* nit

* Update slow-tests.yml

* trigger tests

* up

* up

* more fixes

* fix

* final fixes

* minor fixes

* oops

* add more text

* fix

* more

* trigger CI

* up

* fix

* remove

* run the tests on 2 GPUs only

* final fix SFT

* revert config files + address comments

* fix

* add Phi

* final fixes

* final fix
2024-01-17 10:17:57 +01:00
341f6a6787 fix: improve error message when pad_token_id is not configured (#1152)
* fix: improve error message when `pad_token_id` is not configured

* Add test for error raised when pad_token is None

* Fix pre-commit errors

* Fix error in the test environment
2024-01-17 09:34:20 +01:00
97b9fa212a Update dpo_trainer.py (#1160)
Log metrics on all distributed processes
2024-01-15 15:40:44 +01:00
a7d796c9a2 Remove a repeating line in how_to_train.md (#1226) 2024-01-15 15:18:49 +01:00
fa074e6a15 Create slow-tests.yml (#1223) 2024-01-12 09:29:57 +01:00
776939dcc4 Add support for ChatML dataset format in (#1208)
* Add support for ChatML dataset format in
SFTTrainer

* fix formatting

* fix tests

* more comment

* fix intent

* fix doc string

* Update dataset_formatting.py

* Update dataset_formatting.py

* add documentation

* Update sft_trainer.mdx

* add leonardos comment and more tests

* added more tests and fixed batching

* style

* comment in
2024-01-12 08:05:32 +01:00
163ca9f059 Refactor RewardConfig to own module (#1221)
* Refactor RewardConfig to own module

* Fix init

* Fix import
2024-01-12 17:50:37 +11:00
2eeb7b04cf [core / Docker] Add workflow to build TRL docker images (#1215)
* add docker build

* Update docker/trl-latest-gpu/Dockerfile

* Update docker/trl-source-gpu/Dockerfile
2024-01-11 11:03:43 +01:00
9f8d0e48ad Fix args type (#1214)
* fix args type

* add args desc
2024-01-10 16:35:19 +01:00
c9b7145c75 Update Unsloth SFT, DPO docs (#1213)
* Update sft_trainer.mdx

* Update sft_trainer.mdx

* Update dpo_trainer.mdx

* Update dpo_trainer.mdx

* Update sft_trainer.mdx
2024-01-10 09:08:08 +01:00
baf3c1c293 Fix FSDP error (#1196)
* Fix FSDP error

Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP.

* Apply suggestions from code review

force return_dict

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2024-01-09 18:21:23 +01:00
b181e401a7 Fix shape descriptions in calculate_loss method (#1204) 2024-01-09 14:24:41 +01:00
26da9e80cb Check tokenize params on DPOTrainer (#1197)
* Check if tokenizer and max len params are None

* Update warning messages for missing parameters
2024-01-09 14:10:22 +01:00
d6cc88ab2c set dev version (#1207) 2024-01-09 13:06:30 +01:00
7a95cc8696 release: v0.7.9 (#1206) 2024-01-09 13:02:31 +01:00
d1715514de Revert "Address issue #1122 (#1174)" (#1205)
This reverts commit d57d0f9ca46a63d370b91791352edda0154576f5.
2024-01-09 10:20:50 +01:00
d116887ed4 [DPOTrainer] Fix peft + DPO + bf16 if one uses generate_during_eval or pre-computed logits (#1203)
* fix peft + DPO + bf16

* fix

* revert old behaviour

* fix tests

* fix

* fix

* fix

* fix
2024-01-09 09:35:50 +01:00
a236c5750f Fix reported KL in PPO trainer (#1180)
* Fix reported KL in PPO trainer

previously this was always reporting the estimated KL, even when using `kl_penalty = 'full'` (or `abs`, etc).
Now we return the actual KL calculated in `compute_rewards()`, and report that.

* fix test
2024-01-09 06:48:25 +01:00
4ae35afdd6 Fix instruction token masking (#1185)
* Fix instruction token masking

Fix instruction token masking if the first instruction is tokenized differently than the others, or in general if no instruction is detected before the first response.

* Bugfix for edge case

(in case either of the templates isn't found at all, ...idxs[0] might not exist)

* Add test for instruction masking fix
2024-01-09 06:41:53 +01:00
b21ed0ddbc set dev version (#1201) 2024-01-09 05:19:10 +01:00
384b868fe6 Release: v0.7.8 (#1200) 2024-01-09 05:13:26 +01:00
113 changed files with 7992 additions and 1494 deletions

127
.github/workflows/docker-build.yml vendored Normal file
View File

@ -0,0 +1,127 @@
name: Build Docker images (scheduled)
on:
workflow_dispatch:
workflow_call:
schedule:
- cron: "0 1 * * *"
concurrency:
group: docker-image-builds
cancel-in-progress: false
env:
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
jobs:
trl-latest:
name: "Latest TRL GPU"
runs-on: ubuntu-latest
steps:
- name: Cleanup disk
run: |
sudo ls -l /usr/local/lib/
sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Check out code
uses: actions/checkout@v3
- name: Login to DockerHub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@v4
with:
context: ./docker/trl-latest-gpu
push: true
tags: huggingface/trl-latest-gpu
- name: Post to a Slack channel
id: slack
#uses: slackapi/slack-github-action@v1.25.0
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
with:
# Slack channel id, channel name, or user id to post message.
# See also: https://api.slack.com/methods/chat.postMessage#channels
channel-id: ${{ env.CI_SLACK_CHANNEL }}
# For posting a rich message using Block Kit
payload: |
{
"text": "trl-latest-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
"blocks": [
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": "trl-latest-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
}
}
]
}
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
trl-source:
name: "Latest TRL + HF ecosystem from source"
runs-on: ubuntu-latest
steps:
- name: Cleanup disk
run: |
sudo ls -l /usr/local/lib/
sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Check out code
uses: actions/checkout@v3
- name: Login to DockerHub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
uses: docker/build-push-action@v4
with:
context: ./docker/trl-source-gpu
push: true
tags: huggingface/trl-source-gpu
- name: Post to a Slack channel
id: slack
#uses: slackapi/slack-github-action@v1.25.0
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
with:
# Slack channel id, channel name, or user id to post message.
# See also: https://api.slack.com/methods/chat.postMessage#channels
channel-id: ${{ env.CI_SLACK_CHANNEL }}
# For posting a rich message using Block Kit
payload: |
{
"text": "trl-source-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
"blocks": [
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": "trl-source-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
}
}
]
}
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

96
.github/workflows/slow-tests.yml vendored Normal file
View File

@ -0,0 +1,96 @@
name: Slow tests (on push)
on:
push:
branches: [ main ]
paths:
# Run only when python files are modified
- "trl/**.py"
- "examples/**.py"
env:
RUN_SLOW: "yes"
IS_GITHUB_CI: "1"
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
jobs:
run_all_tests_single_gpu:
strategy:
fail-fast: false
matrix:
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
runs-on: [self-hosted, single-gpu, nvidia-gpu, t4, ci]
env:
CUDA_VISIBLE_DEVICES: "0"
TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
container:
image: ${{ matrix.docker-image-name }}
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v3
- name: Pip install
run: |
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
- name: Run slow SFT tests on single GPU
if: always()
run: |
source activate trl
make slow_tests
- name: Generate Report
if: always()
run: |
pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
run_all_tests_multi_gpu:
strategy:
fail-fast: false
matrix:
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
runs-on: [self-hosted, multi-gpu, nvidia-gpu, t4, ci]
env:
CUDA_VISIBLE_DEVICES: "0,1"
TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
container:
image: ${{ matrix.docker-image-name }}
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v3
- name: Pip install
run: |
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
- name: Run slow SFT tests on Multi GPU
if: always()
run: |
source activate trl
make slow_tests
- name: Run end-to-end examples tests on multi GPU
if: always()
run: |
source activate trl
pip install deepspeed
make test_examples
- name: Generate Reports
if: always()
run: |
pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
rm *.txt

63
.github/workflows/tests-main.yml vendored Normal file
View File

@ -0,0 +1,63 @@
name: tests on transformers PEFT main
on:
push:
branches: [ main ]
env:
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
jobs:
tests:
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
os: ['ubuntu-latest', 'windows-latest']
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# install PEFT & transformers from source
pip install -U git+https://github.com/huggingface/peft.git
pip install -U git+https://github.com/huggingface/transformers.git
# cpu version of pytorch
pip install ".[test, diffusers]"
- name: Test with pytest
run: |
make test
- name: Post to a Slack channel
if: always()
id: slack
#uses: slackapi/slack-github-action@v1.25.0
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
with:
# Slack channel id, channel name, or user id to post message.
# See also: https://api.slack.com/methods/chat.postMessage#channels
channel-id: ${{ env.CI_SLACK_CHANNEL }}
# For posting a rich message using Block Kit
payload: |
{
"text": "TRL CI on transformers/PEFT main: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
"blocks": [
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": "TRL CI on transformers/PEFT main: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
}
}
]
}
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

View File

@ -5,6 +5,13 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
# Run only when relevant files are modified
- "trl/**.py"
- "examples/**.py"
- "scripts/**.py"
- ".github/**.yml"
- "tests/**.py"
jobs:
check_code_quality:
@ -47,7 +54,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install -e ".[test, peft, diffusers]"
pip install ".[test, peft, diffusers]"
- name: Test with pytest
run: |
make test
@ -72,4 +79,4 @@ jobs:
pip install .[test]
- name: Test with pytest
run: |
make test
make test

View File

@ -1,37 +1,10 @@
repos:
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
hooks:
- id: isort
args:
- --profile=black
- --skip-glob=wandb/**/*
- --thirdparty=wandb
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
args:
- -r
- --exclude=wandb,__init__.py
- --in-place
- --remove-unused-variables
- --remove-all-unused-imports
- repo: https://github.com/python/black
rev: 22.3.0
hooks:
- id: black
args:
- --line-length=119
- --target-version=py38
- --exclude=wandb
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
args:
- --ignore=E203,E501,W503,E128
- --max-line-length=119
- id: ruff
args: [ --fix ]
- id: ruff-format
# - repo: https://github.com/codespell-project/codespell
# rev: v2.1.0

View File

@ -5,7 +5,7 @@
Before you start contributing make sure you installed all the dev tools:
```bash
pip install -e ".[dev]"
make dev
```
## Did you find a bug?

View File

@ -2,4 +2,4 @@ include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
recursive-exclude * __pycache__

View File

@ -1,7 +1,16 @@
.PHONY: test precommit benchmark_core benchmark_aux
.PHONY: test precommit benchmark_core benchmark_aux common_tests slow_tests test_examples tests_gpu
check_dirs := examples tests trl
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
COMMAND_FILES_PATH = `pwd`/commands
dev:
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
pip install -e ".[dev]"
ln -s `pwd`/examples/scripts/ `pwd`/trl/commands
test:
python -m pytest -n auto --dist=loadfile -s -v ./tests/
@ -13,3 +22,22 @@ benchmark_core:
benchmark_aux:
bash ./benchmark/benchmark_aux.sh
tests_gpu:
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
slow_tests:
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
test_examples:
touch temp_results_sft_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
echo $$?','$${file} >> temp_results_sft_tests.txt; \
done
touch temp_results_dpo_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
done

122
README.md
View File

@ -3,7 +3,7 @@
</div>
# TRL - Transformer Reinforcement Learning
> Full stack transformer language models with reinforcement learning.
> Full stack library to fine-tune and align large language models.
<p align="center">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
@ -20,58 +20,70 @@
## What is it?
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
</div>
The `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).
`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
**Highlights:**
- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
## How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
This process is illustrated in the sketch below:
The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>
## Highlights
- **`Efficient and scalable`**:
- [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), and [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer).
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).
## Installation
### Python package
Install the library with pip:
Install the library with `pip`:
```bash
pip install trl
```
### From source
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
If you want to use the latest features before an official release you can install from source:
```bash
pip install git+https://github.com/huggingface/trl.git
```
### Repository
If you want to use the examples you can clone the repository with the following command:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .
```
If you wish to develop TRL, you should install in editable mode:
## Command Line Interface (CLI)
You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:
**SFT:**
```bash
pip install -e .
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```
**DPO:**
```bash
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --output_dir opt-sft-hh-rlhf
```
**Chat:**
```bash
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
## How to use
For more flexibility and control over the training you can use the dedicated trainer classes to fine-tune the model in Python.
### `SFTTrainer`
This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
@ -138,11 +150,10 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# initialize trainer
ppo_config = PPOConfig(
batch_size=1,
)
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)
# encode a query
query_txt = "This morning I went to the "
@ -162,13 +173,50 @@ reward = [torch.tensor(1.0)]
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```
### `DPOTrainer`
`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://arxiv.org/abs/2305.18290). This is a basic example on how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
```python
# imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer
# load model and dataset - dataset needs to be in a specific format
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
# load trainer
trainer = DPOTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)
# train
trainer.train()
```
## Development
If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
```
## References
### Proximal Policy Optimisation
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
### Language models
The language models utilize the `transformers` library by 🤗 Hugging Face.
### Direct Preference Optimization
DPO is based on the original implementation of **"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"** by E. Mitchell et al. \[[paper](), [code](https://github.com/eric-mitchell/direct-preference-optimization)]
## Citation

View File

@ -1,20 +1,5 @@
#### Step 1: create a work directory:
# this is necessary because another github action job will remove
# the entire directory, which slurm depends on.
# https://stackoverflow.com/questions/4632028/how-to-create-a-temporary-directory
MY_SLURM_TMP_DIR=/fsx/costa/slurm_tmpdir
mkdir -p $MY_SLURM_TMP_DIR
WORK_DIR=`mktemp -d -p "$MY_SLURM_TMP_DIR"`
cp -r "$PWD" "$WORK_DIR"
cd "$WORK_DIR/$(basename "$PWD")"
echo WORK_DIR: $WORK_DIR
#### Step 2: actual work starts:
echo PATH is $PATH
echo PYTHONPATH is $PYTHONPATH
echo whcih python is $(which python)
export WANDB_ENTITY=huggingface
export WANDB_PROJECT=trl
bash $BENCHMARK_SCRIPT > output.txt
# Extract Job IDs into an array

View File

@ -1,6 +1,39 @@
# hello world experiment
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
--slurm-nodes 1 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/dpo.py --model_name_or_path=gpt2 --per_device_train_batch_size 4 --max_steps 1000 --learning_rate 1e-3 --gradient_accumulation_steps 1 --logging_steps 10 --eval_steps 500 --output_dir="dpo_anthropic_hh" --optim adamw_torch --warmup_steps 150 --report_to wandb --bf16 --logging_first_step --no_remove_unused_columns" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
--slurm-nodes 1 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/sft.py --model_name_or_path="facebook/opt-350m" --report_to="wandb" --learning_rate=1.41e-5 --per_device_train_batch_size=64 --gradient_accumulation_steps=16 --output_dir="sft_openassistant-guanaco" --logging_steps=1 --num_train_epochs=3 --max_steps=-1 --push_to_hub --gradient_checkpointing" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
--slurm-nodes 1 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/reward_modeling.py --model_name_or_path=facebook/opt-350m --output_dir="reward_modeling_anthropic_hh" --per_device_train_batch_size=64 --num_train_epochs=1 --gradient_accumulation_steps=16 --gradient_checkpointing=True --learning_rate=1.41e-5 --report_to="wandb" --remove_unused_columns=False --optim="adamw_torch" --logging_steps=10 --evaluation_strategy="steps" --max_length=512" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \

View File

@ -9,7 +9,37 @@ python -m openrlbenchmark.rlops_multi_metrics \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename benchmark/trl/$FOLDER_STRING/hello_world \
--output-filename benchmark/trl/$FOLDER_STRING/ppo \
--scan-history
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/rewards/accuracies&metrics=train/loss' \
"gpt2$TAGS_STRING" \
--env-ids dpo_anthropic_hh \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename benchmark/trl/$FOLDER_STRING/dpo \
--scan-history
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss&metrics=eval/accuracy&metrics=eval/loss' \
"facebook/opt-350m$TAGS_STRING" \
--env-ids reward_modeling_anthropic_hh \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename benchmark/trl/$FOLDER_STRING/reward_modeling \
--scan-history
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss' \
"facebook/opt-350m$TAGS_STRING" \
--env-ids sft_openassistant-guanaco \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename benchmark/trl/$FOLDER_STRING/sft \
--scan-history
python benchmark/upload_benchmark.py \

View File

@ -1,6 +1,6 @@
# compound experiments: gpt2xl + grad_accu
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
@ -12,7 +12,7 @@ python benchmark/benchmark.py \
# compound experiments: Cerebras-GPT-6.7B + deepspeed zero2 + grad_accu
python benchmark/benchmark.py \
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --ppo_config.exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --ppo_config.batch_size 32 --ppo_config.mini_batch_size 32 --ppo_config.log_with wandb --ppo_config.model_name cerebras/Cerebras-GPT-6.7B --ppo_config.reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --batch_size 32 --mini_batch_size 32 --log_with wandb --model_name cerebras/Cerebras-GPT-6.7B --reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \

View File

@ -1,6 +1,6 @@
## w/ and w/o gradient accumulation
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
@ -12,7 +12,7 @@ python benchmark/benchmark.py \
## w/ different models (gpt2, gpt2-xl, falcon, llama2)
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_gpt2 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2 --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
@ -22,7 +22,7 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
@ -35,7 +35,7 @@ python benchmark/benchmark.py \
## w/ and w/o PEFT
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name ppo_peft --use_peft --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name ppo_peft --use_peft --log_with wandb" \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \

View File

@ -1,6 +1,6 @@
#!/bin/bash
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --partition=hopper-cpu
#SBATCH --ntasks=1
#SBATCH --output=slurm/logs/%x_%j.out

View File

@ -0,0 +1,3 @@
BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" \
BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" \
bash benchmark/benchmark_and_report.sh

View File

@ -1,16 +1,19 @@
#!/bin/bash
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --partition=hopper-prod
#SBATCH --gpus-per-task={{gpus_per_task}}
#SBATCH --cpus-per-gpu={{cpus_per_gpu}}
#SBATCH --ntasks={{ntasks}}
#SBATCH --output=slurm/logs/%x_%j.out
#SBATCH --array={{array}}
#SBATCH --exclude=ip-26-0-156-239,ip-26-0-148-151,ip-26-0-146-212,ip-26-0-145-137,ip-26-0-146-249,ip-26-0-146-149,ip-26-0-147-233,ip-26-0-145-154,ip-26-0-144-35,ip-26-0-144-189,ip-26-0-146-183,ip-26-0-147-120,ip-26-0-144-95,ip-26-0-145-193
##SBATCH --exclude=ip-26-0-149-199
module load cuda/12.1
{{nodes}}
seeds={{seeds}}
seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]}
echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed"
srun {{command}} --ppo_config.seed $seed
srun {{command}} --seed $seed

58
commands/run_dpo.sh Normal file
View File

@ -0,0 +1,58 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
DATASET_NAME="trl-internal-testing/Anthropic-hh-rlhf-processed"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# This is a hack to get the number of available GPUs
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/examples/scripts/dpo.py \
--model_name_or_path $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

59
commands/run_sft.sh Normal file
View File

@ -0,0 +1,59 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_sft/"
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
DATASET_NAME="imdb"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# Set your number of GPUs here
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/examples/scripts/sft.py \
--model_name $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_seq_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

View File

@ -0,0 +1,66 @@
# Builds GPU docker image of PyTorch
# Uses multi-staged approach to reduce size
# Stage 1
# Use base conda image to reduce time
FROM continuumio/miniconda3:latest AS compile-image
# Specify py version
ENV PYTHON_VERSION=3.10
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN apt-get update && \
apt-get install -y curl git wget software-properties-common git-lfs && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Install audio-related libraries
RUN apt-get update && \
apt install -y ffmpeg
RUN apt install -y libsndfile1-dev
RUN git lfs install
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
RUN python3 -m pip install --no-cache-dir --upgrade pip
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
# We don't install pytorch here yet since CUDA isn't available
# instead we use the direct torch wheel
ENV PATH /opt/conda/envs/trl/bin:$PATH
# Activate our bash shell
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
# Stage 2
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
COPY --from=compile-image /opt/conda /opt/conda
ENV PATH /opt/conda/bin:$PATH
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
RUN source activate trl && \
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
# Install apt libs
RUN apt-get update && \
apt-get install -y curl git wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Activate the conda env and install transformers + accelerate from source
RUN source activate trl && \
python3 -m pip install -U --no-cache-dir \
librosa \
"soundfile>=0.12.1" \
scipy \
transformers \
accelerate \
peft \
trl[test]@git+https://github.com/huggingface/trl
RUN source activate trl && \
pip freeze | grep trl
RUN echo "source activate trl" >> ~/.profile
# Activate the virtualenv
CMD ["/bin/bash"]

View File

@ -0,0 +1,66 @@
# Builds GPU docker image of PyTorch
# Uses multi-staged approach to reduce size
# Stage 1
# Use base conda image to reduce time
FROM continuumio/miniconda3:latest AS compile-image
# Specify py version
ENV PYTHON_VERSION=3.10
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN apt-get update && \
apt-get install -y curl git wget software-properties-common git-lfs && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Install audio-related libraries
RUN apt-get update && \
apt install -y ffmpeg
RUN apt install -y libsndfile1-dev
RUN git lfs install
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
RUN python3 -m pip install --no-cache-dir --upgrade pip
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
# We don't install pytorch here yet since CUDA isn't available
# instead we use the direct torch wheel
ENV PATH /opt/conda/envs/trl/bin:$PATH
# Activate our bash shell
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
# Stage 2
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
COPY --from=compile-image /opt/conda /opt/conda
ENV PATH /opt/conda/bin:$PATH
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
RUN source activate trl && \
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
# Install apt libs
RUN apt-get update && \
apt-get install -y curl git wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Activate the conda env and install transformers + accelerate from source
RUN source activate trl && \
python3 -m pip install -U --no-cache-dir \
librosa \
"soundfile>=0.12.1" \
scipy \
git+https://github.com/huggingface/transformers \
git+https://github.com/huggingface/accelerate \
git+https://github.com/huggingface/peft \
trl[test]@git+https://github.com/huggingface/trl
RUN source activate trl && \
pip freeze | grep transformers
RUN echo "source activate trl" >> ~/.profile
# Activate the virtualenv
CMD ["/bin/bash"]

View File

@ -5,6 +5,8 @@
title: Quickstart
- local: installation
title: Installation
- local: clis
title: Get started with Command Line Interfaces (CLIs)
- local: how_to_train
title: PPO Training FAQ
- local: use_model
@ -29,6 +31,8 @@
title: Best of N Sampling
- local: dpo_trainer
title: DPO Trainer
- local: kto_trainer
title: KTO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: iterative_sft_trainer

109
docs/source/clis.mdx Normal file
View File

@ -0,0 +1,109 @@
# Command Line Interfaces (CLIs)
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
Currently supported CLIs are:
- `trl sft`: fine-tune a LLM on a text/instruction dataset
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
## Fine-tuning with the CLI
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
Before using the `sft` or `dpo` commands make sure to run:
```bash
accelerate config
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
```yaml
model_name_or_path:
HuggingFaceM4/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
```
Save that config in a `.yaml` and get directly started ! Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g.:
```bash
trl sft --config example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
```
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
### Supported Arguments
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
[[autodoc]] ModelConfig
You can pass any of these arguments either to the CLI or the YAML file.
### Supervised Fine-tuning (SFT)
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
```
The SFT CLI is based on the `examples/scripts/sft.py` script.
### Direct Policy Optimization (DPO)
First, follow the basic instructions above and run `trl dpo --output_dir <output_dir> <*args>`. Make sure to process your DPO dataset in the TRL format as follows:
1- Make sure to pre-tokenize the dataset using chat templates:
```bash
python examples/datasets/tokenize_ds.py --model gpt2 --dataset yourdataset
```
You might need to adapt the `examples/datasets/tokenize_ds.py` to use yout chat template
2- Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
```
Once your dataset being pushed, run the dpo CLI as follows:
```bash
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --output_dir opt-sft-hh-rlhf
```
The SFT CLI is based on the `examples/scripts/dpo.py` script.
## Chat interface
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
```bash
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
```
Note that the chat interface relies on the chat template of the tokenizer to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Besides talking to the model there are a few commands you can use:
- **clear**: clears the current conversation and start a new one
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **exit**: closes the interface
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONIG_FILE` where you can also specify the default generation parameters.

View File

@ -2,9 +2,24 @@
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
## How DPO works
Fine-tuning a language model via DPO consists of two steps and is easier than PPO:
1. **Data collection**: Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
DPO-compatible datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo).
This process is illustrated in the sketch below (from [figure 1 of the original paper](https://arxiv.org/pdf/2305.18290.pdf)):
<img width="835" alt="Screenshot 2024-03-19 at 12 39 41" src="https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d">
Read more about DPO algorithm in the [original paper](https://arxiv.org/pdf/2305.18290.pdf).
## Expected dataset format
The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
@ -63,7 +78,7 @@ The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that
For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
```py
dpo_trainer = DPOTrainer(
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
@ -86,11 +101,11 @@ Given the preference data, we can fit a binary classifier according to the Bradl
The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer.
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only).
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.
The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of preferences. Thus the dataset are not necessarily preferences but rather desirable vs undesirable completions. For paired preference data as required by the `DPOTrainer`, use the `loss_type="kto_pair"` argument to the trainer to utilize this loss, while for the more general case of desired and undesirable data, use the as of yet unimplemented `KTOTrainer`.
The [KTO](https://arxiv.org/abs/2402.01306) authors directly maximize the utility of LLM generations instead of the log-likelihood of preferences. To use preference data with KTO, we recommend breaking up the n preferences into 2n examples and using [`KTOTrainer`](kto_trainer) (i.e., treating the data like an unpaired feedback dataset). Although it is possible to pass in `loss_type="kto_pair"` into DPOTrainer, this is a highly simplified version of KTO that we *do not recommend* in most cases. Please use [`KTOTrainer`](kto_trainer) when possible.
## Logging
@ -103,44 +118,46 @@ While training and evaluating we record the following reward metrics:
## Accelerate DPO fine-tuning using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `DPOTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures.
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows:
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|----------|-----------------|-----------|------|------------------------|-----------------|----------------|
| A100 40G | Zephyr 7b | Ultra Chat| 1x | 1.24x | **1.88x** | -11.6% |
| Tesla T4 | Zephyr 7b | Ultra Chat| 1x | 1.09x | **1.55x** | -18.6% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```python
import torch
from transformers import TrainingArguments
from trl import DPOTrainer
from unsloth import FastLlamaModel, FastMistralModel
from unsloth import FastLanguageModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
# Load Llama model
model, tokenizer = FastLlamaModel.from_pretrained(
model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/zephyr-sft",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False.
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
# Do model patching and add fast LoRA weights
model = FastLlamaModel.get_peft_model(
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Currently only supports dropout = 0
bias = "none", # Currently only supports bias = "none"
lora_dropout = 0, # Dropout = 0 is currently optimized
bias = "none", # Bias = "none" is currently optimized
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = max_seq_length,
)
args = TrainingArguments(output_dir="./output")
training_args = TrainingArguments(output_dir="./output")
dpo_trainer = DPOTrainer(
model,
@ -150,7 +167,6 @@ dpo_trainer = DPOTrainer(
train_dataset=train_dataset,
tokenizer=tokenizer,
)
dpo_trainer.train()
```
@ -166,15 +182,13 @@ You have three main options (plus several variants) for how the reference model
### Downsides to merging QLoRA before DPO (approach 2)
As suggested by [Tim Dettmers](https://twitter.com/Tim_Dettmers/status/1694654191325573456), the best option for merging QLoRA adapters is to first quantize the base model, merge the adapter, then convert back to bf16. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py)
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
You can also just merge the adapters the standard way without quantizing the base model, but then you have 1-2% reduced performance (and evidently, more issues with empty responses).
If you use the recommended approach, which quantizes the model, you're now in a situation where to use QLoRA for DPO, you will need to re-quantize the merged model again or use an unquantized merge with lower overall performance.
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
### Using option 3 - load the adapter twice
To avoid the downsides with option 2, at the expense of slightly increased VRAM, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in DPOTrainer.
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in DPOTrainer.
For example:
```python

View File

@ -30,7 +30,6 @@ If you generate text by purely sampling from the model distribution things work
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
- **min_length**: this ignores the EOS token until `min_length` is reached, thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.

View File

@ -0,0 +1,93 @@
# KTO Trainer
TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://arxiv.org/abs/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela.
For a full example have a look at [`examples/scripts/kto.py`].
Depending on how good your base model is, you may or may not need to do SFT before KTO.
This is different from standard RLHF and DPO, which always require SFT.
## Expected dataset format
The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
- `prompt`
- `completion`
- `label`
for example:
```
kto_dataset_dict = {
"prompt": [
"Hey, hello",
"How are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"completion": [
"hi nice to meet you",
"leave me alone",
"I don't have a name",
"My name is Mary",
"Python",
"C++",
"Java",
],
"label": [
True,
False,
False,
True,
True,
False,
False,
],
}
```
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
## Expected model format
The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `KTOTrainer`
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` * number of positives) to (`undesirable_weight` * number of negatives) is in the range 1:1 to 4:3.
```py
training_args = KTOConfig(
beta=0.1,
desirable_weight=1.0,
undesirable_weight=1.0,
)
kto_trainer = KTOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
```
After this one can then call:
```py
kto_trainer.train()
```
## KTOTrainer
[[autodoc]] KTOTrainer
## KTOConfig
[[autodoc]] KTOConfig

View File

@ -71,7 +71,7 @@ The `trl` library is powered by `accelerate`. As such it is best to configure an
```bash
accelerate config # will prompt you to define the training configuration
accelerate launch scripts/gpt2-sentiment_peft.py # launches training
accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
```
## Using `trl` + `peft` and Data Parallelism
@ -140,5 +140,5 @@ python PATH_TO_SCRIPT
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
```bash
python examples/scripts/sft.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2
python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
```

View File

@ -1,6 +1,6 @@
# Multi Adapter RL (MARL) - a single base model for everything
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not tested the convergence of the approach. We encourage the community to let us know if they potentially face into any issue.
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues.
## Requirements
@ -48,7 +48,7 @@ trainer = PPOTrainer(
...
```
Then inside your PPO training loop, call the `compute_reward_score` method by accessing to the `model` attribute from `PPOTrainer`.
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
```python
rewards = trainer.model.compute_reward_score(**inputs)
@ -58,8 +58,8 @@ rewards = trainer.model.compute_reward_score(**inputs)
### Control on the adapter name
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is to train multiple adapters on the same base model to fine-tune on different policies.
In this case, you want to have a control on the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
```python
adapter_name_policy_1 = "policy_1"
@ -97,4 +97,4 @@ trainer = PPOTrainer(
...
)
...
```
```

View File

@ -4,6 +4,21 @@ TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training la
The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.
## How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
This process is illustrated in the sketch below:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
</div>
## Expected dataset format
The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
@ -90,7 +105,7 @@ from trl import PPOTrainer
ppo_trainer = PPOTrainer(
model=model,
config=config,
train_dataset=train_dataset,
dataset=dataset,
tokenizer=tokenizer,
)
```

View File

@ -30,7 +30,7 @@ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# 2. initialize trainer
ppo_config = {"batch_size": 1}
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

View File

@ -25,7 +25,7 @@ accelerate launch examples/scripts/ppo.py # launches training
# 3. get help text and documentation
python examples/scripts/ppo.py --help
# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16
python examples/scripts/ppo.py --ppo_config.log_with wandb --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 16
python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16
```
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
@ -42,7 +42,7 @@ Below are some benchmark results for `examples/scripts/ppo.py`. To reproduce loc
```bash
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -61,7 +61,7 @@ python benchmark/benchmark.py \
```bash
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_step_grad_accu --ppo_config.mini_batch_size 1 --ppo_config.gradient_accumulation_steps 128 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -79,7 +79,7 @@ python benchmark/benchmark.py \
```bash
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2 --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -89,7 +89,7 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -99,7 +99,7 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_falcon_rw_1b --ppo_config.model_name tiiuae/falcon-rw-1b --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
@ -116,7 +116,7 @@ python benchmark/benchmark.py \
```
python benchmark/benchmark.py \
--command "python examples/scripts/ppo.py --ppo_config.exp_name sentiment_tuning_peft --use_peft --ppo_config.log_with wandb" \
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_peft --use_peft --log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \

View File

@ -154,6 +154,72 @@ response_template_ids = tokenizer.encode(response_template_with_context, add_spe
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```
### Add Special Tokens for Chat Format
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the models embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# Set up the chat format with default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)
```
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
### Dataset format support
The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
* conversational format
```json
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
```
* instruction format
```json
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```
If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method.
```python
from datasets import load_dataset
from trl import SFTTrainer
...
# load jsonl dataset
dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train")
# load dataset from the HuggingFace Hub
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
...
trainer = SFTTrainer(
"facebook/opt-350m",
args=training_args,
train_dataset=dataset,
packing=True,
)
```
If the dataset is not in one those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
### Format your input prompts
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
@ -185,7 +251,7 @@ trainer = SFTTrainer(
trainer.train()
```
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
### Packing dataset ([`ConstantLengthDataset`])
@ -205,6 +271,7 @@ trainer.train()
```
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTTrainer` init method.
#### Customize your prompts using packed dataset
@ -346,11 +413,11 @@ Note that you cannot train your model using Flash Attention 1 on an arbitrary da
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
|----------------|-------------------|-------------|------------|------------------------|
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
### Using Flash Attention-2
@ -360,13 +427,13 @@ To use Flash Attention 2, first install the latest `flash-attn` package:
pip install -U flash-attn
```
And add `use_flash_attention_2=True` when calling `from_pretrained`:
And add `attn_implementation="flash_attention_2"` when calling `from_pretrained`:
```python
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True
attn_implementation="flash_attention_2"
)
```
@ -375,6 +442,45 @@ After loading your model, you can either train it as it is, or attach adapters a
In contrary to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
### Using model creation utility
We included a utility function to create your model.
[[autodoc]] ModelConfig
```python
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
model_config = ModelConfig(
model_name_or_path="facebook/opt-350m"
attn_implementation=None, # or "flash_attention_2"
)
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
trainer = SFTTrainer(
...,
model=model_config.model_name_or_path,
peft_config=get_peft_config(model_config),
)
```
### Enhance model's performances using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://arxiv.org/abs/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
@ -413,44 +519,48 @@ Note however, that the amount of performance gain is _dataset dependent_ and in
### Accelerate fine-tuning 2x using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) and even full-finetuning (1.1x faster) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama as well) and Mistral architectures.
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth#installation-instructions---conda). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLlamaModel` or `FastMistralModel` as follows:
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|-----------------|-----------|-----|-------------------------|-----------------|----------------|
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```python
import torch
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import FastLlamaModel, FastMistralModel
from unsloth import FastLanguageModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load Llama model
model, tokenizer = FastLlamaModel.from_pretrained(
model_name = "unsloth/llama-2-7b", # Supports any llama model eg meta-llama/Llama-2-7b-hf
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/mistral-7b",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
# Do model patching and add fast LoRA weights
model = FastLlamaModel.get_peft_model(
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Currently only supports dropout = 0
bias = "none", # Currently only supports bias = "none"
lora_dropout = 0, # Dropout = 0 is currently optimized
bias = "none", # Bias = "none" is currently optimized
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = max_seq_length,
)
args = TrainingArguments(output_dir="./output")
args = TrainingArguments(output_dir = "./output")
trainer = SFTTrainer(
model = model,
@ -459,7 +569,6 @@ trainer = SFTTrainer(
dataset_text_field = "text",
max_seq_length = max_seq_length,
)
trainer.train()
```
@ -474,6 +583,20 @@ Pay attention to the following best practices when training a model with that tr
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
## Multi-GPU Training
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
- Ensure that the model is placed on the correct device:
```python
from accelerate import PartialState
device_string = PartialState().process_index
model = AutoModelForCausalLM.from_pretrained(
...
device_map={'':device_string}
)
```
## GPTQ Conversion
You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.

20
example_config.yaml Normal file
View File

@ -0,0 +1,20 @@
# This is an example configuration file of TRL CLI, you can use it for
# SFT like that: `trl sft --config config.yaml --output_dir test-sft`
# The YAML file supports environment variables by adding an `env` field
# as below
# env:
# CUDA_VISIBLE_DEVICES: 0
model_name_or_path:
HuggingFaceM4/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine

View File

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: "NO"
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,122 @@
import multiprocessing
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import HfApi
from huggingface_hub.repocard import RepoCard
from transformers import HfArgumentParser
"""
# debug
python -i examples/datasets/anthropic_hh.py --debug --push_to_hub
# actual push
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity trl-internal-testing
"""
api = HfApi()
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
hf_repo_id: Optional[str] = field(default="hh-rlhf-trl-style", metadata={"help": "The Hugging Face repository ID"})
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
update_main_revision: Optional[bool] = field(
default=True, metadata={"help": "Update the main revision of the repository"}
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
# GPT-4 generated 😄 Define a function to process the input and extract the dialogue into structured format
def extract_dialogue(input_text):
# Split the input by lines and initialize variables
lines = input_text.strip().split("\n\n")
dialogue_list = []
# Iterate through each line and extract the dialogue
for line in lines:
# Check if the line starts with "Human" or "Assistant" and split accordingly
if line.startswith("Human:"):
role = "user"
content = line.replace("Human: ", "").strip()
elif line.startswith("Assistant:"):
role = "assistant"
content = line.replace("Assistant: ", "").strip()
else:
# If the line doesn't start with "Human" or "Assistant", it's part of the previous message's content
# Append it to the last message's content
dialogue_list[-1]["content"] += "\n\n" + line.strip()
continue
# Append the extracted dialogue piece to the list
dialogue_list.append({"role": role, "content": content})
return dialogue_list
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
ds = load_dataset("Anthropic/hh-rlhf")
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
def process(row):
row["chosen"] = extract_dialogue(row["chosen"])
row["rejected"] = extract_dialogue(row["rejected"])
row["prompt"] = row["chosen"][0]["content"]
return row
ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
# get the commnad used to run the script
run_command = " ".join(["python"] + sys.argv)
for revision in revisions:
ds.push_to_hub(full_repo_id, revision=revision)
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
# get the name of the current file
file_name = __file__.split("/")[-1]
api.upload_file(
path_or_fileobj=__file__,
path_in_repo=file_name,
revision=revision,
repo_id=full_repo_id,
repo_type="dataset",
)
sft_card = RepoCard.load(
full_repo_id,
repo_type="dataset",
)
sft_card.text = f"""\
# TRL's Anthropic HH Dataset
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""
sft_card.push_to_hub(
full_repo_id,
repo_type="dataset",
)

View File

@ -0,0 +1,113 @@
import multiprocessing
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import HfApi
from huggingface_hub.repocard import RepoCard
from transformers import HfArgumentParser
"""
# debug
python -i examples/datasets/tldr_preference.py --debug --push_to_hub
# actual push
python examples/datasets/tldr_preference.py --push_to_hub --hf_entity trl-internal-testing
"""
api = HfApi()
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
hf_repo_id: Optional[str] = field(
default="tldr-preference-trl-style", metadata={"help": "The Hugging Face repository ID"}
)
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
update_main_revision: Optional[bool] = field(
default=True, metadata={"help": "Update the main revision of the repository"}
)
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
ds = load_dataset("openai/summarize_from_feedback", "comparisons")
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
cnndm_batches = ["batch0_cnndm", "cnndm0", "cnndm2"]
if not args.debug:
ds["validation_cnndm"] = ds["validation"].filter(lambda x: x["batch"] in cnndm_batches)
ds["validation"] = ds["validation"].filter(lambda x: x["batch"] not in cnndm_batches)
tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
cnndm_format_str = "Article:\n{article}\n\nTL;DR:"
def process(row):
format_str = cnndm_format_str if row["batch"] in cnndm_batches else tldr_format_str
row["prompt"] = format_str.format(**row["info"])
choice = row["choice"]
chosen = row["summaries"][choice]["text"]
rejected = row["summaries"][1 - choice]["text"]
row["chosen"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": chosen}]
row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": rejected}]
return row
ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
for key in ds: # reorder columns
ds[key] = ds[key].select_columns(
["prompt", "chosen", "rejected", "info", "summaries", "choice", "worker", "batch", "split", "extra"]
)
if args.push_to_hub:
revisions = ["main"] if args.update_main_revision else []
revisions.append(args.revision)
# get the commnad used to run the script
run_command = " ".join(["python"] + sys.argv)
for revision in revisions:
ds.push_to_hub(full_repo_id, revision=revision)
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
# get the name of the current file
file_name = __file__.split("/")[-1]
api.upload_file(
path_or_fileobj=__file__,
path_in_repo=file_name,
revision=revision,
repo_id=full_repo_id,
repo_type="dataset",
)
sft_card = RepoCard.load(
full_repo_id,
repo_type="dataset",
)
sft_card.text = f"""\
# TRL's TL;DR Preference Dataset
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
## Reproduce this dataset
1. Download the `{file_name}` from the {repo_full_url}.
2. Run `{run_command}`
"""
sft_card.push_to_hub(
full_repo_id,
repo_type="dataset",
)

View File

@ -0,0 +1,42 @@
import multiprocessing
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser
"""
python -i examples/datasets/tokenize_ds.py --debug --model HuggingFaceH4/zephyr-7b-beta
python -i examples/datasets/tokenize_ds.py --debug --model gpt2
"""
@dataclass
class ScriptArguments:
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
dataset: str = field(default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The dataset to load"})
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
ds = load_dataset(args.dataset)
if args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
ds = ds.map(
process,
num_proc=1 if args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
print(ds["train"][0]["chosen"])

View File

@ -12,7 +12,7 @@ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# 2. initialize trainer
ppo_config = {"batch_size": 1}
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
@ -29,7 +29,7 @@ generation_kwargs = {
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_tensor = ppo_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])
# 5. define a reward for response

View File

@ -15,6 +15,7 @@ from transformers import (
Trainer,
TrainerCallback,
TrainingArguments,
set_seed,
)
from transformers.utils import PaddingStrategy
@ -89,11 +90,14 @@ class ScriptArguments:
default=False,
metadata={"help": "Whether to run eval after the first step"},
)
seed: Optional[int] = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
set_seed(script_args.seed)
# Load the human stack-exchange-paired dataset for tuning the reward model.
train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train")
if script_args.train_subset > 0:
@ -129,7 +133,10 @@ training_args = TrainingArguments(
logging_steps=10,
optim=script_args.optim,
lr_scheduler_type=script_args.lr_scheduler_type,
seed=script_args.seed,
)
# Load the value-head model and tokenizer.
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -67,6 +66,7 @@ class ScriptArguments:
)
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
parser = HfArgumentParser(ScriptArguments)
@ -163,7 +163,7 @@ dataset = build_dataset(tokenizer)
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}
# set seed before initializing value head for deterministic eval
@ -181,7 +181,7 @@ lora_config = LoraConfig(
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
load_in_8bit=True,
load_in_8bit=script_args.load_in_8bit,
device_map={"": current_device},
peft_config=lora_config,
)
@ -216,11 +216,13 @@ sentiment_pipe = pipeline(
"sentiment-analysis",
model=reward_model_name,
device_map={"": current_device},
model_kwargs={"load_in_8bit": True},
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
tokenizer=tokenizer,
return_token_type_ids=False,
)
if sentiment_pipe.model.config.pad_token_id is None:
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.

View File

@ -4,9 +4,10 @@ from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from accelerate import Accelerator
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
from trl import DPOTrainer
@ -41,6 +42,10 @@ class ScriptArguments:
default=True, metadata={"help": "whether to use gradient checkpointing"}
)
gradient_checkpointing_use_reentrant: Optional[bool] = field(
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
)
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
@ -54,6 +59,10 @@ class ScriptArguments:
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
model_dtype: Optional[str] = field(
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
)
# instrumentation
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
@ -73,12 +82,15 @@ class ScriptArguments:
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
seed: Optional[int] = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)
def get_stack_exchange_paired(
data_dir: str = "data/rl",
sanity_check: bool = False,
cache_dir: str = None,
cache_dir: Optional[str] = None,
num_proc=24,
) -> Dataset:
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
@ -123,12 +135,21 @@ if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
set_seed(script_args.seed)
# 1. load a pretrained model
torch_dtype = torch.float
if script_args.model_dtype == "float16":
torch_dtype = torch.float16
elif script_args.model_dtype == "bfloat16":
torch_dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
torch_dtype=torch_dtype,
load_in_4bit=script_args.load_in_4bit,
device_map={"": Accelerator().local_process_index},
)
model.config.use_cache = False
@ -138,12 +159,6 @@ if __name__ == "__main__":
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
model_ref = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
@ -181,6 +196,8 @@ if __name__ == "__main__":
bf16=True,
remove_unused_columns=False,
run_name="dpo_llama2",
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
seed=script_args.seed,
)
peft_config = LoraConfig(
@ -203,7 +220,7 @@ if __name__ == "__main__":
# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
ref_model=None,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,

View File

@ -8,7 +8,14 @@ from accelerate import Accelerator
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
TrainingArguments,
set_seed,
)
from trl import SFTTrainer
from trl.import_utils import is_npu_available, is_xpu_available
@ -27,6 +34,7 @@ class ScriptArguments:
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
# LoraConfig
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
@ -53,6 +61,8 @@ if training_args.group_by_length and script_args.packing:
if training_args.gradient_checkpointing:
raise ValueError("gradient_checkpointing not supported")
set_seed(training_args.seed)
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
"""
@ -91,7 +101,7 @@ def prepare_sample_text(example):
return text
def create_datasets(tokenizer, args):
def create_datasets(tokenizer, args, seed=None):
dataset = load_dataset(
args.dataset_name,
data_dir=args.subset,
@ -104,9 +114,9 @@ def create_datasets(tokenizer, args):
print("Loading the dataset in streaming mode")
valid_data = dataset.take(args.size_valid_set)
train_data = dataset.skip(args.size_valid_set)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
else:
dataset = dataset.train_test_split(test_size=0.005, seed=None)
dataset = dataset.train_test_split(test_size=0.005, seed=seed)
train_data = dataset["train"]
valid_data = dataset["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
@ -133,11 +143,13 @@ def create_datasets(tokenizer, args):
return train_dataset, valid_dataset
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
bnb_config = None
if script_args.use_bnb:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
script_args.model_name,
@ -153,7 +165,7 @@ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_c
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
train_dataset, eval_dataset = create_datasets(tokenizer, script_args)
train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
trainer = SFTTrainer(
model=base_model,

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -107,7 +106,7 @@ text_env = TextEnvironment(
)
# main training loop
for step in range(100):
for _step in range(100):
tasks, answers = generate_data(ppo_config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -61,9 +60,9 @@ def exact_match_reward(responses, answers=None):
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs((predicted_number - float(answer))) < 0.1:
if np.abs(predicted_number - float(answer)) < 0.1:
reward += 1.0
except: # noqa
except Exception:
pass
rewards.append(torch.tensor(reward))
return rewards

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -114,7 +113,7 @@ dataset = dataset.shuffle(local_seed)
def data_generator():
for i in range(len(dataset)):
yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]]
yield dataset[i]["question"], list(dataset[i]["answer"]["normalized_aliases"])
gen = data_generator()
@ -123,7 +122,7 @@ gen = iter(gen)
def generate_data(n):
tasks, answers = [], []
for i in range(n):
for _i in range(n):
q, a = next(gen)
tasks.append(q)
answers.append(a)
@ -143,10 +142,14 @@ def exact_match_reward(responses, answers=None):
return rewards
def tool_fn(x):
# limit the amount of tokens
return tool(x).split("\n")[1][:600]
# text env
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
# limit the amount if tokens
tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa
text_env = TextEnvironment(
model,
tokenizer,
@ -184,8 +187,6 @@ for i in range(args.iterations):
"answer": [", ".join(item) for item in answers],
}
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
ppo_trainer.log_stats(
train_stats, texts, [item for item in all_rewards], columns_to_log=["query", "response", "answer"]
)
ppo_trainer.log_stats(train_stats, texts, list(all_rewards), columns_to_log=["query", "response", "answer"])
if i % 100 == 0:
ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -146,7 +145,7 @@ dataset = build_dataset(config, input_min_text_length=min_input_length, input_ma
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}
# set seed before initializing value head for deterministic eval
@ -218,7 +217,7 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
response_tensors.append(response.squeeze()[-gen_len:])
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
# Compute sentiment score # noqa
# Compute sentiment score
texts = batch["response"]
toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(
ppo_trainer.accelerator.device

338
examples/scripts/chat.py Normal file
View File

@ -0,0 +1,338 @@
# flake8: noqa
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trl.commands.cli_utils import init_zero_verbose
init_zero_verbose()
import copy
import json
import os
import pwd
import re
import time
from threading import Thread
import torch
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from trl.commands.cli_utils import ChatArguments, TrlParser, init_zero_verbose
from trl.trainer.utils import get_kbit_device_map, get_quantization_config
HELP_STRING = """\
**TRL CHAT INTERFACE**
The chat interface is a simple tool to try out a chat model.
Besides talking to the model there are several commands:
- **clear**: clears the current conversation and start a new one
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **exit**: closes the interface
"""
SUPPORTED_GENERATION_KWARGS = [
"max_new_tokens",
"do_sample",
"num_beams",
"temperature",
"top_p",
"top_k",
"repetition_penalty",
]
SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$"
class RichInterface:
def __init__(self, model_name=None, user_name=None):
self._console = Console()
if model_name is None:
self.model_name = "assistant"
else:
self.model_name = model_name
if user_name is None:
self.user_name = "user"
else:
self.user_name = user_name
def stream_output(self, output_stream):
"""Stream output from a role."""
# This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
# Create a Live context for updating the console output
text = ""
self._console.print(f"[bold blue]<{self.model_name}>:")
with Live(console=self._console, refresh_per_second=4) as live:
# Read lines from the stream
for i, outputs in enumerate(output_stream):
if not outputs or i == 0:
continue
text += outputs
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
# Update the Live console output
live.update(markdown)
self._console.print()
return text
def input(self):
input = self._console.input(f"[bold red]<{self.user_name}>:\n")
self._console.print()
return input
def clear(self):
self._console.clear()
def print_user_message(self, text):
self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
self._console.print()
def print_green(self, text):
self._console.print(f"[bold green]{text}")
self._console.print()
def print_red(self, text):
self._console.print(f"[bold red]{text}")
self._console.print()
def print_help(self):
self._console.print(Markdown(HELP_STRING))
self._console.print()
def get_username():
return pwd.getpwuid(os.getuid())[0]
def create_default_filename(model_name):
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
return f"{model_name}/chat_{time_str}.json"
def save_chat(chat, args, filename):
output_dict = {}
output_dict["settings"] = vars(args)
output_dict["chat_history"] = chat
folder = args.save_folder
if filename is None:
filename = create_default_filename(args.model_name_or_path)
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
json.dump(output_dict, f, indent=4)
return os.path.abspath(filename)
def clear_chat_history(system_prompt):
if system_prompt is None:
chat = []
else:
chat = [{"role": "system", "content": system_prompt}]
return chat
def parse_settings(user_input, current_args, interface):
settings = user_input[4:].strip().split(";")
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
settings = dict(settings)
error = False
for name in settings:
if hasattr(current_args, name):
try:
if isinstance(getattr(current_args, name), bool):
if settings[name] == "True":
settings[name] = True
elif settings[name] == "False":
settings[name] = False
else:
raise ValueError
else:
settings[name] = type(getattr(current_args, name))(settings[name])
except ValueError:
interface.print_red(
f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
)
else:
interface.print_red(f"There is no '{name}' setting.")
if error:
interface.print_red("There was an issue parsing the settings. No settings have been changed.")
return current_args, False
else:
for name in settings:
setattr(current_args, name, settings[name])
interface.print_green(f"Set {name} to {settings[name]}.")
time.sleep(1.5) # so the user has time to read the changes
return current_args, True
def load_model_and_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
quantization_config = get_quantization_config(args)
model_kwargs = dict(
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
attn_implementation=args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
return model, tokenizer
def chat_cli():
parser = TrlParser(ChatArguments)
args = parser.parse_args_into_dataclasses()[0]
if args.config == "default":
args.config = os.path.join(os.path.dirname(__file__), "config/default_chat_config.yaml")
if args.config.lower() == "none":
args.config = None
args = parser.update_dataclasses_with_config([args])[0]
if args.examples is None:
args.examples = {}
current_args = copy.deepcopy(args)
if args.user is None:
user = get_username()
else:
user = args.user
model, tokenizer = load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
interface.clear()
chat = clear_chat_history(current_args.system_prompt)
while True:
try:
user_input = interface.input()
if user_input == "clear":
chat = clear_chat_history(current_args.system_prompt)
interface.clear()
continue
if user_input == "help":
interface.print_help()
continue
if user_input == "exit":
break
if user_input == "reset":
interface.clear()
current_args = copy.deepcopy(args)
chat = clear_chat_history(current_args.system_prompt)
continue
if user_input.startswith("save") and len(user_input.split()) < 2:
split_input = user_input.split()
if len(split_input) == 2:
filename = split_input[1]
else:
filename = None
filename = save_chat(chat, current_args, filename)
interface.print_green(f"Chat saved in {filename}!")
continue
if re.match(SETTING_RE, user_input):
current_args, success = parse_settings(user_input, current_args, interface)
if success:
chat = []
interface.clear()
continue
if user_input.startswith("example") and len(user_input.split()) == 2:
example_name = user_input.split()[1]
if example_name in current_args.examples:
interface.clear()
chat = []
interface.print_user_message(current_args.examples[example_name]["text"])
user_input = current_args.examples[example_name]["text"]
else:
interface.print_red(
f"Example {example_name} not found in list of available examples: {list(current_args.examples.keys())}."
)
continue
chat.append({"role": "user", "content": user_input})
generation_kwargs = dict(
inputs=tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
),
streamer=generation_streamer,
max_new_tokens=current_args.max_new_tokens,
do_sample=current_args.do_sample,
num_beams=current_args.num_beams,
temperature=current_args.temperature,
top_k=current_args.top_k,
top_p=current_args.top_p,
repetition_penalty=current_args.repetition_penalty,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
model_output = interface.stream_output(generation_streamer)
thread.join()
chat.append({"role": "assistant", "content": model_output})
except KeyboardInterrupt:
break
if __name__ == "__main__":
chat_cli()

View File

@ -0,0 +1,13 @@
examples:
llama:
text: There is a Llama in my lawn, how can I get rid of it?
code:
text: Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].
helicopter:
text: How many helicopters can a human eat in one sitting?
numbers:
text: Count to 10 but skip every number ending with an 'e'
birds:
text: Why aren't birds real?
socks:
text: Why is it important to eat socks after meditating?

View File

@ -11,18 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python examples/scripts/ddpo.py \
--num_epochs=200 \
--train_gradient_accumulation_steps=1 \
--sample_num_steps=50 \
--sample_batch_size=6 \
--train_batch_size=3 \
--sample_num_batches_per_epoch=4 \
--per_prompt_stat_tracking=True \
--per_prompt_stat_tracking_buffer_size=32 \
--tracker_project_name="stable_diffusion_training" \
--log_with="wandb"
"""
import os
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import tyro
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from transformers import CLIPModel, CLIPProcessor
from transformers import CLIPModel, CLIPProcessor, HfArgumentParser
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
from trl.import_utils import is_npu_available, is_xpu_available
@ -30,40 +40,22 @@ from trl.import_utils import is_npu_available, is_xpu_available
@dataclass
class ScriptArguments:
hf_user_access_token: str
pretrained_model: str = "runwayml/stable-diffusion-v1-5"
"""the pretrained model to use"""
pretrained_revision: str = "main"
"""the pretrained model revision to use"""
hf_hub_model_id: str = "ddpo-finetuned-stable-diffusion"
"""HuggingFace repo to save model weights to"""
hf_hub_aesthetic_model_id: str = "trl-lib/ddpo-aesthetic-predictor"
"""HuggingFace model ID for aesthetic scorer model weights"""
hf_hub_aesthetic_model_filename: str = "aesthetic-model.pth"
"""HuggingFace model filename for aesthetic scorer model weights"""
use_lora: bool = True
"""Whether to use LoRA."""
ddpo_config: DDPOConfig = field(
default_factory=lambda: DDPOConfig(
num_epochs=200,
train_gradient_accumulation_steps=1,
sample_num_steps=50,
sample_batch_size=6,
train_batch_size=3,
sample_num_batches_per_epoch=4,
per_prompt_stat_tracking=True,
per_prompt_stat_tracking_buffer_size=32,
tracker_project_name="stable_diffusion_training",
log_with="wandb",
project_kwargs={
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
},
)
pretrained_model: str = field(
default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"}
)
pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"})
hf_hub_model_id: str = field(
default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"}
)
hf_hub_aesthetic_model_id: str = field(
default="trl-lib/ddpo-aesthetic-predictor",
metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"},
)
hf_hub_aesthetic_model_filename: str = field(
default="aesthetic-model.pth",
metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"},
)
use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."})
class MLP(nn.Module):
@ -183,7 +175,7 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
for i, image in enumerate(images):
prompt = prompts[i]
reward = rewards[i].item()
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0)
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float()
accelerate_logger.log_images(
result,
@ -192,14 +184,21 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
if __name__ == "__main__":
args = tyro.cli(ScriptArguments)
parser = HfArgumentParser((ScriptArguments, DDPOConfig))
args, ddpo_config = parser.parse_args_into_dataclasses()
ddpo_config.project_kwargs = {
"logging_dir": "./logs",
"automatic_checkpoint_naming": True,
"total_limit": 5,
"project_dir": "./save",
}
pipeline = DefaultDDPOStableDiffusionPipeline(
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora
)
trainer = DDPOTrainer(
args.ddpo_config,
ddpo_config,
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
prompt_fn,
pipeline,
@ -208,4 +207,4 @@ if __name__ == "__main__":
trainer.train()
trainer.push_to_hub(args.hf_hub_model_id, token=args.hf_user_access_token)
trainer.push_to_hub(args.hf_hub_model_id)

View File

@ -1,4 +1,4 @@
# coding=utf-8
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -12,187 +12,155 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
python examples/scripts/dpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="dpo_anthropic_hh" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# Note: you need to install transformers from main to run this script. See https://huggingface.co/docs/transformers/installation#install-from-source
# TODO: bump transformers version in requirements at next release.
# peft:
python examples/scripts/dpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="dpo_anthropic_hh" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""
import logging
import os
from contextlib import nullcontext
# 0. imports
from dataclasses import dataclass, field
from typing import Dict, Optional
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
from trl.commands.cli_utils import DpoScriptArguments, init_zero_verbose, TrlParser
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import DPOTrainer
from trl import (
DPOTrainer,
ModelConfig,
RichProgressCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the DPO training script.
"""
# data parameters
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
# training parameters
model_name_or_path: Optional[str] = field(default="gpt2", metadata={"help": "the model name"})
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: Optional[int] = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
# lora parameters
use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
# instrumentation
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
default=None,
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
},
)
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)
def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
def split_prompt_and_responses(sample) -> Dict[str, str]:
prompt = extract_anthropic_prompt(sample["chosen"])
return {
"prompt": prompt,
"chosen": sample["chosen"][len(prompt) :],
"rejected": sample["rejected"][len(prompt) :],
}
return dataset.map(split_prompt_and_responses)
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
parser = TrlParser((DpoScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# 1. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
if script_args.ignore_bias_buffers:
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
peft_config = get_peft_config(model_config)
if peft_config is None:
model_ref = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
else:
model_ref = None
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 2. Load the Anthropic Helpful-Harmless dataset
train_dataset = get_hh("train", sanity_check=script_args.sanity_check)
# 3. Load evaluation dataset
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)
# 4. initialize training arguments:
training_args = TrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
max_steps=script_args.max_steps,
remove_unused_columns=False,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
logging_first_step=True,
logging_steps=10, # match results in blog post
eval_steps=500,
output_dir="./test",
optim="rmsprop",
warmup_steps=150,
report_to=script_args.report_to,
bf16=True,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next transformers release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)
if script_args.use_peft:
peft_config = LoraConfig(
r=script_args.peft_lora_r,
lora_alpha=script_args.peft_lora_alpha,
bias="none",
task_type="CAUSAL_LM",
################
# Dataset
################
train_dataset = load_dataset(args.dataset_name, split="train")
eval_dataset = load_dataset(args.dataset_name, split="test")
################
# Training
################
with init_context:
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
else:
peft_config = None
# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=script_args.max_length,
max_target_length=script_args.max_target_length,
max_prompt_length=script_args.max_prompt_length,
generate_during_eval=True,
peft_config=peft_config,
)
trainer.train()
# 6. train
dpo_trainer.train()
with save_context:
trainer.save_model(training_args.output_dir)

152
examples/scripts/kto.py Normal file
View File

@ -0,0 +1,152 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the KTO training script with the following command with some example arguments.
In general, the optimal configuration for KTO will be similar to that of DPO:
# regular:
python examples/scripts/kto.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="kto_anthropic_hh" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/kto.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="kto_anthropic_hh" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""
from dataclasses import dataclass, field
from typing import Optional
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""
# debugging
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
if search_term_idx == -1:
raise ValueError(f"Prompt and response does not contain '{search_term}'")
return prompt_and_response[: search_term_idx + len(search_term)]
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'completion': List[str],
'label': List[bool],
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
flat_data = {
"prompt": [],
"completion": [],
"label": [],
}
for sample in dataset:
prompt = extract_anthropic_prompt(sample["chosen"])
flat_data["prompt"].append(prompt)
flat_data["completion"].append(sample["chosen"][len(prompt) :])
flat_data["label"].append(True)
flat_data["prompt"].append(prompt)
flat_data["completion"].append(sample["rejected"][len(prompt) :])
flat_data["label"].append(False)
return dataset.from_dict(flat_data)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
# 1. load a pretrained model
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 2. Load the Anthropic Helpful-Harmless dataset
train_dataset = get_hh("train", sanity_check=script_args.sanity_check)
# 3. Load evaluation dataset
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)
# 4. initialize the KTO trainer
kto_trainer = KTOTrainer(
model,
model_ref,
args=kto_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
# 5. train
kto_trainer.train()

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -12,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python examples/scripts/ppo.py \
--log_with=wandb
"""
from dataclasses import dataclass, field
from typing import Optional
import torch
import tyro
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoTokenizer, pipeline
from transformers import AutoTokenizer, HfArgumentParser, pipeline
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
@ -33,42 +35,17 @@ tqdm.pandas()
@dataclass
class ScriptArguments:
ppo_config: PPOConfig = field(
default_factory=lambda: PPOConfig(
model_name="lvwerra/gpt2-imdb",
query_dataset="imdb",
reward_model="sentiment-analysis:lvwerra/distilbert-imdb",
learning_rate=1.41e-5,
log_with=None,
mini_batch_size=128,
batch_size=128,
gradient_accumulation_steps=1,
early_stopping=False,
target_kl=6.0,
kl_penalty="kl",
seed=0,
use_score_scaling=False,
use_score_norm=False,
score_clip=None,
)
)
use_seq2seq: bool = False
"""whether to use seq2seq models"""
use_peft: bool = False
"""whether to use peft"""
peft_config: Optional[LoraConfig] = field(
default_factory=lambda: LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="CAUSAL_LM",
),
)
use_seq2seq: bool = field(default=False, metadata={"help": "whether to use seq2seq"})
trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
# LoraConfig
use_peft: bool = field(default=False, metadata={"help": "whether to use peft"})
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_r: Optional[int] = field(default=16, metadata={"help": "the lora r parameter"})
args = tyro.cli(ScriptArguments)
parser = HfArgumentParser((ScriptArguments, PPOConfig))
args, ppo_config = parser.parse_args_into_dataclasses()
# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
@ -113,42 +90,47 @@ def build_dataset(config, query_dataset, input_min_text_length=2, input_max_text
# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(args.ppo_config, args.ppo_config.query_dataset)
dataset = build_dataset(ppo_config, ppo_config.query_dataset)
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}
# set seed before initializing value head for deterministic eval
set_seed(args.ppo_config.seed)
set_seed(ppo_config.seed)
# Now let's build the model, the reference model, and the tokenizer.
if not args.use_peft:
ref_model = trl_model_class.from_pretrained(args.ppo_config.model_name, trust_remote_code=args.trust_remote_code)
ref_model = trl_model_class.from_pretrained(ppo_config.model_name, trust_remote_code=args.trust_remote_code)
device_map = None
peft_config = None
else:
peft_config = args.peft_config
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
bias="none",
task_type="CAUSAL_LM",
)
ref_model = None
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
model = trl_model_class.from_pretrained(
args.ppo_config.model_name,
ppo_config.model_name,
trust_remote_code=args.trust_remote_code,
device_map=device_map,
peft_config=peft_config,
)
tokenizer = AutoTokenizer.from_pretrained(args.ppo_config.model_name)
tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here.
tokenizer.pad_token_id = tokenizer.eos_token_id
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(args.ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
# We then build the sentiment analysis pipeline, passing the model name and the
# sentiment analysis pipeline arguments. Let's also make sure to set the device
@ -162,7 +144,7 @@ if ppo_trainer.accelerator.num_processes == 1:
else:
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
task, model_name = args.ppo_config.reward_model.split(":")
task, model_name = ppo_config.reward_model.split(":")
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
with ds_plugin.zero3_init_context_manager(enable=False):
sentiment_pipe = pipeline(task, model=model_name, device=device)
@ -188,7 +170,7 @@ generation_kwargs = {
"max_new_tokens": 32,
}
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]
# Get response from gpt2

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -97,7 +96,7 @@ dataset = create_and_prepare_dataset(tokenizer)
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}
config = PPOConfig(
@ -131,7 +130,7 @@ generation_kwargs = {
"max_new_tokens": 32,
}
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
question_tensors = batch["input_ids"]
response_tensors = ppo_trainer.generate(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -12,162 +11,114 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
"""
python examples/scripts/reward_modeling.py \
--model_name_or_path=facebook/opt-350m \
--output_dir="reward_modeling_anthropic_hh" \
--per_device_train_batch_size=64 \
--num_train_epochs=1 \
--gradient_accumulation_steps=16 \
--gradient_checkpointing=True \
--learning_rate=1.41e-5 \
--report_to="wandb" \
--remove_unused_columns=False \
--optim="adamw_torch" \
--logging_steps=10 \
--evaluation_strategy="steps" \
--max_length=512 \
"""
import warnings
import tyro
from accelerate import Accelerator
import torch
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
from trl import RewardConfig, RewardTrainer, is_xpu_available
from trl import ModelConfig, RewardConfig, RewardTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
tqdm.pandas()
@dataclass
class ScriptArguments:
model_name: str = "facebook/opt-350m"
"""the model name"""
dataset_name: str = "Anthropic/hh-rlhf"
"""the dataset name"""
dataset_text_field: str = "text"
"""the text field of the dataset"""
eval_split: str = "none"
"""the dataset split to evaluate on; default to 'none' (no evaluation)"""
load_in_8bit: bool = False
"""load the model in 8 bits precision"""
load_in_4bit: bool = False
"""load the model in 4 bits precision"""
trust_remote_code: bool = True
"""Enable `trust_remote_code`"""
reward_config: RewardConfig = field(
default_factory=lambda: RewardConfig(
output_dir="output",
per_device_train_batch_size=64,
num_train_epochs=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
learning_rate=1.41e-5,
report_to="tensorboard",
remove_unused_columns=False,
optim="adamw_torch",
logging_steps=500,
evaluation_strategy="no",
max_length=512,
if __name__ == "__main__":
parser = HfArgumentParser((RewardConfig, ModelConfig))
reward_config, model_config = parser.parse_args_into_dataclasses()
reward_config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(
model_config.model_name_or_path, num_labels=1, **model_kwargs
)
if model_config.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script."
)
)
use_peft: bool = False
"""whether to use peft"""
peft_config: Optional[LoraConfig] = field(
default_factory=lambda: LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="SEQ_CLS",
modules_to_save=["scores"],
),
)
################
# Dataset
################
raw_datasets = load_dataset("Anthropic/hh-rlhf")
# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
args = tyro.cli(ScriptArguments)
args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no"
def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
# Step 1: Load the model
if args.load_in_8bit and args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif args.load_in_8bit or args.load_in_4bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)
# Copy the model to each device
device_map = (
{"": f"xpu:{Accelerator().local_process_index}"}
if is_xpu_available()
else {"": Accelerator().local_process_index}
)
else:
device_map = None
quantization_config = None
return new_examples
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
quantization_config=quantization_config,
device_map=device_map,
trust_remote_code=args.trust_remote_code,
num_labels=1,
)
# Step 2: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
train_dataset = load_dataset(args.dataset_name, split="train")
# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples
# Preprocess the dataset and filter out examples that are longer than args.max_length
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=4,
)
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
)
if args.eval_split == "none":
eval_dataset = None
else:
eval_dataset = load_dataset(args.dataset_name, split=args.eval_split)
eval_dataset = eval_dataset.map(
# Preprocess the dataset and filter out examples that are longer than args.max_length
raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
num_proc=4,
)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
raw_datasets = raw_datasets.filter(
lambda x: len(x["input_ids_chosen"]) <= reward_config.max_length
and len(x["input_ids_rejected"]) <= reward_config.max_length
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
# Step 4: Define the LoraConfig
if args.use_peft:
peft_config = args.peft_config
else:
peft_config = None
# Step 5: Define the Trainer
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=args.reward_config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
trainer.train()
################
# Training
################
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=reward_config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(reward_config.output_dir)

View File

@ -1,4 +1,4 @@
# coding=utf-8
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -12,150 +12,140 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import List, Optional
"""
# regular:
python examples/scripts/sft.py \
--model_name_or_path="facebook/opt-350m" \
--report_to="wandb" \
--learning_rate=1.41e-5 \
--per_device_train_batch_size=64 \
--gradient_accumulation_steps=16 \
--output_dir="sft_openassistant-guanaco" \
--logging_steps=1 \
--num_train_epochs=3 \
--max_steps=-1 \
--push_to_hub \
--gradient_checkpointing \
# peft:
python examples/scripts/sft.py \
--model_name_or_path="facebook/opt-350m" \
--report_to="wandb" \
--learning_rate=1.41e-5 \
--per_device_train_batch_size=64 \
--gradient_accumulation_steps=16 \
--output_dir="sft_openassistant-guanaco" \
--logging_steps=1 \
--num_train_epochs=3 \
--max_steps=-1 \
--push_to_hub \
--gradient_checkpointing \
--use_peft \
--lora_r=64 \
--lora_alpha=16
"""
import logging
import os
from contextlib import nullcontext
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from trl import SFTTrainer, is_xpu_available
from tqdm.rich import tqdm
from transformers import AutoTokenizer, TrainingArguments
from trl import (
ModelConfig,
RichProgressCallback,
SFTTrainer,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
tqdm.pandas()
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with SFTTrainer
"""
model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(
default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"}
if __name__ == "__main__":
parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"})
report_to: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"})
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"})
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"})
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"})
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
save_steps: Optional[int] = field(
default=100, metadata={"help": "Number of updates steps before two checkpoint saves"}
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
raw_datasets = load_dataset(args.dataset_name)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether to use gradient checkpointing or no"}
)
gradient_checkpointing_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`"
},
)
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
mixed_precision: Optional[str] = field(default="bf16", metadata={"help": "Mixed precision training"})
target_modules: Optional[List[str]] = field(default=None, metadata={"help": "Target modules for LoRA adapters"})
################
# Training
################
with init_context:
trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field=args.dataset_text_field,
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
packing=args.packing,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
trainer.train()
# Step 1: Load the model
if script_args.load_in_8bit and script_args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
# Copy the model to each device
device_map = (
{"": f"xpu:{Accelerator().local_process_index}"}
if is_xpu_available()
else {"": Accelerator().local_process_index}
)
torch_dtype = torch.bfloat16
else:
device_map = None
quantization_config = None
torch_dtype = None
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name,
quantization_config=quantization_config,
device_map=device_map,
trust_remote_code=script_args.trust_remote_code,
torch_dtype=torch_dtype,
use_auth_token=script_args.use_auth_token,
)
# Step 2: Load the dataset
dataset = load_dataset(script_args.dataset_name, split="train")
# Step 3: Define the training arguments
training_args = TrainingArguments(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=script_args.learning_rate,
logging_steps=script_args.logging_steps,
num_train_epochs=script_args.num_train_epochs,
max_steps=script_args.max_steps,
report_to=script_args.report_to,
save_steps=script_args.save_steps,
save_total_limit=script_args.save_total_limit,
push_to_hub=script_args.push_to_hub,
hub_model_id=script_args.hub_model_id,
gradient_checkpointing=script_args.gradient_checkpointing,
# TODO: uncomment that on the next release
# gradient_checkpointing_kwargs=script_args.gradient_checkpointing_kwargs,
)
# Step 4: Define the LoraConfig
if script_args.use_peft:
peft_config = LoraConfig(
r=script_args.peft_lora_r,
lora_alpha=script_args.peft_lora_alpha,
bias="none",
task_type="CAUSAL_LM",
target_modules=script_args.target_modules,
)
else:
peft_config = None
# Step 5: Define the Trainer
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_fast=True)
trainer = SFTTrainer(
model=model,
args=training_args,
max_seq_length=script_args.seq_length,
train_dataset=dataset,
dataset_text_field=script_args.dataset_text_field,
peft_config=peft_config,
tokenizer=tokenizer,
)
trainer.train()
# Step 6: Save the model
trainer.save_model(script_args.output_dir)
with save_context:
trainer.save_model(training_args.output_dir)

View File

@ -1,16 +1,22 @@
[tool.black]
line-length = 119
target-version = ['py38']
[tool.ruff]
ignore = ["E501", "E741", "W605"]
select = ["E", "F", "I", "W"]
target-version = "py37"
line-length = 119
# Ignore import violations in all `__init__.py` files.
[tool.ruff.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
[tool.ruff.lint]
ignore = [
"B028", # warning without explicit stacklevel
"C408", # dict() calls (stylistic)
"C901", # function complexity
"E501",
]
extend-select = ["E", "F", "I", "W", "UP", "B", "T", "C"]
[tool.ruff.isort]
[tool.ruff.lint.per-file-ignores]
# Allow prints in auxiliary scripts
"benchmark/**.py" = ["T201"]
"examples/**.py" = ["T201"]
"scripts/**.py" = ["T201"]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["trl"]

View File

@ -0,0 +1,140 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from datetime import date
from tabulate import tabulate
MAX_LEN_MESSAGE = 2900 # slack endpoint has a limit of 3001 characters
parser = argparse.ArgumentParser()
parser.add_argument("--slack_channel_name", default="trl-push-examples-ci")
parser.add_argument("--text_file_name", required=True)
def main(text_file_name, slack_channel_name=None):
message = ""
if os.path.isfile(text_file_name):
final_results = {}
file = open(text_file_name)
lines = file.readlines()
for line in lines:
result, config_name = line.split(",")
config_name = config_name.split("/")[-1].split(".yaml")[0]
final_results[config_name] = int(result)
no_error_payload = {
"type": "section",
"text": {
"type": "plain_text",
"text": "🌞 There were no failures on the example tests!"
if not len(final_results) == 0
else "Something went wrong there is at least one empty file - please check GH action results.",
"emoji": True,
},
}
total_num_failed = sum(final_results.values())
else:
no_error_payload = {
"type": "section",
"text": {
"type": "plain_text",
"text": "🔴 Something is wrong with the workflow please check ASAP!"
"Something went wrong there is no text file being produced. Please check ASAP.",
"emoji": True,
},
}
total_num_failed = 0
test_type_name = text_file_name.replace(".txt", "").replace("temp_results_", "").replace("_", " ").title()
payload = [
{
"type": "header",
"text": {
"type": "plain_text",
"text": "🤗 Results of the {} TRL {} example tests.".format(
os.environ.get("TEST_TYPE", ""), test_type_name
),
},
},
]
if total_num_failed > 0:
message += f"{total_num_failed} failed tests for example tests!"
for test_name, failed in final_results.items():
failed_table = tabulate(
[[test_name, "🟢" if not failed else "🔴"]],
headers=["Test Name", "Status"],
showindex="always",
tablefmt="grid",
maxcolwidths=[12],
)
message += "\n```\n" + failed_table + "\n```"
print(f"### {message}")
else:
payload.append(no_error_payload)
if os.environ.get("TEST_TYPE", "") != "":
from slack_sdk import WebClient
if len(message) > MAX_LEN_MESSAGE:
print(f"Truncating long message from {len(message)} to {MAX_LEN_MESSAGE}")
message = message[:MAX_LEN_MESSAGE] + "..."
if len(message) != 0:
md_report = {
"type": "section",
"text": {"type": "mrkdwn", "text": message},
}
payload.append(md_report)
action_button = {
"type": "section",
"text": {"type": "mrkdwn", "text": "*For more details:*"},
"accessory": {
"type": "button",
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
"url": f"https://github.com/huggingface/trl/actions/runs/{os.environ['GITHUB_RUN_ID']}",
},
}
payload.append(action_button)
date_report = {
"type": "context",
"elements": [
{
"type": "plain_text",
"text": f"On Push - main {os.environ.get('TEST_TYPE')} test results for {date.today()}",
},
],
}
payload.append(date_report)
print(payload)
client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload)
if __name__ == "__main__":
args = parser.parse_args()
main(args.text_file_name, args.slack_channel_name)

153
scripts/log_reports.py Normal file
View File

@ -0,0 +1,153 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from datetime import date
from pathlib import Path
from tabulate import tabulate
MAX_LEN_MESSAGE = 2900 # slack endpoint has a limit of 3001 characters
parser = argparse.ArgumentParser()
parser.add_argument("--slack_channel_name", default="trl-push-ci")
def main(slack_channel_name=None):
failed = []
passed = []
group_info = []
total_num_failed = 0
empty_file = False or len(list(Path().glob("*.log"))) == 0
total_empty_files = []
for log in Path().glob("*.log"):
section_num_failed = 0
i = 0
with open(log) as f:
for line in f:
line = json.loads(line)
i += 1
if line.get("nodeid", "") != "":
test = line["nodeid"]
if line.get("duration", None) is not None:
duration = f'{line["duration"]:.4f}'
if line.get("outcome", "") == "failed":
section_num_failed += 1
failed.append([test, duration, log.name.split("_")[0]])
total_num_failed += 1
else:
passed.append([test, duration, log.name.split("_")[0]])
empty_file = i == 0
group_info.append([str(log), section_num_failed, failed])
total_empty_files.append(empty_file)
os.remove(log)
failed = []
no_error_payload = {
"type": "section",
"text": {
"type": "plain_text",
"text": "🌞 There were no failures!"
if not any(total_empty_files)
else "Something went wrong there is at least one empty file - please check GH action results.",
"emoji": True,
},
}
message = ""
payload = [
{
"type": "header",
"text": {
"type": "plain_text",
"text": "🤗 Results of the {} TRL tests.".format(os.environ.get("TEST_TYPE", "")),
},
},
]
if total_num_failed > 0:
for i, (name, num_failed, failed_tests) in enumerate(group_info):
if num_failed > 0:
if num_failed == 1:
message += f"*{name}: {num_failed} failed test*\n"
else:
message += f"*{name}: {num_failed} failed tests*\n"
failed_table = []
for test in failed_tests:
failed_report = test[0].split("::")
# Truncate the last string as some test names might be long
failed_report[-1] = failed_report[-1][:30] + ".."
failed_table.append(failed_report)
failed_table = tabulate(
failed_table,
headers=["Test Location", "Test Case", "Test Name"],
showindex="always",
tablefmt="grid",
maxcolwidths=[12, 12, 12],
)
message += "\n```\n" + failed_table + "\n```"
if total_empty_files[i]:
message += f"\n*{name}: Warning! Empty file - please check the GitHub action job *\n"
print(f"### {message}")
else:
payload.append(no_error_payload)
if os.environ.get("TEST_TYPE", "") != "":
from slack_sdk import WebClient
if len(message) > MAX_LEN_MESSAGE:
message = f"There are {total_num_failed} failed tests in total ! Cannot display the entire summary - please check the action results directly"
if len(message) != 0:
md_report = {
"type": "section",
"text": {"type": "mrkdwn", "text": message},
}
payload.append(md_report)
action_button = {
"type": "section",
"text": {"type": "mrkdwn", "text": "*For more details:*"},
"accessory": {
"type": "button",
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
"url": f"https://github.com/huggingface/trl/actions/runs/{os.environ['GITHUB_RUN_ID']}",
},
}
payload.append(action_button)
date_report = {
"type": "context",
"elements": [
{
"type": "plain_text",
"text": f"On Push main {os.environ.get('TEST_TYPE')} test results for {date.today()}",
},
],
}
payload.append(date_report)
print(payload)
client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload)
if __name__ == "__main__":
args = parser.parse_args()
main(args.slack_channel_name)

View File

@ -35,7 +35,7 @@ def main():
open_issues = repo.get_issues(state="open")
for issue in open_issues:
comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True)
comments = sorted(issue.get_comments(), key=lambda i: i.created_at, reverse=True)
last_comment = comments[0] if len(comments) > 0 else None
if (
last_comment is not None

View File

@ -1,11 +1,2 @@
[metadata]
license_file = LICENSE
[isort]
ensure_newline_before_comments = True
force_grid_wrap = 0
include_trailing_comma = True
line_length = 119
lines_after_imports = 2
multi_line_output = 3
use_parentheses = True
license_file = LICENSE

View File

@ -53,11 +53,12 @@ To create the package for pypi.
8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0).
Then push the change with a message 'set dev version'
"""
import os
from setuptools import find_packages, setup
__version__ = "0.7.8.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.8.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"torch>=1.4.0",
@ -68,7 +69,7 @@ REQUIRED_PKGS = [
"tyro>=0.5.11",
]
EXTRAS = {
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate"],
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist"],
"peft": ["peft>=0.4.0"],
"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.9.5"],
@ -79,34 +80,44 @@ EXTRAS["dev"] = []
for reqs in EXTRAS.values():
EXTRAS["dev"].extend(reqs)
setup(
name="trl",
license="Apache 2.0",
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
url="https://github.com/huggingface/trl",
packages=find_packages(),
include_package_data=True,
install_requires=REQUIRED_PKGS,
extras_require=EXTRAS,
python_requires=">=3.7",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
zip_safe=False,
version=__version__,
description="Train transformer language models with reinforcement learning.",
keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf",
author="Leandro von Werra",
author_email="leandro.vonwerra@gmail.com",
)
try:
file_path = os.path.dirname(os.path.abspath(__file__))
os.symlink(os.path.join(file_path, "examples/scripts"), os.path.join(file_path, "trl/commands/scripts"))
setup(
name="trl",
license="Apache 2.0",
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
url="https://github.com/huggingface/trl",
entry_points={
"console_scripts": ["trl=trl.commands.cli:main"],
},
include_package_data=True,
package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*"]},
packages=find_packages(),
install_requires=REQUIRED_PKGS,
extras_require=EXTRAS,
python_requires=">=3.7",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
zip_safe=False,
version=__version__,
description="Train transformer language models with reinforcement learning.",
keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf",
author="Leandro von Werra",
author_email="leandro.vonwerra@gmail.com",
)
finally:
os.unlink(os.path.join(file_path, "trl/commands/scripts"))

0
tests/slow/__init__.py Normal file
View File

221
tests/slow/test_dpo_slow.py Normal file
View File

@ -0,0 +1,221 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import itertools
import tempfile
import unittest
import torch
from accelerate.utils.memory import release_memory
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from trl import DPOTrainer, is_peft_available
from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST
if is_peft_available():
from peft import LoraConfig, PeftModel
@require_torch_gpu
class DPOTrainerSlowTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dataset = load_dataset("trl-internal-testing/mlabonne-chatml-dpo-pairs-copy", split="train[:10%]")
cls.peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
)
cls.max_length = 128
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS)))
def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
"""
A test that tests the simple usage of `DPOTrainer` using a bare model in full precision.
"""
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
remove_unused_columns=False,
gradient_accumulation_steps=2,
learning_rate=9e-1,
evaluation_strategy="steps",
fp16=True,
logging_strategy="no",
report_to="none",
)
# dpo train lora model
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=self.dataset,
eval_dataset=self.dataset,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
max_length=self.max_length,
)
# train the model
trainer.train()
# save trained model or adapter
trainer.save_model()
release_memory(model, trainer)
@parameterized.expand(
list(
itertools.product(
MODELS_TO_TEST,
DPO_LOSS_TYPES,
DPO_PRECOMPUTE_LOGITS,
GRADIENT_CHECKPOINTING_KWARGS,
)
)
)
@require_peft
def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_checkpointing_kwargs):
"""
A test that tests the simple usage of `DPOTrainer` using a peft model in full precision + different scenarios of gradient checkpointing.
"""
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
remove_unused_columns=False,
gradient_accumulation_steps=2,
learning_rate=9e-1,
evaluation_strategy="steps",
fp16=True,
logging_strategy="no",
report_to="none",
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
)
# dpo train lora model
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=self.dataset,
eval_dataset=self.dataset,
generate_during_eval=False,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
peft_config=self.peft_config,
max_length=self.max_length,
)
assert isinstance(trainer.model, PeftModel)
assert trainer.ref_model is None
# train the model
trainer.train()
# save trained model or adapter
trainer.save_model()
release_memory(model, trainer)
@parameterized.expand(
list(
itertools.product(
MODELS_TO_TEST,
DPO_LOSS_TYPES,
DPO_PRECOMPUTE_LOGITS,
GRADIENT_CHECKPOINTING_KWARGS,
)
)
)
@require_bitsandbytes
@require_peft
def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gradient_checkpointing_kwargs):
"""
A test that tests the simple usage of `DPOTrainer` using QLoRA + different scenarios of gradient checkpointing.
"""
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
remove_unused_columns=False,
gradient_accumulation_steps=2,
learning_rate=9e-1,
evaluation_strategy="steps",
fp16=True,
logging_strategy="no",
report_to="none",
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
)
# dpo train lora model
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=self.dataset,
eval_dataset=self.dataset,
generate_during_eval=False,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute_logits,
peft_config=self.peft_config,
max_length=self.max_length,
)
assert isinstance(trainer.model, PeftModel)
assert trainer.ref_model is None
# train the model
trainer.train()
# save trained model or adapter
trainer.save_model()
release_memory(model, trainer)

390
tests/slow/test_sft_slow.py Normal file
View File

@ -0,0 +1,390 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import itertools
import tempfile
import unittest
import torch
from accelerate.utils.memory import release_memory
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from trl import SFTTrainer, is_peft_available
from trl.models.utils import setup_chat_format
from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu, require_torch_multi_gpu
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS
if is_peft_available():
from peft import LoraConfig, PeftModel
@require_torch_gpu
class SFTTrainerSlowTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.train_dataset = load_dataset("imdb", split="train[:10%]")
cls.eval_dataset = load_dataset("imdb", split="test[:10%]")
cls.dataset_text_field = "text"
cls.max_seq_length = 128
cls.peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
def test_sft_trainer_str(self, model_name, packing):
"""
Simply tests if passing a simple str to `SFTTrainer` loads and runs the trainer
as expected.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
)
trainer = SFTTrainer(
model_name,
args=args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
packing=packing,
dataset_text_field=self.dataset_text_field,
max_seq_length=self.max_seq_length,
)
trainer.train()
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
def test_sft_trainer_transformers(self, model_name, packing):
"""
Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer
as expected.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
packing=packing,
dataset_text_field=self.dataset_text_field,
max_seq_length=self.max_seq_length,
)
trainer.train()
release_memory(model, trainer)
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
@require_peft
def test_sft_trainer_peft(self, model_name, packing):
"""
Simply tests if passing a transformers model + peft config to `SFTTrainer` loads and runs the trainer
as expected.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
packing=packing,
dataset_text_field=self.dataset_text_field,
max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)
assert isinstance(trainer.model, PeftModel)
trainer.train()
release_memory(model, trainer)
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
def test_sft_trainer_transformers_mp(self, model_name, packing):
"""
Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer
as expected in mixed precision.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True, # this is sufficient to enable amp
)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
packing=packing,
dataset_text_field=self.dataset_text_field,
max_seq_length=self.max_seq_length,
)
trainer.train()
release_memory(model, trainer)
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS)))
def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_checkpointing_kwargs):
"""
Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer
as expected in mixed precision + different scenarios of gradient_checkpointing.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
packing=packing,
dataset_text_field=self.dataset_text_field,
max_seq_length=self.max_seq_length,
)
trainer.train()
release_memory(model, trainer)
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS)))
@require_peft
def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient_checkpointing_kwargs):
"""
Simply tests if passing a transformers model + PEFT to `SFTTrainer` loads and runs the trainer
as expected in mixed precision + different scenarios of gradient_checkpointing.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
packing=packing,
dataset_text_field=self.dataset_text_field,
max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)
assert isinstance(trainer.model, PeftModel)
trainer.train()
release_memory(model, trainer)
@parameterized.expand(
list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, DEVICE_MAP_OPTIONS))
)
@require_torch_multi_gpu
def test_sft_trainer_transformers_mp_gc_device_map(
self, model_name, packing, gradient_checkpointing_kwargs, device_map
):
"""
Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer
as expected in mixed precision + different scenarios of gradient_checkpointing (single, multi-gpu, etc).
"""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_name)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
packing=packing,
dataset_text_field=self.dataset_text_field,
max_seq_length=self.max_seq_length,
)
trainer.train()
release_memory(model, trainer)
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS)))
@require_peft
@require_bitsandbytes
def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gradient_checkpointing_kwargs):
"""
Simply tests if passing a transformers model + PEFT + bnb to `SFTTrainer` loads and runs the trainer
as expected in mixed precision + different scenarios of gradient_checkpointing.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
)
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
packing=packing,
dataset_text_field=self.dataset_text_field,
max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)
assert isinstance(trainer.model, PeftModel)
trainer.train()
release_memory(model, trainer)
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
@require_peft
@require_bitsandbytes
def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
"""
Simply tests if using setup_chat_format with a transformers model + peft + bnb config to `SFTTrainer` loads and runs the trainer
as expected.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train")
args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True,
)
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model, tokenizer = setup_chat_format(model, tokenizer)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
packing=packing,
max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)
assert isinstance(trainer.model, PeftModel)
trainer.train()
release_memory(model, trainer)

View File

@ -0,0 +1,27 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: push them under trl-org
MODELS_TO_TEST = [
"HuggingFaceM4/tiny-random-LlamaForCausalLM",
"HuggingFaceM4/tiny-random-MistralForCausalLM",
]
# We could have also not declared these variables but let's be verbose
PACKING_OPTIONS = [True, False]
GRADIENT_CHECKPOINTING_KWARGS = [None, {"use_reentrant": False}, {"use_reentrant": True}]
DEVICE_MAP_OPTIONS = [{"": 0}, "auto"]
DPO_LOSS_TYPES = ["sigmoid", "ipo", "kto_pair"]
DPO_PRECOMPUTE_LOGITS = [True, False]

View File

@ -59,7 +59,7 @@ class BestOfNSamplerTester(unittest.TestCase):
for q, expected_length in various_queries_formats:
results = best_of_n.generate(q)
self.assertIsInstance(results, list)
assert isinstance(results, list)
assert len(results) == expected_length
def test_different_sample_sizes_and_n_candidates_values(self):

40
tests/test_cli.py Normal file
View File

@ -0,0 +1,40 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
import unittest
@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_sft_cli():
try:
subprocess.run(
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
except BaseException as exc:
raise AssertionError("An error occured while running the CLI, please double check") from exc
@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
def test_dpo_cli():
try:
subprocess.run(
"trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --learning_rate 1e-4 --lr_scheduler_type cosine",
shell=True,
check=True,
)
except BaseException as exc:
raise AssertionError("An error occured while running the CLI, please double check") from exc

View File

@ -30,13 +30,13 @@ class CoreTester(unittest.TestCase):
cls.test_input_unmasked = cls.test_input[1:3]
def test_masked_mean(self):
self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask))
assert torch.mean(self.test_input_unmasked) == masked_mean(self.test_input, self.test_mask)
def test_masked_var(self):
self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask))
assert torch.var(self.test_input_unmasked) == masked_var(self.test_input, self.test_mask)
def test_masked_whiten(self):
whiten_unmasked = whiten(self.test_input_unmasked)
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
diffs = (whiten_unmasked - whiten_masked).sum()
self.assertAlmostEqual(diffs, 0)
assert abs(diffs.item()) < 0.00001

View File

@ -31,18 +31,21 @@ class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
self.instruction_template = "\n### User:"
self.response_template = "\n### Assistant:"
# GPT2Tokenizer: [198, 21017, 11787, 25] -> [11787, 25]
# GPT2Tokenizer: [198, 21017, 11787, 25] -> [21017, 11787, 25]
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4911, 29901] -> [2277, 29937, 4911, 29901]
# Note: If this test is ever switched to Llama2Tokenizer, this should be double checked,
# and possibly switched back to [2:] instead of [1:].
# With GPT2Tokenizer, [1:] is correct - we want the 21017 token included, which is ###.
self.tokenized_instruction_w_context = self.tokenizer.encode(
self.instruction_template, add_special_tokens=False
)[2:]
)[1:]
# GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25]
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901]
self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:]
# Plain check on string
self.assertIn(self.response_template, self.instruction)
assert self.response_template in self.instruction
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
# Test the fix for #598
@ -57,6 +60,28 @@ class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
)
self.collator.torch_call([self.tokenized_instruction])
# Test for PR #1185
# We pass in a string where the first user template is different than the rest.
# Usually this would happen due to context-sensitive tokenization, but here we
# explicitly change the template to test the fix.
self.instruction = """## User: First instruction
### Assistant: First response
### User: Second instruction
### Assistant: Second response"""
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
self.collator = DataCollatorForCompletionOnlyLM(
self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer
)
collator_output = self.collator.torch_call([self.tokenized_instruction])
collator_text = self.tokenizer.decode(
collator_output["labels"][torch.where(collator_output["labels"] != -100)]
)
expected_text = " First response\n\n Second response" ""
assert collator_text == expected_text
def test_data_collator_handling_of_long_sequences(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.instruction = """### System: You are a helpful assistant.
@ -69,7 +94,7 @@ class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer)
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
result = torch.all(encoded_instance["labels"] == -100)
self.assertTrue(result, "Not all values in the tensor are -100.")
assert result, "Not all values in the tensor are -100."
# check DataCollatorForCompletionOnlyLM using response template and instruction template
self.instruction_template = "\n### User:"
@ -78,4 +103,4 @@ class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
)
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
result = torch.all(encoded_instance["labels"] == -100)
self.assertTrue(result, "Not all values in the tensor are -100.")
assert result, "Not all values in the tensor are -100."

View File

@ -0,0 +1,142 @@
import unittest
from typing import Callable
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from trl.models.utils import ChatMlSpecialTokens, setup_chat_format
class DatasetFormattingTestCase(unittest.TestCase):
def setUp(self):
self.llama_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
self.chatml_tokenizer = AutoTokenizer.from_pretrained("philschmid/gpt2-chatml-tokenizer")
def test_get_formatting_func_from_dataset_with_chatml_messages(self):
dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
]
}
)
# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
assert isinstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
expected = "<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"
assert formatted_text == expected
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == [expected]
# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
assert formatted_text == expected
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == [expected]
def test_get_formatting_func_from_dataset_with_chatml_conversations(self):
dataset = Dataset.from_dict(
{
"conversations": [
[
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
]
}
)
# Llama tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
assert isinstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
expected = "<s>[INST] <<SYS>>\nYou are helpful\n<</SYS>>\n\nHello [/INST] Hi, how can I help you? </s>"
assert formatted_text == expected
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == [expected]
# ChatML tokenizer
formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer)
formatted_text = formatting_func(dataset[0])
expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
assert formatted_text == expected
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == [expected]
def test_get_formatting_func_from_dataset_with_instruction(self):
dataset = Dataset.from_list(
[{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}]
)
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
assert formatting_func is not None
assert isinstance(formatting_func, Callable)
formatted_text = formatting_func(dataset[0])
assert formatted_text == "<s>[INST] What is 2+2? [/INST] 4 </s>"
formatted_text = formatting_func(dataset[0:1])
assert formatted_text == ["<s>[INST] What is 2+2? [/INST] 4 </s>"]
def test_get_formatting_func_from_dataset_from_hub(self):
ds_1 = load_dataset("philschmid/trl-test-instruction", split="train")
ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train")
for ds in [ds_1, ds_2]:
formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer)
assert formatting_func is not None
assert isinstance(formatting_func, Callable)
ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train")
formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer)
assert formatting_func is None
def test_get_formatting_func_from_dataset_with_unknown_format(self):
dataset = Dataset.from_dict({"text": "test"})
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
assert formatting_func is None
class SetupChatFormatTestCase(unittest.TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
def test_setup_chat_format(self):
original_tokenizer_len = len(self.tokenizer)
modified_model, modified_tokenizer = setup_chat_format(
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64
)
_chatml = ChatMlSpecialTokens()
# Check if special tokens are correctly set
assert modified_tokenizer.eos_token == "<|im_end|>"
assert modified_tokenizer.pad_token == "<|im_end|>"
assert modified_tokenizer.bos_token == "<|im_start|>"
assert modified_tokenizer.eos_token == _chatml.eos_token
assert modified_tokenizer.pad_token == _chatml.pad_token
assert modified_tokenizer.bos_token == _chatml.bos_token
assert len(modified_tokenizer) == (original_tokenizer_len + 2)
assert (self.model.get_input_embeddings().weight.shape[0] % 64) == 0
assert self.model.get_input_embeddings().weight.shape[0] == (original_tokenizer_len + 64)
def test_example_with_setup_model(self):
modified_model, modified_tokenizer = setup_chat_format(
self.model,
self.tokenizer,
)
messages = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
assert (
prompt
== "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n"
)

View File

@ -68,13 +68,13 @@ class DDPOTrainerTester(unittest.TestCase):
clip_range = 0.0001
ratio = torch.tensor([1.0])
loss = self.trainer.loss(advantage, clip_range, ratio)
self.assertEqual(loss.item(), 1.0)
assert loss.item() == 1.0
def test_generate_samples(self):
samples, output_pairs = self.trainer._generate_samples(1, 2)
self.assertEqual(len(samples), 1)
self.assertEqual(len(output_pairs), 1)
self.assertEqual(len(output_pairs[0][0]), 2)
assert len(samples) == 1
assert len(output_pairs) == 1
assert len(output_pairs[0][0]) == 2
def test_calculate_loss(self):
samples, _ = self.trainer._generate_samples(1, 2)
@ -87,16 +87,16 @@ class DDPOTrainerTester(unittest.TestCase):
prompt_embeds = sample["prompt_embeds"]
advantage = torch.tensor([1.0], device=prompt_embeds.device)
self.assertEqual(latents.shape, (1, 4, 64, 64))
self.assertEqual(next_latents.shape, (1, 4, 64, 64))
self.assertEqual(log_probs.shape, (1,))
self.assertEqual(timesteps.shape, (1,))
self.assertEqual(prompt_embeds.shape, (2, 77, 32))
assert latents.shape == (1, 4, 64, 64)
assert next_latents.shape == (1, 4, 64, 64)
assert log_probs.shape == (1,)
assert timesteps.shape == (1,)
assert prompt_embeds.shape == (2, 77, 32)
loss, approx_kl, clipfrac = self.trainer.calculate_loss(
latents, timesteps, next_latents, log_probs, advantage, prompt_embeds
)
self.assertTrue(torch.isfinite(loss.cpu()))
assert torch.isfinite(loss.cpu())
@require_diffusers

View File

@ -22,7 +22,7 @@ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokeni
from trl import DPOTrainer
from .testing_utils import require_no_wandb, require_peft
from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft
class DPOTrainerTester(unittest.TestCase):
@ -129,14 +129,14 @@ class DPOTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[-1]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
assert not torch.equal(param, new_param)
def test_dpo_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -167,14 +167,14 @@ class DPOTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[-1]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
assert not torch.equal(param, new_param)
@require_peft
@mark.peft_test
@ -218,7 +218,7 @@ class DPOTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[-1]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
@ -226,7 +226,78 @@ class DPOTrainerTester(unittest.TestCase):
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
assert not torch.equal(param, new_param)
def test_dpo_trainer_padding_token_is_none(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
)
dummy_dataset = self._init_dummy_dataset()
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokenizer.pad_token = None
with self.assertRaisesRegex(
ValueError,
expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token."
r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)"
r" before calling the trainer.",
):
trainer = DPOTrainer(
model=self.model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
trainer.train()
def test_dpo_trainer_w_dataset_num_proc(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
)
dummy_dataset = self._init_dummy_dataset()
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokenizer.pad_token = None
with self.assertRaisesRegex(
ValueError,
expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token."
r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)"
r" before calling the trainer.",
):
trainer = DPOTrainer(
model=self.model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
dataset_num_proc=5,
)
trainer.train()
@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
@ -313,3 +384,267 @@ class DPOTrainerTester(unittest.TestCase):
AutoModelForCausalLM.from_pretrained(tmp_dir)
except OSError:
self.fail("Loading the saved peft adapter failed")
@require_peft
@require_bitsandbytes
@mark.peft_test
def test_dpo_lora_bf16_autocast_llama(self):
# Note this test only works on compute capability > 7 GPU devices
from peft import LoraConfig
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# lora model
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
bf16=True,
)
dummy_dataset = self._init_dummy_dataset()
# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
generate_during_eval=True,
)
# train the model
trainer.train()
# save peft adapter
trainer.save_model()
@parameterized.expand(
[
["gpt2", "sigmoid", False, False],
["gpt2", "sigmoid", False, True],
["gpt2", "sigmoid", True, False],
["gpt2", "sigmoid", True, True],
["gpt2", "ipo", False, False],
["gpt2", "ipo", False, True],
["gpt2", "ipo", True, False],
["gpt2", "ipo", True, True],
["gpt2", "kto_pair", False, False],
["gpt2", "kto_pair", False, True],
["gpt2", "kto_pair", True, False],
["gpt2", "kto_pair", True, True],
]
)
@require_bitsandbytes
@require_peft
@mark.peft_test
@unittest.skip("You need a GPU with bf16 support in order to run these tests")
def test_dpo_lora_bf16_autocast(self, name, loss_type, pre_compute, gen_during_eval):
# Note this test only works on compute capability > 7 GPU devices
from peft import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# lora model
model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
bf16=True,
)
dummy_dataset = self._init_dummy_dataset()
# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
generate_during_eval=gen_during_eval,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute,
)
# train the model
trainer.train()
# save peft adapter
trainer.save_model()
@require_peft
def test_dpo_lora_tags(self):
from peft import LoraConfig
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# lora model
model = AutoModelForCausalLM.from_pretrained(model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
)
dummy_dataset = self._init_dummy_dataset()
# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
assert trainer.model.model_tags == trainer._tag_names
@require_peft
def test_dpo_tags(self):
model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# lora model
model = AutoModelForCausalLM.from_pretrained(model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
)
dummy_dataset = self._init_dummy_dataset()
# dpo train lora model with a lora config
trainer = DPOTrainer(
model=model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
assert trainer.model.model_tags == trainer._tag_names
@require_peft
@mark.peft_test
def test_dpo_lora_force_use_ref(self):
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# lora model
model = AutoModelForCausalLM.from_pretrained(self.model_id)
model_peft = get_peft_model(model, lora_config)
ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
)
dummy_dataset = self._init_dummy_dataset()
with self.assertRaises(ValueError):
# passing a peft_model as model and ref_model should error out,
# unless you pass `force_use_ref_model`
trainer = DPOTrainer(
model=model_peft,
ref_model=ref_model,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
trainer = DPOTrainer(
model=model_peft,
ref_model=ref_model,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
force_use_ref_model=True,
)
# train the model
trainer.train()

View File

@ -38,12 +38,12 @@ class TextHistoryTest(unittest.TestCase):
tokens = torch.tensor([1, 2, 3])
history = TextHistory(text, tokens)
self.assertEqual(history.text, text)
self.assertTrue(torch.equal(history.tokens, tokens))
self.assertTrue(torch.equal(history.token_masks, torch.zeros_like(tokens)))
assert history.text == text
assert torch.equal(history.tokens, tokens)
assert torch.equal(history.token_masks, torch.zeros_like(tokens))
history = TextHistory(text, tokens, system=False)
self.assertTrue(torch.equal(history.token_masks, torch.ones_like(tokens)))
assert torch.equal(history.token_masks, torch.ones_like(tokens))
def test_text_history_append_segment(self):
text = "Hello there!"
@ -51,26 +51,26 @@ class TextHistoryTest(unittest.TestCase):
history = TextHistory(text, tokens)
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False)
self.assertEqual(history.text, text + "General Kenobi!")
self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6])))
self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1])))
assert history.text == (text + "General Kenobi!")
assert torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6]))
assert torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1]))
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
self.assertEqual(history.text, text + "General Kenobi!" + "You are a bold one!")
self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])))
self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0])))
assert history.text == ((text + "General Kenobi!") + "You are a bold one!")
assert torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]))
assert torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0]))
def test_text_history_complete(self):
text = "Hello there!"
tokens = torch.tensor([1, 2, 3])
history = TextHistory(text, tokens)
history.complete()
self.assertTrue(history.completed)
self.assertFalse(history.truncated)
assert history.completed
assert not history.truncated
history.complete(truncated=True)
self.assertTrue(history.completed)
self.assertTrue(history.truncated)
assert history.completed
assert history.truncated
def test_text_history_last_segment(self):
text = "Hello there!"
@ -78,7 +78,7 @@ class TextHistoryTest(unittest.TestCase):
history = TextHistory(text, tokens)
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]))
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
self.assertEqual(history.last_text_segment, "You are a bold one!")
assert history.last_text_segment == "You are a bold one!"
def test_text_history_split_query_response(self):
text = "Hello there!"
@ -88,9 +88,9 @@ class TextHistoryTest(unittest.TestCase):
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]), system=True)
query, response, mask = history.split_query_response_tokens()
self.assertTrue(torch.equal(query, torch.tensor([1, 2, 3])))
self.assertTrue(torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9])))
self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0])))
assert torch.equal(query, torch.tensor([1, 2, 3]))
assert torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9]))
assert torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0]))
class TextEnvironmentTester(unittest.TestCase):
@ -112,10 +112,10 @@ class TextEnvironmentTester(unittest.TestCase):
reward_fn=lambda x: torch.tensor(1),
prompt="I am a prompt!\n",
)
self.assertEqual(env.prompt, "I am a prompt!\n")
self.assertEqual(list(env.tools.keys()), ["DummyTool"])
self.assertTrue(isinstance(env.tools["DummyTool"], DummyTool))
self.assertEqual(env.reward_fn("Hello there!"), 1)
assert env.prompt == "I am a prompt!\n"
assert list(env.tools.keys()) == ["DummyTool"]
assert isinstance(env.tools["DummyTool"], DummyTool)
assert env.reward_fn("Hello there!") == 1
def test_text_environment_generate(self):
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id}
@ -138,7 +138,7 @@ class TextEnvironmentTester(unittest.TestCase):
generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs]
generations_single = self.gpt2_tokenizer.batch_decode(generations_single)
self.assertEqual(generations_single, generations_batched)
assert generations_single == generations_batched
def test_text_environment_tool_call_parsing(self):
string_valid = "Something something <request><Tool1>Hello there!<call>"
@ -155,24 +155,24 @@ class TextEnvironmentTester(unittest.TestCase):
prompt="I am a prompt!\n",
)
tool, response = env.parse_tool_call(string_valid)
self.assertEqual(tool, "Tool1")
self.assertEqual(response, "Hello there!")
assert tool == "Tool1"
assert response == "Hello there!"
tool, response = env.parse_tool_call(string_invalid_request)
self.assertEqual(tool, None)
self.assertEqual(response, None)
assert tool is None
assert response is None
tool, response = env.parse_tool_call(string_invalid_call)
self.assertEqual(tool, None)
self.assertEqual(response, None)
assert tool is None
assert response is None
tool, response = env.parse_tool_call(string_invalid_tool)
self.assertEqual(tool, None)
self.assertEqual(response, None)
assert tool is None
assert response is None
tool, response = env.parse_tool_call(string_invalid_random)
self.assertEqual(tool, None)
self.assertEqual(response, None)
assert tool is None
assert response is None
def test_text_environment_tool_truncation(self):
env = TextEnvironment(
@ -185,19 +185,19 @@ class TextEnvironmentTester(unittest.TestCase):
env.max_tool_response = 100
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 100)
assert (len(history.last_text_segment) - len(env.response_token)) == 100
env.max_tool_response = 500
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 500)
assert (len(history.last_text_segment) - len(env.response_token)) == 500
env.max_tool_response = 1001
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000)
assert (len(history.last_text_segment) - len(env.response_token)) == 1000
env.max_tool_response = 2000
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000)
assert (len(history.last_text_segment) - len(env.response_token)) == 1000
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
def test_text_environment_max_calls(self, mock_generate):
@ -211,20 +211,20 @@ class TextEnvironmentTester(unittest.TestCase):
env.max_turns = 1
_, _, _, _, histories = env.run(["test"])
self.assertEqual(
histories[0].text, "I am a prompt!\n" + "test" + 1 * "<request><DummyTool>test<call>test<response>"
assert histories[0].text == (
("I am a prompt!\n" + "test") + (1 * "<request><DummyTool>test<call>test<response>")
)
env.max_turns = 2
_, _, _, _, histories = env.run(["test"])
self.assertEqual(
histories[0].text, "I am a prompt!\n" + "test" + 2 * "<request><DummyTool>test<call>test<response>"
assert histories[0].text == (
("I am a prompt!\n" + "test") + (2 * "<request><DummyTool>test<call>test<response>")
)
env.max_turns = 4
_, _, _, _, histories = env.run(["test"])
self.assertEqual(
histories[0].text, "I am a prompt!\n" + "test" + 4 * "<request><DummyTool>test<call>test<response>"
assert histories[0].text == (
("I am a prompt!\n" + "test") + (4 * "<request><DummyTool>test<call>test<response>")
)
def test_text_environment_compute_rewards(self):
@ -240,7 +240,7 @@ class TextEnvironmentTester(unittest.TestCase):
histories = env.compute_reward(histories)
for i in range(8):
self.assertEqual(histories[i].reward, i)
assert histories[i].reward == i
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
def test_text_environment_run(self, mock_generate):
@ -256,18 +256,20 @@ class TextEnvironmentTester(unittest.TestCase):
task_2 = "Hello there! General Kenobi!"
query, response, response_mask, reward, histories = env.run([task_1, task_2])
self.assertEqual(len(query[0]), 9)
self.assertEqual(len(query[1]), 12)
self.assertEqual(len(response[0]), 14)
self.assertEqual(len(response[1]), 14)
self.assertEqual(response_mask[0].sum(), 2 * 3) # mocked generate always adds 3 toknes
self.assertEqual(response_mask[1].sum(), 2 * 3) # mocked generate always adds 3 toknes
self.assertEqual(reward[0], 0)
self.assertEqual(reward[1], 1)
self.assertEqual(
histories[0].text, "I am a prompt!\n" + "Hello there!" + 2 * "<request><DummyTool>test<call>test<response>"
assert len(query[0]) == 9
assert len(query[1]) == 12
assert len(response[0]) == 14
assert len(response[1]) == 14
assert response_mask[0].sum() == (2 * 3)
# mocked generate always adds 3 toknes
assert response_mask[1].sum() == (2 * 3)
# mocked generate always adds 3 toknes
assert reward[0] == 0
assert reward[1] == 1
assert histories[0].text == (
("I am a prompt!\n" + "Hello there!") + (2 * "<request><DummyTool>test<call>test<response>")
)
self.assertEqual(
histories[1].text,
"I am a prompt!\n" + "Hello there! General Kenobi!" + 2 * "<request><DummyTool>test<call>test<response>",
assert histories[1].text == (
("I am a prompt!\n" + "Hello there! General Kenobi!")
+ (2 * "<request><DummyTool>test<call>test<response>")
)

View File

@ -13,6 +13,7 @@
# limitations under the License.
import tempfile
import unittest
from functools import partial
import torch
from datasets import Dataset
@ -31,15 +32,27 @@ class IterativeTrainerTester(unittest.TestCase):
cls.tokenizer.pad_token = cls.tokenizer.eos_token
# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab-calibrated"
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
def _init_tensor_dummy_dataset(self):
dummy_dataset_dict = {
"input_ids": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
"attention_mask": [torch.tensor([1, 1]), torch.tensor([1, 1, 1]), torch.tensor([1, 1])],
"labels": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
"input_ids": [
torch.tensor([5303, 3621, 3666, 1438, 318]),
torch.tensor([3666, 1438, 318, 3666, 1438, 318]),
torch.tensor([5303, 3621, 3666, 1438, 318]),
],
"attention_mask": [
torch.tensor([1, 1, 1, 1, 1]),
torch.tensor([1, 1, 1, 1, 1, 1]),
torch.tensor([1, 1, 1, 1, 1]),
],
"labels": [
torch.tensor([5303, 3621, 3666, 1438, 318]),
torch.tensor([3666, 1438, 318, 3666, 1438, 318]),
torch.tensor([5303, 3621, 3666, 1438, 318]),
],
}
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
@ -94,11 +107,10 @@ class IterativeTrainerTester(unittest.TestCase):
tokenizer = self.t5_tokenizer
args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=2,
output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=2, learning_rate=1e-3
)
iterative_trainer = IterativeSFTTrainer(model=model, args=args, tokenizer=tokenizer)
iterative_trainer.optimizer.zero_grad = partial(iterative_trainer.optimizer.zero_grad, set_to_none=False)
iterative_trainer.step(**inputs)

340
tests/test_kto_trainer.py Normal file
View File

@ -0,0 +1,340 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import torch
from datasets import Dataset
from parameterized import parameterized
from pytest import mark
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from trl import KTOConfig, KTOTrainer
from .testing_utils import require_no_wandb, require_peft
class KTOTrainerTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model_id)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
cls.tokenizer.pad_token = cls.tokenizer.eos_token
# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
def _init_dummy_dataset(self):
# fmt: off
dummy_dataset_dict = {
"prompt": [
"Hey, hello",
"How are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"completion": [
"hi nice to meet you",
"leave me alone",
"I don't have a name",
"My name is Mary",
"Python",
"C++",
"Java",
],
"label": [
True,
False,
False,
True,
True,
False,
False,
],
}
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)
@parameterized.expand(
[
["gpt2", True, True],
["gpt2", True, False],
# ["t5", True],
["gpt2", False, True],
["gpt2", False, False],
# ["t5", False],
]
)
def test_kto_trainer(self, name, pre_compute, eval_dataset):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
precompute_ref_log_probs=pre_compute,
)
dummy_dataset = self._init_dummy_dataset()
if name == "gpt2":
model = self.model
ref_model = self.ref_model
tokenizer = self.tokenizer
elif name == "t5":
model = self.t5_model
ref_model = self.t5_ref_model
tokenizer = self.t5_tokenizer
trainer = KTOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset if eval_dataset else None,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
def test_kto_trainer_tokenize_row(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
trainer = KTOTrainer(
model=self.model,
ref_model=self.ref_model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
row = dummy_dataset[0]
# test that the row can be tokenized
tokenized_row = trainer.tokenize_row(row)
# Assert bos_token_id and eos_token_id (latter only for completion)
assert tokenized_row["prompt_input_ids"][0] == self.tokenizer.bos_token_id
assert tokenized_row["completion_input_ids"][0] == self.tokenizer.bos_token_id
assert tokenized_row["prompt_input_ids"][-1] != self.tokenizer.eos_token_id
assert tokenized_row["completion_input_ids"][-1] == self.tokenizer.eos_token_id
def test_kto_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
trainer = KTOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
@require_peft
@mark.peft_test
def test_kto_trainer_without_providing_ref_model_with_lora(self):
from peft import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
trainer = KTOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
@require_no_wandb
def test_kto_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
generate_during_eval=True,
)
dummy_dataset = self._init_dummy_dataset()
with self.assertRaisesRegex(
ValueError,
expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install with `pip install wandb` to resolve.",
):
KTOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
@require_peft
@mark.peft_test
def test_kto_lora_save(self):
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# lora model
model = AutoModelForCausalLM.from_pretrained(self.model_id)
model_peft = get_peft_model(model, lora_config)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
)
dummy_dataset = self._init_dummy_dataset()
# kto train lora model with a lora config
trainer = KTOTrainer(
model=model_peft,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
# train the model
trainer.train()
# save peft adapter
trainer.save_model()
# assert that the model is loaded without giving OSError
try:
AutoModelForCausalLM.from_pretrained(tmp_dir)
except OSError:
self.fail("Loading the saved peft adapter failed")

View File

@ -15,6 +15,7 @@ import gc
import tempfile
import unittest
import pytest
import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
@ -31,7 +32,7 @@ ALL_CAUSAL_LM_MODELS = [
"trl-internal-testing/tiny-random-GPT2LMHeadModel",
"trl-internal-testing/tiny-random-CodeGenForCausalLM-sharded",
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors-sharded",
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors"
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM-safetensors",
# "trl-internal-testing/tiny-random-LlamaForCausalLM", uncomment on the next transformers release
]
@ -68,7 +69,7 @@ class VHeadModelTester:
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
self.assertTrue(hasattr(model, "v_head"))
assert hasattr(model, "v_head")
def test_value_head_shape(self):
r"""
@ -76,7 +77,7 @@ class VHeadModelTester:
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
self.assertTrue(model.v_head.summary.weight.shape[0] == 1)
assert model.v_head.summary.weight.shape[0] == 1
def test_value_head_init_random(self):
r"""
@ -86,7 +87,7 @@ class VHeadModelTester:
"""
for model_name in self.all_model_names:
model = self.trl_model_class.from_pretrained(model_name)
self.assertFalse(torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias)))
assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))
def test_value_head_not_str(self):
r"""
@ -96,7 +97,7 @@ class VHeadModelTester:
for model_name in self.all_model_names:
pretrained_model = self.transformers_model_class.from_pretrained(model_name)
model = self.trl_model_class.from_pretrained(pretrained_model)
self.assertTrue(hasattr(model, "v_head"))
assert hasattr(model, "v_head")
def test_from_save_trl(self):
"""
@ -113,7 +114,7 @@ class VHeadModelTester:
# Check if the weights are the same
for key in model_from_save.state_dict():
self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]))
assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])
def test_from_save_trl_sharded(self):
"""
@ -129,7 +130,7 @@ class VHeadModelTester:
# Check if the weights are the same
for key in model_from_save.state_dict():
self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]))
assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])
def test_from_save_transformers_sharded(self):
"""
@ -146,10 +147,8 @@ class VHeadModelTester:
# Check if the weights are the same
for key in transformers_model.state_dict():
self.assertTrue(
torch.allclose(
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
)
assert torch.allclose(
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
)
def test_from_save_transformers(self):
@ -168,24 +167,20 @@ class VHeadModelTester:
# Check if the weights are the same
for key in transformers_model.state_dict():
self.assertTrue(
torch.allclose(
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
)
assert torch.allclose(
transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key]
)
# Check if the trl model has the same keys as the transformers model
# except the v_head
for key in trl_model.state_dict():
if "v_head" not in key:
self.assertTrue(key in transformers_model.state_dict())
assert key in transformers_model.state_dict()
# check if the weights are the same
self.assertTrue(torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key]))
assert torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key])
# check if they have the same modules
self.assertTrue(
set(transformers_model_from_save.state_dict().keys()) == set(transformers_model.state_dict().keys())
)
assert set(transformers_model_from_save.state_dict().keys()) == set(transformers_model.state_dict().keys())
class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
@ -215,7 +210,7 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
# Check if the outputs are of the right size - here
# we always output 3 values - logits, loss, and value states
self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE)
assert len(outputs) == EXPECTED_OUTPUT_SIZE
def test_dropout_config(self):
r"""
@ -228,7 +223,7 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
model = self.trl_model_class.from_pretrained(pretrained_model)
# Check if v head of the model has the same dropout as the config
self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob)
assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob
def test_dropout_kwargs(self):
r"""
@ -241,12 +236,12 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
# Check if v head of the model has the same dropout as the config
self.assertEqual(model.v_head.dropout.p, 0.5)
assert model.v_head.dropout.p == 0.5
model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5)
# Check if v head of the model has the same dropout as the config
self.assertEqual(model.v_head.dropout.p, 0.5)
assert model.v_head.dropout.p == 0.5
def test_generate(self):
r"""
@ -263,7 +258,7 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
# Test with a model without a LM head
model_id = "trl-internal-testing/tiny-random-GPT2Model"
# This should raise a ValueError
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
pretrained_model = AutoModelForCausalLM.from_pretrained(model_id)
_ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model.transformer)
@ -279,13 +274,11 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
lm_head_namings = self.trl_model_class.lm_head_namings
self.assertTrue(
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
)
assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
for lm_head_naming in lm_head_namings:
if hasattr(trl_model.pretrained_model, lm_head_naming):
self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16)
assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16
dummy_input = torch.LongTensor([[0, 1, 0, 1]])
@ -303,13 +296,12 @@ class CausalLMValueHeadModelTester(VHeadModelTester, unittest.TestCase):
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo")
# check all keys
self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys())
assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
for name, param in model.state_dict().items():
self.assertTrue(
torch.allclose(param, model_from_pretrained.state_dict()[name]),
f"Parameter {name} is not the same after push_to_hub and from_pretrained",
)
assert torch.allclose(
param, model_from_pretrained.state_dict()[name]
), f"Parameter {name} is not the same after push_to_hub and from_pretrained"
class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
@ -340,7 +332,7 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
# Check if the outputs are of the right size - here
# we always output 3 values - logits, loss, and value states
self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE)
assert len(outputs) == EXPECTED_OUTPUT_SIZE
def test_dropout_config(self):
r"""
@ -353,7 +345,7 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
model = self.trl_model_class.from_pretrained(pretrained_model)
# Check if v head of the model has the same dropout as the config
self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob)
assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob
def test_dropout_kwargs(self):
r"""
@ -366,12 +358,12 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
# Check if v head of the model has the same dropout as the config
self.assertEqual(model.v_head.dropout.p, 0.5)
assert model.v_head.dropout.p == 0.5
model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5)
# Check if v head of the model has the same dropout as the config
self.assertEqual(model.v_head.dropout.p, 0.5)
assert model.v_head.dropout.p == 0.5
def test_generate(self):
r"""
@ -389,7 +381,7 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
# Test with a model without a LM head
model_id = "trl-internal-testing/tiny-random-T5Model"
# This should raise a ValueError
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
pretrained_model = AutoModel.from_pretrained(model_id)
_ = self.trl_model_class.from_pretrained(pretrained_model)
@ -404,13 +396,12 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo")
# check all keys
self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys())
assert model.state_dict().keys() == model_from_pretrained.state_dict().keys()
for name, param in model.state_dict().items():
self.assertTrue(
torch.allclose(param, model_from_pretrained.state_dict()[name]),
f"Parameter {name} is not the same after push_to_hub and from_pretrained",
)
assert torch.allclose(
param, model_from_pretrained.state_dict()[name]
), f"Parameter {name} is not the same after push_to_hub and from_pretrained"
def test_transformers_bf16_kwargs(self):
r"""
@ -428,13 +419,11 @@ class Seq2SeqValueHeadModelTester(VHeadModelTester, unittest.TestCase):
# skip the test for FSMT as it does not support mixed-prec
continue
self.assertTrue(
any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
)
assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings)
for lm_head_naming in lm_head_namings:
if hasattr(trl_model.pretrained_model, lm_head_naming):
self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16)
assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16
dummy_input = torch.LongTensor([[0, 1, 0, 1]])
@ -474,14 +463,14 @@ class ReferenceModelTest(unittest.TestCase):
last_ref_layer_after = ref_model.get_parameter(layer_5).data.clone()
# before optimization ref and model are identical
self.assertTrue((first_layer_before == first_ref_layer_before).all())
self.assertTrue((last_layer_before == last_ref_layer_before).all())
assert (first_layer_before == first_ref_layer_before).all()
assert (last_layer_before == last_ref_layer_before).all()
# ref model stays identical after optimization
self.assertTrue((first_ref_layer_before == first_ref_layer_after).all())
self.assertTrue((last_ref_layer_before == last_ref_layer_after).all())
assert (first_ref_layer_before == first_ref_layer_after).all()
assert (last_ref_layer_before == last_ref_layer_after).all()
# optimized model changes
self.assertTrue(not (first_layer_before == first_layer_after).all())
self.assertTrue(not (last_layer_before == last_layer_after).all())
assert not (first_layer_before == first_layer_after).all()
assert not (last_layer_before == last_layer_after).all()
def test_shared_layers(self):
layer_0 = self.layer_format.format(layer=0)
@ -506,12 +495,12 @@ class ReferenceModelTest(unittest.TestCase):
second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone()
# before optimization ref and model are identical
self.assertTrue((first_layer_before == first_ref_layer_before).all())
self.assertTrue((second_layer_before == second_ref_layer_before).all())
assert (first_layer_before == first_ref_layer_before).all()
assert (second_layer_before == second_ref_layer_before).all()
# ref model stays identical after optimization
self.assertTrue((first_ref_layer_before == first_ref_layer_after).all())
self.assertTrue((second_ref_layer_before == second_ref_layer_after).all())
assert (first_ref_layer_before == first_ref_layer_after).all()
assert (second_ref_layer_before == second_ref_layer_after).all()
# first layer of optimized model stays the same
self.assertTrue((first_layer_before == first_layer_after).all())
assert (first_layer_before == first_layer_after).all()
# other layers in optimized model change
self.assertTrue(not (second_layer_before == second_layer_after).all())
assert not (second_layer_before == second_layer_after).all()

View File

@ -13,8 +13,10 @@
# limitations under the License.
import sys
import unittest
from functools import partial
from unittest.mock import patch
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -93,15 +95,15 @@ class TestPeftDependancy(unittest.TestCase):
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
# Check that loading a model with `peft` will raise an error
with self.assertRaises(ModuleNotFoundError):
import peft # noqa
with pytest.raises(ModuleNotFoundError):
import peft # noqa: F401
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) # noqa
trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id) # noqa
_trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
_trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id)
def test_imports_no_peft(self):
with patch.dict(sys.modules, {"peft": None}):
from trl import ( # noqa
from trl import ( # noqa: F401
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PPOConfig,
@ -133,6 +135,7 @@ class TestPeftDependancy(unittest.TestCase):
tokenizer=tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
for query_tensor, response_tensor in dummy_dataloader:
@ -140,14 +143,14 @@ class TestPeftDependancy(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
# check gradients are not None
for _, param in trl_model.named_parameters():
if param.requires_grad:
self.assertIsNotNone(param.grad)
assert param.grad is not None
# check expected stats
for stat in EXPECTED_STATS:
self.assertIn(stat, train_stats)
assert stat in train_stats

View File

@ -23,7 +23,7 @@ from trl import AutoModelForCausalLMWithValueHead, is_peft_available
if is_peft_available():
from peft import get_peft_model, LoraConfig
from peft import LoraConfig, get_peft_model
from .testing_utils import require_bitsandbytes, require_peft
@ -60,7 +60,7 @@ class PeftModelTester(unittest.TestCase):
model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model)
# Check that the value head has requires_grad=True
self.assertTrue(model.v_head.summary.weight.requires_grad)
assert model.v_head.summary.weight.requires_grad
def test_check_peft_model_nb_trainable_params(self):
r"""
@ -73,12 +73,12 @@ class PeftModelTester(unittest.TestCase):
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 10273)
assert nb_trainable_params == 10273
# Check that the number of trainable param for the non-peft model is correct
non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 99578)
assert nb_trainable_params == 99578
def test_create_peft_model_from_config(self):
r"""
@ -89,13 +89,13 @@ class PeftModelTester(unittest.TestCase):
)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 10273)
assert nb_trainable_params == 10273
causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id)
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 10273)
assert nb_trainable_params == 10273
@require_bitsandbytes
def test_create_bnb_peft_model_from_config(self):
@ -109,10 +109,8 @@ class PeftModelTester(unittest.TestCase):
)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 10273)
self.assertTrue(
trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
)
assert nb_trainable_params == 10273
assert trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
causal_lm_model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id, load_in_8bit=True, device_map="auto"
@ -120,10 +118,8 @@ class PeftModelTester(unittest.TestCase):
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 10273)
self.assertTrue(
trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
)
assert nb_trainable_params == 10273
assert trl_model.pretrained_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h.__class__ == Linear8bitLt
def test_save_pretrained_peft(self):
r"""
@ -138,31 +134,23 @@ class PeftModelTester(unittest.TestCase):
model.save_pretrained(tmp_dir)
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
self.assertTrue(
os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
msg=f"{tmp_dir}/adapter_model.safetensors does not exist",
)
self.assertTrue(
os.path.exists(f"{tmp_dir}/adapter_config.json"),
msg=f"{tmp_dir}/adapter_config.json does not exist",
)
assert os.path.isfile(
f"{tmp_dir}/adapter_model.safetensors"
), f"{tmp_dir}/adapter_model.safetensors does not exist"
assert os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
# check also for `pytorch_model.bin` and make sure it only contains `v_head` weights
self.assertTrue(
os.path.exists(f"{tmp_dir}/pytorch_model.bin"),
msg=f"{tmp_dir}/pytorch_model.bin does not exist",
)
assert os.path.exists(f"{tmp_dir}/pytorch_model.bin"), f"{tmp_dir}/pytorch_model.bin does not exist"
maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin")
# check that only keys that starts with `v_head` are in the dict
self.assertTrue(
all(k.startswith("v_head") for k in maybe_v_head.keys()),
msg=f"keys in {tmp_dir}/pytorch_model.bin do not start with `v_head`",
)
assert all(
k.startswith("v_head") for k in maybe_v_head.keys()
), f"keys in {tmp_dir}/pytorch_model.bin do not start with `v_head`"
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir)
# check all the weights are the same
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
self.assertTrue(torch.allclose(p1[1], p2[1]), msg=f"{p1[0]} != {p2[0]}")
assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
def test_load_pretrained_peft(self):
r"""
@ -178,19 +166,15 @@ class PeftModelTester(unittest.TestCase):
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir)
# check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory
self.assertTrue(
os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"),
msg=f"{tmp_dir}/adapter_model.safetensors does not exist",
)
self.assertTrue(
os.path.exists(f"{tmp_dir}/adapter_config.json"),
msg=f"{tmp_dir}/adapter_config.json does not exist",
)
assert os.path.isfile(
f"{tmp_dir}/adapter_model.safetensors"
), f"{tmp_dir}/adapter_model.safetensors does not exist"
assert os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist"
# check all the weights are the same
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]:
self.assertTrue(torch.allclose(p1[1], p2[1]), msg=f"{p1[0]} != {p2[0]}")
assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
def test_continue_training_peft_model(self):
r"""
@ -205,4 +189,4 @@ class PeftModelTester(unittest.TestCase):
model = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir, is_trainable=True)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
self.assertEqual(nb_trainable_params, 10273)
assert nb_trainable_params == 10273

View File

@ -17,6 +17,7 @@ import gc
import re
import tempfile
import unittest
from functools import partial
import pytest
import torch
@ -180,7 +181,7 @@ class PPOTrainerTester(unittest.TestCase):
)
dummy_dataloader = ppo_trainer.dataloader
self.assertEqual(len(dummy_dataloader), 0)
assert len(dummy_dataloader) == 0
def test_ppo_step(self):
# initialize dataset
@ -193,6 +194,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -200,7 +202,7 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
for param in ppo_trainer.model.parameters():
@ -220,6 +222,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -230,9 +233,7 @@ class PPOTrainerTester(unittest.TestCase):
response_mask = [torch.ones_like(r) for r in response_tensor]
# train model
train_stats = ppo_trainer.step(
[q for q in query_tensor], [r for r in response_tensor], reward, response_mask
)
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward, response_mask)
break
for param in ppo_trainer.model.parameters():
@ -254,9 +255,10 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
self.assertTrue(isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD))
assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -264,15 +266,15 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
for name, param in ppo_trainer.model.named_parameters():
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
assert param.grad is not None, f"Parameter {name} has no gradient"
# ref model should not be trained
for name, param in ppo_trainer.ref_model.named_parameters():
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
# Finally check stats
for stat in EXPECTED_STATS:
@ -293,10 +295,11 @@ class PPOTrainerTester(unittest.TestCase):
dataset=dummy_dataset,
lr_scheduler=lr_scheduler,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
self.assertTrue(isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD))
self.assertTrue(isinstance(ppo_trainer.lr_scheduler.scheduler, torch.optim.lr_scheduler.ExponentialLR))
assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)
assert isinstance(ppo_trainer.lr_scheduler.scheduler, torch.optim.lr_scheduler.ExponentialLR)
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -304,23 +307,23 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
for name, param in ppo_trainer.model.named_parameters():
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
assert param.grad is not None, f"Parameter {name} has no gradient"
# ref model should not be trained
for name, param in ppo_trainer.ref_model.named_parameters():
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
# Finally check stats
for stat in EXPECTED_STATS:
assert stat in train_stats.keys()
# assert that the LR has increased for exponential decay
self.assertTrue(train_stats["ppo/learning_rate"] > self.ppo_config.learning_rate)
assert train_stats["ppo/learning_rate"] > self.ppo_config.learning_rate
def test_ppo_step_with_no_ref(self):
# initialize dataset
@ -334,6 +337,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -341,15 +345,15 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
for name, param in ppo_trainer.model.named_parameters():
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
assert param.grad is not None, f"Parameter {name} has no gradient"
# ref model should not be trained
for name, param in ppo_trainer.ref_model.named_parameters():
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
# initialize a new gpt2 model:
model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id)
@ -357,10 +361,9 @@ class PPOTrainerTester(unittest.TestCase):
if "v_head" not in name:
name = name.replace("pretrained_model.", "")
self.assertTrue(
torch.allclose(param.cpu(), model.state_dict()[name].cpu()),
f"Parameter {name} has changed from the original model",
)
assert torch.allclose(
param.cpu(), model.state_dict()[name].cpu()
), f"Parameter {name} has changed from the original model"
# Finally check stats
for stat in EXPECTED_STATS:
@ -385,6 +388,7 @@ class PPOTrainerTester(unittest.TestCase):
dataset=dummy_dataset,
num_shared_layers=num_shared_layers,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -392,7 +396,7 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
pattern = r".*transformer\.h\.(\d+)\..*"
@ -402,15 +406,15 @@ class PPOTrainerTester(unittest.TestCase):
if re.match(pattern, name):
layer_number = int(re.match(pattern, name).groups(0)[0])
if layer_number < num_shared_layers:
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
else:
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
elif any([layer in name for layer in final_layers]):
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
assert param.grad is not None, f"Parameter {name} has no gradient"
elif any(layer in name for layer in final_layers):
assert param.grad is not None, f"Parameter {name} has no gradient"
# ref model should not be trained
for name, param in ppo_trainer.ref_model.named_parameters():
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
for stat in EXPECTED_STATS:
assert stat in train_stats.keys()
@ -452,6 +456,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -459,21 +464,21 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor([[1.0]]), torch.tensor([[0.0]])]
# train model - this should raise an error
with self.assertRaises(ValueError):
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
with pytest.raises(ValueError):
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
reward = [torch.tensor([1.0]), torch.tensor([0.0])]
# train model - this should work
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
# check if the gradients are computed for the model
for name, param in ppo_trainer.model.named_parameters():
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
assert param.grad is not None, f"Parameter {name} has no gradient"
# ref model should not be trained
for name, param in ppo_trainer.ref_model.named_parameters():
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
def test_ppo_step_input_shape(self):
"""
@ -489,6 +494,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -499,16 +505,16 @@ class PPOTrainerTester(unittest.TestCase):
bs = ppo_trainer.config.batch_size
queries, responses, _, _ = ppo_trainer._step_safety_checker(
bs, [q for q in query_tensor], [r for r in response_tensor], reward
bs, list(query_tensor), list(response_tensor), reward
)
self.assertTrue(isinstance(queries, list), f"queries should be a list, got {type(queries)}")
self.assertTrue(isinstance(responses, list), f"responses should be a list, got {type(responses)}")
assert isinstance(queries, list), f"queries should be a list, got {type(queries)}"
assert isinstance(responses, list), f"responses should be a list, got {type(responses)}"
# check the shapes
for i in range(bs):
self.assertEqual(queries[i].shape, torch.Size([7]))
self.assertEqual(responses[i].size(), torch.Size([7]))
assert queries[i].shape == torch.Size([7])
assert responses[i].size() == torch.Size([7])
break
def test_ppo_step_no_dataset(self):
@ -529,6 +535,7 @@ class PPOTrainerTester(unittest.TestCase):
ref_model=self.gpt2_model_ref,
tokenizer=self.gpt2_tokenizer,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
# train model with ppo
reward = [torch.tensor([1.0])]
# train model - this should work fine
@ -536,15 +543,15 @@ class PPOTrainerTester(unittest.TestCase):
# check gradients
for name, param in ppo_trainer.model.named_parameters():
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
assert param.grad is not None, f"Parameter {name} has no gradient"
# ref model should not be trained
for name, param in ppo_trainer.ref_model.named_parameters():
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
# check train stats
for stat in EXPECTED_STATS:
self.assertTrue(stat in train_stats, f"Train stats should contain {stat}")
assert stat in train_stats, f"Train stats should contain {stat}"
def test_loss_trainer(self):
"""
@ -579,7 +586,7 @@ class PPOTrainerTester(unittest.TestCase):
logits = torch.exp(all_logprobs)
vpreds = values + 0.1
score, non_score = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask)
score, non_score, kls = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask)
values, advantages, returns = ppo_trainer.compute_advantages(values, score, mask)
# just make sure a dummy loss is computed
@ -595,8 +602,8 @@ class PPOTrainerTester(unittest.TestCase):
returns[idx].unsqueeze(0),
)
self.assertAlmostEqual(pg_loss.item(), 2.0494, 4)
self.assertAlmostEqual(v_loss.item(), 0.07110, 4)
assert abs(pg_loss.item() - 2.0494) < 0.0001
assert abs(v_loss.item() - 0.0711) < 0.0001
# check if we get same results with masked parts removed
pg_loss_unmasked, v_loss_unmasked, _ = ppo_trainer.loss(
@ -609,8 +616,8 @@ class PPOTrainerTester(unittest.TestCase):
apply_mask(advantages[idx], mask[idx]).unsqueeze(0),
apply_mask(returns[idx], mask[idx]).unsqueeze(0),
)
self.assertAlmostEqual(pg_loss_unmasked.item(), 2.0494, 4)
self.assertAlmostEqual(v_loss_unmasked.item(), 0.07110, 4)
assert abs(pg_loss_unmasked.item() - 2.0494) < 0.0001
assert abs(v_loss_unmasked.item() - 0.0711) < 0.0001
@parameterized.expand(
[
@ -674,11 +681,11 @@ class PPOTrainerTester(unittest.TestCase):
model, dummy_queries, dummy_responses, model_inputs
)
self.assertLessEqual(abs_diff_masked_tensors(logprobs_1, logprobs_2, mask_1, mask_2), 1e-4)
self.assertLessEqual(abs_diff_masked_tensors(values_1, values_2, mask_1, mask_2), 1e-4)
assert abs_diff_masked_tensors(logprobs_1, logprobs_2, mask_1, mask_2) <= 0.0001
assert abs_diff_masked_tensors(values_1, values_2, mask_1, mask_2) <= 0.0001
self.assertLessEqual(abs_diff_masked_tensors(logprobs_0, logprobs_2[:1], mask_0, mask_2[:1]), 1e-4)
self.assertLessEqual(abs_diff_masked_tensors(values_0, values_2[:1], mask_0, mask_2[:1]), 1e-4)
assert abs_diff_masked_tensors(logprobs_0, logprobs_2[:1], mask_0, mask_2[:1]) <= 0.0001
assert abs_diff_masked_tensors(values_0, values_2[:1], mask_0, mask_2[:1]) <= 0.0001
def test_ppo_trainer_max_grad_norm(self):
"""
@ -695,7 +702,7 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
@ -704,16 +711,15 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
# check gradients
for name, param in ppo_trainer.model.named_parameters():
self.assertTrue(param.grad is not None, f"Parameter {name} has no gradient")
self.assertTrue(
torch.all(param.grad.abs() <= self.ppo_config.max_grad_norm),
f"Parameter {name} has a gradient larger than max_grad_norm",
)
assert param.grad is not None, f"Parameter {name} has no gradient"
assert torch.all(
param.grad.abs() <= self.ppo_config.max_grad_norm
), f"Parameter {name} has a gradient larger than max_grad_norm"
def test_ppo_trainer_kl_penalty(self):
dummy_dataset = self._init_dummy_dataset()
@ -730,7 +736,7 @@ class PPOTrainerTester(unittest.TestCase):
)
expected_output = torch.Tensor([[0.1000, -0.1000, 0.1000], [-0.1000, 0.1000, -0.2000]])
self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output))
assert torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)
self.ppo_config.kl_penalty = "abs"
ppo_trainer = PPOTrainer(
@ -742,7 +748,7 @@ class PPOTrainerTester(unittest.TestCase):
)
expected_output = torch.Tensor([[0.1000, 0.1000, 0.1000], [0.1000, 0.1000, 0.2000]])
self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output))
assert torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)
self.ppo_config.kl_penalty = "mse"
ppo_trainer = PPOTrainer(
@ -754,7 +760,7 @@ class PPOTrainerTester(unittest.TestCase):
)
expected_output = torch.Tensor([[0.0050, 0.0050, 0.0050], [0.0050, 0.0050, 0.0200]])
self.assertTrue(torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output))
assert torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output)
def test_ppo_trainer_full_kl_penalty(self):
# a few more extensive tests for the full kl option as it is more involved
@ -793,8 +799,8 @@ class PPOTrainerTester(unittest.TestCase):
[[0.0, 0.0]],
)
output = ppo_trainer._kl_penalty(log_probs, ref_log_probs)
self.assertTrue(output.shape == (1, 2))
self.assertTrue(torch.allclose(output, expected_output))
assert output.shape == (1, 2)
assert torch.allclose(output, expected_output)
# test for when the two dists are almost not overlapping
log_probs = torch.Tensor(
@ -819,8 +825,8 @@ class PPOTrainerTester(unittest.TestCase):
[[4.4474, 4.4474]],
)
output = ppo_trainer._kl_penalty(log_probs, ref_log_probs)
self.assertTrue(output.shape == (1, 2))
self.assertTrue(torch.allclose(output, expected_output))
assert output.shape == (1, 2)
assert torch.allclose(output, expected_output)
# test for when the two dists are almost not overlapping
log_probs = torch.Tensor(
@ -845,8 +851,8 @@ class PPOTrainerTester(unittest.TestCase):
[[3.7361, 0.0]],
)
output = ppo_trainer._kl_penalty(log_probs, ref_log_probs)
self.assertTrue(output.shape == (1, 2))
self.assertTrue(torch.allclose(output, expected_output, atol=1e-4))
assert output.shape == (1, 2)
assert torch.allclose(output, expected_output, atol=0.0001)
@require_peft
@mark.peft_test
@ -883,8 +889,8 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
self.assertTrue(ppo_trainer.ref_model is None)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
assert ppo_trainer.ref_model is None
dummy_dataloader = ppo_trainer.dataloader
@ -894,19 +900,19 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model by running a step twice
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
ppo_trainer.model.train()
ppo_trainer.model.gradient_checkpointing_enable()
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
# check gradients
for name, param in model.named_parameters():
if "lora" in name or "v_head" in name:
self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient")
assert param.grad is not None, f"Parameter {name} has a no gradient"
else:
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
@require_peft
@mark.peft_test
@ -971,8 +977,8 @@ class PPOTrainerTester(unittest.TestCase):
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
self.assertTrue(ppo_trainer.ref_model is None)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
assert ppo_trainer.ref_model is None
dummy_dataloader = ppo_trainer.dataloader
@ -982,23 +988,23 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model by running a step twice
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
ppo_trainer.model.train()
ppo_trainer.model.gradient_checkpointing_enable()
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
new_logits = ppo_trainer.model.compute_reward_score(dummy_inputs)
self.assertTrue(not torch.allclose(previous_rm_logits, new_logits[:, -1, :]))
self.assertTrue(torch.allclose(original_rm_logits, new_logits[:, -1, :]))
assert not torch.allclose(previous_rm_logits, new_logits[:, -1, :])
assert torch.allclose(original_rm_logits, new_logits[:, -1, :])
# check gradients
for name, param in model.named_parameters():
if ("lora" in name or "v_head" in name) and ("reward" not in name):
self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient")
assert param.grad is not None, f"Parameter {name} has a no gradient"
else:
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
@unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.")
def test_push_to_hub(self):
@ -1016,10 +1022,10 @@ class PPOTrainerTester(unittest.TestCase):
url = ppo_trainer.push_to_hub(repo_id=repo_id, token=self._token, api_endpoint=CI_HUB_ENDPOINT)
# Extract repo_name from the url
re_search = re.search(CI_HUB_ENDPOINT + r"/([^/]+/[^/]+)/", url)
self.assertTrue(re_search is not None)
assert re_search is not None
hub_repo_id = re_search.groups()[0]
# Check we created a Hub repo
self.assertEqual(hub_repo_id, repo_id)
assert hub_repo_id == repo_id
# Ensure all files are present
files = sorted(self._api.list_repo_files(hub_repo_id))
assert all(
@ -1057,7 +1063,7 @@ class PPOTrainerTester(unittest.TestCase):
"gpt2", device_map="balanced", max_memory={0: "500MB", 1: "500MB"}
)
self.assertTrue(set(gpt2_model.hf_device_map.values()) == {0, 1})
assert set(gpt2_model.hf_device_map.values()) == {0, 1}
# this line is very important
def make_inputs_require_grad(module, input, output):
@ -1068,7 +1074,7 @@ class PPOTrainerTester(unittest.TestCase):
peft_model = get_peft_model(gpt2_model, lora_config)
model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model)
self.assertTrue(model.is_sequential_parallel)
assert model.is_sequential_parallel
dummy_dataset = self._init_dummy_dataset()
self.ppo_config.batch_size = 2
@ -1082,7 +1088,7 @@ class PPOTrainerTester(unittest.TestCase):
dataset=dummy_dataset,
)
self.assertTrue(ppo_trainer.ref_model is None)
assert ppo_trainer.ref_model is None
dummy_dataloader = ppo_trainer.dataloader
@ -1092,19 +1098,19 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model by running a step twice
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
ppo_trainer.model.train()
ppo_trainer.model.gradient_checkpointing_enable()
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
# check gradients
for name, param in model.named_parameters():
if "lora" in name or "v_head" in name:
self.assertTrue(param.grad is not None, f"Parameter {name} has a no gradient")
assert param.grad is not None, f"Parameter {name} has a no gradient"
else:
self.assertTrue(param.grad is None, f"Parameter {name} has a gradient")
assert param.grad is None, f"Parameter {name} has a gradient"
def test_generation(self):
dummy_dataset = self._init_dummy_dataset()
@ -1134,7 +1140,7 @@ class PPOTrainerTester(unittest.TestCase):
generations_single = [ppo_trainer.generate(inputs, **generation_kwargs).squeeze() for inputs in model_inputs]
generations_single = tokenizer.batch_decode(generations_single)
self.assertEqual(generations_single, generations_batched)
assert generations_single == generations_batched
def test_grad_accumulation(self):
dummy_dataset = self._init_dummy_dataset()
@ -1162,7 +1168,7 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(1.0)]
# train model by running a step twice
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
model_grad = gpt2_model.v_head.summary.weight
@ -1186,11 +1192,11 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(1.0)]
# train model by running a step twice
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
model_grad_acc = gpt2_model_clone.v_head.summary.weight
self.assertTrue(torch.allclose(model_grad_acc, model_grad, rtol=1e-3, atol=1e-3))
assert torch.allclose(model_grad_acc, model_grad, rtol=0.001, atol=0.001)
@unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.")
def test_push_to_hub_if_best_reward(self):
@ -1217,6 +1223,7 @@ class PPOTrainerTester(unittest.TestCase):
dataset=dummy_dataset,
)
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
@ -1224,7 +1231,7 @@ class PPOTrainerTester(unittest.TestCase):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
_ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break
def test_batch_size_check(self):

View File

@ -14,6 +14,7 @@
import tempfile
import unittest
import pytest
import torch
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
@ -35,7 +36,7 @@ class RewardTrainerTester(unittest.TestCase):
def test_accuracy_metrics(self):
dummy_eval_predictions = EvalPrediction(torch.FloatTensor([[0.1, 0.9], [0.9, 0.1]]), torch.LongTensor([0, 0]))
accuracy = compute_accuracy(dummy_eval_predictions)
self.assertEqual(accuracy["accuracy"], 0.5)
assert accuracy["accuracy"] == 0.5
def test_reward_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -52,9 +53,9 @@ class RewardTrainerTester(unittest.TestCase):
# fmt: off
dummy_dataset_dict = {
"input_ids_chosen": [
torch.LongTensor([0, 1, 2,]),
torch.LongTensor([0, 1, 2]),
torch.LongTensor([1, 2]),
torch.LongTensor([0, 1, 2,]),
torch.LongTensor([0, 1, 2]),
torch.LongTensor([1, 2]),
],
"attention_mask_chosen": [
@ -64,9 +65,9 @@ class RewardTrainerTester(unittest.TestCase):
torch.LongTensor([1, 0]),
],
"input_ids_rejected": [
torch.LongTensor([0, 2,]),
torch.LongTensor([0, 2]),
torch.LongTensor([1, 2, 0]),
torch.LongTensor([0, 2,]),
torch.LongTensor([0, 2]),
torch.LongTensor([1, 2, 0]),
],
"attention_mask_rejected": [
@ -91,17 +92,17 @@ class RewardTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
assert not torch.equal(param, new_param)
preds = trainer.predict(dummy_dataset)
self.assertEqual(preds.predictions.shape, (4, 2))
assert preds.predictions.shape == (4, 2)
@require_peft
def test_reward_trainer_peft(self):
@ -132,9 +133,9 @@ class RewardTrainerTester(unittest.TestCase):
# fmt: off
dummy_dataset_dict = {
"input_ids_chosen": [
torch.LongTensor([0, 1, 2,]),
torch.LongTensor([0, 1, 2]),
torch.LongTensor([1, 2]),
torch.LongTensor([0, 1, 2,]),
torch.LongTensor([0, 1, 2]),
torch.LongTensor([1, 2]),
],
"attention_mask_chosen": [
@ -144,9 +145,9 @@ class RewardTrainerTester(unittest.TestCase):
torch.LongTensor([1, 0]),
],
"input_ids_rejected": [
torch.LongTensor([0, 2,]),
torch.LongTensor([0, 2]),
torch.LongTensor([1, 2, 0]),
torch.LongTensor([0, 2,]),
torch.LongTensor([0, 2]),
torch.LongTensor([1, 2, 0]),
],
"attention_mask_rejected": [
@ -175,27 +176,27 @@ class RewardTrainerTester(unittest.TestCase):
# check gradients are not None
for n, param in trainer.model.named_parameters():
if any([t in n for t in trainable_params_name]):
if any(t in n for t in trainable_params_name):
previous_trainable_params[n] = param.clone()
else:
previous_non_trainable_params[n] = param.clone()
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
assert not torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)
# check the non trainable params have not changed
for n, param in previous_non_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
assert torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)
preds = trainer.predict(dummy_dataset)
self.assertEqual(preds.predictions.shape, (4, 2))
assert preds.predictions.shape == (4, 2)
def test_reward_trainer_assert_value_error(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -206,12 +207,12 @@ class RewardTrainerTester(unittest.TestCase):
remove_unused_columns=False,
)
# fmt: off
dummy_dataset_dict = {
# fmt: off
"input_ids_b": [
torch.LongTensor([0, 1, 2,]),
torch.LongTensor([0, 1, 2]),
torch.LongTensor([1, 2]),
torch.LongTensor([0, 1, 2,]),
torch.LongTensor([0, 1, 2]),
torch.LongTensor([1, 2]),
],
"attention_mask_c": [
@ -221,9 +222,9 @@ class RewardTrainerTester(unittest.TestCase):
torch.LongTensor([1, 0]),
],
"input_ids_f": [
torch.LongTensor([0, 2,]),
torch.LongTensor([0, 2]),
torch.LongTensor([1, 2, 0]),
torch.LongTensor([0, 2,]),
torch.LongTensor([0, 2]),
torch.LongTensor([1, 2, 0]),
],
"attention_mask_g": [
@ -232,8 +233,8 @@ class RewardTrainerTester(unittest.TestCase):
torch.LongTensor([1, 1]),
torch.LongTensor([1, 1, 1]),
],
# fmt: on
}
# fmt: on
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
trainer = RewardTrainer(
@ -243,7 +244,7 @@ class RewardTrainerTester(unittest.TestCase):
train_dataset=dummy_dataset,
)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
trainer.train()
training_args = RewardConfig(
@ -276,13 +277,13 @@ class RewardTrainerTester(unittest.TestCase):
# fmt: off
dummy_dataset_dict = {
"input_ids_chosen": [
torch.LongTensor([0, 1, 2,]),
torch.LongTensor([0, 1, 2]),
],
"attention_mask_chosen": [
torch.LongTensor([1, 1, 1]),
],
"input_ids_rejected": [
torch.LongTensor([0, 2,]),
torch.LongTensor([0, 2]),
],
"attention_mask_rejected": [
torch.LongTensor([1, 1]),
@ -306,9 +307,60 @@ class RewardTrainerTester(unittest.TestCase):
batch = trainer.data_collator(batch)
loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True)
self.assertAlmostEqual(
loss,
-torch.nn.functional.logsigmoid(
outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"]
).mean(),
l_val = -torch.nn.functional.logsigmoid(
outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"]
).mean()
assert abs(loss - l_val) < 1e-6
def test_reward_trainer_tags(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = RewardConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
)
# fmt: off
dummy_dataset_dict = {
"input_ids_chosen": [
torch.LongTensor([0, 1, 2]),
torch.LongTensor([1, 2]),
torch.LongTensor([0, 1, 2]),
torch.LongTensor([1, 2]),
],
"attention_mask_chosen": [
torch.LongTensor([1, 1, 1]),
torch.LongTensor([1, 0]),
torch.LongTensor([1, 1, 1]),
torch.LongTensor([1, 0]),
],
"input_ids_rejected": [
torch.LongTensor([0, 2]),
torch.LongTensor([1, 2, 0]),
torch.LongTensor([0, 2]),
torch.LongTensor([1, 2, 0]),
],
"attention_mask_rejected": [
torch.LongTensor([1, 1]),
torch.LongTensor([1, 1, 0]),
torch.LongTensor([1, 1]),
torch.LongTensor([1, 1, 1]),
],
}
# fmt: on
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
trainer = RewardTrainer(
model=self.model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
assert trainer.model.model_tags == trainer._tag_names

View File

@ -17,6 +17,7 @@ import tempfile
import unittest
import numpy as np
import pytest
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
@ -85,6 +86,42 @@ class SFTTrainerTester(unittest.TestCase):
],
}
)
cls.dummy_chatml_dataset = Dataset.from_dict(
{
"messages": [
[
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "What is 3+3?"},
{"role": "assistant", "content": "6"},
],
[
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
],
]
}
)
cls.dummy_instruction_dataset = Dataset.from_list(
[
{"prompt": "What is 2+2?", "completion": "4"},
{"prompt": "What is 3+3?", "completion": "6"},
{"prompt": "What is 4+4?", "completion": "8"},
{"prompt": "What is 2+2?", "completion": "4"},
{"prompt": "What is 3+3?", "completion": "6"},
{"prompt": "What is 4+4?", "completion": "8"},
{"prompt": "What is 2+2?", "completion": "4"},
{"prompt": "What is 3+3?", "completion": "6"},
{"prompt": "What is 4+4?", "completion": "8"},
{"prompt": "What is 2+2?", "completion": "4"},
{"prompt": "What is 3+3?", "completion": "6"},
{"prompt": "What is 4+4?", "completion": "8"},
]
)
cls.train_dataset = ConstantLengthDataset(
cls.tokenizer,
@ -112,18 +149,18 @@ class SFTTrainerTester(unittest.TestCase):
formatting_func=formatting_prompts_func,
)
self.assertTrue(len(formatted_dataset) == len(self.dummy_dataset))
self.assertTrue(len(formatted_dataset) > 0)
assert len(formatted_dataset) == len(self.dummy_dataset)
assert len(formatted_dataset) > 0
for example in formatted_dataset:
self.assertTrue("input_ids" in example)
self.assertTrue("labels" in example)
assert "input_ids" in example
assert "labels" in example
self.assertTrue(len(example["input_ids"]) == formatted_dataset.seq_length)
self.assertTrue(len(example["labels"]) == formatted_dataset.seq_length)
assert len(example["input_ids"]) == formatted_dataset.seq_length
assert len(example["labels"]) == formatted_dataset.seq_length
decoded_text = self.tokenizer.decode(example["input_ids"])
self.assertTrue(("Question" in decoded_text) and ("Answer" in decoded_text))
assert ("Question" in decoded_text) and ("Answer" in decoded_text)
def test_sft_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -147,10 +184,10 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
def test_sft_trainer_uncorrect_data(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -164,14 +201,42 @@ class SFTTrainerTester(unittest.TestCase):
per_device_train_batch_size=2,
)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
packing=True,
)
# this should work since the dummy chatml include the correct format
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
max_seq_length=32, # make sure there is at least 1 packed sequence
num_of_sequences=32,
packing=True,
)
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
packing=False,
)
# this should work since the dummy instruction dataset is the correct format
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_instruction_dataset,
max_seq_length=16, # make sure there is at least 1 packed sequence
packing=True,
)
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_instruction_dataset,
packing=False,
)
# This should work
_ = SFTTrainer(
model=self.model,
@ -182,7 +247,7 @@ class SFTTrainerTester(unittest.TestCase):
packing=True,
)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
# This should not work because not enough data for one sample
_ = SFTTrainer(
model=self.model,
@ -194,7 +259,7 @@ class SFTTrainerTester(unittest.TestCase):
)
# This should not work as well
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
_ = SFTTrainer(
model=self.model,
args=training_args,
@ -235,10 +300,10 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
@ -263,9 +328,9 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
@ -288,9 +353,9 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")
def test_sft_trainer_with_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -314,10 +379,10 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
@ -341,9 +406,9 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
# with formatting_func + packed
with tempfile.TemporaryDirectory() as tmp_dir:
@ -368,9 +433,9 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
# with formatting_func + packed
with tempfile.TemporaryDirectory() as tmp_dir:
@ -393,9 +458,9 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
@ -417,9 +482,9 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")
def test_sft_trainer_with_multiple_eval_datasets(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -446,11 +511,11 @@ class SFTTrainerTester(unittest.TestCase):
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_data1_loss"])
self.assertIsNotNone(trainer.state.log_history[1]["eval_data2_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_data1_loss"] is not None
assert trainer.state.log_history[1]["eval_data2_loss"] is not None
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-1")
def test_data_collator_completion_lm(self):
response_template = "### Response:\n"
@ -465,7 +530,7 @@ class SFTTrainerTester(unittest.TestCase):
labels = batch["labels"]
last_pad_idx = np.where(labels == -100)[1][-1]
result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :])
self.assertEqual(result_text, "I have not been masked correctly.")
assert result_text == "I have not been masked correctly."
def test_data_collator_completion_lm_with_multiple_text(self):
tokenizer = copy.deepcopy(self.tokenizer)
@ -488,7 +553,7 @@ class SFTTrainerTester(unittest.TestCase):
labels = batch["labels"][i]
last_pad_idx = np.where(labels == -100)[0][-1]
result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :])
self.assertEqual(result_text, "I have not been masked correctly.")
assert result_text == "I have not been masked correctly."
def test_data_collator_chat_completion_lm(self):
instruction_template = "### Human:"
@ -509,7 +574,7 @@ class SFTTrainerTester(unittest.TestCase):
labels = batch["labels"]
non_masked_tokens = batch["input_ids"][labels != -100]
result_text = self.tokenizer.decode(non_masked_tokens)
self.assertEqual(result_text, " I should not be masked. I should not be masked too.")
assert result_text == " I should not be masked. I should not be masked too."
def test_data_collator_chat_completion_lm_with_multiple_text(self):
tokenizer = copy.deepcopy(self.tokenizer)
@ -537,11 +602,11 @@ class SFTTrainerTester(unittest.TestCase):
non_masked_tokens1 = input_ids[0][labels[0] != -100]
result_text1 = tokenizer.decode(non_masked_tokens1)
self.assertEqual(result_text1, " I should not be masked.")
assert result_text1 == " I should not be masked."
non_masked_tokens2 = input_ids[1][labels[1] != -100]
result_text2 = tokenizer.decode(non_masked_tokens2)
self.assertEqual(result_text2, " I should not be masked. I should not be masked too.")
assert result_text2 == " I should not be masked. I should not be masked too."
def test_sft_trainer_infinite_with_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -564,15 +629,15 @@ class SFTTrainerTester(unittest.TestCase):
max_seq_length=500,
)
self.assertTrue(trainer.train_dataset.infinite)
assert trainer.train_dataset.infinite
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None
# make sure the trainer did 5 steps
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-5"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-5")
def test_sft_trainer_infinite_with_model_epochs(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -593,14 +658,14 @@ class SFTTrainerTester(unittest.TestCase):
max_seq_length=500,
)
self.assertFalse(trainer.train_dataset.infinite)
assert not trainer.train_dataset.infinite
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
# make sure the trainer did 5 steps
self.assertTrue("model.safetensors" in os.listdir(tmp_dir + "/checkpoint-4"))
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-4")
def test_sft_trainer_with_model_neftune(self):
with tempfile.TemporaryDirectory() as tmp_dir:
@ -634,8 +699,8 @@ class SFTTrainerTester(unittest.TestCase):
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0)
assert not torch.allclose(embeds_neftune, embeds_neftune_2)
assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0
trainer.neftune_hook_handle.remove()
@ -643,7 +708,26 @@ class SFTTrainerTester(unittest.TestCase):
# Make sure forward pass works fine
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0
@require_peft
def test_peft_sft_trainer_str(self):
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
_ = SFTTrainer(
model=self.model_id,
args=None,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
peft_config=peft_config,
packing=True,
)
@require_peft
def test_peft_sft_trainer(self):
@ -675,16 +759,16 @@ class SFTTrainerTester(unittest.TestCase):
packing=True,
)
self.assertTrue(isinstance(trainer.model, PeftModel))
assert isinstance(trainer.model, PeftModel)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))
assert "adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
assert "adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")
assert "model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")
@require_peft
def test_peft_sft_trainer_gc(self):
@ -717,16 +801,16 @@ class SFTTrainerTester(unittest.TestCase):
packing=True,
)
self.assertTrue(isinstance(trainer.model, PeftModel))
assert isinstance(trainer.model, PeftModel)
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))
assert "adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
assert "adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")
assert "model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")
@require_peft
def test_peft_sft_trainer_neftune(self):
@ -761,7 +845,7 @@ class SFTTrainerTester(unittest.TestCase):
trainer.model = trainer._trl_activate_neftune(trainer.model)
self.assertTrue(isinstance(trainer.model, PeftModel))
assert isinstance(trainer.model, PeftModel)
device = trainer.model.get_input_embeddings().weight.device
trainer.model.train()
@ -772,20 +856,127 @@ class SFTTrainerTester(unittest.TestCase):
torch.random.manual_seed(24)
embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) > 0)
assert not torch.allclose(embeds_neftune, embeds_neftune_2)
assert len(trainer.model.get_input_embeddings()._forward_hooks) > 0
trainer.neftune_hook_handle.remove()
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None
self.assertTrue("adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2"))
self.assertTrue("model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2"))
assert "adapter_model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
assert "adapter_config.json" in os.listdir(tmp_dir + "/checkpoint-2")
assert "model.safetensors" not in os.listdir(tmp_dir + "/checkpoint-2")
# Make sure forward pass works fine to check if embeddings forward is not broken.
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
assert len(trainer.model.get_input_embeddings()._forward_hooks) == 0
@require_peft
def test_peft_sft_trainer_tag(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
peft_config=peft_config,
packing=True,
)
assert trainer.model.model_tags == trainer._tag_names
@require_peft
def test_sft_trainer_tag(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
packing=True,
)
assert trainer.model.model_tags == trainer._tag_names
def test_sft_trainer_eval_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
packing=True,
max_seq_length=32, # make sure there is at least 1 packed sequence
eval_packing=False,
)
assert len(trainer.train_dataset["input_ids"]) == 1
assert len(trainer.eval_dataset["input_ids"]) != 1
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
max_seq_length=32, # make sure there is at least 1 packed sequence
packing=True,
)
assert len(trainer.train_dataset["input_ids"]) == 1
assert len(trainer.eval_dataset["input_ids"]) == 1
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
max_seq_length=32, # make sure there is at least 1 packed sequence
packing=False,
)
assert len(trainer.train_dataset["input_ids"]) != 1
assert len(trainer.eval_dataset["input_ids"]) != 1

View File

@ -15,7 +15,13 @@ import unittest
import torch
from trl import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available
from trl import (
is_bitsandbytes_available,
is_diffusers_available,
is_peft_available,
is_wandb_available,
is_xpu_available,
)
def require_peft(test_case):
@ -27,6 +33,15 @@ def require_peft(test_case):
return test_case
def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bnb. Skips the test if bnb is not available.
"""
if not is_bitsandbytes_available():
test_case = unittest.skip("test requires bnb")(test_case)
return test_case
def require_diffusers(test_case):
"""
Decorator marking a test that requires diffusers. Skips the test if diffusers is not available.
@ -55,17 +70,6 @@ def require_no_wandb(test_case):
return require_wandb(test_case, required=False)
def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available.
"""
try:
import bitsandbytes # noqa: F401
except ImportError:
test_case = unittest.skip("test requires bitsandbytes")(test_case)
return test_case
def require_torch_multi_gpu(test_case):
"""
Decorator marking a test that requires multiple GPUs. Skips the test if there aren't enough GPUs.
@ -75,6 +79,15 @@ def require_torch_multi_gpu(test_case):
return test_case
def require_torch_gpu(test_case):
"""
Decorator marking a test that requires GPUs. Skips the test if there is no GPU.
"""
if not torch.cuda.is_available():
test_case = unittest.skip("test requires GPU")(test_case)
return test_case
def require_torch_multi_xpu(test_case):
"""
Decorator marking a test that requires multiple XPUs. Skips the test if there aren't enough XPUs.

View File

@ -1,40 +1,133 @@
# flake8: noqa
__version__ = "0.7.8.dev0"
__version__ = "0.8.0"
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import (
is_diffusers_available,
is_npu_available,
is_peft_available,
is_wandb_available,
is_xpu_available,
)
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
create_reference_model,
)
from .trainer import (
DataCollatorForCompletionOnlyLM,
DPOTrainer,
IterativeSFTTrainer,
PPOConfig,
PPOTrainer,
RewardConfig,
RewardTrainer,
SFTTrainer,
)
from typing import TYPE_CHECKING
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
_import_structure = {
"core": [
"set_seed",
],
"environment": [
"TextEnvironment",
"TextHistory",
],
"extras": [
"BestOfNSampler",
],
"import_utils": [
"is_bitsandbytes_available",
"is_diffusers_available",
"is_npu_available",
"is_peft_available",
"is_wandb_available",
"is_xpu_available",
],
"models": [
"AutoModelForCausalLMWithValueHead",
"AutoModelForSeq2SeqLMWithValueHead",
"PreTrainedModelWrapper",
"create_reference_model",
"setup_chat_format",
"SUPPORTED_ARCHITECTURES",
],
"trainer": [
"DataCollatorForCompletionOnlyLM",
"DPOTrainer",
"IterativeSFTTrainer",
"KTOConfig",
"KTOTrainer",
"ModelConfig",
"PPOConfig",
"PPOTrainer",
"RewardConfig",
"RewardTrainer",
"SFTTrainer",
],
"commands": [],
"commands.utils": ["SftArgumentParser", "init_zero_verbose", "TrlParser", "DpoArgumentParser"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "RichProgressCallback"],
"multitask_prompt_tuning": [
"MultitaskPromptEmbedding",
"MultitaskPromptTuningConfig",
"MultitaskPromptTuningInit",
],
}
if is_diffusers_available():
from .models import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["models"].extend(
[
"DDPOPipelineOutput",
"DDPOSchedulerOutput",
"DDPOStableDiffusionPipeline",
"DefaultDDPOStableDiffusionPipeline",
]
)
_import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"])
if TYPE_CHECKING:
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import (
is_bitsandbytes_available,
is_diffusers_available,
is_npu_available,
is_peft_available,
is_wandb_available,
is_xpu_available,
)
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
create_reference_model,
setup_chat_format,
SUPPORTED_ARCHITECTURES,
)
from .trainer import (
DataCollatorForCompletionOnlyLM,
DPOTrainer,
IterativeSFTTrainer,
KTOConfig,
KTOTrainer,
ModelConfig,
PPOConfig,
PPOTrainer,
RewardConfig,
RewardTrainer,
SFTTrainer,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
from .commands.utils import init_zero_verbose, SftScriptArguments, DpoScriptArguments, TrlParser
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .models import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
from .trainer import DDPOConfig, DDPOTrainer
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
extra_objects={"__version__": __version__},
)
from .trainer import DDPOConfig, DDPOTrainer

34
trl/commands/__init__.py Normal file
View File

@ -0,0 +1,34 @@
# flake8: noqa
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule, OptionalDependencyNotAvailable
_import_structure = {
"cli_utils": ["SftArgumentParser", "init_zero_verbose", "DpoScriptArguments", "TrlParser"],
"config_parser": ["YamlConfigParser"],
}
if TYPE_CHECKING:
from .cli_utils import SftScriptArguments, init_zero_verbose, DpoScriptArguments, TrlParser
from .config_parser import YamlConfigParser
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

71
trl/commands/cli.py Normal file
View File

@ -0,0 +1,71 @@
# This file is a copy of trl/examples/scripts/sft.py so that we could
# use it together with rich and the TRL CLI in a more customizable manner.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from subprocess import CalledProcessError
from rich.console import Console
SUPPORTED_COMMANDS = ["sft", "dpo", "chat"]
def main():
console = Console()
# Make sure to import things locally to avoid verbose from third party libs.
with console.status("[bold purple]Welcome! Initializing the TRL CLI..."):
from trl.commands.cli_utils import init_zero_verbose
init_zero_verbose()
command_name = sys.argv[1]
if command_name not in SUPPORTED_COMMANDS:
raise ValueError(
f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}"
)
trl_examples_dir = os.path.dirname(__file__)
# Force-use rich
os.environ["TRL_USE_RICH"] = "1"
if command_name == "chat":
command = f"""
python {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
"""
else:
command = f"""
accelerate launch {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
"""
try:
subprocess.run(
command.split(),
text=True,
check=True,
encoding="utf-8",
cwd=os.getcwd(),
env=os.environ.copy(),
)
except (CalledProcessError, ChildProcessError) as exc:
console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.")
raise ValueError("TRL CLI failed! Check the traceback above..") from exc
if __name__ == "__main__":
main()

288
trl/commands/cli_utils.py Normal file
View File

@ -0,0 +1,288 @@
# This file is a copy of trl/examples/scripts/sft.py so that we could
# use it together with rich and the TRL CLI in a more customizable manner.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from copy import deepcopy
from dataclasses import asdict, dataclass, field, fields
from typing import Any, List
import yaml
from transformers import HfArgumentParser
class YamlConfigParser:
def __init__(self, config_path: str = None, dataclasses: List[Any] = None):
self.config = None
if config_path is not None:
with open(config_path) as yaml_file:
self.config = yaml.safe_load(yaml_file)
else:
self.config = {}
if dataclasses is None:
dataclasses = []
# We create a dummy training args to compare the values before / after
# __post_init__
# Here we import `TrainingArguments` from the local level to not
# break TRL lazy imports.
from transformers import TrainingArguments
self._dummy_training_args = TrainingArguments(output_dir="dummy-training-args")
self.parse_and_set_env()
self.merge_dataclasses(dataclasses)
def parse_and_set_env(self):
if "env" in self.config:
env_vars = self.config["env"]
if isinstance(env_vars, dict):
for key, value in env_vars.items():
os.environ[key] = str(value)
else:
raise ValueError("`env` field should be a dict in the YAML file.")
def merge_dataclasses(self, dataclasses):
from transformers import TrainingArguments
dataclasses_copy = [deepcopy(dataclass) for dataclass in dataclasses]
if len(self.config) > 0:
for i, dataclass in enumerate(dataclasses):
is_hf_training_args = False
for data_class_field in fields(dataclass):
# Get the field here
field_name = data_class_field.name
field_value = getattr(dataclass, field_name)
if not isinstance(dataclass, TrainingArguments):
default_value = data_class_field.default
else:
default_value = (
getattr(self._dummy_training_args, field_name)
if field_name != "output_dir"
else field_name
)
is_hf_training_args = True
default_value_changed = field_value != default_value
if field_value is not None or field_name in self.config:
if field_name in self.config:
# In case the field value is not different from default, overwrite it
if not default_value_changed:
value_to_replace = self.config[field_name]
setattr(dataclasses_copy[i], field_name, value_to_replace)
# Otherwise do nothing
# Re-init `TrainingArguments` to handle all post-processing correctly
if is_hf_training_args:
init_signature = list(inspect.signature(TrainingArguments.__init__).parameters)
dict_dataclass = asdict(dataclasses_copy[i])
new_dict_dataclass = {k: v for k, v in dict_dataclass.items() if k in init_signature}
dataclasses_copy[i] = TrainingArguments(**new_dict_dataclass)
return dataclasses_copy
def to_string(self):
final_string = """"""
for key, value in self.config.items():
if isinstance(value, (dict, list)):
if len(value) != 0:
value = str(value)
value = value.replace("'", '"')
value = f"'{value}'"
else:
continue
final_string += f"--{key} {value} "
return final_string
def init_zero_verbose():
"""
Perform zero verbose init - use this method on top of the CLI modules to make
"""
import logging
import warnings
from rich.logging import RichHandler
FORMAT = "%(message)s"
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.ERROR)
# Custom warning handler to redirect warnings to the logging system
def warning_handler(message, category, filename, lineno, file=None, line=None):
logging.warning(f"{filename}:{lineno}: {category.__name__}: {message}")
# Add the custom warning handler - we need to do that before importing anything to make sure the loggers work well
warnings.showwarning = warning_handler
@dataclass
class SftScriptArguments:
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"})
max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"})
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
gradient_checkpointing_use_reentrant: bool = field(
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
)
@dataclass
class DpoScriptArguments:
dataset_name: str = field(default=None, metadata={"help": "the dataset name"})
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: int = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"})
ignore_bias_buffers: bool = field(
default=False,
metadata={
"help": "debug argument for distributed training;"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
gradient_checkpointing_use_reentrant: bool = field(
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
)
@dataclass
class ChatArguments:
# general settings
model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"})
user: str = field(default=None, metadata={"help": "Username to display in chat interface"})
system_prompt: str = field(default=None, metadata={"help": "System prompt"})
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"})
device: str = field(
default="cpu",
metadata={"help": "device to use for inference."},
)
config: str = field(
default="default",
metadata={
"help": "Config file used for setting the configs. If `default` uses examples/scripts/config/default_chat_config.yaml"
},
)
examples: str = field(default=None, metadata={"help": "Empty placeholder needs to be set via config."})
# generation settings
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate"})
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation"})
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search"})
temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation"})
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"})
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"})
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"})
# model loading
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
torch_dtype: str = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
attn_implementation: str = field(
default=None,
metadata={
"help": (
"Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`"
)
},
)
load_in_8bit: bool = field(
default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}
)
load_in_4bit: bool = field(
default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}
)
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
class TrlParser(HfArgumentParser):
def __init__(self, parsers):
"""
The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config
parsers for users that pass a valid `config` field and merge the values that are set in the config
with the processed parsers.
Args:
parsers (`List[argparse.ArgumentParser`]):
List of parsers.
"""
super().__init__(parsers)
def post_process_dataclasses(self, dataclasses):
# Apply additional post-processing in case some arguments needs a special
# care
training_args = trl_args = None
training_args_index = None
for i, dataclass_obj in enumerate(dataclasses):
if dataclass_obj.__class__.__name__ == "TrainingArguments":
training_args = dataclass_obj
training_args_index = i
elif dataclass_obj.__class__.__name__ in ("SftScriptArguments", "DpoScriptArguments"):
trl_args = dataclass_obj
else:
...
if trl_args is not None and training_args is not None:
training_args.gradient_checkpointing_kwargs = dict(
use_reentrant=trl_args.gradient_checkpointing_use_reentrant
)
dataclasses[training_args_index] = training_args
return dataclasses
def parse_args_and_config(self):
dataclasses = self.parse_args_into_dataclasses(return_remaining_strings=True)
# Pop the last element which should be the remaining strings
dataclasses = self.update_dataclasses_with_config(dataclasses[:-1])
return dataclasses
def update_dataclasses_with_config(self, dataclasses):
self.config_parser = None
for parser_dataclass in dataclasses:
if hasattr(parser_dataclass, "config"):
if self.config_parser is not None:
raise ValueError("You passed the `config` field twice! Make sure to pass `config` only once.")
self.config_parser = YamlConfigParser(parser_dataclass.config)
if self.config_parser is not None:
dataclasses = self.config_parser.merge_dataclasses(dataclasses)
dataclasses = self.post_process_dataclasses(dataclasses)
return dataclasses

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import top_k_top_p_filtering
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
from .import_utils import is_npu_available, is_xpu_available
@ -30,12 +30,48 @@ from .import_utils import is_npu_available, is_xpu_available
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
from collections.abc import Mapping
WANDB_PADDING = -1
def top_k_top_p_filtering(
logits: torch.FloatTensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)
if 0 <= top_p <= 1.0:
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)
return logits
def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
"""Flatten dictionary and concatenate nested keys with separator."""
@ -80,7 +116,7 @@ def stack_dicts(stats_dicts: List[Dict]) -> Dict:
def add_suffix(input_dict: Dict, suffix: str) -> Dict:
"""Add suffix to dict keys."""
return dict((k + suffix, v) for k, v in input_dict.items())
return {k + suffix: v for k, v in input_dict.items()}
def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor:
@ -113,7 +149,7 @@ def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
return whitened
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool = None) -> torch.Tensor:
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
@ -194,7 +230,7 @@ def respond_to_batch(
) -> torch.LongTensor:
"""Sample text from language model."""
input_ids = queries
for i in range(txt_len):
for _i in range(txt_len):
# Get Logits
outputs = model(input_ids)
next_token_logits = outputs[0][:, -1, :]
@ -236,7 +272,7 @@ class LengthSampler:
return np.random.choice(self.values)
class PPODecorators(object):
class PPODecorators:
optimize_device_cache = False
@classmethod

View File

@ -1,3 +1,14 @@
# flake8: noqa
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule
from .base_environment import TextEnvironment, TextHistory
_import_structure = {
"base_environment": ["TextEnvironment", "TextHistory"],
}
if TYPE_CHECKING:
from .base_environment import TextEnvironment, TextHistory
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -14,6 +14,7 @@
import re
import warnings
from typing import Optional
import torch
from accelerate.utils import extract_model_from_parallel
@ -45,7 +46,7 @@ class StringStoppingCriteria(StoppingCriteria):
done = []
for i, decoded_generation in enumerate(decoded_generations):
sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings])
sequence_complete = any(stop_string in decoded_generation for stop_string in self.stop_strings)
done.append(sequence_complete)
if not sequence_complete:
self.generated_tokens[i] += 1
@ -242,7 +243,7 @@ class TextEnvironment:
if isinstance(tools, dict):
self.tools = tools
else:
self.tools = dict([(tool.__class__.__name__, tool) for tool in tools])
self.tools = {tool.__class__.__name__: tool for tool in tools}
self.reward_fn = reward_fn
self.max_length = max_length
self.request_token = "<request>"
@ -277,7 +278,7 @@ class TextEnvironment:
histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)]
while any([not history.completed for history in histories]) and turns < self.max_turns:
while any(not history.completed for history in histories) and turns < self.max_turns:
histories = self.generate(histories)
histories = self.tasks_end_check(histories)
# TODO: make this parallel rather than for-loop
@ -416,7 +417,7 @@ class TextEnvironment:
self,
query_tensors,
batch_size: int = 16,
pad_to_multiple_of: int = None,
pad_to_multiple_of: Optional[int] = None,
):
"""
Generate responses for a list of query tensors.

View File

@ -13,4 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .best_of_n_sampler import BestOfNSampler
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule
_import_structure = {
"best_of_n_sampler": ["BestOfNSampler"],
}
if TYPE_CHECKING:
from .best_of_n_sampler import BestOfNSampler
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -7,7 +7,7 @@ from ..core import set_seed
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper
class BestOfNSampler(object):
class BestOfNSampler:
def __init__(
self,
model: PreTrainedModelWrapper,

View File

@ -0,0 +1,88 @@
import logging
from typing import Callable, Literal, Optional, Union
from datasets import Dataset, Value
from transformers import AutoTokenizer
from ..trainer.utils import ConstantLengthDataset
FORMAT_MAPPING = {
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
}
def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
r"""
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""
def format_dataset(examples):
if isinstance(examples[messages_field][0], list):
output_texts = []
for i in range(len(examples[messages_field])):
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
return output_texts
else:
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)
return format_dataset
def instructions_formatting_function(tokenizer: AutoTokenizer):
r"""
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""
def format_dataset(examples):
if isinstance(examples["prompt"], list):
output_texts = []
for i in range(len(examples["prompt"])):
converted_sample = [
{"role": "user", "content": examples["prompt"][i]},
{"role": "assistant", "content": examples["completion"][i]},
]
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
return output_texts
else:
converted_sample = [
{"role": "user", "content": examples["prompt"]},
{"role": "assistant", "content": examples["completion"]},
]
return tokenizer.apply_chat_template(converted_sample, tokenize=False)
return format_dataset
def get_formatting_func_from_dataset(
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer
) -> Optional[Callable]:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
- `ChatML` with [{"role": str, "content": str}]
- `instruction` with [{"prompt": str, "completion": str}]
Args:
dataset (Dataset): User dataset
tokenizer (AutoTokenizer): Tokenizer used for formatting
Returns:
Callable: Formatting function if the dataset format is supported else None
"""
if isinstance(dataset, Dataset):
if "messages" in dataset.features:
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "messages")
if "conversations" in dataset.features:
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "conversations")
elif dataset.features == FORMAT_MAPPING["instruction"]:
logging.info("Formatting dataset with instruction format")
return instructions_formatting_function(tokenizer)
return None

View File

@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import sys
from importlib.util import find_spec
from itertools import chain
from types import ModuleType
from typing import Any
if sys.version_info < (3, 8):
@ -22,11 +27,11 @@ else:
def is_peft_available() -> bool:
return importlib.util.find_spec("peft") is not None
return find_spec("peft") is not None
def is_unsloth_available() -> bool:
return importlib.util.find_spec("unsloth") is not None
return find_spec("unsloth") is not None
def is_accelerate_greater_20_0() -> bool:
@ -41,9 +46,16 @@ def is_accelerate_greater_20_0() -> bool:
return accelerate_version >= "0.20.0"
def is_transformers_greater_than(version: str) -> bool:
_transformers_version = importlib.metadata.version("transformers")
return _transformers_version > version
def is_transformers_greater_than(current_version: str) -> bool:
if _is_python_greater_3_8:
from importlib.metadata import version
_transformers_version = version("transformers")
else:
import pkg_resources
_transformers_version = pkg_resources.get_distribution("transformers").version
return _transformers_version > current_version
def is_torch_greater_2_0() -> bool:
@ -59,23 +71,26 @@ def is_torch_greater_2_0() -> bool:
def is_diffusers_available() -> bool:
return importlib.util.find_spec("diffusers") is not None
return find_spec("diffusers") is not None
def is_bitsandbytes_available() -> bool:
return importlib.util.find_spec("bitsandbytes") is not None
import torch
# bnb can be imported without GPU but is not usable.
return find_spec("bitsandbytes") is not None and torch.cuda.is_available()
def is_torchvision_available() -> bool:
return importlib.util.find_spec("torchvision") is not None
return find_spec("torchvision") is not None
def is_rich_available() -> bool:
return importlib.util.find_spec("rich") is not None
return find_spec("rich") is not None
def is_wandb_available() -> bool:
return importlib.util.find_spec("wandb") is not None
return find_spec("wandb") is not None
def is_xpu_available() -> bool:
@ -84,7 +99,7 @@ def is_xpu_available() -> bool:
return accelerate.utils.is_xpu_available()
else:
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
if find_spec("intel_extension_for_pytorch") is None:
return False
try:
import torch
@ -96,10 +111,74 @@ def is_xpu_available() -> bool:
def is_npu_available() -> bool:
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
if find_spec("torch") is None or find_spec("torch_npu") is None:
return False
import torch
import torch_npu # noqa: F401
return hasattr(torch, "npu") and torch.npu.is_available()
class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = import_structure
# Needed for autocompletion in an IDE
def __dir__(self):
result = super().__dir__()
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
for attr in self.__all__:
if attr not in result:
result.append(attr)
return result
def __getattr__(self, name: str) -> Any:
if name in self._objects:
return self._objects[name]
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str):
try:
return importlib.import_module("." + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
f" traceback):\n{e}"
) from e
def __reduce__(self):
return (self.__class__, (self._name, self.__file__, self._import_structure))
class OptionalDependencyNotAvailable(BaseException):
"""Internally used error class for signalling an optional dependency was not found."""

View File

@ -13,22 +13,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_base import PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
# flake8: noqa
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
SUPPORTED_ARCHITECTURES = (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
)
_import_structure = {
"modeling_base": ["PreTrainedModelWrapper", "create_reference_model"],
"modeling_value_head": [
"AutoModelForCausalLMWithValueHead",
"AutoModelForSeq2SeqLMWithValueHead",
],
"utils": ["setup_chat_format", "SUPPORTED_ARCHITECTURES"],
}
from ..import_utils import is_diffusers_available
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_sd_base"] = [
"DDPOPipelineOutput",
"DDPOSchedulerOutput",
"DDPOStableDiffusionPipeline",
"DefaultDDPOStableDiffusionPipeline",
]
if TYPE_CHECKING:
from .modeling_base import PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import setup_chat_format, SUPPORTED_ARCHITECTURES
if is_diffusers_available():
from .modeling_sd_base import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_sd_base import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -15,6 +15,7 @@ import json
import logging
import os
from copy import deepcopy
from typing import Optional
import torch
import torch.nn as nn
@ -70,6 +71,7 @@ class PreTrainedModelWrapper(nn.Module):
supported_args: (`list`)
The list of arguments that are supported by the wrapper class.
"""
transformers_parent_class = None
supported_args = None
supported_modules = ("v_head",)
@ -377,12 +379,12 @@ class PreTrainedModelWrapper(nn.Module):
)
# load json
if is_resuming_training:
with open(index_file_name, "r") as f:
with open(index_file_name) as f:
index = json.load(f)
# check filename with `v_head` or any known extra module:
files_to_download = set()
for k, v in index["weight_map"].items():
if any([module in k for module in cls.supported_modules]):
if any(module in k for module in cls.supported_modules):
files_to_download.add(v)
is_sharded = True
@ -459,7 +461,7 @@ class PreTrainedModelWrapper(nn.Module):
"adapter_model.bin",
token=token,
)
except: # noqa
except Exception:
filename = os.path.join(adapter_model_id, "adapter_model.safetensors")
safe_loading = True
if not os.path.exists(filename):
@ -469,10 +471,11 @@ class PreTrainedModelWrapper(nn.Module):
"adapter_model.safetensors",
token=token,
)
except: # noqa
except Exception as exc:
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
)
"Could not find adapter model in the Hub, "
"make sure you have the correct adapter model id."
) from exc
else:
local_filename = filename
else:
@ -484,7 +487,7 @@ class PreTrainedModelWrapper(nn.Module):
adapter_state_dict = loading_func(local_filename, **load_kwargs)
for score_name_candidate in cls.supported_rm_modules:
if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
if any(score_name_candidate in name for name in adapter_state_dict.keys()):
score_name = score_name_candidate
# we have found the correct head name and can break
break
@ -497,7 +500,7 @@ class PreTrainedModelWrapper(nn.Module):
score_dict[key_name] = param.to(cls._get_current_device())
num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])
has_bias = any("bias" in name for name in adapter_state_dict.keys())
score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=cls._get_current_device(),
@ -600,7 +603,7 @@ class PreTrainedModelWrapper(nn.Module):
def create_reference_model(
model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None
model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None
) -> PreTrainedModelWrapper:
"""
Creates a static reference copy of a model. Note that model will be in `.eval()` mode.
@ -635,7 +638,7 @@ def create_reference_model(
else:
for pattern_candidate in LAYER_PATTERNS:
pattern_candidate = pattern_candidate.format(layer=num_shared_layers)
if any([pattern_candidate in name for name in parameter_names]):
if any(pattern_candidate in name for name in parameter_names):
pattern = pattern_candidate
break
@ -647,7 +650,7 @@ def create_reference_model(
unshared_param_list = []
shared_parameter = True
for name, param in model.named_parameters():
for name, _param in model.named_parameters():
if pattern in name:
shared_parameter = False
if shared_parameter:
@ -660,8 +663,7 @@ def create_reference_model(
param = model.get_parameter(param_name)
param.requires_grad = False
ref_param = ref_model.get_parameter(param_name) # noqa
ref_param = param # noqa
_ref_param = ref_model.get_parameter(param_name)
# for all other parameters just make sure they don't use gradients
for param_name in unshared_param_list:

View File

@ -22,10 +22,10 @@ import numpy as np
import torch
from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils import convert_state_dict_to_diffusers
from ..core import randn_tensor
from ..import_utils import is_peft_available
from .sd_utils import convert_state_dict_to_diffusers
if is_peft_available():
@ -34,7 +34,7 @@ if is_peft_available():
@dataclass
class DDPOPipelineOutput(object):
class DDPOPipelineOutput:
"""
Output class for the diffusers pipeline to be finetuned with the DDPO trainer
@ -54,7 +54,7 @@ class DDPOPipelineOutput(object):
@dataclass
class DDPOSchedulerOutput(object):
class DDPOSchedulerOutput:
"""
Output class for the diffusers scheduler to be finetuned with the DDPO trainer
@ -69,7 +69,7 @@ class DDPOSchedulerOutput(object):
log_probs: torch.Tensor
class DDPOStableDiffusionPipeline(object):
class DDPOStableDiffusionPipeline:
"""
Main class for the diffusers pipeline to be finetuned with the DDPO trainer
"""

View File

@ -85,6 +85,7 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
- **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution.
"""
transformers_parent_class = AutoModelForCausalLM
lm_head_namings = ["lm_head", "embed_out"]
supported_args = (
@ -218,7 +219,7 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
return pretrained_model_state_dict
def push_to_hub(self, *args, **kwargs):
setattr(self.pretrained_model, "v_head", self.v_head)
self.pretrained_model.v_head = self.v_head
return self.pretrained_model.push_to_hub(*args, **kwargs)
@ -276,6 +277,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
kwargs:
Additional keyword arguments passed along to the `ValueHead` class.
"""
transformers_parent_class = AutoModelForSeq2SeqLM
lm_head_namings = ["lm_head", "embed_out", "output_projection"]
supported_args = (
@ -298,7 +300,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
def _has_lm_head(self):
# check module names of all modules inside `pretrained_model` to find the language model head
for name, module in self.pretrained_model.named_modules():
for name, _module in self.pretrained_model.named_modules():
if any(attribute in name for attribute in self.lm_head_namings):
return True
return False
@ -374,7 +376,7 @@ class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
return pretrained_model_state_dict
def push_to_hub(self, *args, **kwargs):
setattr(self.pretrained_model, "v_head", self.v_head)
self.pretrained_model.v_head = self.v_head
return self.pretrained_model.push_to_hub(*args, **kwargs)

150
trl/models/sd_utils.py Normal file
View File

@ -0,0 +1,150 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
State dict utilities: utility methods for converting state dicts easily
File copied from diffusers to avoid import issues and make TRL compatible
with most of diffusers versions.
"""
import enum
class StateDictType(enum.Enum):
"""
The mode to use when converting state dicts.
"""
DIFFUSERS_OLD = "diffusers_old"
PEFT = "peft"
PEFT_TO_DIFFUSERS = {
".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
"to_k.lora_A": "to_k.lora.down",
"to_k.lora_B": "to_k.lora.up",
"to_q.lora_A": "to_q.lora.down",
"to_q.lora_B": "to_q.lora.up",
"to_v.lora_A": "to_v.lora.down",
"to_v.lora_B": "to_v.lora.up",
"to_out.0.lora_A": "to_out.0.lora.down",
"to_out.0.lora_B": "to_out.0.lora.up",
}
DIFFUSERS_OLD_TO_DIFFUSERS = {
".to_q_lora.up": ".q_proj.lora_linear_layer.up",
".to_q_lora.down": ".q_proj.lora_linear_layer.down",
".to_k_lora.up": ".k_proj.lora_linear_layer.up",
".to_k_lora.down": ".k_proj.lora_linear_layer.down",
".to_v_lora.up": ".v_proj.lora_linear_layer.up",
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
}
DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}
KEYS_TO_ALWAYS_REPLACE = {
".processor.": ".",
}
def convert_state_dict(state_dict, mapping):
r"""
Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
mapping (`dict[str, str]`):
The mapping to use for conversion, the mapping should be a dictionary with the following structure:
- key: the pattern to replace
- value: the pattern to replace with
Returns:
converted_state_dict (`dict`)
The converted state dict.
"""
converted_state_dict = {}
for k, v in state_dict.items():
# First, filter out the keys that we always want to replace
for pattern in KEYS_TO_ALWAYS_REPLACE.keys():
if pattern in k:
new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern]
k = k.replace(pattern, new_pattern)
for pattern in mapping.keys():
if pattern in k:
new_pattern = mapping[pattern]
k = k.replace(pattern, new_pattern)
break
converted_state_dict[k] = v
return converted_state_dict
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
r"""
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
return the state dict as is.
The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
but we add it here in case we don't want to rely on that method.
"""
peft_adapter_name = kwargs.pop("adapter_name", None)
if peft_adapter_name is not None:
peft_adapter_name = "." + peft_adapter_name
else:
peft_adapter_name = ""
if original_type is None:
# Old diffusers to PEFT
if any("to_out_lora" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS_OLD
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
elif any("lora_linear_layer" in k for k in state_dict.keys()):
# nothing to do
return state_dict
else:
raise ValueError("Could not automatically infer state dict type")
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping)

93
trl/models/utils.py Normal file
View File

@ -0,0 +1,93 @@
from dataclasses import dataclass
from typing import Literal, Optional, Tuple
from transformers import PreTrainedModel, PreTrainedTokenizer
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
SUPPORTED_ARCHITECTURES = (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
)
# TODO: Add Abstract Base Class if more formats are added
@dataclass
class ChatMlSpecialTokens:
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens."""
bos_token: str = "<|im_start|>"
eos_token: str = "<|im_end|>"
pad_token: str = "<|im_end|>"
@property
def system(self):
return f"{self.bos_token}system"
@property
def user(self):
return f"{self.bos_token}user"
@property
def assistant(self):
return f"{self.bos_token}assistant"
@property
def chat_template(self):
return (
"{% for message in messages %}"
f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
f"{{{{ '{self.assistant}\n' }}}}"
"{% endif %}"
)
FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}
def setup_chat_format(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
format: Optional[Literal["chatml"]] = "chatml",
resize_to_multiple_of: Optional[int] = None,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
Args:
model (`~transformers.PreTrainedModel`): The model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None.
Returns:
model (`~transformers.PreTrainedModel`): The modified model.
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
"""
# check if format available and retrieve
if format not in FORMAT_MAPPING:
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")
chat_format = FORMAT_MAPPING[format]()
# set special tokens and them
tokenizer.eos_token = chat_format.eos_token
tokenizer.pad_token = chat_format.pad_token
tokenizer.bos_token = chat_format.bos_token
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
# set chat format for tokenizer
tokenizer.chat_template = chat_format.chat_template
# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
)
# Make sure to update the generation config to use the new eos & bos token
if getattr(model, "generation_config", None) is not None:
model.generation_config.bos_token_id = tokenizer.bos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer

View File

@ -15,31 +15,84 @@
# limitations under the License.
# There is a circular import in the PPOTrainer if we let isort sort these
# isort: off
from .utils import (
AdaptiveKLController,
FixedKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
disable_dropout_in_model,
peft_module_casting_to_bf16,
)
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
# isort: on
_import_structure = {
"utils": [
"AdaptiveKLController",
"FixedKLController",
"ConstantLengthDataset",
"DataCollatorForCompletionOnlyLM",
"RunningMoments",
"disable_dropout_in_model",
"peft_module_casting_to_bf16",
"RichProgressCallback",
],
"dpo_trainer": [
"DPOTrainer",
],
"iterative_sft_trainer": [
"IterativeSFTTrainer",
],
"kto_config": ["KTOConfig"],
"kto_trainer": ["KTOTrainer"],
"model_config": ["ModelConfig"],
"ppo_config": ["PPOConfig"],
"ppo_trainer": ["PPOTrainer"],
"reward_config": ["RewardConfig"],
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
"sft_trainer": ["SFTTrainer"],
"base": ["BaseTrainer"],
"ddpo_config": ["DDPOConfig"],
}
from ..import_utils import is_diffusers_available
from .base import BaseTrainer
from .ddpo_config import DDPOConfig
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["ddpo_trainer"] = ["DDPOTrainer"]
if is_diffusers_available():
from .ddpo_trainer import DDPOTrainer
if TYPE_CHECKING:
# isort: off
from .utils import (
AdaptiveKLController,
FixedKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
disable_dropout_in_model,
peft_module_casting_to_bf16,
RichProgressCallback,
)
from .dpo_trainer import DPOTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .reward_trainer import RewardTrainer, compute_accuracy
from .sft_trainer import SFTTrainer
from .training_configs import RewardConfig
# isort: on
from .base import BaseTrainer
from .ddpo_config import DDPOConfig
from .dpo_trainer import DPOTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .kto_config import KTOConfig
from .kto_trainer import KTOTrainer
from .model_config import ModelConfig
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer, compute_accuracy
from .sft_trainer import SFTTrainer
try:
if not is_diffusers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .ddpo_trainer import DDPOTrainer
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

Some files were not shown because too many files have changed in this diff Show More