Merge branch 'main' into py3.14

This commit is contained in:
Quentin Gallouédec
2025-10-07 09:02:52 -06:00
committed by GitHub
5 changed files with 148 additions and 14 deletions

View File

@ -7,6 +7,44 @@
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
Below is the current list of TRL trainers, organized by method type (⚡️ = vLLM support).
<div style="display: flex; justify-content: space-between; width: 100%; gap: 2rem;">
<div style="flex: 1; min-width: 0;">
**Online methods**
- [`GRPOTrainer`] ⚡️
- [`RLOOTrainer`] ⚡️
- [`OnlineDPOTrainer`] ⚡️
- [`NashMDTrainer`] ⚡️
- [`XPOTrainer`] ⚡️
- [`PPOTrainer`]
**Reward modeling**
- [`PRMTrainer`]
- [`RewardTrainer`]
</div>
<div style="flex: 1; min-width: 0;">
**Offline methods**
- [`SFTTrainer`]
- [`DPOTrainer`]
- [`ORPOTrainer`]
- [`BCOTrainer`]
- [`CPOTrainer`]
- [`KTOTrainer`]
**Knowledge distillation**
- [`GKDTrainer`]
</div>
</div>
## 🎉 What's New
**✨ OpenAI GPT OSS Support**: TRL now fully supports fine-tuning the latest [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4)! Check out the:

View File

@ -42,7 +42,7 @@ from trl import SFTTrainer, SFTConfig
dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
peft_config = LoraConfig(lora_r=256, lora_alpha=16, lora_target_modules="all-linear")
peft_config = LoraConfig(r=256, lora_alpha=16, target_modules="all-linear")
training_args = SFTConfig(
learning_rate=2e-4,
@ -245,9 +245,9 @@ def strip_reasoning_accuracy_reward(completions, **kwargs):
...
peft_config = LoraConfig(
lora_r=1,
r=1,
lora_alpha=32,
lora_target_modules="all-linear"
target_modules="all-linear"
)
training_args = GRPOConfig(
@ -419,7 +419,7 @@ The blog post defines the ideal dataset size for LoRA to match full fine-tuning
### 3. *"FullFT and high-rank LoRAs have similar learning curves"*
Counterintuitively, the blog post recommends using similar learning rates to full fine-tuning. In the TRL script, we could use `--learning_rate` to set the learning rate. The \\( \frac{1}{r} \\) scaling in LoRA makes the optimal learning rate approximately rank-independent.
Counterintuitively, the blog post recommends using a higher learning rate than for full fine-tuning. In the table above, we used 1.0e-5 for LoRA and 1.0e-6 for full fine-tuning. In the TRL script, we could use `--learning_rate` to set the learning rate. The \\( \frac{1}{r} \\) scaling in LoRA makes the optimal learning rate approximately rank-independent.
![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/2.png)

View File

@ -155,7 +155,6 @@ def init_weights_tiny_model(model):
for model_id, config_class, model_class, suffix in [
("bigscience/bloomz-560m", BloomConfig, BloomForCausalLM, None),
("CohereForAI/aya-expanse-8b", CohereConfig, CohereForCausalLM, None),
("databricks/dbrx-instruct", DbrxConfig, DbrxForCausalLM, None),
("deepseek-ai/DeepSeek-R1", DeepseekV3Config, DeepseekV3ForCausalLM, None),
# It's important to have R1-0528 as it doesn't have the same chat template
("deepseek-ai/DeepSeek-R1-0528", DeepseekV3Config, DeepseekV3ForCausalLM, "0528"),
@ -209,6 +208,17 @@ for model_id, config_class, model_class, suffix in [
init_weights_tiny_model(model)
push_to_hub(model, tokenizer, "tiny", suffix)
# Special case for databricks/dbrx-instruct as it requires specific changes in the config
model_id = "databricks/dbrx-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = DbrxConfig.from_pretrained(model_id, n_layers=2, n_heads=16, d_model=24)
# transformers mistakenly ignores ffn_config keys when loading from pretrained. We need to set them manually after
# loading the config
config.ffn_config.ffn_hidden_size = 24
config.ffn_config.hidden_size = 24
model = DbrxForCausalLM(config).to(dtype=torch.bfloat16)
init_weights_tiny_model(model)
push_to_hub(model, tokenizer, "tiny")
# Two slightly bigger models, required for vLLM testing
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct")

View File

@ -16,6 +16,8 @@ import gc
import pytest
import torch
import transformers
from packaging import version
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig
@ -63,6 +65,12 @@ class BaseTester:
Test if the v-head is added to the model successfully
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
model = self.trl_model_class.from_pretrained(model_name)
assert hasattr(model, "v_head")
@ -71,6 +79,12 @@ class BaseTester:
Test if the v-head has the correct shape
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
model = self.trl_model_class.from_pretrained(model_name)
assert model.v_head.summary.weight.shape[0] == 1
@ -80,6 +94,12 @@ class BaseTester:
than zeros by default.
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
model = self.trl_model_class.from_pretrained(model_name)
assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias))
@ -89,6 +109,12 @@ class BaseTester:
`from_pretrained`.
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
pretrained_model = self.transformers_model_class.from_pretrained(model_name)
model = self.trl_model_class.from_pretrained(pretrained_model)
assert hasattr(model, "v_head")
@ -99,6 +125,12 @@ class BaseTester:
additional modules (e.g. v_head)
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
model = self.trl_model_class.from_pretrained(model_name)
model.save_pretrained(self.tmp_dir)
@ -114,6 +146,12 @@ class BaseTester:
Test if the model can be saved and loaded from a directory and get the same weights - sharded case
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
model = self.trl_model_class.from_pretrained(model_name)
model.save_pretrained(self.tmp_dir)
@ -129,6 +167,12 @@ class BaseTester:
Test if the model can be saved and loaded using transformers and get the same weights - sharded case
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
trl_model = self.trl_model_class.from_pretrained(model_name)
@ -150,6 +194,12 @@ class BaseTester:
of the super class to check if the weights are the same.
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name)
trl_model = self.trl_model_class.from_pretrained(model_name)
@ -200,6 +250,12 @@ class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
EXPECTED_OUTPUT_SIZE = 3
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
outputs = model(input_ids)
@ -213,6 +269,12 @@ class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
pretrained_model = self.transformers_model_class.from_pretrained(model_name)
pretrained_model.config.summary_dropout_prob = 0.5
model = self.trl_model_class.from_pretrained(pretrained_model)
@ -225,6 +287,11 @@ class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
v_head_kwargs = {"summary_dropout_prob": 0.5}
model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs)
@ -242,6 +309,12 @@ class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
r"""
Test if `generate` works for every model
"""
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
pytest.xfail("DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version")
generation_config = GenerationConfig(max_new_tokens=9)
model = self.trl_model_class.from_pretrained(model_name).to(self.device)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device)
@ -256,6 +329,12 @@ class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
run a dummy forward pass without any issue.
"""
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
trl_model = self.trl_model_class.from_pretrained(model_name, dtype=torch.bfloat16).to(self.device)
lm_head_namings = ["lm_head", "embed_out", "output_layer"]
@ -276,6 +355,12 @@ class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase):
@pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.")
def test_push_to_hub(self):
for model_name in self.all_model_names:
if model_name == "trl-internal-testing/tiny-DbrxForCausalLM" and version.parse(
transformers.__version__
) < version.parse("4.58.0.dev0"):
# DbrxConfig generated after 4.58.0 isn't compatible with modeling code before this version
continue
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
if "sharded" in model_name:
model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB")

View File

@ -938,6 +938,7 @@ class SFTTrainer(BaseTrainer):
prompt_ids = processing_class.apply_chat_template(
example["prompt"],
tokenize=True,
add_generation_prompt=True,
tools=example.get("tools"),
**example.get("chat_template_kwargs", {}),
)
@ -975,7 +976,7 @@ class SFTTrainer(BaseTrainer):
"token handling. Verify that the tokenizer is processing text consistently."
)
# Create a completion mask
# Create completion mask
completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids))
output["input_ids"] = prompt_completion_ids
output["completion_mask"] = completion_mask
@ -995,17 +996,17 @@ class SFTTrainer(BaseTrainer):
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
# even for single examples, while for LLMs it returns lists of ints.
processed = {k: v[0] if isinstance(v[0], list) else v for k, v in processed.items()}
if "assistant_masks" in processed and 1 not in processed["assistant_masks"]:
raise RuntimeError(
"You're using `assistant_only_loss=True`, but at least one example has no "
"assistant tokens. This usually means the tokenizer's chat template doesn't "
"generate assistant masks — it may be missing the `{% generation %}` keyword. Please "
"check the template and ensure it's correctly configured to support assistant "
"masking."
)
output = {k: processed[k] for k in ("input_ids", "assistant_masks") if k in processed}
else:
output = {"input_ids": processing_class(text=example[dataset_text_field])["input_ids"]}
if "assistant_masks" in output and 1 not in output["assistant_masks"]:
raise RuntimeError(
"You're using `assistant_only_loss=True`, but at least one example has no assistant "
"tokens. This usually means the tokenizer's chat template doesn't generate assistant "
"masks — it may be missing the `{% generation %}` keyword. Please check the template and "
"ensure it's correctly configured to support assistant masking."
)
return output
dataset = dataset.map(