Files
trl/docs/source/rloo_trainer.md
2025-10-07 14:35:02 +02:00

31 KiB
Raw Blame History

RLOO Trainer

Overview

TRL supports the RLOO Trainer for training language models, as described in the paper Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs by Arash Ahmadian, Chris Cremer, Matthias Gallé, Marzieh Fadaee, Julia Kreutzer, Ahmet Üstün and Sara Hooker.

The abstract from the paper is the following:

AI alignment in the shape of Reinforcement Learning from Human Feedback (RLHF) is increasingly treated as a crucial ingredient for high performance large language models. Proximal Policy Optimization (PPO) has been positioned by recent literature as the canonical method for the RL part of RLHF However, it involves both high computational cost and sensitive hyperparameter tuning. We posit that most of the motivational principles that led to the development of PPO are less of a practical concern in RLHF and advocate for a less computationally expensive method that preserves and even increases performance. We revisit the formulation of alignment from human preferences in the context of RL. Keeping simplicity as a guiding principle, we show that many components of PPO are unnecessary in an RLHF context and that far simpler REINFORCE-style optimization variants outperform both PPO and newly proposed “RL-free” methods such as DPO and RAFT. Our work suggests that careful adaptation to LLMs alignment characteristics enables benefiting from online RL optimization at low cost.

This post-training method was contributed by Costa Huang and later refactored by Shirin Yamani.

Quick start

This example demonstrates how to train a model using the RLOO method. We train a Qwen 0.5B Instruct model with the prompts from the UltraFeedback prompts dataset. You can view the data in the dataset here:

Below is the script to train the model.

# train_rloo.py
from datasets import load_dataset
from trl import RLOOConfig, RLOOTrainer

dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

# Dummy reward function for demonstration purposes
def reward_num_unique_letters(completions, **kwargs):
    """Reward function that rewards completions with more unique letters."""
    completion_contents = [completion[0]["content"] for completion in completions]
    return [float(len(set(content))) for content in completion_contents]

training_args = RLOOConfig(output_dir="Qwen2-0.5B-RLOO")
trainer = RLOOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_num_unique_letters,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

Execute the script using the following command:

accelerate launch train_rloo.py

Looking deeper into the RLOO method

RLOO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind RLOO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how RLOO works, it can be broken down into four main steps: Generating completions, computing the advantage, estimating the KL divergence, and computing the loss.

RLOO

Generating completions

At each training step, we sample a batch of prompts and generate a set of \( G \) completions for each prompt (denoted as \( o_i \)).

Computing the reward

In RLOO, the reward consists of two components: the reward provided by the reward model (or reward function) and a KL penalty that discourages the policy from deviating too far from a fixed reference policy

  1. For each of the \( G \) generated sequences \( o_i = (o_{i,1}, \dots, o_{i,T}) \) conditioned on a query \( q \), we compute a scalar reward using a reward model \( R(o_i, q) \).
  2. Concurrently, we estimate the KL divergence between the current policy \( \pi_\theta \) and the fixed reference policy \( \pi_{\text{ref}} \) over the sequence. The KL estimate for sequence \( o_i \) is:

\mathbb{D}_{\mathrm{KL}}\!\left[\pi_\theta\|\pi_{\mathrm{ref}}\right] = \sum_{t=1}^T \log \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\mathrm{ref}}(o_{i,t} \mid q, o_{i,<t})}.

The final reward assigned to sequence \( o_i \) is then:


r_i = R(o_i, q) - \beta \, \mathbb{D}_{\mathrm{KL}}\!\left[\pi_\theta \|\pi_{\mathrm{ref}}\right],

where \( \beta > 0 \) controls the strength of the KL penalty.

Tip

In a purely online setting (num_iterations = 1, default), the data are generated by the current policy. In this case, the KL penalty is computed directly using the current policy.

In the more general setting (e.g., multiple gradient steps per batch), the data are instead generated by an earlier snapshot \( \pi_{\text{old}} \). To keep the penalty consistent with the sampling distribution, the KL is defined with respect to this policy:


\mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \,\|\, \pi_{\text{ref}}\right].

Equivalently, for a sampled sequence o, the Monte Carlo estimate is


\mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \|\pi_{\mathrm{ref}}\right] = \sum_{t=1}^T \log \frac{\pi_{\text{old}}(o_{i,t} \mid q, o_{i,<t})}{\pi_{\mathrm{ref}}(o_{i,t} \mid q, o_{i,<t})}.

Computing the advantage

Once the rewards for each completion have been computed, we calculate a baseline as the average reward of all other samples in the same batch, excluding the current sample. This baseline is used to reduce the variance of the policy gradient estimate. The advantage for each completion is then obtained as the difference between its own reward and this leave-one-out baseline.

Formally, for a batch of G completions, the baseline for completion is:


b_i = \frac{1}{G-1} \sum_{j \neq i} r_j

and then the advantage for each completion is computed as the difference between its reward and the baseline:


A_i = r_i - b_i

Computing the loss

The REINFORCE loss is simply defined as:


\mathcal{L}_{\text{RLOO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \hat{A}_i \, \log \pi_\theta(o_i \mid q)

In practice, performing multiple gradient steps on the same batch makes the actions effectively off-policy relative to the current parameters. To correct for this, we introduce the importance sampling ratio. To prevent excessively large updates when the policy changes between sampling and gradient steps, we clip this ratio:


\mathcal{L}_{\text{RLOO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \min \left( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)} \hat{A}_i, \, \text{clip}\left(\frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)}, 1-\epsilon, 1+\epsilon\right) \hat{A}_i \right)

In a fully online, single-step setting (default), \( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)} = 1 \) and this reduces to standard REINFORCE.

Logged metrics

While training and evaluating, we record the following reward metrics:

  • num_tokens: The total number of tokens processed so far, including both prompts and completions.
  • completions/mean_length: The average length of generated completions.
  • completions/min_length: The minimum length of generated completions.
  • completions/max_length: The maximum length of generated completions.
  • completions/mean_terminated_length: The average length of generated completions that terminate with EOS.
  • completions/min_terminated_length: The minimum length of generated completions that terminate with EOS.
  • completions/max_terminated_length: The maximum length of generated completions that terminate with EOS.
  • completions/clipped_ratio: The ratio of truncated (clipped) completions.
  • reward/{reward_func_name}/mean: The average reward from a specific reward function.
  • reward/{reward_func_name}/std: The standard deviation of the reward from a specific reward function.
  • reward: The overall average reward after applying reward weights.
  • reward_std: The standard deviation of rewards after applying reward weights. This is the average of the per-group standard deviations.
  • frac_reward_zero_std: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
  • entropy: Average entropy of token predictions across generated completions. (If mask_truncated_completions=True, masked sequences tokens are excluded.)
  • kl: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if beta is nonzero.
  • clip_ratio/region_mean: The ratio of sequence probabilities where the RLOO objective is clipped to stay within the trust region:

\text{clip}\left( r_{i}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i}(\theta) = \frac{\pi_\theta(o_{i} \mid q)}{\pi_{\theta_{\text{old}}}(o_{i} \mid q)}\,.
A higher value means more samples are clipped, which constrains how much the policy $\pi_\theta$ can change.
  • clip_ratio/low_mean: The average ratio of sequence probabilities that were clipped on the lower bound of the trust region: \(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\)
  • clip_ratio/low_min: The minimum ratio of sequence probabilities that were clipped on the lower bound of the trust region: \(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\)
  • clip_ratio/high_mean: The average ratio of sequence probabilities that were clipped on the upper bound of the trust region: \(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\)
  • clip_ratio/high_max: The maximum ratio of sequence probabilities that were clipped on the upper bound of the trust region: \(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\).

Customization

Speed up training with vLLM-powered generation

Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use vLLM, a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with

pip install trl[vllm]

We support two ways of using vLLM during training: server mode and colocate mode.

🔌 Option 1: Server mode

In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.

  1. Start the vLLM server:

    trl vllm-serve --model <model_name>
    
  2. Enable server mode in your training script:

    from trl import RLOOConfig
    
    training_args = RLOOConfig(
        ...,
        use_vllm=True,
        vllm_mode="server",  # default value, can be omitted
    )
    

Warning

Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the CUDA_VISIBLE_DEVICES environment variable.

🧩 Option 2: Colocate mode

In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.

from trl import RLOOConfig

training_args = RLOOConfig(
    ...,
    use_vllm=True,
    vllm_mode="colocate",
)

Tip

Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the vllm_gpu_memory_utilization parameter in [RLOOConfig] to avoid underutilization or out-of-memory errors.

We provide a HF Space to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get vllm_gpu_memory_utilization recommendation:

If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.

If you still find you are getting out-of-memory errors set vllm_enable_sleep_mode to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see Reducing Memory Usage with vLLM Sleep Mode.

Tip

By default, RLOO uses MASTER_ADDR=localhost and MASTER_PORT=12345 for vLLM, but you can override these values by setting the environment variables accordingly.

For more information, see Speeding up training with vLLM.

RLOO at scale: train a 70B+ Model on multiple nodes

When training large models like Qwen2.5-72B, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:

  • DeepSpeed ZeRO Stage 3: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such models. For more details, see DeepSpeed Integration.
  • Accelerate: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see Distributing Training.
  • vLLM: See the previous section on how to use vLLM to speed up generation.

Below is an example SLURM script to train a 70B model with RLOO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.

#!/bin/bash
#SBATCH --nodes=5
#SBATCH --gres=gpu:8

# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))

# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}"  # Nodes 0, 1, 2, 3 for training
VLLM_NODE="${NODELIST[4]}"  # Node 4 for vLLM

# Run training on the first 4 nodes (Group 1)
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
     --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
     --num_processes 32 \
     --num_machines 4 \
     --main_process_ip ${NODELIST[0]} \
     --machine_rank $SLURM_PROCID \
     --rdzv_backend c10d \
     train_rloo.py \
     --server_ip $VLLM_NODE &

# Run vLLM server on the 5th node (Group 2)
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &

wait
import argparse

from datasets import load_dataset
from trl import RLOOTrainer, RLOOConfig

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
    args = parser.parse_args()

    # Example dataset from TLDR
    dataset = load_dataset("trl-lib/tldr", split="train")

    # Dummy reward function: count the number of unique characters in the completions
    def reward_num_unique_chars(completions, **kwargs):
        return [len(set(c)) for c in completions]

    training_args = RLOOConfig(
        output_dir="Qwen2.5-72B-RLOO",
        per_device_train_batch_size=4,
        bf16=True,
        gradient_checkpointing=True,
        use_vllm=True,
        vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."),  # from ip-X-X-X-X to X.X.X.X
    )

    trainer = RLOOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
    trainer.train()

if __name__=="__main__":
    main()

Using a custom reward function

The [RLOOTrainer] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:

  1. Input arguments:

    • The function must accept the following as keyword arguments:

      • prompts (contains the prompts),
      • completions (contains the generated completions),
      • completions_ids (contains the tokenized completions),
      • trainer_state ([~transformers.TrainerState]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
      • All column names (but prompt) that the dataset may have. For example, if the dataset contains a column named ground_truth, the function will be called with ground_truth as a keyword argument.

      The easiest way to comply with this requirement is to use **kwargs in the function signature.

    • Depending on the dataset format, the input will vary:

  2. Return value: The function must return a list of floats. Each float represents the reward corresponding to a single completion.

Example 1: Reward longer completions

Below is an example of a reward function for a standard format that rewards longer completions:

def reward_func(completions_ids, **kwargs):
    """Reward function that assigns higher scores to longer completions (in terms of token count)."""
    return [float(len(ids)) for ids in completions_ids]

You can test it as follows:

>>> prompts = ["The sky is", "The sun is"]  # not used in the reward function, but the trainer will pass it
>>> completions = [" blue.", " in the sky."]  # not used in the reward function, but the trainer will pass it
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[2.0, 4.0]

Example 1.1: Reward longer completions (based on the number of characters)

Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.

def reward_func(completions, **kwargs):
    """Reward function that assigns higher scores to longer completions (in terms of character count)."""
    return [float(len(completion)) for completion in completions]

You can test it as follows:

>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]  # not used in the reward function, but the trainer will pass it
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[6.0, 12.0]

Example 2: Reward completions with a specific format

Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the format reward function used in the paper DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning. It is designed for a conversational format, where prompts and completions consist of structured messages.

import re

def format_reward_func(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

You can test this function as follows:

>>> prompts = [
...     [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
...     [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
...     [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
...     [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts=prompts, completions=completions)
[1.0, 0.0]

Example 3: Reward completions based on a reference

Below is an example of a reward function that checks if the completion is correct. This example is inspired by the accuracy reward function used in the paper DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning. This example is designed for standard format, where the dataset contains a column named ground_truth.

import re

def reward_func(completions, ground_truth, **kwargs):
    # Regular expression to capture content inside \boxed{}
    matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
    contents = [match.group(1) if match else "" for match in matches]
    # Reward 1 if the content is the same as the ground truth, 0 otherwise
    return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]

You can test this function as follows:

>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
>>> ground_truth = ["2", "5"]
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
[1.0, 0.0]

Example 4: Multi-task reward functions

Below is an example of using multiple reward functions in the [RLOOTrainer]. In this example, we define two task-specific reward functions: math_reward_func and coding_reward_func. The math_reward_func rewards math problems based on their correctness, while the coding_reward_func rewards coding problems based on whether the solution works.

from datasets import Dataset
from trl import RLOOTrainer

# Define a dataset that contains both math and coding problems
dataset = Dataset.from_list(
    [
        {"prompt": "What is 2+2?", "task": "math"},
        {"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
        {"prompt": "What is 3*4?", "task": "math"},
        {"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
    ]
)

# Math-specific reward function
def math_reward_func(prompts, completions, task, **kwargs):
    rewards = []
    for prompt, completion, t in zip(prompts, completions, task):
        if t == "math":
            # Calculate math-specific reward
            correct = check_math_solution(prompt, completion)
            reward = 1.0 if correct else -1.0
            rewards.append(reward)
        else:
            # Return None for non-math tasks
            rewards.append(None)
    return rewards

# Coding-specific reward function
def coding_reward_func(prompts, completions, task, **kwargs):
    rewards = []
    for prompt, completion, t in zip(prompts, completions, task):
        if t == "coding":
            # Calculate coding-specific reward
            works = test_code_solution(prompt, completion)
            reward = 1.0 if works else -1.0
            rewards.append(reward)
        else:
            # Return None for non-coding tasks
            rewards.append(None)
    return rewards

# Use both task-specific reward functions
trainer = RLOOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=[math_reward_func, coding_reward_func],
    train_dataset=dataset,
)

trainer.train()

In this example, the math_reward_func and coding_reward_func are designed to work with a mixed dataset that contains both math and coding problems. The task column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return None, and the [RLOOTrainer] will continue with the valid functions and tasks. This allows the [RLOOTrainer] to handle multiple reward functions with different applicability.

Note that the [RLOOTrainer] will ignore the None rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.

Passing the reward function to the trainer

To use your custom reward function, pass it to the [RLOOTrainer] as follows:

from trl import RLOOTrainer

trainer = RLOOTrainer(
    reward_funcs=reward_func,
    ...,
)

If you have multiple reward functions, you can pass them as a list:

from trl import RLOOTrainer

trainer = RLOOTrainer(
    reward_funcs=[reward_func1, reward_func2],
    ...,
)

and the reward will be computed as the sum of the rewards from each function, or the weighted sum if reward_weights is provided in the config.

Note that [RLOOTrainer] supports multiple reward functions of different types. See the parameters documentation for more details.

Vision-Language Model (VLM) Training

RLOO supports training Vision-Language Models (VLMs) on multimodal datasets containing both text and images.

Supported Models

Tested with:

  • Gemma3 — e.g., google/gemma-3-4b-it
  • LLaVA-NeXT — e.g., llava-hf/llava-v1.6-mistral-7b-hf
  • Qwen2-VL — e.g., Qwen/Qwen2-VL-2B-Instruct
  • Qwen2.5-VL — e.g., Qwen/Qwen2.5-VL-3B-Instruct
  • SmolVLM2 — e.g., HuggingFaceTB/SmolVLM2-2.2B-Instruct

Tip

Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.

Quick Start

Use rloo_vlm.py to fine-tune a VLM. Example command for training on lmms-lab/multimodal-open-r1-8k-verified:

accelerate launch \
  --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
  examples/scripts/rloo_vlm.py \
  --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
  --output_dir rloo-Qwen2.5-VL-3B-Instruct \
  --learning_rate 1e-5 \
  --gradient_checkpointing \
  --dtype bfloat16 \
  --max_prompt_length 2048 \
  --max_completion_length 1024 \
  --use_vllm \
  --vllm_mode colocate \
  --use_peft \
  --lora_target_modules "q_proj", "v_proj" \
  --log_completions

Configuration Tips

Tip

For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set max_length=None in the [RLOOConfig]. This allows the model to process the full sequence length without truncating image tokens.

RLOOConfig(max_length=None, ...)

Only use max_length when you've verified that truncation won't remove image tokens for the entire dataset.

  • Use LoRA on vision-language projection layers
  • Enable 4-bit quantization to reduce memory usage
  • VLMs are memory-intensive — start with smaller batch sizes
  • Most models are compatible with vLLM (server and colocate modes)

Dataset Format

Each training sample should include:

  • prompt: Text formatted via the processor's chat template
  • image/images: PIL Image or list of PIL Images

The trainer automatically handles image-to-tensor conversion via the models image processor.

RLOOTrainer

autodoc RLOOTrainer - train - save_model - push_to_hub

RLOOConfig

autodoc RLOOConfig

References

  1. RLOO Paper
  2. Paper Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs
  3. Paper - REINFORCE++: A Simple and Efficient Approach for Aligning Large Language Models
  4. Blog Post - Putting RL back in RLHF
  5. Blog Post - Unraveling RLHF and Its Variants: Progress and Practical Engineering Insights
  6. Youtube - RLOO: A Cost-Efficient Optimization for Learning from Human Feedback in LLMs

Migration Guide from the old implementation (0.21 and below)

With the release of version 0.22.0, we have revamped the [RLOOTrainer] to be more aligned with other online trainers in the library, like [GRPOTrainer]. This new implementation introduces several changes to the configuration parameters and overall structure of the trainer. Below is a summary of the key changes for [RLOOConfig]:

TRL ≤ 0.21.x TRL ≥ 0.22.0
rloo_k renamed to num_generations
cliprange renamed to epsilon
kl_coef renamed to beta
exp_name renamed to run_name. Use run_name = f"{exp_name}__{seed}__{int(time.time())}" to replicate old behavior
normalize_reward renamed to normalize_advantages. Note: this always normalized advantages (despite the old name)
num_ppo_epochs renamed to num_iterations (default: 1)
token_level_kl removed KL is now computed only at the sequence level
dataset_num_proc removed it was unused
num_mini_batches renamed to steps_per_generation
total_episodes use max_steps=total_episodes / gradient_accumulation_steps instead
local_rollout_forward_batch_size removed now automatically set to per_device_train_batch_size (or per_device_eval_batch_size during evaluation)
num_sample_generations removed use logging_steps to control generation logging frequency
response_length renamed to max_completion_length (default: 256)
stop_token removed
stop_token_id removed use processing_class.eos_token_id instead
missing_eos_penalty removed replicate with a custom reward function checking if eos_token_id is in completion_ids

Below is a summary of the key changes for [RLOOTrainer]:

TRL ≤ 0.21.x TRL ≥ 0.22.0
config renamed to args
reward_model renamed to reward_funcs, which now supports both reward models and custom reward functions
policy renamed to model
ref_policy removed the reference model is now created automatically from model
data_collator removed