mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
Compare commits
5 Commits
7ddb9b29f0
...
21271aabb9
Author | SHA1 | Date | |
---|---|---|---|
21271aabb9 | |||
7f27789961 | |||
e9ee6b39c6 | |||
9d4554b931 | |||
71cf69e7ad |
2
.github/workflows/e2e_sft.yml
vendored
2
.github/workflows/e2e_sft.yml
vendored
@ -91,7 +91,7 @@ jobs:
|
||||
e2e_sft:
|
||||
needs: setup
|
||||
runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"]
|
||||
timeout-minutes: 25 # Increase this timeout value as needed
|
||||
timeout-minutes: 30 # Increase this timeout value as needed
|
||||
env:
|
||||
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
|
||||
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
|
||||
|
642
docs/advance/rollout_is_migration.md
Normal file
642
docs/advance/rollout_is_migration.md
Normal file
@ -0,0 +1,642 @@
|
||||
# Rollout Importance Sampling - Migration Guide
|
||||
|
||||
Last updated: 10/11/2025.
|
||||
|
||||
This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation merged from aiic_verl into verl.
|
||||
|
||||
## References
|
||||
|
||||
- **When Speed Kills Stability**: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
- **Off-policy RL**: https://fengyao.notion.site/off-policy-rl
|
||||
|
||||
## Overview
|
||||
|
||||
Rollout Importance Sampling corrects for distribution mismatch between:
|
||||
- **Rollout policy**: e.g., vLLM with BFloat16
|
||||
- **Training policy**: e.g., FSDP with FP32
|
||||
|
||||
This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases.
|
||||
|
||||
## What Changed
|
||||
|
||||
### **Removed (Old Implementation)**
|
||||
|
||||
```yaml
|
||||
# Old TIS configuration (REMOVED)
|
||||
actor:
|
||||
tis_imp_ratio_cap: 2.0 # ❌ No longer supported
|
||||
```
|
||||
|
||||
The old implementation:
|
||||
- Only supported token-level truncate mode
|
||||
- Had no metrics tracking
|
||||
- Lacked numerical stability safeguards
|
||||
- No configurability for different scenarios
|
||||
|
||||
### **Added (New Implementation)**
|
||||
|
||||
```yaml
|
||||
# New Rollout IS configuration (all in algorithm config)
|
||||
algorithm:
|
||||
# Main control: set threshold to enable (null = disabled)
|
||||
rollout_is_threshold: 2.0
|
||||
# Whether to apply weights to loss (default: false = metrics only)
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: null # Auto-reciprocal
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
|
||||
# REQUIRED: Enable log prob calculation
|
||||
actor_rollout_ref:
|
||||
rollout:
|
||||
calculate_log_probs: true
|
||||
```
|
||||
|
||||
The new implementation:
|
||||
- ✅ Three aggregation levels: token, sequence, geometric
|
||||
- ✅ Two bounding modes: truncate, clip
|
||||
- ✅ Dual threshold support (upper/lower)
|
||||
- ✅ Veto mechanism for catastrophic outliers
|
||||
- ✅ 30+ comprehensive metrics
|
||||
- ✅ Log-space computation for numerical stability
|
||||
- ✅ Memory-efficient implementation
|
||||
|
||||
## Files Modified
|
||||
|
||||
### **Core Implementation**
|
||||
|
||||
1. **NEW**: `verl/trainer/ppo/mismatch_helper.py`
|
||||
- Contains `compute_rollout_importance_weights()` - main function
|
||||
- Contains `compute_is_metrics()` - comprehensive metrics
|
||||
|
||||
2. **MODIFIED**: `verl/trainer/ppo/core_algos.py` (lines 962-991)
|
||||
- Replaced old TIS implementation (lines 962-967)
|
||||
- Added new rollout IS with metrics support
|
||||
|
||||
3. **MODIFIED**: `verl/workers/actor/dp_actor.py`
|
||||
- Updated to use `rollout_is_threshold` instead of `tis_imp_ratio_cap`
|
||||
- Collects and logs all rollout IS metrics
|
||||
|
||||
### **Configuration Files**
|
||||
|
||||
4. **MODIFIED**: `verl/trainer/config/algorithm.py` (lines 95-100)
|
||||
- Added 6 new rollout IS parameters to `AlgoConfig`
|
||||
|
||||
5. **MODIFIED**: `verl/workers/config/actor.py` (lines 110-115)
|
||||
- Added 6 new rollout IS parameters to `ActorConfig`
|
||||
|
||||
6. **MODIFIED**: `verl/trainer/config/actor/actor.yaml` (lines 77-89)
|
||||
- Added rollout IS configuration section
|
||||
|
||||
7. **MODIFIED**: `verl/trainer/config/ppo_trainer.yaml` (lines 116-133)
|
||||
- Added rollout IS to algorithm config
|
||||
|
||||
### **Documentation**
|
||||
|
||||
8. **MODIFIED**: `docs/examples/config.rst`
|
||||
- Updated actor config with rollout IS parameters
|
||||
- Updated algorithm config with rollout IS parameters
|
||||
- Added detailed parameter descriptions
|
||||
|
||||
### **Example Scripts**
|
||||
|
||||
9. **MODIFIED**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
|
||||
- Updated from `tis_imp_ratio_cap` to rollout IS parameters
|
||||
- Added comprehensive comments
|
||||
|
||||
10. **NEW**: `examples/rollout_importance_sampling/README.md`
|
||||
- Comprehensive guide with usage patterns
|
||||
- Troubleshooting section
|
||||
- Performance considerations
|
||||
|
||||
11. **NEW**: `examples/rollout_importance_sampling/run_with_rollout_is.sh`
|
||||
- Basic example with token-level truncate
|
||||
|
||||
### **Tests**
|
||||
|
||||
12. **NEW**: `tests/trainer/ppo/test_rollout_is.py`
|
||||
- Unit tests for rollout IS functionality
|
||||
|
||||
13. **NEW**: `tests/trainer/ppo/test_rollout_is_integration.py`
|
||||
- Integration tests with PPO
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
### `algorithm.rollout_is_threshold` (float or null)
|
||||
**Main on/off switch.** Upper threshold for IS weights.
|
||||
- `null` = disabled (no computation, no metrics)
|
||||
- `float` value (e.g., 2.0) = enabled (compute weights and metrics)
|
||||
|
||||
### `algorithm.rollout_is` (bool)
|
||||
Whether to apply IS weights to policy loss. Default: `False`
|
||||
- `true` = apply weights to loss (full IS correction)
|
||||
- `false` = compute metrics only (useful for monitoring before enabling)
|
||||
|
||||
**Recommended threshold ranges:**
|
||||
- Token level: 1.5 - 5.0
|
||||
- Sequence level: 2.0 - 10.0
|
||||
- Geometric level: 1.0002 - 1.001
|
||||
|
||||
### `algorithm.rollout_is_threshold_lower` (float or null)
|
||||
Lower threshold for IS weights. If `null`, defaults to 1/upper (reciprocal).
|
||||
|
||||
### `algorithm.rollout_is_level` (str)
|
||||
Aggregation level for IS weights:
|
||||
- `"token"`: Per-token ratios
|
||||
- `"sequence"`: Product of ratios
|
||||
- `"geometric"`: Geometric mean (experimental)
|
||||
|
||||
### `algorithm.rollout_is_mode` (str)
|
||||
Bounding mode:
|
||||
- `"truncate"`: Cap weights at upper threshold only
|
||||
- `"clip"`: Zero out weights outside [lower, upper]
|
||||
|
||||
### `algorithm.rollout_is_veto_threshold` (float)
|
||||
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
|
||||
Default: `1e-4` (ratio 10,000x off)
|
||||
|
||||
## Migration Steps
|
||||
|
||||
### Step 1: Update Your Configuration
|
||||
|
||||
**Before (Old):**
|
||||
```yaml
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
tis_imp_ratio_cap: 2.0
|
||||
rollout:
|
||||
calculate_log_probs: true
|
||||
```
|
||||
|
||||
**After (New):**
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0 # Main control
|
||||
rollout_is: true # Apply to loss (default: false)
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
|
||||
actor_rollout_ref:
|
||||
rollout:
|
||||
calculate_log_probs: true # Still required!
|
||||
```
|
||||
|
||||
### Step 2: Monitor New Metrics
|
||||
|
||||
All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs.
|
||||
|
||||
#### **Core IS Weight Metrics**
|
||||
|
||||
- **`rollout_is_mean`**: Mean importance sampling weight across all valid tokens
|
||||
- **Ideal value**: Close to 1.0 (indicates minimal distribution mismatch)
|
||||
- **Warning**: < 0.5 or > 2.0 suggests significant policy mismatch
|
||||
|
||||
- **`rollout_is_std`**: Standard deviation of IS weights
|
||||
- **Ideal value**: < 0.5 for stable training
|
||||
- **Warning**: > 1.0 indicates high variance, may need tighter thresholds
|
||||
|
||||
- **`rollout_is_min`**: Minimum IS weight observed
|
||||
- Shows the most underweighted token/sequence
|
||||
|
||||
- **`rollout_is_max`**: Maximum IS weight observed (before clipping)
|
||||
- Shows the most overweighted token/sequence
|
||||
- Compare with `rollout_is_threshold` to see truncation impact
|
||||
|
||||
#### **Percentile Metrics**
|
||||
|
||||
- **`rollout_is_p25`**: 25th percentile of IS weights
|
||||
- **`rollout_is_p50`**: Median IS weight (50th percentile)
|
||||
- Should be close to `rollout_is_mean` if distribution is symmetric
|
||||
- **`rollout_is_p75`**: 75th percentile of IS weights
|
||||
- **`rollout_is_p95`**: 95th percentile of IS weights
|
||||
- Use to detect outliers
|
||||
- **`rollout_is_p99`**: 99th percentile of IS weights
|
||||
- Should be close to `rollout_is_threshold` if truncation is working
|
||||
|
||||
#### **Effective Sample Size**
|
||||
|
||||
- **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting
|
||||
- **Formula**: `1 / mean(weights²)` where weights are normalized
|
||||
- **Range**: 0.0 to 1.0 (as fraction of original batch)
|
||||
- **Ideal value**: > 0.5 (retaining at least 50% effective samples)
|
||||
- **Warning**: < 0.3 means high variance, losing too many effective samples
|
||||
|
||||
#### **Veto Mechanism Metrics**
|
||||
|
||||
- **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism
|
||||
- **Ideal value**: < 0.05 (less than 5% vetoed)
|
||||
- **Warning**: > 0.1 suggests policies are too different or numerical issues
|
||||
|
||||
- **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold
|
||||
- Identifies problematic tokens before sequence-level veto
|
||||
- **Warning**: > 0.01 indicates widespread distribution issues
|
||||
|
||||
#### **Threshold Exceedance Metrics**
|
||||
|
||||
- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold
|
||||
- Shows how often truncation/clipping occurs on high end
|
||||
- **Ideal value**: < 0.1 (most weights within bounds)
|
||||
|
||||
- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold
|
||||
- Shows how often clipping occurs on low end (clip mode only)
|
||||
- **Ideal value**: < 0.1
|
||||
|
||||
#### **Sequence-Level Metrics** (for sequence/geometric modes)
|
||||
|
||||
- **`rollout_is_seq_mean`**: Mean IS weight at sequence level
|
||||
- Should match `rollout_is_mean` for sequence-level aggregation
|
||||
|
||||
- **`rollout_is_seq_std`**: Standard deviation of sequence-level IS weights
|
||||
|
||||
- **`rollout_is_seq_min`**: Minimum sequence-level IS weight
|
||||
|
||||
- **`rollout_is_seq_max`**: Maximum sequence-level IS weight
|
||||
|
||||
- **`rollout_is_seq_max_deviation`**: Maximum absolute deviation from 1.0 at sequence level
|
||||
- **Ideal value**: < 1.0
|
||||
- Shows worst-case sequence mismatch
|
||||
|
||||
- **`rollout_is_seq_fraction_high`**: Fraction of sequences exceeding upper threshold
|
||||
|
||||
- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold
|
||||
|
||||
#### **Clipping Metrics** (clip mode only)
|
||||
|
||||
- **`rollout_is_clipped_fraction`**: Fraction of tokens clipped (set to zero)
|
||||
- **Ideal value**: < 0.1
|
||||
- **Warning**: > 0.3 means losing too much data
|
||||
|
||||
- **`rollout_is_seq_clipped_fraction`**: Fraction of sequences with at least one clipped token
|
||||
- Shows sequence-level impact of clipping
|
||||
|
||||
#### **Distribution Mismatch Metrics** (Training vs Rollout Policy)
|
||||
|
||||
- **`mismatch_training_ppl`**: Perplexity of training policy (e.g., FSDP FP32)
|
||||
- **Formula**: `exp(-mean(log_probs))`
|
||||
- Lower is better (model is more confident)
|
||||
|
||||
- **`mismatch_rollout_ppl`**: Perplexity of rollout policy (e.g., vLLM BF16)
|
||||
- Should be close to `mismatch_training_ppl` if policies match well
|
||||
|
||||
- **`mismatch_ppl_ratio`**: Ratio of training PPL to rollout PPL
|
||||
- **Formula**: `exp(mean(log(training_ppl / rollout_ppl)))`
|
||||
- **Ideal value**: Close to 1.0
|
||||
- **Meaning**: > 1.0 means training is less confident than rollout
|
||||
|
||||
- **`mismatch_training_log_ppl`**: Log perplexity of training policy
|
||||
- Useful for identifying trends (linear scale)
|
||||
|
||||
- **`mismatch_rollout_log_ppl`**: Log perplexity of rollout policy
|
||||
|
||||
- **`mismatch_log_ppl_diff`**: Mean difference in log perplexities
|
||||
- **Formula**: `mean(log_ppl_rollout - log_ppl_training)`
|
||||
- **Ideal value**: Close to 0.0
|
||||
- Sign indicates which policy is more confident
|
||||
|
||||
- **`mismatch_log_ppl_abs_diff`**: Mean absolute log perplexity difference
|
||||
- Magnitude of mismatch regardless of direction
|
||||
|
||||
- **`mismatch_log_ppl_diff_max`**: Maximum log perplexity difference across sequences
|
||||
- Identifies worst-case sequence
|
||||
|
||||
- **`mismatch_log_ppl_diff_min`**: Minimum log perplexity difference across sequences
|
||||
|
||||
- **`mismatch_kl`**: KL divergence KL(π_rollout || π_training)
|
||||
- **Formula**: `mean(log_prob_rollout - log_prob_training)`
|
||||
- **Ideal value**: Close to 0.0 (policies match)
|
||||
- **Warning**: > 0.1 indicates significant mismatch
|
||||
- **Note**: Can be negative (rollout is less confident)
|
||||
|
||||
- **`mismatch_k3_kl`**: K3 KL estimator
|
||||
- **Formula**: `mean(exp(log_ratio) - log_ratio - 1)`
|
||||
- More stable for small KL values
|
||||
- Always non-negative
|
||||
|
||||
#### **Example: Accessing Metrics in Code**
|
||||
|
||||
```python
|
||||
# Metrics are returned from compute_rollout_importance_weights
|
||||
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights
|
||||
|
||||
weights_proto, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=training_log_probs, # from training policy
|
||||
rollout_log_prob=rollout_log_probs, # from rollout policy
|
||||
response_mask=response_mask,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
# All metrics have 'mismatch/' prefix
|
||||
print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}")
|
||||
print(f"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}")
|
||||
print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}")
|
||||
print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}")
|
||||
|
||||
# Check for warning conditions
|
||||
if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0:
|
||||
print("⚠️ Warning: Mean IS weight far from 1.0, significant policy mismatch detected")
|
||||
|
||||
if metrics['mismatch/rollout_is_eff_sample_size'] < 0.3:
|
||||
print("⚠️ Warning: Low effective sample size, high variance in IS weights")
|
||||
|
||||
if metrics['mismatch/rollout_is_veto_fraction'] > 0.1:
|
||||
print("⚠️ Warning: High veto fraction, policies may be too different")
|
||||
```
|
||||
|
||||
#### **Example: Monitoring Metrics During Training**
|
||||
|
||||
```python
|
||||
# In your training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
# ... rollout phase ...
|
||||
|
||||
# Compute IS weights and get metrics
|
||||
weights_proto, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=batch.old_log_prob,
|
||||
rollout_log_prob=batch.rollout_log_prob,
|
||||
response_mask=batch.response_mask,
|
||||
rollout_is_level=config.rollout_is_level,
|
||||
rollout_is_mode=config.rollout_is_mode,
|
||||
rollout_is_threshold=config.rollout_is_threshold,
|
||||
rollout_is_veto_threshold=config.rollout_is_veto_threshold,
|
||||
)
|
||||
|
||||
# Log to tensorboard/wandb
|
||||
for metric_name, metric_value in metrics.items():
|
||||
logger.log_scalar(metric_name, metric_value, step=global_step)
|
||||
|
||||
# Use IS weights in training
|
||||
is_weights = weights_proto.batch["rollout_is_weights"]
|
||||
# ... apply weights to policy gradient ...
|
||||
```
|
||||
|
||||
#### **Example: Conditional Alerting Based on Metrics**
|
||||
|
||||
```python
|
||||
def check_rollout_is_health(metrics, config):
|
||||
"""Check if rollout IS metrics indicate healthy training."""
|
||||
warnings = []
|
||||
|
||||
# Check mean IS weight
|
||||
mean_weight = metrics['mismatch/rollout_is_mean']
|
||||
if mean_weight < 0.5 or mean_weight > 2.0:
|
||||
warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0")
|
||||
|
||||
# Check effective sample size
|
||||
ess = metrics['mismatch/rollout_is_eff_sample_size']
|
||||
if ess < 0.3:
|
||||
warnings.append(f"Effective sample size {ess:.3f} is too low")
|
||||
|
||||
# Check veto fraction
|
||||
veto_frac = metrics['mismatch/rollout_is_veto_fraction']
|
||||
if veto_frac > 0.1:
|
||||
warnings.append(f"Veto fraction {veto_frac:.3f} is too high")
|
||||
|
||||
# Check variance
|
||||
std = metrics['mismatch/rollout_is_std']
|
||||
if std > 1.0:
|
||||
warnings.append(f"IS weight std {std:.3f} is too high")
|
||||
|
||||
# Check KL divergence
|
||||
kl = metrics['mismatch/mismatch_kl']
|
||||
if abs(kl) > 0.1:
|
||||
warnings.append(f"KL divergence {kl:.3f} indicates significant mismatch")
|
||||
|
||||
if warnings:
|
||||
print("⚠️ Rollout IS Health Warnings:")
|
||||
for warning in warnings:
|
||||
print(f" - {warning}")
|
||||
return False
|
||||
else:
|
||||
print("✅ Rollout IS metrics look healthy")
|
||||
return True
|
||||
|
||||
# Use in training
|
||||
_, metrics = compute_rollout_importance_weights(...)
|
||||
is_healthy = check_rollout_is_health(metrics, config)
|
||||
|
||||
if not is_healthy:
|
||||
# Consider adjusting config or investigating issues
|
||||
print("Consider:")
|
||||
print(" - Tightening rollout_is_threshold")
|
||||
print(" - Switching to geometric aggregation level")
|
||||
print(" - Checking if rollout and training policies are too different")
|
||||
```
|
||||
|
||||
### Step 3: Test Your Training
|
||||
|
||||
Start with the basic token-level truncate configuration:
|
||||
```bash
|
||||
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
|
||||
```
|
||||
|
||||
Monitor metrics for 1-2 epochs before adjusting parameters.
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Example 1: Full IS Correction
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0
|
||||
rollout_is: true # Apply weights to loss
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
```
|
||||
|
||||
### Example 2: Metrics Only (Monitoring Mode)
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0
|
||||
rollout_is: false # Compute metrics, don't apply weights
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
```
|
||||
|
||||
### Example 3: Geometric Mean with Clip
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 1.0002
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.9998
|
||||
rollout_is_level: geometric
|
||||
rollout_is_mode: clip
|
||||
```
|
||||
|
||||
### Example 4: Asymmetric Thresholds
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 5.0
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.8
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: clip
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: High variance in IS weights
|
||||
**Symptoms:** `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3
|
||||
|
||||
**Solutions:**
|
||||
1. Switch from `sequence` to `geometric` level
|
||||
2. Tighten thresholds
|
||||
3. Verify rollout and training aren't too different
|
||||
|
||||
### Issue: Too many sequences vetoed
|
||||
**Symptoms:** `rollout_is_veto_fraction` > 0.1
|
||||
|
||||
**Solutions:**
|
||||
1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`
|
||||
2. Check for numerical issues in log prob computation
|
||||
3. Verify policies aren't completely different
|
||||
|
||||
### Issue: Mean IS weight far from 1.0
|
||||
**Symptoms:** `rollout_is_mean` < 0.5 or > 2.0
|
||||
|
||||
**Solutions:**
|
||||
1. Verify `calculate_log_probs=True` is set
|
||||
2. Check rollout_log_probs are correctly passed
|
||||
3. Check for systematic bias
|
||||
|
||||
### Debugging: Visualizing Metrics
|
||||
|
||||
**Example: Plot IS weight distribution**
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def plot_is_metrics(metrics_history):
|
||||
"""Plot rollout IS metrics over training steps."""
|
||||
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
||||
|
||||
# Plot 1: Mean IS weight over time
|
||||
axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean'])
|
||||
axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
|
||||
axes[0, 0].set_title('Mean IS Weight')
|
||||
axes[0, 0].set_xlabel('Step')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# Plot 2: Effective sample size
|
||||
axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size'])
|
||||
axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good')
|
||||
axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning')
|
||||
axes[0, 1].set_title('Effective Sample Size')
|
||||
axes[0, 1].set_xlabel('Step')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# Plot 3: Veto fraction
|
||||
axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction'])
|
||||
axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='Warning')
|
||||
axes[0, 2].set_title('Veto Fraction')
|
||||
axes[0, 2].set_xlabel('Step')
|
||||
axes[0, 2].legend()
|
||||
|
||||
# Plot 4: IS weight distribution (latest step)
|
||||
latest_idx = -1
|
||||
percentiles = [25, 50, 75, 95, 99]
|
||||
values = [metrics_history[f'mismatch/rollout_is_p{p}'][latest_idx] for p in percentiles]
|
||||
axes[1, 0].bar([f'p{p}' for p in percentiles], values)
|
||||
axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
|
||||
axes[1, 0].set_title('IS Weight Percentiles (Latest)')
|
||||
axes[1, 0].legend()
|
||||
|
||||
# Plot 5: KL divergence over time
|
||||
axes[1, 1].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
|
||||
axes[1, 1].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
|
||||
axes[1, 1].axhline(y=0, color='g', linestyle='--', alpha=0.3)
|
||||
axes[1, 1].set_title('KL Divergence')
|
||||
axes[1, 1].set_xlabel('Step')
|
||||
axes[1, 1].legend()
|
||||
|
||||
# Plot 6: PPL ratio over time
|
||||
axes[1, 2].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
|
||||
axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
|
||||
axes[1, 2].set_title('PPL Ratio (Training/Rollout)')
|
||||
axes[1, 2].set_xlabel('Step')
|
||||
axes[1, 2].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('rollout_is_metrics.png', dpi=150)
|
||||
print("Saved plot to rollout_is_metrics.png")
|
||||
```
|
||||
|
||||
**Example: Metric collection during training**
|
||||
|
||||
```python
|
||||
# Collect metrics over time
|
||||
metrics_history = {
|
||||
'mismatch/rollout_is_mean': [],
|
||||
'mismatch/rollout_is_eff_sample_size': [],
|
||||
'mismatch/rollout_is_veto_fraction': [],
|
||||
'mismatch/rollout_is_p25': [],
|
||||
'mismatch/rollout_is_p50': [],
|
||||
'mismatch/rollout_is_p75': [],
|
||||
'mismatch/rollout_is_p95': [],
|
||||
'mismatch/rollout_is_p99': [],
|
||||
'mismatch/mismatch_kl': [],
|
||||
'mismatch/mismatch_k3_kl': [],
|
||||
'mismatch/mismatch_ppl_ratio': [],
|
||||
}
|
||||
|
||||
# In training loop
|
||||
for step in range(num_steps):
|
||||
# ... compute IS weights ...
|
||||
_, metrics = compute_rollout_importance_weights(...)
|
||||
|
||||
# Store metrics
|
||||
for key in metrics_history.keys():
|
||||
if key in metrics:
|
||||
metrics_history[key].append(metrics[key])
|
||||
|
||||
# Plot every 100 steps
|
||||
if step % 100 == 0:
|
||||
plot_is_metrics(metrics_history)
|
||||
```
|
||||
|
||||
## Performance Impact
|
||||
|
||||
- **Memory overhead**: ~1% of model memory
|
||||
- **Computational overhead**: 1-3% depending on level
|
||||
- **Training stability**: Significantly improved when mismatch exists
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
**The old `tis_imp_ratio_cap` parameter is completely removed.** There is no backward compatibility mode.
|
||||
|
||||
All scripts and configurations must be updated to use the new rollout IS parameters.
|
||||
|
||||
## Testing
|
||||
|
||||
Run the test suite to verify everything works:
|
||||
|
||||
```bash
|
||||
# Basic unit tests
|
||||
python test_rollout_is.py
|
||||
|
||||
# Integration tests (if pytest is available)
|
||||
pytest tests/trainer/ppo/test_rollout_is_integration.py -v
|
||||
```
|
||||
|
||||
Expected output: All tests pass ✓
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- **Implementation**: `verl/trainer/ppo/mismatch_helper.py`
|
||||
- **Examples**: `examples/rollout_importance_sampling/`
|
||||
- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
|
||||
|
||||
## Summary
|
||||
|
||||
The new Rollout Importance Sampling implementation provides:
|
||||
- ✅ More robust handling of distribution mismatch
|
||||
- ✅ Better numerical stability
|
||||
- ✅ Comprehensive metrics for monitoring
|
||||
- ✅ Flexibility for different scenarios
|
||||
- ✅ Memory-efficient computation
|
||||
|
||||
Migration is straightforward: replace `tis_imp_ratio_cap` with the new `rollout_is_*` parameters in the `algorithm` config section.
|
@ -118,7 +118,13 @@ Actor/Rollout/Reference Policy
|
||||
clip_ratio: 0.2
|
||||
entropy_coeff: 0.0
|
||||
use_kl_loss: False # True for GRPO
|
||||
tis_imp_ratio_cap: -1 # set to positive values for Truncated Importance Sampling (requires setting `rollout.calculate_log_probs` as True)
|
||||
# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)
|
||||
rollout_is: False # Enable IS correction
|
||||
rollout_is_threshold: null # Upper threshold for IS weights (null to disable)
|
||||
rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)
|
||||
rollout_is_level: token # Aggregation: token/sequence/geometric
|
||||
rollout_is_mode: truncate # Bounding: truncate/clip
|
||||
rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold
|
||||
use_torch_compile: True # False to disable torch compile
|
||||
kl_loss_coef: 0.001 # for grpo
|
||||
kl_loss_type: low_var_kl # for grpo
|
||||
@ -132,7 +138,7 @@ Actor/Rollout/Reference Policy
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: 0.0 # only used with cosine lr scheduler, default to 0.0
|
||||
num_cycles: 0.5 # only used with cosine lr scheduler, default to 0.5
|
||||
warmup_style: constant # select from constant/cosine
|
||||
lr_scheduler_type: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
fsdp_config:
|
||||
wrap_policy:
|
||||
@ -415,7 +421,7 @@ ____________________________________________________
|
||||
|
||||
Notice that there are some differences in APIs between Megatron optimizer and FSDP optimizer.
|
||||
|
||||
- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``warmup_style`` actually means the style of lr decay after warmup.
|
||||
- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``lr_scheduler_type`` actually means the style of lr decay after warmup.
|
||||
- Megatron optimizer also support weight decay decay mechanism
|
||||
- ``use_checkpoint_opt_param_scheduler`` determines whether to use the checkpoint optimizer parameter scheduler. If set to True, the optimizer parameter scheduler will be saved in the checkpoint and loaded from the checkpoint during resuming training.
|
||||
|
||||
@ -498,6 +504,13 @@ Algorithm
|
||||
kl_coef: 0.005
|
||||
horizon: 10000
|
||||
target_kl: 0.1
|
||||
# Rollout Importance Sampling
|
||||
rollout_is: False
|
||||
rollout_is_threshold: null
|
||||
rollout_is_threshold_lower: null
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
|
||||
- ``gamma``: discount factor
|
||||
- ``lam``: Trade-off between bias and variance in the GAE estimator
|
||||
@ -510,6 +523,13 @@ Algorithm
|
||||
- ``kl_coef``: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.
|
||||
- ``type``: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.
|
||||
- ``horizon`` and ``target_kl``: See source code of AdaptiveKLController for details.
|
||||
- ``rollout_is``: Whether to enable rollout importance sampling correction. Default is False.
|
||||
- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely.
|
||||
- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).
|
||||
- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).
|
||||
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``clip`` (zero outside bounds).
|
||||
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4.
|
||||
Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.
|
||||
|
||||
Trainer
|
||||
~~~~~~~
|
||||
|
@ -121,6 +121,7 @@ verl is fast with:
|
||||
examples/sandbox_fusion_example
|
||||
advance/rollout_trace.rst
|
||||
advance/rollout_skip.rst
|
||||
advance/rollout_is_migration.md
|
||||
advance/one_step_off
|
||||
advance/agent_loop
|
||||
|
||||
|
242
examples/rollout_importance_sampling/README.md
Normal file
242
examples/rollout_importance_sampling/README.md
Normal file
@ -0,0 +1,242 @@
|
||||
# Rollout Importance Sampling (IS) Examples
|
||||
|
||||
This directory contains examples and documentation for using Rollout Importance Sampling to correct distribution mismatch between rollout and training policies.
|
||||
|
||||
**References:**
|
||||
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
- Off-policy RL: https://fengyao.notion.site/off-policy-rl
|
||||
|
||||
## Overview
|
||||
|
||||
Rollout Importance Sampling corrects for distribution mismatch when:
|
||||
1. **Rollout generation** uses one policy (e.g., vLLM with BFloat16)
|
||||
2. **Training** uses another policy (e.g., FSDP with FP32)
|
||||
3. This mismatch leads to biased gradient estimates
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
# Main control: set threshold to enable (null = disabled)
|
||||
rollout_is_threshold: 2.0
|
||||
# Whether to apply weights to policy loss (true) or just compute metrics (false)
|
||||
rollout_is: true
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
|
||||
# IMPORTANT: Must enable log prob calculation
|
||||
actor_rollout_ref:
|
||||
rollout:
|
||||
calculate_log_probs: true
|
||||
```
|
||||
|
||||
### Running the Example
|
||||
|
||||
```bash
|
||||
# Basic example with token-level truncate
|
||||
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Aggregation Levels (`rollout_is_level`)
|
||||
|
||||
| Level | Properties | Threshold Range |
|
||||
|-------|-----------|-----------------|
|
||||
| **token** | Per-token | 1.5 - 5.0 |
|
||||
| **sequence** | Per-sequence | 2.0 - 10.0 |
|
||||
| **geometric** | Geometric mean | 1.0002 - 1.001 |
|
||||
|
||||
### Bounding Modes (`rollout_is_mode`)
|
||||
|
||||
| Mode | Behavior |
|
||||
|------|----------|
|
||||
| **truncate** | Cap weights at upper threshold only |
|
||||
| **clip** | Zero out weights outside [lower, upper] |
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.**
|
||||
- `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false.
|
||||
- `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper)
|
||||
- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: 1e-4)
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Example 1: Full IS Correction (Apply Weights)
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0
|
||||
rollout_is: true # Apply to loss
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
```
|
||||
|
||||
### Example 2: Metrics Only (No Weight Application)
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 2.0
|
||||
rollout_is: false # Compute metrics only, don't apply to loss
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
```
|
||||
|
||||
### Example 3: Geometric Mean with Clip
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 1.0002
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.9998
|
||||
rollout_is_level: geometric
|
||||
rollout_is_mode: clip
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
```
|
||||
|
||||
### Example 4: Sequence-level with Truncate
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 5.0
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: null # Auto-reciprocal: 0.2
|
||||
rollout_is_level: sequence
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
```
|
||||
|
||||
### Example 5: Asymmetric Thresholds
|
||||
|
||||
```yaml
|
||||
algorithm:
|
||||
rollout_is_threshold: 5.0
|
||||
rollout_is: true
|
||||
rollout_is_threshold_lower: 0.8
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: clip
|
||||
```
|
||||
|
||||
## Monitoring Metrics
|
||||
|
||||
Key metrics to watch (all prefixed with `mismatch/` in logs):
|
||||
|
||||
### Health Indicators
|
||||
- `rollout_is_mean`: Mean IS weight across sequences
|
||||
- `rollout_is_eff_sample_size`: Effective sample size after weighting
|
||||
- `rollout_is_veto_fraction`: Fraction of sequences vetoed
|
||||
|
||||
### Distribution Metrics
|
||||
- `rollout_is_max`, `rollout_is_min`: Weight extremes
|
||||
- `rollout_is_std`: Standard deviation
|
||||
- `rollout_is_p50`, `rollout_is_p95`, `rollout_is_p99`: Percentiles
|
||||
|
||||
### Diagnostic Metrics
|
||||
- `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold
|
||||
- `rollout_is_ratio_fraction_low`: Fraction below lower threshold
|
||||
- `rollout_is_catastrophic_token_fraction`: Catastrophic tokens detected
|
||||
|
||||
### Mismatch Metrics (Training vs Rollout Policy)
|
||||
|
||||
These metrics help diagnose the distribution mismatch between rollout and training policies:
|
||||
|
||||
**Perplexity Metrics:**
|
||||
- `mismatch_training_ppl`: Perplexity of training policy
|
||||
- `mismatch_rollout_ppl`: Perplexity of rollout policy
|
||||
- `mismatch_ppl_ratio`: Ratio of training PPL to rollout PPL
|
||||
- `mismatch_log_ppl_diff`: Log perplexity difference
|
||||
|
||||
**KL Divergence Metrics:**
|
||||
- `mismatch_kl`: KL divergence KL(π_rollout || π_training)
|
||||
- `mismatch_k3_kl`: K3 KL estimator
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: High Variance in IS Weights
|
||||
|
||||
**Symptoms**: `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3
|
||||
|
||||
**Solutions**:
|
||||
1. Switch from `sequence` to `geometric` level
|
||||
2. Tighten thresholds
|
||||
3. Check if rollout and training are too different
|
||||
|
||||
### Issue: Too Many Sequences Vetoed
|
||||
|
||||
**Symptoms**: `rollout_is_veto_fraction` > 0.1
|
||||
|
||||
**Solutions**:
|
||||
1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`
|
||||
2. Check for numerical issues in log prob computation
|
||||
3. Verify rollout and training policies aren't completely different
|
||||
|
||||
### Issue: Mean IS Weight Far from 1.0
|
||||
|
||||
**Symptoms**: `rollout_is_mean` < 0.5 or > 2.0
|
||||
|
||||
**Solutions**:
|
||||
1. Check that `calculate_log_probs=True` is set
|
||||
2. Verify rollout_log_probs are correctly passed
|
||||
3. Check for systematic bias in rollout vs training
|
||||
|
||||
### Issue: Too Much Data Discarded (Clip Mode)
|
||||
|
||||
**Symptoms**: `rollout_is_clipped_fraction` > 0.5
|
||||
|
||||
**Solutions**:
|
||||
1. Widen thresholds
|
||||
2. Switch to `truncate` mode
|
||||
3. Use `geometric` level for better stability
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Memory Usage
|
||||
- Rollout IS adds minimal memory overhead (~1% of model memory)
|
||||
- Log-space computation prevents numerical overflow
|
||||
|
||||
### Computational Cost
|
||||
- Token-level: ~1-2% overhead
|
||||
- Sequence-level: ~2-3% overhead
|
||||
- Geometric: ~2-3% overhead
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Dual Thresholds
|
||||
|
||||
Specify both upper and lower explicitly:
|
||||
|
||||
```yaml
|
||||
rollout_is_threshold: 2.0 # Upper
|
||||
rollout_is_threshold_lower: 0.5 # Lower (not 1/2.0 = 0.5)
|
||||
```
|
||||
|
||||
Or use auto-reciprocal:
|
||||
|
||||
```yaml
|
||||
rollout_is_threshold: 2.0 # Upper = 2.0, Lower = 0.5 (auto)
|
||||
rollout_is_threshold_lower: null
|
||||
```
|
||||
|
||||
### Veto Mechanism
|
||||
|
||||
The veto mechanism zeros out entire sequences containing catastrophic outliers:
|
||||
|
||||
- If any token has ratio < `rollout_is_veto_threshold`, the entire sequence is rejected
|
||||
- This prevents extreme outliers from dominating training
|
||||
- Default threshold: 1e-4 (ratio 10,000x off)
|
||||
- Set to `null` to disable: `rollout_is_veto_threshold: null`
|
||||
|
||||
## Examples
|
||||
|
||||
See the script in this directory:
|
||||
- `run_with_rollout_is.sh`: Basic example with token-level truncate mode
|
||||
|
||||
## References
|
||||
|
||||
- Implementation: `verl/trainer/ppo/mismatch_helper.py`
|
||||
- Core algorithm: `verl/trainer/ppo/core_algos.py`
|
||||
- Paper: "Your Efficient RL Framework Secretly Brings You Off-Policy RL Training"
|
99
examples/rollout_importance_sampling/run_with_rollout_is.sh
Executable file
99
examples/rollout_importance_sampling/run_with_rollout_is.sh
Executable file
@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env bash
|
||||
# Example: Basic PPO training with Rollout Importance Sampling
|
||||
# This demonstrates the standard setup for correcting distribution mismatch
|
||||
|
||||
set -xeuo pipefail
|
||||
|
||||
# ==============================================================================
|
||||
# Rollout Importance Sampling Configuration
|
||||
# ==============================================================================
|
||||
|
||||
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
|
||||
rollout_is_threshold=2.0
|
||||
|
||||
# Whether to apply IS weights to policy loss
|
||||
# true = apply weights to loss, false = compute metrics only
|
||||
rollout_is=true
|
||||
|
||||
# Lower threshold (null = auto-reciprocal, i.e., 1/upper = 0.5)
|
||||
rollout_is_threshold_lower=null
|
||||
|
||||
# Aggregation level: token | sequence | geometric (experimental)
|
||||
rollout_is_level=token
|
||||
|
||||
# Bounding mode: truncate (cap upper) | clip (zero outside bounds)
|
||||
rollout_is_mode=truncate
|
||||
|
||||
# Catastrophic outlier veto threshold
|
||||
rollout_is_veto_threshold=1e-4
|
||||
|
||||
# ==============================================================================
|
||||
# Model and Data Configuration
|
||||
# ==============================================================================
|
||||
|
||||
MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2.5-7B"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"data/train.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"data/test.parquet"}
|
||||
|
||||
max_prompt_length=512
|
||||
max_response_length=1024
|
||||
|
||||
# ==============================================================================
|
||||
# Training Configuration
|
||||
# ==============================================================================
|
||||
|
||||
train_batch_size=128
|
||||
ppo_mini_batch_size=32
|
||||
ppo_epochs=1
|
||||
learning_rate=5e-7
|
||||
|
||||
# ==============================================================================
|
||||
# Algorithm Configuration
|
||||
# ==============================================================================
|
||||
|
||||
adv_estimator=gae
|
||||
gamma=1.0
|
||||
lam=0.95
|
||||
|
||||
# ==============================================================================
|
||||
# Launch Training
|
||||
# ==============================================================================
|
||||
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_batch_size} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.gamma=${gamma} \
|
||||
algorithm.lam=${lam} \
|
||||
algorithm.rollout_is=${rollout_is} \
|
||||
algorithm.rollout_is_threshold=${rollout_is_threshold} \
|
||||
algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \
|
||||
algorithm.rollout_is_level=${rollout_is_level} \
|
||||
algorithm.rollout_is_mode=${rollout_is_mode} \
|
||||
algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=${learning_rate} \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
|
||||
actor_rollout_ref.actor.ppo_epochs=${ppo_epochs} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
trainer.logger='["console","wandb"]' \
|
||||
trainer.project_name="rollout_is_example" \
|
||||
trainer.experiment_name="basic_token_truncate" \
|
||||
trainer.total_epochs=10
|
||||
|
||||
echo "Training completed!"
|
||||
echo ""
|
||||
echo "Rollout IS Configuration:"
|
||||
echo " - Threshold: ${rollout_is_threshold}"
|
||||
echo " - Apply to loss: ${rollout_is}"
|
||||
echo " - Level: ${rollout_is_level}"
|
||||
echo " - Mode: ${rollout_is_mode}"
|
||||
echo ""
|
||||
echo "Monitor these key metrics in wandb:"
|
||||
echo " - mismatch/rollout_is_mean (should be ~1.0)"
|
||||
echo " - mismatch/rollout_is_eff_sample_size (should be >0.5)"
|
||||
echo " - mismatch/rollout_is_veto_fraction (should be <0.1)"
|
@ -51,7 +51,7 @@ actor_rollout_ref:
|
||||
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
lr_scheduler_type: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
fsdp_config:
|
||||
wrap_policy:
|
||||
@ -105,7 +105,7 @@ critic:
|
||||
lr: 1e-5
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null # only useful for warmup with cosine
|
||||
warmup_style: constant # select from constant/cosine
|
||||
lr_scheduler_type: constant # select from constant/cosine
|
||||
total_training_steps: -1 # must be override by program
|
||||
model:
|
||||
path: ~/models/deepseek-llm-7b-chat
|
||||
|
@ -304,6 +304,11 @@ class RayDAPOTrainer(RayPPOTrainer):
|
||||
values = self.critic_wg.compute_values(batch)
|
||||
batch = batch.union(values)
|
||||
|
||||
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
|
||||
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
|
||||
# IS and mismatch metrics already have mismatch/ prefix
|
||||
metrics.update(is_metrics)
|
||||
|
||||
with marked_timer("adv", timing_raw, "brown"):
|
||||
# compute advantages, executed on the driver process
|
||||
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
|
||||
|
@ -1,8 +1,13 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
# Rollout Importance Sampling Example
|
||||
# References:
|
||||
# - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
# - Off-policy RL: https://fengyao.notion.site/off-policy-rl
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
|
||||
exp_name='DAPO-Qwen2.5-32B-RolloutIS' # Rollout Importance Sampling
|
||||
|
||||
adv_estimator=grpo
|
||||
|
||||
@ -10,7 +15,14 @@ use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
tis_imp_ratio_cap=2.0
|
||||
|
||||
# Rollout Importance Sampling parameters (matches original TIS with threshold=2)
|
||||
rollout_is=True
|
||||
rollout_is_threshold=2.0
|
||||
rollout_is_threshold_lower=null # No lower bound (original TIS behavior)
|
||||
rollout_is_level=token # token-level (original TIS behavior)
|
||||
rollout_is_mode=truncate # truncate mode (original TIS behavior)
|
||||
rollout_is_veto_threshold=null # No veto (original TIS behavior)
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
@ -58,14 +70,17 @@ offload=True
|
||||
gen_tp=4
|
||||
|
||||
|
||||
# Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
|
||||
|
||||
# Please note that server mode(agent loop) hasn't return rollout_log_probs for now.
|
||||
# so currently, server mode is not supported for TIS.
|
||||
|
||||
# To turn on TIS, you need to set the following parameters. Note 2.0 is a hyper-parameter and can be tuned.
|
||||
# actor_rollout_ref.actor.tis_imp_ratio_cap=2.0
|
||||
# actor_rollout_ref.rollout.calculate_log_probs=True
|
||||
# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)
|
||||
#
|
||||
# Please note that server mode (agent loop) hasn't returned rollout_log_probs for now,
|
||||
# so currently server mode is not supported for Rollout IS.
|
||||
#
|
||||
# Rollout IS parameters (configured at top of script):
|
||||
# algorithm.rollout_is=True
|
||||
# algorithm.rollout_is_threshold=2.0 # Upper threshold (can be tuned)
|
||||
# algorithm.rollout_is_level=token # Aggregation level
|
||||
# algorithm.rollout_is_mode=truncate # Bounding mode
|
||||
# actor_rollout_ref.rollout.calculate_log_probs=True # Required!
|
||||
|
||||
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
|
||||
--working-dir "${WORKING_DIR}" \
|
||||
@ -109,7 +124,12 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.tis_imp_ratio_cap=${tis_imp_ratio_cap} \
|
||||
algorithm.rollout_is=${rollout_is} \
|
||||
algorithm.rollout_is_threshold=${rollout_is_threshold} \
|
||||
algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \
|
||||
algorithm.rollout_is_level=${rollout_is_level} \
|
||||
algorithm.rollout_is_mode=${rollout_is_mode} \
|
||||
algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
|
@ -103,7 +103,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -100,7 +100,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -99,7 +99,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -103,7 +103,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -99,7 +99,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0 \
|
||||
actor_rollout_ref.actor.optim.warmup_style=constant \
|
||||
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
||||
|
@ -577,6 +577,11 @@ class OneStepOffRayTrainer(RayPPOTrainer):
|
||||
else:
|
||||
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
|
||||
|
||||
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
|
||||
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
|
||||
# IS and mismatch metrics already have mismatch/ prefix
|
||||
metrics.update(is_metrics)
|
||||
|
||||
# compute advantages, executed on the driver process
|
||||
|
||||
norm_adv_by_std_in_grpo = self.config.algorithm.get(
|
||||
|
@ -48,7 +48,8 @@ reward_model:
|
||||
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
|
||||
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
||||
min_lr_ratio: null
|
||||
warmup_style: constant
|
||||
warmup_style: null # deprecated
|
||||
lr_scheduler_type: constant
|
||||
total_training_steps: -1 # must be overridden by program
|
||||
weight_decay: 0.
|
||||
grad_clip: 10.0
|
||||
|
@ -42,7 +42,7 @@ FSDP_ENGINE_CONFIG="\
|
||||
optim.betas="[0.9,0.95]" \
|
||||
optim.clip_grad=1.0 \
|
||||
optim.min_lr_ratio=0.1 \
|
||||
optim.warmup_style=cosine \
|
||||
optim.lr_scheduler_type=cosine \
|
||||
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
|
||||
engine.strategy=${FSDP_STRATEGY} \
|
||||
engine.fsdp_size=${FSDP_SIZE}"
|
||||
|
@ -301,8 +301,8 @@ actor_rollout_ref:
|
||||
# Number of cosine cycles in LR schedule
|
||||
num_cycles: 0.5
|
||||
|
||||
# LR warmup style: "constant" or "cosine"
|
||||
warmup_style: constant
|
||||
# LR scheduler type: "constant" or "cosine"
|
||||
lr_scheduler_type: constant
|
||||
|
||||
# Total training steps (must be overridden at runtime)
|
||||
total_training_steps: -1
|
||||
@ -605,8 +605,8 @@ critic:
|
||||
# Minimum LR ratio for cosine schedule
|
||||
min_lr_ratio: 0.0
|
||||
|
||||
# LR warmup style: "constant" or "cosine"
|
||||
warmup_style: constant
|
||||
# LR scheduler type: "constant" or "cosine"
|
||||
lr_scheduler_type: constant
|
||||
|
||||
# Total training steps (must be overridden at runtime)
|
||||
total_training_steps: -1
|
||||
|
289
tests/trainer/ppo/test_rollout_is.py
Normal file
289
tests/trainer/ppo/test_rollout_is.py
Normal file
@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Quick Sanity Test for Rollout Importance Sampling
|
||||
|
||||
This is a standalone test script that can be run without pytest to quickly verify
|
||||
the rollout IS implementation is working correctly. For comprehensive integration
|
||||
tests, see: tests/trainer/ppo/test_rollout_is_integration.py
|
||||
|
||||
Usage:
|
||||
python test_rollout_is.py
|
||||
|
||||
This tests:
|
||||
- Basic rollout IS functionality (3 levels, 2 modes)
|
||||
- Metrics completeness (32 total: 21 IS + 11 mismatch metrics)
|
||||
- Veto mechanism
|
||||
- Edge cases
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights
|
||||
|
||||
|
||||
def test_basic_rollout_is():
|
||||
"""Test basic rollout IS functionality."""
|
||||
print("Testing basic rollout IS functionality...")
|
||||
|
||||
# Create test data
|
||||
batch_size, seq_length = 4, 10
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Create slightly different log probs (simulating BF16 vs FP32 mismatch)
|
||||
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
||||
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1
|
||||
eos_mask = torch.ones(batch_size, seq_length, device=device)
|
||||
|
||||
# Test token-level truncate mode (equivalent to old TIS)
|
||||
print("\n1. Testing token-level truncate mode...")
|
||||
weights_proto, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
weights = weights_proto.batch["rollout_is_weights"]
|
||||
print(f" Weights shape: {weights.shape}")
|
||||
print(f" Mean weight: {metrics['mismatch/rollout_is_mean']:.4f}")
|
||||
print(f" Max weight: {metrics['mismatch/rollout_is_max']:.4f}")
|
||||
print(f" Min weight: {metrics['mismatch/rollout_is_min']:.4f}")
|
||||
print(f" Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.4f}")
|
||||
assert weights.shape == old_log_prob.shape
|
||||
assert weights.max() <= 2.0, "Weights should be capped at threshold"
|
||||
print(" ✓ Token-level truncate mode passed")
|
||||
|
||||
# Test sequence-level mode
|
||||
print("\n2. Testing sequence-level mode...")
|
||||
weights_seq_proto, metrics_seq = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_level="sequence",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=5.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
weights_seq = weights_seq_proto.batch["rollout_is_weights"]
|
||||
print(f" Mean weight: {metrics_seq['mismatch/rollout_is_mean']:.4f}")
|
||||
print(f" Effective sample size: {metrics_seq['mismatch/rollout_is_eff_sample_size']:.4f}")
|
||||
# Check that all tokens in a sequence have the same weight
|
||||
for i in range(batch_size):
|
||||
seq_weights = weights_seq[i, eos_mask[i].bool()]
|
||||
assert torch.allclose(seq_weights, seq_weights[0]), "All tokens in sequence should have same weight"
|
||||
print(" ✓ Sequence-level mode passed")
|
||||
|
||||
# Test geometric mean mode
|
||||
print("\n3. Testing geometric mean mode...")
|
||||
weights_geo_proto, metrics_geo = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_level="geometric",
|
||||
rollout_is_mode="clip",
|
||||
rollout_is_threshold=1.5,
|
||||
rollout_is_threshold_lower=0.5,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}")
|
||||
print(f" Clipped fraction: {metrics_geo['mismatch/rollout_is_clipped_fraction']:.4f}")
|
||||
print(" ✓ Geometric mean mode passed")
|
||||
|
||||
# Test veto mechanism
|
||||
print("\n4. Testing veto mechanism...")
|
||||
# Create data with catastrophic outliers
|
||||
old_log_prob_veto = torch.randn(2, 5, device=device)
|
||||
rollout_log_prob_veto = old_log_prob_veto.clone()
|
||||
# Make one token have catastrophically low ratio
|
||||
rollout_log_prob_veto[0, 2] = old_log_prob_veto[0, 2] + 15.0 # ratio ~= 3e-7
|
||||
eos_mask_veto = torch.ones(2, 5, device=device)
|
||||
|
||||
weights_veto_proto, metrics_veto = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob_veto,
|
||||
rollout_log_prob=rollout_log_prob_veto,
|
||||
response_mask=eos_mask_veto,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
weights_veto = weights_veto_proto.batch["rollout_is_weights"]
|
||||
print(f" Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}")
|
||||
# Check that the sequence with catastrophic token has all weights zeroed
|
||||
assert weights_veto[0].sum() == 0, "Sequence with catastrophic token should be vetoed"
|
||||
assert weights_veto[1].sum() > 0, "Normal sequence should not be vetoed"
|
||||
print(" ✓ Veto mechanism passed")
|
||||
|
||||
# Test disabled IS (threshold=None)
|
||||
print("\n5. Testing disabled IS...")
|
||||
weights_disabled, metrics_disabled = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_threshold=None,
|
||||
)
|
||||
|
||||
assert weights_disabled is None, "Should return None when threshold is None"
|
||||
assert len(metrics_disabled) == 0, "Should return empty metrics when disabled"
|
||||
print(" ✓ Disabled IS passed")
|
||||
|
||||
print("\n✓ All tests passed!")
|
||||
|
||||
|
||||
def test_metrics_completeness():
|
||||
"""Test that all expected metrics are returned."""
|
||||
print("\nTesting metrics completeness...")
|
||||
|
||||
batch_size, seq_length = 3, 8
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
||||
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2
|
||||
eos_mask = torch.ones(batch_size, seq_length, device=device)
|
||||
|
||||
_, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=eos_mask,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.5,
|
||||
)
|
||||
|
||||
# Expected IS metrics
|
||||
expected_is_metrics = [
|
||||
"mismatch/rollout_is_mean",
|
||||
"mismatch/rollout_is_max",
|
||||
"mismatch/rollout_is_min",
|
||||
"mismatch/rollout_is_std",
|
||||
"mismatch/rollout_is_eff_sample_size",
|
||||
"mismatch/rollout_is_veto_fraction",
|
||||
"mismatch/rollout_is_catastrophic_token_fraction",
|
||||
"mismatch/rollout_is_ratio_fraction_high",
|
||||
"mismatch/rollout_is_ratio_fraction_low",
|
||||
"mismatch/rollout_is_p25",
|
||||
"mismatch/rollout_is_p50",
|
||||
"mismatch/rollout_is_p75",
|
||||
"mismatch/rollout_is_p95",
|
||||
"mismatch/rollout_is_p99",
|
||||
]
|
||||
|
||||
# Expected mismatch/diagnostic metrics (also included now)
|
||||
expected_mismatch_metrics = [
|
||||
"mismatch/mismatch_training_ppl",
|
||||
"mismatch/mismatch_training_log_ppl",
|
||||
"mismatch/mismatch_kl",
|
||||
"mismatch/mismatch_k3_kl",
|
||||
"mismatch/mismatch_rollout_ppl",
|
||||
"mismatch/mismatch_rollout_log_ppl",
|
||||
"mismatch/mismatch_log_ppl_diff",
|
||||
"mismatch/mismatch_log_ppl_abs_diff",
|
||||
"mismatch/mismatch_log_ppl_diff_max",
|
||||
"mismatch/mismatch_log_ppl_diff_min",
|
||||
"mismatch/mismatch_ppl_ratio",
|
||||
]
|
||||
|
||||
expected_metrics = expected_is_metrics + expected_mismatch_metrics
|
||||
|
||||
missing_metrics = [m for m in expected_metrics if m not in metrics]
|
||||
if missing_metrics:
|
||||
print(f" ✗ Missing metrics: {missing_metrics}")
|
||||
return False
|
||||
|
||||
print(f" ✓ All {len(expected_metrics)} expected metrics present")
|
||||
print(f" Total metrics returned: {len(metrics)}")
|
||||
return True
|
||||
|
||||
|
||||
def test_mismatch_metrics():
|
||||
"""Test mismatch metrics computation."""
|
||||
print("\nTesting mismatch metrics computation...")
|
||||
|
||||
batch_size, seq_length = 4, 12
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Create test data with some mismatch
|
||||
old_log_prob = torch.randn(batch_size, seq_length, device=device) - 2.0 # training policy
|
||||
rollout_log_prob = torch.randn(batch_size, seq_length, device=device) - 1.5 # rollout policy (more confident)
|
||||
response_mask = torch.ones(batch_size, seq_length, device=device)
|
||||
|
||||
# Test with rollout log probs
|
||||
metrics = compute_mismatch_metrics(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=response_mask,
|
||||
)
|
||||
|
||||
expected_metrics = [
|
||||
"mismatch_training_ppl",
|
||||
"mismatch_training_log_ppl",
|
||||
"mismatch_kl",
|
||||
"mismatch_k3_kl",
|
||||
"mismatch_rollout_ppl",
|
||||
"mismatch_rollout_log_ppl",
|
||||
"mismatch_log_ppl_diff",
|
||||
"mismatch_log_ppl_abs_diff",
|
||||
"mismatch_log_ppl_diff_max",
|
||||
"mismatch_log_ppl_diff_min",
|
||||
"mismatch_ppl_ratio",
|
||||
]
|
||||
|
||||
for metric in expected_metrics:
|
||||
assert metric in metrics, f"Missing metric: {metric}"
|
||||
|
||||
print(f" Training PPL: {metrics['mismatch_training_ppl']:.4f}")
|
||||
print(f" Rollout PPL: {metrics['mismatch_rollout_ppl']:.4f}")
|
||||
print(f" KL divergence: {metrics['mismatch_kl']:.6f}")
|
||||
print(f" K3 KL: {metrics['mismatch_k3_kl']:.6f}")
|
||||
print(f" PPL ratio: {metrics['mismatch_ppl_ratio']:.4f}")
|
||||
print(f" ✓ All {len(expected_metrics)} mismatch metrics present")
|
||||
|
||||
# Test without rollout log probs
|
||||
metrics_no_rollout = compute_mismatch_metrics(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=None,
|
||||
response_mask=response_mask,
|
||||
)
|
||||
|
||||
assert "mismatch_training_ppl" in metrics_no_rollout
|
||||
assert "mismatch_rollout_ppl" not in metrics_no_rollout
|
||||
print(" ✓ Mismatch metrics work without rollout log probs")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Rollout Importance Sampling Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_basic_rollout_is()
|
||||
test_metrics_completeness()
|
||||
test_mismatch_metrics()
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL TESTS PASSED ✓")
|
||||
print("=" * 60)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
exit(1)
|
241
tests/trainer/ppo/test_rollout_is_integration.py
Normal file
241
tests/trainer/ppo/test_rollout_is_integration.py
Normal file
@ -0,0 +1,241 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Integration tests for Rollout Importance Sampling."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from verl.trainer.ppo.core_algos import compute_policy_loss_vanilla
|
||||
from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights
|
||||
from verl.workers.config.actor import ActorConfig
|
||||
|
||||
|
||||
class TestRolloutISIntegration:
|
||||
"""Integration tests for Rollout IS with PPO."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample training data."""
|
||||
batch_size, seq_length = 4, 16
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
return {
|
||||
"old_log_prob": torch.randn(batch_size, seq_length, device=device),
|
||||
"log_prob": torch.randn(batch_size, seq_length, device=device),
|
||||
"rollout_log_prob": torch.randn(batch_size, seq_length, device=device),
|
||||
"advantages": torch.randn(batch_size, seq_length, device=device),
|
||||
"response_mask": torch.ones(batch_size, seq_length, device=device),
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def config_with_rollout_is(self):
|
||||
"""Create config for policy loss computation.
|
||||
|
||||
Note: rollout_is config has been moved to algorithm config.
|
||||
This config only needs fields used by policy loss (clip_ratio, etc).
|
||||
"""
|
||||
config = ActorConfig(
|
||||
strategy="fsdp",
|
||||
rollout_n=1,
|
||||
ppo_micro_batch_size=2,
|
||||
clip_ratio=0.2,
|
||||
)
|
||||
return config
|
||||
|
||||
def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):
|
||||
"""Test that policy loss computation works with rollout IS weights.
|
||||
|
||||
Note: In production, IS weights are computed centrally in the trainer
|
||||
(before advantage computation) and passed to policy loss.
|
||||
This test simulates that workflow.
|
||||
"""
|
||||
# First compute IS weights (as trainer would do centrally)
|
||||
rollout_is_weights_proto, _ = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
|
||||
|
||||
# Policy loss function receives pre-computed IS weights
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss_vanilla(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
log_prob=sample_data["log_prob"],
|
||||
advantages=sample_data["advantages"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
loss_agg_mode="token-mean",
|
||||
config=config_with_rollout_is,
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
)
|
||||
|
||||
# Check loss is valid
|
||||
assert isinstance(pg_loss, torch.Tensor)
|
||||
assert pg_loss.ndim == 0 # Scalar
|
||||
assert not torch.isnan(pg_loss)
|
||||
assert not torch.isinf(pg_loss)
|
||||
|
||||
def test_rollout_is_weights_computation(self, sample_data):
|
||||
"""Test rollout IS weights and metrics computation."""
|
||||
weights_proto, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
# Check weights
|
||||
from verl.protocol import DataProto
|
||||
|
||||
assert isinstance(weights_proto, DataProto)
|
||||
weights = weights_proto.batch["rollout_is_weights"]
|
||||
assert isinstance(weights, torch.Tensor)
|
||||
assert weights.shape == sample_data["old_log_prob"].shape
|
||||
|
||||
# Check metrics are returned
|
||||
assert isinstance(metrics, dict)
|
||||
assert len(metrics) > 0
|
||||
assert "mismatch/rollout_is_mean" in metrics
|
||||
|
||||
def test_all_aggregation_levels(self, sample_data):
|
||||
"""Test all three aggregation levels."""
|
||||
levels = ["token", "sequence", "geometric"]
|
||||
|
||||
for level in levels:
|
||||
_, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level=level,
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
)
|
||||
|
||||
assert "mismatch/rollout_is_mean" in metrics
|
||||
|
||||
def test_both_bounding_modes(self, sample_data):
|
||||
"""Test both truncate and clip modes."""
|
||||
modes = ["truncate", "clip"]
|
||||
|
||||
for mode in modes:
|
||||
_, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode=mode,
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_threshold_lower=0.5,
|
||||
)
|
||||
|
||||
assert "mismatch/rollout_is_mean" in metrics
|
||||
|
||||
def test_mismatch_metrics(self, sample_data):
|
||||
"""Test mismatch diagnostic metrics computation."""
|
||||
metrics = compute_mismatch_metrics(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
)
|
||||
|
||||
# Check key metrics are present
|
||||
assert "mismatch_training_ppl" in metrics
|
||||
assert "mismatch_rollout_ppl" in metrics
|
||||
assert "mismatch_kl" in metrics
|
||||
assert isinstance(metrics["mismatch_kl"], float)
|
||||
|
||||
def test_veto_mechanism(self):
|
||||
"""Test veto mechanism with catastrophic outliers."""
|
||||
batch_size, seq_length = 2, 5
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
old_log_prob = torch.randn(batch_size, seq_length, device=device)
|
||||
rollout_log_prob = old_log_prob.clone()
|
||||
|
||||
# Create catastrophic outlier in first sequence
|
||||
rollout_log_prob[0, 2] += 15.0 # Makes ratio ~3e-7
|
||||
|
||||
response_mask = torch.ones(batch_size, seq_length, device=device)
|
||||
|
||||
_, metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=old_log_prob,
|
||||
rollout_log_prob=rollout_log_prob,
|
||||
response_mask=response_mask,
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
rollout_is_veto_threshold=1e-4,
|
||||
)
|
||||
|
||||
# Should have vetoed one sequence
|
||||
assert metrics["mismatch/rollout_is_veto_fraction"] > 0
|
||||
assert metrics["mismatch/rollout_is_veto_fraction"] <= 1.0
|
||||
|
||||
def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
|
||||
"""Test metrics-only mode: compute IS weights/metrics but don't apply to loss.
|
||||
|
||||
This tests the use case where rollout_is_threshold is set (enables computation)
|
||||
but rollout_is=False (disables weight application to policy loss).
|
||||
"""
|
||||
# Compute IS weights (as trainer would do)
|
||||
rollout_is_weights_proto, is_metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
rollout_log_prob=sample_data["rollout_log_prob"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
rollout_is_level="token",
|
||||
rollout_is_mode="truncate",
|
||||
rollout_is_threshold=2.0,
|
||||
)
|
||||
|
||||
# Metrics should be computed
|
||||
assert len(is_metrics) > 0
|
||||
assert "mismatch/rollout_is_mean" in is_metrics
|
||||
|
||||
# In metrics-only mode, we compute loss WITHOUT applying weights
|
||||
# (simulating rollout_is=False)
|
||||
pg_loss_no_weights, _, _, _ = compute_policy_loss_vanilla(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
log_prob=sample_data["log_prob"],
|
||||
advantages=sample_data["advantages"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
loss_agg_mode="token-mean",
|
||||
config=config_with_rollout_is,
|
||||
rollout_is_weights=None, # Don't apply weights
|
||||
)
|
||||
|
||||
# Compare to loss WITH weights (rollout_is=True)
|
||||
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
|
||||
pg_loss_with_weights, _, _, _ = compute_policy_loss_vanilla(
|
||||
old_log_prob=sample_data["old_log_prob"],
|
||||
log_prob=sample_data["log_prob"],
|
||||
advantages=sample_data["advantages"],
|
||||
response_mask=sample_data["response_mask"],
|
||||
loss_agg_mode="token-mean",
|
||||
config=config_with_rollout_is,
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
)
|
||||
|
||||
# Losses should be different (weights have an effect)
|
||||
assert not torch.allclose(pg_loss_no_weights, pg_loss_with_weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
@ -21,15 +21,24 @@ class TestFSDPOptimizerConfigCPU:
|
||||
def test_default_configuration(self):
|
||||
config = FSDPOptimizerConfig(lr=0.1)
|
||||
assert config.min_lr_ratio is None
|
||||
assert config.warmup_style == "constant"
|
||||
assert config.lr_scheduler_type == "constant"
|
||||
assert config.num_cycles == 0.5
|
||||
|
||||
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
|
||||
def test_valid_warmup_styles(self, warmup_style):
|
||||
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
|
||||
assert config.warmup_style == warmup_style
|
||||
@pytest.mark.parametrize("lr_scheduler_type", ["constant", "cosine"])
|
||||
def test_valid_lr_scheduler_types(self, lr_scheduler_type):
|
||||
config = FSDPOptimizerConfig(lr_scheduler_type=lr_scheduler_type, lr=0.1)
|
||||
assert config.lr_scheduler_type == lr_scheduler_type
|
||||
|
||||
def test_invalid_warmup_style(self):
|
||||
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
|
||||
def test_valid_warmup_style_types(self, warmup_style):
|
||||
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
|
||||
assert config.lr_scheduler_type == warmup_style
|
||||
|
||||
def test_invalid_lr_scheduler_type(self):
|
||||
with pytest.raises((ValueError, AssertionError)):
|
||||
FSDPOptimizerConfig(lr_scheduler_type="invalid_style", lr=0.1)
|
||||
|
||||
def test_invalid_warmup_style_type(self):
|
||||
with pytest.raises((ValueError, AssertionError)):
|
||||
FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1)
|
||||
|
||||
|
@ -127,6 +127,8 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
|
||||
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
|
||||
inputs_embeds = kwargs.get("inputs_embeds")
|
||||
position_ids = kwargs.get("position_ids")
|
||||
visual_pos_masks = kwargs.get("visual_pos_masks")
|
||||
deepstack_visual_embeds = kwargs.get("deepstack_visual_embeds")
|
||||
call_kwargs = kwargs.copy()
|
||||
|
||||
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
||||
@ -139,6 +141,43 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
|
||||
if slice_now:
|
||||
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
|
||||
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
|
||||
# Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models
|
||||
if visual_pos_masks is not None:
|
||||
original_visual_mask = visual_pos_masks
|
||||
sliced_visual_mask = slice_input_tensor(visual_pos_masks, dim=1, padding=False)
|
||||
call_kwargs["visual_pos_masks"] = sliced_visual_mask
|
||||
|
||||
if deepstack_visual_embeds is not None:
|
||||
sliced_embeds = []
|
||||
|
||||
num_visual_before = original_visual_mask.sum().item()
|
||||
num_visual_in_shard = sliced_visual_mask.sum().item()
|
||||
|
||||
if num_visual_in_shard > 0 and num_visual_before > 0:
|
||||
# Calculate which visual embeddings belong to this shard
|
||||
# We need to find the offset of visual tokens in this shard
|
||||
from verl.utils.ulysses import get_ulysses_sequence_parallel_rank
|
||||
|
||||
rank = get_ulysses_sequence_parallel_rank()
|
||||
seq_len = original_visual_mask.shape[1]
|
||||
local_seq_len = seq_len // current_ulysses_sp_size
|
||||
start_idx = rank * local_seq_len
|
||||
end_idx = start_idx + local_seq_len
|
||||
|
||||
# Get total visual tokens before and up to the end of the shard's sequence slice
|
||||
# This correctly handles batches by summing across all samples
|
||||
visual_start = original_visual_mask[:, :start_idx].sum().item() if start_idx > 0 else 0
|
||||
visual_end = original_visual_mask[:, :end_idx].sum().item()
|
||||
|
||||
# Slice each tensor in deepstack_visual_embeds
|
||||
for embed in deepstack_visual_embeds:
|
||||
sliced_embeds.append(embed[visual_start:visual_end])
|
||||
else:
|
||||
# No visual tokens in this shard, create empty tensors to maintain gradient flow
|
||||
for embed in deepstack_visual_embeds:
|
||||
sliced_embeds.append(embed[:0])
|
||||
call_kwargs["deepstack_visual_embeds"] = sliced_embeds
|
||||
|
||||
self._needs_initial_slice = False
|
||||
try:
|
||||
return original_forward(self, *args, **call_kwargs)
|
||||
@ -290,9 +329,7 @@ def apply_monkey_patch(
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
|
||||
)
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2VLFlashAttention2 as Qwen2VLAttention,
|
||||
)
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
|
||||
|
||||
if use_remove_padding or ulysses_sp_size > 1:
|
||||
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward
|
||||
|
@ -209,8 +209,10 @@ def _get_input_embeds(
|
||||
patch_dim = config.in_channels * config.temporal_patch_size * config.patch_size**2
|
||||
pixel_values = torch.zeros((16, patch_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
||||
image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
|
||||
image_embeds, _ = model.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_embeds, dummy_deepstack_image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
inputs_embeds += 0.0 * image_embeds.mean()
|
||||
for emb in dummy_deepstack_image_embeds or []:
|
||||
inputs_embeds += 0.0 * emb.mean()
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
@ -75,7 +75,6 @@ actor_rollout_ref:
|
||||
clip_ratio_c: 3.0
|
||||
loss_agg_mode: token-mean
|
||||
entropy_coeff: 0
|
||||
tis_imp_ratio_cap: -1
|
||||
use_kl_loss: false
|
||||
use_torch_compile: true
|
||||
kl_loss_coef: 0.001
|
||||
@ -484,6 +483,12 @@ algorithm:
|
||||
pf_ppo:
|
||||
reweight_method: pow
|
||||
weight_pow: 2.0
|
||||
rollout_is_threshold: null
|
||||
rollout_is_threshold_lower: null
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 0.0001
|
||||
rollout_is: false
|
||||
trainer:
|
||||
balance_batch: true
|
||||
total_epochs: 30
|
||||
|
@ -18,7 +18,8 @@ actor_rollout_ref:
|
||||
clip_grad: 1.0
|
||||
min_lr_ratio: 0.0
|
||||
num_cycles: 0.5
|
||||
warmup_style: constant
|
||||
lr_scheduler_type: constant
|
||||
warmup_style: null
|
||||
fsdp_config:
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
wrap_policy:
|
||||
@ -59,7 +60,6 @@ actor_rollout_ref:
|
||||
clip_ratio_c: 3.0
|
||||
loss_agg_mode: token-mean
|
||||
entropy_coeff: 0
|
||||
tis_imp_ratio_cap: -1
|
||||
use_kl_loss: false
|
||||
use_torch_compile: true
|
||||
kl_loss_coef: 0.001
|
||||
@ -315,7 +315,8 @@ critic:
|
||||
clip_grad: 1.0
|
||||
min_lr_ratio: 0.0
|
||||
num_cycles: 0.5
|
||||
warmup_style: constant
|
||||
lr_scheduler_type: constant
|
||||
warmup_style: null
|
||||
model:
|
||||
fsdp_config:
|
||||
_target_: verl.workers.config.FSDPEngineConfig
|
||||
@ -462,6 +463,12 @@ algorithm:
|
||||
pf_ppo:
|
||||
reweight_method: pow
|
||||
weight_pow: 2.0
|
||||
rollout_is_threshold: null
|
||||
rollout_is_threshold_lower: null
|
||||
rollout_is_level: token
|
||||
rollout_is_mode: truncate
|
||||
rollout_is_veto_threshold: 0.0001
|
||||
rollout_is: false
|
||||
trainer:
|
||||
balance_batch: true
|
||||
total_epochs: 30
|
||||
|
@ -74,10 +74,6 @@ loss_agg_mode: token-mean
|
||||
# Entropy regularization coefficient in PPO loss
|
||||
entropy_coeff: 0
|
||||
|
||||
# Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl
|
||||
# the truncation value C of truncated Importance Sampling (-1 for disable TIS)
|
||||
tis_imp_ratio_cap: -1
|
||||
|
||||
# Whether to use KL loss instead of KL reward penalty. True for GRPO
|
||||
use_kl_loss: false
|
||||
|
||||
|
@ -73,6 +73,14 @@ class AlgoConfig(BaseConfig):
|
||||
use_pf_ppo (bool): Whether to enable preference feedback PPO.
|
||||
pf_ppo (dict[str, Any]): Preference feedback PPO settings.
|
||||
filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy
|
||||
rollout_is_threshold (Optional[float]): Upper threshold for IS weights. null = disabled,
|
||||
float value = enabled (compute weights and metrics). This is the main on/off switch.
|
||||
rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.
|
||||
rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric".
|
||||
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "clip" (zero outside bounds).
|
||||
rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers.
|
||||
rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,
|
||||
False = compute metrics only (useful for monitoring before enabling correction). Default: False.
|
||||
"""
|
||||
|
||||
gamma: float = 1.0
|
||||
@ -85,3 +93,13 @@ class AlgoConfig(BaseConfig):
|
||||
use_pf_ppo: bool = False
|
||||
pf_ppo: dict[str, Any] = field(default_factory=dict)
|
||||
filter_groups: Optional[FilterGroupsConfig] = None
|
||||
# Rollout Importance Sampling (replaces legacy tis_imp_ratio_cap)
|
||||
# Controls computation of IS weights and mismatch metrics
|
||||
rollout_is_threshold: Optional[float] = None # null = disabled, float = enabled
|
||||
rollout_is_threshold_lower: Optional[float] = None
|
||||
rollout_is_level: str = "token"
|
||||
rollout_is_mode: str = "truncate"
|
||||
rollout_is_veto_threshold: Optional[float] = 1e-4
|
||||
# Controls whether to apply IS weights to policy loss (only if rollout_is_threshold is set)
|
||||
# True = apply weights to loss, False = compute metrics only (no weight application)
|
||||
rollout_is: bool = False
|
||||
|
@ -28,6 +28,8 @@ min_lr_ratio: 0.0
|
||||
# Number of cosine cycles in LR schedule
|
||||
num_cycles: 0.5
|
||||
|
||||
# LR warmup style: "constant" or "cosine"
|
||||
warmup_style: constant
|
||||
# LR scheduler type: "constant" or "cosine"
|
||||
lr_scheduler_type: constant
|
||||
|
||||
# deprecated
|
||||
warmup_style: null
|
||||
|
@ -73,6 +73,28 @@ algorithm:
|
||||
reweight_method: pow # ["pow", "max_min", "max_random"]
|
||||
weight_pow: 2.0
|
||||
|
||||
# Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies
|
||||
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
|
||||
# When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)
|
||||
rollout_is_threshold: null
|
||||
|
||||
# Lower threshold for IS weights (null = auto-reciprocal of upper)
|
||||
rollout_is_threshold_lower: null
|
||||
|
||||
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
|
||||
rollout_is_level: token
|
||||
|
||||
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
|
||||
rollout_is_mode: truncate
|
||||
|
||||
# Per-token veto threshold for catastrophic outliers
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
|
||||
# Whether to apply IS weights to policy loss
|
||||
# true = apply weights to loss, false = compute metrics only (no weight application)
|
||||
# Useful for monitoring mismatch before enabling correction
|
||||
rollout_is: false
|
||||
|
||||
trainer:
|
||||
balance_batch: True
|
||||
total_epochs: 30
|
||||
|
@ -113,6 +113,28 @@ algorithm:
|
||||
# Power used for weight scaling in "pow" method
|
||||
weight_pow: 2.0
|
||||
|
||||
# Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies
|
||||
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
|
||||
# When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)
|
||||
rollout_is_threshold: null
|
||||
|
||||
# Lower threshold for IS weights (null = auto-reciprocal of upper)
|
||||
rollout_is_threshold_lower: null
|
||||
|
||||
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
|
||||
rollout_is_level: token
|
||||
|
||||
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
|
||||
rollout_is_mode: truncate
|
||||
|
||||
# Per-token veto threshold for catastrophic outliers
|
||||
rollout_is_veto_threshold: 1e-4
|
||||
|
||||
# Whether to apply IS weights to policy loss
|
||||
# true = apply weights to loss, false = compute metrics only (no weight application)
|
||||
# Useful for monitoring mismatch before enabling correction
|
||||
rollout_is: false
|
||||
|
||||
# config for the trainer
|
||||
trainer:
|
||||
|
||||
|
@ -881,7 +881,7 @@ def compute_policy_loss(
|
||||
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
|
||||
|
||||
|
||||
@register_policy_loss("vanilla")
|
||||
@register_policy_loss("vanilla") # type: ignore[arg-type]
|
||||
def compute_policy_loss_vanilla(
|
||||
old_log_prob: torch.Tensor,
|
||||
log_prob: torch.Tensor,
|
||||
@ -889,7 +889,7 @@ def compute_policy_loss_vanilla(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for PPO.
|
||||
@ -959,11 +959,9 @@ def compute_policy_loss_vanilla(
|
||||
|
||||
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
|
||||
|
||||
if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None:
|
||||
# Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
|
||||
tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
|
||||
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
|
||||
pg_losses = pg_losses * tis_imp_ratio
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
|
||||
@ -978,7 +976,7 @@ def compute_policy_loss_gspo(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "seq-mean-token-mean",
|
||||
config: Optional[DictConfig | ActorConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for GSPO.
|
||||
@ -1024,6 +1022,10 @@ def compute_policy_loss_gspo(
|
||||
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
|
||||
pg_losses = torch.maximum(pg_losses1, pg_losses2)
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
# for GSPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean")
|
||||
|
||||
@ -1044,7 +1046,7 @@ def compute_policy_loss_gpg(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Adapted from
|
||||
https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
|
||||
@ -1061,6 +1063,10 @@ def compute_policy_loss_gpg(
|
||||
"""
|
||||
pg_losses = -log_prob * advantages
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
|
||||
|
||||
@ -1073,7 +1079,7 @@ def compute_policy_loss_clip_cov(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for Clip-Cov.
|
||||
@ -1155,6 +1161,11 @@ def compute_policy_loss_clip_cov(
|
||||
pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)
|
||||
|
||||
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
|
||||
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
|
||||
@ -1168,7 +1179,7 @@ def compute_policy_loss_kl_cov(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for Clip-Cov.
|
||||
@ -1227,6 +1238,10 @@ def compute_policy_loss_kl_cov(
|
||||
large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]
|
||||
]
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
if rollout_is_weights is not None:
|
||||
pg_losses = pg_losses * rollout_is_weights
|
||||
|
||||
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
|
||||
|
||||
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
|
||||
@ -1240,7 +1255,7 @@ def compute_policy_loss_geo_mean(
|
||||
response_mask: torch.Tensor,
|
||||
loss_agg_mode: str = "token-mean",
|
||||
config: Optional[DictConfig | AlgoConfig] = None,
|
||||
rollout_log_probs: torch.Tensor | None = None,
|
||||
rollout_is_weights: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the clipped policy objective and related metrics for GMPO.
|
||||
@ -1293,6 +1308,17 @@ def compute_policy_loss_geo_mean(
|
||||
# otherwise, below would be not consistent with the paper
|
||||
advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
|
||||
pg_losses = -advantage * ratio
|
||||
|
||||
# Apply rollout importance sampling weights if provided
|
||||
# For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level
|
||||
if rollout_is_weights is not None:
|
||||
# Aggregate token-level weights to sequence level using geometric mean for consistency
|
||||
# Note: rollout_is_weights is always 2D regardless of rollout_is_level
|
||||
seq_is_weights = torch.exp(
|
||||
(torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
|
||||
)
|
||||
pg_losses = pg_losses * seq_is_weights
|
||||
|
||||
pg_loss = torch.mean(pg_losses)
|
||||
|
||||
# higher: ratio is too large that need clamp to clip_high (when adv > 0)
|
||||
|
459
verl/trainer/ppo/mismatch_helper.py
Normal file
459
verl/trainer/ppo/mismatch_helper.py
Normal file
@ -0,0 +1,459 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Rollout Importance Sampling (IS) Helper Module
|
||||
|
||||
This module handles importance sampling weight computation for correcting
|
||||
distribution mismatch between rollout policy (e.g., vLLM BFloat16) and
|
||||
training policy (e.g., FSDP FP32).
|
||||
|
||||
Key Features:
|
||||
1. Three aggregation levels: token, sequence, geometric
|
||||
2. Two handling modes: truncate (TIS), clip (CIS)
|
||||
3. Per-token veto mechanism for catastrophic outliers
|
||||
4. Memory-efficient computation to prevent CUDA OOM
|
||||
5. Comprehensive metrics tracking
|
||||
|
||||
Usage Notes:
|
||||
- compute_rollout_importance_weights() computes both IS weights and mismatch metrics
|
||||
- Used in ray_trainer.py via compute_rollout_importance_weights_and_add_to_batch()
|
||||
- Also used in dp_actor.py for distributed worker computations
|
||||
- compute_mismatch_metrics() is called internally by compute_rollout_importance_weights()
|
||||
|
||||
References:
|
||||
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
- Off-policy RL: https://fengyao.notion.site/off-policy-rl
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from verl.protocol import DataProto
|
||||
|
||||
|
||||
def compute_rollout_importance_weights(
|
||||
old_log_prob: torch.Tensor,
|
||||
rollout_log_prob: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
rollout_is_level: str = "token",
|
||||
rollout_is_mode: str = "truncate",
|
||||
rollout_is_threshold: Optional[float] = None,
|
||||
rollout_is_threshold_lower: Optional[float] = None,
|
||||
rollout_is_veto_threshold: Optional[float] = 1e-4,
|
||||
) -> tuple[Optional[DataProto], dict[str, Any]]:
|
||||
"""Compute importance sampling weights and metrics for rollout-training mismatch correction.
|
||||
|
||||
This function handles the computation of importance sampling (IS) weights to correct
|
||||
for the distribution mismatch between rollout policy and training policy.
|
||||
|
||||
Reference:
|
||||
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
|
||||
Memory-efficient implementation that prevents CUDA OOM by:
|
||||
- Using log-space computation where possible
|
||||
- Applying safety bounds to prevent numerical overflow
|
||||
- Computing metrics without creating huge intermediate tensors
|
||||
|
||||
Args:
|
||||
old_log_prob: Log probabilities from training policy (e.g., FSDP), shape (batch_size, seq_length)
|
||||
rollout_log_prob: Log probabilities from rollout policy (e.g., vLLM), shape (batch_size, seq_length)
|
||||
response_mask: Mask for valid tokens, shape (batch_size, seq_length)
|
||||
rollout_is_level: Level of IS aggregation:
|
||||
- "token": Per-token ratios (biased)
|
||||
- "sequence": Product of ratios (unbiased)
|
||||
- "geometric": Geometric mean of ratios (experimental)
|
||||
rollout_is_mode: How to handle weights exceeding threshold:
|
||||
- "truncate": Cap weights at upper_threshold only (TIS)
|
||||
- "clip": Zero out weights outside [lower_threshold, upper_threshold] (CIS)
|
||||
rollout_is_threshold: Upper threshold for IS weights
|
||||
rollout_is_threshold_lower: Lower threshold for IS weights (clip mode only; if None, defaults to 1/upper)
|
||||
rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence.
|
||||
If None, veto mechanism is disabled.
|
||||
|
||||
Returns:
|
||||
Tuple of (weights_proto, metrics) where:
|
||||
weights_proto: DataProto containing IS weights with key "rollout_is_weights",
|
||||
shape (batch_size, seq_length). Returns None if rollout_is_threshold is None.
|
||||
metrics: Dictionary of IS statistics and mismatch metrics (KL, PPL, etc.),
|
||||
all converted to scalars and prefixed with "mismatch/"
|
||||
"""
|
||||
if rollout_is_threshold is None:
|
||||
return None, {}
|
||||
|
||||
# Parse thresholds: if lower not specified, use 1/upper (reciprocal)
|
||||
upper_threshold = rollout_is_threshold
|
||||
if rollout_is_threshold_lower is not None:
|
||||
lower_threshold = rollout_is_threshold_lower
|
||||
else:
|
||||
# Default: lower = 1/upper (reciprocal)
|
||||
lower_threshold = 1.0 / upper_threshold
|
||||
|
||||
# Step 1: Compute raw importance weights based on the specified level
|
||||
log_ratio = old_log_prob - rollout_log_prob
|
||||
|
||||
# Pre-compute log thresholds
|
||||
device = old_log_prob.device
|
||||
log_threshold_upper = torch.log(torch.tensor(upper_threshold, device=device))
|
||||
log_threshold_lower = torch.log(torch.tensor(lower_threshold, device=device))
|
||||
|
||||
# Safety bound to prevent numerical overflow (exp(20) ≈ 485M)
|
||||
SAFETY_BOUND = 20.0
|
||||
|
||||
# Store unclamped values in log-space for accurate metrics
|
||||
if rollout_is_level == "token":
|
||||
# Token-level IS: π_train(a|s) / π_rollout(a|s) per token
|
||||
log_ratio_for_metrics = log_ratio
|
||||
|
||||
# Apply safety bound to prevent overflow
|
||||
log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND)
|
||||
rollout_is_weights = torch.exp(log_ratio_safe)
|
||||
|
||||
elif rollout_is_level == "sequence":
|
||||
# Sequence-level IS: π_train(y|x) / π_rollout(y|x) for entire sequence
|
||||
# Product of token ratios: exp(Σ log(π_train/π_rollout))
|
||||
log_ratio_sum = verl_F.masked_sum(log_ratio, response_mask, axis=-1).unsqueeze(-1)
|
||||
log_ratio_for_metrics = log_ratio_sum # Store for metrics
|
||||
|
||||
# Apply safety bound to prevent overflow
|
||||
log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND)
|
||||
rollout_is_weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob)
|
||||
|
||||
elif rollout_is_level == "geometric":
|
||||
# Geometric mean IS: (∏ π_train/π_rollout)^(1/T)
|
||||
# Equivalent to exp(mean(log(π_train/π_rollout)))
|
||||
log_ratio_mean = verl_F.masked_mean(log_ratio, response_mask, axis=-1).unsqueeze(-1)
|
||||
log_ratio_for_metrics = log_ratio_mean # Store for metrics
|
||||
|
||||
# Geometric mean rarely explodes due to averaging, but apply safety bound anyway
|
||||
log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-SAFETY_BOUND, max=SAFETY_BOUND)
|
||||
rollout_is_weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid rollout_is_level: {rollout_is_level}. Must be 'token', 'sequence', or 'geometric'.")
|
||||
|
||||
# Step 1.5: Apply per-token veto check in log space (memory efficient)
|
||||
if rollout_is_veto_threshold is not None:
|
||||
log_veto_threshold = torch.log(torch.tensor(rollout_is_veto_threshold, device=device))
|
||||
|
||||
# Check if any token ratio is below veto threshold (in log space)
|
||||
# log(π_train/π_rollout) < log(veto_threshold) ⟺ π_train/π_rollout < veto_threshold
|
||||
catastrophic_tokens = (log_ratio < log_veto_threshold) & response_mask.bool()
|
||||
|
||||
# For each sequence, check if it has any catastrophic token
|
||||
# Use broadcasting instead of expand_as to save memory
|
||||
has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True)
|
||||
|
||||
# Create veto mask: 0 if sequence has catastrophic token, 1 otherwise
|
||||
veto_mask = (~has_catastrophic).float()
|
||||
else:
|
||||
# No veto mechanism
|
||||
catastrophic_tokens = torch.zeros_like(response_mask, dtype=torch.bool)
|
||||
has_catastrophic = torch.zeros((old_log_prob.size(0), 1), dtype=torch.bool, device=device)
|
||||
veto_mask = torch.ones((old_log_prob.size(0), 1), dtype=torch.float32, device=device)
|
||||
|
||||
# Step 2: Compute comprehensive metrics
|
||||
metrics = compute_is_metrics(
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
log_ratio_for_metrics=log_ratio_for_metrics,
|
||||
response_mask=response_mask,
|
||||
rollout_is_level=rollout_is_level,
|
||||
rollout_is_threshold=upper_threshold,
|
||||
rollout_is_threshold_lower=lower_threshold,
|
||||
log_threshold_upper=log_threshold_upper,
|
||||
log_threshold_lower=log_threshold_lower,
|
||||
has_catastrophic=has_catastrophic,
|
||||
catastrophic_tokens=catastrophic_tokens,
|
||||
SAFETY_BOUND=SAFETY_BOUND,
|
||||
)
|
||||
|
||||
# Step 3: Apply truncation or clipping based on mode
|
||||
if rollout_is_mode == "truncate":
|
||||
# Truncated IS (TIS): only cap upper bound to prevent overweighting
|
||||
rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold)
|
||||
|
||||
elif rollout_is_mode == "clip":
|
||||
# Clipped IS (CIS): zero out weights outside [lower_threshold, upper_threshold]
|
||||
clip_mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
|
||||
clip_mask = clip_mask.float()
|
||||
|
||||
# Track CIS-specific metrics
|
||||
metrics["rollout_is_clipped_fraction"] = verl_F.masked_mean(1 - clip_mask, response_mask)
|
||||
|
||||
# Sequence-level clipping fraction
|
||||
if rollout_is_level in ["sequence", "geometric"]:
|
||||
# All tokens in a sequence have the same weight, so reuse clip_mask
|
||||
metrics["rollout_is_seq_clipped_fraction"] = (1 - clip_mask[:, 0]).mean()
|
||||
else:
|
||||
# Check if any token in each sequence is clipped
|
||||
seq_has_clipped = verl_F.masked_sum(1 - clip_mask, response_mask, axis=-1) > 0
|
||||
metrics["rollout_is_seq_clipped_fraction"] = seq_has_clipped.float().mean()
|
||||
|
||||
rollout_is_weights = rollout_is_weights * clip_mask
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'clip'.")
|
||||
|
||||
# Apply veto mask AFTER all thresholding
|
||||
# This zeros out entire sequences that have any catastrophic token
|
||||
rollout_is_weights = rollout_is_weights * veto_mask
|
||||
|
||||
# Apply response_mask to ensure weights are 0 where mask is 0
|
||||
rollout_is_weights = rollout_is_weights * response_mask
|
||||
|
||||
# Wrap in DataProto for consistency with worker methods
|
||||
rollout_is_weights_proto = DataProto.from_dict(tensors={"rollout_is_weights": rollout_is_weights})
|
||||
|
||||
# Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics
|
||||
mismatch_metrics = compute_mismatch_metrics(
|
||||
old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask
|
||||
)
|
||||
metrics.update(mismatch_metrics)
|
||||
|
||||
# Convert all tensor metrics to scalars for logging
|
||||
# Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad()
|
||||
metrics_scalar = {}
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
metrics_scalar[f"mismatch/{key}"] = value.item()
|
||||
else:
|
||||
metrics_scalar[f"mismatch/{key}"] = value
|
||||
|
||||
return rollout_is_weights_proto, metrics_scalar
|
||||
|
||||
|
||||
def compute_is_metrics(
|
||||
rollout_is_weights: torch.Tensor,
|
||||
log_ratio_for_metrics: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
rollout_is_level: str,
|
||||
rollout_is_threshold: float,
|
||||
rollout_is_threshold_lower: float,
|
||||
log_threshold_upper: torch.Tensor,
|
||||
log_threshold_lower: torch.Tensor,
|
||||
has_catastrophic: torch.Tensor,
|
||||
catastrophic_tokens: torch.Tensor,
|
||||
SAFETY_BOUND: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute comprehensive metrics for importance sampling weights.
|
||||
|
||||
Reference:
|
||||
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
|
||||
This function computes metrics using a mix of true unclamped values (for max/min/fractions
|
||||
in sequence/geometric mode via log-space) and safety-clamped values (for mean/std/ESS)
|
||||
to balance accuracy with numerical stability and avoid overflow.
|
||||
"""
|
||||
# Validate that we have at least one valid sample
|
||||
assert response_mask.any(), "Expected at least one valid sample in response_mask"
|
||||
|
||||
metrics = {}
|
||||
device = rollout_is_weights.device
|
||||
|
||||
# Track veto statistics
|
||||
metrics["rollout_is_veto_fraction"] = has_catastrophic.float().mean()
|
||||
metrics["rollout_is_catastrophic_token_fraction"] = verl_F.masked_mean(catastrophic_tokens.float(), response_mask)
|
||||
|
||||
# Compute metrics based on IS level
|
||||
if rollout_is_level in ["sequence", "geometric"]:
|
||||
# For sequence/geometric, compute true statistics from log-space
|
||||
# This reflects the actual distribution before clamping
|
||||
|
||||
# True max/min in log space
|
||||
log_max = log_ratio_for_metrics.max()
|
||||
log_min = log_ratio_for_metrics.min()
|
||||
|
||||
# Convert to regular space with safety bound
|
||||
metrics["rollout_is_max"] = torch.exp(torch.clamp(log_max, max=SAFETY_BOUND))
|
||||
metrics["rollout_is_min"] = torch.exp(log_min)
|
||||
|
||||
# Mean uses clamped weights to avoid overflow
|
||||
metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask)
|
||||
|
||||
# Compute fraction exceeding threshold in log space (accurate)
|
||||
exceeds_upper = log_ratio_for_metrics > log_threshold_upper
|
||||
below_lower = log_ratio_for_metrics < log_threshold_lower
|
||||
|
||||
if rollout_is_level == "sequence":
|
||||
# For sequence level, all tokens in a sequence have the same weight
|
||||
metrics["rollout_is_ratio_fraction_high"] = exceeds_upper.float().mean()
|
||||
metrics["rollout_is_ratio_fraction_low"] = below_lower.float().mean()
|
||||
else: # geometric
|
||||
# Need to expand to match token dimensions
|
||||
exceeds_upper_expanded = exceeds_upper.expand_as(response_mask)
|
||||
below_lower_expanded = below_lower.expand_as(response_mask)
|
||||
metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean(
|
||||
exceeds_upper_expanded.float(), response_mask
|
||||
)
|
||||
metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(below_lower_expanded.float(), response_mask)
|
||||
|
||||
else:
|
||||
# Token-level: compute directly from weights
|
||||
metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask)
|
||||
|
||||
# Fraction exceeding thresholds
|
||||
rollout_is_above_threshold = rollout_is_weights > rollout_is_threshold
|
||||
rollout_is_below_threshold = rollout_is_weights < rollout_is_threshold_lower
|
||||
metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean(
|
||||
rollout_is_above_threshold.float(), response_mask
|
||||
)
|
||||
metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(rollout_is_below_threshold.float(), response_mask)
|
||||
|
||||
# Max/min for token level
|
||||
mask_bool = response_mask.bool()
|
||||
metrics["rollout_is_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max()
|
||||
metrics["rollout_is_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min()
|
||||
|
||||
# Compute standard deviation using clamped weights to avoid overflow
|
||||
mask_count = response_mask.sum()
|
||||
if mask_count > 1:
|
||||
# Use clamped weights for variance to avoid squaring huge values
|
||||
weights_for_std = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
|
||||
# Use mean from clamped weights for consistency
|
||||
mean_clamped = verl_F.masked_mean(weights_for_std, response_mask)
|
||||
rollout_is_var = verl_F.masked_mean(weights_for_std.square(), response_mask) - mean_clamped.square()
|
||||
metrics["rollout_is_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0))
|
||||
else:
|
||||
metrics["rollout_is_std"] = torch.tensor(0.0, device=device)
|
||||
|
||||
# Effective sample size (use clamped weights to avoid overflow)
|
||||
weights_for_ess = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
|
||||
mean_for_ess = verl_F.masked_mean(weights_for_ess, response_mask)
|
||||
is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8)
|
||||
metrics["rollout_is_eff_sample_size"] = 1.0 / verl_F.masked_mean(is_weights_normalized.square(), response_mask)
|
||||
|
||||
# Per-sequence breakdown metrics
|
||||
if rollout_is_weights.dim() > 1:
|
||||
# Compute mean IS weight per sequence
|
||||
seq_mean_weights = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1)
|
||||
|
||||
# Per-sequence statistics
|
||||
metrics["rollout_is_seq_mean"] = seq_mean_weights.mean()
|
||||
metrics["rollout_is_seq_std"] = (
|
||||
seq_mean_weights.std() if seq_mean_weights.numel() > 1 else torch.tensor(0.0, device=device)
|
||||
)
|
||||
metrics["rollout_is_seq_max"] = seq_mean_weights.max()
|
||||
metrics["rollout_is_seq_min"] = seq_mean_weights.min()
|
||||
|
||||
# Identify most problematic sequences
|
||||
seq_deviation = (seq_mean_weights - 1.0).abs()
|
||||
metrics["rollout_is_seq_max_deviation"] = seq_deviation.max()
|
||||
|
||||
# Fraction of sequences with high IS weights
|
||||
metrics["rollout_is_seq_fraction_high"] = (seq_mean_weights > rollout_is_threshold).float().mean()
|
||||
metrics["rollout_is_seq_fraction_low"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean()
|
||||
|
||||
# Percentile metrics for better distribution understanding
|
||||
# Get all valid IS weights
|
||||
flat_weights = rollout_is_weights[response_mask.bool()]
|
||||
# Compute key percentiles (guaranteed to have elements due to assertion at function start)
|
||||
assert flat_weights.numel() > 0, "flat_weights should not be empty"
|
||||
metrics["rollout_is_p25"] = torch.quantile(flat_weights, 0.25)
|
||||
metrics["rollout_is_p50"] = torch.quantile(flat_weights, 0.50) # median
|
||||
metrics["rollout_is_p75"] = torch.quantile(flat_weights, 0.75)
|
||||
metrics["rollout_is_p95"] = torch.quantile(flat_weights, 0.95)
|
||||
metrics["rollout_is_p99"] = torch.quantile(flat_weights, 0.99)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def compute_mismatch_metrics(
|
||||
old_log_prob: torch.Tensor,
|
||||
rollout_log_prob: Optional[torch.Tensor],
|
||||
response_mask: torch.Tensor,
|
||||
) -> dict[str, Any]:
|
||||
"""Compute training-inference mismatch metrics (helper function).
|
||||
|
||||
This helper function operates on raw tensors and is used internally by:
|
||||
- compute_rollout_importance_weights() in this module (automatically included)
|
||||
- Tests (test_rollout_is.py, test_rollout_is_integration.py)
|
||||
|
||||
These metrics help diagnose the mismatch between the rollout policy (e.g., vLLM)
|
||||
and the training policy (e.g., FSDP), which can cause training instability.
|
||||
|
||||
Key metrics:
|
||||
- mismatch_kl: Direct KL divergence estimator KL(π_rollout || π_training)
|
||||
- mismatch_k3_kl: K3 KL estimator for stability (more stable for small KL)
|
||||
- training_ppl: Perplexity of training policy
|
||||
- rollout_ppl: Perplexity of rollout policy
|
||||
- log_ppl_diff: Difference in log perplexities
|
||||
- ppl_ratio: Ratio of training PPL to rollout PPL
|
||||
|
||||
Args:
|
||||
old_log_prob: Log probabilities from training policy, shape (batch_size, seq_length)
|
||||
rollout_log_prob: Log probabilities from rollout policy, shape (batch_size, seq_length)
|
||||
response_mask: Mask for valid tokens, shape (batch_size, seq_length)
|
||||
|
||||
Returns:
|
||||
Dictionary of mismatch metrics (without prefix)
|
||||
|
||||
Reference:
|
||||
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
|
||||
"""
|
||||
# Validate that we have at least one valid token
|
||||
assert response_mask.any(), "Expected at least one valid token in response_mask"
|
||||
|
||||
metrics = {}
|
||||
|
||||
# 1. Training policy perplexity (always available)
|
||||
# Formula: exp(-1/|T| * Σ log π_training(y_t|y_<t))
|
||||
# where |T| is the number of tokens generated by the model
|
||||
mean_log_prob_training = verl_F.masked_mean(old_log_prob, response_mask, axis=-1) # (batch_size,)
|
||||
training_ppl = torch.exp(-mean_log_prob_training).mean() # Batch mean of per-sequence PPL
|
||||
metrics["mismatch_training_ppl"] = training_ppl.detach().item()
|
||||
|
||||
# Also log log-ppl for easier analysis (avoids exponential scale)
|
||||
metrics["mismatch_training_log_ppl"] = (-mean_log_prob_training).mean().detach().item()
|
||||
|
||||
# 2. Compute rollout mismatch metrics (only if rollout_log_probs available)
|
||||
if rollout_log_prob is not None:
|
||||
# 2a. mismatch_kl: Direct estimator for KL(π_rollout || π_training)
|
||||
# This is the standard KL divergence: E[log(π_rollout) - log(π_training)]
|
||||
# Positive value means rollout policy is more confident than training policy
|
||||
metrics["mismatch_kl"] = verl_F.masked_mean(rollout_log_prob - old_log_prob, response_mask).detach().item()
|
||||
|
||||
# 2b. mismatch_k3_kl: K3 estimator for KL(π_rollout || π_training)
|
||||
# More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1]
|
||||
# Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout
|
||||
log_ratio = old_log_prob - rollout_log_prob
|
||||
mismatch_k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1
|
||||
metrics["mismatch_k3_kl"] = verl_F.masked_mean(mismatch_k3_kl_matrix, response_mask).detach().item()
|
||||
|
||||
# 2c. Rollout policy perplexity
|
||||
mean_log_prob_rollout = verl_F.masked_mean(rollout_log_prob, response_mask, axis=-1) # (batch_size,)
|
||||
rollout_ppl = torch.exp(-mean_log_prob_rollout).mean() # Batch mean of per-sequence PPL
|
||||
metrics["mismatch_rollout_ppl"] = rollout_ppl.detach().item()
|
||||
metrics["mismatch_rollout_log_ppl"] = (-mean_log_prob_rollout).mean().detach().item()
|
||||
|
||||
# 2d. Log PPL difference (sequence-level perplexity difference)
|
||||
# log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
|
||||
# Since ppl = exp(-log_prob), we have:
|
||||
# log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff
|
||||
# Positive value means training assigns lower probability (higher PPL) than rollout
|
||||
log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
|
||||
metrics["mismatch_log_ppl_diff"] = log_ppl_diff.mean().detach().item()
|
||||
metrics["mismatch_log_ppl_abs_diff"] = log_ppl_diff.abs().mean().detach().item()
|
||||
metrics["mismatch_log_ppl_diff_max"] = log_ppl_diff.max().detach().item()
|
||||
metrics["mismatch_log_ppl_diff_min"] = log_ppl_diff.min().detach().item()
|
||||
|
||||
# 2e. PPL ratio (how much higher is training PPL vs rollout PPL)
|
||||
# IMPORTANT: Compute per-sequence ratio first, then average
|
||||
# For numerical stability, compute in log space using log_ppl_diff
|
||||
# Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff)
|
||||
# This is the inverse of geometric IS: ppl_ratio_i = 1 / geometric_is_i for each sequence
|
||||
ppl_ratio = torch.exp(log_ppl_diff).mean() # mean(exp(log_ppl_diff)) = mean(ppl_ratio_i)
|
||||
metrics["mismatch_ppl_ratio"] = ppl_ratio.detach().item()
|
||||
|
||||
return metrics
|
@ -49,6 +49,7 @@ from verl.trainer.ppo.metric_utils import (
|
||||
compute_timing_metrics,
|
||||
process_validation_metrics,
|
||||
)
|
||||
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights
|
||||
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
||||
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
|
||||
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
|
||||
@ -918,6 +919,49 @@ class RayPPOTrainer:
|
||||
)
|
||||
metrics.update(global_balance_stats)
|
||||
|
||||
def compute_rollout_importance_weights_and_add_to_batch(self, batch: DataProto) -> tuple[DataProto, dict]:
|
||||
"""Compute rollout importance sampling weights and mismatch metrics, conditionally add weights to batch.
|
||||
|
||||
This method computes IS weights to correct for distribution mismatch between
|
||||
rollout policy and training policy. It always computes metrics when enabled, but
|
||||
only adds weights to batch if algorithm.rollout_is is True.
|
||||
|
||||
Args:
|
||||
batch: DataProto containing old_log_probs, rollout_log_probs, response_mask
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_batch, metrics) where:
|
||||
- updated_batch: Batch with rollout_is_weights added (if rollout_is=True)
|
||||
- metrics: Dictionary of IS and mismatch metrics (all with mismatch/ prefix)
|
||||
"""
|
||||
# Compute rollout IS weights if enabled and data is available
|
||||
# rollout_is_threshold is the main on/off switch
|
||||
if self.config.algorithm.rollout_is_threshold is not None and "rollout_log_probs" in batch.batch:
|
||||
rollout_is_weights, rollout_is_metrics = compute_rollout_importance_weights(
|
||||
old_log_prob=batch.batch["old_log_probs"],
|
||||
rollout_log_prob=batch.batch["rollout_log_probs"],
|
||||
response_mask=batch.batch["response_mask"],
|
||||
rollout_is_level=self.config.algorithm.rollout_is_level,
|
||||
rollout_is_mode=self.config.algorithm.rollout_is_mode,
|
||||
rollout_is_threshold=self.config.algorithm.rollout_is_threshold,
|
||||
rollout_is_threshold_lower=self.config.algorithm.rollout_is_threshold_lower,
|
||||
rollout_is_veto_threshold=self.config.algorithm.rollout_is_veto_threshold,
|
||||
)
|
||||
|
||||
# Control: Should we apply weights to policy loss?
|
||||
# True = add weights to batch (actor will apply them)
|
||||
# False = don't add weights (metrics only, no loss modification)
|
||||
apply_weights = self.config.algorithm.get("rollout_is", False)
|
||||
|
||||
if apply_weights:
|
||||
# Add IS weights to batch for distribution to workers
|
||||
batch = batch.union(rollout_is_weights)
|
||||
|
||||
return batch, rollout_is_metrics
|
||||
|
||||
# Return unchanged batch and empty metrics if IS is disabled
|
||||
return batch, {}
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
The training loop of PPO.
|
||||
@ -1107,6 +1151,13 @@ class RayPPOTrainer:
|
||||
else:
|
||||
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
|
||||
|
||||
# Compute rollout importance sampling weights centrally (once per batch)
|
||||
# This corrects for mismatch between rollout policy and training policy
|
||||
# Also computes mismatch metrics (KL, PPL, etc.)
|
||||
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
|
||||
# IS and mismatch metrics already have mismatch/ prefix
|
||||
metrics.update(is_metrics)
|
||||
|
||||
# compute advantages, executed on the driver process
|
||||
norm_adv_by_std_in_grpo = self.config.algorithm.get(
|
||||
"norm_adv_by_std_in_grpo", True
|
||||
@ -1205,6 +1256,7 @@ class RayPPOTrainer:
|
||||
# TODO: implement actual tflpo and theoretical tflpo
|
||||
n_gpus = self.resource_pool_manager.get_n_gpus()
|
||||
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
|
||||
# Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation
|
||||
|
||||
# this is experimental and may be changed/removed in the future in favor of a general-purpose one
|
||||
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
|
||||
|
@ -373,13 +373,10 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
]
|
||||
if self.config.use_kl_loss:
|
||||
select_keys.append("ref_log_prob")
|
||||
if self.config.tis_imp_ratio_cap > 0:
|
||||
assert "rollout_log_probs" in data.batch.keys(), (
|
||||
"Truncated Importance Sampling (TIS) requires to configure "
|
||||
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
|
||||
"and is not currently supported in Server mode (agent loop)."
|
||||
)
|
||||
select_keys.append("rollout_log_probs")
|
||||
# Include pre-computed IS weights if present in batch
|
||||
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
|
||||
if "rollout_is_weights" in data.batch.keys():
|
||||
select_keys.append("rollout_is_weights")
|
||||
|
||||
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
|
||||
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
|
||||
@ -412,7 +409,6 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
|
||||
response_mask = model_inputs["response_mask"]
|
||||
old_log_prob = model_inputs["old_log_probs"]
|
||||
rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None
|
||||
advantages = model_inputs["advantages"]
|
||||
|
||||
entropy_coeff = self.config.entropy_coeff
|
||||
@ -438,9 +434,21 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
|
||||
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
|
||||
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
|
||||
|
||||
# Extract pre-computed rollout importance sampling weights if present
|
||||
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
|
||||
rollout_is_weights = model_inputs.get("rollout_is_weights", None)
|
||||
|
||||
# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
|
||||
# are computed centrally in ray_trainer.py for consistency and efficiency.
|
||||
# This ensures metrics are computed uniformly across all batches at the trainer level
|
||||
# and avoids redundant computation across workers and micro-batches.
|
||||
|
||||
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
|
||||
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
|
||||
policy_loss_fn = get_policy_loss_fn(loss_mode)
|
||||
|
||||
# Compute policy loss (all functions return 4 values)
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
@ -448,7 +456,7 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
response_mask=response_mask,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
config=self.config,
|
||||
rollout_log_probs=rollout_log_probs,
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
)
|
||||
|
||||
if entropy_coeff != 0:
|
||||
|
@ -316,6 +316,10 @@ class MegatronPPOActor(BasePPOActor):
|
||||
]
|
||||
if self.config.use_kl_loss:
|
||||
select_keys.append("ref_log_prob")
|
||||
# Include pre-computed IS weights if present in batch
|
||||
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
|
||||
if "rollout_is_weights" in data.batch.keys():
|
||||
select_keys.append("rollout_is_weights")
|
||||
self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
|
||||
if self.has_multi_modal_inputs:
|
||||
data = data.select(select_keys, ["multi_modal_inputs"])
|
||||
@ -419,7 +423,6 @@ class MegatronPPOActor(BasePPOActor):
|
||||
response_length = responses.size(1)
|
||||
response_mask = data["response_mask"].to(bool)
|
||||
loss_agg_mode = self.config.loss_agg_mode
|
||||
|
||||
# compute policy loss
|
||||
log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous()
|
||||
ret_entropy = None
|
||||
@ -434,6 +437,15 @@ class MegatronPPOActor(BasePPOActor):
|
||||
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
|
||||
|
||||
policy_loss_fn = get_policy_loss_fn(loss_mode)
|
||||
|
||||
# Extract pre-computed rollout importance sampling weights if present
|
||||
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
|
||||
rollout_is_weights = data.get("rollout_is_weights", None)
|
||||
|
||||
# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
|
||||
# are computed centrally in ray_trainer.py for consistency and efficiency.
|
||||
# This ensures metrics are computed uniformly across all batches at the trainer level
|
||||
# and avoids redundant computation across workers and micro-batches.
|
||||
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
|
||||
old_log_prob=old_log_prob,
|
||||
log_prob=log_prob,
|
||||
@ -441,6 +453,7 @@ class MegatronPPOActor(BasePPOActor):
|
||||
response_mask=response_mask,
|
||||
loss_agg_mode=loss_agg_mode,
|
||||
config=self.config,
|
||||
rollout_is_weights=rollout_is_weights,
|
||||
)
|
||||
|
||||
stats.update(
|
||||
|
@ -106,7 +106,6 @@ class ActorConfig(BaseConfig):
|
||||
clip_ratio_c: float = 3.0
|
||||
loss_agg_mode: str = "token-mean"
|
||||
entropy_coeff: float = 0
|
||||
tis_imp_ratio_cap: float = -1
|
||||
use_kl_loss: bool = False
|
||||
use_torch_compile: bool = True
|
||||
kl_loss_coef: float = 0.001
|
||||
|
@ -60,16 +60,27 @@ class FSDPOptimizerConfig(OptimizerConfig):
|
||||
Args:
|
||||
lr (float): Learning rate.
|
||||
min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.
|
||||
warmup_style (str): LR warmup style: "constant" or "cosine".
|
||||
lr_scheduler_type (str): LR scheduler type: "constant" or "cosine".
|
||||
num_cycles (float): Number of cosine cycles in LR schedule.
|
||||
"""
|
||||
|
||||
_mutable_fields = OptimizerConfig._mutable_fields.copy()
|
||||
_mutable_fields.add("lr_scheduler_type")
|
||||
|
||||
min_lr_ratio: Optional[float] = None
|
||||
warmup_style: str = "constant"
|
||||
# deprecate warmup_style
|
||||
warmup_style: Optional[str] = None
|
||||
lr_scheduler_type: str = "constant"
|
||||
num_cycles: float = 0.5
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.warmup_style in ["constant", "cosine"]
|
||||
if self.warmup_style is not None:
|
||||
assert self.warmup_style in ["constant", "cosine"]
|
||||
warnings.warn(
|
||||
"`warmup_style` is deprecated, use `lr_scheduler_type` instead.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
self.lr_scheduler_type = self.warmup_style
|
||||
assert self.lr_scheduler_type in ["constant", "cosine"]
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
|
@ -370,7 +370,7 @@ class FSDPEngine(BaseEngine):
|
||||
|
||||
total_steps = optim_config.total_training_steps
|
||||
num_warmup_steps = optim_config.lr_warmup_steps
|
||||
warmup_style = optim_config.warmup_style
|
||||
lr_scheduler_type = optim_config.lr_scheduler_type
|
||||
min_lr_ratio = optim_config.min_lr_ratio
|
||||
num_cycles = optim_config.num_cycles
|
||||
if num_warmup_steps <= 0:
|
||||
@ -380,9 +380,9 @@ class FSDPEngine(BaseEngine):
|
||||
if self.rank == 0:
|
||||
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
|
||||
|
||||
if warmup_style == "constant":
|
||||
if lr_scheduler_type == "constant":
|
||||
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)
|
||||
elif warmup_style == "cosine":
|
||||
elif lr_scheduler_type == "cosine":
|
||||
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
@ -391,7 +391,7 @@ class FSDPEngine(BaseEngine):
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
||||
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
|
||||
return lr_scheduler
|
||||
|
||||
def _build_model_optimizer(self):
|
||||
|
@ -529,7 +529,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
|
||||
total_steps = optim_config.get("total_training_steps", 0)
|
||||
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
|
||||
warmup_style = optim_config.get("warmup_style", "constant")
|
||||
lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant")
|
||||
min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
|
||||
num_cycles = optim_config.get("num_cycles", 0.5)
|
||||
if num_warmup_steps < 0:
|
||||
@ -539,11 +539,11 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
if self.rank == 0:
|
||||
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
|
||||
|
||||
if warmup_style == "constant":
|
||||
if lr_scheduler_type == "constant":
|
||||
actor_lr_scheduler = get_constant_schedule_with_warmup(
|
||||
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
|
||||
)
|
||||
elif warmup_style == "cosine":
|
||||
elif lr_scheduler_type == "cosine":
|
||||
actor_lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer=actor_optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
@ -552,7 +552,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
||||
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
|
||||
|
||||
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
|
||||
else:
|
||||
@ -1386,7 +1386,8 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
|
||||
total_steps = config.optim.get("total_training_steps", 0)
|
||||
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
|
||||
warmup_style = config.optim.get("warmup_style", "constant")
|
||||
|
||||
lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant")
|
||||
if num_warmup_steps < 0:
|
||||
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
|
||||
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
|
||||
@ -1396,11 +1397,11 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
|
||||
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
|
||||
|
||||
if warmup_style == "constant":
|
||||
if lr_scheduler_type == "constant":
|
||||
critic_lr_scheduler = get_constant_schedule_with_warmup(
|
||||
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
|
||||
)
|
||||
elif warmup_style == "cosine":
|
||||
elif lr_scheduler_type == "cosine":
|
||||
min_lr_ratio = config.optim.get("min_lr_ratio", 0.0)
|
||||
num_cycles = config.optim.get("num_cycles", 0.5)
|
||||
critic_lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
@ -1411,7 +1412,7 @@ class CriticWorker(Worker, DistProfilerExtension):
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
||||
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
|
||||
|
||||
return critic_module, critic_optimizer, critic_lr_scheduler
|
||||
|
||||
|
Reference in New Issue
Block a user