Compare commits

...

5 Commits

Author SHA1 Message Date
21271aabb9 [BREAKING][rollout, trainer, algo] feat: comprehensive rollout importance sampling implementation (#3694)
# Rollout Importance Sampling Framework

## Summary

This PR introduces a comprehensive **Rollout Importance Sampling (IS)**
framework to correct distribution mismatch between data-collecting
(rollout) and training policies, a critical factor for ensuring stable
and efficient model training in RL fine-tuning.

This work is motivated by the analysis in our blog post, [When Speed
Kills Stability: Demystifying RL Collapse from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda).
If you find this implementation useful in your research, please consider
citing:

```bibtex
@misc{liu-li-2025,
  title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch},
  url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda},
  author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen},
  year = {2025},
  month = {September},
}
```

---

## Problem Statement

When using different policies for rollout generation (e.g., vLLM with
BFloat16) and training (e.g., FSDP with FP32), distribution mismatch
occurs, leading to:
- Biased gradient estimates
- Training instability and collapse
- Reduced sample efficiency
- Poor convergence properties

This framework addresses these issues through principled importance
sampling correction.

---

## Key Features & Improvements

### 1. **Flexible Aggregation Levels**
Three methods for calculating IS weights:
- **`token`**: Per-token importance ratios
- **`sequence`**: Product of per-token ratios
- **`geometric`**: Geometric mean of ratios

### 2. **Advanced Bounding Modes**
Two strategies to control weight variance:
- **`truncate`** (TIS): Caps weights at upper threshold only, preserving
gradients
- **`clip`** (CIS): Zeros out weights outside bounds, more aggressive
filtering

### 3. **Comprehensive Diagnostics**
Detailed metrics to monitor distribution mismatch and training health:

**Rollout IS Metrics** (automatically prefixed with `mismatch/`):
- Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean`
- Distribution statistics: `rollout_is_p25`, `rollout_is_p50`,
`rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`,
`rollout_is_min`, `rollout_is_std`
- Diagnostics: `rollout_is_veto_fraction`,
`rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction`
(clip mode)
- Sequence-level statistics (for sequence/geometric modes):
`rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`,
`rollout_is_seq_min`, etc.

**Mismatch Metrics** (computed efficiently within IS weight
computation):
- KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3
estimator for stability)
- Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`,
`mismatch_ppl_ratio`
- Log perplexity statistics: `mismatch_log_ppl_diff`,
`mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`,
`mismatch_log_ppl_diff_min`

### 4. **Outlier Mitigation**
- **Veto mechanism**: Automatically discards samples with catastrophic
importance weights (per-token ratios below threshold)
- Prevents gradient corruption from extreme outliers
- Configurable threshold (default: 1e-4)

### 5. **Numerical Stability**
- All core computations in **log-space** to prevent underflow/overflow
- Carefully designed clipping and bounding to maintain numerical
precision
- Safe handling of edge cases (zero probabilities, extreme ratios)

### 6. **Memory Efficiency**
- Optimized computation to minimize CUDA memory usage
- Efficient metric aggregation without large intermediate tensors
- Suitable for large-scale distributed training

### 7. **Metrics-Only Mode**
- Compute and monitor mismatch metrics **without** applying IS weights
- Useful for:
  - Understanding distribution mismatch before intervention
  - Deciding whether IS correction is needed
  - A/B testing IS impact
- Controlled by `algorithm.rollout_is` flag (independent of weight
computation)

### 8. **Universal PPO Support**
- Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov,
KL-Cov, geo_mean
- Consistent interface across different policy loss functions
- Automatic weight application when enabled

---

## API and Configuration Changes

### Migration from Legacy TIS

####  **Before (REMOVED)**
```yaml
# Old TIS configuration - NO LONGER SUPPORTED
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0  # Removed from actor config
```

The legacy implementation:
- Only supported token-level truncation
- No metrics tracking
- Lacked numerical stability
- Limited configurability

####  **After (New Framework)**

Configuration moved to `algorithm` section for better organization:

```yaml
algorithm:
  # Main on/off switch: null = disabled, float = enabled
  rollout_is_threshold: 2.0

  # Control weight application (independent of metrics computation)
  rollout_is: true  # true = apply weights, false = metrics only

  # Optional: lower threshold (defaults to 1/upper if null)
  rollout_is_threshold_lower: null

  # Aggregation level: "token", "sequence", or "geometric"
  rollout_is_level: token

  # Bounding mode: "truncate" or "clip"
  rollout_is_mode: truncate

  # Veto threshold for catastrophic outliers (null = disabled)
  rollout_is_veto_threshold: 1e-4

# REQUIRED: Enable log probability calculation
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

### Configuration Examples

**1. Token-level truncation (recommended starting point)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate
```

**2. Sequence-level clipping (more aggressive)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: sequence
  rollout_is_mode: clip
```

**3. Metrics-only mode (monitoring without correction)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: false  # Compute metrics but don't apply weights
  rollout_is_level: token
  rollout_is_mode: truncate
```

**Example script:** `bash
examples/rollout_importance_sampling/run_with_rollout_is.sh`

---

## Code Changes Overview

### New Files (4 files, 1,442 lines)

1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines)
   - Core implementation of IS weight computation
   - Three aggregation levels: token, sequence, geometric
   - Two bounding modes: truncate, clip
   - Veto mechanism for outlier detection
   - Comprehensive metrics computation (IS + mismatch)
   - All computations in log-space for numerical stability
   - Memory-efficient design

2. **`docs/advance/rollout_is_migration.md`** (642 lines)
   - Comprehensive migration guide from legacy TIS
   - Detailed explanation of all configuration options
   - Recommended threshold ranges for each aggregation level
   - Troubleshooting guide and best practices
   - Metrics interpretation guide

3. **`examples/rollout_importance_sampling/README.md`** (242 lines)
   - Quick start guide with working examples
   - Configuration templates for common scenarios
   - Threshold tuning guidelines
   - Metrics monitoring instructions

4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99
lines)
   - Complete working example script
   - Demonstrates token-level and sequence-level configurations
   - Ready to run with minimal modifications

### Modified Core Files (9 files)

1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed)
   - Removed legacy TIS logic (`tis_imp_ratio_cap`)
   - Added `rollout_is_weights` parameter to all policy loss functions
   - Unified IS weight application interface across all PPO variants:
     - `compute_policy_loss_vanilla`
     - `compute_policy_loss_gspo`
     - `compute_policy_loss_gpg`
     - `compute_policy_loss_clip_cov`
     - `compute_policy_loss_kl_cov`
     - `compute_policy_loss_geo_mean`
   - Special handling for `geo_mean` (sequence-level aggregation)

2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added)
   - New method: `compute_rollout_importance_weights_and_add_to_batch()`
   - Centralized IS computation (once per batch, on driver)
- Conditional weight distribution to workers based on
`algorithm.rollout_is`
   - Metrics collection and aggregation
   - Integration with existing training loop

3. **`verl/trainer/config/algorithm.py`** (+18 lines)
   - Added 6 new Rollout IS parameters:
     - `rollout_is_threshold` (main on/off switch)
     - `rollout_is` (weight application control)
     - `rollout_is_threshold_lower`
     - `rollout_is_level`
     - `rollout_is_mode`
     - `rollout_is_veto_threshold`
   - Comprehensive docstrings explaining each parameter

4. **`verl/workers/config/actor.py`** (-1 line)
   - Removed deprecated `tis_imp_ratio_cap` parameter

5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

7. **Configuration Files** (4 files updated)
   - `verl/trainer/config/ppo_trainer.yaml`
   - `verl/trainer/config/ppo_megatron_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml`
- Added default Rollout IS configuration section with explanatory
comments

### Testing (2 files, 530 lines)

1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines)
   - Unit tests for `mismatch_helper.py`
   - Coverage for all aggregation levels (token, sequence, geometric)
   - Coverage for all bounding modes (truncate, clip)
   - Veto mechanism tests
   - Edge case handling (zeros, extremes, empty sequences)
   - Numerical stability verification
   - Metrics correctness validation

2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines)
   - Integration tests with PPO training loop
   - End-to-end workflow validation
   - Batch processing tests
   - Configuration validation
   - Metrics collection verification
   - Compatibility with distributed training

### Updated Recipes (2 files)

1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines)
   - Updated imports to use new framework

2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed)
   - Migrated from legacy TIS to new Rollout IS configuration
   - Updated documentation and comments

### Documentation Updates (2 files)

1. **`docs/examples/config.rst`** (~22 lines changed)
   - Updated configuration examples
   - Added Rollout IS section

2. **`docs/index.rst`** (+1 line)
   - Added link to Rollout IS migration guide

---

## Implementation Highlights

### Centralized Architecture

The new design follows a clean separation of concerns:

```
ray_trainer.py (driver)
    └─> compute_rollout_importance_weights_and_add_to_batch()
         └─> mismatch_helper.compute_rollout_importance_weights()
              ├─> Computes IS weights (token/sequence/geometric)
              ├─> Applies bounding (truncate/clip)
              ├─> Veto mechanism for outliers
              ├─> Computes IS metrics
              └─> Computes mismatch metrics (KL, PPL)
    └─> Conditionally adds weights to batch (if rollout_is=True)
    └─> Distributes batch to workers

actor workers (dp_actor, megatron_actor)
    └─> Receive batch with rollout_is_weights (if enabled)
    └─> Pass weights to policy loss function

core_algos.py
    └─> All policy loss functions accept rollout_is_weights
    └─> Apply weights if provided: pg_losses *= rollout_is_weights
```

### Key Design Decisions

1. **Centralized Computation**: IS weights computed once on driver, not
per worker
   - Reduces redundant computation
   - Ensures consistency across workers
   - Simplifies debugging and metrics collection

2. **Configuration in Algorithm**: Moved from actor config to algorithm
config
- Better conceptual organization (algorithm-level concern, not
worker-level)
   - Easier to manage and validate
   - Consistent with other algorithm parameters

3. **Two-Level Control**:
   - `rollout_is_threshold`: Enables/disables entire system (null = off)
- `rollout_is`: Controls weight application (true = apply, false =
metrics only)
   - Allows flexible monitoring and gradual rollout

4. **Metrics Consolidation**: Mismatch metrics computed within IS weight
computation
   - Eliminates duplicate computation
   - Reduces memory overhead
   - Maintains metric accuracy

5. **Universal PPO Support**: Single interface for all PPO variants
   - Minimal code changes required
   - Consistent behavior across algorithms
   - Easy to add new variants

---

## Migration Guide

### For Users of Legacy TIS

**Step 1: Update your configuration file**

```yaml
# OLD (remove this)
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0

# NEW (add this)
algorithm:
  rollout_is_threshold: 2.0  # Use same value as old tis_imp_ratio_cap
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

# REQUIRED (add if not present)
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

**Step 2: Monitor metrics**

The first time you run with the new configuration, check these metrics:
- `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size
- `mismatch/rollout_is_veto_fraction`: Should be < 5%
- `mismatch/rollout_is_mean`: Should be close to 1.0

**Step 3: Tune if needed**

If effective sample size is too low:
- Increase `rollout_is_threshold`
- Try `rollout_is_mode: clip` with appropriate lower bound
- Consider `rollout_is_level: sequence` for more aggressive correction

For detailed guidance, see `docs/advance/rollout_is_migration.md`.

### For New Users

Start with recommended defaults:

```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

Run the example script to see it in action:
```bash
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
```

---

## Testing

### Unit Tests
- **289 lines** of comprehensive unit tests in `test_rollout_is.py`
- Covers all aggregation levels, bounding modes, and edge cases
- Validates numerical stability and correctness
- Fast execution (~1-2 seconds)

### Integration Tests
- **241 lines** of integration tests in `test_rollout_is_integration.py`
- End-to-end workflow with PPO training loop
- Distributed training compatibility
- Metrics collection validation
- Moderate execution time (~10-20 seconds)

### Running Tests
```bash
# Run all Rollout IS tests
pytest tests/trainer/ppo/test_rollout_is.py -v
pytest tests/trainer/ppo/test_rollout_is_integration.py -v

# Run specific test
pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v
```

---

## Metrics Reference

### Rollout IS Metrics (all prefixed with `mismatch/`)

| Metric | Description | Ideal Range |
|--------|-------------|-------------|
| `rollout_is_eff_sample_size` | Effective number of samples after IS |
> 80% of batch |
| `rollout_is_mean` | Mean IS weight | ~1.0 |
| `rollout_is_std` | Standard deviation of IS weights | Low variance |
| `rollout_is_p25` | 25th percentile | ~0.8-1.0 |
| `rollout_is_p50` | Median IS weight | ~1.0 |
| `rollout_is_p75` | 75th percentile | ~1.0-1.2 |
| `rollout_is_p95` | 95th percentile | < threshold |
| `rollout_is_p99` | 99th percentile | < threshold |
| `rollout_is_max` | Maximum weight | ≤ threshold |
| `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) |
| `rollout_is_veto_fraction` | % sequences vetoed | < 5% |
| `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | <
1% |
| `rollout_is_clipped_fraction` | % tokens clipped (clip mode) |
Variable |

### Mismatch Metrics (all prefixed with `mismatch/`)

| Metric | Description | What It Means |
|--------|-------------|---------------|
| `mismatch_kl` | Forward KL divergence | Distribution difference
(rollout vs training) |
| `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small
divergences |
| `mismatch_training_ppl` | Training policy perplexity | Prediction
difficulty of training policy |
| `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction
difficulty of rollout policy |
| `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative
prediction difficulty |
| `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level
PPL mismatch |
| `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude
of mismatch |
| `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case
mismatch |
| `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case
mismatch |
| `mismatch_training_log_ppl` | Log of training PPL | Log-scale training
perplexity |
| `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout
perplexity |

---

## Performance Impact

### Memory
- Minimal overhead: ~1-2% increase in peak memory usage
- Efficient log-space computation
- No large intermediate tensors

### Computation
- Negligible impact on training speed: < 1% overhead
- Centralized computation on driver (no per-worker redundancy)
- Optimized tensor operations

### Training Stability
- Significant improvement in stability when distribution mismatch exists
- Faster convergence in many scenarios
- Reduced risk of training collapse

---

## Breaking Changes

> [!IMPORTANT]
> This PR contains **BREAKING CHANGES** to the configuration API.

### Removed
- `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported

### Migration Required
All users of the legacy TIS implementation must update their
configuration files. See the migration guide above or
`docs/advance/rollout_is_migration.md` for detailed instructions.

### Backward Compatibility
- No backward compatibility with legacy TIS
- Configuration files with `tis_imp_ratio_cap` will raise validation
errors
- Affected recipes have been updated in this PR

---

## Pre-Submission Checklist

- [x] Search for similar PRs:
[https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling)
- [x] Format PR title as `[{modules}] {type}: {description}` (checked by
CI)
- **Suggested title:** `[BREAKING][rollout, trainer, algo] feat:
implement comprehensive Rollout Importance Sampling framework`
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md)
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting)
- [x] Add/update
[documentation](https://github.com/volcengine/verl/tree/main/docs) (3
new docs, 2 updated)
- [x] Add unit and integration tests (530 lines of tests)
- [x] Once PR is ready for CI, send message in `ci-request` channel

---

## References

- **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse
from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda)
- **Migration guide:** `docs/advance/rollout_is_migration.md`
- **Examples:** `examples/rollout_importance_sampling/`
- **Tests:** `tests/trainer/ppo/test_rollout_is*.py`

---------

Co-authored-by: Yan Bai <bayan@nvidia.com>
2025-10-13 17:05:29 +08:00
7f27789961 [fsdp,doc] refactor: rename warmup_style@FSDPOptimizerConfig -> lr_scheduler_type (#3739)
### What does this PR do?

> Rename `warmup_style` in FSDPOptimizerConfig to `lr_scheduler_type` to
align with Hugging Face Trainer API。

The following pull request is for refactoring the optimizer, however,
the naming issue persists.
https://github.com/volcengine/verl/pull/3656 
### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: weiqi.li <weiqi.li@bytedance.com>
2025-10-13 15:58:59 +08:00
e9ee6b39c6 [model] fix: qwen3vl models shape mismatch error with SP (#3735) 2025-10-13 13:09:10 +08:00
9d4554b931 [model] fix: qwen3vl training stuck with mixed text-image data (#3734) 2025-10-13 13:08:13 +08:00
71cf69e7ad [ci] feat: increase sft e2e time (#3738)
### What does this PR do?

- As title

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
2025-10-13 11:29:39 +08:00
39 changed files with 2340 additions and 86 deletions

View File

@ -91,7 +91,7 @@ jobs:
e2e_sft:
needs: setup
runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"]
timeout-minutes: 25 # Increase this timeout value as needed
timeout-minutes: 30 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}

View File

@ -0,0 +1,642 @@
# Rollout Importance Sampling - Migration Guide
Last updated: 10/11/2025.
This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation merged from aiic_verl into verl.
## References
- **When Speed Kills Stability**: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
- **Off-policy RL**: https://fengyao.notion.site/off-policy-rl
## Overview
Rollout Importance Sampling corrects for distribution mismatch between:
- **Rollout policy**: e.g., vLLM with BFloat16
- **Training policy**: e.g., FSDP with FP32
This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases.
## What Changed
### **Removed (Old Implementation)**
```yaml
# Old TIS configuration (REMOVED)
actor:
tis_imp_ratio_cap: 2.0 # ❌ No longer supported
```
The old implementation:
- Only supported token-level truncate mode
- Had no metrics tracking
- Lacked numerical stability safeguards
- No configurability for different scenarios
### **Added (New Implementation)**
```yaml
# New Rollout IS configuration (all in algorithm config)
algorithm:
# Main control: set threshold to enable (null = disabled)
rollout_is_threshold: 2.0
# Whether to apply weights to loss (default: false = metrics only)
rollout_is: true
rollout_is_threshold_lower: null # Auto-reciprocal
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
# REQUIRED: Enable log prob calculation
actor_rollout_ref:
rollout:
calculate_log_probs: true
```
The new implementation:
- ✅ Three aggregation levels: token, sequence, geometric
- ✅ Two bounding modes: truncate, clip
- ✅ Dual threshold support (upper/lower)
- ✅ Veto mechanism for catastrophic outliers
- ✅ 30+ comprehensive metrics
- ✅ Log-space computation for numerical stability
- ✅ Memory-efficient implementation
## Files Modified
### **Core Implementation**
1. **NEW**: `verl/trainer/ppo/mismatch_helper.py`
- Contains `compute_rollout_importance_weights()` - main function
- Contains `compute_is_metrics()` - comprehensive metrics
2. **MODIFIED**: `verl/trainer/ppo/core_algos.py` (lines 962-991)
- Replaced old TIS implementation (lines 962-967)
- Added new rollout IS with metrics support
3. **MODIFIED**: `verl/workers/actor/dp_actor.py`
- Updated to use `rollout_is_threshold` instead of `tis_imp_ratio_cap`
- Collects and logs all rollout IS metrics
### **Configuration Files**
4. **MODIFIED**: `verl/trainer/config/algorithm.py` (lines 95-100)
- Added 6 new rollout IS parameters to `AlgoConfig`
5. **MODIFIED**: `verl/workers/config/actor.py` (lines 110-115)
- Added 6 new rollout IS parameters to `ActorConfig`
6. **MODIFIED**: `verl/trainer/config/actor/actor.yaml` (lines 77-89)
- Added rollout IS configuration section
7. **MODIFIED**: `verl/trainer/config/ppo_trainer.yaml` (lines 116-133)
- Added rollout IS to algorithm config
### **Documentation**
8. **MODIFIED**: `docs/examples/config.rst`
- Updated actor config with rollout IS parameters
- Updated algorithm config with rollout IS parameters
- Added detailed parameter descriptions
### **Example Scripts**
9. **MODIFIED**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
- Updated from `tis_imp_ratio_cap` to rollout IS parameters
- Added comprehensive comments
10. **NEW**: `examples/rollout_importance_sampling/README.md`
- Comprehensive guide with usage patterns
- Troubleshooting section
- Performance considerations
11. **NEW**: `examples/rollout_importance_sampling/run_with_rollout_is.sh`
- Basic example with token-level truncate
### **Tests**
12. **NEW**: `tests/trainer/ppo/test_rollout_is.py`
- Unit tests for rollout IS functionality
13. **NEW**: `tests/trainer/ppo/test_rollout_is_integration.py`
- Integration tests with PPO
## Configuration Parameters
### `algorithm.rollout_is_threshold` (float or null)
**Main on/off switch.** Upper threshold for IS weights.
- `null` = disabled (no computation, no metrics)
- `float` value (e.g., 2.0) = enabled (compute weights and metrics)
### `algorithm.rollout_is` (bool)
Whether to apply IS weights to policy loss. Default: `False`
- `true` = apply weights to loss (full IS correction)
- `false` = compute metrics only (useful for monitoring before enabling)
**Recommended threshold ranges:**
- Token level: 1.5 - 5.0
- Sequence level: 2.0 - 10.0
- Geometric level: 1.0002 - 1.001
### `algorithm.rollout_is_threshold_lower` (float or null)
Lower threshold for IS weights. If `null`, defaults to 1/upper (reciprocal).
### `algorithm.rollout_is_level` (str)
Aggregation level for IS weights:
- `"token"`: Per-token ratios
- `"sequence"`: Product of ratios
- `"geometric"`: Geometric mean (experimental)
### `algorithm.rollout_is_mode` (str)
Bounding mode:
- `"truncate"`: Cap weights at upper threshold only
- `"clip"`: Zero out weights outside [lower, upper]
### `algorithm.rollout_is_veto_threshold` (float)
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
Default: `1e-4` (ratio 10,000x off)
## Migration Steps
### Step 1: Update Your Configuration
**Before (Old):**
```yaml
actor_rollout_ref:
actor:
tis_imp_ratio_cap: 2.0
rollout:
calculate_log_probs: true
```
**After (New):**
```yaml
algorithm:
rollout_is_threshold: 2.0 # Main control
rollout_is: true # Apply to loss (default: false)
rollout_is_level: token
rollout_is_mode: truncate
actor_rollout_ref:
rollout:
calculate_log_probs: true # Still required!
```
### Step 2: Monitor New Metrics
All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs.
#### **Core IS Weight Metrics**
- **`rollout_is_mean`**: Mean importance sampling weight across all valid tokens
- **Ideal value**: Close to 1.0 (indicates minimal distribution mismatch)
- **Warning**: < 0.5 or > 2.0 suggests significant policy mismatch
- **`rollout_is_std`**: Standard deviation of IS weights
- **Ideal value**: < 0.5 for stable training
- **Warning**: > 1.0 indicates high variance, may need tighter thresholds
- **`rollout_is_min`**: Minimum IS weight observed
- Shows the most underweighted token/sequence
- **`rollout_is_max`**: Maximum IS weight observed (before clipping)
- Shows the most overweighted token/sequence
- Compare with `rollout_is_threshold` to see truncation impact
#### **Percentile Metrics**
- **`rollout_is_p25`**: 25th percentile of IS weights
- **`rollout_is_p50`**: Median IS weight (50th percentile)
- Should be close to `rollout_is_mean` if distribution is symmetric
- **`rollout_is_p75`**: 75th percentile of IS weights
- **`rollout_is_p95`**: 95th percentile of IS weights
- Use to detect outliers
- **`rollout_is_p99`**: 99th percentile of IS weights
- Should be close to `rollout_is_threshold` if truncation is working
#### **Effective Sample Size**
- **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting
- **Formula**: `1 / mean(weights²)` where weights are normalized
- **Range**: 0.0 to 1.0 (as fraction of original batch)
- **Ideal value**: > 0.5 (retaining at least 50% effective samples)
- **Warning**: < 0.3 means high variance, losing too many effective samples
#### **Veto Mechanism Metrics**
- **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism
- **Ideal value**: < 0.05 (less than 5% vetoed)
- **Warning**: > 0.1 suggests policies are too different or numerical issues
- **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold
- Identifies problematic tokens before sequence-level veto
- **Warning**: > 0.01 indicates widespread distribution issues
#### **Threshold Exceedance Metrics**
- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold
- Shows how often truncation/clipping occurs on high end
- **Ideal value**: < 0.1 (most weights within bounds)
- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold
- Shows how often clipping occurs on low end (clip mode only)
- **Ideal value**: < 0.1
#### **Sequence-Level Metrics** (for sequence/geometric modes)
- **`rollout_is_seq_mean`**: Mean IS weight at sequence level
- Should match `rollout_is_mean` for sequence-level aggregation
- **`rollout_is_seq_std`**: Standard deviation of sequence-level IS weights
- **`rollout_is_seq_min`**: Minimum sequence-level IS weight
- **`rollout_is_seq_max`**: Maximum sequence-level IS weight
- **`rollout_is_seq_max_deviation`**: Maximum absolute deviation from 1.0 at sequence level
- **Ideal value**: < 1.0
- Shows worst-case sequence mismatch
- **`rollout_is_seq_fraction_high`**: Fraction of sequences exceeding upper threshold
- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold
#### **Clipping Metrics** (clip mode only)
- **`rollout_is_clipped_fraction`**: Fraction of tokens clipped (set to zero)
- **Ideal value**: < 0.1
- **Warning**: > 0.3 means losing too much data
- **`rollout_is_seq_clipped_fraction`**: Fraction of sequences with at least one clipped token
- Shows sequence-level impact of clipping
#### **Distribution Mismatch Metrics** (Training vs Rollout Policy)
- **`mismatch_training_ppl`**: Perplexity of training policy (e.g., FSDP FP32)
- **Formula**: `exp(-mean(log_probs))`
- Lower is better (model is more confident)
- **`mismatch_rollout_ppl`**: Perplexity of rollout policy (e.g., vLLM BF16)
- Should be close to `mismatch_training_ppl` if policies match well
- **`mismatch_ppl_ratio`**: Ratio of training PPL to rollout PPL
- **Formula**: `exp(mean(log(training_ppl / rollout_ppl)))`
- **Ideal value**: Close to 1.0
- **Meaning**: > 1.0 means training is less confident than rollout
- **`mismatch_training_log_ppl`**: Log perplexity of training policy
- Useful for identifying trends (linear scale)
- **`mismatch_rollout_log_ppl`**: Log perplexity of rollout policy
- **`mismatch_log_ppl_diff`**: Mean difference in log perplexities
- **Formula**: `mean(log_ppl_rollout - log_ppl_training)`
- **Ideal value**: Close to 0.0
- Sign indicates which policy is more confident
- **`mismatch_log_ppl_abs_diff`**: Mean absolute log perplexity difference
- Magnitude of mismatch regardless of direction
- **`mismatch_log_ppl_diff_max`**: Maximum log perplexity difference across sequences
- Identifies worst-case sequence
- **`mismatch_log_ppl_diff_min`**: Minimum log perplexity difference across sequences
- **`mismatch_kl`**: KL divergence KL(π_rollout || π_training)
- **Formula**: `mean(log_prob_rollout - log_prob_training)`
- **Ideal value**: Close to 0.0 (policies match)
- **Warning**: > 0.1 indicates significant mismatch
- **Note**: Can be negative (rollout is less confident)
- **`mismatch_k3_kl`**: K3 KL estimator
- **Formula**: `mean(exp(log_ratio) - log_ratio - 1)`
- More stable for small KL values
- Always non-negative
#### **Example: Accessing Metrics in Code**
```python
# Metrics are returned from compute_rollout_importance_weights
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=training_log_probs, # from training policy
rollout_log_prob=rollout_log_probs, # from rollout policy
response_mask=response_mask,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
# All metrics have 'mismatch/' prefix
print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}")
print(f"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}")
print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}")
print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}")
# Check for warning conditions
if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0:
print("⚠️ Warning: Mean IS weight far from 1.0, significant policy mismatch detected")
if metrics['mismatch/rollout_is_eff_sample_size'] < 0.3:
print("⚠️ Warning: Low effective sample size, high variance in IS weights")
if metrics['mismatch/rollout_is_veto_fraction'] > 0.1:
print("⚠️ Warning: High veto fraction, policies may be too different")
```
#### **Example: Monitoring Metrics During Training**
```python
# In your training loop
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(dataloader):
# ... rollout phase ...
# Compute IS weights and get metrics
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=batch.old_log_prob,
rollout_log_prob=batch.rollout_log_prob,
response_mask=batch.response_mask,
rollout_is_level=config.rollout_is_level,
rollout_is_mode=config.rollout_is_mode,
rollout_is_threshold=config.rollout_is_threshold,
rollout_is_veto_threshold=config.rollout_is_veto_threshold,
)
# Log to tensorboard/wandb
for metric_name, metric_value in metrics.items():
logger.log_scalar(metric_name, metric_value, step=global_step)
# Use IS weights in training
is_weights = weights_proto.batch["rollout_is_weights"]
# ... apply weights to policy gradient ...
```
#### **Example: Conditional Alerting Based on Metrics**
```python
def check_rollout_is_health(metrics, config):
"""Check if rollout IS metrics indicate healthy training."""
warnings = []
# Check mean IS weight
mean_weight = metrics['mismatch/rollout_is_mean']
if mean_weight < 0.5 or mean_weight > 2.0:
warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0")
# Check effective sample size
ess = metrics['mismatch/rollout_is_eff_sample_size']
if ess < 0.3:
warnings.append(f"Effective sample size {ess:.3f} is too low")
# Check veto fraction
veto_frac = metrics['mismatch/rollout_is_veto_fraction']
if veto_frac > 0.1:
warnings.append(f"Veto fraction {veto_frac:.3f} is too high")
# Check variance
std = metrics['mismatch/rollout_is_std']
if std > 1.0:
warnings.append(f"IS weight std {std:.3f} is too high")
# Check KL divergence
kl = metrics['mismatch/mismatch_kl']
if abs(kl) > 0.1:
warnings.append(f"KL divergence {kl:.3f} indicates significant mismatch")
if warnings:
print("⚠️ Rollout IS Health Warnings:")
for warning in warnings:
print(f" - {warning}")
return False
else:
print("✅ Rollout IS metrics look healthy")
return True
# Use in training
_, metrics = compute_rollout_importance_weights(...)
is_healthy = check_rollout_is_health(metrics, config)
if not is_healthy:
# Consider adjusting config or investigating issues
print("Consider:")
print(" - Tightening rollout_is_threshold")
print(" - Switching to geometric aggregation level")
print(" - Checking if rollout and training policies are too different")
```
### Step 3: Test Your Training
Start with the basic token-level truncate configuration:
```bash
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
```
Monitor metrics for 1-2 epochs before adjusting parameters.
## Configuration Examples
### Example 1: Full IS Correction
```yaml
algorithm:
rollout_is_threshold: 2.0
rollout_is: true # Apply weights to loss
rollout_is_level: token
rollout_is_mode: truncate
```
### Example 2: Metrics Only (Monitoring Mode)
```yaml
algorithm:
rollout_is_threshold: 2.0
rollout_is: false # Compute metrics, don't apply weights
rollout_is_level: token
rollout_is_mode: truncate
```
### Example 3: Geometric Mean with Clip
```yaml
algorithm:
rollout_is_threshold: 1.0002
rollout_is: true
rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric
rollout_is_mode: clip
```
### Example 4: Asymmetric Thresholds
```yaml
algorithm:
rollout_is_threshold: 5.0
rollout_is: true
rollout_is_threshold_lower: 0.8
rollout_is_level: token
rollout_is_mode: clip
```
## Troubleshooting
### Issue: High variance in IS weights
**Symptoms:** `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3
**Solutions:**
1. Switch from `sequence` to `geometric` level
2. Tighten thresholds
3. Verify rollout and training aren't too different
### Issue: Too many sequences vetoed
**Symptoms:** `rollout_is_veto_fraction` > 0.1
**Solutions:**
1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`
2. Check for numerical issues in log prob computation
3. Verify policies aren't completely different
### Issue: Mean IS weight far from 1.0
**Symptoms:** `rollout_is_mean` < 0.5 or > 2.0
**Solutions:**
1. Verify `calculate_log_probs=True` is set
2. Check rollout_log_probs are correctly passed
3. Check for systematic bias
### Debugging: Visualizing Metrics
**Example: Plot IS weight distribution**
```python
import matplotlib.pyplot as plt
import numpy as np
def plot_is_metrics(metrics_history):
"""Plot rollout IS metrics over training steps."""
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Plot 1: Mean IS weight over time
axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean'])
axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[0, 0].set_title('Mean IS Weight')
axes[0, 0].set_xlabel('Step')
axes[0, 0].legend()
# Plot 2: Effective sample size
axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size'])
axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good')
axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning')
axes[0, 1].set_title('Effective Sample Size')
axes[0, 1].set_xlabel('Step')
axes[0, 1].legend()
# Plot 3: Veto fraction
axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction'])
axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='Warning')
axes[0, 2].set_title('Veto Fraction')
axes[0, 2].set_xlabel('Step')
axes[0, 2].legend()
# Plot 4: IS weight distribution (latest step)
latest_idx = -1
percentiles = [25, 50, 75, 95, 99]
values = [metrics_history[f'mismatch/rollout_is_p{p}'][latest_idx] for p in percentiles]
axes[1, 0].bar([f'p{p}' for p in percentiles], values)
axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 0].set_title('IS Weight Percentiles (Latest)')
axes[1, 0].legend()
# Plot 5: KL divergence over time
axes[1, 1].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
axes[1, 1].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
axes[1, 1].axhline(y=0, color='g', linestyle='--', alpha=0.3)
axes[1, 1].set_title('KL Divergence')
axes[1, 1].set_xlabel('Step')
axes[1, 1].legend()
# Plot 6: PPL ratio over time
axes[1, 2].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 2].set_title('PPL Ratio (Training/Rollout)')
axes[1, 2].set_xlabel('Step')
axes[1, 2].legend()
plt.tight_layout()
plt.savefig('rollout_is_metrics.png', dpi=150)
print("Saved plot to rollout_is_metrics.png")
```
**Example: Metric collection during training**
```python
# Collect metrics over time
metrics_history = {
'mismatch/rollout_is_mean': [],
'mismatch/rollout_is_eff_sample_size': [],
'mismatch/rollout_is_veto_fraction': [],
'mismatch/rollout_is_p25': [],
'mismatch/rollout_is_p50': [],
'mismatch/rollout_is_p75': [],
'mismatch/rollout_is_p95': [],
'mismatch/rollout_is_p99': [],
'mismatch/mismatch_kl': [],
'mismatch/mismatch_k3_kl': [],
'mismatch/mismatch_ppl_ratio': [],
}
# In training loop
for step in range(num_steps):
# ... compute IS weights ...
_, metrics = compute_rollout_importance_weights(...)
# Store metrics
for key in metrics_history.keys():
if key in metrics:
metrics_history[key].append(metrics[key])
# Plot every 100 steps
if step % 100 == 0:
plot_is_metrics(metrics_history)
```
## Performance Impact
- **Memory overhead**: ~1% of model memory
- **Computational overhead**: 1-3% depending on level
- **Training stability**: Significantly improved when mismatch exists
## Backward Compatibility
**The old `tis_imp_ratio_cap` parameter is completely removed.** There is no backward compatibility mode.
All scripts and configurations must be updated to use the new rollout IS parameters.
## Testing
Run the test suite to verify everything works:
```bash
# Basic unit tests
python test_rollout_is.py
# Integration tests (if pytest is available)
pytest tests/trainer/ppo/test_rollout_is_integration.py -v
```
Expected output: All tests pass ✓
## Additional Resources
- **Implementation**: `verl/trainer/ppo/mismatch_helper.py`
- **Examples**: `examples/rollout_importance_sampling/`
- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
## Summary
The new Rollout Importance Sampling implementation provides:
- ✅ More robust handling of distribution mismatch
- ✅ Better numerical stability
- ✅ Comprehensive metrics for monitoring
- ✅ Flexibility for different scenarios
- ✅ Memory-efficient computation
Migration is straightforward: replace `tis_imp_ratio_cap` with the new `rollout_is_*` parameters in the `algorithm` config section.

View File

@ -118,7 +118,13 @@ Actor/Rollout/Reference Policy
clip_ratio: 0.2
entropy_coeff: 0.0
use_kl_loss: False # True for GRPO
tis_imp_ratio_cap: -1 # set to positive values for Truncated Importance Sampling (requires setting `rollout.calculate_log_probs` as True)
# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)
rollout_is: False # Enable IS correction
rollout_is_threshold: null # Upper threshold for IS weights (null to disable)
rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)
rollout_is_level: token # Aggregation: token/sequence/geometric
rollout_is_mode: truncate # Bounding: truncate/clip
rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
@ -132,7 +138,7 @@ Actor/Rollout/Reference Policy
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: 0.0 # only used with cosine lr scheduler, default to 0.0
num_cycles: 0.5 # only used with cosine lr scheduler, default to 0.5
warmup_style: constant # select from constant/cosine
lr_scheduler_type: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
@ -415,7 +421,7 @@ ____________________________________________________
Notice that there are some differences in APIs between Megatron optimizer and FSDP optimizer.
- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``warmup_style`` actually means the style of lr decay after warmup.
- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``lr_scheduler_type`` actually means the style of lr decay after warmup.
- Megatron optimizer also support weight decay decay mechanism
- ``use_checkpoint_opt_param_scheduler`` determines whether to use the checkpoint optimizer parameter scheduler. If set to True, the optimizer parameter scheduler will be saved in the checkpoint and loaded from the checkpoint during resuming training.
@ -498,6 +504,13 @@ Algorithm
kl_coef: 0.005
horizon: 10000
target_kl: 0.1
# Rollout Importance Sampling
rollout_is: False
rollout_is_threshold: null
rollout_is_threshold_lower: null
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
- ``gamma``: discount factor
- ``lam``: Trade-off between bias and variance in the GAE estimator
@ -510,6 +523,13 @@ Algorithm
- ``kl_coef``: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.
- ``type``: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.
- ``horizon`` and ``target_kl``: See source code of AdaptiveKLController for details.
- ``rollout_is``: Whether to enable rollout importance sampling correction. Default is False.
- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely.
- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).
- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``clip`` (zero outside bounds).
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4.
Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.
Trainer
~~~~~~~

View File

@ -121,6 +121,7 @@ verl is fast with:
examples/sandbox_fusion_example
advance/rollout_trace.rst
advance/rollout_skip.rst
advance/rollout_is_migration.md
advance/one_step_off
advance/agent_loop

View File

@ -0,0 +1,242 @@
# Rollout Importance Sampling (IS) Examples
This directory contains examples and documentation for using Rollout Importance Sampling to correct distribution mismatch between rollout and training policies.
**References:**
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
- Off-policy RL: https://fengyao.notion.site/off-policy-rl
## Overview
Rollout Importance Sampling corrects for distribution mismatch when:
1. **Rollout generation** uses one policy (e.g., vLLM with BFloat16)
2. **Training** uses another policy (e.g., FSDP with FP32)
3. This mismatch leads to biased gradient estimates
## Quick Start
### Basic Configuration
```yaml
algorithm:
# Main control: set threshold to enable (null = disabled)
rollout_is_threshold: 2.0
# Whether to apply weights to policy loss (true) or just compute metrics (false)
rollout_is: true
rollout_is_level: token
rollout_is_mode: truncate
# IMPORTANT: Must enable log prob calculation
actor_rollout_ref:
rollout:
calculate_log_probs: true
```
### Running the Example
```bash
# Basic example with token-level truncate
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
```
## Configuration Options
### Aggregation Levels (`rollout_is_level`)
| Level | Properties | Threshold Range |
|-------|-----------|-----------------|
| **token** | Per-token | 1.5 - 5.0 |
| **sequence** | Per-sequence | 2.0 - 10.0 |
| **geometric** | Geometric mean | 1.0002 - 1.001 |
### Bounding Modes (`rollout_is_mode`)
| Mode | Behavior |
|------|----------|
| **truncate** | Cap weights at upper threshold only |
| **clip** | Zero out weights outside [lower, upper] |
### Key Parameters
- `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.**
- `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false.
- `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper)
- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: 1e-4)
## Configuration Examples
### Example 1: Full IS Correction (Apply Weights)
```yaml
algorithm:
rollout_is_threshold: 2.0
rollout_is: true # Apply to loss
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
```
### Example 2: Metrics Only (No Weight Application)
```yaml
algorithm:
rollout_is_threshold: 2.0
rollout_is: false # Compute metrics only, don't apply to loss
rollout_is_level: token
rollout_is_mode: truncate
```
### Example 3: Geometric Mean with Clip
```yaml
algorithm:
rollout_is_threshold: 1.0002
rollout_is: true
rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric
rollout_is_mode: clip
rollout_is_veto_threshold: 1e-4
```
### Example 4: Sequence-level with Truncate
```yaml
algorithm:
rollout_is_threshold: 5.0
rollout_is: true
rollout_is_threshold_lower: null # Auto-reciprocal: 0.2
rollout_is_level: sequence
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
```
### Example 5: Asymmetric Thresholds
```yaml
algorithm:
rollout_is_threshold: 5.0
rollout_is: true
rollout_is_threshold_lower: 0.8
rollout_is_level: token
rollout_is_mode: clip
```
## Monitoring Metrics
Key metrics to watch (all prefixed with `mismatch/` in logs):
### Health Indicators
- `rollout_is_mean`: Mean IS weight across sequences
- `rollout_is_eff_sample_size`: Effective sample size after weighting
- `rollout_is_veto_fraction`: Fraction of sequences vetoed
### Distribution Metrics
- `rollout_is_max`, `rollout_is_min`: Weight extremes
- `rollout_is_std`: Standard deviation
- `rollout_is_p50`, `rollout_is_p95`, `rollout_is_p99`: Percentiles
### Diagnostic Metrics
- `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold
- `rollout_is_ratio_fraction_low`: Fraction below lower threshold
- `rollout_is_catastrophic_token_fraction`: Catastrophic tokens detected
### Mismatch Metrics (Training vs Rollout Policy)
These metrics help diagnose the distribution mismatch between rollout and training policies:
**Perplexity Metrics:**
- `mismatch_training_ppl`: Perplexity of training policy
- `mismatch_rollout_ppl`: Perplexity of rollout policy
- `mismatch_ppl_ratio`: Ratio of training PPL to rollout PPL
- `mismatch_log_ppl_diff`: Log perplexity difference
**KL Divergence Metrics:**
- `mismatch_kl`: KL divergence KL(π_rollout || π_training)
- `mismatch_k3_kl`: K3 KL estimator
## Troubleshooting
### Issue: High Variance in IS Weights
**Symptoms**: `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3
**Solutions**:
1. Switch from `sequence` to `geometric` level
2. Tighten thresholds
3. Check if rollout and training are too different
### Issue: Too Many Sequences Vetoed
**Symptoms**: `rollout_is_veto_fraction` > 0.1
**Solutions**:
1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`
2. Check for numerical issues in log prob computation
3. Verify rollout and training policies aren't completely different
### Issue: Mean IS Weight Far from 1.0
**Symptoms**: `rollout_is_mean` < 0.5 or > 2.0
**Solutions**:
1. Check that `calculate_log_probs=True` is set
2. Verify rollout_log_probs are correctly passed
3. Check for systematic bias in rollout vs training
### Issue: Too Much Data Discarded (Clip Mode)
**Symptoms**: `rollout_is_clipped_fraction` > 0.5
**Solutions**:
1. Widen thresholds
2. Switch to `truncate` mode
3. Use `geometric` level for better stability
## Performance Considerations
### Memory Usage
- Rollout IS adds minimal memory overhead (~1% of model memory)
- Log-space computation prevents numerical overflow
### Computational Cost
- Token-level: ~1-2% overhead
- Sequence-level: ~2-3% overhead
- Geometric: ~2-3% overhead
## Advanced Topics
### Dual Thresholds
Specify both upper and lower explicitly:
```yaml
rollout_is_threshold: 2.0 # Upper
rollout_is_threshold_lower: 0.5 # Lower (not 1/2.0 = 0.5)
```
Or use auto-reciprocal:
```yaml
rollout_is_threshold: 2.0 # Upper = 2.0, Lower = 0.5 (auto)
rollout_is_threshold_lower: null
```
### Veto Mechanism
The veto mechanism zeros out entire sequences containing catastrophic outliers:
- If any token has ratio < `rollout_is_veto_threshold`, the entire sequence is rejected
- This prevents extreme outliers from dominating training
- Default threshold: 1e-4 (ratio 10,000x off)
- Set to `null` to disable: `rollout_is_veto_threshold: null`
## Examples
See the script in this directory:
- `run_with_rollout_is.sh`: Basic example with token-level truncate mode
## References
- Implementation: `verl/trainer/ppo/mismatch_helper.py`
- Core algorithm: `verl/trainer/ppo/core_algos.py`
- Paper: "Your Efficient RL Framework Secretly Brings You Off-Policy RL Training"

View File

@ -0,0 +1,99 @@
#!/usr/bin/env bash
# Example: Basic PPO training with Rollout Importance Sampling
# This demonstrates the standard setup for correcting distribution mismatch
set -xeuo pipefail
# ==============================================================================
# Rollout Importance Sampling Configuration
# ==============================================================================
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
rollout_is_threshold=2.0
# Whether to apply IS weights to policy loss
# true = apply weights to loss, false = compute metrics only
rollout_is=true
# Lower threshold (null = auto-reciprocal, i.e., 1/upper = 0.5)
rollout_is_threshold_lower=null
# Aggregation level: token | sequence | geometric (experimental)
rollout_is_level=token
# Bounding mode: truncate (cap upper) | clip (zero outside bounds)
rollout_is_mode=truncate
# Catastrophic outlier veto threshold
rollout_is_veto_threshold=1e-4
# ==============================================================================
# Model and Data Configuration
# ==============================================================================
MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2.5-7B"}
TRAIN_FILE=${TRAIN_FILE:-"data/train.parquet"}
TEST_FILE=${TEST_FILE:-"data/test.parquet"}
max_prompt_length=512
max_response_length=1024
# ==============================================================================
# Training Configuration
# ==============================================================================
train_batch_size=128
ppo_mini_batch_size=32
ppo_epochs=1
learning_rate=5e-7
# ==============================================================================
# Algorithm Configuration
# ==============================================================================
adv_estimator=gae
gamma=1.0
lam=0.95
# ==============================================================================
# Launch Training
# ==============================================================================
python3 -m verl.trainer.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_batch_size} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.gamma=${gamma} \
algorithm.lam=${lam} \
algorithm.rollout_is=${rollout_is} \
algorithm.rollout_is_threshold=${rollout_is_threshold} \
algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \
algorithm.rollout_is_level=${rollout_is_level} \
algorithm.rollout_is_mode=${rollout_is_mode} \
algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=${learning_rate} \
actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
actor_rollout_ref.actor.ppo_epochs=${ppo_epochs} \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.rollout.name=vllm \
trainer.logger='["console","wandb"]' \
trainer.project_name="rollout_is_example" \
trainer.experiment_name="basic_token_truncate" \
trainer.total_epochs=10
echo "Training completed!"
echo ""
echo "Rollout IS Configuration:"
echo " - Threshold: ${rollout_is_threshold}"
echo " - Apply to loss: ${rollout_is}"
echo " - Level: ${rollout_is_level}"
echo " - Mode: ${rollout_is_mode}"
echo ""
echo "Monitor these key metrics in wandb:"
echo " - mismatch/rollout_is_mean (should be ~1.0)"
echo " - mismatch/rollout_is_eff_sample_size (should be >0.5)"
echo " - mismatch/rollout_is_veto_fraction (should be <0.1)"

View File

@ -51,7 +51,7 @@ actor_rollout_ref:
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
lr_scheduler_type: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
@ -105,7 +105,7 @@ critic:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
lr_scheduler_type: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat

View File

@ -304,6 +304,11 @@ class RayDAPOTrainer(RayPPOTrainer):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
# IS and mismatch metrics already have mismatch/ prefix
metrics.update(is_metrics)
with marked_timer("adv", timing_raw, "brown"):
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)

View File

@ -1,8 +1,13 @@
#!/usr/bin/env bash
set -xeuo pipefail
# Rollout Importance Sampling Example
# References:
# - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
# - Off-policy RL: https://fengyao.notion.site/off-policy-rl
project_name='DAPO'
exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
exp_name='DAPO-Qwen2.5-32B-RolloutIS' # Rollout Importance Sampling
adv_estimator=grpo
@ -10,7 +15,14 @@ use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
tis_imp_ratio_cap=2.0
# Rollout Importance Sampling parameters (matches original TIS with threshold=2)
rollout_is=True
rollout_is_threshold=2.0
rollout_is_threshold_lower=null # No lower bound (original TIS behavior)
rollout_is_level=token # token-level (original TIS behavior)
rollout_is_mode=truncate # truncate mode (original TIS behavior)
rollout_is_veto_threshold=null # No veto (original TIS behavior)
clip_ratio_low=0.2
clip_ratio_high=0.28
@ -58,14 +70,17 @@ offload=True
gen_tp=4
# Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
# Please note that server mode(agent loop) hasn't return rollout_log_probs for now.
# so currently, server mode is not supported for TIS.
# To turn on TIS, you need to set the following parameters. Note 2.0 is a hyper-parameter and can be tuned.
# actor_rollout_ref.actor.tis_imp_ratio_cap=2.0
# actor_rollout_ref.rollout.calculate_log_probs=True
# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)
#
# Please note that server mode (agent loop) hasn't returned rollout_log_probs for now,
# so currently server mode is not supported for Rollout IS.
#
# Rollout IS parameters (configured at top of script):
# algorithm.rollout_is=True
# algorithm.rollout_is_threshold=2.0 # Upper threshold (can be tuned)
# algorithm.rollout_is_level=token # Aggregation level
# algorithm.rollout_is_mode=truncate # Bounding mode
# actor_rollout_ref.rollout.calculate_log_probs=True # Required!
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
@ -109,7 +124,12 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.tis_imp_ratio_cap=${tis_imp_ratio_cap} \
algorithm.rollout_is=${rollout_is} \
algorithm.rollout_is_threshold=${rollout_is_threshold} \
algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \
algorithm.rollout_is_level=${rollout_is_level} \
algorithm.rollout_is_mode=${rollout_is_mode} \
algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \

View File

@ -103,7 +103,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -100,7 +100,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -99,7 +99,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -103,7 +103,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -99,7 +99,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -577,6 +577,11 @@ class OneStepOffRayTrainer(RayPPOTrainer):
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
# IS and mismatch metrics already have mismatch/ prefix
metrics.update(is_metrics)
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(

View File

@ -48,7 +48,8 @@ reward_model:
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null
warmup_style: constant
warmup_style: null # deprecated
lr_scheduler_type: constant
total_training_steps: -1 # must be overridden by program
weight_decay: 0.
grad_clip: 10.0

View File

@ -42,7 +42,7 @@ FSDP_ENGINE_CONFIG="\
optim.betas="[0.9,0.95]" \
optim.clip_grad=1.0 \
optim.min_lr_ratio=0.1 \
optim.warmup_style=cosine \
optim.lr_scheduler_type=cosine \
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
engine.strategy=${FSDP_STRATEGY} \
engine.fsdp_size=${FSDP_SIZE}"

View File

@ -301,8 +301,8 @@ actor_rollout_ref:
# Number of cosine cycles in LR schedule
num_cycles: 0.5
# LR warmup style: "constant" or "cosine"
warmup_style: constant
# LR scheduler type: "constant" or "cosine"
lr_scheduler_type: constant
# Total training steps (must be overridden at runtime)
total_training_steps: -1
@ -605,8 +605,8 @@ critic:
# Minimum LR ratio for cosine schedule
min_lr_ratio: 0.0
# LR warmup style: "constant" or "cosine"
warmup_style: constant
# LR scheduler type: "constant" or "cosine"
lr_scheduler_type: constant
# Total training steps (must be overridden at runtime)
total_training_steps: -1

View File

@ -0,0 +1,289 @@
#!/usr/bin/env python3
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Quick Sanity Test for Rollout Importance Sampling
This is a standalone test script that can be run without pytest to quickly verify
the rollout IS implementation is working correctly. For comprehensive integration
tests, see: tests/trainer/ppo/test_rollout_is_integration.py
Usage:
python test_rollout_is.py
This tests:
- Basic rollout IS functionality (3 levels, 2 modes)
- Metrics completeness (32 total: 21 IS + 11 mismatch metrics)
- Veto mechanism
- Edge cases
"""
import torch
from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights
def test_basic_rollout_is():
"""Test basic rollout IS functionality."""
print("Testing basic rollout IS functionality...")
# Create test data
batch_size, seq_length = 4, 10
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create slightly different log probs (simulating BF16 vs FP32 mismatch)
old_log_prob = torch.randn(batch_size, seq_length, device=device)
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1
eos_mask = torch.ones(batch_size, seq_length, device=device)
# Test token-level truncate mode (equivalent to old TIS)
print("\n1. Testing token-level truncate mode...")
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
weights = weights_proto.batch["rollout_is_weights"]
print(f" Weights shape: {weights.shape}")
print(f" Mean weight: {metrics['mismatch/rollout_is_mean']:.4f}")
print(f" Max weight: {metrics['mismatch/rollout_is_max']:.4f}")
print(f" Min weight: {metrics['mismatch/rollout_is_min']:.4f}")
print(f" Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.4f}")
assert weights.shape == old_log_prob.shape
assert weights.max() <= 2.0, "Weights should be capped at threshold"
print(" ✓ Token-level truncate mode passed")
# Test sequence-level mode
print("\n2. Testing sequence-level mode...")
weights_seq_proto, metrics_seq = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_level="sequence",
rollout_is_mode="truncate",
rollout_is_threshold=5.0,
rollout_is_veto_threshold=1e-4,
)
weights_seq = weights_seq_proto.batch["rollout_is_weights"]
print(f" Mean weight: {metrics_seq['mismatch/rollout_is_mean']:.4f}")
print(f" Effective sample size: {metrics_seq['mismatch/rollout_is_eff_sample_size']:.4f}")
# Check that all tokens in a sequence have the same weight
for i in range(batch_size):
seq_weights = weights_seq[i, eos_mask[i].bool()]
assert torch.allclose(seq_weights, seq_weights[0]), "All tokens in sequence should have same weight"
print(" ✓ Sequence-level mode passed")
# Test geometric mean mode
print("\n3. Testing geometric mean mode...")
weights_geo_proto, metrics_geo = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_level="geometric",
rollout_is_mode="clip",
rollout_is_threshold=1.5,
rollout_is_threshold_lower=0.5,
rollout_is_veto_threshold=1e-4,
)
print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}")
print(f" Clipped fraction: {metrics_geo['mismatch/rollout_is_clipped_fraction']:.4f}")
print(" ✓ Geometric mean mode passed")
# Test veto mechanism
print("\n4. Testing veto mechanism...")
# Create data with catastrophic outliers
old_log_prob_veto = torch.randn(2, 5, device=device)
rollout_log_prob_veto = old_log_prob_veto.clone()
# Make one token have catastrophically low ratio
rollout_log_prob_veto[0, 2] = old_log_prob_veto[0, 2] + 15.0 # ratio ~= 3e-7
eos_mask_veto = torch.ones(2, 5, device=device)
weights_veto_proto, metrics_veto = compute_rollout_importance_weights(
old_log_prob=old_log_prob_veto,
rollout_log_prob=rollout_log_prob_veto,
response_mask=eos_mask_veto,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
weights_veto = weights_veto_proto.batch["rollout_is_weights"]
print(f" Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}")
# Check that the sequence with catastrophic token has all weights zeroed
assert weights_veto[0].sum() == 0, "Sequence with catastrophic token should be vetoed"
assert weights_veto[1].sum() > 0, "Normal sequence should not be vetoed"
print(" ✓ Veto mechanism passed")
# Test disabled IS (threshold=None)
print("\n5. Testing disabled IS...")
weights_disabled, metrics_disabled = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_threshold=None,
)
assert weights_disabled is None, "Should return None when threshold is None"
assert len(metrics_disabled) == 0, "Should return empty metrics when disabled"
print(" ✓ Disabled IS passed")
print("\n✓ All tests passed!")
def test_metrics_completeness():
"""Test that all expected metrics are returned."""
print("\nTesting metrics completeness...")
batch_size, seq_length = 3, 8
device = "cuda" if torch.cuda.is_available() else "cpu"
old_log_prob = torch.randn(batch_size, seq_length, device=device)
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2
eos_mask = torch.ones(batch_size, seq_length, device=device)
_, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.5,
)
# Expected IS metrics
expected_is_metrics = [
"mismatch/rollout_is_mean",
"mismatch/rollout_is_max",
"mismatch/rollout_is_min",
"mismatch/rollout_is_std",
"mismatch/rollout_is_eff_sample_size",
"mismatch/rollout_is_veto_fraction",
"mismatch/rollout_is_catastrophic_token_fraction",
"mismatch/rollout_is_ratio_fraction_high",
"mismatch/rollout_is_ratio_fraction_low",
"mismatch/rollout_is_p25",
"mismatch/rollout_is_p50",
"mismatch/rollout_is_p75",
"mismatch/rollout_is_p95",
"mismatch/rollout_is_p99",
]
# Expected mismatch/diagnostic metrics (also included now)
expected_mismatch_metrics = [
"mismatch/mismatch_training_ppl",
"mismatch/mismatch_training_log_ppl",
"mismatch/mismatch_kl",
"mismatch/mismatch_k3_kl",
"mismatch/mismatch_rollout_ppl",
"mismatch/mismatch_rollout_log_ppl",
"mismatch/mismatch_log_ppl_diff",
"mismatch/mismatch_log_ppl_abs_diff",
"mismatch/mismatch_log_ppl_diff_max",
"mismatch/mismatch_log_ppl_diff_min",
"mismatch/mismatch_ppl_ratio",
]
expected_metrics = expected_is_metrics + expected_mismatch_metrics
missing_metrics = [m for m in expected_metrics if m not in metrics]
if missing_metrics:
print(f" ✗ Missing metrics: {missing_metrics}")
return False
print(f" ✓ All {len(expected_metrics)} expected metrics present")
print(f" Total metrics returned: {len(metrics)}")
return True
def test_mismatch_metrics():
"""Test mismatch metrics computation."""
print("\nTesting mismatch metrics computation...")
batch_size, seq_length = 4, 12
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create test data with some mismatch
old_log_prob = torch.randn(batch_size, seq_length, device=device) - 2.0 # training policy
rollout_log_prob = torch.randn(batch_size, seq_length, device=device) - 1.5 # rollout policy (more confident)
response_mask = torch.ones(batch_size, seq_length, device=device)
# Test with rollout log probs
metrics = compute_mismatch_metrics(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
)
expected_metrics = [
"mismatch_training_ppl",
"mismatch_training_log_ppl",
"mismatch_kl",
"mismatch_k3_kl",
"mismatch_rollout_ppl",
"mismatch_rollout_log_ppl",
"mismatch_log_ppl_diff",
"mismatch_log_ppl_abs_diff",
"mismatch_log_ppl_diff_max",
"mismatch_log_ppl_diff_min",
"mismatch_ppl_ratio",
]
for metric in expected_metrics:
assert metric in metrics, f"Missing metric: {metric}"
print(f" Training PPL: {metrics['mismatch_training_ppl']:.4f}")
print(f" Rollout PPL: {metrics['mismatch_rollout_ppl']:.4f}")
print(f" KL divergence: {metrics['mismatch_kl']:.6f}")
print(f" K3 KL: {metrics['mismatch_k3_kl']:.6f}")
print(f" PPL ratio: {metrics['mismatch_ppl_ratio']:.4f}")
print(f" ✓ All {len(expected_metrics)} mismatch metrics present")
# Test without rollout log probs
metrics_no_rollout = compute_mismatch_metrics(
old_log_prob=old_log_prob,
rollout_log_prob=None,
response_mask=response_mask,
)
assert "mismatch_training_ppl" in metrics_no_rollout
assert "mismatch_rollout_ppl" not in metrics_no_rollout
print(" ✓ Mismatch metrics work without rollout log probs")
if __name__ == "__main__":
print("=" * 60)
print("Rollout Importance Sampling Test Suite")
print("=" * 60)
try:
test_basic_rollout_is()
test_metrics_completeness()
test_mismatch_metrics()
print("\n" + "=" * 60)
print("ALL TESTS PASSED ✓")
print("=" * 60)
except Exception as e:
print(f"\n✗ Test failed with error: {e}")
import traceback
traceback.print_exc()
exit(1)

View File

@ -0,0 +1,241 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for Rollout Importance Sampling."""
import pytest
import torch
from verl.trainer.ppo.core_algos import compute_policy_loss_vanilla
from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights
from verl.workers.config.actor import ActorConfig
class TestRolloutISIntegration:
"""Integration tests for Rollout IS with PPO."""
@pytest.fixture
def sample_data(self):
"""Create sample training data."""
batch_size, seq_length = 4, 16
device = "cuda" if torch.cuda.is_available() else "cpu"
return {
"old_log_prob": torch.randn(batch_size, seq_length, device=device),
"log_prob": torch.randn(batch_size, seq_length, device=device),
"rollout_log_prob": torch.randn(batch_size, seq_length, device=device),
"advantages": torch.randn(batch_size, seq_length, device=device),
"response_mask": torch.ones(batch_size, seq_length, device=device),
}
@pytest.fixture
def config_with_rollout_is(self):
"""Create config for policy loss computation.
Note: rollout_is config has been moved to algorithm config.
This config only needs fields used by policy loss (clip_ratio, etc).
"""
config = ActorConfig(
strategy="fsdp",
rollout_n=1,
ppo_micro_batch_size=2,
clip_ratio=0.2,
)
return config
def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):
"""Test that policy loss computation works with rollout IS weights.
Note: In production, IS weights are computed centrally in the trainer
(before advantage computation) and passed to policy loss.
This test simulates that workflow.
"""
# First compute IS weights (as trainer would do centrally)
rollout_is_weights_proto, _ = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
# Policy loss function receives pre-computed IS weights
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss_vanilla(
old_log_prob=sample_data["old_log_prob"],
log_prob=sample_data["log_prob"],
advantages=sample_data["advantages"],
response_mask=sample_data["response_mask"],
loss_agg_mode="token-mean",
config=config_with_rollout_is,
rollout_is_weights=rollout_is_weights,
)
# Check loss is valid
assert isinstance(pg_loss, torch.Tensor)
assert pg_loss.ndim == 0 # Scalar
assert not torch.isnan(pg_loss)
assert not torch.isinf(pg_loss)
def test_rollout_is_weights_computation(self, sample_data):
"""Test rollout IS weights and metrics computation."""
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
# Check weights
from verl.protocol import DataProto
assert isinstance(weights_proto, DataProto)
weights = weights_proto.batch["rollout_is_weights"]
assert isinstance(weights, torch.Tensor)
assert weights.shape == sample_data["old_log_prob"].shape
# Check metrics are returned
assert isinstance(metrics, dict)
assert len(metrics) > 0
assert "mismatch/rollout_is_mean" in metrics
def test_all_aggregation_levels(self, sample_data):
"""Test all three aggregation levels."""
levels = ["token", "sequence", "geometric"]
for level in levels:
_, metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level=level,
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
)
assert "mismatch/rollout_is_mean" in metrics
def test_both_bounding_modes(self, sample_data):
"""Test both truncate and clip modes."""
modes = ["truncate", "clip"]
for mode in modes:
_, metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level="token",
rollout_is_mode=mode,
rollout_is_threshold=2.0,
rollout_is_threshold_lower=0.5,
)
assert "mismatch/rollout_is_mean" in metrics
def test_mismatch_metrics(self, sample_data):
"""Test mismatch diagnostic metrics computation."""
metrics = compute_mismatch_metrics(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
)
# Check key metrics are present
assert "mismatch_training_ppl" in metrics
assert "mismatch_rollout_ppl" in metrics
assert "mismatch_kl" in metrics
assert isinstance(metrics["mismatch_kl"], float)
def test_veto_mechanism(self):
"""Test veto mechanism with catastrophic outliers."""
batch_size, seq_length = 2, 5
device = "cuda" if torch.cuda.is_available() else "cpu"
old_log_prob = torch.randn(batch_size, seq_length, device=device)
rollout_log_prob = old_log_prob.clone()
# Create catastrophic outlier in first sequence
rollout_log_prob[0, 2] += 15.0 # Makes ratio ~3e-7
response_mask = torch.ones(batch_size, seq_length, device=device)
_, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
# Should have vetoed one sequence
assert metrics["mismatch/rollout_is_veto_fraction"] > 0
assert metrics["mismatch/rollout_is_veto_fraction"] <= 1.0
def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
"""Test metrics-only mode: compute IS weights/metrics but don't apply to loss.
This tests the use case where rollout_is_threshold is set (enables computation)
but rollout_is=False (disables weight application to policy loss).
"""
# Compute IS weights (as trainer would do)
rollout_is_weights_proto, is_metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
)
# Metrics should be computed
assert len(is_metrics) > 0
assert "mismatch/rollout_is_mean" in is_metrics
# In metrics-only mode, we compute loss WITHOUT applying weights
# (simulating rollout_is=False)
pg_loss_no_weights, _, _, _ = compute_policy_loss_vanilla(
old_log_prob=sample_data["old_log_prob"],
log_prob=sample_data["log_prob"],
advantages=sample_data["advantages"],
response_mask=sample_data["response_mask"],
loss_agg_mode="token-mean",
config=config_with_rollout_is,
rollout_is_weights=None, # Don't apply weights
)
# Compare to loss WITH weights (rollout_is=True)
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
pg_loss_with_weights, _, _, _ = compute_policy_loss_vanilla(
old_log_prob=sample_data["old_log_prob"],
log_prob=sample_data["log_prob"],
advantages=sample_data["advantages"],
response_mask=sample_data["response_mask"],
loss_agg_mode="token-mean",
config=config_with_rollout_is,
rollout_is_weights=rollout_is_weights,
)
# Losses should be different (weights have an effect)
assert not torch.allclose(pg_loss_no_weights, pg_loss_with_weights)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@ -21,15 +21,24 @@ class TestFSDPOptimizerConfigCPU:
def test_default_configuration(self):
config = FSDPOptimizerConfig(lr=0.1)
assert config.min_lr_ratio is None
assert config.warmup_style == "constant"
assert config.lr_scheduler_type == "constant"
assert config.num_cycles == 0.5
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
def test_valid_warmup_styles(self, warmup_style):
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
assert config.warmup_style == warmup_style
@pytest.mark.parametrize("lr_scheduler_type", ["constant", "cosine"])
def test_valid_lr_scheduler_types(self, lr_scheduler_type):
config = FSDPOptimizerConfig(lr_scheduler_type=lr_scheduler_type, lr=0.1)
assert config.lr_scheduler_type == lr_scheduler_type
def test_invalid_warmup_style(self):
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
def test_valid_warmup_style_types(self, warmup_style):
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
assert config.lr_scheduler_type == warmup_style
def test_invalid_lr_scheduler_type(self):
with pytest.raises((ValueError, AssertionError)):
FSDPOptimizerConfig(lr_scheduler_type="invalid_style", lr=0.1)
def test_invalid_warmup_style_type(self):
with pytest.raises((ValueError, AssertionError)):
FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1)

View File

@ -127,6 +127,8 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
inputs_embeds = kwargs.get("inputs_embeds")
position_ids = kwargs.get("position_ids")
visual_pos_masks = kwargs.get("visual_pos_masks")
deepstack_visual_embeds = kwargs.get("deepstack_visual_embeds")
call_kwargs = kwargs.copy()
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
@ -139,6 +141,43 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
if slice_now:
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
# Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models
if visual_pos_masks is not None:
original_visual_mask = visual_pos_masks
sliced_visual_mask = slice_input_tensor(visual_pos_masks, dim=1, padding=False)
call_kwargs["visual_pos_masks"] = sliced_visual_mask
if deepstack_visual_embeds is not None:
sliced_embeds = []
num_visual_before = original_visual_mask.sum().item()
num_visual_in_shard = sliced_visual_mask.sum().item()
if num_visual_in_shard > 0 and num_visual_before > 0:
# Calculate which visual embeddings belong to this shard
# We need to find the offset of visual tokens in this shard
from verl.utils.ulysses import get_ulysses_sequence_parallel_rank
rank = get_ulysses_sequence_parallel_rank()
seq_len = original_visual_mask.shape[1]
local_seq_len = seq_len // current_ulysses_sp_size
start_idx = rank * local_seq_len
end_idx = start_idx + local_seq_len
# Get total visual tokens before and up to the end of the shard's sequence slice
# This correctly handles batches by summing across all samples
visual_start = original_visual_mask[:, :start_idx].sum().item() if start_idx > 0 else 0
visual_end = original_visual_mask[:, :end_idx].sum().item()
# Slice each tensor in deepstack_visual_embeds
for embed in deepstack_visual_embeds:
sliced_embeds.append(embed[visual_start:visual_end])
else:
# No visual tokens in this shard, create empty tensors to maintain gradient flow
for embed in deepstack_visual_embeds:
sliced_embeds.append(embed[:0])
call_kwargs["deepstack_visual_embeds"] = sliced_embeds
self._needs_initial_slice = False
try:
return original_forward(self, *args, **call_kwargs)
@ -290,9 +329,7 @@ def apply_monkey_patch(
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLFlashAttention2 as Qwen2VLAttention,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
if use_remove_padding or ulysses_sp_size > 1:
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward

View File

@ -209,8 +209,10 @@ def _get_input_embeds(
patch_dim = config.in_channels * config.temporal_patch_size * config.patch_size**2
pixel_values = torch.zeros((16, patch_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
image_embeds, _ = model.visual(pixel_values, grid_thw=image_grid_thw)
image_embeds, dummy_deepstack_image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
inputs_embeds += 0.0 * image_embeds.mean()
for emb in dummy_deepstack_image_embeds or []:
inputs_embeds += 0.0 * emb.mean()
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)

View File

@ -75,7 +75,6 @@ actor_rollout_ref:
clip_ratio_c: 3.0
loss_agg_mode: token-mean
entropy_coeff: 0
tis_imp_ratio_cap: -1
use_kl_loss: false
use_torch_compile: true
kl_loss_coef: 0.001
@ -484,6 +483,12 @@ algorithm:
pf_ppo:
reweight_method: pow
weight_pow: 2.0
rollout_is_threshold: null
rollout_is_threshold_lower: null
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 0.0001
rollout_is: false
trainer:
balance_batch: true
total_epochs: 30

View File

@ -18,7 +18,8 @@ actor_rollout_ref:
clip_grad: 1.0
min_lr_ratio: 0.0
num_cycles: 0.5
warmup_style: constant
lr_scheduler_type: constant
warmup_style: null
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
wrap_policy:
@ -59,7 +60,6 @@ actor_rollout_ref:
clip_ratio_c: 3.0
loss_agg_mode: token-mean
entropy_coeff: 0
tis_imp_ratio_cap: -1
use_kl_loss: false
use_torch_compile: true
kl_loss_coef: 0.001
@ -315,7 +315,8 @@ critic:
clip_grad: 1.0
min_lr_ratio: 0.0
num_cycles: 0.5
warmup_style: constant
lr_scheduler_type: constant
warmup_style: null
model:
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
@ -462,6 +463,12 @@ algorithm:
pf_ppo:
reweight_method: pow
weight_pow: 2.0
rollout_is_threshold: null
rollout_is_threshold_lower: null
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 0.0001
rollout_is: false
trainer:
balance_batch: true
total_epochs: 30

View File

@ -74,10 +74,6 @@ loss_agg_mode: token-mean
# Entropy regularization coefficient in PPO loss
entropy_coeff: 0
# Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl
# the truncation value C of truncated Importance Sampling (-1 for disable TIS)
tis_imp_ratio_cap: -1
# Whether to use KL loss instead of KL reward penalty. True for GRPO
use_kl_loss: false

View File

@ -73,6 +73,14 @@ class AlgoConfig(BaseConfig):
use_pf_ppo (bool): Whether to enable preference feedback PPO.
pf_ppo (dict[str, Any]): Preference feedback PPO settings.
filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy
rollout_is_threshold (Optional[float]): Upper threshold for IS weights. null = disabled,
float value = enabled (compute weights and metrics). This is the main on/off switch.
rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.
rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric".
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "clip" (zero outside bounds).
rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers.
rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,
False = compute metrics only (useful for monitoring before enabling correction). Default: False.
"""
gamma: float = 1.0
@ -85,3 +93,13 @@ class AlgoConfig(BaseConfig):
use_pf_ppo: bool = False
pf_ppo: dict[str, Any] = field(default_factory=dict)
filter_groups: Optional[FilterGroupsConfig] = None
# Rollout Importance Sampling (replaces legacy tis_imp_ratio_cap)
# Controls computation of IS weights and mismatch metrics
rollout_is_threshold: Optional[float] = None # null = disabled, float = enabled
rollout_is_threshold_lower: Optional[float] = None
rollout_is_level: str = "token"
rollout_is_mode: str = "truncate"
rollout_is_veto_threshold: Optional[float] = 1e-4
# Controls whether to apply IS weights to policy loss (only if rollout_is_threshold is set)
# True = apply weights to loss, False = compute metrics only (no weight application)
rollout_is: bool = False

View File

@ -28,6 +28,8 @@ min_lr_ratio: 0.0
# Number of cosine cycles in LR schedule
num_cycles: 0.5
# LR warmup style: "constant" or "cosine"
warmup_style: constant
# LR scheduler type: "constant" or "cosine"
lr_scheduler_type: constant
# deprecated
warmup_style: null

View File

@ -73,6 +73,28 @@ algorithm:
reweight_method: pow # ["pow", "max_min", "max_random"]
weight_pow: 2.0
# Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
# When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)
rollout_is_threshold: null
# Lower threshold for IS weights (null = auto-reciprocal of upper)
rollout_is_threshold_lower: null
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
rollout_is_level: token
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
rollout_is_mode: truncate
# Per-token veto threshold for catastrophic outliers
rollout_is_veto_threshold: 1e-4
# Whether to apply IS weights to policy loss
# true = apply weights to loss, false = compute metrics only (no weight application)
# Useful for monitoring mismatch before enabling correction
rollout_is: false
trainer:
balance_batch: True
total_epochs: 30

View File

@ -113,6 +113,28 @@ algorithm:
# Power used for weight scaling in "pow" method
weight_pow: 2.0
# Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
# When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)
rollout_is_threshold: null
# Lower threshold for IS weights (null = auto-reciprocal of upper)
rollout_is_threshold_lower: null
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
rollout_is_level: token
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
rollout_is_mode: truncate
# Per-token veto threshold for catastrophic outliers
rollout_is_veto_threshold: 1e-4
# Whether to apply IS weights to policy loss
# true = apply weights to loss, false = compute metrics only (no weight application)
# Useful for monitoring mismatch before enabling correction
rollout_is: false
# config for the trainer
trainer:

View File

@ -881,7 +881,7 @@ def compute_policy_loss(
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("vanilla")
@register_policy_loss("vanilla") # type: ignore[arg-type]
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
@ -889,7 +889,7 @@ def compute_policy_loss_vanilla(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.
@ -959,11 +959,9 @@ def compute_policy_loss_vanilla(
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None:
# Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
pg_losses = pg_losses * tis_imp_ratio
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
@ -978,7 +976,7 @@ def compute_policy_loss_gspo(
response_mask: torch.Tensor,
loss_agg_mode: str = "seq-mean-token-mean",
config: Optional[DictConfig | ActorConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for GSPO.
@ -1024,6 +1022,10 @@ def compute_policy_loss_gspo(
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
pg_losses = torch.maximum(pg_losses1, pg_losses2)
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
# for GSPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean")
@ -1044,7 +1046,7 @@ def compute_policy_loss_gpg(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adapted from
https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
@ -1061,6 +1063,10 @@ def compute_policy_loss_gpg(
"""
pg_losses = -log_prob * advantages
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
@ -1073,7 +1079,7 @@ def compute_policy_loss_clip_cov(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
@ -1155,6 +1161,11 @@ def compute_policy_loss_clip_cov(
pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
@ -1168,7 +1179,7 @@ def compute_policy_loss_kl_cov(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
@ -1227,6 +1238,10 @@ def compute_policy_loss_kl_cov(
large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]
]
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
@ -1240,7 +1255,7 @@ def compute_policy_loss_geo_mean(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for GMPO.
@ -1293,6 +1308,17 @@ def compute_policy_loss_geo_mean(
# otherwise, below would be not consistent with the paper
advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
pg_losses = -advantage * ratio
# Apply rollout importance sampling weights if provided
# For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level
if rollout_is_weights is not None:
# Aggregate token-level weights to sequence level using geometric mean for consistency
# Note: rollout_is_weights is always 2D regardless of rollout_is_level
seq_is_weights = torch.exp(
(torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
)
pg_losses = pg_losses * seq_is_weights
pg_loss = torch.mean(pg_losses)
# higher: ratio is too large that need clamp to clip_high (when adv > 0)

View File

@ -0,0 +1,459 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Rollout Importance Sampling (IS) Helper Module
This module handles importance sampling weight computation for correcting
distribution mismatch between rollout policy (e.g., vLLM BFloat16) and
training policy (e.g., FSDP FP32).
Key Features:
1. Three aggregation levels: token, sequence, geometric
2. Two handling modes: truncate (TIS), clip (CIS)
3. Per-token veto mechanism for catastrophic outliers
4. Memory-efficient computation to prevent CUDA OOM
5. Comprehensive metrics tracking
Usage Notes:
- compute_rollout_importance_weights() computes both IS weights and mismatch metrics
- Used in ray_trainer.py via compute_rollout_importance_weights_and_add_to_batch()
- Also used in dp_actor.py for distributed worker computations
- compute_mismatch_metrics() is called internally by compute_rollout_importance_weights()
References:
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
- Off-policy RL: https://fengyao.notion.site/off-policy-rl
"""
from typing import Any, Optional
import torch
import verl.utils.torch_functional as verl_F
from verl.protocol import DataProto
def compute_rollout_importance_weights(
old_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
response_mask: torch.Tensor,
rollout_is_level: str = "token",
rollout_is_mode: str = "truncate",
rollout_is_threshold: Optional[float] = None,
rollout_is_threshold_lower: Optional[float] = None,
rollout_is_veto_threshold: Optional[float] = 1e-4,
) -> tuple[Optional[DataProto], dict[str, Any]]:
"""Compute importance sampling weights and metrics for rollout-training mismatch correction.
This function handles the computation of importance sampling (IS) weights to correct
for the distribution mismatch between rollout policy and training policy.
Reference:
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
Memory-efficient implementation that prevents CUDA OOM by:
- Using log-space computation where possible
- Applying safety bounds to prevent numerical overflow
- Computing metrics without creating huge intermediate tensors
Args:
old_log_prob: Log probabilities from training policy (e.g., FSDP), shape (batch_size, seq_length)
rollout_log_prob: Log probabilities from rollout policy (e.g., vLLM), shape (batch_size, seq_length)
response_mask: Mask for valid tokens, shape (batch_size, seq_length)
rollout_is_level: Level of IS aggregation:
- "token": Per-token ratios (biased)
- "sequence": Product of ratios (unbiased)
- "geometric": Geometric mean of ratios (experimental)
rollout_is_mode: How to handle weights exceeding threshold:
- "truncate": Cap weights at upper_threshold only (TIS)
- "clip": Zero out weights outside [lower_threshold, upper_threshold] (CIS)
rollout_is_threshold: Upper threshold for IS weights
rollout_is_threshold_lower: Lower threshold for IS weights (clip mode only; if None, defaults to 1/upper)
rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence.
If None, veto mechanism is disabled.
Returns:
Tuple of (weights_proto, metrics) where:
weights_proto: DataProto containing IS weights with key "rollout_is_weights",
shape (batch_size, seq_length). Returns None if rollout_is_threshold is None.
metrics: Dictionary of IS statistics and mismatch metrics (KL, PPL, etc.),
all converted to scalars and prefixed with "mismatch/"
"""
if rollout_is_threshold is None:
return None, {}
# Parse thresholds: if lower not specified, use 1/upper (reciprocal)
upper_threshold = rollout_is_threshold
if rollout_is_threshold_lower is not None:
lower_threshold = rollout_is_threshold_lower
else:
# Default: lower = 1/upper (reciprocal)
lower_threshold = 1.0 / upper_threshold
# Step 1: Compute raw importance weights based on the specified level
log_ratio = old_log_prob - rollout_log_prob
# Pre-compute log thresholds
device = old_log_prob.device
log_threshold_upper = torch.log(torch.tensor(upper_threshold, device=device))
log_threshold_lower = torch.log(torch.tensor(lower_threshold, device=device))
# Safety bound to prevent numerical overflow (exp(20) ≈ 485M)
SAFETY_BOUND = 20.0
# Store unclamped values in log-space for accurate metrics
if rollout_is_level == "token":
# Token-level IS: π_train(a|s) / π_rollout(a|s) per token
log_ratio_for_metrics = log_ratio
# Apply safety bound to prevent overflow
log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND)
rollout_is_weights = torch.exp(log_ratio_safe)
elif rollout_is_level == "sequence":
# Sequence-level IS: π_train(y|x) / π_rollout(y|x) for entire sequence
# Product of token ratios: exp(Σ log(π_train/π_rollout))
log_ratio_sum = verl_F.masked_sum(log_ratio, response_mask, axis=-1).unsqueeze(-1)
log_ratio_for_metrics = log_ratio_sum # Store for metrics
# Apply safety bound to prevent overflow
log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND)
rollout_is_weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob)
elif rollout_is_level == "geometric":
# Geometric mean IS: (∏ π_train/π_rollout)^(1/T)
# Equivalent to exp(mean(log(π_train/π_rollout)))
log_ratio_mean = verl_F.masked_mean(log_ratio, response_mask, axis=-1).unsqueeze(-1)
log_ratio_for_metrics = log_ratio_mean # Store for metrics
# Geometric mean rarely explodes due to averaging, but apply safety bound anyway
log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-SAFETY_BOUND, max=SAFETY_BOUND)
rollout_is_weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob)
else:
raise ValueError(f"Invalid rollout_is_level: {rollout_is_level}. Must be 'token', 'sequence', or 'geometric'.")
# Step 1.5: Apply per-token veto check in log space (memory efficient)
if rollout_is_veto_threshold is not None:
log_veto_threshold = torch.log(torch.tensor(rollout_is_veto_threshold, device=device))
# Check if any token ratio is below veto threshold (in log space)
# log(π_train/π_rollout) < log(veto_threshold) ⟺ π_train/π_rollout < veto_threshold
catastrophic_tokens = (log_ratio < log_veto_threshold) & response_mask.bool()
# For each sequence, check if it has any catastrophic token
# Use broadcasting instead of expand_as to save memory
has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True)
# Create veto mask: 0 if sequence has catastrophic token, 1 otherwise
veto_mask = (~has_catastrophic).float()
else:
# No veto mechanism
catastrophic_tokens = torch.zeros_like(response_mask, dtype=torch.bool)
has_catastrophic = torch.zeros((old_log_prob.size(0), 1), dtype=torch.bool, device=device)
veto_mask = torch.ones((old_log_prob.size(0), 1), dtype=torch.float32, device=device)
# Step 2: Compute comprehensive metrics
metrics = compute_is_metrics(
rollout_is_weights=rollout_is_weights,
log_ratio_for_metrics=log_ratio_for_metrics,
response_mask=response_mask,
rollout_is_level=rollout_is_level,
rollout_is_threshold=upper_threshold,
rollout_is_threshold_lower=lower_threshold,
log_threshold_upper=log_threshold_upper,
log_threshold_lower=log_threshold_lower,
has_catastrophic=has_catastrophic,
catastrophic_tokens=catastrophic_tokens,
SAFETY_BOUND=SAFETY_BOUND,
)
# Step 3: Apply truncation or clipping based on mode
if rollout_is_mode == "truncate":
# Truncated IS (TIS): only cap upper bound to prevent overweighting
rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold)
elif rollout_is_mode == "clip":
# Clipped IS (CIS): zero out weights outside [lower_threshold, upper_threshold]
clip_mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
clip_mask = clip_mask.float()
# Track CIS-specific metrics
metrics["rollout_is_clipped_fraction"] = verl_F.masked_mean(1 - clip_mask, response_mask)
# Sequence-level clipping fraction
if rollout_is_level in ["sequence", "geometric"]:
# All tokens in a sequence have the same weight, so reuse clip_mask
metrics["rollout_is_seq_clipped_fraction"] = (1 - clip_mask[:, 0]).mean()
else:
# Check if any token in each sequence is clipped
seq_has_clipped = verl_F.masked_sum(1 - clip_mask, response_mask, axis=-1) > 0
metrics["rollout_is_seq_clipped_fraction"] = seq_has_clipped.float().mean()
rollout_is_weights = rollout_is_weights * clip_mask
else:
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'clip'.")
# Apply veto mask AFTER all thresholding
# This zeros out entire sequences that have any catastrophic token
rollout_is_weights = rollout_is_weights * veto_mask
# Apply response_mask to ensure weights are 0 where mask is 0
rollout_is_weights = rollout_is_weights * response_mask
# Wrap in DataProto for consistency with worker methods
rollout_is_weights_proto = DataProto.from_dict(tensors={"rollout_is_weights": rollout_is_weights})
# Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics
mismatch_metrics = compute_mismatch_metrics(
old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask
)
metrics.update(mismatch_metrics)
# Convert all tensor metrics to scalars for logging
# Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad()
metrics_scalar = {}
for key, value in metrics.items():
if isinstance(value, torch.Tensor):
metrics_scalar[f"mismatch/{key}"] = value.item()
else:
metrics_scalar[f"mismatch/{key}"] = value
return rollout_is_weights_proto, metrics_scalar
def compute_is_metrics(
rollout_is_weights: torch.Tensor,
log_ratio_for_metrics: torch.Tensor,
response_mask: torch.Tensor,
rollout_is_level: str,
rollout_is_threshold: float,
rollout_is_threshold_lower: float,
log_threshold_upper: torch.Tensor,
log_threshold_lower: torch.Tensor,
has_catastrophic: torch.Tensor,
catastrophic_tokens: torch.Tensor,
SAFETY_BOUND: float,
) -> dict[str, Any]:
"""Compute comprehensive metrics for importance sampling weights.
Reference:
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
This function computes metrics using a mix of true unclamped values (for max/min/fractions
in sequence/geometric mode via log-space) and safety-clamped values (for mean/std/ESS)
to balance accuracy with numerical stability and avoid overflow.
"""
# Validate that we have at least one valid sample
assert response_mask.any(), "Expected at least one valid sample in response_mask"
metrics = {}
device = rollout_is_weights.device
# Track veto statistics
metrics["rollout_is_veto_fraction"] = has_catastrophic.float().mean()
metrics["rollout_is_catastrophic_token_fraction"] = verl_F.masked_mean(catastrophic_tokens.float(), response_mask)
# Compute metrics based on IS level
if rollout_is_level in ["sequence", "geometric"]:
# For sequence/geometric, compute true statistics from log-space
# This reflects the actual distribution before clamping
# True max/min in log space
log_max = log_ratio_for_metrics.max()
log_min = log_ratio_for_metrics.min()
# Convert to regular space with safety bound
metrics["rollout_is_max"] = torch.exp(torch.clamp(log_max, max=SAFETY_BOUND))
metrics["rollout_is_min"] = torch.exp(log_min)
# Mean uses clamped weights to avoid overflow
metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask)
# Compute fraction exceeding threshold in log space (accurate)
exceeds_upper = log_ratio_for_metrics > log_threshold_upper
below_lower = log_ratio_for_metrics < log_threshold_lower
if rollout_is_level == "sequence":
# For sequence level, all tokens in a sequence have the same weight
metrics["rollout_is_ratio_fraction_high"] = exceeds_upper.float().mean()
metrics["rollout_is_ratio_fraction_low"] = below_lower.float().mean()
else: # geometric
# Need to expand to match token dimensions
exceeds_upper_expanded = exceeds_upper.expand_as(response_mask)
below_lower_expanded = below_lower.expand_as(response_mask)
metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean(
exceeds_upper_expanded.float(), response_mask
)
metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(below_lower_expanded.float(), response_mask)
else:
# Token-level: compute directly from weights
metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask)
# Fraction exceeding thresholds
rollout_is_above_threshold = rollout_is_weights > rollout_is_threshold
rollout_is_below_threshold = rollout_is_weights < rollout_is_threshold_lower
metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean(
rollout_is_above_threshold.float(), response_mask
)
metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(rollout_is_below_threshold.float(), response_mask)
# Max/min for token level
mask_bool = response_mask.bool()
metrics["rollout_is_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max()
metrics["rollout_is_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min()
# Compute standard deviation using clamped weights to avoid overflow
mask_count = response_mask.sum()
if mask_count > 1:
# Use clamped weights for variance to avoid squaring huge values
weights_for_std = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
# Use mean from clamped weights for consistency
mean_clamped = verl_F.masked_mean(weights_for_std, response_mask)
rollout_is_var = verl_F.masked_mean(weights_for_std.square(), response_mask) - mean_clamped.square()
metrics["rollout_is_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0))
else:
metrics["rollout_is_std"] = torch.tensor(0.0, device=device)
# Effective sample size (use clamped weights to avoid overflow)
weights_for_ess = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
mean_for_ess = verl_F.masked_mean(weights_for_ess, response_mask)
is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8)
metrics["rollout_is_eff_sample_size"] = 1.0 / verl_F.masked_mean(is_weights_normalized.square(), response_mask)
# Per-sequence breakdown metrics
if rollout_is_weights.dim() > 1:
# Compute mean IS weight per sequence
seq_mean_weights = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1)
# Per-sequence statistics
metrics["rollout_is_seq_mean"] = seq_mean_weights.mean()
metrics["rollout_is_seq_std"] = (
seq_mean_weights.std() if seq_mean_weights.numel() > 1 else torch.tensor(0.0, device=device)
)
metrics["rollout_is_seq_max"] = seq_mean_weights.max()
metrics["rollout_is_seq_min"] = seq_mean_weights.min()
# Identify most problematic sequences
seq_deviation = (seq_mean_weights - 1.0).abs()
metrics["rollout_is_seq_max_deviation"] = seq_deviation.max()
# Fraction of sequences with high IS weights
metrics["rollout_is_seq_fraction_high"] = (seq_mean_weights > rollout_is_threshold).float().mean()
metrics["rollout_is_seq_fraction_low"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean()
# Percentile metrics for better distribution understanding
# Get all valid IS weights
flat_weights = rollout_is_weights[response_mask.bool()]
# Compute key percentiles (guaranteed to have elements due to assertion at function start)
assert flat_weights.numel() > 0, "flat_weights should not be empty"
metrics["rollout_is_p25"] = torch.quantile(flat_weights, 0.25)
metrics["rollout_is_p50"] = torch.quantile(flat_weights, 0.50) # median
metrics["rollout_is_p75"] = torch.quantile(flat_weights, 0.75)
metrics["rollout_is_p95"] = torch.quantile(flat_weights, 0.95)
metrics["rollout_is_p99"] = torch.quantile(flat_weights, 0.99)
return metrics
def compute_mismatch_metrics(
old_log_prob: torch.Tensor,
rollout_log_prob: Optional[torch.Tensor],
response_mask: torch.Tensor,
) -> dict[str, Any]:
"""Compute training-inference mismatch metrics (helper function).
This helper function operates on raw tensors and is used internally by:
- compute_rollout_importance_weights() in this module (automatically included)
- Tests (test_rollout_is.py, test_rollout_is_integration.py)
These metrics help diagnose the mismatch between the rollout policy (e.g., vLLM)
and the training policy (e.g., FSDP), which can cause training instability.
Key metrics:
- mismatch_kl: Direct KL divergence estimator KL(π_rollout || π_training)
- mismatch_k3_kl: K3 KL estimator for stability (more stable for small KL)
- training_ppl: Perplexity of training policy
- rollout_ppl: Perplexity of rollout policy
- log_ppl_diff: Difference in log perplexities
- ppl_ratio: Ratio of training PPL to rollout PPL
Args:
old_log_prob: Log probabilities from training policy, shape (batch_size, seq_length)
rollout_log_prob: Log probabilities from rollout policy, shape (batch_size, seq_length)
response_mask: Mask for valid tokens, shape (batch_size, seq_length)
Returns:
Dictionary of mismatch metrics (without prefix)
Reference:
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
"""
# Validate that we have at least one valid token
assert response_mask.any(), "Expected at least one valid token in response_mask"
metrics = {}
# 1. Training policy perplexity (always available)
# Formula: exp(-1/|T| * Σ log π_training(y_t|y_<t))
# where |T| is the number of tokens generated by the model
mean_log_prob_training = verl_F.masked_mean(old_log_prob, response_mask, axis=-1) # (batch_size,)
training_ppl = torch.exp(-mean_log_prob_training).mean() # Batch mean of per-sequence PPL
metrics["mismatch_training_ppl"] = training_ppl.detach().item()
# Also log log-ppl for easier analysis (avoids exponential scale)
metrics["mismatch_training_log_ppl"] = (-mean_log_prob_training).mean().detach().item()
# 2. Compute rollout mismatch metrics (only if rollout_log_probs available)
if rollout_log_prob is not None:
# 2a. mismatch_kl: Direct estimator for KL(π_rollout || π_training)
# This is the standard KL divergence: E[log(π_rollout) - log(π_training)]
# Positive value means rollout policy is more confident than training policy
metrics["mismatch_kl"] = verl_F.masked_mean(rollout_log_prob - old_log_prob, response_mask).detach().item()
# 2b. mismatch_k3_kl: K3 estimator for KL(π_rollout || π_training)
# More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1]
# Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout
log_ratio = old_log_prob - rollout_log_prob
mismatch_k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1
metrics["mismatch_k3_kl"] = verl_F.masked_mean(mismatch_k3_kl_matrix, response_mask).detach().item()
# 2c. Rollout policy perplexity
mean_log_prob_rollout = verl_F.masked_mean(rollout_log_prob, response_mask, axis=-1) # (batch_size,)
rollout_ppl = torch.exp(-mean_log_prob_rollout).mean() # Batch mean of per-sequence PPL
metrics["mismatch_rollout_ppl"] = rollout_ppl.detach().item()
metrics["mismatch_rollout_log_ppl"] = (-mean_log_prob_rollout).mean().detach().item()
# 2d. Log PPL difference (sequence-level perplexity difference)
# log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
# Since ppl = exp(-log_prob), we have:
# log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff
# Positive value means training assigns lower probability (higher PPL) than rollout
log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
metrics["mismatch_log_ppl_diff"] = log_ppl_diff.mean().detach().item()
metrics["mismatch_log_ppl_abs_diff"] = log_ppl_diff.abs().mean().detach().item()
metrics["mismatch_log_ppl_diff_max"] = log_ppl_diff.max().detach().item()
metrics["mismatch_log_ppl_diff_min"] = log_ppl_diff.min().detach().item()
# 2e. PPL ratio (how much higher is training PPL vs rollout PPL)
# IMPORTANT: Compute per-sequence ratio first, then average
# For numerical stability, compute in log space using log_ppl_diff
# Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff)
# This is the inverse of geometric IS: ppl_ratio_i = 1 / geometric_is_i for each sequence
ppl_ratio = torch.exp(log_ppl_diff).mean() # mean(exp(log_ppl_diff)) = mean(ppl_ratio_i)
metrics["mismatch_ppl_ratio"] = ppl_ratio.detach().item()
return metrics

View File

@ -49,6 +49,7 @@ from verl.trainer.ppo.metric_utils import (
compute_timing_metrics,
process_validation_metrics,
)
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
@ -918,6 +919,49 @@ class RayPPOTrainer:
)
metrics.update(global_balance_stats)
def compute_rollout_importance_weights_and_add_to_batch(self, batch: DataProto) -> tuple[DataProto, dict]:
"""Compute rollout importance sampling weights and mismatch metrics, conditionally add weights to batch.
This method computes IS weights to correct for distribution mismatch between
rollout policy and training policy. It always computes metrics when enabled, but
only adds weights to batch if algorithm.rollout_is is True.
Args:
batch: DataProto containing old_log_probs, rollout_log_probs, response_mask
Returns:
Tuple of (updated_batch, metrics) where:
- updated_batch: Batch with rollout_is_weights added (if rollout_is=True)
- metrics: Dictionary of IS and mismatch metrics (all with mismatch/ prefix)
"""
# Compute rollout IS weights if enabled and data is available
# rollout_is_threshold is the main on/off switch
if self.config.algorithm.rollout_is_threshold is not None and "rollout_log_probs" in batch.batch:
rollout_is_weights, rollout_is_metrics = compute_rollout_importance_weights(
old_log_prob=batch.batch["old_log_probs"],
rollout_log_prob=batch.batch["rollout_log_probs"],
response_mask=batch.batch["response_mask"],
rollout_is_level=self.config.algorithm.rollout_is_level,
rollout_is_mode=self.config.algorithm.rollout_is_mode,
rollout_is_threshold=self.config.algorithm.rollout_is_threshold,
rollout_is_threshold_lower=self.config.algorithm.rollout_is_threshold_lower,
rollout_is_veto_threshold=self.config.algorithm.rollout_is_veto_threshold,
)
# Control: Should we apply weights to policy loss?
# True = add weights to batch (actor will apply them)
# False = don't add weights (metrics only, no loss modification)
apply_weights = self.config.algorithm.get("rollout_is", False)
if apply_weights:
# Add IS weights to batch for distribution to workers
batch = batch.union(rollout_is_weights)
return batch, rollout_is_metrics
# Return unchanged batch and empty metrics if IS is disabled
return batch, {}
def fit(self):
"""
The training loop of PPO.
@ -1107,6 +1151,13 @@ class RayPPOTrainer:
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# Compute rollout importance sampling weights centrally (once per batch)
# This corrects for mismatch between rollout policy and training policy
# Also computes mismatch metrics (KL, PPL, etc.)
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
# IS and mismatch metrics already have mismatch/ prefix
metrics.update(is_metrics)
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
@ -1205,6 +1256,7 @@ class RayPPOTrainer:
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation
# this is experimental and may be changed/removed in the future in favor of a general-purpose one
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):

View File

@ -373,13 +373,10 @@ class DataParallelPPOActor(BasePPOActor):
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if self.config.tis_imp_ratio_cap > 0:
assert "rollout_log_probs" in data.batch.keys(), (
"Truncated Importance Sampling (TIS) requires to configure "
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
"and is not currently supported in Server mode (agent loop)."
)
select_keys.append("rollout_log_probs")
# Include pre-computed IS weights if present in batch
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
if "rollout_is_weights" in data.batch.keys():
select_keys.append("rollout_is_weights")
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
@ -412,7 +409,6 @@ class DataParallelPPOActor(BasePPOActor):
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
old_log_prob = model_inputs["old_log_probs"]
rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None
advantages = model_inputs["advantages"]
entropy_coeff = self.config.entropy_coeff
@ -438,9 +434,21 @@ class DataParallelPPOActor(BasePPOActor):
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
# Extract pre-computed rollout importance sampling weights if present
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
rollout_is_weights = model_inputs.get("rollout_is_weights", None)
# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
# are computed centrally in ray_trainer.py for consistency and efficiency.
# This ensures metrics are computed uniformly across all batches at the trainer level
# and avoids redundant computation across workers and micro-batches.
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
policy_loss_fn = get_policy_loss_fn(loss_mode)
# Compute policy loss (all functions return 4 values)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
@ -448,7 +456,7 @@ class DataParallelPPOActor(BasePPOActor):
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_log_probs=rollout_log_probs,
rollout_is_weights=rollout_is_weights,
)
if entropy_coeff != 0:

View File

@ -316,6 +316,10 @@ class MegatronPPOActor(BasePPOActor):
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
# Include pre-computed IS weights if present in batch
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
if "rollout_is_weights" in data.batch.keys():
select_keys.append("rollout_is_weights")
self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
if self.has_multi_modal_inputs:
data = data.select(select_keys, ["multi_modal_inputs"])
@ -419,7 +423,6 @@ class MegatronPPOActor(BasePPOActor):
response_length = responses.size(1)
response_mask = data["response_mask"].to(bool)
loss_agg_mode = self.config.loss_agg_mode
# compute policy loss
log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous()
ret_entropy = None
@ -434,6 +437,15 @@ class MegatronPPOActor(BasePPOActor):
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
policy_loss_fn = get_policy_loss_fn(loss_mode)
# Extract pre-computed rollout importance sampling weights if present
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
rollout_is_weights = data.get("rollout_is_weights", None)
# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
# are computed centrally in ray_trainer.py for consistency and efficiency.
# This ensures metrics are computed uniformly across all batches at the trainer level
# and avoids redundant computation across workers and micro-batches.
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
@ -441,6 +453,7 @@ class MegatronPPOActor(BasePPOActor):
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_is_weights=rollout_is_weights,
)
stats.update(

View File

@ -106,7 +106,6 @@ class ActorConfig(BaseConfig):
clip_ratio_c: float = 3.0
loss_agg_mode: str = "token-mean"
entropy_coeff: float = 0
tis_imp_ratio_cap: float = -1
use_kl_loss: bool = False
use_torch_compile: bool = True
kl_loss_coef: float = 0.001

View File

@ -60,16 +60,27 @@ class FSDPOptimizerConfig(OptimizerConfig):
Args:
lr (float): Learning rate.
min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.
warmup_style (str): LR warmup style: "constant" or "cosine".
lr_scheduler_type (str): LR scheduler type: "constant" or "cosine".
num_cycles (float): Number of cosine cycles in LR schedule.
"""
_mutable_fields = OptimizerConfig._mutable_fields.copy()
_mutable_fields.add("lr_scheduler_type")
min_lr_ratio: Optional[float] = None
warmup_style: str = "constant"
# deprecate warmup_style
warmup_style: Optional[str] = None
lr_scheduler_type: str = "constant"
num_cycles: float = 0.5
def __post_init__(self):
assert self.warmup_style in ["constant", "cosine"]
if self.warmup_style is not None:
assert self.warmup_style in ["constant", "cosine"]
warnings.warn(
"`warmup_style` is deprecated, use `lr_scheduler_type` instead.", DeprecationWarning, stacklevel=2
)
self.lr_scheduler_type = self.warmup_style
assert self.lr_scheduler_type in ["constant", "cosine"]
return super().__post_init__()

View File

@ -370,7 +370,7 @@ class FSDPEngine(BaseEngine):
total_steps = optim_config.total_training_steps
num_warmup_steps = optim_config.lr_warmup_steps
warmup_style = optim_config.warmup_style
lr_scheduler_type = optim_config.lr_scheduler_type
min_lr_ratio = optim_config.min_lr_ratio
num_cycles = optim_config.num_cycles
if num_warmup_steps <= 0:
@ -380,9 +380,9 @@ class FSDPEngine(BaseEngine):
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if warmup_style == "constant":
if lr_scheduler_type == "constant":
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)
elif warmup_style == "cosine":
elif lr_scheduler_type == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
@ -391,7 +391,7 @@ class FSDPEngine(BaseEngine):
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
return lr_scheduler
def _build_model_optimizer(self):

View File

@ -529,7 +529,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
total_steps = optim_config.get("total_training_steps", 0)
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
warmup_style = optim_config.get("warmup_style", "constant")
lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant")
min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
num_cycles = optim_config.get("num_cycles", 0.5)
if num_warmup_steps < 0:
@ -539,11 +539,11 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if warmup_style == "constant":
if lr_scheduler_type == "constant":
actor_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
)
elif warmup_style == "cosine":
elif lr_scheduler_type == "cosine":
actor_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps,
@ -552,7 +552,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
else:
@ -1386,7 +1386,8 @@ class CriticWorker(Worker, DistProfilerExtension):
total_steps = config.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
warmup_style = config.optim.get("warmup_style", "constant")
lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant")
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
@ -1396,11 +1397,11 @@ class CriticWorker(Worker, DistProfilerExtension):
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
if warmup_style == "constant":
if lr_scheduler_type == "constant":
critic_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
)
elif warmup_style == "cosine":
elif lr_scheduler_type == "cosine":
min_lr_ratio = config.optim.get("min_lr_ratio", 0.0)
num_cycles = config.optim.get("num_cycles", 0.5)
critic_lr_scheduler = get_cosine_schedule_with_warmup(
@ -1411,7 +1412,7 @@ class CriticWorker(Worker, DistProfilerExtension):
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
return critic_module, critic_optimizer, critic_lr_scheduler