# Rollout Importance Sampling Framework related to https://github.com/volcengine/verl/pull/3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
verl: Volcano Engine Reinforcement Learning for LLMs
verl is a flexible, efficient and production-ready RL training library for large language models (LLMs).
verl is the open-source version of HybridFlow: A Flexible and Efficient RLHF Framework paper.
verl is flexible and easy to use with:
-
Easy extension of diverse RL algorithms: The hybrid-controller programming model enables flexible representation and efficient execution of complex post-training dataflows. Build RL dataflows such as GRPO, PPO in a few lines of code.
-
Seamless integration of existing LLM infra with modular APIs: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as FSDP, Megatron-LM, vLLM, SGLang, etc
-
Flexible device mapping: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes.
-
Ready integration with popular HuggingFace models
verl is fast with:
-
State-of-the-art throughput: SOTA LLM training and inference engine integrations and SOTA RL throughput.
-
Efficient actor model resharding with 3D-HybridEngine: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases.
News
- [2025/08] verl is presented in the PyTorch Expert Exchange Webinar. Slides available.
- [2025/07] The ReTool recipe is fully open sourced. Blog
- [2025/07] The first verl meetup will be held at ICML Vancouver on July 16th! Please join us if you are at ICML! (onsite only)
- [2025/06] verl with Megatron backend enables large MoE models such as DeepSeek-671B and Qwen3-235B.
- [2025/03] DAPO is the open-sourced SOTA RL algorithm that achieves 50 points on AIME 2024 based on the Qwen2.5-32B pre-trained model, surpassing the previous SOTA achieved by DeepSeek's GRPO (DeepSeek-R1-Zero-Qwen-32B). DAPO's training is fully powered by verl and the reproduction code is available in
recipe/dapo
now.
more...
- [2025/04] [Seed-Thinking-v1.5](https://github.com/ByteDance-Seed/Seed-Thinking-v1.5/blob/main/seed-thinking-v1.5.pdf) tech report is released! Trained with verl, Seed-Thinking-v1.5 achieves 86.7 on AIME 2024, 55.0 on Codeforces and 77.3 on GPQA, demonstrating excellent reasoning abilities in STEM and coding. Beyond reasoning tasks, the method demonstrates notable generalization across diverse domains.
- [2025/07] verl keynote at [AWS AI Hours Singapore](https://pages.awscloud.com/aws-ai-hours-sg.html#agenda) on 7/8, verl & verl-agent project updates at [Agent for SWE meetup](https://lu.ma/e498qhsi) by LF AI & Data Singapore on 7/11.
- [2025/06] verl team will provide latest project updates at [PyTorch Day China](https://www.lfasiallc.com/pytorch-day-china/) on June 7th. Meet our dev team in Beijing!
- [2025/04] [VAPO](https://arxiv.org/pdf/2504.05118) (value-based augmented PPO) paper covers our latest RL method for reasoning models. Trained from Qwen-32B-base model, VAPO achieves 60.4 on AIME 2024, outperforming DAPO-32B.
- [2025/05] [PF-PPO](https://arxiv.org/abs/2409.06957), accepted to ICML 2025, is now supported in verl! PF-PPO enhances policy learning efficiency and robustness by filtering potentially noisy reward signals and reusing high-quality experiences via a replay buffer.
- [2025/04] We will give a tutorial about latest post-training techniques and programming guide for verl at [ICLR 2025 Expo](https://iclr.cc/virtual/2025/calendar?filter_events=Expo+Talk+Panel&filter_rooms=), [SCI-FM workshop](https://open-foundation-model.github.io/) and [LMSys afterparty](https://lu.ma/d23nyynm). Talk materials available [here](https://github.com/eric-haibin-lin/verl-community/tree/main/iclr25).
- [2025/03] verl v0.3.0.post1 is released! See [release note](https://github.com/volcengine/verl/releases/) for details. It achieves [~1.4x speedup](https://tongyx361.github.io/blogs/posts/verl-intro/#/verl-flexible-and-efficient-rl-for-llms) compared to prev versions.
- [2025/05] verl will be presented at [A2M Shanghai](https://a2m.msup.com.cn/home/?aid=4488&city=shanghai) on 5/16 - 5/17.
- [2025/05] verl will be presented at [GOSIM x PyTorch Day 2025](https://paris2025.gosim.org/). See you in Paris!
- [2025/03] We introduced the programming model of verl at the [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg) and [verl intro and updates](https://github.com/eric-haibin-lin/verl-community/blob/main/slides/verl-lmsys-meetup.pdf) at the [SGLang-LMSYS Org Meetup](https://lu.ma/ntjrr7ig) in Sunnyvale mid-March.
- [2025/03] We will present verl(HybridFlow) at EuroSys 2025. See you in Rotterdam!
- [2025/02] verl v0.2.0.post2 is released!
- [2025/02] We presented verl in the Bytedance/NVIDIA/Anyscale Ray Meetup. See you in San Jose!
- [2025/01] [Doubao-1.5-pro](https://team.doubao.com/zh/special/doubao_1_5_pro) is released with SOTA-level performance on LLM & VLM. The RL scaling preview model is trained using verl, reaching OpenAI O1-level performance on math benchmarks (70.0 pass@1 on AIME).
- [2024/12] verl is presented at Ray Forward 2024. Slides available here
- [2024/12] The team presented Post-training LLMs: From Algorithms to Infrastructure at NeurIPS 2024. Slides and video available.
- [2024/10] verl is presented at Ray Summit. Youtube video available.
- [2024/08] HybridFlow (verl) is accepted to EuroSys 2025.
Key Features
- FSDP, FSDP2 and Megatron-LM for training.
- vLLM, SGLang and HF Transformers for rollout generation.
- Compatible with Hugging Face Transformers and Modelscope Hub: Qwen-3, Qwen-2.5, Llama3.1, Gemma2, DeepSeek-LLM, etc
- Supervised fine-tuning.
- Reinforcement learning with PPO, GRPO, GSPO, ReMax, REINFORCE++, RLOO, PRIME, DAPO, DrGRPO, KL_Cov & Clip_Cov etc.
- Support model-based reward and function-based reward (verifiable reward) for math, coding, etc
- Support vision-language models (VLMs) and multi-modal RL with Qwen2.5-vl, Kimi-VL
- Multi-turn with tool calling
- LLM alignment recipes such as Self-play preference optimization (SPPO)
- Flash attention 2, sequence packing, sequence parallelism support via DeepSpeed Ulysses, LoRA, Liger-kernel.
- Scales up to 671B models and hundreds of GPUs with expert parallelism
- Multi-gpu LoRA RL support to save memory.
- Experiment tracking with wandb, swanlab, mlflow and tensorboard.
Upcoming Features and Changes
- Q3 Roadmap https://github.com/volcengine/verl/issues/2388
- DeepSeek 671b optimizations with Megatron https://github.com/volcengine/verl/issues/1033
- Multi-turn rollout and tools using optimizations https://github.com/volcengine/verl/issues/1882
- Agent integration
- Async and off-policy architecture https://github.com/volcengine/verl/pull/2231
- List of breaking changes since v0.4 https://github.com/volcengine/verl/discussions/2270
Getting Started
Quickstart:
- Installation
- Quickstart
- Programming Guide & Tech Talk (in Chinese)
- PPO in verl
- GRPO in verl
Running a PPO example step-by-step:
- Prepare Data for Post-Training
- Implement Reward Function for Dataset
- PPO Example Architecture
- Config Explanation
Reproducible algorithm baselines:
For code explanation and advance usage (extension):
-
PPO Trainer and Workers
-
Advanced Usage and Extension
Blogs from the community
- When Reasoning Models Break Tokenization: The Hidden Complexity of Multiturn Training
- verl deployment on AWS SageMaker
- verl x SGLang Multi-turn Code Walkthrough
- Optimizing SGLang Memory Usage in verl
- SGLang, verl, OpenBMB and Tsinghua University: Pioneering End-to-End Multi-Turn RLHF
- Reinforcement Learning from Human Feedback on AMD GPUs with verl and ROCm Integration
- veMLP x verl :玩转强化学习训练
- 使用 verl 进行 GRPO 分布式强化学习训练最佳实践
- HybridFlow verl 原文浅析
- 最高提升 20 倍吞吐量!豆包大模型团队发布全新 RLHF 框架,现已开源!
Performance Tuning Guide
The performance is essential for on-policy RL algorithm. We have written a detailed performance tuning guide to help you optimize performance.
Upgrade to vLLM >= v0.8.2
verl now supports vLLM>=0.8.2 when using FSDP as the training backend. Please refer to this document for the installation guide and more information. Please avoid vllm 0.7.x, which contains bugs that may lead to OOMs and unexpected errors.
Use Latest SGLang
SGLang is fully supported with verl, and SGLang RL Group is working extensively on building unique features, including multi-turn agentic RL, VLM RLHF, server-based RL, and partial rollout. Please refer to this document for the installation guide and more information.
Upgrade to FSDP2
verl is fully embracing FSDP2! FSDP2 is recommended by torch distributed team, providing better throughput and memory usage, and is composible with other features (e.g. torch.compile). To enable FSDP2, simply use verl main and set the following options:
actor_rollout_ref.ref.strategy=fsdp2
actor_rollout_ref.actor.strategy=fsdp2
critic.strategy=fsdp2
reward_model.strategy=fsdp2
Furthermore, FSDP2 cpu offloading is compatible with gradient accumulation. You can turn it on to save memory with actor_rollout_ref.actor.fsdp_config.offload_policy=True
. For more details, see https://github.com/volcengine/verl/pull/1026
AMD Support (ROCm Kernel)
verl now supports FSDP as the training engine (Megatron support coming soon) and both integrates with vLLM and SGLang as inference engines. Please refer to this document for the installation guide and more information, and this document for the vLLM performance tuning for ROCm.
Citation and acknowledgement
If you find the project helpful, please cite:
- HybridFlow: A Flexible and Efficient RLHF Framework
- A Framework for Training Large Language Models for Code Generation via Proximal Policy Optimization
@article{sheng2024hybridflow,
title = {HybridFlow: A Flexible and Efficient RLHF Framework},
author = {Guangming Sheng and Chi Zhang and Zilingfeng Ye and Xibin Wu and Wang Zhang and Ru Zhang and Yanghua Peng and Haibin Lin and Chuan Wu},
year = {2024},
journal = {arXiv preprint arXiv: 2409.19256}
}
verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The project is adopted and contributed by Bytedance, Anyscale, LMSys.org, Alibaba Qwen team, Shanghai AI Lab, Tsinghua University, UC Berkeley, UCLA, UIUC, University of Hong Kong, ke.com, All Hands AI, ModelBest, JD AI Lab, Microsoft Research, StepFun, Amazon, LinkedIn, Meituan, Camel-AI, OpenManus, Xiaomi, NVIDIA research, Baichuan, RedNote, SwissAI, Moonshot AI (Kimi), Baidu, Snowflake, Skywork.ai, JetBrains, IceSword Lab, and many more.
Awesome work using verl
- TinyZero: a reproduction of DeepSeek R1 Zero recipe for reasoning tasks
- SkyThought: RL training for Sky-T1-7B by NovaSky AI team.
- simpleRL-reason: SimpleRL-Zoo: Investigating and Taming Zero Reinforcement Learning for Open Base Models in the Wild
- Easy-R1: Multi-modal RL training framework
- OpenManus-RL: LLM Agents RL tunning framework for multiple agent environments.
- rllm: async RL training with verl-pipeline
- RAGEN: a general-purpose reasoning agent training framework
- Search-R1: RL with reasoning and searching (tool-call) interleaved LLMs
- ReSearch: Learning to Reason with Search for LLMs via Reinforcement Learning
- Skywork-OR1: Skywork open reaonser series
- ToRL: Scaling tool-integrated RL
- Absolute Zero Reasoner: A no human curated data self-play framework for reasoning
- verl-agent: A scalable training framework for long-horizon LLM/VLM agents, along with a new algorithm GiGPO
- RL-Factory: An easy and efficient RL post-training framework for Agentic Learning
- ReTool: ReTool: reinforcement learning for strategic tool use in LLMs. Code release is in progress...
- verl-tool: An unified and easy-to-extend tool-agent training framework based on verl
- PRIME: Process reinforcement through implicit rewards
- MemAgent: MemAgent: Reshaping Long-Context LLM with Multi-Conv RL based Memory Agent
- POLARIS: A Post-training recipe for scaling RL on Advanced Reasoning models
- GUI-R1: GUI-R1: A Generalist R1-style Vision-Language Action Model For GUI Agents
- DeepRetrieval: RL Training of Search Agent with Search/Retrieval Outcome
- Code-R1: Reproducing R1 for Code with Reliable Rewards
- DeepResearcher: Scaling deep research via reinforcement learning in real-world environments
- VAGEN: Training VLM agents with multi-turn reinforcement learning
- RM-R1: RL training of reasoning reward models
- LUFFY: Learning to Reason under Off-Policy Guidance
- DeepMath: DeepMath-103K data and series models for math reasoning
- PACS: Implicit Actor Critic Coupling via a Supervised Learning Framework for RLVR
- Entropy Mechanism of RL: The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning
- LLaSA-TTS-GRPO: TTS fine-tuning with GRPO optimization based on LLASA models
- PF-PPO: Policy Filtration for PPO based on the reliability of reward signals for more efficient and robust RLHF.
- RACRO: Build multi-modal reasoning models via decoupling it into query-conditioned captioning and text-only reasoning
- Agent Lightning: A flexible and extensible framework that enables seamless agent optimization for any existing agent framework.
- VTool-R1: VLMs Learn to Think with Images via Reinforcement Learning on Multimodal Tool Use.
- Kimina-Prover-RL: Training pipeline for formal theorem proving, based on a paradigm inspired by DeepSeek-R1.
- RL-PLUS: Countering Capability Boundary Collapse of LLMs in Reinforcement Learning with Hybrid-policy Optimization.
- rStar2-Agent: Using reinforcement learning with multi-step tool-calling for math tasks, rStar2-Agent-14B reaches frontier-level math reasoning in just 510 RL training steps
- Vision-SR1: Self-Rewarding Vision-Language Model via Reasoning Decomposition
- SimpleVLA-RL: SimpleVLA-RL: A Simple yet Effective Vision-Language Action Model for Reinforcement Learning
- Table-R1: Table-R1: Inference-Time Scaling for Table Reasoning
and many more awesome work listed in recipe.
Contribution Guide
About ByteDance Seed Team
Founded in 2023, ByteDance Seed Team is dedicated to crafting the industry's most advanced AI foundation models. The team aspires to become a world-class research team and make significant contributions to the advancement of science and society. You can get to know Bytedance Seed better through the following channels👇
---We are HIRING! Send us an email if you are interested in internship/FTE opportunities in RL for agents.