Files
verl/examples
Yingru Li 21271aabb9 [BREAKING][rollout, trainer, algo] feat: comprehensive rollout importance sampling implementation (#3694)
# Rollout Importance Sampling Framework

## Summary

This PR introduces a comprehensive **Rollout Importance Sampling (IS)**
framework to correct distribution mismatch between data-collecting
(rollout) and training policies, a critical factor for ensuring stable
and efficient model training in RL fine-tuning.

This work is motivated by the analysis in our blog post, [When Speed
Kills Stability: Demystifying RL Collapse from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda).
If you find this implementation useful in your research, please consider
citing:

```bibtex
@misc{liu-li-2025,
  title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch},
  url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda},
  author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen},
  year = {2025},
  month = {September},
}
```

---

## Problem Statement

When using different policies for rollout generation (e.g., vLLM with
BFloat16) and training (e.g., FSDP with FP32), distribution mismatch
occurs, leading to:
- Biased gradient estimates
- Training instability and collapse
- Reduced sample efficiency
- Poor convergence properties

This framework addresses these issues through principled importance
sampling correction.

---

## Key Features & Improvements

### 1. **Flexible Aggregation Levels**
Three methods for calculating IS weights:
- **`token`**: Per-token importance ratios
- **`sequence`**: Product of per-token ratios
- **`geometric`**: Geometric mean of ratios

### 2. **Advanced Bounding Modes**
Two strategies to control weight variance:
- **`truncate`** (TIS): Caps weights at upper threshold only, preserving
gradients
- **`clip`** (CIS): Zeros out weights outside bounds, more aggressive
filtering

### 3. **Comprehensive Diagnostics**
Detailed metrics to monitor distribution mismatch and training health:

**Rollout IS Metrics** (automatically prefixed with `mismatch/`):
- Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean`
- Distribution statistics: `rollout_is_p25`, `rollout_is_p50`,
`rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`,
`rollout_is_min`, `rollout_is_std`
- Diagnostics: `rollout_is_veto_fraction`,
`rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction`
(clip mode)
- Sequence-level statistics (for sequence/geometric modes):
`rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`,
`rollout_is_seq_min`, etc.

**Mismatch Metrics** (computed efficiently within IS weight
computation):
- KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3
estimator for stability)
- Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`,
`mismatch_ppl_ratio`
- Log perplexity statistics: `mismatch_log_ppl_diff`,
`mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`,
`mismatch_log_ppl_diff_min`

### 4. **Outlier Mitigation**
- **Veto mechanism**: Automatically discards samples with catastrophic
importance weights (per-token ratios below threshold)
- Prevents gradient corruption from extreme outliers
- Configurable threshold (default: 1e-4)

### 5. **Numerical Stability**
- All core computations in **log-space** to prevent underflow/overflow
- Carefully designed clipping and bounding to maintain numerical
precision
- Safe handling of edge cases (zero probabilities, extreme ratios)

### 6. **Memory Efficiency**
- Optimized computation to minimize CUDA memory usage
- Efficient metric aggregation without large intermediate tensors
- Suitable for large-scale distributed training

### 7. **Metrics-Only Mode**
- Compute and monitor mismatch metrics **without** applying IS weights
- Useful for:
  - Understanding distribution mismatch before intervention
  - Deciding whether IS correction is needed
  - A/B testing IS impact
- Controlled by `algorithm.rollout_is` flag (independent of weight
computation)

### 8. **Universal PPO Support**
- Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov,
KL-Cov, geo_mean
- Consistent interface across different policy loss functions
- Automatic weight application when enabled

---

## API and Configuration Changes

### Migration from Legacy TIS

####  **Before (REMOVED)**
```yaml
# Old TIS configuration - NO LONGER SUPPORTED
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0  # Removed from actor config
```

The legacy implementation:
- Only supported token-level truncation
- No metrics tracking
- Lacked numerical stability
- Limited configurability

####  **After (New Framework)**

Configuration moved to `algorithm` section for better organization:

```yaml
algorithm:
  # Main on/off switch: null = disabled, float = enabled
  rollout_is_threshold: 2.0

  # Control weight application (independent of metrics computation)
  rollout_is: true  # true = apply weights, false = metrics only

  # Optional: lower threshold (defaults to 1/upper if null)
  rollout_is_threshold_lower: null

  # Aggregation level: "token", "sequence", or "geometric"
  rollout_is_level: token

  # Bounding mode: "truncate" or "clip"
  rollout_is_mode: truncate

  # Veto threshold for catastrophic outliers (null = disabled)
  rollout_is_veto_threshold: 1e-4

# REQUIRED: Enable log probability calculation
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

### Configuration Examples

**1. Token-level truncation (recommended starting point)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate
```

**2. Sequence-level clipping (more aggressive)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: sequence
  rollout_is_mode: clip
```

**3. Metrics-only mode (monitoring without correction)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: false  # Compute metrics but don't apply weights
  rollout_is_level: token
  rollout_is_mode: truncate
```

**Example script:** `bash
examples/rollout_importance_sampling/run_with_rollout_is.sh`

---

## Code Changes Overview

### New Files (4 files, 1,442 lines)

1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines)
   - Core implementation of IS weight computation
   - Three aggregation levels: token, sequence, geometric
   - Two bounding modes: truncate, clip
   - Veto mechanism for outlier detection
   - Comprehensive metrics computation (IS + mismatch)
   - All computations in log-space for numerical stability
   - Memory-efficient design

2. **`docs/advance/rollout_is_migration.md`** (642 lines)
   - Comprehensive migration guide from legacy TIS
   - Detailed explanation of all configuration options
   - Recommended threshold ranges for each aggregation level
   - Troubleshooting guide and best practices
   - Metrics interpretation guide

3. **`examples/rollout_importance_sampling/README.md`** (242 lines)
   - Quick start guide with working examples
   - Configuration templates for common scenarios
   - Threshold tuning guidelines
   - Metrics monitoring instructions

4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99
lines)
   - Complete working example script
   - Demonstrates token-level and sequence-level configurations
   - Ready to run with minimal modifications

### Modified Core Files (9 files)

1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed)
   - Removed legacy TIS logic (`tis_imp_ratio_cap`)
   - Added `rollout_is_weights` parameter to all policy loss functions
   - Unified IS weight application interface across all PPO variants:
     - `compute_policy_loss_vanilla`
     - `compute_policy_loss_gspo`
     - `compute_policy_loss_gpg`
     - `compute_policy_loss_clip_cov`
     - `compute_policy_loss_kl_cov`
     - `compute_policy_loss_geo_mean`
   - Special handling for `geo_mean` (sequence-level aggregation)

2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added)
   - New method: `compute_rollout_importance_weights_and_add_to_batch()`
   - Centralized IS computation (once per batch, on driver)
- Conditional weight distribution to workers based on
`algorithm.rollout_is`
   - Metrics collection and aggregation
   - Integration with existing training loop

3. **`verl/trainer/config/algorithm.py`** (+18 lines)
   - Added 6 new Rollout IS parameters:
     - `rollout_is_threshold` (main on/off switch)
     - `rollout_is` (weight application control)
     - `rollout_is_threshold_lower`
     - `rollout_is_level`
     - `rollout_is_mode`
     - `rollout_is_veto_threshold`
   - Comprehensive docstrings explaining each parameter

4. **`verl/workers/config/actor.py`** (-1 line)
   - Removed deprecated `tis_imp_ratio_cap` parameter

5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

7. **Configuration Files** (4 files updated)
   - `verl/trainer/config/ppo_trainer.yaml`
   - `verl/trainer/config/ppo_megatron_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml`
- Added default Rollout IS configuration section with explanatory
comments

### Testing (2 files, 530 lines)

1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines)
   - Unit tests for `mismatch_helper.py`
   - Coverage for all aggregation levels (token, sequence, geometric)
   - Coverage for all bounding modes (truncate, clip)
   - Veto mechanism tests
   - Edge case handling (zeros, extremes, empty sequences)
   - Numerical stability verification
   - Metrics correctness validation

2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines)
   - Integration tests with PPO training loop
   - End-to-end workflow validation
   - Batch processing tests
   - Configuration validation
   - Metrics collection verification
   - Compatibility with distributed training

### Updated Recipes (2 files)

1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines)
   - Updated imports to use new framework

2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed)
   - Migrated from legacy TIS to new Rollout IS configuration
   - Updated documentation and comments

### Documentation Updates (2 files)

1. **`docs/examples/config.rst`** (~22 lines changed)
   - Updated configuration examples
   - Added Rollout IS section

2. **`docs/index.rst`** (+1 line)
   - Added link to Rollout IS migration guide

---

## Implementation Highlights

### Centralized Architecture

The new design follows a clean separation of concerns:

```
ray_trainer.py (driver)
    └─> compute_rollout_importance_weights_and_add_to_batch()
         └─> mismatch_helper.compute_rollout_importance_weights()
              ├─> Computes IS weights (token/sequence/geometric)
              ├─> Applies bounding (truncate/clip)
              ├─> Veto mechanism for outliers
              ├─> Computes IS metrics
              └─> Computes mismatch metrics (KL, PPL)
    └─> Conditionally adds weights to batch (if rollout_is=True)
    └─> Distributes batch to workers

actor workers (dp_actor, megatron_actor)
    └─> Receive batch with rollout_is_weights (if enabled)
    └─> Pass weights to policy loss function

core_algos.py
    └─> All policy loss functions accept rollout_is_weights
    └─> Apply weights if provided: pg_losses *= rollout_is_weights
```

### Key Design Decisions

1. **Centralized Computation**: IS weights computed once on driver, not
per worker
   - Reduces redundant computation
   - Ensures consistency across workers
   - Simplifies debugging and metrics collection

2. **Configuration in Algorithm**: Moved from actor config to algorithm
config
- Better conceptual organization (algorithm-level concern, not
worker-level)
   - Easier to manage and validate
   - Consistent with other algorithm parameters

3. **Two-Level Control**:
   - `rollout_is_threshold`: Enables/disables entire system (null = off)
- `rollout_is`: Controls weight application (true = apply, false =
metrics only)
   - Allows flexible monitoring and gradual rollout

4. **Metrics Consolidation**: Mismatch metrics computed within IS weight
computation
   - Eliminates duplicate computation
   - Reduces memory overhead
   - Maintains metric accuracy

5. **Universal PPO Support**: Single interface for all PPO variants
   - Minimal code changes required
   - Consistent behavior across algorithms
   - Easy to add new variants

---

## Migration Guide

### For Users of Legacy TIS

**Step 1: Update your configuration file**

```yaml
# OLD (remove this)
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0

# NEW (add this)
algorithm:
  rollout_is_threshold: 2.0  # Use same value as old tis_imp_ratio_cap
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

# REQUIRED (add if not present)
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

**Step 2: Monitor metrics**

The first time you run with the new configuration, check these metrics:
- `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size
- `mismatch/rollout_is_veto_fraction`: Should be < 5%
- `mismatch/rollout_is_mean`: Should be close to 1.0

**Step 3: Tune if needed**

If effective sample size is too low:
- Increase `rollout_is_threshold`
- Try `rollout_is_mode: clip` with appropriate lower bound
- Consider `rollout_is_level: sequence` for more aggressive correction

For detailed guidance, see `docs/advance/rollout_is_migration.md`.

### For New Users

Start with recommended defaults:

```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

Run the example script to see it in action:
```bash
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
```

---

## Testing

### Unit Tests
- **289 lines** of comprehensive unit tests in `test_rollout_is.py`
- Covers all aggregation levels, bounding modes, and edge cases
- Validates numerical stability and correctness
- Fast execution (~1-2 seconds)

### Integration Tests
- **241 lines** of integration tests in `test_rollout_is_integration.py`
- End-to-end workflow with PPO training loop
- Distributed training compatibility
- Metrics collection validation
- Moderate execution time (~10-20 seconds)

### Running Tests
```bash
# Run all Rollout IS tests
pytest tests/trainer/ppo/test_rollout_is.py -v
pytest tests/trainer/ppo/test_rollout_is_integration.py -v

# Run specific test
pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v
```

---

## Metrics Reference

### Rollout IS Metrics (all prefixed with `mismatch/`)

| Metric | Description | Ideal Range |
|--------|-------------|-------------|
| `rollout_is_eff_sample_size` | Effective number of samples after IS |
> 80% of batch |
| `rollout_is_mean` | Mean IS weight | ~1.0 |
| `rollout_is_std` | Standard deviation of IS weights | Low variance |
| `rollout_is_p25` | 25th percentile | ~0.8-1.0 |
| `rollout_is_p50` | Median IS weight | ~1.0 |
| `rollout_is_p75` | 75th percentile | ~1.0-1.2 |
| `rollout_is_p95` | 95th percentile | < threshold |
| `rollout_is_p99` | 99th percentile | < threshold |
| `rollout_is_max` | Maximum weight | ≤ threshold |
| `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) |
| `rollout_is_veto_fraction` | % sequences vetoed | < 5% |
| `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | <
1% |
| `rollout_is_clipped_fraction` | % tokens clipped (clip mode) |
Variable |

### Mismatch Metrics (all prefixed with `mismatch/`)

| Metric | Description | What It Means |
|--------|-------------|---------------|
| `mismatch_kl` | Forward KL divergence | Distribution difference
(rollout vs training) |
| `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small
divergences |
| `mismatch_training_ppl` | Training policy perplexity | Prediction
difficulty of training policy |
| `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction
difficulty of rollout policy |
| `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative
prediction difficulty |
| `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level
PPL mismatch |
| `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude
of mismatch |
| `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case
mismatch |
| `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case
mismatch |
| `mismatch_training_log_ppl` | Log of training PPL | Log-scale training
perplexity |
| `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout
perplexity |

---

## Performance Impact

### Memory
- Minimal overhead: ~1-2% increase in peak memory usage
- Efficient log-space computation
- No large intermediate tensors

### Computation
- Negligible impact on training speed: < 1% overhead
- Centralized computation on driver (no per-worker redundancy)
- Optimized tensor operations

### Training Stability
- Significant improvement in stability when distribution mismatch exists
- Faster convergence in many scenarios
- Reduced risk of training collapse

---

## Breaking Changes

> [!IMPORTANT]
> This PR contains **BREAKING CHANGES** to the configuration API.

### Removed
- `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported

### Migration Required
All users of the legacy TIS implementation must update their
configuration files. See the migration guide above or
`docs/advance/rollout_is_migration.md` for detailed instructions.

### Backward Compatibility
- No backward compatibility with legacy TIS
- Configuration files with `tis_imp_ratio_cap` will raise validation
errors
- Affected recipes have been updated in this PR

---

## Pre-Submission Checklist

- [x] Search for similar PRs:
[https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling)
- [x] Format PR title as `[{modules}] {type}: {description}` (checked by
CI)
- **Suggested title:** `[BREAKING][rollout, trainer, algo] feat:
implement comprehensive Rollout Importance Sampling framework`
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md)
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting)
- [x] Add/update
[documentation](https://github.com/volcengine/verl/tree/main/docs) (3
new docs, 2 updated)
- [x] Add unit and integration tests (530 lines of tests)
- [x] Once PR is ready for CI, send message in `ci-request` channel

---

## References

- **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse
from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda)
- **Migration guide:** `docs/advance/rollout_is_migration.md`
- **Examples:** `examples/rollout_importance_sampling/`
- **Tests:** `tests/trainer/ppo/test_rollout_is*.py`

---------

Co-authored-by: Yan Bai <bayan@nvidia.com>
2025-10-13 17:05:29 +08:00
..