[rollout] refactor: rename "clip" mode back to "mask" mode (#3750)

# Rollout Importance Sampling Framework

related to https://github.com/volcengine/verl/pull/3694

## Summary

This PR introduces a comprehensive **Rollout Importance Sampling (IS)**
framework to correct distribution mismatch between data-collecting
(rollout) and training policies, a critical factor for ensuring stable
and efficient model training in RL fine-tuning.

This work is motivated by the analysis in our blog post, [When Speed
Kills Stability: Demystifying RL Collapse from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda).
If you find this implementation useful in your research, please consider
citing:

```bibtex
@misc{liu-li-2025,
  title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch},
  url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda},
  author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen},
  year = {2025},
  month = {September},
}
```

---

## Problem Statement

When using different policies for rollout generation (e.g., vLLM with
BFloat16) and training (e.g., FSDP with FP32), distribution mismatch
occurs, leading to:
- Biased gradient estimates
- Training instability and collapse
- Reduced sample efficiency
- Poor convergence properties

This framework addresses these issues through principled importance
sampling correction.

---

## Key Features & Improvements

### 1. **Flexible Aggregation Levels**
Three methods for calculating IS weights:
- **`token`**: Per-token importance ratios
- **`sequence`**: Product of per-token ratios
- **`geometric`**: Geometric mean of ratios

### 2. **Advanced Bounding Modes**
Two strategies to control weight variance:
- **`truncate`** (TIS): Caps weights at upper threshold only, preserving
gradients
- **`mask`** (MIS): Zeros out weights outside bounds, more aggressive
filtering

### 3. **Comprehensive Diagnostics**
Detailed metrics to monitor distribution mismatch and training health:

**Rollout IS Metrics** (automatically prefixed with `mismatch/`):
- Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean`
- Distribution statistics: `rollout_is_p25`, `rollout_is_p50`,
`rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`,
`rollout_is_min`, `rollout_is_std`
- Diagnostics: `rollout_is_veto_fraction`,
`rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction`
(mask mode)
- Sequence-level statistics (for sequence/geometric modes):
`rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`,
`rollout_is_seq_min`, etc.

**Mismatch Metrics** (computed efficiently within IS weight
computation):
- KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3
estimator for stability)
- Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`,
`mismatch_ppl_ratio`
- Log perplexity statistics: `mismatch_log_ppl_diff`,
`mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`,
`mismatch_log_ppl_diff_min`

### 4. **Outlier Mitigation**
- **Veto mechanism**: Automatically discards samples with catastrophic
importance weights (per-token ratios below threshold)
- Prevents gradient corruption from extreme outliers
- Configurable threshold (default: 1e-4)

### 5. **Numerical Stability**
- All core computations in **log-space** to prevent underflow/overflow
- Carefully designed clamping and bounding to maintain numerical
precision
- Safe handling of edge cases (zero probabilities, extreme ratios)

### 6. **Memory Efficiency**
- Optimized computation to minimize CUDA memory usage
- Efficient metric aggregation without large intermediate tensors
- Suitable for large-scale distributed training

### 7. **Metrics-Only Mode**
- Compute and monitor mismatch metrics **without** applying IS weights
- Useful for:
  - Understanding distribution mismatch before intervention
  - Deciding whether IS correction is needed
  - A/B testing IS impact
- Controlled by `algorithm.rollout_is` flag (independent of weight
computation)

### 8. **Universal PPO Support**
- Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov,
KL-Cov, geo_mean
- Consistent interface across different policy loss functions
- Automatic weight application when enabled

---

## API and Configuration Changes

### Migration from Legacy TIS

####  **Before (REMOVED)**
```yaml
# Old TIS configuration - NO LONGER SUPPORTED
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0  # Removed from actor config
```

The legacy implementation:
- Only supported token-level truncation
- No metrics tracking
- Lacked numerical stability
- Limited configurability

####  **After (New Framework)**

Configuration moved to `algorithm` section for better organization:

```yaml
algorithm:
  # Main on/off switch: null = disabled, float = enabled
  rollout_is_threshold: 2.0

  # Control weight application (independent of metrics computation)
  rollout_is: true  # true = apply weights, false = metrics only

  # Optional: lower threshold (defaults to 1/upper if null)
  rollout_is_threshold_lower: null

  # Aggregation level: "token", "sequence", or "geometric"
  rollout_is_level: token

  # Bounding mode: "truncate" or "mask"
  rollout_is_mode: truncate

  # Veto threshold for catastrophic outliers (null = disabled)
  rollout_is_veto_threshold: 1e-4

# REQUIRED: Enable log probability calculation
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

### Configuration Examples

**1. Token-level truncation (recommended starting point)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate
```

**2. Sequence-level masking (more aggressive)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: sequence
  rollout_is_mode: mask
```

**3. Metrics-only mode (monitoring without correction)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: false  # Compute metrics but don't apply weights
  rollout_is_level: token
  rollout_is_mode: truncate
```

**Example script:** `bash
examples/rollout_importance_sampling/run_with_rollout_is.sh`

---

## Code Changes Overview

### New Files (4 files, 1,442 lines)

1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines)
   - Core implementation of IS weight computation
   - Three aggregation levels: token, sequence, geometric
   - Two bounding modes: truncate, mask
   - Veto mechanism for outlier detection
   - Comprehensive metrics computation (IS + mismatch)
   - All computations in log-space for numerical stability
   - Memory-efficient design

2. **`docs/advance/rollout_is_migration.md`** (642 lines)
   - Comprehensive migration guide from legacy TIS
   - Detailed explanation of all configuration options
   - Recommended threshold ranges for each aggregation level
   - Troubleshooting guide and best practices
   - Metrics interpretation guide

3. **`examples/rollout_importance_sampling/README.md`** (242 lines)
   - Quick start guide with working examples
   - Configuration templates for common scenarios
   - Threshold tuning guidelines
   - Metrics monitoring instructions

4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99
lines)
   - Complete working example script
   - Demonstrates token-level and sequence-level configurations
   - Ready to run with minimal modifications

### Modified Core Files (9 files)

1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed)
   - Removed legacy TIS logic (`tis_imp_ratio_cap`)
   - Added `rollout_is_weights` parameter to all policy loss functions
   - Unified IS weight application interface across all PPO variants:
     - `compute_policy_loss_vanilla`
     - `compute_policy_loss_gspo`
     - `compute_policy_loss_gpg`
     - `compute_policy_loss_clip_cov`
     - `compute_policy_loss_kl_cov`
     - `compute_policy_loss_geo_mean`
   - Special handling for `geo_mean` (sequence-level aggregation)

2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added)
   - New method: `compute_rollout_importance_weights_and_add_to_batch()`
   - Centralized IS computation (once per batch, on driver)
- Conditional weight distribution to workers based on
`algorithm.rollout_is`
   - Metrics collection and aggregation
   - Integration with existing training loop

3. **`verl/trainer/config/algorithm.py`** (+18 lines)
   - Added 6 new Rollout IS parameters:
     - `rollout_is_threshold` (main on/off switch)
     - `rollout_is` (weight application control)
     - `rollout_is_threshold_lower`
     - `rollout_is_level`
     - `rollout_is_mode`
     - `rollout_is_veto_threshold`
   - Comprehensive docstrings explaining each parameter

4. **`verl/workers/config/actor.py`** (-1 line)
   - Removed deprecated `tis_imp_ratio_cap` parameter

5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

7. **Configuration Files** (4 files updated)
   - `verl/trainer/config/ppo_trainer.yaml`
   - `verl/trainer/config/ppo_megatron_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml`
- Added default Rollout IS configuration section with explanatory
comments

### Testing (2 files, 530 lines)

1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines)
   - Unit tests for `mismatch_helper.py`
   - Coverage for all aggregation levels (token, sequence, geometric)
   - Coverage for all bounding modes (truncate, mask)
   - Veto mechanism tests
   - Edge case handling (zeros, extremes, empty sequences)
   - Numerical stability verification
   - Metrics correctness validation

2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines)
   - Integration tests with PPO training loop
   - End-to-end workflow validation
   - Batch processing tests
   - Configuration validation
   - Metrics collection verification
   - Compatibility with distributed training

### Updated Recipes (2 files)

1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines)
   - Updated imports to use new framework

2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed)
   - Migrated from legacy TIS to new Rollout IS configuration
   - Updated documentation and comments

### Documentation Updates (2 files)

1. **`docs/examples/config.rst`** (~22 lines changed)
   - Updated configuration examples
   - Added Rollout IS section

2. **`docs/index.rst`** (+1 line)
   - Added link to Rollout IS migration guide

---

## Implementation Highlights

### Centralized Architecture

The new design follows a clean separation of concerns:

```
ray_trainer.py (driver)
    └─> compute_rollout_importance_weights_and_add_to_batch()
         └─> mismatch_helper.compute_rollout_importance_weights()
              ├─> Computes IS weights (token/sequence/geometric)
              ├─> Applies bounding (truncate/mask)
              ├─> Veto mechanism for outliers
              ├─> Computes IS metrics
              └─> Computes mismatch metrics (KL, PPL)
    └─> Conditionally adds weights to batch (if rollout_is=True)
    └─> Distributes batch to workers

actor workers (dp_actor, megatron_actor)
    └─> Receive batch with rollout_is_weights (if enabled)
    └─> Pass weights to policy loss function

core_algos.py
    └─> All policy loss functions accept rollout_is_weights
    └─> Apply weights if provided: pg_losses *= rollout_is_weights
```

### Key Design Decisions

1. **Centralized Computation**: IS weights computed once on driver, not
per worker
   - Reduces redundant computation
   - Ensures consistency across workers
   - Simplifies debugging and metrics collection

2. **Configuration in Algorithm**: Moved from actor config to algorithm
config
- Better conceptual organization (algorithm-level concern, not
worker-level)
   - Easier to manage and validate
   - Consistent with other algorithm parameters

3. **Two-Level Control**:
   - `rollout_is_threshold`: Enables/disables entire system (null = off)
- `rollout_is`: Controls weight application (true = apply, false =
metrics only)
   - Allows flexible monitoring and gradual rollout

4. **Metrics Consolidation**: Mismatch metrics computed within IS weight
computation
   - Eliminates duplicate computation
   - Reduces memory overhead
   - Maintains metric accuracy

5. **Universal PPO Support**: Single interface for all PPO variants
   - Minimal code changes required
   - Consistent behavior across algorithms
   - Easy to add new variants

---

## Migration Guide

### For Users of Legacy TIS

**Step 1: Update your configuration file**

```yaml
# OLD (remove this)
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0

# NEW (add this)
algorithm:
  rollout_is_threshold: 2.0  # Use same value as old tis_imp_ratio_cap
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

# REQUIRED (add if not present)
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

**Step 2: Monitor metrics**

The first time you run with the new configuration, check these metrics:
- `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size
- `mismatch/rollout_is_veto_fraction`: Should be < 5%
- `mismatch/rollout_is_mean`: Should be close to 1.0

**Step 3: Tune if needed**

If effective sample size is too low:
- Increase `rollout_is_threshold`
- Try `rollout_is_mode: mask` with appropriate lower bound
- Consider `rollout_is_level: sequence` for more aggressive correction

For detailed guidance, see `docs/advance/rollout_is_migration.md`.

### For New Users

Start with recommended defaults:

```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

Run the example script to see it in action:
```bash
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
```

---

## Testing

### Unit Tests
- **289 lines** of comprehensive unit tests in `test_rollout_is.py`
- Covers all aggregation levels, bounding modes, and edge cases
- Validates numerical stability and correctness
- Fast execution (~1-2 seconds)

### Integration Tests
- **241 lines** of integration tests in `test_rollout_is_integration.py`
- End-to-end workflow with PPO training loop
- Distributed training compatibility
- Metrics collection validation
- Moderate execution time (~10-20 seconds)

### Running Tests
```bash
# Run all Rollout IS tests
pytest tests/trainer/ppo/test_rollout_is.py -v
pytest tests/trainer/ppo/test_rollout_is_integration.py -v

# Run specific test
pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v
```

---

## Metrics Reference

### Rollout IS Metrics (all prefixed with `mismatch/`)

| Metric | Description | Ideal Range |
|--------|-------------|-------------|
| `rollout_is_eff_sample_size` | Effective number of samples after IS |
> 80% of batch |
| `rollout_is_mean` | Mean IS weight | ~1.0 |
| `rollout_is_std` | Standard deviation of IS weights | Low variance |
| `rollout_is_p25` | 25th percentile | ~0.8-1.0 |
| `rollout_is_p50` | Median IS weight | ~1.0 |
| `rollout_is_p75` | 75th percentile | ~1.0-1.2 |
| `rollout_is_p95` | 95th percentile | < threshold |
| `rollout_is_p99` | 99th percentile | < threshold |
| `rollout_is_max` | Maximum weight | ≤ threshold |
| `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) |
| `rollout_is_veto_fraction` | % sequences vetoed | < 5% |
| `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | <
1% |
| `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable
|

### Mismatch Metrics (all prefixed with `mismatch/`)

| Metric | Description | What It Means |
|--------|-------------|---------------|
| `mismatch_kl` | Forward KL divergence | Distribution difference
(rollout vs training) |
| `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small
divergences |
| `mismatch_training_ppl` | Training policy perplexity | Prediction
difficulty of training policy |
| `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction
difficulty of rollout policy |
| `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative
prediction difficulty |
| `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level
PPL mismatch |
| `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude
of mismatch |
| `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case
mismatch |
| `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case
mismatch |
| `mismatch_training_log_ppl` | Log of training PPL | Log-scale training
perplexity |
| `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout
perplexity |

---

## Performance Impact

### Memory
- Minimal overhead: ~1-2% increase in peak memory usage
- Efficient log-space computation
- No large intermediate tensors

### Computation
- Negligible impact on training speed: < 1% overhead
- Centralized computation on driver (no per-worker redundancy)
- Optimized tensor operations

### Training Stability
- Significant improvement in stability when distribution mismatch exists
- Faster convergence in many scenarios
- Reduced risk of training collapse

---

## Breaking Changes

> [!IMPORTANT]
> This PR contains **BREAKING CHANGES** to the configuration API.

### Removed
- `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported

### Migration Required
All users of the legacy TIS implementation must update their
configuration files. See the migration guide above or
`docs/advance/rollout_is_migration.md` for detailed instructions.

### Backward Compatibility
- No backward compatibility with legacy TIS
- Configuration files with `tis_imp_ratio_cap` will raise validation
errors
- Affected recipes have been updated in this PR

---

## Pre-Submission Checklist

- [x] Search for similar PRs:
[https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling)
- [x] Format PR title as `[{modules}] {type}: {description}` (checked by
CI)
- **Suggested title:** `[BREAKING][rollout, trainer, algo] feat:
implement comprehensive Rollout Importance Sampling framework`
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md)
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting)
- [x] Add/update
[documentation](https://github.com/volcengine/verl/tree/main/docs) (3
new docs, 2 updated)
- [x] Add unit and integration tests (530 lines of tests)
- [x] Once PR is ready for CI, send message in `ci-request` channel

---

## References

- **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse
from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda)
- **Migration guide:** `docs/advance/rollout_is_migration.md`
- **Examples:** `examples/rollout_importance_sampling/`
- **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
This commit is contained in:
Yingru Li
2025-10-14 02:06:36 +08:00
committed by GitHub
parent 21271aabb9
commit 5d378b5f95
10 changed files with 45 additions and 45 deletions

View File

@ -55,7 +55,7 @@ actor_rollout_ref:
The new implementation: The new implementation:
- ✅ Three aggregation levels: token, sequence, geometric - ✅ Three aggregation levels: token, sequence, geometric
- ✅ Two bounding modes: truncate, clip - ✅ Two bounding modes: truncate, mask
- ✅ Dual threshold support (upper/lower) - ✅ Dual threshold support (upper/lower)
- ✅ Veto mechanism for catastrophic outliers - ✅ Veto mechanism for catastrophic outliers
- ✅ 30+ comprehensive metrics - ✅ 30+ comprehensive metrics
@ -150,7 +150,7 @@ Aggregation level for IS weights:
### `algorithm.rollout_is_mode` (str) ### `algorithm.rollout_is_mode` (str)
Bounding mode: Bounding mode:
- `"truncate"`: Cap weights at upper threshold only - `"truncate"`: Cap weights at upper threshold only
- `"clip"`: Zero out weights outside [lower, upper] - `"mask"`: Zero out weights outside [lower, upper]
### `algorithm.rollout_is_veto_threshold` (float) ### `algorithm.rollout_is_veto_threshold` (float)
Per-token veto threshold. If any token ratio < this, entire sequence is rejected. Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
@ -199,7 +199,7 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
- **`rollout_is_min`**: Minimum IS weight observed - **`rollout_is_min`**: Minimum IS weight observed
- Shows the most underweighted token/sequence - Shows the most underweighted token/sequence
- **`rollout_is_max`**: Maximum IS weight observed (before clipping) - **`rollout_is_max`**: Maximum IS weight observed (before truncation/masking)
- Shows the most overweighted token/sequence - Shows the most overweighted token/sequence
- Compare with `rollout_is_threshold` to see truncation impact - Compare with `rollout_is_threshold` to see truncation impact
@ -235,11 +235,11 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
#### **Threshold Exceedance Metrics** #### **Threshold Exceedance Metrics**
- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold - **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold
- Shows how often truncation/clipping occurs on high end - Shows how often truncation/masking occurs on high end
- **Ideal value**: < 0.1 (most weights within bounds) - **Ideal value**: < 0.1 (most weights within bounds)
- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold - **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold
- Shows how often clipping occurs on low end (clip mode only) - Shows how often masking occurs on low end (mask mode only)
- **Ideal value**: < 0.1 - **Ideal value**: < 0.1
#### **Sequence-Level Metrics** (for sequence/geometric modes) #### **Sequence-Level Metrics** (for sequence/geometric modes)
@ -261,14 +261,14 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold - **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold
#### **Clipping Metrics** (clip mode only) #### **Masking Metrics** (mask mode only)
- **`rollout_is_clipped_fraction`**: Fraction of tokens clipped (set to zero) - **`rollout_is_masked_fraction`**: Fraction of tokens masked (set to zero)
- **Ideal value**: < 0.1 - **Ideal value**: < 0.1
- **Warning**: > 0.3 means losing too much data - **Warning**: > 0.3 means losing too much data
- **`rollout_is_seq_clipped_fraction`**: Fraction of sequences with at least one clipped token - **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one masked token
- Shows sequence-level impact of clipping - Shows sequence-level impact of masking
#### **Distribution Mismatch Metrics** (Training vs Rollout Policy) #### **Distribution Mismatch Metrics** (Training vs Rollout Policy)
@ -456,14 +456,14 @@ algorithm:
rollout_is_mode: truncate rollout_is_mode: truncate
``` ```
### Example 3: Geometric Mean with Clip ### Example 3: Geometric Mean with Mask
```yaml ```yaml
algorithm: algorithm:
rollout_is_threshold: 1.0002 rollout_is_threshold: 1.0002
rollout_is: true rollout_is: true
rollout_is_threshold_lower: 0.9998 rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric rollout_is_level: geometric
rollout_is_mode: clip rollout_is_mode: mask
``` ```
### Example 4: Asymmetric Thresholds ### Example 4: Asymmetric Thresholds
@ -473,7 +473,7 @@ algorithm:
rollout_is: true rollout_is: true
rollout_is_threshold_lower: 0.8 rollout_is_threshold_lower: 0.8
rollout_is_level: token rollout_is_level: token
rollout_is_mode: clip rollout_is_mode: mask
``` ```
## Troubleshooting ## Troubleshooting

View File

@ -123,7 +123,7 @@ Actor/Rollout/Reference Policy
rollout_is_threshold: null # Upper threshold for IS weights (null to disable) rollout_is_threshold: null # Upper threshold for IS weights (null to disable)
rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper) rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)
rollout_is_level: token # Aggregation: token/sequence/geometric rollout_is_level: token # Aggregation: token/sequence/geometric
rollout_is_mode: truncate # Bounding: truncate/clip rollout_is_mode: truncate # Bounding: truncate/mask
rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold
use_torch_compile: True # False to disable torch compile use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo kl_loss_coef: 0.001 # for grpo
@ -527,7 +527,7 @@ Algorithm
- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely. - ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely.
- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper). - ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).
- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental). - ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``clip`` (zero outside bounds). - ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds).
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4. - ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4.
Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``. Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.

View File

@ -86,7 +86,7 @@ algorithm:
rollout_is_mode: truncate rollout_is_mode: truncate
``` ```
### Example 3: Geometric Mean with Clip ### Example 3: Geometric Mean with Mask
```yaml ```yaml
algorithm: algorithm:
@ -94,7 +94,7 @@ algorithm:
rollout_is: true rollout_is: true
rollout_is_threshold_lower: 0.9998 rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric rollout_is_level: geometric
rollout_is_mode: clip rollout_is_mode: mask
rollout_is_veto_threshold: 1e-4 rollout_is_veto_threshold: 1e-4
``` ```
@ -118,7 +118,7 @@ algorithm:
rollout_is: true rollout_is: true
rollout_is_threshold_lower: 0.8 rollout_is_threshold_lower: 0.8
rollout_is_level: token rollout_is_level: token
rollout_is_mode: clip rollout_is_mode: mask
``` ```
## Monitoring Metrics ## Monitoring Metrics
@ -183,9 +183,9 @@ These metrics help diagnose the distribution mismatch between rollout and traini
2. Verify rollout_log_probs are correctly passed 2. Verify rollout_log_probs are correctly passed
3. Check for systematic bias in rollout vs training 3. Check for systematic bias in rollout vs training
### Issue: Too Much Data Discarded (Clip Mode) ### Issue: Too Much Data Discarded (Mask Mode)
**Symptoms**: `rollout_is_clipped_fraction` > 0.5 **Symptoms**: `rollout_is_masked_fraction` > 0.5
**Solutions**: **Solutions**:
1. Widen thresholds 1. Widen thresholds

View File

@ -21,7 +21,7 @@ rollout_is_threshold_lower=null
# Aggregation level: token | sequence | geometric (experimental) # Aggregation level: token | sequence | geometric (experimental)
rollout_is_level=token rollout_is_level=token
# Bounding mode: truncate (cap upper) | clip (zero outside bounds) # Bounding mode: truncate (cap upper) | mask (zero outside bounds)
rollout_is_mode=truncate rollout_is_mode=truncate
# Catastrophic outlier veto threshold # Catastrophic outlier veto threshold

View File

@ -97,14 +97,14 @@ def test_basic_rollout_is():
rollout_log_prob=rollout_log_prob, rollout_log_prob=rollout_log_prob,
response_mask=eos_mask, response_mask=eos_mask,
rollout_is_level="geometric", rollout_is_level="geometric",
rollout_is_mode="clip", rollout_is_mode="mask",
rollout_is_threshold=1.5, rollout_is_threshold=1.5,
rollout_is_threshold_lower=0.5, rollout_is_threshold_lower=0.5,
rollout_is_veto_threshold=1e-4, rollout_is_veto_threshold=1e-4,
) )
print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}") print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}")
print(f" Clipped fraction: {metrics_geo['mismatch/rollout_is_clipped_fraction']:.4f}") print(f" Masked fraction: {metrics_geo['mismatch/rollout_is_masked_fraction']:.4f}")
print(" ✓ Geometric mean mode passed") print(" ✓ Geometric mean mode passed")
# Test veto mechanism # Test veto mechanism

View File

@ -132,8 +132,8 @@ class TestRolloutISIntegration:
assert "mismatch/rollout_is_mean" in metrics assert "mismatch/rollout_is_mean" in metrics
def test_both_bounding_modes(self, sample_data): def test_both_bounding_modes(self, sample_data):
"""Test both truncate and clip modes.""" """Test both truncate and mask modes."""
modes = ["truncate", "clip"] modes = ["truncate", "mask"]
for mode in modes: for mode in modes:
_, metrics = compute_rollout_importance_weights( _, metrics = compute_rollout_importance_weights(

View File

@ -77,7 +77,7 @@ class AlgoConfig(BaseConfig):
float value = enabled (compute weights and metrics). This is the main on/off switch. float value = enabled (compute weights and metrics). This is the main on/off switch.
rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper. rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.
rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric". rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric".
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "clip" (zero outside bounds). rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds).
rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers. rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers.
rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights, rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,
False = compute metrics only (useful for monitoring before enabling correction). Default: False. False = compute metrics only (useful for monitoring before enabling correction). Default: False.

View File

@ -84,7 +84,7 @@ algorithm:
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental) # Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
rollout_is_level: token rollout_is_level: token
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds) # Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
rollout_is_mode: truncate rollout_is_mode: truncate
# Per-token veto threshold for catastrophic outliers # Per-token veto threshold for catastrophic outliers

View File

@ -124,7 +124,7 @@ algorithm:
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental) # Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
rollout_is_level: token rollout_is_level: token
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds) # Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
rollout_is_mode: truncate rollout_is_mode: truncate
# Per-token veto threshold for catastrophic outliers # Per-token veto threshold for catastrophic outliers

View File

@ -20,7 +20,7 @@ training policy (e.g., FSDP FP32).
Key Features: Key Features:
1. Three aggregation levels: token, sequence, geometric 1. Three aggregation levels: token, sequence, geometric
2. Two handling modes: truncate (TIS), clip (CIS) 2. Two handling modes: truncate (TIS), mask (MIS)
3. Per-token veto mechanism for catastrophic outliers 3. Per-token veto mechanism for catastrophic outliers
4. Memory-efficient computation to prevent CUDA OOM 4. Memory-efficient computation to prevent CUDA OOM
5. Comprehensive metrics tracking 5. Comprehensive metrics tracking
@ -77,9 +77,9 @@ def compute_rollout_importance_weights(
- "geometric": Geometric mean of ratios (experimental) - "geometric": Geometric mean of ratios (experimental)
rollout_is_mode: How to handle weights exceeding threshold: rollout_is_mode: How to handle weights exceeding threshold:
- "truncate": Cap weights at upper_threshold only (TIS) - "truncate": Cap weights at upper_threshold only (TIS)
- "clip": Zero out weights outside [lower_threshold, upper_threshold] (CIS) - "mask": Zero out weights outside [lower_threshold, upper_threshold] (MIS)
rollout_is_threshold: Upper threshold for IS weights rollout_is_threshold: Upper threshold for IS weights
rollout_is_threshold_lower: Lower threshold for IS weights (clip mode only; if None, defaults to 1/upper) rollout_is_threshold_lower: Lower threshold for IS weights (mask mode only; if None, defaults to 1/upper)
rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence. rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence.
If None, veto mechanism is disabled. If None, veto mechanism is disabled.
@ -179,32 +179,32 @@ def compute_rollout_importance_weights(
SAFETY_BOUND=SAFETY_BOUND, SAFETY_BOUND=SAFETY_BOUND,
) )
# Step 3: Apply truncation or clipping based on mode # Step 3: Apply truncation or masking based on mode
if rollout_is_mode == "truncate": if rollout_is_mode == "truncate":
# Truncated IS (TIS): only cap upper bound to prevent overweighting # Truncated IS (TIS): only cap upper bound to prevent overweighting
rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold) rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold)
elif rollout_is_mode == "clip": elif rollout_is_mode == "mask":
# Clipped IS (CIS): zero out weights outside [lower_threshold, upper_threshold] # Masked IS (MIS): zero out weights outside [lower_threshold, upper_threshold]
clip_mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold) mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
clip_mask = clip_mask.float() mask = mask.float()
# Track CIS-specific metrics # Track MIS-specific metrics
metrics["rollout_is_clipped_fraction"] = verl_F.masked_mean(1 - clip_mask, response_mask) metrics["rollout_is_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask)
# Sequence-level clipping fraction # Sequence-level masking fraction
if rollout_is_level in ["sequence", "geometric"]: if rollout_is_level in ["sequence", "geometric"]:
# All tokens in a sequence have the same weight, so reuse clip_mask # All tokens in a sequence have the same weight, so reuse mask
metrics["rollout_is_seq_clipped_fraction"] = (1 - clip_mask[:, 0]).mean() metrics["rollout_is_seq_masked_fraction"] = (1 - mask[:, 0]).mean()
else: else:
# Check if any token in each sequence is clipped # Check if any token in each sequence is masked
seq_has_clipped = verl_F.masked_sum(1 - clip_mask, response_mask, axis=-1) > 0 seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0
metrics["rollout_is_seq_clipped_fraction"] = seq_has_clipped.float().mean() metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean()
rollout_is_weights = rollout_is_weights * clip_mask rollout_is_weights = rollout_is_weights * mask
else: else:
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'clip'.") raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.")
# Apply veto mask AFTER all thresholding # Apply veto mask AFTER all thresholding
# This zeros out entire sequences that have any catastrophic token # This zeros out entire sequences that have any catastrophic token