mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[rollout] refactor: rename "clip" mode back to "mask" mode (#3750)
# Rollout Importance Sampling Framework related to https://github.com/volcengine/verl/pull/3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
This commit is contained in:
@ -55,7 +55,7 @@ actor_rollout_ref:
|
||||
|
||||
The new implementation:
|
||||
- ✅ Three aggregation levels: token, sequence, geometric
|
||||
- ✅ Two bounding modes: truncate, clip
|
||||
- ✅ Two bounding modes: truncate, mask
|
||||
- ✅ Dual threshold support (upper/lower)
|
||||
- ✅ Veto mechanism for catastrophic outliers
|
||||
- ✅ 30+ comprehensive metrics
|
||||
@ -150,7 +150,7 @@ Aggregation level for IS weights:
|
||||
### `algorithm.rollout_is_mode` (str)
|
||||
Bounding mode:
|
||||
- `"truncate"`: Cap weights at upper threshold only
|
||||
- `"clip"`: Zero out weights outside [lower, upper]
|
||||
- `"mask"`: Zero out weights outside [lower, upper]
|
||||
|
||||
### `algorithm.rollout_is_veto_threshold` (float)
|
||||
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
|
||||
@ -199,7 +199,7 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
|
||||
- **`rollout_is_min`**: Minimum IS weight observed
|
||||
- Shows the most underweighted token/sequence
|
||||
|
||||
- **`rollout_is_max`**: Maximum IS weight observed (before clipping)
|
||||
- **`rollout_is_max`**: Maximum IS weight observed (before truncation/masking)
|
||||
- Shows the most overweighted token/sequence
|
||||
- Compare with `rollout_is_threshold` to see truncation impact
|
||||
|
||||
@ -235,11 +235,11 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
|
||||
#### **Threshold Exceedance Metrics**
|
||||
|
||||
- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold
|
||||
- Shows how often truncation/clipping occurs on high end
|
||||
- Shows how often truncation/masking occurs on high end
|
||||
- **Ideal value**: < 0.1 (most weights within bounds)
|
||||
|
||||
- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold
|
||||
- Shows how often clipping occurs on low end (clip mode only)
|
||||
- Shows how often masking occurs on low end (mask mode only)
|
||||
- **Ideal value**: < 0.1
|
||||
|
||||
#### **Sequence-Level Metrics** (for sequence/geometric modes)
|
||||
@ -261,14 +261,14 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
|
||||
|
||||
- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold
|
||||
|
||||
#### **Clipping Metrics** (clip mode only)
|
||||
#### **Masking Metrics** (mask mode only)
|
||||
|
||||
- **`rollout_is_clipped_fraction`**: Fraction of tokens clipped (set to zero)
|
||||
- **`rollout_is_masked_fraction`**: Fraction of tokens masked (set to zero)
|
||||
- **Ideal value**: < 0.1
|
||||
- **Warning**: > 0.3 means losing too much data
|
||||
|
||||
- **`rollout_is_seq_clipped_fraction`**: Fraction of sequences with at least one clipped token
|
||||
- Shows sequence-level impact of clipping
|
||||
- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one masked token
|
||||
- Shows sequence-level impact of masking
|
||||
|
||||
#### **Distribution Mismatch Metrics** (Training vs Rollout Policy)
|
||||
|
||||
@ -456,14 +456,14 @@ algorithm:
|
||||
rollout_is_mode: truncate
|
||||
```
|
||||
|
||||
### Example 3: Geometric Mean with Clip
|
||||
### Example 3: Geometric Mean with Mask
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 1.0002
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.9998
|
||||
rollout_is_level: geometric
|
||||
rollout_is_mode: clip
|
||||
rollout_is_mode: mask
|
||||
```
|
||||
|
||||
### Example 4: Asymmetric Thresholds
|
||||
@ -473,7 +473,7 @@ algorithm:
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.8
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: clip
|
||||
rollout_is_mode: mask
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
@ -123,7 +123,7 @@ Actor/Rollout/Reference Policy
|
||||
rollout_is_threshold: null # Upper threshold for IS weights (null to disable)
|
||||
rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)
|
||||
rollout_is_level: token # Aggregation: token/sequence/geometric
|
||||
rollout_is_mode: truncate # Bounding: truncate/clip
|
||||
rollout_is_mode: truncate # Bounding: truncate/mask
|
||||
rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold
|
||||
use_torch_compile: True # False to disable torch compile
|
||||
kl_loss_coef: 0.001 # for grpo
|
||||
@ -527,7 +527,7 @@ Algorithm
|
||||
- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely.
|
||||
- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).
|
||||
- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).
|
||||
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``clip`` (zero outside bounds).
|
||||
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds).
|
||||
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4.
|
||||
Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.
|
||||
|
||||
|
@ -86,7 +86,7 @@ algorithm:
|
||||
rollout_is_mode: truncate
|
||||
```
|
||||
|
||||
### Example 3: Geometric Mean with Clip
|
||||
### Example 3: Geometric Mean with Mask
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
@ -94,7 +94,7 @@ algorithm:
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.9998
|
||||
rollout_is_level: geometric
|
||||
rollout_is_mode: clip
|
||||
rollout_is_mode: mask
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
```
|
||||
|
||||
@ -118,7 +118,7 @@ algorithm:
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.8
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: clip
|
||||
rollout_is_mode: mask
|
||||
```
|
||||
|
||||
## Monitoring Metrics
|
||||
@ -183,9 +183,9 @@ These metrics help diagnose the distribution mismatch between rollout and traini
|
||||
2. Verify rollout_log_probs are correctly passed
|
||||
3. Check for systematic bias in rollout vs training
|
||||
|
||||
### Issue: Too Much Data Discarded (Clip Mode)
|
||||
### Issue: Too Much Data Discarded (Mask Mode)
|
||||
|
||||
**Symptoms**: `rollout_is_clipped_fraction` > 0.5
|
||||
**Symptoms**: `rollout_is_masked_fraction` > 0.5
|
||||
|
||||
**Solutions**:
|
||||
1. Widen thresholds
|
||||
|
@ -21,7 +21,7 @@ rollout_is_threshold_lower=null
|
||||
# Aggregation level: token | sequence | geometric (experimental)
|
||||
rollout_is_level=token
|
||||
|
||||
# Bounding mode: truncate (cap upper) | clip (zero outside bounds)
|
||||
# Bounding mode: truncate (cap upper) | mask (zero outside bounds)
|
||||
rollout_is_mode=truncate
|
||||
|
||||
# Catastrophic outlier veto threshold
|
||||
|
@ -97,14 +97,14 @@ def test_basic_rollout_is():
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_level="geometric",
|
||||
rollout_is_mode="clip",
|
||||
rollout_is_mode="mask",
|
||||
rollout_is_threshold=1.5,
|
||||
rollout_is_threshold_lower=0.5,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}")
|
||||
print(f" Clipped fraction: {metrics_geo['mismatch/rollout_is_clipped_fraction']:.4f}")
|
||||
print(f" Masked fraction: {metrics_geo['mismatch/rollout_is_masked_fraction']:.4f}")
|
||||
print(" ✓ Geometric mean mode passed")
|
||||
|
||||
# Test veto mechanism
|
||||
|
@ -132,8 +132,8 @@ class TestRolloutISIntegration:
|
||||
assert "mismatch/rollout_is_mean" in metrics
|
||||
|
||||
def test_both_bounding_modes(self, sample_data):
|
||||
"""Test both truncate and clip modes."""
|
||||
modes = ["truncate", "clip"]
|
||||
"""Test both truncate and mask modes."""
|
||||
modes = ["truncate", "mask"]
|
||||
|
||||
for mode in modes:
|
||||
_, metrics = compute_rollout_importance_weights(
|
||||
|
@ -77,7 +77,7 @@ class AlgoConfig(BaseConfig):
|
||||
float value = enabled (compute weights and metrics). This is the main on/off switch.
|
||||
rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.
|
||||
rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric".
|
||||
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "clip" (zero outside bounds).
|
||||
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds).
|
||||
rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers.
|
||||
rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,
|
||||
False = compute metrics only (useful for monitoring before enabling correction). Default: False.
|
||||
|
@ -84,7 +84,7 @@ algorithm:
|
||||
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
|
||||
rollout_is_level: token
|
||||
|
||||
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
|
||||
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
|
||||
rollout_is_mode: truncate
|
||||
|
||||
# Per-token veto threshold for catastrophic outliers
|
||||
|
@ -124,7 +124,7 @@ algorithm:
|
||||
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
|
||||
rollout_is_level: token
|
||||
|
||||
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
|
||||
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
|
||||
rollout_is_mode: truncate
|
||||
|
||||
# Per-token veto threshold for catastrophic outliers
|
||||
|
@ -20,7 +20,7 @@ training policy (e.g., FSDP FP32).
|
||||
|
||||
Key Features:
|
||||
1. Three aggregation levels: token, sequence, geometric
|
||||
2. Two handling modes: truncate (TIS), clip (CIS)
|
||||
2. Two handling modes: truncate (TIS), mask (MIS)
|
||||
3. Per-token veto mechanism for catastrophic outliers
|
||||
4. Memory-efficient computation to prevent CUDA OOM
|
||||
5. Comprehensive metrics tracking
|
||||
@ -77,9 +77,9 @@ def compute_rollout_importance_weights(
|
||||
- "geometric": Geometric mean of ratios (experimental)
|
||||
rollout_is_mode: How to handle weights exceeding threshold:
|
||||
- "truncate": Cap weights at upper_threshold only (TIS)
|
||||
- "clip": Zero out weights outside [lower_threshold, upper_threshold] (CIS)
|
||||
- "mask": Zero out weights outside [lower_threshold, upper_threshold] (MIS)
|
||||
rollout_is_threshold: Upper threshold for IS weights
|
||||
rollout_is_threshold_lower: Lower threshold for IS weights (clip mode only; if None, defaults to 1/upper)
|
||||
rollout_is_threshold_lower: Lower threshold for IS weights (mask mode only; if None, defaults to 1/upper)
|
||||
rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence.
|
||||
If None, veto mechanism is disabled.
|
||||
|
||||
@ -179,32 +179,32 @@ def compute_rollout_importance_weights(
|
||||
SAFETY_BOUND=SAFETY_BOUND,
|
||||
)
|
||||
|
||||
# Step 3: Apply truncation or clipping based on mode
|
||||
# Step 3: Apply truncation or masking based on mode
|
||||
if rollout_is_mode == "truncate":
|
||||
# Truncated IS (TIS): only cap upper bound to prevent overweighting
|
||||
rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold)
|
||||
|
||||
elif rollout_is_mode == "clip":
|
||||
# Clipped IS (CIS): zero out weights outside [lower_threshold, upper_threshold]
|
||||
clip_mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
|
||||
clip_mask = clip_mask.float()
|
||||
elif rollout_is_mode == "mask":
|
||||
# Masked IS (MIS): zero out weights outside [lower_threshold, upper_threshold]
|
||||
mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
|
||||
mask = mask.float()
|
||||
|
||||
# Track CIS-specific metrics
|
||||
metrics["rollout_is_clipped_fraction"] = verl_F.masked_mean(1 - clip_mask, response_mask)
|
||||
# Track MIS-specific metrics
|
||||
metrics["rollout_is_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask)
|
||||
|
||||
# Sequence-level clipping fraction
|
||||
# Sequence-level masking fraction
|
||||
if rollout_is_level in ["sequence", "geometric"]:
|
||||
# All tokens in a sequence have the same weight, so reuse clip_mask
|
||||
metrics["rollout_is_seq_clipped_fraction"] = (1 - clip_mask[:, 0]).mean()
|
||||
# All tokens in a sequence have the same weight, so reuse mask
|
||||
metrics["rollout_is_seq_masked_fraction"] = (1 - mask[:, 0]).mean()
|
||||
else:
|
||||
# Check if any token in each sequence is clipped
|
||||
seq_has_clipped = verl_F.masked_sum(1 - clip_mask, response_mask, axis=-1) > 0
|
||||
metrics["rollout_is_seq_clipped_fraction"] = seq_has_clipped.float().mean()
|
||||
# Check if any token in each sequence is masked
|
||||
seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0
|
||||
metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean()
|
||||
|
||||
rollout_is_weights = rollout_is_weights * clip_mask
|
||||
rollout_is_weights = rollout_is_weights * mask
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'clip'.")
|
||||
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.")
|
||||
|
||||
# Apply veto mask AFTER all thresholding
|
||||
# This zeros out entire sequences that have any catastrophic token
|
||||
|
Reference in New Issue
Block a user