Files
verl/examples/ppo_trainer
Feng Yao b8dc5377c6 [BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling (#2953)
### What does this PR do?

Support [vLLM-FSDP off-policy importance sampling
correction](https://fengyao.notion.site/off-policy-rl) using Truncated
Importance Sampling (TIS):

<img width="859" height="382" alt="TIS"
src="https://github.com/user-attachments/assets/adc8f797-aa14-4b29-b265-a682c281d08e"
/>




### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=gae \
    data.train_files="$train_files" \
    data.val_files="$test_files" \
    data.train_batch_size=1024 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \
    actor_rollout_ref.model.enable_gradient_checkpointing=False \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.actor.use_kl_loss=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    critic.optim.lr=1e-5 \
    critic.model.use_remove_padding=True \
    critic.model.path=Qwen/Qwen2.5-32B-Instruct \
    critic.model.enable_gradient_checkpointing=False \
    critic.ppo_micro_batch_size_per_gpu=8 \
    critic.model.fsdp_config.param_offload=False \
    critic.model.fsdp_config.optimizer_offload=False \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='verl_example' \
    trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=4 \
    trainer.save_freq=20 \
    trainer.test_freq=10 \
    trainer.total_epochs=15 \
    actor_rollout_ref.rollout.calculate_log_probs=True \   # add this config to return rollout prob
    +actor_rollout_ref.actor.behav_imp_weight_cap=10.0$@   # add this config to set up C value in TIS
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: Narsil-Dinghuai Zhang 张鼎怀 <dinghuai233@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: LiyuanLucasLiu <llychinalz@gmail.com>
2025-08-26 14:06:07 -07:00
..

Proximal Policy Optimization (PPO)

Proximal Policy Optimization (PPO) is a family of policy gradient methods for reinforcement learning, proposed by OpenAI in 2017. PPO strikes a balance between simplicity, stability, and performance, making it one of the most widely used algorithms in modern RL applications, including large-scale language model fine-tuning.

Traditional policy gradient methods like REINFORCE or Vanilla Policy Gradient suffer from:

  • High variance and sample inefficiency.
  • Instability due to large policy updates.

PPO addresses this problem using a clipped surrogate objective that avoids overly large updates without requiring second-order derivatives.

For more technical details regarding PPO, we suggest reading the introduction in the OpenAI spinning up tutorial, and the paper Proximal Policy Optimization Algorithms.

Key Components

  • Actor-Critic Architecture: PPO requires both an actor model (policy) and a critic model (value function). This differs from other algorithms like GRPO and RLOO that don't require a critic model.

  • Generalized Advantage Estimation (GAE): PPO uses GAE for computing advantage values, which helps reduce variance in policy gradient estimates while maintaining low bias.

  • Clipped Surrogate Objective: The core of PPO is implemented through the clipped surrogate objective function that limits policy updates.

Configuration

Note that all configs containing micro_batch_size are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior.

Most critic configs are similar to those of actors. Note that the critic model is omitted from the figure below.

image

  • data.train_batch_size: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is data.train_batch_size * actor_rollout.ref.rollout.n

  • actor_rollout_ref.actor.ppo_mini_batch_size: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers

  • actor_rollout_ref.critic.ppo_mini_batch_size: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO critic updates. The ppo_mini_batch_size is a global size across all workers

  • actor_rollout_ref.actor.clip_ratio: The PPO clip range. Default to 0.2

  • actor_rollout_ref.actor.ppo_epochs: Number of epochs for PPO updates on one set of sampled trajectories for actor

  • critic.ppo_epochs: Number of epochs for PPO updates on one set of sampled trajectories for critic. Defaults to actor_rollout_ref.actor.ppo_epochs

  • algorithm.gamma: discount factor

  • algorithm.lam: The lambda term that trades off between bias and variance in the GAE estimator

  • algorithm.adv_estimator: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo

Advanced Extensions

KL Divergence Control

Options to prevent the policy from diverging too far from a reference policy. Two mechanisms are available: KL reward penalty and KL loss. For more technical details, see Training language models to follow instructions with human feedback

Options to use KL loss for KL divergence control:

  • actor_rollout_ref.actor.use_kl_loss: to use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False

  • actor_rollout_ref.actor.kl_loss_coef: The coefficient of kl loss. Default is 0.001.

  • actor_rollout_ref.actor.kl_loss_type: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. Appending "+" in the end (e.g., 'k1+' and 'k3+') would apply straight through to employ k2 for unbiased gradient estimation, regardless of the kl value estimation (see https://github.com/volcengine/verl/pull/2953#issuecomment-3162113848 for more details). How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html

Options to use KL penalty in the reward:

  • algorithm.use_kl_in_reward: Whether to enable in-reward kl penalty. Default is False.

  • algorithm.kl_penalty: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. This defines the way to calculate the kl divergence between actor and reference policy. For specific options, refer to kl_penalty in core_algos.py. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html

  • algorithm.kl_ctrl.kl_coef: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.

  • algorithm.kl_ctrl.type: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.

  • algorithm.kl_ctrl.horizon: See source code of AdaptiveKLController for details.

  • algorithm.kl_ctrl.target_kl: See source code of AdaptiveKLController for details.

Dual-clip PPO

The Dual-Clip PPO introduces a approach by applying a lower bound to the policy ratio when the advantage is less than zero, when multiplied by a large raito, does not exceed a specified lower bound.

image

  • actor_rollout_ref.actor.clip_ratio_c: lower bound of the value for Dual-clip PPO, defaults to 3.0

Reference Example

Qwen2.5 training log and commands: link

bash run_gemma.sh
  trainer.n_gpus_per_node=1 \
  actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
  trainer.logger=console \
  critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
  actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
  data.train_batch_size=256 \
  actor_rollout_ref.actor.ppo_mini_batch_size=64 \
  actor_rollout_ref.actor.ppo_micro_batch_size=2 \
  critic.ppo_micro_batch_size=2

Reference performance with verl v0.2:

Model Method Score Link
Qwen/Qwen2.5-0.5B-Instruct pretrained model 36.4 Qwen Blog
Qwen/Qwen2.5-0.5B-Instruct PPO 56.7 PPO Command and Logs