Commit Graph

23 Commits

Author SHA1 Message Date
9df19e8a75 📜 Fix license and copyrights (#3264) 2025-04-08 15:22:58 -07:00
cf97133d51 📉 Optimize GRPO memory usage by redefining per_device_batch_size as generations per device (#2776)
* Distribute

* fix some logic errors

* fix and document RepeatRandomSampler

* comment

* doc clarification

* fix type hint

* more readable

* fix eval

* fix tests

* roll back to distribute generation

* improve comment [ci skip]

* fix slice

* catch for eval batch size as well; fix completion_ids in vllm

* log completions

* Revert "log completions"

This reverts commit 1e4af8ffb8dda15d7596e707ac784208db88135a.

* Before the first training step, the model has no optimizer: fix ds3
2025-02-06 20:20:44 +01:00
3d2c1e49b1 Fix merge error (#2595) 2025-01-20 22:17:39 +01:00
5fd78367ae 🫣 Ignore CLI test for Python 3.9 (#2592)
* ignore cli test for python 3.9

* move import inside tests
2025-01-20 21:26:11 +01:00
0f5ffad26e 👨‍👨‍👧‍👧 GRPO (#2565)
* init grpo [ci skip]

* initial version

* refine args defs

* model card

* initial doc

* fix badges

* fix spaces

* try link to super in doc

* temperature, fix indexing, and std=0.0

* grpo script for cli

* peft support

* move data preparation in `compute_loss`

* weird doc trial

* fix device and some logging

* unwrap_model_for_generation for distributed setting

* Compat with distrib training

* revert grpo config doc trial (didn't work)

* test

* allow model to be str and processing_class to be none; fix loss computation

* advantage is always 0.0: don't log

* fix peft not installed

* proper reward model for testing

* fix script for cli

* add trl grpo to cli doc

* test peft

* flush left

* fix reward calculation

* new reward model

* support any reward model

* fix reward processing class def

* log reward std

* fix reward logging

* fix grad computation

* skip embed layer in test

* remove optimizer_cls_and_kwargs

* improve GRPO default args

* reduce mem usage for grpo test

* reduce mem usage in test grpo

* reduce memory usage for test

* Fix the test

* remove redondant

* fix min version

* Update test_grpo_trainer.py

* Update test_grpo_trainer.py

* Fix test, finally found the solution!

* some doc

* Update doc-builder workflow to use specific commit sha

* more doc

* advantages

* drop cancel fo no grad

* logged metrics [ci skip]

* completion col is ignored [ci skip]

* fix latex

* double space? ~?

* try a latex fix

* with branch

* Empty commit

* Empty commit

* double space seems to be the solution
2025-01-20 19:02:15 +01:00
1d23ecc36f ©️ Update copyrights year (#2547)
* happy new year

* fix wandb import sort
2025-01-07 14:53:09 +01:00
ca850be0a2 🕹️ CLI refactor (#2380)
* Refactor main function in dpo.py

* Update setup.py and add cli.py

* Add examples to package data

* style

* Refactor setup.py file

* Add new file t.py

* Move dpo to package

* Update MANIFEST.in and setup.py, refactor trl/cli.py

* Add __init__.py to trl/scripts directory

* Add license header to __init__.py

* File moved instruction

* Add Apache License and update file path

* Move dpo.py to new location

* Refactor CLI and DPO script

* Refactor import structure in scripts package

* env

* rm config from chat arg

* rm old cli

* chat init

* test cli [skip ci]

* Add `datast_config_name` to `ScriptArguments` (#2440)

* add missing arg

* Add test cases for 'trl sft' and 'trl dpo' commands

* Add sft.py script and update cli.py to include sft command

* Move sft script

* chat

* style [ci skip]

* kto

* rm example config

* first step on doc

* see #2442

* see #2443

* fix chat windows

* ©️ Copyrights update (#2454)

* First changes

* Other files

* Finally

* rm comment

* fix nashmd

* Fix example

* Fix example [ci skip]

* 💬 Fix chat for windows (#2443)

* fix chat for windows

* add some tests back

* Revert "add some tests back"

This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06.

* 🆔 Add `datast_config` to `ScriptArguments` (#2440)

* datast_config_name

* Update trl/utils.py [ci skip]

* sort import

* typo [ci skip]

* Trigger CI

* Rename `dataset_config_name` to `dataset_config`

* 🏎 Fix deepspeed preparation of `ref_model` in `OnlineDPOTrainer` (#2417)

* Remove unused deepspeed code

* add model prep back

* add deepspeed even if it doesn't work

* rm old code

* Fix config name

* Remove `make dev` in favor of `pip install -e .[dev]`

* Update script paths and remove old symlink related things

* Fix chat script path [ci skip]

* style
2024-12-13 17:52:23 +01:00
9410874787 ©️ Copyrights update (#2454)
* First changes

* Other files

* Finally

* rm comment

* fix nashmd

* Fix example

* Fix example [ci skip]
2024-12-10 10:40:00 +01:00
453db5cd79 🤏 New models for tests (#2287)
* first commit

* uncomment

* other tests adaptations

* Remove unused variable in test_setup_chat_format

* Remove unused import statement

* style

* Add Bart model

* Update BCOTrainerTester class in test_bco_trainer.py

* Update model IDs and tokenizers in test files

* Add new models and processors

* Update model IDs in test files

* Fix formatting issue in test_dataset_formatting.py

* Refactor dataset formatting in test_dataset_formatting.py

* Fix dataset sequence length in SFTTrainerTester

* Remove tokenizer

* Remove print statement

* Add reward_model_path and sft_model_path to PPO trainer

* Fix tokenizer padding issue

* Add chat template for testing purposes in PaliGemma model

* Update PaliGemma model and chat template

* Increase learning rate to speed up test

* Update model names in run_dpo.sh and run_sft.sh scripts

* Update model and dataset names

* Fix formatting issue in test_dataset_formatting.py

* Fix formatting issue in test_dataset_formatting.py

* Remove unused chat template

* Update model generation script

* additional models

* Update model references in test files

* Remove unused imports in test_online_dpo_trainer.py

* Add is_llm_blender_available import and update reward_tokenizer

* Refactor test_online_dpo_trainer.py: Move skipped test case decorator

* remove models without chat templates

* Update model names in scripts and tests

* Update model_id in test_modeling_value_head.py

* Update model versions in test files

* Fix formatting issue in test_dataset_formatting.py

* Update embedding model ID in BCOTrainerTester

* Update test_online_dpo_trainer.py with reward model changes

* Update expected formatted text in test_dataset_formatting.py

* Add reward_tokenizer to TestOnlineDPOTrainer

* fix tests

* Add SIMPLE_CHAT_TEMPLATE to T5 tokenizer

* Fix dummy_text format in test_rloo_trainer.py

* Skip outdated test for chatML data collator

* Add new vision language models

* Commented out unused model IDs in test_vdpo_trainer

* Update model and vision configurations in generate_tiny_models.py and test_dpo_trainer.py

* Update model and tokenizer references

* Don't push if it already exists

* Add comment explaining test skip

* Fix model_exists function call and add new models

* Update LlavaForConditionalGeneration model and processor

* `qgallouedec` -> `trl-internal-testing`
2024-11-25 16:31:56 +01:00
24fb32733f 🔧 Use standard unittest assertion methods (#2283)
* WIP: Partial unit test update

* Update unittest format

* Update tests/slow/test_sft_slow.py comment

* Refactor unit tests: replace pytest.raises with self.assertRaises

* Fix: Restore accidentally deleted 'ref_model' parameter in DPOTrainer

* Re-run pre-commit

* fix: Incorrectly replacing non-TestCase assert

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-10-31 15:10:43 +01:00
92f6d246d3 🏗️ Refactor DPO data processing (#2209)
* in progress

* refactor concatenated_inputs and concatenated_forward

* progress

* further modif

* padding side

* eos prompt enc dec

* prompt_padding_side

* drop prompt apdding side collator

* working on decoder only

* dpo trainer

* Fix loss_mask type conversion bug

* bad attention mask

* try to get the same tokens as main

* fix loss mask

* fix unused col

* added comment

* raise error when paddind token not set

* remove private method tests

* initial vlm support

* make it work for paligemma

* minor test updates

* style

* improve readibility

* improve doc

* style

* flush left and truncate

* flush left in the code

* fix empty_cols and make max_length optional

* always add eos token

* minor changes and doc

* style

* fix docstring

* preference collator in doc

* fix doc

* optional max_completion_length

* Investigating CI failing

* style

* just dpo trainer test

* just idefics

* paligemma

* llava

* test cli

* dataset in test

* all tests

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

* reference to ref

* rich descriptions

* fix logits reporting

* fix truncation

* remove chat template from dpo_vlm

* `get_batch_sample` -> `generate_from_model[_and_ref]`

* add `num_items_in_batch=None`

* `num_items_in_batch` in `training_step`

* Fix return type hint

* test tokenize row

* fix test

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-10-21 12:47:33 +02:00
a9cffc7caf Default dataset_text_field to "text" (#2078)
* clarify ConstantLengthDataset usage

* dont provide dataset text field when formatting func is provided

* kto maybe_apply_chat_template

* default text field

* doc

* remove maybe_apply_chat_template from kto example

* dataset text field always a str

* remove `dataset_text_field="text"`

* update doc
2024-10-04 10:55:47 +02:00
1a3bb372ac Fix typo in error message (#2168)
occured -> occurred
2024-10-04 09:36:52 +02:00
2cad48d511 [CLI] trl env for printing system info (#2104) 2024-09-24 09:57:24 +02:00
9a6061fc2f Clean up DPO example (#2043)
* Clean up DPO example

* Fix bs

* Remove rentrant

* Fix tests

* Nuke sanity checks

* Switch dataset

* Remove sanity check from XPO
2024-09-11 17:45:00 +02:00
ac071d6225 Drop canonical dataset namespaces (#2048)
* drop canonical

* Delete ultrafeedback_prompt_only.py dataset script

* reduce dif in best_of_n

* try to revert best_of_n to make github happy

* anyway...
2024-09-10 12:12:00 +02:00
7075cec94d Update HH dataset on helpful only subset (#1613)
* Update HH dataset on helpful only subset

* format
2024-05-02 12:12:12 -04:00
9f68ead8cf FIX: Fix CI on transformers main (#1576)
* Update run_dpo.sh

* Update run_sft.sh

* Update clis.mdx

* Update example_config.yaml

* Update test_cli.py

* Update testing_constants.py

* Update test_dpo_trainer.py
2024-04-23 14:31:45 +02:00
24fd8dd513 [DPO] DPOConfig class (#1554)
* initial DPOConfig

* fix doc string

* use DPOConfig

* fix missing import

* fix DpoScriptArguments

* override args config when given in init

* use DPOConfig

* fix output dir name

* over-ride with depreicated arguments if given

* use DPOConfig in tests

* fix comment

* add custom_message

* use dataset_train_name and dataset_test_name

* beta is also in the training_args

* fix loss_type docs

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* use DPOScriptArguments

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-04-23 11:06:28 +02:00
5d1deb1445 CLI: Set dataset_text_field to None to allow ChatML automatic template (#1545)
* Update cli_utils.py

* Update test_cli.py
2024-04-17 14:45:14 +02:00
423991c204 Use the standard dataset for DPO CLI (#1456)
* Use the standard dataset

* update docs

* update dpo examples

* fix cli error

* fix CI

* use trl-internal-testing/hh-rlhf-trl-style
2024-03-20 13:14:08 -04:00
eb2d5b2972 CI / CLI: Properly raise error when CLI tests failed (#1446)
* properly raise error

* another fix

* Update tests.yml

* Update tests-main.yml
2024-03-19 11:39:07 +01:00
a2aa0f0b09 FEAT: Add CLIs in TRL ! (#1419)
* CLI V1

* v1 CLI

* add rich enhancmeents

* revert unindented change

* some comments

* cleaner CLI

* fix

* fix

* remove print callback

* move to cli instead of trl_cli

* revert unneeded changes

* fix test

* Update trl/commands/sft.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* remove redundant strings

* fix import issue

* fix other issues

* add packing

* add config parser

* some refactor

* cleaner

* add example config yaml file

* small refactor

* change a bit the logic

* fix issues here and there

* add CLI in docs

* move to examples/sft

* remove redundant licenses

* make it work on dpo

* set to None

* switch to accelerate and fix many things

* add docs

* more docs

* added tests

* doc clarification

* more docs

* fix CI for windows and python 3.8

* fix

* attempt to fix CI

* fix?

* test

* fix

* tweak?

* fix

* test

* another test

* fix

* test

* fix

* fix

* fix

* skip tests for windows

* test @lvwerra approach

* make dev

* revert unneeded changes

* fix sft dpo

* optimize a bit

* address final comments

* update docs

* final comment

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2024-03-18 12:20:54 +01:00