mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
31 Commits
online-dpo
...
a932e2796d
Author | SHA1 | Date | |
---|---|---|---|
a932e2796d | |||
04fd1203af | |||
19d2f97932 | |||
31caf64778 | |||
8e2d5516ca | |||
94aac4a101 | |||
26b7c2507e | |||
aa25c2697c | |||
93c7d88563 | |||
c7c041ecc8 | |||
ef40c047aa | |||
7e0adbc552 | |||
773afd9314 | |||
966b397201 | |||
927cf6ba46 | |||
56cb6ccf76 | |||
49c8f14b06 | |||
cefbacb30e | |||
fae245a062 | |||
2aa9506c69 | |||
d6eeb290d9 | |||
1684ef279a | |||
aab21eb5e7 | |||
b997a31981 | |||
86d1963cc1 | |||
039d526d24 | |||
bcd059a384 | |||
0e57b4a9df | |||
98488e0946 | |||
f45e86571b | |||
f5827928a0 |
7
.github/workflows/slow-tests.yml
vendored
7
.github/workflows/slow-tests.yml
vendored
@ -102,13 +102,6 @@ jobs:
|
||||
source .venv/bin/activate
|
||||
make slow_tests
|
||||
|
||||
- name: Run end-to-end examples tests on multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install deepspeed
|
||||
make test_examples
|
||||
|
||||
- name: Generate Reports
|
||||
if: always()
|
||||
run: |
|
||||
|
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -129,7 +129,7 @@ jobs:
|
||||
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -U git+https://github.com/huggingface/datasets.git
|
||||
uv pip install -U git+https://github.com/huggingface/transformers.git
|
||||
|
||||
uv pip install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
|
2
.github/workflows/tests_latest.yml
vendored
2
.github/workflows/tests_latest.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
with: { ref: v0.23-release }
|
||||
with: { ref: v0.24-release }
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
|
@ -31,4 +31,4 @@ keywords:
|
||||
- pytorch
|
||||
- transformers
|
||||
license: Apache-2.0
|
||||
version: "0.23"
|
||||
version: "0.24"
|
||||
|
16
Makefile
16
Makefile
@ -1,9 +1,8 @@
|
||||
.PHONY: test precommit common_tests slow_tests test_examples tests_gpu test_experimental
|
||||
.PHONY: test precommit common_tests slow_tests tests_gpu test_experimental
|
||||
|
||||
check_dirs := examples tests trl
|
||||
|
||||
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
test:
|
||||
pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
|
||||
@ -16,18 +15,5 @@ precommit:
|
||||
slow_tests:
|
||||
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
|
||||
test_examples:
|
||||
touch temp_results_sft_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
|
||||
echo $$?','$${file} >> temp_results_sft_tests.txt; \
|
||||
done
|
||||
|
||||
touch temp_results_dpo_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
|
||||
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
|
||||
done
|
||||
|
||||
test_experimental:
|
||||
pytest -k "experimental"
|
||||
|
@ -1,58 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_dpo/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# This is a hack to get the number of available GPUs
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/trl/scripts/dpo.py \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
@ -1,59 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_sft/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="stanfordnlp/imdb"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# Set your number of GPUs here
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/trl/scripts/sft.py \
|
||||
--model_name $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
@ -13,10 +13,6 @@
|
||||
title: Paper Index
|
||||
- local: experimental
|
||||
title: Experimental
|
||||
- local: how_to_train
|
||||
title: Training FAQ
|
||||
- local: logging
|
||||
title: Understanding Logs
|
||||
title: Conceptual Guides
|
||||
- sections:
|
||||
- local: clis
|
||||
@ -59,10 +55,6 @@
|
||||
title: LoRA Without Regret
|
||||
- local: sentiment_tuning
|
||||
title: Sentiment Tuning
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: detoxifying_a_lm
|
||||
title: Detoxifying a Language Model
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
title: Examples
|
||||
|
@ -1,5 +1,8 @@
|
||||
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
|
||||
|
||||
> [!WARNING]
|
||||
> Best-of-N sampling is deprecated and will be removed in TRL 0.25.0.
|
||||
|
||||
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
|
||||
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
|
||||
|
||||
@ -44,7 +47,7 @@ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=o
|
||||
```
|
||||
|
||||
There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
|
||||
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
|
||||
This is done by passing a [`~transformers.GenerationConfig`] from the `transformers` library at the time of initialization
|
||||
|
||||
```python
|
||||
|
||||
|
@ -112,7 +112,7 @@ trainer.train()
|
||||
|
||||
## Use the accelerator cache optimizer
|
||||
|
||||
When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to `DPOConfig`:
|
||||
When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to [`DPOConfig`]:
|
||||
|
||||
```python
|
||||
training_args = DPOConfig(..., optimize_device_cache=True)
|
||||
|
@ -1,201 +0,0 @@
|
||||
# Detoxifying a Language Model using PPO
|
||||
|
||||
Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it.
|
||||
|
||||
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
||||
|
||||
| File | Description | Colab link |
|
||||
| --- | --- | --- |
|
||||
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
||||
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
||||
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
|
||||
|
||||
## Context
|
||||
|
||||
Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it.
|
||||
|
||||
### Computing toxicity scores
|
||||
|
||||
In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic.
|
||||
Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier.
|
||||
One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one.
|
||||
|
||||
### Selection of models
|
||||
|
||||
We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models:
|
||||
|
||||
* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters)
|
||||
* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
|
||||
* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
|
||||
|
||||
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
|
||||
|
||||
| Model | Mean toxicity score |
|
||||
| --- | --- |
|
||||
| `gpt2` | 0.01602 |
|
||||
| `facebook/opt-350m` | 0.01628 |
|
||||
| `bigscience/bloom-560m` | 0.00767 |
|
||||
| `EleutherAI/gpt-neo-125M` | **0.02016** |
|
||||
|
||||
## Designing the problem
|
||||
|
||||
When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge.
|
||||
|
||||
### Pre-processing the dataset
|
||||
|
||||
The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score.
|
||||
|
||||
A `prompt` example:
|
||||
|
||||
```python
|
||||
{
|
||||
"text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and",
|
||||
"profanity": 0.08360514,
|
||||
"sexually_explicit": 0.118823394,
|
||||
"identity_attack": 0.09608547,
|
||||
"flirtation": 0.23102224,
|
||||
"threat": 0.13215046,
|
||||
"insult": 0.10130461,
|
||||
"severe_toxicity": 0.04068885,
|
||||
"toxicity": 0.16534281
|
||||
}
|
||||
```
|
||||
|
||||
And its `continuation` value:
|
||||
|
||||
```python
|
||||
{
|
||||
"text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.",
|
||||
"severe_toxicity": 0.067997746,
|
||||
"toxicity": 0.1694093,
|
||||
"profanity": 0.11931301,
|
||||
"sexually_explicit": 0.12521537,
|
||||
"identity_attack": 0.09268324,
|
||||
"flirtation": 0.13452998,
|
||||
"threat": 0.31312028,
|
||||
"insult": 0.10761123
|
||||
}
|
||||
```
|
||||
|
||||
We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
|
||||
|
||||
```python
|
||||
train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")
|
||||
|
||||
def filter_fn(sample):
|
||||
toxicity = sample["prompt"]["toxicity"]
|
||||
return toxicity is not None and toxicity > 0.3
|
||||
|
||||
train_dataset = train_dataset.filter(filter_fn, batched=False)
|
||||
```
|
||||
|
||||
### Reward function
|
||||
|
||||
The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not.
|
||||
We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral".
|
||||
|
||||
```python
|
||||
logits = toxicity_model(**toxicity_inputs).logits.float()
|
||||
rewards = (logits[:, 0]).tolist()
|
||||
```
|
||||
|
||||
### Impact of input prompts length
|
||||
|
||||
We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts.
|
||||
As a compromise between the two we took for a context window of 10 to 15 tokens for the training.
|
||||
|
||||

|
||||
|
||||
### How to deal with OOM issues
|
||||
|
||||
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
|
||||
|
||||
* Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `dtype` and specify the mixed precision argument when calling `accelerate config`.
|
||||
|
||||
* Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
|
||||
|
||||

|
||||
|
||||
```python
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
trainer = PPOTrainer(..., ref_model=ref_model)
|
||||
```
|
||||
|
||||
In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
|
||||
|
||||
* One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
|
||||
|
||||
## Training the model
|
||||
|
||||
We have decided to keep 3 models in total that correspond to our best models:
|
||||
|
||||
* [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox)
|
||||
* [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox)
|
||||
* [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox)
|
||||
|
||||
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
|
||||
|
||||

|
||||
|
||||
The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
|
||||
|
||||

|
||||
|
||||
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
|
||||
|
||||
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
|
||||
|
||||

|
||||
|
||||
## Results
|
||||
|
||||
We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity).
|
||||
We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below:
|
||||
|
||||
| Model | Mean toxicity score | Std toxicity score |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
|
||||
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
|
||||
| | | |
|
||||
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
|
||||
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
|
||||
| | | |
|
||||
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
|
||||
| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** |
|
||||
|
||||
<div class="column" style="text-align:center">
|
||||
<figure>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-final-barplot.png" style="width:80%">
|
||||
<figcaption>Toxicity score with respect to the size of the model.</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
Below are few generation examples of `gpt-j-6b-detox` model:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-toxicity-examples.png">
|
||||
</div>
|
||||
|
||||
The evaluation script can be found in [`examples/research_projects/toxicity/scripts/evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
|
||||
### Discussions
|
||||
|
||||
The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we see less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
|
||||
|
||||
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful.
|
||||
|
||||
### Limitations
|
||||
|
||||
We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use.
|
||||
|
||||
## What is next?
|
||||
|
||||
You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [ybelkada/detoxified-lms](https://huggingface.co/spaces/ybelkada/detoxified-lms).
|
@ -70,8 +70,6 @@ Here are also some easier-to-run colab notebooks that you can use to get started
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
We also have some other examples that are less maintained but can be used as a reference in [research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects). Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
||||
|
||||
## Distributed training
|
||||
|
||||
All the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments).
|
||||
|
@ -1,65 +0,0 @@
|
||||
# Training FAQ
|
||||
|
||||
## What Metrics Should I Look at?
|
||||
|
||||
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
|
||||
|
||||
To address this, we recommend focusing on two key metrics first:
|
||||
|
||||
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
|
||||
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
|
||||
|
||||
However, there are more metrics that can be useful for debugging, check out the [logging section](logging).
|
||||
|
||||
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
|
||||
|
||||
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
|
||||
|
||||
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kl-example.png">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
|
||||
</div>
|
||||
|
||||
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
|
||||
|
||||
## What Is the Concern with Negative KL Divergence?
|
||||
|
||||
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in several cases:
|
||||
|
||||
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
|
||||
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
|
||||
|
||||
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
|
||||
|
||||
So how should you generate text for PPO training? Let's have a look!
|
||||
|
||||
## How to generate text for training?
|
||||
|
||||
In order to avoid the KL issues described above we recommend to use the following settings:
|
||||
|
||||
```python
|
||||
generation_kwargs = {
|
||||
"min_length": -1, # don't ignore the EOS token (see above)
|
||||
"top_k": 0.0, # no top-k sampling
|
||||
"top_p": 1.0, # no nucleus sampling
|
||||
"do_sample": True, # yes, we want to sample
|
||||
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
|
||||
"max_new_tokens": 32, # specify how many tokens you want to generate at most
|
||||
}
|
||||
```
|
||||
|
||||
With these settings we usually don't encounter any issues. You can also experiment with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
|
||||
|
||||
## How can debug your own use-case?
|
||||
|
||||
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
|
||||
|
||||
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
|
||||
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
|
||||
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
|
||||
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
|
||||
- **Inspect the reward model**: If your reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
|
||||
|
||||
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
|
@ -13,7 +13,7 @@ pip install trl[judges]
|
||||
|
||||
## Using the provided judges
|
||||
|
||||
TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
|
||||
TRL provides several judges out of the box. For example, you can use the [`HfPairwiseJudge`] to compare two completions using a pre-trained model from the Hugging Face model hub:
|
||||
|
||||
```python
|
||||
from trl import HfPairwiseJudge
|
||||
|
@ -1,106 +0,0 @@
|
||||
# Logging
|
||||
|
||||
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
|
||||
By default, TRL trainers like [`PPOTrainer`] and [`GRPOTrainer`] save a lot of relevant information to supported experiment trackers like Trackio, Weights & Biases (wandb) or TensorBoard.
|
||||
|
||||
Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for `PPOTrainer`, or [`GRPOConfig`] for `GRPOTrainer`):
|
||||
|
||||
```python
|
||||
# For PPOTrainer
|
||||
ppo_config = PPOConfig(
|
||||
# ...,
|
||||
report_to="trackio" # or "wandb" or "tensorboard"
|
||||
)
|
||||
|
||||
# For GRPOTrainer
|
||||
grpo_config = GRPOConfig(
|
||||
# ...,
|
||||
report_to="trackio" # or "wandb" or "tensorboard"
|
||||
)
|
||||
```
|
||||
|
||||
If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., `PPOConfig` or `GRPOConfig`).
|
||||
|
||||
## PPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to `policy/clipfrac_avg` but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: The current learning rate used by the optimizer.
|
||||
* `episode`: The current episode count in the training process.
|
||||
|
||||
### Crucial values
|
||||
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
1. `objective/rlhf_reward`: The mean RLHF reward. This is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
1. `objective/non_score_reward`: The mean reward from non-score-related sources (e.g., KL penalty).
|
||||
|
||||
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
|
||||
|
||||
1. `loss/value_avg`: The average value loss. It will spike / NaN when not going well.
|
||||
1. `val/ratio`: The mean ratio of the current policy probability to the old policy probability. This number should float around 1.0. If this `ratio` is too high (e.g., 2.0 or 1000.0) or too small (e.g., 0.1), it means the updates between consecutive policies are too drastic.
|
||||
1. `policy/clipfrac_avg` and `policy/approxkl_avg`: If `val/ratio` is too high, the `ratio` is going to get clipped, resulting in high `policy/clipfrac_avg` and high `policy/approxkl_avg` as well.
|
||||
1. `objective/kl`: The mean KL divergence. It should stay positive and ideally not too large, so that the policy is not too far away from the reference policy.
|
||||
|
||||
## GRPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data for the GRPO trainer:
|
||||
|
||||
* `num_tokens`: Total number of input tokens processed during training so far.
|
||||
|
||||
### Completions
|
||||
|
||||
* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token).
|
||||
* `completions/min_length`: Minimum length among all generated completions.
|
||||
* `completions/max_length`: Maximum length among all generated completions.
|
||||
* `completions/clipped_ratio`: The ratio of completions that did not end with an EOS token before reaching the maximum generation length (i.e., they were truncated).
|
||||
* `completions/mean_terminated_length`: Mean length of only those completions that successfully ended with an EOS token.
|
||||
* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token.
|
||||
* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token.
|
||||
|
||||
### Rewards
|
||||
|
||||
* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used.
|
||||
* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function.
|
||||
* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages).
|
||||
* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages.
|
||||
|
||||
### Policy and Loss Metrics
|
||||
|
||||
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero.
|
||||
* `entropy`: Average entropy of token predictions across generated completions.
|
||||
* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`):
|
||||
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
|
||||
* If standard GRPOLoss is used (`use_liger_loss: False`):
|
||||
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
|
||||
* `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes.
|
||||
* `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound).
|
||||
* `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes.
|
||||
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
|
||||
|
||||
### Crucial GRPO values
|
||||
|
||||
During GRPO training, monitor these values for insights into performance and stability:
|
||||
|
||||
* `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.
|
||||
* `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability.
|
||||
* `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive.
|
||||
* `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short.
|
||||
* `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources.
|
||||
* `entropy`: Measures how uncertain the policy is in its action choices, higher entropy suggests more exploration. A collapse in entropy means the policy is becoming overconfident and deterministic, often too early. This can stall learning by reducing exploration and making updates overly biased. Stable but non-zero entropy is usually a sign that the policy retains flexibility and continues to explore.
|
@ -90,7 +90,7 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_name,
|
||||
peft_config=lora_config,
|
||||
reward_adapter=rm_adapter_id,
|
||||
load_in_8bit=True,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
|
||||
...
|
||||
|
@ -338,7 +338,7 @@ training_args = DPOConfig(
|
||||
)
|
||||
```
|
||||
|
||||
For the unpaired version, the user should utilize `BCOConfig` and `BCOTrainer`.
|
||||
For the unpaired version, the user should utilize [`BCOConfig`] and [`BCOTrainer`].
|
||||
|
||||
### Self-Play Preference Optimization for Language Model Alignment
|
||||
|
||||
|
@ -3,14 +3,6 @@
|
||||
The notebooks and scripts in these examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
|
||||
For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685).
|
||||
|
||||
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
| File | Task | Description | Colab link |
|
||||
| ---| ---| --- |
|
||||
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
|
||||
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
|
||||
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
|
||||
|
||||
## Installation
|
||||
|
||||
Note: peft is in active development, so we install directly from their Github page.
|
||||
@ -28,7 +20,7 @@ Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scr
|
||||
|
||||
## How to use it?
|
||||
|
||||
Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
|
||||
Simply declare a [`~peft.PeftConfig`] object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
@ -91,7 +91,6 @@ trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
- [SFT Trainer](sft_trainer) - Complete SFT guide
|
||||
- [DPO Trainer](dpo_trainer) - Preference alignment
|
||||
- [GRPO Trainer](grpo_trainer) - Group relative policy optimization
|
||||
- [Training FAQ](how_to_train) - Common questions
|
||||
|
||||
### 🚀 Scale Up
|
||||
|
||||
@ -141,4 +140,4 @@ Try adjusting the learning rate:
|
||||
training_args = SFTConfig(learning_rate=2e-5) # Good starting point
|
||||
```
|
||||
|
||||
For more help, see our [Training FAQ](how_to_train) or open an [issue on GitHub](https://github.com/huggingface/trl/issues).
|
||||
For more help, open an [issue on GitHub](https://github.com/huggingface/trl/issues).
|
||||
|
@ -77,7 +77,7 @@ Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.
|
||||
Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` in the [`SFTConfig`].
|
||||
|
||||
> [!TIP]
|
||||
> In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in `SFTConfig`.
|
||||
> In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in [`SFTConfig`].
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
@ -2,14 +2,14 @@
|
||||
|
||||
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
|
||||
|
||||
## Format rewards
|
||||
## accuracy_reward
|
||||
|
||||
### think_format_reward
|
||||
[[autodoc]] rewards.accuracy_reward
|
||||
|
||||
## think_format_reward
|
||||
|
||||
[[autodoc]] rewards.think_format_reward
|
||||
|
||||
## Other rewards
|
||||
|
||||
### get_soft_overlong_punishment
|
||||
## get_soft_overlong_punishment
|
||||
|
||||
[[autodoc]] rewards.get_soft_overlong_punishment
|
||||
|
@ -1,159 +0,0 @@
|
||||
# Using LLaMA models with TRL
|
||||
|
||||
We've begun rolling out examples to use Meta's LLaMA models in `trl` (see [Meta's LLaMA release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) for the original LLaMA model).
|
||||
|
||||
## Efficient training strategies
|
||||
|
||||
Even training the smallest LLaMA model requires an enormous amount of memory. Some quick math: in bf16, every parameter uses 2 bytes (in fp32 4 bytes) in addition to 8 bytes used, e.g., in the Adam optimizer (see the [performance docs](https://huggingface.co/docs/transformers/perf_train_gpu_one#optimizer) in Transformers for more info). So a 7B parameter model would use `(2+8)*7B=70GB` just to fit in memory and would likely need more when you compute intermediate values such as attention scores. So you couldn’t train the model even on a single 80GB A100 like that. You can use some tricks, like more efficient optimizers of half-precision training, to squeeze a bit more into memory, but you’ll run out sooner or later.
|
||||
|
||||
Another option is to use Parameter-Efficient Fine-Tuning (PEFT) techniques, such as the [`peft`](https://github.com/huggingface/peft) library, which can perform low-rank adaptation (LoRA) on a model loaded in 8-bit.
|
||||
For more on `peft` + `trl`, see the [Peft integration](peft_integration) docs.
|
||||
|
||||
Loading the model in 8bit reduces the memory footprint drastically since you only need one byte per parameter for the weights (e.g. 7B LlaMa is 7GB in memory).
|
||||
Instead of training the original weights directly, LoRA adds small adapter layers on top of some specific layers (usually the attention layers); thus, the number of trainable parameters is drastically reduced.
|
||||
|
||||
In this scenario, a rule of thumb is to allocate ~1.2-1.4GB per billion parameters (depending on the batch size and sequence length) to fit the entire fine-tuning setup.
|
||||
This enables fine-tuning larger models (up to 50-60B scale models on a NVIDIA A100 80GB) at low cost.
|
||||
|
||||
Now we can fit very large models into a single GPU, but the training might still be very slow.
|
||||
The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU.
|
||||
With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs.
|
||||
|
||||

|
||||
|
||||
We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively.
|
||||
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_machines 1 --num_processes 8 my_accelerate_script.py
|
||||
torchrun --nnodes 1 --nproc_per_node 8 my_torch_script.py
|
||||
```
|
||||
|
||||
## Supervised fine-tuning
|
||||
|
||||
Before we start training reward models and tuning our model with RL, it helps if the model is already good in the domain we are interested in.
|
||||
In our case, we want it to answer questions, while for other use cases, we might want it to follow instructions, in which case instruction tuning is a great idea.
|
||||
The easiest way to achieve this is by continuing to train the language model with the language modeling objective on texts from the domain or task.
|
||||
The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences) is enormous (over 10 million instructions), so we can easily train the language model on a subset of it.
|
||||
|
||||
There is nothing special about fine-tuning the model before doing RLHF - it’s just the causal language modeling objective from pretraining that we apply here.
|
||||
To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with an EOS token in between and cut chunks of the context size to fill the batch without any padding.
|
||||
|
||||

|
||||
|
||||
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
|
||||
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
|
||||
|
||||
```python
|
||||
# load model in 8bit
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path,
|
||||
load_in_8bit=True,
|
||||
device_map={"": Accelerator().local_process_index}
|
||||
)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# add LoRA to model
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = get_peft_model(model, config)
|
||||
```
|
||||
|
||||
We train the model for a few thousand steps with the causal language modeling objective and save the model.
|
||||
Since we will tune the model again with different objectives, we merge the adapter weights with the original model weights.
|
||||
|
||||
**Disclaimer:** due to LLaMA's license, we release only the adapter weights for this and the model checkpoints in the following sections.
|
||||
You can apply for access to the base model's weights by filling out Meta AI's [form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) and then converting them to the 🤗 Transformers format by running this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py).
|
||||
Note that you'll also need to install 🤗 Transformers from source until the `v4.28` is released.
|
||||
|
||||
Now that we have fine-tuned the model for the task, we are ready to train a reward model.
|
||||
|
||||
## Reward modeling and human preferences
|
||||
|
||||
In principle, we could fine-tune the model using RLHF directly with the human annotations.
|
||||
However, this would require us to send some samples to humans for rating after each optimization iteration.
|
||||
This is expensive and slow due to the number of training samples needed for convergence and the inherent latency of human reading and annotator speed.
|
||||
|
||||
A trick that works well instead of direct feedback is training a reward model on human annotations collected before the RL loop.
|
||||
The goal of the reward model is to imitate how a human would rate a text. There are several possible strategies to build a reward model: the most straightforward way would be to predict the annotation (e.g. a rating score or a binary value for “good”/”bad”).
|
||||
In practice, what works better is to predict the ranking of two examples, where the reward model is presented with two candidates `(y_k, y_j)` for a given prompt `x` and has to predict which one would be rated higher by a human annotator.
|
||||
|
||||
With the StackExchange dataset, we can infer which of the two answers was preferred by the users based on the score.
|
||||
With that information and the loss defined above, we can then modify the `transformers.Trainer` by adding a custom loss function.
|
||||
|
||||
```python
|
||||
class RewardTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
|
||||
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
|
||||
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
|
||||
if return_outputs:
|
||||
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
|
||||
return loss
|
||||
```
|
||||
|
||||
We utilize a subset of a 100,000 pair of candidates and evaluate on a held-out set of 50,000. With a modest training batch size of 4, we train the Llama model using the LoRA `peft` adapter for a single epoch using the Adam optimizer with BF16 precision. Our LoRA configuration is:
|
||||
|
||||
```python
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.SEQ_CLS,
|
||||
inference_mode=False,
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
```
|
||||
|
||||
As detailed in the next section, the resulting adapter can be merged into the frozen model and saved for further downstream use.
|
||||
|
||||
## Reinforcement Learning from Human Feedback
|
||||
|
||||
With the fine-tuned language model and the reward model at hand, we are now ready to run the RL loop. It follows roughly three steps:
|
||||
|
||||
1. Generate responses from prompts,
|
||||
2. Rate the responses with the reward model,
|
||||
3. Run a reinforcement learning policy-optimization step with the ratings.
|
||||
|
||||
The Query and Response prompts are templated as follows before being tokenized and passed to the model:
|
||||
|
||||
```bash
|
||||
Question: <Query>
|
||||
|
||||
Answer: <Response>
|
||||
```
|
||||
|
||||
The same template was used for SFT, RM and RLHF stages.
|
||||
Once more, we utilize `peft` for memory-efficient training, which offers an extra advantage in the RLHF context.
|
||||
Here, the reference model and policy share the same base, the SFT model, which we load in 8-bit and freeze during training.
|
||||
We exclusively optimize the policy's LoRA weights using PPO while sharing the base model's weights.
|
||||
|
||||
```python
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
question_tensors = batch["input_ids"]
|
||||
|
||||
# sample from the policy and to generate responses
|
||||
response_tensors = ppo_trainer.generate(
|
||||
question_tensors,
|
||||
return_prompt=False,
|
||||
length_sampler=output_length_sampler,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
|
||||
|
||||
# Compute sentiment score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
|
||||
# Log stats to Wandb
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
```
|
||||
|
||||
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
|
694
examples/notebooks/grpo_qwen3_vl.ipynb
Normal file
694
examples/notebooks/grpo_qwen3_vl.ipynb
Normal file
@ -0,0 +1,694 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "-J8iGzLf4rUJ"
|
||||
},
|
||||
"source": [
|
||||
"# GRPO Qwen3-VL with QLoRA using TRL\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_qwen3_vl.ipynb)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can fine-tune cutting edge vision language models. It comes with support for quantized parameter efficient fine-tuning technique **QLoRA**, so we can use free Colab (T4 GPU) to fine-tune models like [Qwen3-VL](https://huggingface.co/collections/Qwen/qwen3-vl-68d2a7c1b8a8afce4ebd2dbe).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n",
|
||||
"- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n",
|
||||
"- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n",
|
||||
"- [More Qwen3-VL Fine-tuning Examples (including TRL scripts)](https://github.com/QwenLM/Qwen3-VL/tree/main/qwen-vl-finetune/)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "NvrzGRnu48Vz"
|
||||
},
|
||||
"source": [
|
||||
"## Install dependencies\n",
|
||||
"\n",
|
||||
"We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "8CfZlUevmkg7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -Uq \"trl[peft]\" bitsandbytes trackio math_verify"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gpzI6omi7728"
|
||||
},
|
||||
"source": [
|
||||
"### Log in to Hugging Face\n",
|
||||
"\n",
|
||||
"Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "4Ncx0wYtnYCW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "V_Zylc4t79-n"
|
||||
},
|
||||
"source": [
|
||||
"## Load dataset\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"We'll load the [**lmms-lab/multimodal-open-r1-8k-verified**](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset from the Hugging Face Hub using the `datasets` library.\n",
|
||||
"\n",
|
||||
"This dataset contains maths problems with the image representing the problem, along with the solution in thinking format specially tailored for VLMs. By training our model with this dataset, it'll improve its maths and thinking reasoning.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "TzXogU24F_QR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
|
||||
"train_dataset = load_dataset(dataset_id, split='train[:5%]')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gVV7RoRN8zk5"
|
||||
},
|
||||
"source": [
|
||||
"In addition to the `problem` and `image` columns, we also include a custom system prompt to tell the model how we'd like the generation.\n",
|
||||
"\n",
|
||||
"The system prompt is extracted from DeepSeek R1. Refer to [this previous recipe](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) for more details.\n",
|
||||
"\n",
|
||||
"We convert the dataset samples into conversation samples, including the system prompt and one image and problem description per sample, since this is how the GRPO trainer expects them.\n",
|
||||
"\n",
|
||||
"We also set `padding_side=\"left\"` to ensure that generated completions during training are concatenated directly after the prompt, which is essential for GRPO to correctly compare token-level probabilities between preferred and rejected responses."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ZT1JfiiTGExB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoProcessor\n",
|
||||
"\n",
|
||||
"model_name = \"Qwen/Qwen3-VL-4B-Instruct\" # \"Qwen/Qwen3-VL-8B-Instruct\"\n",
|
||||
"processor = AutoProcessor.from_pretrained(model_name, padding_side=\"left\")\n",
|
||||
"\n",
|
||||
"SYSTEM_PROMPT = (\n",
|
||||
" \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. \"\n",
|
||||
" \"You first think about the reasoning process as an internal monologue and then provide the user with the answer. \"\n",
|
||||
" \"Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_conversation(example):\n",
|
||||
" conversation = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": [{\"type\": \"text\", \"text\": SYSTEM_PROMPT}],\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": example[\"image\"]},\n",
|
||||
" {\"type\": \"text\", \"text\": example[\"problem\"]},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)\n",
|
||||
" return {\n",
|
||||
" \"prompt\": prompt,\n",
|
||||
" \"image\": example[\"image\"],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"train_dataset = train_dataset.map(make_conversation)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "5txAuMAa8ock"
|
||||
},
|
||||
"source": [
|
||||
"Let's review one example to understand the internal structure:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "PDXQd5Jk2Bqe"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "hzSR_56wxKDA"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = train_dataset.remove_columns(['problem', 'original_question', 'original_answer'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "T9rCkeqDODba"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YY3uMp909Eqy"
|
||||
},
|
||||
"source": [
|
||||
"## Load model and configure LoRA/QLoRA\n",
|
||||
"\n",
|
||||
"This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "gt05dgXgm9QR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import Qwen3VLForConditionalGeneration, BitsAndBytesConfig\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"model = Qwen3VLForConditionalGeneration.from_pretrained(\n",
|
||||
" model_name, dtype=\"auto\",\n",
|
||||
" device_map=\"auto\",\n",
|
||||
" quantization_config=BitsAndBytesConfig(\n",
|
||||
" load_in_4bit=True,\n",
|
||||
" bnb_4bit_use_double_quant=True,\n",
|
||||
" bnb_4bit_quant_type=\"nf4\",\n",
|
||||
" bnb_4bit_compute_dtype=torch.float16\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "WZGf-GF09Gsc"
|
||||
},
|
||||
"source": [
|
||||
"The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ME1im5gh2LFg"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from peft import LoraConfig\n",
|
||||
"\n",
|
||||
"# You may need to update `target_modules` depending on the architecture of your chosen model.\n",
|
||||
"# For example, different VLMs might have different attention/projection layer names.\n",
|
||||
"peft_config = LoraConfig(\n",
|
||||
" r=8,\n",
|
||||
" lora_alpha=32,\n",
|
||||
" lora_dropout=0.1,\n",
|
||||
" target_modules=[\"q_proj\", \"v_proj\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "mDq4V6dN9MGk"
|
||||
},
|
||||
"source": [
|
||||
"## Train model\n",
|
||||
"\n",
|
||||
"We'll configure **GRPO** using `GRPOConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL GRPOConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.GRPOConfig).\n",
|
||||
"\n",
|
||||
"First, we need to define the rewards functions that the training algorithm will use to improve the model. In this case, we'll include two reward functions.\n",
|
||||
"We'll use a format reward that will reward the model when the output includes `<think>` and `<answer>` tags and additionally a length-based reward to discourage overthinking. Both functions have been extracted from [here](https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Dqp3TfUwHUxW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re\n",
|
||||
"\n",
|
||||
"def format_reward(completions, **kwargs):\n",
|
||||
" \"\"\"Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags.\"\"\"\n",
|
||||
" pattern = r\"^<think>\\n.*?\\n</think>\\n<answer>\\n.*?\\n</answer>$\"\n",
|
||||
" matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]\n",
|
||||
" return [1.0 if match else 0.0 for match in matches]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rxNPUp7RBFcz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from math_verify import LatexExtractionConfig, parse, verify\n",
|
||||
"from latex2sympy2_extended import NormalizationConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def len_reward(completions, solution, **kwargs) -> float:\n",
|
||||
" \"\"\"Compute length-based rewards to discourage overthinking and promote token efficiency.\n",
|
||||
"\n",
|
||||
" Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" completions: List of model completions\n",
|
||||
" solution: List of ground truth solutions\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" List of rewards where:\n",
|
||||
" - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)\n",
|
||||
" - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))\n",
|
||||
" \"\"\"\n",
|
||||
" contents = completions\n",
|
||||
"\n",
|
||||
" # First check correctness of answers\n",
|
||||
" correctness = []\n",
|
||||
" for content, sol in zip(contents, solution):\n",
|
||||
" gold_parsed = parse(\n",
|
||||
" sol,\n",
|
||||
" extraction_mode=\"first_match\",\n",
|
||||
" extraction_config=[LatexExtractionConfig()],\n",
|
||||
" )\n",
|
||||
" if len(gold_parsed) == 0:\n",
|
||||
" # Skip unparseable examples\n",
|
||||
" correctness.append(True) # Treat as correct to avoid penalizing\n",
|
||||
" print(\"Failed to parse gold solution: \", sol)\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" answer_parsed = parse(\n",
|
||||
" content,\n",
|
||||
" extraction_config=[\n",
|
||||
" LatexExtractionConfig(\n",
|
||||
" normalization_config=NormalizationConfig(\n",
|
||||
" nits=False,\n",
|
||||
" malformed_operators=False,\n",
|
||||
" basic_latex=True,\n",
|
||||
" equations=True,\n",
|
||||
" boxed=True,\n",
|
||||
" units=True,\n",
|
||||
" ),\n",
|
||||
" boxed_match_priority=0,\n",
|
||||
" try_extract_without_anchor=False,\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
" extraction_mode=\"first_match\",\n",
|
||||
" )\n",
|
||||
" correctness.append(verify(answer_parsed, gold_parsed))\n",
|
||||
"\n",
|
||||
" # Calculate lengths\n",
|
||||
" lengths = [len(content) for content in contents]\n",
|
||||
" min_len = min(lengths)\n",
|
||||
" max_len = max(lengths)\n",
|
||||
"\n",
|
||||
" # If all responses have the same length, return zero rewards\n",
|
||||
" if max_len == min_len:\n",
|
||||
" return [0.0] * len(completions)\n",
|
||||
"\n",
|
||||
" rewards = []\n",
|
||||
" for length, is_correct in zip(lengths, correctness):\n",
|
||||
" lambda_val = 0.5 - (length - min_len) / (max_len - min_len)\n",
|
||||
"\n",
|
||||
" if is_correct:\n",
|
||||
" reward = lambda_val\n",
|
||||
" else:\n",
|
||||
" reward = min(0, lambda_val)\n",
|
||||
"\n",
|
||||
" rewards.append(float(reward))\n",
|
||||
"\n",
|
||||
" return rewards\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "9xBL7Rni9LZb"
|
||||
},
|
||||
"source": [
|
||||
"After defining the reward function(s), we can define the `GRPOConfig`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "OEmRM0rIHXQ4"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from trl import GRPOConfig\n",
|
||||
"\n",
|
||||
"output_dir = \"Qwen3-VL-4B-Instruct-trl-grpo\"\n",
|
||||
"\n",
|
||||
"# Configure training arguments using GRPOConfig\n",
|
||||
"training_args = GRPOConfig(\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" #num_train_epochs=1,\n",
|
||||
" max_steps=100, # Number of dataset passes. For full trainings, use `num_train_epochs` instead\n",
|
||||
"\n",
|
||||
" # Parameters that control the data preprocessing\n",
|
||||
" per_device_train_batch_size=2,\n",
|
||||
" max_completion_length=1024, # default: 256 # Max completion length produced during training\n",
|
||||
" num_generations=2, # 2, # default: 8 # Number of generations produced during trainig for comparison\n",
|
||||
" max_prompt_length=2048, # default: 512 # Max prompt lenght of the input prompt used for generation during training\n",
|
||||
"\n",
|
||||
" fp16=True,\n",
|
||||
"\n",
|
||||
" # Parameters related to reporting and saving\n",
|
||||
" output_dir=output_dir, # Where to save model checkpoints and logs\n",
|
||||
" logging_steps=1, # Log training metrics every N steps\n",
|
||||
" report_to=\"trackio\", # Experiment tracking tool\n",
|
||||
"\n",
|
||||
" # Hub integration\n",
|
||||
" push_to_hub=True,\n",
|
||||
" log_completions=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "O0q3myQg927v"
|
||||
},
|
||||
"source": [
|
||||
"Configure the GRPO Trainer. We pass the previously configured `training_args`. We don't use eval dataset to maintain memory usage low but you can configure it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "z5JxkmS9HqD5",
|
||||
"outputId": "2b39338e-2194-4829-fc54-5e286566fd28"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.12/dist-packages/peft/mapping_func.py:73: UserWarning: You are trying to modify a model with PEFT for a second time. If you want to reload the model with a different config, make sure to call `.unload()` before.\n",
|
||||
" warnings.warn(\n",
|
||||
"/usr/local/lib/python3.12/dist-packages/peft/tuners/tuners_utils.py:196: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from trl import GRPOTrainer\n",
|
||||
"\n",
|
||||
"trainer = GRPOTrainer(\n",
|
||||
" model=model,\n",
|
||||
" reward_funcs=[format_reward, len_reward],\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_dataset,\n",
|
||||
" peft_config=peft_config,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kQC7Q5kg95xq"
|
||||
},
|
||||
"source": [
|
||||
"Show memory stats before training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "naG_7qlYyBP6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gpu_stats = torch.cuda.get_device_properties(0)\n",
|
||||
"start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
||||
"max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
|
||||
"\n",
|
||||
"print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
|
||||
"print(f\"{start_gpu_memory} GB of memory reserved.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YazYtLAe97Dc"
|
||||
},
|
||||
"source": [
|
||||
"And train!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "pbJXrhA0ywra"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer_stats = trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "SmcYN5yW99IP"
|
||||
},
|
||||
"source": [
|
||||
"Show memory stats after training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "TrrwP4ADMmrp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
||||
"used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
|
||||
"used_percentage = round(used_memory / max_memory * 100, 3)\n",
|
||||
"lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
|
||||
"\n",
|
||||
"print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
|
||||
"print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
|
||||
"print(f\"Peak reserved memory = {used_memory} GB.\")\n",
|
||||
"print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
|
||||
"print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
|
||||
"print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "saarW87Y9_-R"
|
||||
},
|
||||
"source": [
|
||||
"## Saving fine tuned model\n",
|
||||
"\n",
|
||||
"In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "71A8aqEyyETA"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.save_model(output_dir)\n",
|
||||
"trainer.push_to_hub(dataset_name=dataset_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "nfqvO0qw-OvS"
|
||||
},
|
||||
"source": [
|
||||
"## Load the fine-tuned model and run inference\n",
|
||||
"\n",
|
||||
"Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "R8T2uFQVyFeH"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import Qwen3VLForConditionalGeneration, AutoProcessor\n",
|
||||
"from peft import PeftModel\n",
|
||||
"\n",
|
||||
"base_model = model_name\n",
|
||||
"adapter_model = f\"{output_dir}\" # Replace with your HF username or organization\n",
|
||||
"\n",
|
||||
"model = Qwen3VLForConditionalGeneration.from_pretrained(base_model, dtype=\"auto\", device_map=\"auto\")\n",
|
||||
"model = PeftModel.from_pretrained(model, adapter_model)\n",
|
||||
"\n",
|
||||
"processor = AutoProcessor.from_pretrained(base_model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dPBHP0CpLa6K"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "cG5-ccGRyHgo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
|
||||
"train_dataset = load_dataset(dataset_id, split='train[:5%]')\n",
|
||||
"\n",
|
||||
"problem = train_dataset[0]['problem']\n",
|
||||
"image = train_dataset[0]['image']\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\", \"content\": [\n",
|
||||
" {\"type\": \"text\", \"text\": SYSTEM_PROMPT}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": image},\n",
|
||||
" {\"type\": \"text\", \"text\": problem},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "r_70q_8lLgfV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "PX92MjqlyIwB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"# Inference: Generation of the output\n",
|
||||
"generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
|
||||
"generated_ids_trimmed = [\n",
|
||||
" out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
|
||||
"]\n",
|
||||
"output_text = processor.batch_decode(\n",
|
||||
" generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
||||
")\n",
|
||||
"print(output_text)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
515
examples/notebooks/sft_qwen_vl.ipynb
Normal file
515
examples/notebooks/sft_qwen_vl.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -1,7 +0,0 @@
|
||||
# Research projects that use TRL
|
||||
|
||||
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
|
||||
|
||||
- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity)
|
||||
- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama)
|
||||
- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2)
|
@ -1,15 +0,0 @@
|
||||
# LayerSkip Training Recipe
|
||||
|
||||
Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710).
|
||||
|
||||
## Run training
|
||||
```
|
||||
cd scripts
|
||||
python layer_skip_sft.py
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
```
|
||||
cd scripts
|
||||
python benchmark_layer_skip.py
|
||||
```
|
@ -1,77 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def generate_tokens(model, inputs):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def generate_tokens_with_assistance(model, inputs, assistant_early_exit):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
assistant_early_exit=assistant_early_exit,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ckpt = config.hub_model_id
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
|
||||
prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "
|
||||
|
||||
results = []
|
||||
label = "Generation Times"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens(model, inputs)",
|
||||
setup="from __main__ import generate_tokens",
|
||||
globals={"model": model, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label="no layer skip",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
for i in range(1, model.config.num_hidden_layers):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)",
|
||||
setup="from __main__ import generate_assistant_tokens",
|
||||
globals={"model": model, "assistant_early_exit": i, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label=f"layer skip {i}",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
benchmark.Compare(results).print()
|
@ -1,28 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub import whoami
|
||||
|
||||
|
||||
model_name = "unsloth/Llama-3.2-3B"
|
||||
tokenizer_name = "unsloth/Llama-3.2-3B"
|
||||
dataset_name = "WillHeld/top_v2"
|
||||
|
||||
output_root_dir = "./checkpoints/"
|
||||
hub_model_id = f"{whoami()['name']}/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}"
|
||||
output_dir = f"{output_root_dir}/{hub_model_id}"
|
||||
|
||||
per_device_train_batch_size = 8
|
||||
gradient_accumulation_steps = 1
|
||||
learning_rate = 2e-5
|
@ -1,48 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
class LayerSkipSFTTrainer(SFTTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.early_exit_layer = 0 # initialize with 0
|
||||
self.always_last_layer = True
|
||||
self.early_exit_loss_scale = 1.0
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
self.early_exit_layer = (
|
||||
self.early_exit_layer % (model.config.num_hidden_layers - 1)
|
||||
) + 1 # rotates between [1, num_hidden_layers-1]
|
||||
bs, seqlen = inputs.input_ids.shape
|
||||
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, output_hidden_states=True)
|
||||
|
||||
hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype)
|
||||
if self.early_exit_layer != model.config.num_hidden_layers:
|
||||
hidden_state = model.model.norm(hidden_state)
|
||||
logits = model.lm_head(hidden_state)
|
||||
loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)
|
||||
|
||||
if self.always_last_layer:
|
||||
loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
|
||||
loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last
|
||||
# normalize loss scales
|
||||
loss = loss / (1.0 + self.early_exit_loss_scale)
|
||||
else:
|
||||
loss = loss_early
|
||||
|
||||
return loss
|
@ -1,90 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from custom_trainer import LayerSkipSFTTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl import DataCollatorForCompletionOnlyLM, SFTConfig
|
||||
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}"
|
||||
|
||||
# Inject eos_token as a string before tokenization, because they are not always added
|
||||
# See: https://github.com/huggingface/transformers/issues/22794 and
|
||||
# https://github.com/huggingface/trl/issues/1623
|
||||
if tokenizer.eos_token: # usually something like "</s>" for GPT2 or "<|endoftext|>"
|
||||
text += f"{tokenizer.eos_token}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load the dataset
|
||||
print("[INFO] loading the dataset...")
|
||||
train_dataset = load_dataset(config.dataset_name, split="train")
|
||||
|
||||
print(f"output_root_dir: {config.output_root_dir}")
|
||||
print(f"hub_model_id: {config.hub_model_id}")
|
||||
|
||||
# load the model and tokenizer
|
||||
print("[INFO] loading the model and tokenizer...")
|
||||
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)
|
||||
|
||||
# adding pad and eos tokens if not provided in the tokenizer
|
||||
if tokenizer.pad_token is None:
|
||||
# Add '[PAD]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token:
|
||||
# Add '[EOS]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
response_template = " ### Response:"
|
||||
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
|
||||
|
||||
args = SFTConfig(
|
||||
do_train=True,
|
||||
bf16=True,
|
||||
max_seq_length=None,
|
||||
per_device_train_batch_size=config.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
learning_rate=config.learning_rate,
|
||||
packing=False,
|
||||
num_train_epochs=1.0,
|
||||
report_to="none",
|
||||
push_to_hub=True,
|
||||
hub_model_id=config.hub_model_id,
|
||||
output_dir=config.output_dir,
|
||||
save_steps=1000,
|
||||
save_total_limit=2,
|
||||
)
|
||||
|
||||
trainer = LayerSkipSFTTrainer(
|
||||
model,
|
||||
train_dataset=train_dataset,
|
||||
args=args,
|
||||
formatting_func=formatting_prompts_func,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
@ -1,18 +0,0 @@
|
||||
# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model.
|
||||
There were three main steps to the training process:
|
||||
1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
|
||||
2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=<LLAMA_SE_MODEL>`
|
||||
3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model:
|
||||
- `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/research_projects/stack_llama/scripts/rl_training.py --log_with=wandb --model_name=<LLAMA_SE_MODEL> --reward_model_name=<LLAMA_SE_RM_MODEL> --adafactor=False --tokenizer_name=<LLAMA_TOKENIZER> --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam`
|
||||
|
||||
|
||||
LoRA layers were using at all stages to reduce memory requirements.
|
||||
At each stage the peft adapter layers were merged with the base model, using:
|
||||
```shell
|
||||
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ
|
||||
```
|
||||
Note that this script requires `peft>=0.3.0`.
|
||||
|
||||
For access to the base llama-7b model, please see Meta's [release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) and [request form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform).
|
@ -1,60 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the
|
||||
merged model.
|
||||
"""
|
||||
|
||||
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
|
||||
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
|
||||
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
|
||||
assert script_args.base_model_name is not None, "please provide the name of the Base model"
|
||||
assert script_args.output_name is not None, "please provide the output name of the merged model"
|
||||
|
||||
peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
|
||||
if peft_config.task_type == "SEQ_CLS":
|
||||
# The sequence classification task is used for the reward model in PPO
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
script_args.base_model_name, num_labels=1, dtype=torch.bfloat16
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(script_args.base_model_name, return_dict=True, dtype=torch.bfloat16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)
|
||||
|
||||
# Load the PEFT model
|
||||
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
|
||||
model.eval()
|
||||
|
||||
model = model.merge_and_unload()
|
||||
|
||||
model.save_pretrained(f"{script_args.output_name}")
|
||||
tokenizer.save_pretrained(f"{script_args.output_name}")
|
||||
model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False)
|
@ -1,321 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
|
||||
"""
|
||||
|
||||
local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})
|
||||
resume_from_checkpoint: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "If you want to resume training where it left off."},
|
||||
)
|
||||
deepspeed: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
|
||||
},
|
||||
)
|
||||
per_device_train_batch_size: Optional[int] = field(default=4)
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1)
|
||||
gradient_accumulation_steps: Optional[int] = field(default=1)
|
||||
learning_rate: Optional[float] = field(default=2e-5)
|
||||
weight_decay: Optional[float] = field(default=0.001)
|
||||
model_name: Optional[str] = field(
|
||||
default="gpt2",
|
||||
metadata={
|
||||
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
|
||||
},
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The tokenizer for your model, if left empty will use the default for your model",
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
|
||||
},
|
||||
)
|
||||
num_train_epochs: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of training epochs for the reward model."},
|
||||
)
|
||||
train_subset: Optional[int] = field(
|
||||
default=100000,
|
||||
metadata={"help": "The size of the subset of the training data to use"},
|
||||
)
|
||||
eval_subset: Optional[int] = field(
|
||||
default=50000,
|
||||
metadata={"help": "The size of the subset of the eval data to use"},
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enables gradient checkpointing."},
|
||||
)
|
||||
optim: Optional[str] = field(
|
||||
default="adamw_hf",
|
||||
metadata={"help": "The optimizer to use."},
|
||||
)
|
||||
lr_scheduler_type: Optional[str] = field(
|
||||
default="linear",
|
||||
metadata={"help": "The lr scheduler"},
|
||||
)
|
||||
max_length: Optional[int] = field(default=512)
|
||||
eval_first_step: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to run eval after the first step"},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
set_seed(script_args.seed)
|
||||
# Load the human stack-exchange-paired dataset for tuning the reward model.
|
||||
train_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/reward", split="train", verification_mode="no_checks"
|
||||
)
|
||||
if script_args.train_subset > 0:
|
||||
train_dataset = train_dataset.select(range(script_args.train_subset))
|
||||
eval_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train", verification_mode="no_checks"
|
||||
)
|
||||
if script_args.eval_subset > 0:
|
||||
eval_dataset = eval_dataset.select(range(script_args.eval_subset))
|
||||
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
|
||||
model_name_split = script_args.model_name.split("/")[-1]
|
||||
output_name = (
|
||||
f"{model_name_split}_peft_stack-exchange-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
num_train_epochs=script_args.num_train_epochs,
|
||||
weight_decay=script_args.weight_decay,
|
||||
eval_strategy="steps",
|
||||
eval_steps=500,
|
||||
save_strategy="steps",
|
||||
save_steps=500,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
deepspeed=script_args.deepspeed,
|
||||
local_rank=script_args.local_rank,
|
||||
remove_unused_columns=False,
|
||||
label_names=[],
|
||||
bf16=script_args.bf16,
|
||||
logging_strategy="steps",
|
||||
optim=script_args.optim,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
|
||||
# Load the value-head model and tokenizer.
|
||||
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.SEQ_CLS,
|
||||
inference_mode=False,
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(script_args.model_name, num_labels=1, dtype=torch.bfloat16)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# Need to do this for gpt2, because it doesn't have an official pad token.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
model.config.use_cache = not script_args.gradient_checkpointing
|
||||
num_proc = 24 # Can adjust to be higher if you have more processors.
|
||||
original_columns = train_dataset.column_names
|
||||
|
||||
|
||||
# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other.
|
||||
# Then tokenize the dataset.
|
||||
def preprocess_function(examples):
|
||||
new_examples = {
|
||||
"input_ids_j": [],
|
||||
"attention_mask_j": [],
|
||||
"input_ids_k": [],
|
||||
"attention_mask_k": [],
|
||||
}
|
||||
for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]):
|
||||
tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True)
|
||||
tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True)
|
||||
|
||||
new_examples["input_ids_j"].append(tokenized_j["input_ids"])
|
||||
new_examples["attention_mask_j"].append(tokenized_j["attention_mask"])
|
||||
new_examples["input_ids_k"].append(tokenized_k["input_ids"])
|
||||
new_examples["attention_mask_k"].append(tokenized_k["attention_mask"])
|
||||
|
||||
return new_examples
|
||||
|
||||
|
||||
# preprocess the dataset and filter out QAs that are longer than script_args.max_length
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
|
||||
# We need to define a special data collator that batches the data in our j vs k format.
|
||||
@dataclass
|
||||
class RewardDataCollatorWithPadding:
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
features_j = []
|
||||
features_k = []
|
||||
for feature in features:
|
||||
features_j.append(
|
||||
{
|
||||
"input_ids": feature["input_ids_j"],
|
||||
"attention_mask": feature["attention_mask_j"],
|
||||
}
|
||||
)
|
||||
features_k.append(
|
||||
{
|
||||
"input_ids": feature["input_ids_k"],
|
||||
"attention_mask": feature["attention_mask_k"],
|
||||
}
|
||||
)
|
||||
batch_j = self.tokenizer.pad(
|
||||
features_j,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch_k = self.tokenizer.pad(
|
||||
features_k,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch = {
|
||||
"input_ids_j": batch_j["input_ids"],
|
||||
"attention_mask_j": batch_j["attention_mask"],
|
||||
"input_ids_k": batch_k["input_ids"],
|
||||
"attention_mask_k": batch_k["attention_mask"],
|
||||
"return_loss": True,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
# Define the metric that we'll use for validation.
|
||||
accuracy = evaluate.load("accuracy")
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, _ = eval_pred
|
||||
# Here, predictions is rewards_j and rewards_k.
|
||||
# We want to see how much of the time rewards_j > rewards_k.
|
||||
predictions = np.argmax(predictions, axis=0)
|
||||
labels = np.zeros(predictions.shape)
|
||||
return accuracy.compute(predictions=predictions, references=labels)
|
||||
|
||||
|
||||
class RewardTrainer(Trainer):
|
||||
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://huggingface.co/papers/2203.02155
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
|
||||
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
|
||||
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
|
||||
if return_outputs:
|
||||
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
|
||||
return loss
|
||||
|
||||
|
||||
# Train the model, woohoo.
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
|
||||
if script_args.eval_first_step:
|
||||
|
||||
class EvaluateFirstStepCallback(TrainerCallback):
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
|
||||
trainer.add_callback(EvaluateFirstStepCallback())
|
||||
|
||||
trainer.train(script_args.resume_from_checkpoint)
|
||||
|
||||
print("Saving last checkpoint of the model")
|
||||
model.save_pretrained(output_name + "_peft_last_checkpoint")
|
@ -1,270 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline, set_seed
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine-tune with PPO
|
||||
"""
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
# models like gpt-neo* models are more suitable.
|
||||
model_name: Optional[str] = field(default="", metadata={"help": "the model name"})
|
||||
tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"})
|
||||
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
|
||||
output_max_length: Optional[int] = field(default=128, metadata={"help": "maximum length for generation"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"})
|
||||
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
|
||||
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
|
||||
reward_baseline: Optional[float] = field(
|
||||
default=0.0,
|
||||
metadata={"help": "a baseline value that is subtracted from the reward"},
|
||||
)
|
||||
batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"})
|
||||
save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"})
|
||||
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the seed"})
|
||||
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"})
|
||||
init_kl_coef: Optional[float] = field(
|
||||
default=0.2,
|
||||
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
|
||||
)
|
||||
|
||||
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
||||
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
|
||||
reward_model_name = script_args.reward_model_name
|
||||
dataset_name = "lvwerra/stack-exchange-paired"
|
||||
config = PPOConfig(
|
||||
steps=script_args.steps,
|
||||
model_name=script_args.model_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
log_with=script_args.log_with,
|
||||
batch_size=script_args.batch_size,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
optimize_device_cache=True,
|
||||
early_stopping=script_args.early_stopping,
|
||||
target_kl=script_args.target_kl,
|
||||
ppo_epochs=script_args.ppo_epochs,
|
||||
seed=script_args.seed,
|
||||
init_kl_coef=script_args.init_kl_coef,
|
||||
adap_kl_ctrl=script_args.adap_kl_ctrl,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/rl", split="train", verification_mode="no_checks"
|
||||
)
|
||||
train_dataset = train_dataset.select(range(100000))
|
||||
original_columns = train_dataset.column_names
|
||||
|
||||
# We then define the arguments to pass to the sentiment analysis pipeline.
|
||||
# We set `return_all_scores` to True to get the sentiment score for each token.
|
||||
sent_kwargs = {
|
||||
"return_all_scores": True,
|
||||
"function_to_apply": "none",
|
||||
"batch_size": 16,
|
||||
"truncation": True,
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name)
|
||||
# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
|
||||
# only for this model.
|
||||
|
||||
if getattr(tokenizer, "pad_token", None) is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
# Below is an example function to build the dataset. In our case, we use the IMDB dataset
|
||||
# from the `datasets` library. One should customize this function to train the model on
|
||||
# its own dataset.
|
||||
def build_dataset(
|
||||
tokenizer,
|
||||
dataset_name="lvwerra/stack-exchange-paired",
|
||||
):
|
||||
"""
|
||||
Build dataset for training. This builds the dataset from `load_dataset`, one should
|
||||
customize this function to train the model on its own dataset.
|
||||
|
||||
Args:
|
||||
tokenizer (`transformers.PreTrainedTokenizer`):
|
||||
The tokenizer used for the model.
|
||||
dataset_name (`str`):
|
||||
The name of the dataset to be loaded.
|
||||
|
||||
Returns:
|
||||
dataloader (`torch.utils.data.DataLoader`):
|
||||
The dataloader for the dataset.
|
||||
"""
|
||||
|
||||
num_proc = 24
|
||||
|
||||
def preprocess_function(examples):
|
||||
new_examples = {
|
||||
"query": [],
|
||||
"input_ids": [],
|
||||
}
|
||||
for question in examples["question"]:
|
||||
query = "Question: " + question + "\n\nAnswer: "
|
||||
tokenized_question = tokenizer(query, truncation=True)
|
||||
new_examples["query"].append(query)
|
||||
new_examples["input_ids"].append(tokenized_question["input_ids"])
|
||||
|
||||
return new_examples
|
||||
|
||||
ds = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False, num_proc=num_proc)
|
||||
|
||||
ds.set_format(type="torch")
|
||||
return ds
|
||||
|
||||
|
||||
# We retrieve the dataloader by calling the `build_dataset` function.
|
||||
dataset = build_dataset(tokenizer)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
set_seed(config.seed)
|
||||
|
||||
# Now let's build the model, the reference model, and the tokenizer.
|
||||
current_device = Accelerator().local_process_index
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
load_in_8bit=script_args.load_in_8bit,
|
||||
device_map={"": current_device},
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
optimizer = None
|
||||
if script_args.adafactor:
|
||||
optimizer = Adafactor(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
warmup_init=False,
|
||||
lr=config.learning_rate,
|
||||
)
|
||||
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
|
||||
ppo_trainer = PPOTrainer(
|
||||
config,
|
||||
model,
|
||||
ref_model=None,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=collator,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
# We then build the sentiment analysis pipeline using our reward model, passing the
|
||||
# model name and the sentiment analysis pipeline arguments. Let's also make sure to
|
||||
# set the device to the same device as the PPOTrainer.
|
||||
device = ppo_trainer.accelerator.device
|
||||
if ppo_trainer.accelerator.num_processes == 1:
|
||||
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a ` pipeline` bug
|
||||
sentiment_pipe = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=reward_model_name,
|
||||
device_map={"": current_device},
|
||||
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
|
||||
tokenizer=tokenizer,
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
|
||||
if sentiment_pipe.model.config.pad_token_id is None:
|
||||
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
|
||||
# We then define the arguments to pass to the `generate` function. These arguments
|
||||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
|
||||
# the `generate` function of the trained model.
|
||||
generation_kwargs = {
|
||||
# "min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
"eos_token_id": 100_000,
|
||||
}
|
||||
output_min_length = 32
|
||||
output_max_length = script_args.output_max_length
|
||||
output_length_sampler = LengthSampler(output_min_length, output_max_length)
|
||||
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
if epoch >= config.total_ppo_epochs:
|
||||
break
|
||||
|
||||
question_tensors = batch["input_ids"]
|
||||
|
||||
response_tensors = ppo_trainer.generate(
|
||||
question_tensors,
|
||||
return_prompt=False,
|
||||
length_sampler=output_length_sampler,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
|
||||
|
||||
# Compute reward score (using the sentiment analysis pipeline)
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
|
||||
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")
|
@ -1,222 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed
|
||||
|
||||
from trl import SFTTrainer
|
||||
from trl.trainer import ConstantLengthDataset
|
||||
|
||||
|
||||
"""
|
||||
Fine-Tune Llama-7b on SE paired dataset
|
||||
"""
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="")
|
||||
parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired")
|
||||
parser.add_argument("--subset", type=str, default="data/finetune")
|
||||
parser.add_argument("--split", type=str, default="train")
|
||||
parser.add_argument("--size_valid_set", type=int, default=4000)
|
||||
parser.add_argument("--streaming", action="store_true")
|
||||
parser.add_argument("--shuffle_buffer", type=int, default=5000)
|
||||
|
||||
parser.add_argument("--seq_length", type=int, default=1024)
|
||||
parser.add_argument("--max_steps", type=int, default=10000)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--eos_token_id", type=int, default=49152)
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
|
||||
parser.add_argument("--num_warmup_steps", type=int, default=100)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.05)
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("--fp16", action="store_true", default=False)
|
||||
parser.add_argument("--bf16", action="store_true", default=False)
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--num_workers", type=int, default=None)
|
||||
parser.add_argument("--output_dir", type=str, default="./checkpoints")
|
||||
parser.add_argument("--log_freq", default=1, type=int)
|
||||
parser.add_argument("--eval_freq", default=1000, type=int)
|
||||
parser.add_argument("--save_freq", default=1000, type=int)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
"""
|
||||
Estimate the average number of characters per token in the dataset.
|
||||
"""
|
||||
total_characters, total_tokens = 0, 0
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
||||
text = prepare_sample_text(example)
|
||||
total_characters += len(text)
|
||||
if tokenizer.is_fast:
|
||||
total_tokens += len(tokenizer(text).tokens())
|
||||
else:
|
||||
total_tokens += len(tokenizer.tokenize(text))
|
||||
|
||||
return total_characters / total_tokens
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
"""
|
||||
Prints the number of trainable parameters in the model.
|
||||
"""
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_sample_text(example):
|
||||
"""Prepare the text from a sample of the dataset."""
|
||||
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
|
||||
return text
|
||||
|
||||
|
||||
def create_datasets(tokenizer, args):
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
data_dir=args.subset,
|
||||
split=args.split,
|
||||
use_auth_token=True,
|
||||
num_proc=args.num_workers if not args.streaming else None,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
if args.streaming:
|
||||
print("Loading the dataset in streaming mode")
|
||||
valid_data = dataset.take(args.size_valid_set)
|
||||
train_data = dataset.skip(args.size_valid_set)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=args.seed)
|
||||
train_data = dataset["train"]
|
||||
valid_data = dataset["test"]
|
||||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
||||
|
||||
chars_per_token = chars_token_ratio(train_data, tokenizer)
|
||||
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
||||
|
||||
train_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
train_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=True,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
valid_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
valid_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=False,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
def run_training(args, train_data, val_data):
|
||||
print("Loading the model")
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
train_data.start_iteration = 0
|
||||
|
||||
print("Starting main loop")
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
dataloader_drop_last=True,
|
||||
eval_strategy="steps",
|
||||
max_steps=args.max_steps,
|
||||
eval_steps=args.eval_freq,
|
||||
save_steps=args.save_freq,
|
||||
logging_steps=args.log_freq,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
per_device_eval_batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate,
|
||||
lr_scheduler_type=args.lr_scheduler_type,
|
||||
warmup_steps=args.num_warmup_steps,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
fp16=args.fp16,
|
||||
bf16=args.bf16,
|
||||
weight_decay=args.weight_decay,
|
||||
run_name="llama-7b-finetuned",
|
||||
report_to="wandb",
|
||||
ddp_find_unused_parameters=False,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index}
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
peft_config=lora_config,
|
||||
packing=True,
|
||||
)
|
||||
|
||||
print_trainable_parameters(trainer.model)
|
||||
|
||||
print("Training...")
|
||||
trainer.train()
|
||||
|
||||
print("Saving last checkpoint of the model")
|
||||
trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
|
||||
|
||||
|
||||
def main(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, args)
|
||||
run_training(args, train_dataset, eval_dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
assert args.model_path != "", "Please provide the llama model path"
|
||||
|
||||
set_seed(args.seed)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
logging.set_verbosity_error()
|
||||
|
||||
main(args)
|
@ -1,78 +0,0 @@
|
||||
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install all the dependencies in the `requirements.txt`:
|
||||
|
||||
```shell
|
||||
pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
Since we will use `accelerate` for training, make sure to run:
|
||||
|
||||
```shell
|
||||
accelerate config
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
There were two main steps to the DPO training process:
|
||||
|
||||
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
|
||||
|
||||
```shell
|
||||
accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \
|
||||
--output_dir="./sft" \
|
||||
--max_steps=500 \
|
||||
--save_steps=10 \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing=False \
|
||||
--group_by_length=False \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler_type="cosine" \
|
||||
--warmup_steps=100 \
|
||||
--weight_decay=0.05 \
|
||||
--optim="paged_adamw_32bit" \
|
||||
--bf16=True \
|
||||
--remove_unused_columns=False \
|
||||
--run_name="sft_llama2" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
|
||||
2. Run the DPO trainer using the model saved by the previous step:
|
||||
|
||||
```shell
|
||||
accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \
|
||||
--model_name_or_path="sft/final_checkpoint" \
|
||||
--output_dir="dpo"
|
||||
```
|
||||
|
||||
## Merging the adaptors
|
||||
|
||||
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
|
||||
|
||||
```shell
|
||||
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2"
|
||||
```
|
||||
|
||||
which will also push the model to your HuggingFace hub account.
|
||||
|
||||
## Running the model
|
||||
|
||||
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via:
|
||||
|
||||
```python
|
||||
from peft import AutoPeftModelForCausalLM
|
||||
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
"dpo/final_checkpoint",
|
||||
low_cpu_mem_usage=True,
|
||||
dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
model.generate(...)
|
||||
```
|
@ -1,252 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# 0. imports
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
||||
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The arguments for the DPO training script.
|
||||
"""
|
||||
|
||||
# data parameters
|
||||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
|
||||
# training parameters
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default="../sft/results/final_checkpoint",
|
||||
metadata={"help": "the location of the SFT model name or path"},
|
||||
)
|
||||
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
|
||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
|
||||
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
|
||||
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
|
||||
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
|
||||
|
||||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
||||
)
|
||||
|
||||
gradient_checkpointing_use_reentrant: Optional[bool] = field(
|
||||
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
|
||||
)
|
||||
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
|
||||
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
|
||||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
|
||||
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
|
||||
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
|
||||
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
|
||||
|
||||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
||||
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
|
||||
model_dtype: Optional[str] = field(
|
||||
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
|
||||
)
|
||||
|
||||
# instrumentation
|
||||
report_to: Optional[str] = field(
|
||||
default="wandb",
|
||||
metadata={
|
||||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
|
||||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
|
||||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
|
||||
},
|
||||
)
|
||||
# debug argument for distributed training
|
||||
ignore_bias_buffers: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
def get_stack_exchange_paired(
|
||||
data_dir: str = "data/rl",
|
||||
cache_dir: Optional[str] = None,
|
||||
num_proc=24,
|
||||
) -> Dataset:
|
||||
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
|
||||
|
||||
The dataset is converted to a dictionary with the following structure:
|
||||
{
|
||||
'prompt': list[str],
|
||||
'chosen': list[str],
|
||||
'rejected': list[str],
|
||||
}
|
||||
|
||||
Prompts are structured as follows:
|
||||
"Question: " + <prompt> + "\n\nAnswer: "
|
||||
"""
|
||||
dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired",
|
||||
split="train",
|
||||
cache_dir=cache_dir,
|
||||
data_dir=data_dir,
|
||||
verification_mode="no_checks",
|
||||
)
|
||||
original_columns = dataset.column_names
|
||||
|
||||
def return_prompt_and_responses(samples) -> dict[str, str]:
|
||||
return {
|
||||
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
|
||||
"chosen": samples["response_j"],
|
||||
"rejected": samples["response_k"],
|
||||
}
|
||||
|
||||
return dataset.map(
|
||||
return_prompt_and_responses,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
set_seed(script_args.seed)
|
||||
|
||||
# 1. load a pretrained model
|
||||
dtype = torch.float
|
||||
if script_args.model_dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif script_args.model_dtype == "bfloat16":
|
||||
dtype = torch.bfloat16
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
dtype=dtype,
|
||||
load_in_4bit=script_args.load_in_4bit,
|
||||
device_map={"": Accelerator().local_process_index},
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
if script_args.ignore_bias_buffers:
|
||||
# torch distributed hack
|
||||
model._ddp_params_and_buffers_to_ignore = [
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. Load the Stack-exchange paired dataset
|
||||
train_dataset = get_stack_exchange_paired(data_dir="data/rl")
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
||||
num_proc=script_args.num_proc,
|
||||
)
|
||||
|
||||
# 3. Load evaluation dataset
|
||||
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation")
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
||||
num_proc=script_args.num_proc,
|
||||
)
|
||||
|
||||
# 4. initialize training arguments:
|
||||
training_args = DPOConfig(
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
max_steps=script_args.max_steps,
|
||||
logging_steps=script_args.logging_steps,
|
||||
save_steps=script_args.save_steps,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
learning_rate=script_args.learning_rate,
|
||||
eval_strategy="steps",
|
||||
eval_steps=script_args.eval_steps,
|
||||
output_dir=script_args.output_dir,
|
||||
report_to=script_args.report_to,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
warmup_steps=script_args.warmup_steps,
|
||||
optim=script_args.optimizer_type,
|
||||
bf16=True,
|
||||
remove_unused_columns=False,
|
||||
run_name="dpo_llama2",
|
||||
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"out_proj",
|
||||
"fc_in",
|
||||
"fc_out",
|
||||
"wte",
|
||||
],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# 5. initialize the DPO trainer
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
peft_config=peft_config,
|
||||
max_prompt_length=script_args.max_prompt_length,
|
||||
max_length=script_args.max_length,
|
||||
)
|
||||
|
||||
# 6. train
|
||||
dpo_trainer.train()
|
||||
dpo_trainer.save_model(script_args.output_dir)
|
||||
|
||||
# 7. save
|
||||
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
|
||||
dpo_trainer.model.save_pretrained(output_dir)
|
@ -1,7 +0,0 @@
|
||||
transformers
|
||||
trl
|
||||
peft
|
||||
accelerate
|
||||
datasets
|
||||
bitsandbytes
|
||||
wandb
|
@ -1,212 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Fine-Tune Llama2-7b on SE paired dataset
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import AutoPeftModelForCausalLM, LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
HfArgumentParser,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from trl.trainer import ConstantLengthDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
|
||||
dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
|
||||
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
|
||||
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
|
||||
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
|
||||
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
|
||||
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
|
||||
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
|
||||
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
|
||||
|
||||
# LoraConfig
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
|
||||
parser = HfArgumentParser((ScriptArguments, SFTConfig))
|
||||
script_args, training_args = parser.parse_args_into_dataclasses()
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
if training_args.group_by_length and training_args.packing:
|
||||
raise ValueError("Cannot use both packing and group by length")
|
||||
|
||||
# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used.
|
||||
# `gradient_checkpointing=True` will cause `Variable._execution_engine.run_backward`.
|
||||
if training_args.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing not supported")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
|
||||
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
"""
|
||||
Estimate the average number of characters per token in the dataset.
|
||||
"""
|
||||
total_characters, total_tokens = 0, 0
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
||||
text = prepare_sample_text(example)
|
||||
total_characters += len(text)
|
||||
if tokenizer.is_fast:
|
||||
total_tokens += len(tokenizer(text).tokens())
|
||||
else:
|
||||
total_tokens += len(tokenizer.tokenize(text))
|
||||
|
||||
return total_characters / total_tokens
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
"""
|
||||
Prints the number of trainable parameters in the model.
|
||||
"""
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_sample_text(example):
|
||||
"""Prepare the text from a sample of the dataset."""
|
||||
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
|
||||
return text
|
||||
|
||||
|
||||
def create_datasets(tokenizer, args, seed=None):
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
data_dir=args.subset,
|
||||
split=args.split,
|
||||
use_auth_token=True,
|
||||
num_proc=args.num_workers if not args.streaming else None,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
if args.streaming:
|
||||
print("Loading the dataset in streaming mode")
|
||||
valid_data = dataset.take(args.size_valid_set)
|
||||
train_data = dataset.skip(args.size_valid_set)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=seed)
|
||||
train_data = dataset["train"]
|
||||
valid_data = dataset["test"]
|
||||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
||||
|
||||
chars_per_token = chars_token_ratio(train_data, tokenizer)
|
||||
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
||||
|
||||
train_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
train_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=True,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
valid_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
valid_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=False,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
bnb_config = None
|
||||
if script_args.use_bnb:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map={"": Accelerator().local_process_index},
|
||||
trust_remote_code=True,
|
||||
use_auth_token=True,
|
||||
)
|
||||
base_model.config.use_cache = False
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
|
||||
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=base_model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=peft_config,
|
||||
max_length=None,
|
||||
formatting_func=prepare_sample_text,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
||||
output_dir = os.path.join(training_args.output_dir, "final_checkpoint")
|
||||
trainer.model.save_pretrained(output_dir)
|
||||
|
||||
# Free memory for merging weights
|
||||
del base_model
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", dtype=torch.bfloat16)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint")
|
||||
model.save_pretrained(output_merged_dir, safe_serialization=True)
|
@ -1,7 +0,0 @@
|
||||
# De-detoxifying language models
|
||||
|
||||
To run this code, do the following:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file {CONFIG} examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py --log_with wandb
|
||||
```
|
@ -1,146 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
|
||||
toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement")
|
||||
ds = load_dataset("OxAISH-AL-LLM/wiki_toxic", split="test")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Evaluate de-toxified models")
|
||||
parser.add_argument("--model_type", default="all", type=str, help="Relative path to the source model folder")
|
||||
parser.add_argument("--output_file", default="toxicity.csv", type=str, help="Relative path to the source model folder")
|
||||
parser.add_argument("--batch_size", default=64, type=int, help="Batch size")
|
||||
parser.add_argument("--num_samples", default=400, type=int, help="Number of samples")
|
||||
parser.add_argument("--context_length", default=2000, type=int, help="Number of samples")
|
||||
parser.add_argument("--max_new_tokens", default=30, type=int, help="Max new tokens for generation")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.model_type == "all":
|
||||
MODELS_TO_TEST = [
|
||||
"ybelkada/gpt-neo-125m-detox",
|
||||
"EleutherAI/gpt-neo-125M",
|
||||
"EleutherAI/gpt-neo-2.7B",
|
||||
"ybelkada/gpt-neo-2.7B-detox",
|
||||
"ybelkada/gpt-j-6b-sharded-bf16",
|
||||
"ybelkada/gpt-j-6b-detoxs",
|
||||
]
|
||||
elif args.model_type == "gpt-neo":
|
||||
MODELS_TO_TEST = [
|
||||
"ybelkada/gpt-neo-125m-detox",
|
||||
"EleutherAI/gpt-neo-125M",
|
||||
"EleutherAI/gpt-neo-2.7B",
|
||||
"ybelkada/gpt-neo-2.7B-detox",
|
||||
]
|
||||
elif args.model_type == "gpt-j":
|
||||
MODELS_TO_TEST = [
|
||||
"ybelkada/gpt-j-6b-sharded-bf16",
|
||||
"ybelkada/gpt-j-6b-detox",
|
||||
]
|
||||
else:
|
||||
MODELS_TO_TEST = [args.model_type]
|
||||
NUM_SAMPLES = args.num_samples
|
||||
BATCH_SIZE = args.batch_size
|
||||
output_file = args.output_file
|
||||
max_new_tokens = args.max_new_tokens
|
||||
context_length = args.context_length
|
||||
if is_torch_xpu_available():
|
||||
device = torch.xpu.current_device()
|
||||
elif is_torch_npu_available():
|
||||
device = torch.npu.current_device()
|
||||
else:
|
||||
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# consider only toxic prompts
|
||||
ds = ds.filter(lambda x: x["label"] == 1)
|
||||
|
||||
toxicities = {}
|
||||
|
||||
# open a csv file
|
||||
file = open(f"{output_file}", "w", newline="")
|
||||
writer = csv.writer(file)
|
||||
# add first rows
|
||||
writer.writerow(["model_id", "mean_toxicity", "std_toxicity"])
|
||||
|
||||
|
||||
for model_id in tqdm(MODELS_TO_TEST):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
input_texts = []
|
||||
|
||||
for i, example in enumerate(ds):
|
||||
# set seed
|
||||
torch.manual_seed(42)
|
||||
|
||||
input_text = example["comment_text"]
|
||||
input_texts.append(input_text[:2000])
|
||||
|
||||
if i > NUM_SAMPLES:
|
||||
break
|
||||
|
||||
if (i + 1) % BATCH_SIZE == 0:
|
||||
inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device)
|
||||
inputs.input_ids = inputs.input_ids[:context_length]
|
||||
inputs.attention_mask = inputs.attention_mask[:context_length]
|
||||
outputs = model.generate(**inputs, do_sample=True, max_new_tokens=max_new_tokens, use_cache=True)
|
||||
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
generated_texts = [
|
||||
generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts)
|
||||
]
|
||||
toxicity_score = toxicity.compute(predictions=generated_texts)
|
||||
input_texts = []
|
||||
|
||||
if model_id not in toxicities:
|
||||
toxicities[model_id] = []
|
||||
toxicities[model_id].extend(toxicity_score["toxicity"])
|
||||
|
||||
# last batch
|
||||
inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device)
|
||||
outputs = model.generate(**inputs, do_sample=True, max_new_tokens=30)
|
||||
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
generated_texts = [generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts)]
|
||||
toxicity_score = toxicity.compute(predictions=generated_texts)
|
||||
toxicities[model_id].extend(toxicity_score["toxicity"])
|
||||
|
||||
# compute mean & std using np
|
||||
mean = np.mean(toxicities[model_id])
|
||||
std = np.std(toxicities[model_id])
|
||||
|
||||
# save to file
|
||||
writer.writerow([model_id, mean, std])
|
||||
|
||||
# print
|
||||
print(f"Model: {model_id} - Mean: {mean} - Std: {std}")
|
||||
|
||||
model = None
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# close file
|
||||
file.close()
|
@ -1,245 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaTokenizer,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model
|
||||
from trl.core import LengthSampler
|
||||
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
########################################################################
|
||||
# This is a fully working simple example to use trl with accelerate.
|
||||
#
|
||||
# This example fine-tunes a GPTJ model to generate less toxic contents
|
||||
# by using allenai/real-toxicity-prompts dataset. We use PPO
|
||||
# (proximal policy optimization) to optimize the model.
|
||||
# in any of the following settings (with the same script):
|
||||
# - single CPU or single GPU
|
||||
# - multi GPUS (using PyTorch distributed mode)
|
||||
# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2)
|
||||
# - fp16 (mixed-precision) or fp32 (normal precision)
|
||||
#
|
||||
# To run it in each of these various modes, first initialize the accelerate
|
||||
# configuration with `accelerate config`
|
||||
#
|
||||
########################################################################
|
||||
|
||||
|
||||
# We first define the configuration of the experiment, defining the model, the dataset,
|
||||
# the training parameters, and the PPO parameters.
|
||||
# Check the default arguments in the `PPOConfig` class for more details.
|
||||
# If you want to log with tensorboard, add the kwarg
|
||||
# `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine-tune with PPO
|
||||
"""
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
# models like gpt-neo* models are more suitable.
|
||||
model_name: Optional[str] = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=1, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
model_save_path: Optional[str] = field(
|
||||
default="./gpt-j-6B-detoxified-long-context-26-shl-1e4-final",
|
||||
metadata={"help": "the path to save the model"},
|
||||
)
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
config = PPOConfig(
|
||||
model_name=script_args.model_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
log_with=script_args.log_with,
|
||||
ppo_epochs=100,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
batch_size=script_args.batch_size,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
|
||||
# Below is an example function to build the dataset. In our case, we use the IMDB dataset
|
||||
# from the `datasets` library. One should customize this function to train the model on
|
||||
# its own dataset.
|
||||
def build_dataset(
|
||||
config, dataset_name="allenai/real-toxicity-prompts", input_min_text_length=5, input_max_text_length=10
|
||||
):
|
||||
"""
|
||||
Build dataset for training. This builds the dataset from `load_dataset`, one should
|
||||
customize this function to train the model on its own dataset.
|
||||
|
||||
Args:
|
||||
config (`PPOConfig`):
|
||||
The configuration of the PPO training.
|
||||
dataset_name (`str`):
|
||||
The name of the dataset to be loaded.
|
||||
input_min_text_length (`int`, defaults to 5):
|
||||
The minimum length of the input text.
|
||||
input_max_text_length (`int`, defaults to 10):
|
||||
The maximum length of the input text.
|
||||
|
||||
Returns:
|
||||
dataloader (`torch.utils.data.DataLoader`):
|
||||
The dataloader for the dataset.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
ds = load_dataset(dataset_name, split="train")
|
||||
|
||||
def filter_fn(sample):
|
||||
toxicity = sample["prompt"]["toxicity"]
|
||||
return toxicity is not None and toxicity > 0.3
|
||||
|
||||
ds = ds.filter(filter_fn, batched=False)
|
||||
|
||||
input_size = LengthSampler(input_min_text_length, input_max_text_length)
|
||||
|
||||
def tokenize(sample):
|
||||
prompt = sample["prompt"]["text"]
|
||||
continuation = sample["continuation"]["text"]
|
||||
|
||||
sample["input_ids"] = tokenizer.encode(prompt + continuation)[: input_size()]
|
||||
sample["query"] = tokenizer.decode(sample["input_ids"])
|
||||
return sample
|
||||
|
||||
ds = ds.map(tokenize, batched=False)
|
||||
ds.set_format(type="torch")
|
||||
|
||||
ds = ds.train_test_split(test_size=0.2, shuffle=False)["train"]
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
# We retrieve the dataloader by calling the `build_dataset` function.
|
||||
min_input_length = 30
|
||||
max_input_length = 40
|
||||
dataset = build_dataset(config, input_min_text_length=min_input_length, input_max_text_length=max_input_length)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
set_seed(config.seed)
|
||||
|
||||
# Now let's build the model, the reference model, and the tokenizer. We first load the model
|
||||
# in bfloat16 to save memory using `transformers`.
|
||||
model = AutoModelForCausalLM.from_pretrained(config.model_name, dtype=torch.bfloat16)
|
||||
# And then we pass the loaded model to `AutoModelForCausalLMWithValueHead`.
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
||||
# We create a reference model by sharing 20 layers
|
||||
ref_model = create_reference_model(model, num_shared_layers=20)
|
||||
|
||||
# We make sure to use `Adam` optimizer on the model parameters that require gradients.
|
||||
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
|
||||
|
||||
# GPT-2 / GPT-J tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
|
||||
# only for this model.
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
|
||||
ppo_trainer = PPOTrainer(
|
||||
config,
|
||||
model,
|
||||
ref_model=ref_model,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=collator,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
# We then build the reward pipeline, we will use the toxicity model to compute the reward.
|
||||
# We first load the toxicity model and tokenizer.
|
||||
toxicity_model_id = "facebook/roberta-hate-speech-dynabench-r4-target"
|
||||
toxicity_tokenizer = RobertaTokenizer.from_pretrained(toxicity_model_id)
|
||||
# We load the toxicity model in fp16 to save memory.
|
||||
toxicity_model = RobertaForSequenceClassification.from_pretrained(toxicity_model_id, dtype=torch.float16).to(
|
||||
ppo_trainer.accelerator.device
|
||||
)
|
||||
|
||||
|
||||
# We then define the arguments to pass to the `generate` function. These arguments
|
||||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
|
||||
# the `generate` function of the trained model.
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
}
|
||||
output_min_length = 20
|
||||
output_max_length = 30
|
||||
output_length_sampler = LengthSampler(output_min_length, output_max_length)
|
||||
|
||||
model_save_path = script_args.model_save_path
|
||||
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# Get response from the policy model
|
||||
response_tensors = []
|
||||
for query in query_tensors:
|
||||
gen_len = output_length_sampler()
|
||||
generation_kwargs["max_new_tokens"] = gen_len
|
||||
response = ppo_trainer.generate(query, **generation_kwargs)
|
||||
response_tensors.append(response.squeeze()[-gen_len:])
|
||||
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
|
||||
|
||||
# Compute sentiment score
|
||||
texts = batch["response"]
|
||||
toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(
|
||||
ppo_trainer.accelerator.device
|
||||
)
|
||||
logits = toxicity_model(**toxicity_inputs).logits.float()
|
||||
toxicity_labels = (logits[:, 0]).tolist()
|
||||
|
||||
rewards = [torch.tensor(output) for output in toxicity_labels]
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
# Save model every 100 epochs
|
||||
if epoch % 100 == 0:
|
||||
if ppo_trainer.accelerator.is_main_process:
|
||||
ppo_trainer.save_pretrained(model_save_path)
|
@ -70,8 +70,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
@ -83,7 +81,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -149,54 +147,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -57,8 +57,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
@ -70,7 +68,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -120,54 +118,6 @@ if __name__ == "__main__":
|
||||
train_dataset = train_dataset.remove_columns(["messages", "problem"])
|
||||
eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -57,8 +57,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
@ -70,7 +68,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -136,54 +134,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -87,8 +87,6 @@ import os
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from transformers import AutoConfig, AutoProcessor, GenerationConfig
|
||||
|
||||
from trl import (
|
||||
@ -102,7 +100,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -192,54 +190,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training (same as GRPO VLM)
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -33,12 +33,10 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from peft import LoraConfig
|
||||
|
||||
from trl import RLOOConfig, RLOOTrainer
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -67,52 +65,6 @@ def main():
|
||||
train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
||||
eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
||||
|
||||
# Reward function for training
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
# Training
|
||||
training_args = RLOOConfig(
|
||||
output_dir="Qwen3-0.6B-RLOO",
|
||||
|
@ -70,8 +70,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
@ -83,7 +81,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -149,54 +147,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -89,6 +89,9 @@ vlm = [
|
||||
"torchvision",
|
||||
"num2words==0.5.14"
|
||||
]
|
||||
math_verify = [
|
||||
"math-verify>=0.5.2",
|
||||
]
|
||||
dev = [
|
||||
# bco
|
||||
"scikit-learn",
|
||||
|
@ -23,9 +23,8 @@ def cleanup_gpu():
|
||||
"""
|
||||
Automatically cleanup GPU memory after each test.
|
||||
|
||||
This fixture helps prevent CUDA out of memory errors when running tests in parallel
|
||||
with pytest-xdist by ensuring models and tensors are properly garbage collected
|
||||
and GPU memory caches are cleared between tests.
|
||||
This fixture helps prevent CUDA out of memory errors when running tests in parallel with pytest-xdist by ensuring
|
||||
models and tensors are properly garbage collected and GPU memory caches are cleared between tests.
|
||||
"""
|
||||
yield
|
||||
# Cleanup after test
|
||||
|
@ -118,6 +118,7 @@ class TestGRPOTrainerSlow(TrlTestCase):
|
||||
max_completion_length=self.max_length,
|
||||
report_to="none",
|
||||
logging_strategy="no",
|
||||
loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
@ -328,11 +329,11 @@ class TestGRPOTrainerSlow(TrlTestCase):
|
||||
assert lora_params_changed, "No LoRA parameters were updated during training."
|
||||
|
||||
except torch.OutOfMemoryError as e:
|
||||
self.skipTest(f"Skipping VLM training test due to insufficient GPU memory: {e}")
|
||||
pytest.skip(f"Skipping VLM training test due to insufficient GPU memory: {e}")
|
||||
except Exception as e:
|
||||
# Check for other memory-related errors
|
||||
if any(keyword in str(e).lower() for keyword in ["memory", "cuda", "out of memory", "insufficient"]):
|
||||
self.skipTest(f"Skipping VLM training test due to hardware constraints: {e}")
|
||||
pytest.skip(f"Skipping VLM training test due to hardware constraints: {e}")
|
||||
else:
|
||||
raise
|
||||
|
||||
@ -473,11 +474,11 @@ class TestGRPOTrainerSlow(TrlTestCase):
|
||||
"decrease gpu memory",
|
||||
]
|
||||
):
|
||||
self.skipTest(f"Skipping vLLM colocate test due to hardware constraints: {e}")
|
||||
pytest.skip(f"Skipping vLLM colocate test due to hardware constraints: {e}")
|
||||
elif "KeyError" in str(e) and "RANK" in str(e):
|
||||
self.skipTest(f"Skipping vLLM colocate test due to environment setup issues: {e}")
|
||||
pytest.skip(f"Skipping vLLM colocate test due to environment setup issues: {e}")
|
||||
elif "ValueError" in str(e) and "memory" in str(e).lower():
|
||||
self.skipTest(f"Skipping vLLM colocate test due to memory constraints: {e}")
|
||||
pytest.skip(f"Skipping vLLM colocate test due to memory constraints: {e}")
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
@ -540,11 +541,11 @@ class TestGRPOTrainerSlow(TrlTestCase):
|
||||
"decrease gpu memory",
|
||||
]
|
||||
):
|
||||
self.skipTest(f"Skipping vLLM training test due to hardware constraints: {e}")
|
||||
pytest.skip(f"Skipping vLLM training test due to hardware constraints: {e}")
|
||||
elif "KeyError" in str(e) and "RANK" in str(e):
|
||||
self.skipTest(f"Skipping vLLM training test due to environment setup issues: {e}")
|
||||
pytest.skip(f"Skipping vLLM training test due to environment setup issues: {e}")
|
||||
elif "ValueError" in str(e) and "memory" in str(e).lower():
|
||||
self.skipTest(f"Skipping vLLM training test due to memory constraints: {e}")
|
||||
pytest.skip(f"Skipping vLLM training test due to memory constraints: {e}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
@ -412,12 +412,12 @@ class TestSFTTrainerSlow(TrlTestCase):
|
||||
eval_dataset=self.eval_dataset,
|
||||
)
|
||||
|
||||
# Register cleanup now that we have the trainer
|
||||
self.addCleanup(cleanup_liger_patches, trainer)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(trainer.model, trainer)
|
||||
# Ensure cleanup of liger patches after the test
|
||||
try:
|
||||
trainer.train()
|
||||
release_memory(trainer.model, trainer)
|
||||
finally:
|
||||
cleanup_liger_patches(trainer)
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
|
||||
@require_torch_accelerator
|
||||
|
@ -396,6 +396,29 @@ class TestApplyChatTemplate(TrlTestCase):
|
||||
assert isinstance(result["label"], bool)
|
||||
assert result["label"] == example["label"]
|
||||
|
||||
def test_apply_chat_template_with_chat_template_kwargs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM")
|
||||
|
||||
example = {
|
||||
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
# with this tokenizer, when you pass enable_thinking=False, it will add "<think>\n\n</think>\n\n"
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
result = apply_chat_template(example, tokenizer)
|
||||
|
||||
# docstyle-ignore
|
||||
expected = textwrap.dedent("""\
|
||||
<|im_start|>user
|
||||
What color is the sky?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<think>
|
||||
|
||||
</think>
|
||||
|
||||
""")
|
||||
|
||||
assert result["prompt"] == expected
|
||||
|
||||
def test_apply_chat_template_with_tools(self):
|
||||
tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2")
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
@ -23,6 +24,7 @@ from trl.models.utils import ChatMlSpecialTokens, clone_chat_template, setup_cha
|
||||
from .testing_utils import TrlTestCase
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
class TestDatasetFormatting(TrlTestCase):
|
||||
def setup_method(self):
|
||||
self.llama_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-MistralForCausalLM-0.1")
|
||||
|
@ -33,12 +33,18 @@ from transformers import (
|
||||
from transformers.testing_utils import (
|
||||
get_device_properties,
|
||||
require_liger_kernel,
|
||||
require_torch_gpu_if_bnb_not_multi_backend_enabled,
|
||||
)
|
||||
|
||||
from trl import DPOConfig, DPOTrainer, FDivergenceType
|
||||
|
||||
from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft, require_vision
|
||||
from .testing_utils import (
|
||||
TrlTestCase,
|
||||
require_bitsandbytes,
|
||||
require_no_wandb,
|
||||
require_peft,
|
||||
require_torch_gpu_if_bnb_not_multi_backend_enabled,
|
||||
require_vision,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -636,6 +642,7 @@ class TestDPOTrainer(TrlTestCase):
|
||||
def test_dpo_lora_bf16_autocast_llama(self):
|
||||
# Note this test only works on compute capability > 7 GPU devices
|
||||
from peft import LoraConfig
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
@ -649,7 +656,9 @@ class TestDPOTrainer(TrlTestCase):
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
|
||||
)
|
||||
|
||||
training_args = DPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
@ -719,6 +728,7 @@ class TestDPOTrainer(TrlTestCase):
|
||||
)
|
||||
def test_dpo_lora_bf16_autocast(self, loss_type, pre_compute, gen_during_eval):
|
||||
from peft import LoraConfig
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
@ -729,7 +739,9 @@ class TestDPOTrainer(TrlTestCase):
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
|
||||
)
|
||||
|
||||
training_args = DPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
@ -1410,6 +1422,7 @@ class TestDPOVisionTrainer(TrlTestCase):
|
||||
# ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",),
|
||||
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
|
||||
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
|
||||
("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
|
||||
]
|
||||
)
|
||||
def test_vdpo_trainer(self, model_id):
|
||||
|
@ -259,7 +259,7 @@ class TestGKDTrainer(TrlTestCase):
|
||||
|
||||
# Ensure liger fused JSD path is enabled; if not, skip (runtime may lack system libs)
|
||||
if not getattr(trainer, "use_liger_gkd_loss", False):
|
||||
self.skipTest("Liger fused JSD not enabled at runtime; skipping fused-loss assertion")
|
||||
pytest.skip("Liger fused JSD not enabled at runtime; skipping fused-loss assertion")
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
@ -1471,47 +1471,6 @@ class TestGRPOTrainer(TrlTestCase):
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@require_vision
|
||||
def test_training_vlm_and_prompt_truncation(self):
|
||||
# If not handled properly, prompt truncation may truncate image token
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
|
||||
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that rewards longer completions."""
|
||||
return [float(len(completion[0]["content"])) for completion in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
max_prompt_length=18,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
||||
reward_funcs=reward_func,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# Check that the params have changed
|
||||
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
|
||||
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
|
||||
params_to_skip = ("model.visual.",)
|
||||
for n, param in previous_trainable_params.items():
|
||||
if n.startswith(params_to_skip):
|
||||
continue
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
||||
|
@ -61,7 +61,7 @@ class TestJudges(TrlTestCase):
|
||||
|
||||
@require_llm_blender
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)"
|
||||
sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)"
|
||||
)
|
||||
def test_pair_rm_judge(self):
|
||||
judge = self.load_pair_rm_judge()
|
||||
@ -73,7 +73,7 @@ class TestJudges(TrlTestCase):
|
||||
|
||||
@require_llm_blender
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)"
|
||||
sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)"
|
||||
)
|
||||
def test_pair_rm_judge_return_scores(self):
|
||||
judge = self.load_pair_rm_judge()
|
||||
|
@ -16,12 +16,11 @@ import os
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.testing_utils import require_torch_gpu_if_bnb_not_multi_backend_enabled
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from .testing_utils import TrlTestCase, require_peft
|
||||
from .testing_utils import TrlTestCase, require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@ -102,9 +101,12 @@ class TestPeftModel(TrlTestCase):
|
||||
Simply creates a peft model and checks that it can be loaded.
|
||||
"""
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
self.causal_lm_model_id, peft_config=self.lora_config, load_in_8bit=True
|
||||
self.causal_lm_model_id,
|
||||
peft_config=self.lora_config,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
# Check that the number of trainable parameters is correct
|
||||
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
|
||||
@ -112,7 +114,7 @@ class TestPeftModel(TrlTestCase):
|
||||
assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt)
|
||||
|
||||
causal_lm_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.causal_lm_model_id, load_in_8bit=True, device_map="auto"
|
||||
self.causal_lm_model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto"
|
||||
)
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
|
||||
# Check that the number of trainable parameters is correct
|
||||
|
@ -13,9 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
from .testing_utils import TrlTestCase
|
||||
from .testing_utils import TrlTestCase, require_math_latex
|
||||
|
||||
|
||||
class TestThinkFormatReward(TrlTestCase):
|
||||
@ -85,3 +85,60 @@ class TestSoftOverlongPunishmentReward:
|
||||
completion_ids = [[1] * 90] # 90 is between 80 and 100
|
||||
rewards = reward_fn(completion_ids)
|
||||
assert round(abs(rewards[0] - -0.5), 4) == 0
|
||||
|
||||
|
||||
class TestAccuracyReward:
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_correct_answer(self):
|
||||
"""Test accuracy_reward with a correct answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{63}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}", "63/400"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 1.0
|
||||
assert rewards[1] == 1.0
|
||||
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_wrong_answer(self):
|
||||
"""Test accuracy_reward with an incorrect answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 0.0
|
||||
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_wrong_answer_no_latex(self):
|
||||
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
|
||||
completion = [[{"content": r"\boxed{3}"}]]
|
||||
solution = ["6"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 0.0
|
||||
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_unparseable_gold(self):
|
||||
"""Test accuracy_reward with an unparseable gold solution."""
|
||||
completion = [
|
||||
[{"content": "Answer is forty two."}],
|
||||
[{"content": "Some other content."}],
|
||||
[{"content": r"Answer is \boxed{42}."}],
|
||||
[{"content": r"Answer is \boxed{\mathbf{42}}."}], # Make response bold
|
||||
[{"content": r"Answer is \boxed{\textbf{42}}."}], # Different latex command for bold
|
||||
[{"content": r"Answer is \boxed{42}."}],
|
||||
[{"content": r"Answer is \boxed{42.3456}."}],
|
||||
]
|
||||
solution = [
|
||||
"Answer is forty two.",
|
||||
"Answer is forty three.",
|
||||
"Answer is 42.0", # Decimal point
|
||||
"Answer is 42 43 okay?", # Extra space
|
||||
"Answer is 42",
|
||||
r"Answer is \n\boxed{42}", # Newline in gold solution
|
||||
"Answer is 42.34560", # Extra trailing zero
|
||||
]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 1.0 # Should revert to exact text match
|
||||
assert rewards[1] == 0.0
|
||||
assert rewards[2] == 1.0
|
||||
assert rewards[3] == 1.0
|
||||
assert rewards[4] == 1.0
|
||||
assert rewards[5] == 1.0
|
||||
assert rewards[6] == 1.0 # Should ignore trailing zeros
|
||||
|
@ -63,4 +63,3 @@ class TestRichProgressCallback(TrlTestCase):
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.train()
|
||||
|
@ -1212,47 +1212,6 @@ class TestRLOOTrainer(TrlTestCase):
|
||||
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
|
||||
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@require_vision
|
||||
def test_training_vlm_and_prompt_truncation(self):
|
||||
# If not handled properly, prompt truncation may truncate image token
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
|
||||
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that rewards longer completions."""
|
||||
return [float(len(completion[0]["content"])) for completion in completions]
|
||||
|
||||
training_args = RLOOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
max_prompt_length=18,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = RLOOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
||||
reward_funcs=reward_func,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# Check that the params have changed
|
||||
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
|
||||
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
|
||||
params_to_skip = ("model.visual.",)
|
||||
for n, param in previous_trainable_params.items():
|
||||
if n.startswith(params_to_skip):
|
||||
continue
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
||||
|
@ -1477,6 +1477,7 @@ class TestSFTTrainer(TrlTestCase):
|
||||
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
|
||||
@pytest.mark.slow
|
||||
@require_vision
|
||||
@pytest.mark.skip(reason="Model google/gemma-3n-E2B-it is gated and requires HF token")
|
||||
def test_train_vlm_gemma_3n(self):
|
||||
# Get the dataset
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")
|
||||
|
@ -42,7 +42,6 @@ from trl.trainer.utils import (
|
||||
shuffle_sequence_dict,
|
||||
split_pixel_values_by_grid,
|
||||
split_tensor_dict,
|
||||
truncate_with_protected_tokens,
|
||||
unsplit_pixel_values_by_grid,
|
||||
)
|
||||
|
||||
@ -1009,84 +1008,6 @@ class TestSplitPixelValuesByGrid(TrlTestCase):
|
||||
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
|
||||
|
||||
|
||||
class TestTruncateWithProtectedTokens(TrlTestCase):
|
||||
def test_basic_example(self):
|
||||
"""Test the basic example from the problem description."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = [2, 3]
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
expected_ids = [2, 3, 5]
|
||||
assert new_ids == expected_ids
|
||||
|
||||
def test_no_truncation_needed(self):
|
||||
"""Test when target length equals current length."""
|
||||
prompt_ids = [1, 2, 3]
|
||||
protected_tokens = [2]
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
assert new_ids == prompt_ids
|
||||
|
||||
def test_no_protected_tokens(self):
|
||||
"""Test truncation with no protected tokens (normal right truncation)."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = []
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
expected_ids = [3, 4, 5] # Last 3 tokens
|
||||
assert new_ids == expected_ids
|
||||
|
||||
def test_all_tokens_protected(self):
|
||||
"""Test when all remaining tokens are protected."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = [3, 4, 5]
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
expected_ids = [3, 4, 5]
|
||||
assert new_ids == expected_ids
|
||||
|
||||
def test_too_many_protected_tokens(self):
|
||||
"""Test error when too many protected tokens for target length."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = [1, 2, 3, 4]
|
||||
target_length = 3
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
def test_single_batch_single_token(self):
|
||||
"""Test edge case with single batch and single token."""
|
||||
prompt_ids = [5]
|
||||
protected_tokens = [5]
|
||||
target_length = 1
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
assert new_ids == prompt_ids
|
||||
|
||||
def test_order_preservation(self):
|
||||
"""Test that relative order is preserved."""
|
||||
prompt_ids = [10, 2, 20, 3, 30, 40]
|
||||
protected_tokens = [2, 3]
|
||||
target_length = 4
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
# Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
|
||||
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
|
||||
expected_ids = [2, 3, 30, 40]
|
||||
|
||||
assert new_ids == expected_ids
|
||||
|
||||
|
||||
class TestUnsplitPixelValuesByGrid(TrlTestCase):
|
||||
def test_unsplit_correctly(self):
|
||||
pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]
|
||||
|
@ -22,7 +22,7 @@ from transformers.testing_utils import require_torch_multi_accelerator, torch_de
|
||||
from trl.extras.vllm_client import VLLMClient
|
||||
from trl.scripts.vllm_serve import chunk_list
|
||||
|
||||
from .testing_utils import TrlTestCase, kill_process, require_3_accelerators
|
||||
from .testing_utils import TrlTestCase, kill_process, require_3_accelerators, require_vllm
|
||||
|
||||
|
||||
class TestChunkList(TrlTestCase):
|
||||
@ -53,6 +53,7 @@ class TestChunkList(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_torch_multi_accelerator
|
||||
@require_vllm
|
||||
class TestVLLMClientServer(TrlTestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@ -212,6 +213,7 @@ class TestVLLMClientServerBaseURL(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_3_accelerators
|
||||
@require_vllm
|
||||
class TestVLLMClientServerTP(TrlTestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@ -274,6 +276,7 @@ class TestVLLMClientServerTP(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_3_accelerators
|
||||
@require_vllm
|
||||
class TestVLLMClientServerDP(TrlTestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@ -336,6 +339,7 @@ class TestVLLMClientServerDP(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_torch_multi_accelerator
|
||||
@require_vllm
|
||||
class TestVLLMClientServerDeviceParameter(TrlTestCase):
|
||||
"""Test the device parameter functionality in init_communicator."""
|
||||
|
||||
|
@ -26,12 +26,19 @@ from transformers.testing_utils import torch_device
|
||||
from transformers.utils import is_peft_available, is_rich_available, is_vision_available
|
||||
|
||||
from trl import BaseBinaryJudge, BasePairwiseJudge
|
||||
from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available
|
||||
from trl.import_utils import (
|
||||
is_joblib_available,
|
||||
is_llm_blender_available,
|
||||
is_math_verify_available,
|
||||
is_mergekit_available,
|
||||
is_vllm_available,
|
||||
)
|
||||
|
||||
|
||||
require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
|
||||
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
|
||||
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
|
||||
require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
|
||||
require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
|
||||
require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft")
|
||||
require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")
|
||||
@ -47,6 +54,21 @@ require_3_accelerators = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
def is_bitsandbytes_multi_backend_available() -> bool:
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
return "multi_backend" in getattr(bnb, "features", set())
|
||||
return False
|
||||
|
||||
|
||||
# Function ported from transformers.testing_utils before transformers#41283
|
||||
require_torch_gpu_if_bnb_not_multi_backend_enabled = pytest.mark.skipif(
|
||||
not is_bitsandbytes_multi_backend_available() and not torch_device == "cuda",
|
||||
reason="test requires bitsandbytes multi-backend enabled or 'cuda' torch device",
|
||||
)
|
||||
|
||||
|
||||
class RandomBinaryJudge(BaseBinaryJudge):
|
||||
"""
|
||||
Random binary judge, for testing purposes.
|
||||
|
@ -12,12 +12,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .import_utils import _LazyModule
|
||||
|
||||
|
||||
if sys.version_info[:2] == (3, 9):
|
||||
warnings.warn(
|
||||
(
|
||||
"Support for Python 3.9 will be dropped in the next release "
|
||||
"(after its end-of-life on October 31, 2025). "
|
||||
"Please upgrade to Python 3.10 or newer."
|
||||
),
|
||||
category=FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
__version__ = version("trl")
|
||||
except PackageNotFoundError:
|
||||
|
@ -143,7 +143,13 @@ def apply_chat_template(
|
||||
|
||||
# Apply the chat template to the whole conversation
|
||||
if "messages" in example:
|
||||
messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False, **template_kwargs)
|
||||
messages = tokenizer.apply_chat_template(
|
||||
example["messages"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
|
||||
# Apply the chat template to the prompt, adding the generation prompt
|
||||
if "prompt" in example:
|
||||
@ -162,6 +168,7 @@ def apply_chat_template(
|
||||
continue_final_message=continue_final_message,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
|
||||
@ -169,7 +176,11 @@ def apply_chat_template(
|
||||
if "prompt" in example: # explicit prompt and prompt-completion case
|
||||
if "chosen" in example:
|
||||
prompt_chosen = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["chosen"], tools=tools, tokenize=False, **template_kwargs
|
||||
example["prompt"] + example["chosen"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
# DeepSeek-R1 inserts a <tool_call> token when using `add_generation_prompt`, which can cause discrepancies
|
||||
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
|
||||
@ -179,24 +190,42 @@ def apply_chat_template(
|
||||
chosen = prompt_chosen[len(prompt) :]
|
||||
if "rejected" in example and "prompt" in example: # explicit prompt
|
||||
prompt_rejected = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs
|
||||
example["prompt"] + example["rejected"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
|
||||
rejected = prompt_rejected[len(prompt) :]
|
||||
if "completion" in example:
|
||||
prompt_completion = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs
|
||||
example["prompt"] + example["completion"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
|
||||
completion = prompt_completion[len(prompt) :]
|
||||
else: # implicit prompt case
|
||||
if "chosen" in example:
|
||||
chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False, **template_kwargs)
|
||||
chosen = tokenizer.apply_chat_template(
|
||||
example["chosen"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
if "rejected" in example:
|
||||
rejected = tokenizer.apply_chat_template(
|
||||
example["rejected"], tools=tools, tokenize=False, **template_kwargs
|
||||
example["rejected"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
|
||||
# Extract the completion by removing the prompt part from the prompt-completion string
|
||||
@ -239,8 +268,10 @@ def maybe_apply_chat_template(
|
||||
- Unpaired preference dataset: `"prompt"`, `"completion"`, and `"label"`.
|
||||
|
||||
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
|
||||
messages, where each message is a dictionary with keys `"role"` and `"content"`.
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
messages, where each message is a dictionary with keys `"role"` and `"content"`. Additionally, the example
|
||||
may contain a `"chat_template_kwargs"` key, which is a dictionary of additional keyword arguments to pass
|
||||
to the chat template renderer.
|
||||
tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
|
||||
Tokenizer to apply the chat template with.
|
||||
tools (`list[Union[dict, Callable]]`, *optional*):
|
||||
A list of tools (callable functions) that will be accessible to the model. If the template does not support
|
||||
@ -297,7 +328,7 @@ def unpair_preference_dataset(
|
||||
Unpair a preference dataset.
|
||||
|
||||
Args:
|
||||
dataset (`Dataset` or `DatasetDict`):
|
||||
dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]):
|
||||
Preference dataset to unpair. The dataset must have columns `"chosen"`, `"rejected"` and optionally
|
||||
`"prompt"`.
|
||||
num_proc (`int`, *optional*):
|
||||
@ -306,7 +337,7 @@ def unpair_preference_dataset(
|
||||
Meaningful description to be displayed alongside with the progress bar while mapping examples.
|
||||
|
||||
Returns:
|
||||
`Dataset`: The unpaired preference dataset.
|
||||
[`~datasets.Dataset`]: The unpaired preference dataset.
|
||||
|
||||
Example:
|
||||
|
||||
@ -340,7 +371,7 @@ def maybe_unpair_preference_dataset(
|
||||
Unpair a preference dataset if it is paired.
|
||||
|
||||
Args:
|
||||
dataset (`Dataset` or `DatasetDict`):
|
||||
dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]):
|
||||
Preference dataset to unpair. The dataset must have columns `"chosen"`, `"rejected"` and optionally
|
||||
`"prompt"`.
|
||||
num_proc (`int`, *optional*):
|
||||
@ -349,7 +380,8 @@ def maybe_unpair_preference_dataset(
|
||||
Meaningful description to be displayed alongside with the progress bar while mapping examples.
|
||||
|
||||
Returns:
|
||||
`Dataset` or `DatasetDict`: The unpaired preference dataset if it was paired, otherwise the original dataset.
|
||||
[`~datasets.Dataset`] or [`~datasets.DatasetDict`]: The unpaired preference dataset if it was paired, otherwise
|
||||
the original dataset.
|
||||
|
||||
Example:
|
||||
|
||||
@ -442,7 +474,7 @@ def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]:
|
||||
'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}
|
||||
```
|
||||
|
||||
Or, with the `map` method of `datasets.Dataset`:
|
||||
Or, with the `map` method of [`~datasets.Dataset`]:
|
||||
|
||||
```python
|
||||
>>> from trl import extract_prompt
|
||||
@ -633,7 +665,7 @@ def pack_dataset(
|
||||
Pack sequences in a dataset into chunks of size `seq_length`.
|
||||
|
||||
Args:
|
||||
dataset (`Dataset` or `DatasetDict`):
|
||||
dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]):
|
||||
Dataset to pack
|
||||
seq_length (`int`):
|
||||
Target sequence length to pack to.
|
||||
@ -648,8 +680,8 @@ def pack_dataset(
|
||||
Additional keyword arguments to pass to the dataset's map method when packing examples.
|
||||
|
||||
Returns:
|
||||
`Dataset` or `DatasetDict`: The dataset with packed sequences. The number of examples may decrease as sequences
|
||||
are combined.
|
||||
[`~datasets.Dataset`] or [`~datasets.DatasetDict`]: The dataset with packed sequences. The number of examples
|
||||
may decrease as sequences are combined.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@ -689,7 +721,7 @@ def truncate_dataset(
|
||||
Truncate sequences in a dataset to a specified `max_length`.
|
||||
|
||||
Args:
|
||||
dataset (`Dataset` or `DatasetDict`):
|
||||
dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]):
|
||||
Dataset to truncate.
|
||||
max_length (`int`):
|
||||
Maximum sequence length to truncate to.
|
||||
@ -697,7 +729,7 @@ def truncate_dataset(
|
||||
Additional keyword arguments to pass to the dataset's map method when truncating examples.
|
||||
|
||||
Returns:
|
||||
`Dataset` or `DatasetDict`: The dataset with truncated sequences.
|
||||
[`~datasets.Dataset`] or [`~datasets.DatasetDict`]: The dataset with truncated sequences.
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -42,8 +43,16 @@ class BestOfNSampler:
|
||||
generation_config ([`~transformers.GenerationConfig`], *optional*):
|
||||
Generation config passed to the underlying model's `generate` method. See
|
||||
[`~transformers.GenerationConfig`] for more details.
|
||||
|
||||
<Deprecated version="0.24.0">
|
||||
|
||||
`BestOfNSampler` is deprecated and will be removed in version 0.25.
|
||||
|
||||
</Deprecated>
|
||||
"""
|
||||
|
||||
warnings.warn("`BestOfNSampler` is deprecated and will be removed in TRL 0.25.", FutureWarning, stacklevel=2)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModelWrapper,
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Callable, Literal, Optional
|
||||
|
||||
import datasets
|
||||
@ -41,7 +42,20 @@ def conversations_formatting_function(
|
||||
r"""
|
||||
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the
|
||||
tokenizer apply chat template to the dataset along with the schema of the list of functions in the tools list.
|
||||
|
||||
<Deprecated version="0.24.0">
|
||||
|
||||
`conversations_formatting_function` is deprecated and will be removed in version 0.27. Please use
|
||||
`tokenizer.apply_chat_template()` directly instead.
|
||||
|
||||
</Deprecated>
|
||||
"""
|
||||
warnings.warn(
|
||||
"`conversations_formatting_function` is deprecated and will be removed in TRL 0.27. "
|
||||
"Please use `tokenizer.apply_chat_template()` directly instead.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def format_dataset(examples):
|
||||
if isinstance(examples[messages_field][0], list):
|
||||
@ -61,7 +75,20 @@ def instructions_formatting_function(tokenizer: AutoTokenizer):
|
||||
r"""
|
||||
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the
|
||||
tokenizer apply chat template to the dataset
|
||||
|
||||
<Deprecated version="0.24.0">
|
||||
|
||||
`instructions_formatting_function` is deprecated and will be removed in version 0.27. Please use
|
||||
`tokenizer.apply_chat_template()` directly instead.
|
||||
|
||||
</Deprecated>
|
||||
"""
|
||||
warnings.warn(
|
||||
"`instructions_formatting_function` is deprecated and will be removed in TRL 0.27. "
|
||||
"Please use `tokenizer.apply_chat_template()` directly instead.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def format_dataset(examples):
|
||||
if isinstance(examples["prompt"], list):
|
||||
@ -99,7 +126,21 @@ def get_formatting_func_from_dataset(
|
||||
|
||||
Returns:
|
||||
Callable: Formatting function if the dataset format is supported else None
|
||||
|
||||
<Deprecated version="0.24.0">
|
||||
|
||||
`get_formatting_func_from_dataset` is deprecated and will be removed in version 0.27. Please use
|
||||
`tokenizer.apply_chat_template()` directly instead.
|
||||
|
||||
</Deprecated>
|
||||
"""
|
||||
warnings.warn(
|
||||
"`get_formatting_func_from_dataset` is deprecated and will be removed in TRL 0.27. "
|
||||
"Please use `tokenizer.apply_chat_template()` directly instead.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(dataset, Dataset):
|
||||
if "messages" in dataset.features:
|
||||
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
|
||||
|
@ -182,6 +182,7 @@ class VLLMClient:
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
max_tokens: int = 16,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
guided_decoding_regex: Optional[str] = None,
|
||||
generation_kwargs: Optional[dict] = None,
|
||||
) -> list[list[int]]:
|
||||
@ -207,6 +208,10 @@ class VLLMClient:
|
||||
Minimum probability for sampling.
|
||||
max_tokens (`int`, *optional*, defaults to `16`):
|
||||
Maximum number of tokens to generate for each prompt.
|
||||
truncate_prompt_tokens (`int`, *optional*):
|
||||
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
|
||||
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
|
||||
disabled.
|
||||
guided_decoding_regex (`str`, *optional*):
|
||||
Regular expression to guide the decoding process.
|
||||
generation_kwargs (`dict`, *optional*):
|
||||
@ -246,6 +251,7 @@ class VLLMClient:
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
"max_tokens": max_tokens,
|
||||
"truncate_prompt_tokens": truncate_prompt_tokens,
|
||||
"guided_decoding_regex": guided_decoding_regex,
|
||||
"generation_kwargs": generation_kwargs or {},
|
||||
},
|
||||
|
@ -31,6 +31,7 @@ _fastapi_available = _is_package_available("fastapi")
|
||||
_joblib_available = _is_package_available("joblib")
|
||||
_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
|
||||
_llm_blender_available = _is_package_available("llm_blender")
|
||||
_math_verify_available = _is_package_available("math_verify")
|
||||
_mergekit_available = _is_package_available("mergekit")
|
||||
_pydantic_available = _is_package_available("pydantic")
|
||||
_requests_available = _is_package_available("requests")
|
||||
@ -61,6 +62,10 @@ def is_llm_blender_available() -> bool:
|
||||
return _llm_blender_available
|
||||
|
||||
|
||||
def is_math_verify_available() -> bool:
|
||||
return _math_verify_available
|
||||
|
||||
|
||||
def is_mergekit_available() -> bool:
|
||||
return _mergekit_available
|
||||
|
||||
|
@ -264,7 +264,7 @@ def merge_models(config: MergeConfig, out_path: str):
|
||||
Merge two models using mergekit
|
||||
|
||||
Args:
|
||||
config (`MergeConfig`): The merge configuration.
|
||||
config ([`MergeConfig`]): The merge configuration.
|
||||
out_path (`str`): The output path for the merged model.
|
||||
"""
|
||||
if not is_mergekit_available():
|
||||
|
@ -57,14 +57,17 @@ LAYER_PATTERNS = [
|
||||
|
||||
|
||||
class PreTrainedModelWrapper(nn.Module):
|
||||
r"""
|
||||
A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the (`~transformers.PreTrained`)
|
||||
class in order to keep some attributes and methods of the (`~transformers.PreTrainedModel`) class.
|
||||
"""
|
||||
Wrapper for a [`~transformers.PreTrainedModel`] implemented as a standard PyTorch [`torch.nn.Module`].
|
||||
|
||||
This class provides a compatibility layer that preserves the key attributes and methods of the original
|
||||
[`~transformers.PreTrainedModel`], while exposing a uniform interface consistent with PyTorch modules. It enables
|
||||
seamless integration of pretrained Transformer models into custom training, evaluation, or inference workflows.
|
||||
|
||||
Attributes:
|
||||
pretrained_model (`transformers.PreTrainedModel`):
|
||||
pretrained_model ([`~transformers.PreTrainedModel`]):
|
||||
The model to be wrapped.
|
||||
parent_class (`transformers.PreTrainedModel`):
|
||||
parent_class ([`~transformers.PreTrainedModel`]):
|
||||
The parent class of the model to be wrapped.
|
||||
supported_args (`list`):
|
||||
The list of arguments that are supported by the wrapper class.
|
||||
@ -111,19 +114,20 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""
|
||||
Instantiates a new model from a pretrained model from `transformers`. The pretrained model is loaded using the
|
||||
`from_pretrained` method of the `transformers.PreTrainedModel` class. The arguments that are specific to the
|
||||
`transformers.PreTrainedModel` class are passed along this method and filtered out from the `kwargs` argument.
|
||||
`from_pretrained` method of the [`~transformers.PreTrainedModel`] class. The arguments that are specific to the
|
||||
[`~transformers.PreTrainedModel`] class are passed along this method and filtered out from the `kwargs`
|
||||
argument.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`):
|
||||
pretrained_model_name_or_path (`str` or [`~transformers.PreTrainedModel`]):
|
||||
The path to the pretrained model or its name.
|
||||
*model_args (`list`, *optional*)):
|
||||
*model_args (`list`, *optional*):
|
||||
Additional positional arguments passed along to the underlying model's `from_pretrained` method.
|
||||
**kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments passed along to the underlying model's `from_pretrained` method. We also
|
||||
pre-process the kwargs to extract the arguments that are specific to the `transformers.PreTrainedModel`
|
||||
class and the arguments that are specific to trl models. The kwargs also support
|
||||
`prepare_model_for_kbit_training` arguments from `peft` library.
|
||||
pre-process the kwargs to extract the arguments that are specific to the
|
||||
[`~transformers.PreTrainedModel`] class and the arguments that are specific to trl models. The kwargs
|
||||
also support `prepare_model_for_kbit_training` arguments from `peft` library.
|
||||
"""
|
||||
if kwargs is not None:
|
||||
peft_config = kwargs.pop("peft_config", None)
|
||||
@ -149,8 +153,13 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
|
||||
current_device = cls._get_current_device()
|
||||
if isinstance(pretrained_model_name_or_path, str):
|
||||
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
|
||||
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
|
||||
quantization_config = pretrained_kwargs.get("quantization_config", None)
|
||||
if quantization_config is not None:
|
||||
is_loaded_in_8bit = getattr(quantization_config, "load_in_8bit", False)
|
||||
is_loaded_in_4bit = getattr(quantization_config, "load_in_4bit", False)
|
||||
else:
|
||||
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
|
||||
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
|
||||
else:
|
||||
is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False)
|
||||
is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False)
|
||||
@ -507,8 +516,8 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
def push_to_hub(self, *args, **kwargs):
|
||||
r"""
|
||||
Push the pretrained model to the hub. This method is a wrapper around
|
||||
`transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation of
|
||||
`transformers.PreTrainedModel.push_to_hub` for more information.
|
||||
[`~transformers.PreTrainedModel.push_to_hub`]. Please refer to the documentation of
|
||||
[`~transformers.PreTrainedModel.push_to_hub`] for more information.
|
||||
|
||||
Args:
|
||||
*args (`list`, *optional*):
|
||||
@ -521,8 +530,8 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
r"""
|
||||
Save the pretrained model to a directory. This method is a wrapper around
|
||||
`transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation of
|
||||
`transformers.PreTrainedModel.save_pretrained` for more information.
|
||||
[`~transformers.PreTrainedModel.save_pretrained`]. Please refer to the documentation of
|
||||
[`~transformers.PreTrainedModel.save_pretrained`] for more information.
|
||||
|
||||
Args:
|
||||
*args (`list`, *optional*):
|
||||
@ -596,14 +605,14 @@ def create_reference_model(
|
||||
Creates a static reference copy of a model. Note that model will be in `.eval()` mode.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModelWrapper`): The model to be copied.
|
||||
model ([`PreTrainedModelWrapper`]): The model to be copied.
|
||||
num_shared_layers (`int`, *optional*):
|
||||
The number of initial layers that are shared between both models and kept frozen.
|
||||
pattern (`str`, *optional*): The shared layers are selected with a string pattern
|
||||
(e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here.
|
||||
|
||||
Returns:
|
||||
`PreTrainedModelWrapper`
|
||||
[`PreTrainedModelWrapper`]
|
||||
"""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError(
|
||||
@ -665,13 +674,13 @@ def create_reference_model(
|
||||
|
||||
|
||||
class GeometricMixtureWrapper(GenerationMixin):
|
||||
r"""
|
||||
"""
|
||||
Geometric Mixture generation wrapper that samples from the logits of two model's geometric mixture.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The model to be wrapped.
|
||||
ref_model (`PreTrainedModel`): The reference model.
|
||||
generation_config (`GenerationConfig`): The generation config.
|
||||
model ([`~transformers.PreTrainedModel`]): The model to be wrapped.
|
||||
ref_model ([`~transformers.PreTrainedModel`]): The reference model.
|
||||
generation_config ([`~transformers.GenerationConfig`]): The generation config.
|
||||
mixture_coef (`float`, *optional* - default: 0.5): The mixture coefficient.
|
||||
"""
|
||||
|
||||
|
@ -60,26 +60,27 @@ class ValueHead(nn.Module):
|
||||
|
||||
|
||||
class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
||||
r"""
|
||||
"""
|
||||
An autoregressive model with a value head in addition to the language model head. This class inherits from
|
||||
`~trl.PreTrainedModelWrapper` and wraps a `transformers.PreTrainedModel` class. The wrapper class supports classic
|
||||
[`PreTrainedModelWrapper`] and wraps a [`~transformers.PreTrainedModel`] class. The wrapper class supports classic
|
||||
functions such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped model, simply
|
||||
manipulate the `pretrained_model` attribute of this class.
|
||||
|
||||
Class attributes:
|
||||
- **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
|
||||
- **transformers_parent_class** ([`~transformers.PreTrainedModel`]) -- The parent class of the wrapped model.
|
||||
This
|
||||
should be set to `transformers.AutoModelForCausalLM` for this class.
|
||||
- **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
|
||||
by the `ValueHead` class. Currently, the supported args are:
|
||||
by the [`ValueHead`] class. Currently, the supported args are:
|
||||
- **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
|
||||
`ValueHead` class.
|
||||
[`ValueHead`] class.
|
||||
- **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the
|
||||
`ValueHead` if a specific initialization strategy is selected.
|
||||
[`ValueHead`] if a specific initialization strategy is selected.
|
||||
- **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the
|
||||
`ValueHead`. Currently, the supported strategies are:
|
||||
- **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the
|
||||
[`ValueHead`]. Currently, the supported strategies are:
|
||||
- **`None`** -- Initializes the weights of the [`ValueHead`] with a random distribution. This is the
|
||||
default strategy.
|
||||
- **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution.
|
||||
- **"normal"** -- Initializes the weights of the [`ValueHead`] with a normal distribution.
|
||||
"""
|
||||
|
||||
transformers_parent_class = AutoModelForCausalLM
|
||||
@ -90,15 +91,15 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
||||
)
|
||||
|
||||
def __init__(self, pretrained_model, **kwargs):
|
||||
r"""
|
||||
"""
|
||||
Initializes the model.
|
||||
|
||||
Args:
|
||||
pretrained_model (`transformers.PreTrainedModel`):
|
||||
pretrained_model ([`~transformers.PreTrainedModel`]):
|
||||
The model to wrap. It should be a causal language model such as GPT2. or any model mapped inside the
|
||||
`AutoModelForCausalLM` class.
|
||||
kwargs (`dict`, `optional`):
|
||||
Additional keyword arguments, that are passed to the `ValueHead` class.
|
||||
Additional keyword arguments, that are passed to the [`ValueHead`] class.
|
||||
"""
|
||||
super().__init__(pretrained_model, **kwargs)
|
||||
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
|
||||
@ -114,8 +115,8 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
||||
|
||||
Args:
|
||||
**kwargs (`dict`, `optional`):
|
||||
Additional keyword arguments, that are passed to the `ValueHead` class. These arguments can contain the
|
||||
`v_head_init_strategy` argument as well as the `v_head_initializer_range` argument.
|
||||
Additional keyword arguments, that are passed to the [`ValueHead`] class. These arguments can contain
|
||||
the `v_head_init_strategy` argument as well as the `v_head_initializer_range` argument.
|
||||
"""
|
||||
initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
|
||||
# random init by default
|
||||
@ -263,18 +264,18 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
||||
|
||||
|
||||
class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
|
||||
r"""
|
||||
"""
|
||||
A seq2seq model with a value head in addition to the language model head. This class inherits from
|
||||
`~trl.PreTrainedModelWrapper` and wraps a `transformers.PreTrainedModel` class. The wrapper class supports classic
|
||||
[`PreTrainedModelWrapper`] and wraps a [`~transformers.PreTrainedModel`] class. The wrapper class supports classic
|
||||
functions such as `from_pretrained` and `push_to_hub` and also provides some additional functionalities such as
|
||||
`generate`.
|
||||
|
||||
Args:
|
||||
pretrained_model (`transformers.PreTrainedModel`):
|
||||
pretrained_model ([`~transformers.PreTrainedModel`]):
|
||||
The model to wrap. It should be a causal language model such as GPT2. or any model mapped inside the
|
||||
`AutoModelForSeq2SeqLM` class.
|
||||
[`~transformers.AutoModelForSeq2SeqLM`] class.
|
||||
kwargs:
|
||||
Additional keyword arguments passed along to the `ValueHead` class.
|
||||
Additional keyword arguments passed along to the [`ValueHead`] class.
|
||||
"""
|
||||
|
||||
transformers_parent_class = AutoModelForSeq2SeqLM
|
||||
|
@ -102,21 +102,21 @@ def setup_chat_format(
|
||||
`tokenizer.chat_template` to `None`.
|
||||
|
||||
Args:
|
||||
model (`~transformers.PreTrainedModel`): The model to be modified.
|
||||
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
|
||||
model ([`~transformers.PreTrainedModel`]): The model to be modified.
|
||||
tokenizer ([`~transformers.PreTrainedTokenizer`]): The tokenizer to be modified.
|
||||
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
|
||||
resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None.
|
||||
|
||||
Returns:
|
||||
model (`~transformers.PreTrainedModel`):
|
||||
model ([`~transformers.PreTrainedModel`]):
|
||||
The modified model.
|
||||
tokenizer (`~transformers.PreTrainedTokenizer`):
|
||||
tokenizer ([`~transformers.PreTrainedTokenizer`]):
|
||||
The modified tokenizer.
|
||||
"""
|
||||
warnings.warn(
|
||||
"The `setup_chat_format` function is deprecated and will be removed in version 0.26.0. Please use "
|
||||
"`clone_chat_template` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
# check if model already had a chat template
|
||||
if tokenizer.chat_template is not None:
|
||||
@ -178,9 +178,9 @@ def clone_chat_template(
|
||||
the embedding dimensions.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`):
|
||||
model ([`~transformers.PreTrainedModel`]):
|
||||
Model to update.
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
tokenizer ([`~transformers.PreTrainedTokenizer`]):
|
||||
Tokenizer to update.
|
||||
source_tokenizer_path (`str`):
|
||||
Path or identifier of the pretrained tokenizer to clone from.
|
||||
@ -189,9 +189,9 @@ def clone_chat_template(
|
||||
new vocabulary size to the nearest multiple of this value.
|
||||
|
||||
Returns:
|
||||
model (`PreTrainedModel`):
|
||||
model ([`~transformers.PreTrainedModel`]):
|
||||
Updated model with resized token embeddings and EOS token configured.
|
||||
tokenizer (`~transformers.PreTrainedTokenizer`):
|
||||
tokenizer ([`~transformers.PreTrainedTokenizer`]):
|
||||
Updated tokenizer with the chat template and special tokens applied.
|
||||
added_tokens (`list[int]`):
|
||||
List of tokens that were added to the tokenizer from the source tokenizer.
|
||||
@ -316,7 +316,7 @@ def unwrap_model_for_generation(
|
||||
Args:
|
||||
model (`Union[DistributedDataParallel, DeepSpeedEngine]`):
|
||||
Model to be unwrapped.
|
||||
accelerator (`~accelerate.Accelerator`):
|
||||
accelerator ([`~accelerate.Accelerator`]):
|
||||
Accelerator instance managing the model.
|
||||
gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):
|
||||
Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which
|
||||
|
@ -20,12 +20,14 @@ from ..import_utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"accuracy_rewards": ["accuracy_reward"],
|
||||
"format_rewards": ["think_format_reward"],
|
||||
"other_rewards": ["get_soft_overlong_punishment"],
|
||||
}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .accuracy_rewards import accuracy_reward
|
||||
from .format_rewards import think_format_reward
|
||||
from .other_rewards import get_soft_overlong_punishment
|
||||
|
||||
|
93
trl/rewards/accuracy_rewards.py
Normal file
93
trl/rewards/accuracy_rewards.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from trl.import_utils import is_math_verify_available
|
||||
|
||||
|
||||
if is_math_verify_available():
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
|
||||
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
|
||||
r"""
|
||||
Reward function that checks if the completion is the same as the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
|
||||
Args:
|
||||
completions (`list[list[dict[str, str]]]`):
|
||||
List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
|
||||
containing the key `"content"` with the value being the text of the completion.
|
||||
solution: (`list[str]`):
|
||||
List of the raw-text solutions to the questions/problems/prompts.
|
||||
**kwargs:
|
||||
Additional keyword arguments. This function does not use them, but they are required in the function
|
||||
signature to ensure compatibility with trainers like [`GRPOTrainer`].
|
||||
Example:
|
||||
```python
|
||||
>>> from trl.rewards import accuracy_reward
|
||||
|
||||
>>> solution = [r"\frac{1}{3}", r"\frac{1}{3}"]
|
||||
>>> completion = [
|
||||
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
|
||||
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
|
||||
... ]
|
||||
>>> accuracy_reward(completion, solution)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
"""
|
||||
if not is_math_verify_available():
|
||||
raise ImportError("Please install the `math_verify` package to use accuracy_reward")
|
||||
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
for content, sol in zip(contents, solution):
|
||||
gold_parsed = parse(
|
||||
sol,
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
# Compute binary rewards if verifiable, `None` otherwise to skip this example
|
||||
try:
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception:
|
||||
reward = None
|
||||
else:
|
||||
# If the gold solution is not parseable, we assign `None` to skip this example
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
@ -41,7 +41,7 @@ from trl import (
|
||||
get_dataset,
|
||||
get_peft_config,
|
||||
)
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
||||
|
||||
|
||||
reward_funcs_registry = {
|
||||
"accuracy_reward": accuracy_reward,
|
||||
"think_format_reward": think_format_reward,
|
||||
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
||||
}
|
||||
@ -68,6 +69,7 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
reward_funcs (`list[str]`, *optional*):
|
||||
Reward functions to use. Supported values are:
|
||||
|
||||
- `"accuracy_reward"`
|
||||
- `"think_format_reward"`
|
||||
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||
@ -83,7 +85,7 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
reward_funcs: Optional[list[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
||||
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
|
||||
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
||||
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
||||
},
|
||||
|
@ -41,7 +41,7 @@ from trl import (
|
||||
get_dataset,
|
||||
get_peft_config,
|
||||
)
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
||||
|
||||
|
||||
reward_funcs_registry = {
|
||||
"accuracy_reward": accuracy_reward,
|
||||
"think_format_reward": think_format_reward,
|
||||
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
||||
}
|
||||
@ -68,6 +69,7 @@ class RLOOScriptArguments(ScriptArguments):
|
||||
reward_funcs (`list[str]`, *optional*):
|
||||
Reward functions to use. Supported values are:
|
||||
|
||||
- `"accuracy_reward"`
|
||||
- `"think_format_reward"`
|
||||
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||
@ -83,7 +85,7 @@ class RLOOScriptArguments(ScriptArguments):
|
||||
reward_funcs: Optional[list[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
||||
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
|
||||
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
||||
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
||||
},
|
||||
|
@ -60,7 +60,8 @@ class DatasetConfig:
|
||||
Configuration for a dataset.
|
||||
|
||||
This class matches the signature of [`~datasets.load_dataset`] and the arguments are used directly in the
|
||||
`datasets.load_dataset` function. You can refer to the `datasets.load_dataset` documentation for more details.
|
||||
[`~datasets.load_dataset`] function. You can refer to the [`~datasets.load_dataset`] documentation for more
|
||||
details.
|
||||
|
||||
Parameters:
|
||||
path (`str`):
|
||||
@ -422,11 +423,11 @@ def get_dataset(mixture_config: DatasetMixtureConfig) -> DatasetDict:
|
||||
Load a mixture of datasets based on the configuration.
|
||||
|
||||
Args:
|
||||
mixture_config (`DatasetMixtureConfig`):
|
||||
mixture_config ([`DatasetMixtureConfig`]):
|
||||
Script arguments containing dataset configuration.
|
||||
|
||||
Returns:
|
||||
`DatasetDict`:
|
||||
[`~datasets.DatasetDict`]:
|
||||
Combined dataset(s) from the mixture configuration, with optional train/test split if `test_split_size` is
|
||||
set.
|
||||
|
||||
|
@ -495,6 +495,7 @@ def main(script_args: ScriptArguments):
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
max_tokens: int = 16
|
||||
truncate_prompt_tokens: Optional[int] = None
|
||||
guided_decoding_regex: Optional[str] = None
|
||||
generation_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
@ -525,6 +526,9 @@ def main(script_args: ScriptArguments):
|
||||
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
|
||||
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
|
||||
completion.
|
||||
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
|
||||
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
|
||||
truncation). If set to `None`, truncation is disabled.
|
||||
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
|
||||
model will only generate tokens that match this regex pattern.
|
||||
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
|
||||
@ -575,6 +579,7 @@ def main(script_args: ScriptArguments):
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"max_tokens": request.max_tokens,
|
||||
"truncate_prompt_tokens": request.truncate_prompt_tokens,
|
||||
"guided_decoding": guided_decoding,
|
||||
"logprobs": 0,
|
||||
}
|
||||
|
@ -283,25 +283,25 @@ class BCOTrainer(BaseTrainer):
|
||||
Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForSequenceClassification`.
|
||||
ref_model (`PreTrainedModelWrapper`):
|
||||
model ([`~transformers.PreTrainedModel`]):
|
||||
The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
|
||||
ref_model ([`PreTrainedModelWrapper`]):
|
||||
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
|
||||
and loss. If no reference model is provided, the trainer will create a reference model with the same
|
||||
architecture as the model to be optimized.
|
||||
args (`BCOConfig`):
|
||||
args ([`BCOConfig`]):
|
||||
The arguments to use for training.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
train_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
eval_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
data_collator (`transformers.DataCollator`, *optional*):
|
||||
data_collator ([`~transformers.DataCollator`], *optional*):
|
||||
The data collator to use for training. If None is specified, the default data collator
|
||||
(`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
|
||||
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
||||
sequences in the batch, given a dataset of paired sequences.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be
|
||||
|
@ -44,10 +44,12 @@ from .utils import log_table_to_comet_experiment
|
||||
|
||||
|
||||
if is_rich_available():
|
||||
from rich.columns import Columns
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress
|
||||
from rich.table import Table
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
@ -152,74 +154,105 @@ class RichProgressCallback(TrainerCallback):
|
||||
raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.")
|
||||
|
||||
self.training_bar = None
|
||||
self.prediction_bar = None
|
||||
|
||||
self.training_task_id = None
|
||||
self.prediction_task_id = None
|
||||
|
||||
self.evaluation_bar = None
|
||||
self.training_task = None
|
||||
self.evaluation_task = None
|
||||
self.rich_group = None
|
||||
self.rich_console = None
|
||||
|
||||
self.training_status = None
|
||||
self.current_step = None
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
self.training_bar = Progress()
|
||||
self.prediction_bar = Progress()
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
self.rich_console = Console()
|
||||
|
||||
self.training_status = self.rich_console.status("Nothing to log yet ...")
|
||||
|
||||
self.rich_group = Live(Panel(Group(self.training_bar, self.prediction_bar, self.training_status)))
|
||||
self.rich_group.start()
|
||||
|
||||
self.training_task_id = self.training_bar.add_task("[blue]Training the model", total=state.max_steps)
|
||||
self.current_step = 0
|
||||
self.training_bar = Progress()
|
||||
self.evaluation_bar = Progress()
|
||||
self.rich_console = Console()
|
||||
self.training_status = self.rich_console.status("Nothing to log yet ...")
|
||||
self.rich_group = Live(Panel(Group(self.training_bar, self.evaluation_bar, self.training_status)))
|
||||
self.rich_group.start()
|
||||
self.training_task = self.training_bar.add_task("[blue]Training ", total=state.max_steps)
|
||||
self.current_step = 0
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
self.training_bar.update(self.training_task_id, advance=state.global_step - self.current_step, update=True)
|
||||
self.current_step = state.global_step
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
self.training_bar.update(self.training_task, advance=state.global_step - self.current_step, update=True)
|
||||
self.current_step = state.global_step
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if state.is_world_process_zero and has_length(eval_dataloader):
|
||||
if self.prediction_task_id is None:
|
||||
self.prediction_task_id = self.prediction_bar.add_task(
|
||||
"[blue]Predicting on the evaluation dataset", total=len(eval_dataloader)
|
||||
)
|
||||
self.prediction_bar.update(self.prediction_task_id, advance=1, update=True)
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if has_length(eval_dataloader):
|
||||
if self.evaluation_task is None:
|
||||
self.evaluation_task = self.evaluation_bar.add_task("[blue]Evaluation", total=len(eval_dataloader))
|
||||
self.evaluation_bar.update(self.evaluation_task, advance=1, update=True)
|
||||
|
||||
def on_evaluate(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
if self.prediction_task_id is not None:
|
||||
self.prediction_bar.remove_task(self.prediction_task_id)
|
||||
self.prediction_task_id = None
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if self.evaluation_task is not None:
|
||||
self.evaluation_bar.remove_task(self.evaluation_task)
|
||||
self.evaluation_task = None
|
||||
|
||||
def on_predict(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
if self.prediction_task_id is not None:
|
||||
self.prediction_bar.remove_task(self.prediction_task_id)
|
||||
self.prediction_task_id = None
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if self.evaluation_task is not None:
|
||||
self.evaluation_bar.remove_task(self.evaluation_task)
|
||||
self.evaluation_task = None
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if state.is_world_process_zero and self.training_bar is not None:
|
||||
_ = logs.pop("total_flos", None)
|
||||
self.training_status.update(f"[bold green]Status = {str(logs)}")
|
||||
if not (state.is_world_process_zero and self.training_bar):
|
||||
return
|
||||
|
||||
# Group keys by top-level prefix
|
||||
grouped_logs = {}
|
||||
for key, value in logs.items():
|
||||
parts = key.split("/")
|
||||
group = parts[0] if len(parts) > 1 else None
|
||||
subkey = "/".join(parts[1:]) if len(parts) > 1 else key
|
||||
grouped_logs.setdefault(group, {})[subkey] = value
|
||||
|
||||
# Create a table per group
|
||||
tables = []
|
||||
for group_name, metrics in grouped_logs.items():
|
||||
table = Table(
|
||||
title=f"[bold blue]{group_name}[/]" if group_name else None, header_style="bold magenta", box=None
|
||||
)
|
||||
table.add_column("Metric", justify="left", no_wrap=True)
|
||||
table.add_column("Value", justify="right")
|
||||
|
||||
for metric, val in metrics.items():
|
||||
formatted = f"{val:.3f}" if isinstance(val, (float, int)) else str(val)
|
||||
table.add_row(metric, formatted)
|
||||
|
||||
tables.append(Panel(table, border_style="cyan", padding=(0, 1)))
|
||||
|
||||
# Arrange tables in columns using Columns
|
||||
column_layout = Columns(tables, equal=False, expand=True)
|
||||
self.training_status.update(
|
||||
Panel(column_layout, title=f"[bold green]Step {state.global_step}[/bold green]", border_style="green")
|
||||
)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
self.rich_group.stop()
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
self.training_bar = None
|
||||
self.prediction_bar = None
|
||||
self.training_task_id = None
|
||||
self.prediction_task_id = None
|
||||
self.rich_group = None
|
||||
self.rich_console = None
|
||||
self.training_status = None
|
||||
self.current_step = None
|
||||
self.rich_group.stop()
|
||||
self.training_bar = None
|
||||
self.evaluation_bar = None
|
||||
self.training_task = None
|
||||
self.evaluation_task = None
|
||||
self.rich_group = None
|
||||
self.rich_console = None
|
||||
self.training_status = None
|
||||
self.current_step = None
|
||||
|
||||
|
||||
def _win_rate_completions_df(
|
||||
@ -251,14 +284,14 @@ class WinRateCallback(TrainerCallback):
|
||||
```
|
||||
|
||||
Args:
|
||||
judge (`BasePairwiseJudge`):
|
||||
judge ([`BasePairwiseJudge`]):
|
||||
The judge to use for comparing completions.
|
||||
trainer (`Trainer`):
|
||||
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
|
||||
column containing the prompts for generating completions. If the `Trainer` has a reference model (via the
|
||||
`ref_model` attribute), it will use this reference model for generating the reference completions;
|
||||
otherwise, it defaults to using the initial model.
|
||||
generation_config (`GenerationConfig`, *optional*):
|
||||
generation_config ([`~transformers.GenerationConfig`], *optional*):
|
||||
The generation config to use for generating completions.
|
||||
num_prompts (`int`, *optional*):
|
||||
The number of prompts to generate completions for. If not provided, defaults to the number of examples in
|
||||
@ -439,7 +472,7 @@ class LogCompletionsCallback(TrainerCallback):
|
||||
trainer (`Trainer`):
|
||||
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
|
||||
column containing the prompts for generating completions.
|
||||
generation_config (`GenerationConfig`, *optional*):
|
||||
generation_config ([`~transformers.GenerationConfig`], *optional*):
|
||||
The generation config to use for generating completions.
|
||||
num_prompts (`int`, *optional*):
|
||||
The number of prompts to generate completions for. If not provided, defaults to the number of examples in
|
||||
@ -569,7 +602,7 @@ class WeaveCallback(TrainerCallback):
|
||||
Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions
|
||||
only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should
|
||||
have signature: `scorer(prompt: str, completion: str) -> Union[float, int]`
|
||||
generation_config (`GenerationConfig`, *optional*):
|
||||
generation_config ([`~transformers.GenerationConfig`], *optional*):
|
||||
Generation config to use for generating completions.
|
||||
num_prompts (`int` or `None`, *optional*):
|
||||
Number of prompts to generate completions for. If not provided, defaults to the number of examples in the
|
||||
|
@ -77,17 +77,17 @@ class CPOTrainer(BaseTrainer):
|
||||
Initialize CPOTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForSequenceClassification`.
|
||||
args (`CPOConfig`):
|
||||
model ([`~transformers.PreTrainedModel`]):
|
||||
The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
|
||||
args ([`CPOConfig`]):
|
||||
The CPO config arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
data_collator ([`~transformers.DataCollator`]):
|
||||
The data collator to use for training. If None is specified, the default data collator
|
||||
(`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
|
||||
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
||||
sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
train_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
eval_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
|
@ -177,6 +177,9 @@ class DataCollatorForPreference(DataCollatorMixin):
|
||||
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
|
||||
output["ref_chosen_logps"] = ref_chosen_logps
|
||||
output["ref_rejected_logps"] = ref_rejected_logps
|
||||
if "token_type_ids" in examples[0]:
|
||||
token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples]
|
||||
output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left")
|
||||
|
||||
return output
|
||||
|
||||
@ -197,13 +200,13 @@ class DPOTrainer(BaseTrainer):
|
||||
using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
|
||||
`args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
||||
ref_model (`PreTrainedModelWrapper`):
|
||||
ref_model ([`PreTrainedModelWrapper`]):
|
||||
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
|
||||
and loss. If no reference model is provided, the trainer will create a reference model with the same
|
||||
architecture as the model to be optimized.
|
||||
args ([`DPOConfig`], *optional*):
|
||||
Configuration for this trainer. If `None`, a default configuration is used.
|
||||
data_collator (`DataCollator`, *optional*):
|
||||
data_collator ([`~transformers.DataCollator`], *optional*):
|
||||
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
|
||||
Will default to [`DataCollatorForPreference`].
|
||||
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
||||
@ -689,7 +692,7 @@ class DPOTrainer(BaseTrainer):
|
||||
Args:
|
||||
features (`dict[str, str]`):
|
||||
Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`.
|
||||
processing_class (`PreTrainedTokenizerBase`):
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`]):
|
||||
Processing class used to process the data.
|
||||
max_prompt_length (`int` or `None`):
|
||||
Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated.
|
||||
@ -790,6 +793,8 @@ class DPOTrainer(BaseTrainer):
|
||||
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
|
||||
if "image_sizes" in processed_features:
|
||||
output["image_sizes"] = processed_features["image_sizes"][0]
|
||||
if "token_type_ids" in processed_features:
|
||||
output["token_type_ids"] = processed_features["token_type_ids"][0]
|
||||
|
||||
return output
|
||||
|
||||
@ -804,6 +809,7 @@ class DPOTrainer(BaseTrainer):
|
||||
"chosen_input_ids",
|
||||
"rejected_input_ids",
|
||||
"image_sizes",
|
||||
"token_type_ids",
|
||||
"ref_chosen_logps",
|
||||
"ref_rejected_logps",
|
||||
]
|
||||
@ -991,6 +997,8 @@ class DPOTrainer(BaseTrainer):
|
||||
)
|
||||
if "image_sizes" in batch:
|
||||
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
|
||||
if "token_type_ids" in batch:
|
||||
output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))
|
||||
|
||||
# Concatenate the chosen and rejected completions
|
||||
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
||||
@ -1516,6 +1524,9 @@ class DPOTrainer(BaseTrainer):
|
||||
# Concatenate the prompt and completion inputs
|
||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
prompt_token_type_ids = concatenated_batch["token_type_ids"]
|
||||
token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
|
||||
# Mask the prompt but not the completion for the loss
|
||||
loss_mask = torch.cat(
|
||||
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
|
||||
@ -1528,7 +1539,12 @@ class DPOTrainer(BaseTrainer):
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
attention_mask = attention_mask[:, : self.max_length]
|
||||
input_ids = input_ids[:, : self.max_length]
|
||||
loss_mask = loss_mask[:, : self.max_length]
|
||||
@ -1536,11 +1552,22 @@ class DPOTrainer(BaseTrainer):
|
||||
# Flush right before truncating left, then flush left
|
||||
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
|
||||
# [0, x, x, x, 0, 0]] [0, x, x, x]]
|
||||
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
token_type_ids = token_type_ids[:, -self.max_length :]
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
|
||||
input_ids = input_ids[:, -self.max_length :]
|
||||
attention_mask = attention_mask[:, -self.max_length :]
|
||||
loss_mask = loss_mask[:, -self.max_length :]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
|
||||
@ -1550,7 +1577,15 @@ class DPOTrainer(BaseTrainer):
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
model_kwargs["token_type_ids"] = token_type_ids
|
||||
|
||||
if self.use_logits_to_keep:
|
||||
# Compute logits_to_keep based on loss_mask pattern:
|
||||
|
@ -92,10 +92,10 @@ class GRPOConfig(TrainingArguments):
|
||||
cache_implementation (`str`, *optional*):
|
||||
Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
|
||||
generation_kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or `SamplingParams` (if
|
||||
using vLLM) when sampling completions. This can be used to further customize the generation behavior, such
|
||||
as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict with the other generation
|
||||
parameters (like `min_p`, `top_p`, etc.), they will override them.
|
||||
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
|
||||
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
|
||||
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
|
||||
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
|
||||
|
||||
> Parameters that control generation acceleration powered by vLLM
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
from collections import defaultdict, deque
|
||||
from contextlib import nullcontext
|
||||
@ -71,7 +70,6 @@ from .utils import (
|
||||
shuffle_sequence_dict,
|
||||
split_pixel_values_by_grid,
|
||||
split_tensor_dict,
|
||||
truncate_with_protected_tokens,
|
||||
unsplit_pixel_values_by_grid,
|
||||
)
|
||||
|
||||
@ -176,7 +174,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
|
||||
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
|
||||
`tokenizer.eos_token` will be used as the default.
|
||||
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*):
|
||||
reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
|
||||
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
||||
|
||||
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
||||
@ -275,7 +273,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
|
||||
|
||||
# Handle pad token for processors or tokenizers
|
||||
if isinstance(processing_class, ProcessorMixin):
|
||||
@ -291,10 +289,6 @@ class GRPOTrainer(BaseTrainer):
|
||||
self.pad_token = tokenizer.pad_token
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
self.eos_token_id = tokenizer.eos_token_id
|
||||
self.image_token = getattr(processing_class, "image_token", None)
|
||||
self.image_token_id = getattr(processing_class, "image_token_id", None)
|
||||
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
|
||||
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
|
||||
|
||||
# Reward functions
|
||||
if not isinstance(reward_funcs, list):
|
||||
@ -1092,58 +1086,12 @@ class GRPOTrainer(BaseTrainer):
|
||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
|
||||
prompt_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
|
||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
||||
# tokens are needed for generation.
|
||||
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
|
||||
protected = [token for token in protected if token is not None]
|
||||
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
|
||||
|
||||
prompts_text = self.processing_class.batch_decode(
|
||||
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
|
||||
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
|
||||
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
|
||||
# collapse them back into a single token string to match the original chat template in case it originally
|
||||
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
|
||||
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
|
||||
# the vision_start_token_id (e.g. <start_of_image>).
|
||||
if self.image_token is not None:
|
||||
escaped_img_token = re.escape(self.image_token)
|
||||
# Search for the image token in the chat template
|
||||
if re.search(escaped_img_token, self.processing_class.chat_template):
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
|
||||
if self.vision_end_token_id is not None:
|
||||
escaped_eoi_token = re.escape(
|
||||
self.processing_class.tokenizer.decode([self.vision_end_token_id])
|
||||
)
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If vision_end_token_id is None, just remove the image tokens
|
||||
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
|
||||
if images is not None:
|
||||
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.use_vllm:
|
||||
@ -1185,6 +1133,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
top_k=-1 if self.top_k is None else self.top_k,
|
||||
min_p=0.0 if self.min_p is None else self.min_p,
|
||||
max_tokens=self.max_completion_length,
|
||||
truncate_prompt_tokens=self.max_prompt_length,
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
generation_kwargs=self.args.generation_kwargs,
|
||||
)
|
||||
@ -1223,6 +1172,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
"top_k": -1 if self.top_k is None else self.top_k,
|
||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||
"max_tokens": self.max_completion_length,
|
||||
"truncate_prompt_tokens": self.max_prompt_length,
|
||||
"guided_decoding": guided_decoding,
|
||||
"logprobs": 0, # only return the logprob of the generated token
|
||||
}
|
||||
@ -1319,7 +1269,17 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
else:
|
||||
# Regular generation path
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
generate_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
max_length=self.max_prompt_length,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
generate_inputs = super()._prepare_inputs(generate_inputs)
|
||||
|
||||
with (
|
||||
profiling_context(self, "transformers.generate"),
|
||||
@ -1330,15 +1290,11 @@ class GRPOTrainer(BaseTrainer):
|
||||
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
||||
):
|
||||
prompt_completion_ids = unwrapped_model.generate(
|
||||
input_ids=prompt_ids,
|
||||
attention_mask=prompt_mask,
|
||||
**forward_kwargs,
|
||||
generation_config=self.generation_config,
|
||||
disable_compile=True,
|
||||
**generate_inputs, generation_config=self.generation_config, disable_compile=True
|
||||
)
|
||||
# Compute prompt length and extract completion ids
|
||||
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
|
||||
prompt_length = prompt_ids.size(1)
|
||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
|
@ -279,25 +279,25 @@ class KTOTrainer(BaseTrainer):
|
||||
Initialize KTOTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForSequenceClassification`.
|
||||
ref_model (`PreTrainedModelWrapper`):
|
||||
model ([`~transformers.PreTrainedModel`]):
|
||||
The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
|
||||
ref_model ([`PreTrainedModelWrapper`]):
|
||||
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
|
||||
and loss. If no reference model is provided, the trainer will create a reference model with the same
|
||||
architecture as the model to be optimized.
|
||||
args (`KTOConfig`):
|
||||
args ([`KTOConfig`]):
|
||||
The arguments to use for training.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
train_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
eval_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
data_collator (`transformers.DataCollator`, *optional*):
|
||||
data_collator ([`~transformers.DataCollator`], *optional*):
|
||||
The data collator to use for training. If None is specified, the default data collator
|
||||
(`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
|
||||
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
||||
sequences in the batch, given a dataset of paired sequences.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be
|
||||
|
@ -193,7 +193,7 @@ class ModelConfig:
|
||||
if self.torch_dtype and not self.dtype:
|
||||
warnings.warn(
|
||||
"`torch_dtype` is deprecated and will be removed in version 0.27.0, please use `dtype` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
self.dtype = self.torch_dtype
|
||||
|
||||
|
@ -58,25 +58,26 @@ class NashMDTrainer(OnlineDPOTrainer):
|
||||
It is implemented as a subclass of [`OnlineDPOTrainer`].
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
model ([`~transformers.PreTrainedModel`]):
|
||||
The model to train, preferably an `AutoModelForCausalLM`.
|
||||
ref_model (`PreTrainedModelWrapper`):
|
||||
ref_model ([`PreTrainedModelWrapper`]):
|
||||
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
|
||||
and loss. If no reference model is provided, the trainer will create a reference model with the same
|
||||
architecture as the model to be optimized.
|
||||
reward_funcs (`transformers.PreTrainedModel`):
|
||||
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
||||
judge (`BasePairwiseJudge`):
|
||||
reward_funcs ([`~transformers.PreTrainedModel`]):
|
||||
The reward model to score completions with, preferably an
|
||||
[`~transformers.AutoModelForSequenceClassification`].
|
||||
judge ([`BasePairwiseJudge`]):
|
||||
The judge to use for pairwise comparison of model completions.
|
||||
args (`NashMDConfig`):
|
||||
args ([`NashMDConfig`]):
|
||||
The NashMD config arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
data_collator ([`~transformers.DataCollator`]):
|
||||
The data collator to use for training. If None is specified, the default data collator
|
||||
(`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
|
||||
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
||||
sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
train_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
eval_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
|
@ -95,10 +95,10 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
cache_implementation (`str`, *optional*):
|
||||
Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
|
||||
generation_kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or `SamplingParams` (if
|
||||
using vLLM) when sampling completions. This can be used to further customize the generation behavior, such
|
||||
as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict with the other generation
|
||||
parameters (like `min_p`, `top_p`, etc.), they will override them.
|
||||
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
|
||||
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
|
||||
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
|
||||
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
|
||||
|
||||
> Parameters that control generation acceleration powered by vLLM
|
||||
|
||||
@ -412,3 +412,10 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
|
||||
if hasattr(self.beta, "__len__") and len(self.beta) == 1:
|
||||
self.beta = self.beta[0]
|
||||
|
||||
if self.max_new_tokens >= self.max_length:
|
||||
warnings.warn(
|
||||
f"The configuration has `max_new_tokens` ({self.max_new_tokens}) >= `max_length` ({self.max_length}). "
|
||||
"This will cause prompts to be truncated or completely removed in the forward pass. "
|
||||
"To preserve prompts, ensure e.g. `max_length > max_new_tokens + 512`. ",
|
||||
)
|
||||
|
@ -57,8 +57,13 @@ from ..data_utils import apply_chat_template, is_conversational, maybe_apply_cha
|
||||
from ..extras.profiling import profiling_context
|
||||
from ..extras.vllm_client import VLLMClient
|
||||
from ..import_utils import is_vllm_available
|
||||
from ..models import create_reference_model, prepare_peft_model
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from ..models import (
|
||||
create_reference_model,
|
||||
prepare_deepspeed,
|
||||
prepare_fsdp,
|
||||
prepare_peft_model,
|
||||
unwrap_model_for_generation,
|
||||
)
|
||||
from .base_trainer import BaseTrainer
|
||||
from .judges import BasePairwiseJudge
|
||||
from .online_dpo_config import OnlineDPOConfig
|
||||
@ -69,7 +74,6 @@ from .utils import (
|
||||
empty_cache,
|
||||
ensure_master_addr_port,
|
||||
pad,
|
||||
prepare_deepspeed,
|
||||
truncate_right,
|
||||
)
|
||||
|
||||
@ -113,10 +117,10 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
|
||||
`args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
||||
ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
||||
ref_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `None`):
|
||||
The reference model to use for training. If None is specified, the reference model will be created from the
|
||||
model.
|
||||
judge (`BasePairwiseJudge`):
|
||||
judge ([`BasePairwiseJudge`]):
|
||||
The judge to use for pairwise comparison of model completions.
|
||||
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`, *optional*):
|
||||
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
||||
@ -127,11 +131,11 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
- A list of reward functions: Must all be of compatible types.
|
||||
|
||||
Note: Only one of `judge`, or `reward_funcs` should be provided.
|
||||
args (`OnlineDPOConfig`):
|
||||
args ([`OnlineDPOConfig`]):
|
||||
The online DPO config arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
data_collator ([`~transformers.DataCollator`]):
|
||||
The data collator to use for training. If None is specified, the default data collator
|
||||
(`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
|
||||
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
||||
sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
||||
The dataset to use for training.
|
||||
@ -141,7 +145,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*):
|
||||
reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
|
||||
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
||||
|
||||
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
||||
@ -330,7 +334,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
logger.warning(
|
||||
"The `missing_eos_penalty` parameter is deprecated when used with the deprecated `reward_model` parameter. "
|
||||
"Please use `reward_funcs` instead of `reward_model` to continue using this feature.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
@ -588,24 +592,20 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
|
||||
self.generation_config = GenerationConfig(**generation_kwargs)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
if self.ref_model is not None:
|
||||
self.ref_model = prepare_deepspeed(
|
||||
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
||||
)
|
||||
# Prepare reward function models for DeepSpeed
|
||||
if self.reward_funcs is not None:
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if self.ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
elif self.is_fsdp_enabled:
|
||||
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
if self.reward_funcs is not None:
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if self.is_deepspeed_enabled:
|
||||
self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
|
||||
else:
|
||||
if self.ref_model is not None:
|
||||
self.ref_model = self.ref_model.to(self.accelerator.device)
|
||||
# Prepare reward function models for FSDP/regular training
|
||||
if self.reward_funcs is not None:
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
# Set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
|
||||
else:
|
||||
# set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
|
||||
self.reward_funcs[i] = self.accelerator.prepare_model(
|
||||
reward_func, evaluation_mode=True, device_placement=True
|
||||
)
|
||||
@ -833,8 +833,10 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
|
||||
def _generate_vllm_colocate(self, prompts, images=None):
|
||||
"""Generate completions using vLLM colocate mode"""
|
||||
# Update model weights if needed
|
||||
self._move_model_to_vllm()
|
||||
# Update model weights if needed - only after gradient accumulation completes
|
||||
if self.state.global_step != self._last_loaded_step:
|
||||
self._move_model_to_vllm()
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
# Apply chat template if conversational
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
@ -1234,10 +1236,12 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
# Get the logprobs of the completions from the model
|
||||
output = model(prompt_completion_ids, **model_kwargs)
|
||||
|
||||
# There is 1 offset, because the model predict the next token
|
||||
# There is 1 offset, because the model predicts the next token
|
||||
prompt_len = prompt_ids.size(1)
|
||||
start_idx = prompt_len - 1 if prompt_len > 0 else 0
|
||||
logits = output.logits[:, start_idx:-1]
|
||||
# Only slice off the last logit when we have a prompt, otherwise we need all logits
|
||||
end_idx = -1 if prompt_len > 0 else None
|
||||
logits = output.logits[:, start_idx:end_idx]
|
||||
|
||||
# Take the completion tokens logprob
|
||||
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
|
||||
|
@ -81,17 +81,17 @@ class ORPOTrainer(BaseTrainer):
|
||||
Initialize ORPOTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForSequenceClassification`.
|
||||
args (`ORPOConfig`):
|
||||
model ([`~transformers.PreTrainedModel`]):
|
||||
The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
|
||||
args ([`ORPOConfig`]):
|
||||
The ORPO config arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
data_collator ([`~transformers.DataCollator`]):
|
||||
The data collator to use for training. If None is specified, the default data collator
|
||||
(`DPODataCollatorWithPadding`) will be used which will pad the sequences to the maximum length of the
|
||||
([`DPODataCollatorWithPadding`]) will be used which will pad the sequences to the maximum length of the
|
||||
sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
train_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
eval_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
|
@ -51,17 +51,17 @@ class PRMTrainer(BaseTrainer):
|
||||
Initialize PRMTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
model ([`~transformers.PreTrainedModel`]):
|
||||
The model to train, preferably an `AutoModelForTokenClassification`.
|
||||
args (`PRMConfig`):
|
||||
args ([`PRMConfig`]):
|
||||
The arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
data_collator ([`~transformers.DataCollator`]):
|
||||
The data collator to use for training. If None is specified, the default data collator
|
||||
(`DataCollatorForTokenClassification`) will be used which will pad the sequences to the maximum length of
|
||||
the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
([`~transformers.DataCollatorForTokenClassification`]) will be used which will pad the sequences to the
|
||||
maximum length of the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
eval_dataset ([`~datasets.Dataset`]):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
@ -219,7 +219,7 @@ class PRMTrainer(BaseTrainer):
|
||||
Args:
|
||||
features (`dict[str, str]`):
|
||||
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
|
||||
Tokenizer used to process the data.
|
||||
step_separator (`str`):
|
||||
Separator between steps in the completion.
|
||||
|
@ -93,10 +93,10 @@ class RLOOConfig(TrainingArguments):
|
||||
cache_implementation (`str`, *optional*):
|
||||
Implementation of the cache method for faster generation when `use_vllm` is set to `False`.
|
||||
generation_kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or `SamplingParams` (if
|
||||
using vLLM) when sampling completions. This can be used to further customize the generation behavior, such
|
||||
as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict with the other generation
|
||||
parameters (like `min_p`, `top_p`, etc.), they will override them.
|
||||
Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or
|
||||
`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the
|
||||
generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict
|
||||
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
|
||||
|
||||
> Parameters that control generation acceleration powered by vLLM
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict, deque
|
||||
@ -71,7 +70,6 @@ from .utils import (
|
||||
shuffle_sequence_dict,
|
||||
split_pixel_values_by_grid,
|
||||
split_tensor_dict,
|
||||
truncate_with_protected_tokens,
|
||||
unsplit_pixel_values_by_grid,
|
||||
)
|
||||
|
||||
@ -173,7 +171,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
|
||||
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
|
||||
`tokenizer.eos_token` will be used as the default.
|
||||
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*):
|
||||
reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
|
||||
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
||||
|
||||
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
||||
@ -394,7 +392,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
|
||||
|
||||
# Handle pad token for processors or tokenizers
|
||||
if isinstance(processing_class, ProcessorMixin):
|
||||
@ -410,10 +408,6 @@ class RLOOTrainer(BaseTrainer):
|
||||
self.pad_token = tokenizer.pad_token
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
self.eos_token_id = tokenizer.eos_token_id
|
||||
self.image_token = getattr(processing_class, "image_token", None)
|
||||
self.image_token_id = getattr(processing_class, "image_token_id", None)
|
||||
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
|
||||
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
|
||||
|
||||
# Reward functions
|
||||
if not isinstance(reward_funcs, list):
|
||||
@ -1088,58 +1082,12 @@ class RLOOTrainer(BaseTrainer):
|
||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
|
||||
prompt_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
|
||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
||||
# tokens are needed for generation.
|
||||
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
|
||||
protected = [token for token in protected if token is not None]
|
||||
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
|
||||
|
||||
prompts_text = self.processing_class.batch_decode(
|
||||
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
|
||||
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
|
||||
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
|
||||
# collapse them back into a single token string to match the original chat template in case it originally
|
||||
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
|
||||
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
|
||||
# the vision_start_token_id (e.g. <start_of_image>).
|
||||
if self.image_token is not None:
|
||||
escaped_img_token = re.escape(self.image_token)
|
||||
# Search for the image token in the chat template
|
||||
if re.search(escaped_img_token, self.processing_class.chat_template):
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
|
||||
if self.vision_end_token_id is not None:
|
||||
escaped_eoi_token = re.escape(
|
||||
self.processing_class.tokenizer.decode([self.vision_end_token_id])
|
||||
)
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If vision_end_token_id is None, just remove the image tokens
|
||||
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
|
||||
if images is not None:
|
||||
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.use_vllm:
|
||||
@ -1181,6 +1129,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
top_k=-1 if self.top_k is None else self.top_k,
|
||||
min_p=0.0 if self.min_p is None else self.min_p,
|
||||
max_tokens=self.max_completion_length,
|
||||
truncate_prompt_tokens=self.max_prompt_length,
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
generation_kwargs=self.args.generation_kwargs,
|
||||
)
|
||||
@ -1218,6 +1167,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
"top_k": -1 if self.top_k is None else self.top_k,
|
||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||
"max_tokens": self.max_completion_length,
|
||||
"truncate_prompt_tokens": self.max_prompt_length,
|
||||
"guided_decoding": guided_decoding,
|
||||
}
|
||||
if self.args.generation_kwargs is not None:
|
||||
@ -1305,7 +1255,17 @@ class RLOOTrainer(BaseTrainer):
|
||||
|
||||
else:
|
||||
# Regular generation path
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
generate_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
max_length=self.max_prompt_length,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
generate_inputs = super()._prepare_inputs(generate_inputs)
|
||||
|
||||
with (
|
||||
profiling_context(self, "transformers.generate"),
|
||||
@ -1316,15 +1276,11 @@ class RLOOTrainer(BaseTrainer):
|
||||
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
||||
):
|
||||
prompt_completion_ids = unwrapped_model.generate(
|
||||
input_ids=prompt_ids,
|
||||
attention_mask=prompt_mask,
|
||||
**forward_kwargs,
|
||||
generation_config=self.generation_config,
|
||||
disable_compile=True,
|
||||
**generate_inputs, generation_config=self.generation_config, disable_compile=True
|
||||
)
|
||||
# Compute prompt length and extract completion ids
|
||||
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
|
||||
prompt_length = prompt_ids.size(1)
|
||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
|
@ -273,9 +273,9 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
Additional keys may be present depending on the processor, such as `"image_grid_thw"`.
|
||||
|
||||
Args:
|
||||
processor (`ProcessorMixin`):
|
||||
The processor used to tokenize text and process images. It must be a subclass of `ProcessorMixin` and
|
||||
include a `tokenizer` with a defined `pad_token_id`.
|
||||
processor ([`~transformers.ProcessorMixin`]):
|
||||
The processor used to tokenize text and process images. It must be a subclass of
|
||||
[`~transformers.ProcessorMixin`] and include a `tokenizer` with a defined `pad_token_id`.
|
||||
max_length (`int` or `None`, optional, defaults to `None`):
|
||||
Maximum sequence length for input tokens. If `None`, no truncation is applied.
|
||||
completion_only_loss (`bool`, *optional*, defaults to `False`):
|
||||
|
@ -226,7 +226,7 @@ class RewardDataCollatorWithPadding:
|
||||
`trl.trainer.reward_trainer.DataCollatorForPreference` instead.
|
||||
|
||||
Args:
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
|
||||
padding_strategy to pass to the tokenizer.
|
||||
@ -245,7 +245,7 @@ class RewardDataCollatorWithPadding:
|
||||
warnings.warn(
|
||||
"The `RewardDataCollatorWithPadding` is deprecated and will be removed in version 0.27.0. Please use "
|
||||
"`trl.trainer.reward_trainer.DataCollatorForPreference` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -1111,7 +1111,7 @@ def generate(
|
||||
The tensor containing the input queries.
|
||||
pad_token_id (`int`):
|
||||
The token ID representing the pad token.
|
||||
generation_config (`GenerationConfig`):
|
||||
generation_config ([`~transformers.GenerationConfig`]):
|
||||
The configuration for the generation process.
|
||||
|
||||
Returns:
|
||||
@ -1263,7 +1263,7 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
|
||||
Args:
|
||||
inputs (`torch.Tensor`):
|
||||
The input tensor to be decoded.
|
||||
tokenizer (`transformers.PreTrainedTokenizerBase`):
|
||||
tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
|
||||
The tokenizer used to decode the input tensor.
|
||||
|
||||
Returns:
|
||||
@ -1273,7 +1273,7 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
|
||||
warnings.warn(
|
||||
"The function `decode_and_strip_padding` is deprecated and will be removed in a version 0.25.0. If you want "
|
||||
"to keep using it, please copy the code into your codebase and use it from there.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False)
|
||||
return [d.replace(tokenizer.pad_token, "") for d in decoded]
|
||||
@ -1294,7 +1294,7 @@ def generate_model_card(
|
||||
comet_url: Optional[str] = None,
|
||||
) -> ModelCard:
|
||||
"""
|
||||
Generate a `ModelCard` from a template.
|
||||
Generate a [`~huggingface_hub.ModelCard`] from a template.
|
||||
|
||||
Args:
|
||||
base_model (`str` or `None`):
|
||||
@ -1323,7 +1323,7 @@ def generate_model_card(
|
||||
ArXiv paper ID as `YYMM.NNNNN`.
|
||||
|
||||
Returns:
|
||||
`ModelCard`:
|
||||
[`~huggingface_hub.ModelCard`]:
|
||||
A ModelCard object.
|
||||
"""
|
||||
card_data = ModelCardData(
|
||||
@ -1377,7 +1377,7 @@ def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None:
|
||||
Args:
|
||||
name (`str`):
|
||||
Table name.
|
||||
table (`pd.DataFrame`):
|
||||
table (`pandas.DataFrame`):
|
||||
The Pandas DataFrame containing the table to log.
|
||||
"""
|
||||
if not is_comet_available():
|
||||
@ -1925,47 +1925,6 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch
|
||||
return batch
|
||||
|
||||
|
||||
def truncate_with_protected_tokens(ids: list[int], target_length: int, protected_tokens: list[int]) -> list[int]:
|
||||
"""
|
||||
Truncate list to target length while preserving protected tokens.
|
||||
|
||||
Args:
|
||||
ids (`list[int]`):
|
||||
Input sequence of token IDs.
|
||||
target_length (`int`):
|
||||
Desired length of the output sequence.
|
||||
protected_tokens (`list[int]`):
|
||||
List of token IDs that should be preserved in the output.
|
||||
|
||||
Returns:
|
||||
`list[int]`: Truncated sequence.
|
||||
|
||||
Raises:
|
||||
`ValueError`: If `len(protected_tokens ∩ seq) > target_length`.
|
||||
"""
|
||||
protected_set = set(protected_tokens)
|
||||
|
||||
# Count protected tokens
|
||||
num_protected = sum(1 for t in ids if t in protected_set)
|
||||
if num_protected > target_length:
|
||||
raise ValueError(
|
||||
f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). "
|
||||
f"Please increase target length to at least {num_protected} or disable truncation."
|
||||
)
|
||||
num_non_protected_needed = target_length - num_protected
|
||||
result = []
|
||||
|
||||
# Iterate backward to select all protected tokens and rightmost non-protected tokens
|
||||
for t in reversed(ids):
|
||||
if t in protected_set:
|
||||
result.append(t)
|
||||
elif num_non_protected_needed > 0:
|
||||
result.append(t)
|
||||
num_non_protected_needed -= 1
|
||||
# Reverse to restore original order
|
||||
return result[::-1]
|
||||
|
||||
|
||||
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user